package infra import ( "bufio" "context" "encoding/json" "fmt" "io" "os" "sync" ) // MCPToolDef describes a tool exported by the MCP server. // InputSchema must be a valid JSON Schema object with "type":"object" and // "properties" describing the tool arguments. type MCPToolDef struct { Name string `json:"name"` Description string `json:"description"` InputSchema json.RawMessage `json:"inputSchema"` } // MCPToolHandler executes a tool. input is the raw JSON of the arguments // sent by the MCP client (the value of params.arguments). // Returns result (any JSON-serializable value), isError (true when the tool // itself reports a logical error, not a protocol error), and err (internal // failure that results in a JSON-RPC error response with code -32603). type MCPToolHandler func(ctx context.Context, name string, input json.RawMessage) (result any, isError bool, err error) // MCPServerOpts configures the MCP stdio server. type MCPServerOpts struct { Name string // server name reported to the client in initialize Version string // server version reported to the client in initialize Tools []MCPToolDef Handler MCPToolHandler // single dispatcher for all tools In io.Reader // defaults to os.Stdin when nil Out io.Writer // defaults to os.Stdout when nil Logger io.Writer // optional log sink (e.g. os.Stderr); discards when nil } // jsonrpcRequest is the wire format for an incoming JSON-RPC 2.0 message. type jsonrpcRequest struct { JSONRPC string `json:"jsonrpc"` ID any `json:"id"` // number, string, or null; absent for notifications Method string `json:"method"` Params json.RawMessage `json:"params"` } // jsonrpcResponse is the wire format for an outgoing JSON-RPC 2.0 response. type jsonrpcResponse struct { JSONRPC string `json:"jsonrpc"` ID any `json:"id,omitempty"` Result any `json:"result,omitempty"` Error *jsonrpcError `json:"error,omitempty"` } type jsonrpcError struct { Code int `json:"code"` Message string `json:"message"` } // mcpCallParams is params.arguments unwrapped from a tools/call request. type mcpCallParams struct { Name string `json:"name"` Arguments json.RawMessage `json:"arguments"` } // ServeMCP runs the JSON-RPC 2.0 loop over stdio implementing the minimum MCP // protocol surface: initialize, initialized (notification), tools/list, // tools/call, and ping. It reads newline-delimited JSON from opts.In and writes // newline-delimited JSON to opts.Out. // // ServeMCP returns nil when the client closes stdin (EOF) or when ctx is // cancelled. It returns an error only on unrecoverable write failures. func ServeMCP(ctx context.Context, opts MCPServerOpts) error { in := opts.In if in == nil { in = os.Stdin } out := opts.Out if out == nil { out = os.Stdout } logf := func(format string, args ...any) { if opts.Logger != nil { fmt.Fprintf(opts.Logger, "[mcp] "+format+"\n", args...) } } var mu sync.Mutex writeLine := func(v any) error { b, err := json.Marshal(v) if err != nil { return fmt.Errorf("mcp marshal: %w", err) } mu.Lock() defer mu.Unlock() if _, err := out.Write(append(b, '\n')); err != nil { return fmt.Errorf("mcp write: %w", err) } return nil } sendResult := func(id any, result any) error { return writeLine(jsonrpcResponse{ JSONRPC: "2.0", ID: id, Result: result, }) } sendError := func(id any, code int, message string) error { return writeLine(jsonrpcResponse{ JSONRPC: "2.0", ID: id, Error: &jsonrpcError{Code: code, Message: message}, }) } scanner := bufio.NewScanner(in) scanner.Buffer(make([]byte, 4*1024*1024), 4*1024*1024) scanCh := make(chan string) scanErr := make(chan error, 1) go func() { for scanner.Scan() { line := scanner.Text() if len(line) == 0 { continue } select { case scanCh <- line: case <-ctx.Done(): return } } if err := scanner.Err(); err != nil { scanErr <- err } close(scanCh) }() for { select { case <-ctx.Done(): logf("context cancelled, stopping") return nil case err := <-scanErr: return fmt.Errorf("mcp scanner: %w", err) case line, ok := <-scanCh: if !ok { logf("stdin closed, stopping") return nil } logf("recv: %s", line) var req jsonrpcRequest if err := json.Unmarshal([]byte(line), &req); err != nil { logf("json parse error: %v", err) // id is unknown; respond with null id if err2 := sendError(nil, -32700, "parse error: "+err.Error()); err2 != nil { return err2 } continue } // Notifications have no id field. After unmarshal, ID is nil only // when the key was absent (not when explicitly null). We distinguish // by checking whether "id" key appears in the raw message. isNotification := !jsonHasKey([]byte(line), "id") switch req.Method { case "initialize": if isNotification { continue } result := map[string]any{ "protocolVersion": "2024-11-05", "capabilities": map[string]any{ "tools": map[string]any{}, }, "serverInfo": map[string]any{ "name": opts.Name, "version": opts.Version, }, } if err := sendResult(req.ID, result); err != nil { return err } case "initialized": // notification — ignore, no response case "tools/list": if isNotification { continue } tools := opts.Tools if tools == nil { tools = []MCPToolDef{} } result := map[string]any{ "tools": tools, } if err := sendResult(req.ID, result); err != nil { return err } case "tools/call": if isNotification { continue } var p mcpCallParams if err := json.Unmarshal(req.Params, &p); err != nil { if err2 := sendError(req.ID, -32602, "invalid params: "+err.Error()); err2 != nil { return err2 } continue } args := p.Arguments if args == nil { args = json.RawMessage(`{}`) } toolResult, isErr, handlerErr := opts.Handler(ctx, p.Name, args) if handlerErr != nil { logf("handler error for %q: %v", p.Name, handlerErr) if err2 := sendError(req.ID, -32603, handlerErr.Error()); err2 != nil { return err2 } continue } // Serialize the result value to JSON text for the content block. resultText, _ := json.Marshal(toolResult) callResult := map[string]any{ "content": []map[string]any{ { "type": "text", "text": string(resultText), }, }, "isError": isErr, } if err := sendResult(req.ID, callResult); err != nil { return err } case "ping": if isNotification { continue } if err := sendResult(req.ID, map[string]any{}); err != nil { return err } default: if isNotification { logf("unknown notification %q, ignoring", req.Method) continue } logf("unknown method %q", req.Method) if err2 := sendError(req.ID, -32601, "method not found: "+req.Method); err2 != nil { return err2 } } } } } // jsonHasKey reports whether the JSON object b contains the given top-level key. func jsonHasKey(b []byte, key string) bool { var m map[string]json.RawMessage if err := json.Unmarshal(b, &m); err != nil { return false } _, ok := m[key] return ok }