package infra import ( "context" "encoding/json" "fmt" "io" "net/http" ) // MCPHTTPAuthFunc validates an incoming HTTP request and returns an enriched // context (e.g. with user_id) or an error. When it returns an error the // handler replies 401 Unauthorized without invoking the tool handler. // If nil, no auth is performed. type MCPHTTPAuthFunc func(r *http.Request) (context.Context, error) // MCPHTTPOpts configures the Streamable HTTP MCP handler. type MCPHTTPOpts struct { Name string // server name reported to the client in initialize Version string // server version reported to the client in initialize Tools []MCPToolDef // reuses MCPToolDef from mcp_server_stdio.go Handler MCPToolHandler // reuses MCPToolHandler from mcp_server_stdio.go Auth MCPHTTPAuthFunc // optional; if nil, no auth Logger io.Writer // optional log sink; discards when nil } const mcpHTTPBodyLimit = 1 << 20 // 1 MiB // MCPHTTPHandler returns an http.Handler that implements the Streamable HTTP // MCP transport (spec 2025-03-26). // // Mount at any single path (e.g. /mcp). Handles POST for client→server // JSON-RPC 2.0 requests. GET and DELETE return 405 Method Not Allowed (SSE // server→client streaming is not implemented — see Gotchas in the .md). // // The handler is safe for concurrent use; it carries no shared mutable state. func MCPHTTPHandler(opts MCPHTTPOpts) http.Handler { logf := func(format string, args ...any) { if opts.Logger != nil { fmt.Fprintf(opts.Logger, "[mcp-http] "+format+"\n", args...) } } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodPost: // handled below case http.MethodGet, http.MethodDelete: // SSE server→client and session close not implemented yet. http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } // Optional auth: validate request and (optionally) enrich context. ctx := r.Context() if opts.Auth != nil { enriched, err := opts.Auth(r) if err != nil { logf("auth rejected: %v", err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) _ = json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"}) return } ctx = enriched } // Read body with size limit (anti-DoS). limitedBody := http.MaxBytesReader(w, r.Body, mcpHTTPBodyLimit) body, err := io.ReadAll(limitedBody) if err != nil { // MaxBytesReader wraps the error; treat any read failure as 413. logf("body read error: %v", err) w.WriteHeader(http.StatusRequestEntityTooLarge) return } logf("recv: %s", body) // Parse JSON-RPC request. On parse failure respond HTTP 200 with // JSON-RPC error -32700 (per MCP spec — not HTTP 400). var req jsonrpcRequest if err := json.Unmarshal(body, &req); err != nil { logf("json parse error: %v", err) writeJSONRPCError(w, nil, -32700, "parse error: "+err.Error()) return } // Notifications (no "id" key in raw JSON) → 202 Accepted, no body. isNotification := !jsonHasKey(body, "id") if isNotification { w.WriteHeader(http.StatusAccepted) return } // Dispatch method. switch req.Method { case "initialize": result := map[string]any{ "protocolVersion": "2024-11-05", "capabilities": map[string]any{ "tools": map[string]any{}, }, "serverInfo": map[string]any{ "name": opts.Name, "version": opts.Version, }, } writeJSONRPCResult(w, req.ID, result) case "initialized": // Should not arrive as a non-notification, but handle gracefully. writeJSONRPCResult(w, req.ID, map[string]any{}) case "tools/list": tools := opts.Tools if tools == nil { tools = []MCPToolDef{} } writeJSONRPCResult(w, req.ID, map[string]any{"tools": tools}) case "tools/call": var p mcpCallParams if err := json.Unmarshal(req.Params, &p); err != nil { writeJSONRPCError(w, req.ID, -32602, "invalid params: "+err.Error()) return } args := p.Arguments if args == nil { args = json.RawMessage(`{}`) } toolResult, isErr, handlerErr := opts.Handler(ctx, p.Name, args) if handlerErr != nil { logf("handler error for %q: %v", p.Name, handlerErr) writeJSONRPCError(w, req.ID, -32603, handlerErr.Error()) return } resultText, _ := json.Marshal(toolResult) callResult := map[string]any{ "content": []map[string]any{ { "type": "text", "text": string(resultText), }, }, "isError": isErr, } writeJSONRPCResult(w, req.ID, callResult) case "ping": writeJSONRPCResult(w, req.ID, map[string]any{}) default: logf("unknown method %q", req.Method) writeJSONRPCError(w, req.ID, -32601, "method not found: "+req.Method) } }) } // writeJSONRPCResult writes a JSON-RPC 2.0 success response. func writeJSONRPCResult(w http.ResponseWriter, id any, result any) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(jsonrpcResponse{ JSONRPC: "2.0", ID: id, Result: result, }) } // writeJSONRPCError writes a JSON-RPC 2.0 error response with HTTP 200. // Per the MCP Streamable HTTP spec, protocol errors still use HTTP 200 so // the client can parse the JSON-RPC error object (not HTTP status codes). func writeJSONRPCError(w http.ResponseWriter, id any, code int, message string) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(jsonrpcResponse{ JSONRPC: "2.0", ID: id, Error: &jsonrpcError{Code: code, Message: message}, }) }