package infra import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" ) // --- helpers ---------------------------------------------------------------- func newTestMCPHandler(auth MCPHTTPAuthFunc) http.Handler { tools := []MCPToolDef{ { Name: "greet", Description: "Returns a greeting", InputSchema: json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`), }, } handler := func(_ context.Context, name string, _ json.RawMessage) (any, bool, error) { if name == "greet" { return map[string]string{"hello": "world"}, false, nil } return nil, true, errors.New("unknown tool") } return MCPHTTPHandler(MCPHTTPOpts{ Name: "test-server", Version: "0.0.1", Tools: tools, Handler: handler, Auth: auth, }) } func postMCP(h http.Handler, body string) *httptest.ResponseRecorder { r := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(body)) r.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() h.ServeHTTP(w, r) return w } // --- tests ------------------------------------------------------------------ func TestMCPHTTPHandler_Initialize(t *testing.T) { h := newTestMCPHandler(nil) w := postMCP(h, `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}`) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var resp jsonrpcResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decode error: %v", err) } if resp.Error != nil { t.Fatalf("unexpected error: %+v", resp.Error) } result, ok := resp.Result.(map[string]any) if !ok { t.Fatalf("result is not map: %T", resp.Result) } if _, ok := result["protocolVersion"]; !ok { t.Error("missing protocolVersion in result") } si, ok := result["serverInfo"].(map[string]any) if !ok { t.Fatal("missing serverInfo") } if si["name"] != "test-server" { t.Errorf("serverInfo.name = %v, want test-server", si["name"]) } } func TestMCPHTTPHandler_ToolsList(t *testing.T) { h := newTestMCPHandler(nil) w := postMCP(h, `{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var resp jsonrpcResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decode: %v", err) } if resp.Error != nil { t.Fatalf("unexpected rpc error: %+v", resp.Error) } result := resp.Result.(map[string]any) tools, ok := result["tools"].([]any) if !ok || len(tools) == 0 { t.Fatalf("expected non-empty tools array, got %v", result["tools"]) } } func TestMCPHTTPHandler_ToolsCall(t *testing.T) { h := newTestMCPHandler(nil) w := postMCP(h, `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"greet","arguments":{}}}`) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var resp jsonrpcResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decode: %v", err) } if resp.Error != nil { t.Fatalf("unexpected rpc error: %+v", resp.Error) } result := resp.Result.(map[string]any) content, ok := result["content"].([]any) if !ok || len(content) == 0 { t.Fatalf("expected content array, got %v", result["content"]) } first := content[0].(map[string]any) if first["text"] != `{"hello":"world"}` { t.Errorf("content[0].text = %q, want {\"hello\":\"world\"}", first["text"]) } } func TestMCPHTTPHandler_BadAuth(t *testing.T) { auth := func(_ *http.Request) (context.Context, error) { return nil, errors.New("bad token") } h := newTestMCPHandler(auth) w := postMCP(h, `{"jsonrpc":"2.0","id":4,"method":"initialize","params":{}}`) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } } func TestMCPHTTPHandler_BodyTooLarge(t *testing.T) { h := newTestMCPHandler(nil) big := strings.Repeat("x", mcpHTTPBodyLimit+1) body := `{"jsonrpc":"2.0","id":5,"method":"initialize","params":{"x":"` + big + `"}}` r := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(body)) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusRequestEntityTooLarge { t.Fatalf("expected 413, got %d", w.Code) } } func TestMCPHTTPHandler_ParseError(t *testing.T) { h := newTestMCPHandler(nil) w := postMCP(h, `not valid json`) if w.Code != http.StatusOK { t.Fatalf("expected HTTP 200 for parse error, got %d", w.Code) } var resp jsonrpcResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decode: %v", err) } if resp.Error == nil { t.Fatal("expected JSON-RPC error, got nil") } if resp.Error.Code != -32700 { t.Errorf("error code = %d, want -32700", resp.Error.Code) } } func TestMCPHTTPHandler_Notification(t *testing.T) { h := newTestMCPHandler(nil) // A notification has no "id" key at all. w := postMCP(h, `{"jsonrpc":"2.0","method":"initialized","params":{}}`) if w.Code != http.StatusAccepted { t.Fatalf("expected 202 for notification, got %d", w.Code) } if w.Body.Len() != 0 { t.Errorf("expected empty body for notification, got %q", w.Body.String()) } } func TestMCPHTTPHandler_MethodNotAllowed(t *testing.T) { h := newTestMCPHandler(nil) for _, method := range []string{http.MethodGet, http.MethodDelete} { t.Run(method, func(t *testing.T) { r := httptest.NewRequest(method, "/mcp", nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusMethodNotAllowed { t.Errorf("%s: expected 405, got %d", method, w.Code) } }) } }