package infra import ( "encoding/json" "net/http" "net/http/httptest" "testing" ) func TestRateLimiterByKey(t *testing.T) { base := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) keyFunc := func(r *http.Request) string { return r.Header.Get("X-API-Key") } t.Run("aplica limit por la clave devuelta por keyFunc", func(t *testing.T) { rl := RateLimiterNew(1, 2) mw := RateLimiterByKey(rl, keyFunc) handler := mw(base) // Agotar tokens de api-key-A for i := 0; i < 2; i++ { req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-API-Key", "key-A") handler.ServeHTTP(httptest.NewRecorder(), req) } // Tercer rechazado rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-API-Key", "key-A") handler.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("status=%d, want 429", rec.Code) } // key-B intacta recB := httptest.NewRecorder() reqB := httptest.NewRequest("GET", "/", nil) reqB.Header.Set("X-API-Key", "key-B") handler.ServeHTTP(recB, reqB) if recB.Code != http.StatusOK { t.Errorf("key-B status=%d, want 200", recB.Code) } }) t.Run("key vacia salta el limit", func(t *testing.T) { rl := RateLimiterNew(1, 1) mw := RateLimiterByKey(rl, keyFunc) handler := mw(base) // Sin X-API-Key, hagamos muchos requests for i := 0; i < 10; i++ { rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("request %d status=%d, want 200 (sin key, sin limit)", i, rec.Code) } } }) t.Run("responde 429 con body JSON al exceder", func(t *testing.T) { rl := RateLimiterNew(1, 1) mw := RateLimiterByKey(rl, keyFunc) handler := mw(base) // Agotar req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-API-Key", "key-X") handler.ServeHTTP(httptest.NewRecorder(), req) // Rechazado rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Fatalf("status=%d, want 429", rec.Code) } var body map[string]any if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { t.Fatalf("body no es JSON: %v", err) } // HTTPError se serializa con campos PascalCase (sin tags JSON) if body["Code"] != "rate_limited" { t.Errorf("Code=%v, want rate_limited", body["Code"]) } }) t.Run("headers X-RateLimit-* siempre presentes en respuesta", func(t *testing.T) { rl := RateLimiterNew(10, 20) mw := RateLimiterByKey(rl, keyFunc) handler := mw(base) rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-API-Key", "key-Y") handler.ServeHTTP(rec, req) if rec.Header().Get("X-RateLimit-Limit") == "" { t.Error("X-RateLimit-Limit ausente") } if rec.Header().Get("X-RateLimit-Remaining") == "" { t.Error("X-RateLimit-Remaining ausente") } if rec.Header().Get("X-RateLimit-Reset") == "" { t.Error("X-RateLimit-Reset ausente") } }) }