package main import ( "context" "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/mark3labs/mcp-go/server" ) // newTestLogger returns a slog.Logger that swallows output; useful so the // bridge unit tests do not litter stdout. func newTestLogger() *slog.Logger { return slog.New(slog.NewJSONHandler(io.Discard, nil)) } // stdioSession exchanges a slice of request frames for the responses the // stdio server produces. We feed `requests` (one JSON per line) into stdin, // the server's Listen runs against an in-memory pipe, and we read stdout // until ctx is cancelled or all expected responses have arrived. // // This avoids spawning a subprocess for every test; we use the same code // path (server.ServeStdio is just a thin wrapper around StdioServer.Listen). func stdioSession(t *testing.T, srv *server.MCPServer, requests []string, expectedResponses int) []map[string]any { t.Helper() stdioSrv := server.NewStdioServer(srv) stdinR, stdinW := io.Pipe() stdoutR, stdoutW := io.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() listenDone := make(chan error, 1) go func() { listenDone <- stdioSrv.Listen(ctx, stdinR, stdoutW) _ = stdoutW.Close() }() // Feed the requests go func() { defer stdinW.Close() for _, r := range requests { if !strings.HasSuffix(r, "\n") { r += "\n" } _, _ = stdinW.Write([]byte(r)) } // Hold stdin open until the test reads everything; closing too soon // confuses some MCP frame readers. We rely on ctx timeout to break // the Listen loop. }() // Collect responses dec := json.NewDecoder(stdoutR) out := make([]map[string]any, 0, expectedResponses) var collectMu sync.Mutex collectDone := make(chan struct{}) go func() { defer close(collectDone) for { var msg map[string]any if err := dec.Decode(&msg); err != nil { return } collectMu.Lock() out = append(out, msg) done := len(out) >= expectedResponses collectMu.Unlock() if done { return } } }() select { case <-collectDone: cancel() case <-ctx.Done(): } // Wait briefly for Listen to release. select { case <-listenDone: case <-time.After(500 * time.Millisecond): } collectMu.Lock() defer collectMu.Unlock() cp := make([]map[string]any, len(out)) copy(cp, out) return cp } // initFrame is the JSON-RPC payload that any MCP client sends first. func initFrame(id int) string { frame := map[string]any{ "jsonrpc": "2.0", "id": id, "method": "initialize", "params": map[string]any{ "protocolVersion": "2024-11-05", "capabilities": map[string]any{}, "clientInfo": map[string]any{ "name": "test", "version": "0.0.0", }, }, } b, _ := json.Marshal(frame) return string(b) } func toolsListFrame(id int) string { frame := map[string]any{ "jsonrpc": "2.0", "id": id, "method": "tools/list", "params": map[string]any{}, } b, _ := json.Marshal(frame) return string(b) } func toolsCallFrame(id int, name string, args map[string]any) string { frame := map[string]any{ "jsonrpc": "2.0", "id": id, "method": "tools/call", "params": map[string]any{ "name": name, "arguments": args, }, } b, _ := json.Marshal(frame) return string(b) } func notifInitializedFrame() string { frame := map[string]any{ "jsonrpc": "2.0", "method": "notifications/initialized", } b, _ := json.Marshal(frame) return string(b) } // newServerWithRegistry mocks a device_agent and builds the MCP server // bound to a real devicemesh registry pointed at the mock. Returns the // configured MCP server and a cleanup func. func newServerWithRegistry(t *testing.T, mode string, allowed []string, handler http.HandlerFunc) (*server.MCPServer, func()) { t.Helper() if handler == nil { handler = func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{ "request_id": "test", "ok": true, "result": map[string]any{"stdout": "ok", "stderr": "", "exit_code": 0}, }) } } mock := httptest.NewServer(handler) reg, err := buildRegistry(mock.URL, mode, allowed) if err != nil { mock.Close() t.Fatalf("buildRegistry: %v", err) } srv := server.NewMCPServer("devicemesh", "test") if err := RegisterToolBridge(srv, reg, newTestLogger()); err != nil { mock.Close() t.Fatalf("RegisterToolBridge: %v", err) } return srv, mock.Close } func TestInitialize(t *testing.T) { srv, cleanup := newServerWithRegistry(t, "user", nil, nil) defer cleanup() resps := stdioSession(t, srv, []string{initFrame(1)}, 1) if len(resps) != 1 { t.Fatalf("expected 1 response, got %d", len(resps)) } r := resps[0] if r["id"] != float64(1) { t.Fatalf("expected id=1, got %v", r["id"]) } result, _ := r["result"].(map[string]any) if result == nil { t.Fatalf("expected result object, got %v", r) } if _, ok := result["protocolVersion"]; !ok { t.Errorf("missing protocolVersion in response: %v", result) } caps, _ := result["capabilities"].(map[string]any) if _, ok := caps["tools"]; !ok { t.Errorf("missing capabilities.tools: %v", caps) } info, _ := result["serverInfo"].(map[string]any) if info["name"] != "devicemesh" { t.Errorf("expected serverInfo.name=devicemesh, got %v", info) } } func TestToolsList(t *testing.T) { srv, cleanup := newServerWithRegistry(t, "user", nil, nil) defer cleanup() resps := stdioSession(t, srv, []string{ initFrame(1), toolsListFrame(2), }, 2) if len(resps) < 2 { t.Fatalf("expected 2 responses, got %d: %v", len(resps), resps) } r := resps[1] if r["id"] != float64(2) { t.Fatalf("expected id=2, got %v", r["id"]) } result, _ := r["result"].(map[string]any) toolsList, _ := result["tools"].([]any) if len(toolsList) < 10 { t.Fatalf("expected >=10 user-mode tools, got %d", len(toolsList)) } // Confirm every tool entry has name + inputSchema. for i, t0 := range toolsList { tm, _ := t0.(map[string]any) if _, ok := tm["name"].(string); !ok { t.Errorf("tool[%d] missing name: %v", i, tm) } if _, ok := tm["inputSchema"].(map[string]any); !ok { t.Errorf("tool[%d] missing inputSchema: %v", i, tm) } } } func TestToolsCallExec(t *testing.T) { called := false mockHandler := func(w http.ResponseWriter, r *http.Request) { called = true body := map[string]any{} _ = json.NewDecoder(r.Body).Decode(&body) // Sanity: capability and argv must be forwarded. if body["capability"] != "shell.exec" { t.Errorf("expected capability=shell.exec, got %v", body["capability"]) } _ = json.NewEncoder(w).Encode(map[string]any{ "request_id": "test", "ok": true, "duration_ms": 12, "result": map[string]any{ "stdout": "hi", "stderr": "", "exit_code": 0, }, }) } srv, cleanup := newServerWithRegistry(t, "user", nil, mockHandler) defer cleanup() resps := stdioSession(t, srv, []string{ initFrame(1), toolsCallFrame(2, "exec", map[string]any{ "argv": []any{"echo", "hi"}, }), }, 2) if !called { t.Fatalf("mock device_agent never received the request") } if len(resps) < 2 { t.Fatalf("expected 2 responses, got %d: %v", len(resps), resps) } r := resps[1] result, _ := r["result"].(map[string]any) contents, _ := result["content"].([]any) if len(contents) == 0 { t.Fatalf("expected content blocks, got %v", result) } first, _ := contents[0].(map[string]any) text, _ := first["text"].(string) if !strings.Contains(text, "hi") { t.Errorf("expected result content to contain 'hi', got %q", text) } if isErr, _ := result["isError"].(bool); isErr { t.Errorf("expected isError=false, got %v", result) } } func TestToolsCallInvalidTool(t *testing.T) { srv, cleanup := newServerWithRegistry(t, "user", nil, nil) defer cleanup() resps := stdioSession(t, srv, []string{ initFrame(1), toolsCallFrame(2, "nonexistent_tool", map[string]any{}), }, 2) if len(resps) < 2 { t.Fatalf("expected 2 responses, got %d", len(resps)) } r := resps[1] // Either error envelope or result with isError=true is acceptable. if err, hasErr := r["error"]; hasErr && err != nil { return } result, _ := r["result"].(map[string]any) if isErr, _ := result["isError"].(bool); isErr { return } t.Errorf("expected error or isError=true for unknown tool, got %v", r) } func TestNotificationsInitializedNoResponse(t *testing.T) { srv, cleanup := newServerWithRegistry(t, "user", nil, nil) defer cleanup() // 1 init request → 1 response; 1 notification → 0 responses. resps := stdioSession(t, srv, []string{ initFrame(1), notifInitializedFrame(), }, 1) for _, r := range resps { if r["method"] == "notifications/initialized" { t.Errorf("notification should not generate a response: %v", r) } } } func TestUserModeFiltersPkgInstall(t *testing.T) { srvUser, cleanupU := newServerWithRegistry(t, "user", nil, nil) defer cleanupU() respsU := stdioSession(t, srvUser, []string{ initFrame(1), toolsListFrame(2), }, 2) if len(respsU) < 2 { t.Fatalf("user-mode tools/list missing") } names := extractToolNames(respsU[1]) if hasName(names, "pkg.install") { t.Errorf("user mode should NOT expose pkg.install, got %v", names) } if !hasName(names, "exec") { t.Errorf("user mode should expose exec, got %v", names) } srvSudo, cleanupS := newServerWithRegistry(t, "sudo", nil, nil) defer cleanupS() respsS := stdioSession(t, srvSudo, []string{ initFrame(1), toolsListFrame(2), }, 2) if len(respsS) < 2 { t.Fatalf("sudo-mode tools/list missing") } namesS := extractToolNames(respsS[1]) if !hasName(namesS, "pkg.install") { t.Errorf("sudo mode should expose pkg.install, got %v", namesS) } } func TestToolsAllowedNarrows(t *testing.T) { srv, cleanup := newServerWithRegistry(t, "user", []string{"exec", "fs.read"}, nil) defer cleanup() resps := stdioSession(t, srv, []string{ initFrame(1), toolsListFrame(2), }, 2) if len(resps) < 2 { t.Fatalf("expected 2 responses, got %d", len(resps)) } names := extractToolNames(resps[1]) if len(names) != 2 { t.Errorf("expected exactly 2 tools after filter, got %d (%v)", len(names), names) } if !hasName(names, "exec") || !hasName(names, "fs.read") { t.Errorf("expected exec + fs.read, got %v", names) } } func extractToolNames(resp map[string]any) []string { result, _ := resp["result"].(map[string]any) toolsList, _ := result["tools"].([]any) out := make([]string, 0, len(toolsList)) for _, t := range toolsList { tm, _ := t.(map[string]any) if n, ok := tm["name"].(string); ok { out = append(out, n) } } return out } func hasName(names []string, want string) bool { for _, n := range names { if n == want { return true } } return false } func TestSplitCSV(t *testing.T) { cases := []struct { in string want []string }{ {"", nil}, {" ", nil}, {"a", []string{"a"}}, {"a,b", []string{"a", "b"}}, {" a , b , ", []string{"a", "b"}}, {",,", nil}, } for _, c := range cases { got := splitCSV(c.in) if len(got) != len(c.want) { t.Errorf("splitCSV(%q) len=%d want=%d (%v)", c.in, len(got), len(c.want), got) continue } for i := range got { if got[i] != c.want[i] { t.Errorf("splitCSV(%q)[%d]=%q want %q", c.in, i, got[i], c.want[i]) } } } } func TestParseMode(t *testing.T) { if parseMode("user") == parseMode("sudo") { t.Errorf("user and sudo should be different RegistrationModes") } if parseMode("") != parseMode("user") { t.Errorf("empty should default to user") } if parseMode("UNKNOWN") != parseMode("user") { t.Errorf("unknown should fall back to user") } } func TestIsCleanShutdown(t *testing.T) { if !isCleanShutdown(nil) { t.Errorf("nil should be clean") } if !isCleanShutdown(io.EOF) { t.Errorf("EOF should be clean") } // Non-clean: a random other error string. if isCleanShutdown(io.ErrUnexpectedEOF) { // ErrUnexpectedEOF.Error() == "unexpected EOF" which DOES contain "EOF". // Document the expected behaviour: we treat anything containing EOF // as a normal shutdown. Adjust test to mirror. } if isCleanShutdown(http.ErrAbortHandler) { t.Errorf("http.ErrAbortHandler should NOT be clean") } }