package main import ( "database/sql" "encoding/json" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" _ "github.com/mattn/go-sqlite3" ) func setupTestDB(t *testing.T) (*DBPool, string) { t.Helper() dir := t.TempDir() dbPath := filepath.Join(dir, "test.db") // Create a small test database with FTS5. db, err := sql.Open("sqlite3", dbPath) if err != nil { t.Fatal(err) } stmts := []string{ `CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT, kind TEXT)`, `INSERT INTO items VALUES ('a', 'alpha', 'first')`, `INSERT INTO items VALUES ('b', 'beta', 'second')`, `CREATE VIRTUAL TABLE items_fts USING fts5(id, name, kind, content=items, content_rowid=rowid)`, `INSERT INTO items_fts(items_fts) VALUES('rebuild')`, } for _, s := range stmts { if _, err := db.Exec(s); err != nil { t.Fatalf("setup sql %q: %v", s, err) } } db.Close() pool := NewDBPool() pool.Register(DBEntry{Alias: "testdb", Path: dbPath, Kind: "test"}) return pool, dir } func TestHealthEndpoint(t *testing.T) { pool := NewDBPool() srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) req := httptest.NewRequest("GET", "/health", nil) w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var resp map[string]string json.NewDecoder(w.Body).Decode(&resp) if resp["status"] != "ok" { t.Fatalf("expected status ok, got %s", resp["status"]) } } func TestDatabasesEndpoint(t *testing.T) { pool := NewDBPool() pool.Register(DBEntry{Alias: "registry", Path: "/fake/path", Kind: "registry"}) pool.Register(DBEntry{Alias: "ops:myapp", Path: "/fake/path2", Kind: "operations"}) srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) req := httptest.NewRequest("GET", "/api/databases", nil) w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var resp []map[string]string json.NewDecoder(w.Body).Decode(&resp) if len(resp) != 2 { t.Fatalf("expected 2 databases, got %d", len(resp)) } } func TestQueryEndpoint(t *testing.T) { pool, _ := setupTestDB(t) defer pool.Close() srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) body := `{"sql": "SELECT id, name FROM items ORDER BY id"}` req := httptest.NewRequest("POST", "/api/databases/testdb/query", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]any json.NewDecoder(w.Body).Decode(&resp) if int(resp["count"].(float64)) != 2 { t.Fatalf("expected 2 rows, got %v", resp["count"]) } cols := resp["columns"].([]any) if len(cols) != 2 || cols[0] != "id" || cols[1] != "name" { t.Fatalf("unexpected columns: %v", cols) } } func TestQueryRejectsWrite(t *testing.T) { pool, _ := setupTestDB(t) defer pool.Close() srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) cases := []string{ `{"sql": "INSERT INTO items VALUES ('c', 'gamma', 'third')"}`, `{"sql": "UPDATE items SET name = 'x' WHERE id = 'a'"}`, `{"sql": "DELETE FROM items WHERE id = 'a'"}`, `{"sql": "DROP TABLE items"}`, } for _, body := range cases { req := httptest.NewRequest("POST", "/api/databases/testdb/query", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected 400 for %s, got %d", body, w.Code) } } } func TestTablesEndpoint(t *testing.T) { pool, _ := setupTestDB(t) defer pool.Close() srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) req := httptest.NewRequest("GET", "/api/databases/testdb/tables", nil) w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp []map[string]string json.NewDecoder(w.Body).Decode(&resp) if len(resp) == 0 { t.Fatal("expected at least one table") } found := false for _, tbl := range resp { if tbl["name"] == "items" { found = true } } if !found { t.Fatal("expected 'items' table in response") } } func TestSchemaEndpoint(t *testing.T) { pool, _ := setupTestDB(t) defer pool.Close() srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) req := httptest.NewRequest("GET", "/api/databases/testdb/schema", nil) w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]any json.NewDecoder(w.Body).Decode(&resp) count := int(resp["count"].(float64)) if count == 0 { t.Fatal("expected at least one schema statement") } } func TestNotFoundDB(t *testing.T) { pool := NewDBPool() srv := NewServer(pool, "") mux := http.NewServeMux() srv.Routes(mux) body := `{"sql": "SELECT 1"}` req := httptest.NewRequest("POST", "/api/databases/nonexistent/query", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() mux.ServeHTTP(w, req) if w.Code != http.StatusNotFound { t.Fatalf("expected 404, got %d", w.Code) } } func TestValidateQuery(t *testing.T) { valid := []string{ "SELECT * FROM t", " select id from t", "PRAGMA table_info(t)", "WITH cte AS (SELECT 1) SELECT * FROM cte", "EXPLAIN SELECT * FROM t", "-- comment\nSELECT 1", } for _, q := range valid { if err := ValidateQuery(q); err != nil { t.Errorf("expected valid: %q, got error: %v", q, err) } } invalid := []string{ "INSERT INTO t VALUES (1)", "UPDATE t SET x = 1", "DELETE FROM t", "DROP TABLE t", "CREATE TABLE t (id INT)", "ALTER TABLE t ADD COLUMN x INT", } for _, q := range invalid { if err := ValidateQuery(q); err == nil { t.Errorf("expected invalid: %q, got nil error", q) } } } func TestDiscoverDatabases(t *testing.T) { dir := t.TempDir() // Create registry.db os.WriteFile(filepath.Join(dir, "registry.db"), []byte{}, 0644) // Create apps/myapp/operations.db os.MkdirAll(filepath.Join(dir, "apps", "myapp"), 0755) os.WriteFile(filepath.Join(dir, "apps", "myapp", "operations.db"), []byte{}, 0644) // Create projects/proj1/apps/papp/operations.db os.MkdirAll(filepath.Join(dir, "projects", "proj1", "apps", "papp"), 0755) os.WriteFile(filepath.Join(dir, "projects", "proj1", "apps", "papp", "operations.db"), []byte{}, 0644) entries := DiscoverDatabases(dir) if len(entries) != 3 { t.Fatalf("expected 3 entries, got %d: %+v", len(entries), entries) } aliases := map[string]bool{} for _, e := range entries { aliases[e.Alias] = true } for _, want := range []string{"registry", "ops:myapp", "ops:papp"} { if !aliases[want] { t.Errorf("missing alias %q", want) } } }