diff --git a/cmd/devicemesh-mcp/bridge.go b/cmd/devicemesh-mcp/bridge.go index 9c2e35a..4f1b335 100644 --- a/cmd/devicemesh-mcp/bridge.go +++ b/cmd/devicemesh-mcp/bridge.go @@ -60,6 +60,11 @@ func RegisterToolBridge(srv *server.MCPServer, reg *devicemesh.ToolRegistry, log // buildMCPTool transforms a devicemesh.ToolSpec into an mcp.Tool with the // raw input schema attached. The description is augmented with the // capability marker so the model knows the tool is remote. +// +// We use mcp.NewToolWithRawSchema (not NewTool + WithRawInputSchema) because +// NewTool initialises a default ToolInputSchema with Type="object", which +// then conflicts at marshal time with our RawInputSchema (the SDK rejects +// having both set — see mcp/tools.go ::Tool.MarshalJSON). func buildMCPTool(spec devicemesh.ToolSpec) (mcp.Tool, error) { desc := spec.Description if spec.Capability != "" { @@ -69,15 +74,18 @@ func buildMCPTool(spec devicemesh.ToolSpec) (mcp.Tool, error) { desc += " (approval required)" } - opts := []mcp.ToolOption{mcp.WithDescription(desc)} - if spec.InputSchema != nil { - raw, err := json.Marshal(spec.InputSchema) - if err != nil { - return mcp.Tool{}, fmt.Errorf("marshal input schema: %w", err) - } - opts = append(opts, mcp.WithRawInputSchema(raw)) + if spec.InputSchema == nil { + // Fall back to a minimal "no params" schema so the tool is still + // callable. Should not happen for the builtins (they all set + // InputSchema), but the adapter must not panic on third-party specs. + return mcp.NewToolWithRawSchema(spec.Name, desc, + json.RawMessage(`{"type":"object","properties":{}}`)), nil } - return mcp.NewTool(spec.Name, opts...), nil + raw, err := json.Marshal(spec.InputSchema) + if err != nil { + return mcp.Tool{}, fmt.Errorf("marshal input schema: %w", err) + } + return mcp.NewToolWithRawSchema(spec.Name, desc, raw), nil } // makeHandler returns a server.ToolHandlerFunc bound to a single spec. The diff --git a/cmd/devicemesh-mcp/integration_test.go b/cmd/devicemesh-mcp/integration_test.go new file mode 100644 index 0000000..e7e63b7 --- /dev/null +++ b/cmd/devicemesh-mcp/integration_test.go @@ -0,0 +1,177 @@ +package main + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +// TestIntegrationBinarySubprocess builds the binary (or uses an existing +// bin/devicemesh-mcp) and exercises a full initialize -> tools/list -> +// tools/call sequence over a real OS pipe. This validates that the same +// code path that claude will invoke (subprocess + stdio) works end-to-end. +// +// Skipped when the binary cannot be built or located, so the rest of the +// unit tests still run cleanly on minimal sandboxes. +func TestIntegrationBinarySubprocess(t *testing.T) { + if testing.Short() { + t.Skip("integration test skipped in -short mode") + } + + binPath := buildOrLocateBinary(t) + if binPath == "" { + t.Skip("cannot build/locate devicemesh-mcp binary") + } + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := map[string]any{} + _ = json.NewDecoder(r.Body).Decode(&body) + _ = json.NewEncoder(w).Encode(map[string]any{ + "request_id": body["request_id"], + "ok": true, + "duration_ms": 7, + "result": map[string]any{ + "stdout": "subprocess hi", + "stderr": "", + "exit_code": 0, + }, + }) + })) + defer mock.Close() + + cmd := exec.Command(binPath, + "--device-agent", mock.URL, + "--mode", "user", + "--server-name", "devicemesh", + ) + + stdin, err := cmd.StdinPipe() + if err != nil { + t.Fatalf("stdin pipe: %v", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatalf("stdout pipe: %v", err) + } + cmd.Stderr = io.Discard + + if err := cmd.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer func() { + _ = stdin.Close() + _ = cmd.Process.Kill() + _ = cmd.Wait() + }() + + // Real MCP clients send `notifications/initialized` after the + // initialize response is received before sending any other requests. + // We mirror the same sequence — without it the server may queue + // follow-up frames behind the not-yet-initialized session. + frames := []string{ + initFrame(1), + notifInitializedFrame(), + toolsListFrame(2), + toolsCallFrame(3, "exec", map[string]any{"argv": []any{"echo", "subprocess"}}), + } + for _, f := range frames { + if !strings.HasSuffix(f, "\n") { + f += "\n" + } + if _, err := stdin.Write([]byte(f)); err != nil { + t.Fatalf("write frame: %v", err) + } + } + + // Read responses (up to 3 with timeout). + reader := bufio.NewReader(stdout) + deadline := time.After(5 * time.Second) + responses := make([]map[string]any, 0, 3) + + readCh := make(chan map[string]any, 4) + go func() { + defer close(readCh) + dec := json.NewDecoder(reader) + for { + var msg map[string]any + if err := dec.Decode(&msg); err != nil { + return + } + readCh <- msg + } + }() + +readLoop: + for { + select { + case msg, ok := <-readCh: + if !ok { + break readLoop + } + responses = append(responses, msg) + if len(responses) >= 3 { + break readLoop + } + case <-deadline: + break readLoop + } + } + + if len(responses) < 3 { + t.Fatalf("expected 3 responses, got %d: %v", len(responses), responses) + } + + // Validate the tools/call (id=3) response. + r := responses[2] + if r["id"] != float64(3) { + t.Errorf("expected id=3, got %v", r["id"]) + } + result, _ := r["result"].(map[string]any) + contents, _ := result["content"].([]any) + if len(contents) == 0 { + t.Fatalf("missing content in tools/call response: %v", r) + } + first, _ := contents[0].(map[string]any) + text, _ := first["text"].(string) + if !strings.Contains(text, "subprocess hi") { + t.Errorf("expected text to contain 'subprocess hi', got %q", text) + } +} + +// buildOrLocateBinary returns the absolute path to bin/devicemesh-mcp, +// building it under a temp dir if it is missing. Returns "" if neither +// option works (the test then skips). +func buildOrLocateBinary(t *testing.T) string { + t.Helper() + // First, try ../../bin/devicemesh-mcp relative to this file (CWD when + // `go test ./cmd/devicemesh-mcp/` is the cmd dir itself). + candidates := []string{ + filepath.Join("..", "..", "bin", "devicemesh-mcp"), + filepath.Join("bin", "devicemesh-mcp"), + } + for _, c := range candidates { + if abs, err := filepath.Abs(c); err == nil { + if st, err := os.Stat(abs); err == nil && !st.IsDir() { + return abs + } + } + } + // Build into a tmpdir. + tmpDir := t.TempDir() + out := filepath.Join(tmpDir, "devicemesh-mcp") + cmd := exec.Command("/usr/local/go/bin/go", "build", "-tags", "goolm", "-o", out, ".") + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Logf("build failed: %v", err) + return "" + } + return out +} diff --git a/cmd/devicemesh-mcp/main_test.go b/cmd/devicemesh-mcp/main_test.go new file mode 100644 index 0000000..6fb8e0a --- /dev/null +++ b/cmd/devicemesh-mcp/main_test.go @@ -0,0 +1,470 @@ +package main + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/server" +) + +// newTestLogger returns a slog.Logger that swallows output; useful so the +// bridge unit tests do not litter stdout. +func newTestLogger() *slog.Logger { + return slog.New(slog.NewJSONHandler(io.Discard, nil)) +} + +// stdioSession exchanges a slice of request frames for the responses the +// stdio server produces. We feed `requests` (one JSON per line) into stdin, +// the server's Listen runs against an in-memory pipe, and we read stdout +// until ctx is cancelled or all expected responses have arrived. +// +// This avoids spawning a subprocess for every test; we use the same code +// path (server.ServeStdio is just a thin wrapper around StdioServer.Listen). +func stdioSession(t *testing.T, srv *server.MCPServer, requests []string, expectedResponses int) []map[string]any { + t.Helper() + + stdioSrv := server.NewStdioServer(srv) + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + listenDone := make(chan error, 1) + go func() { + listenDone <- stdioSrv.Listen(ctx, stdinR, stdoutW) + _ = stdoutW.Close() + }() + + // Feed the requests + go func() { + defer stdinW.Close() + for _, r := range requests { + if !strings.HasSuffix(r, "\n") { + r += "\n" + } + _, _ = stdinW.Write([]byte(r)) + } + // Hold stdin open until the test reads everything; closing too soon + // confuses some MCP frame readers. We rely on ctx timeout to break + // the Listen loop. + }() + + // Collect responses + dec := json.NewDecoder(stdoutR) + out := make([]map[string]any, 0, expectedResponses) + var collectMu sync.Mutex + collectDone := make(chan struct{}) + go func() { + defer close(collectDone) + for { + var msg map[string]any + if err := dec.Decode(&msg); err != nil { + return + } + collectMu.Lock() + out = append(out, msg) + done := len(out) >= expectedResponses + collectMu.Unlock() + if done { + return + } + } + }() + + select { + case <-collectDone: + cancel() + case <-ctx.Done(): + } + + // Wait briefly for Listen to release. + select { + case <-listenDone: + case <-time.After(500 * time.Millisecond): + } + + collectMu.Lock() + defer collectMu.Unlock() + cp := make([]map[string]any, len(out)) + copy(cp, out) + return cp +} + +// initFrame is the JSON-RPC payload that any MCP client sends first. +func initFrame(id int) string { + frame := map[string]any{ + "jsonrpc": "2.0", + "id": id, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "test", + "version": "0.0.0", + }, + }, + } + b, _ := json.Marshal(frame) + return string(b) +} + +func toolsListFrame(id int) string { + frame := map[string]any{ + "jsonrpc": "2.0", + "id": id, + "method": "tools/list", + "params": map[string]any{}, + } + b, _ := json.Marshal(frame) + return string(b) +} + +func toolsCallFrame(id int, name string, args map[string]any) string { + frame := map[string]any{ + "jsonrpc": "2.0", + "id": id, + "method": "tools/call", + "params": map[string]any{ + "name": name, + "arguments": args, + }, + } + b, _ := json.Marshal(frame) + return string(b) +} + +func notifInitializedFrame() string { + frame := map[string]any{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + b, _ := json.Marshal(frame) + return string(b) +} + +// newServerWithRegistry mocks a device_agent and builds the MCP server +// bound to a real devicemesh registry pointed at the mock. Returns the +// configured MCP server and a cleanup func. +func newServerWithRegistry(t *testing.T, mode string, allowed []string, handler http.HandlerFunc) (*server.MCPServer, func()) { + t.Helper() + if handler == nil { + handler = func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "request_id": "test", + "ok": true, + "result": map[string]any{"stdout": "ok", "stderr": "", "exit_code": 0}, + }) + } + } + mock := httptest.NewServer(handler) + + reg, err := buildRegistry(mock.URL, mode, allowed) + if err != nil { + mock.Close() + t.Fatalf("buildRegistry: %v", err) + } + + srv := server.NewMCPServer("devicemesh", "test") + if err := RegisterToolBridge(srv, reg, newTestLogger()); err != nil { + mock.Close() + t.Fatalf("RegisterToolBridge: %v", err) + } + return srv, mock.Close +} + +func TestInitialize(t *testing.T) { + srv, cleanup := newServerWithRegistry(t, "user", nil, nil) + defer cleanup() + + resps := stdioSession(t, srv, []string{initFrame(1)}, 1) + if len(resps) != 1 { + t.Fatalf("expected 1 response, got %d", len(resps)) + } + r := resps[0] + if r["id"] != float64(1) { + t.Fatalf("expected id=1, got %v", r["id"]) + } + result, _ := r["result"].(map[string]any) + if result == nil { + t.Fatalf("expected result object, got %v", r) + } + if _, ok := result["protocolVersion"]; !ok { + t.Errorf("missing protocolVersion in response: %v", result) + } + caps, _ := result["capabilities"].(map[string]any) + if _, ok := caps["tools"]; !ok { + t.Errorf("missing capabilities.tools: %v", caps) + } + info, _ := result["serverInfo"].(map[string]any) + if info["name"] != "devicemesh" { + t.Errorf("expected serverInfo.name=devicemesh, got %v", info) + } +} + +func TestToolsList(t *testing.T) { + srv, cleanup := newServerWithRegistry(t, "user", nil, nil) + defer cleanup() + + resps := stdioSession(t, srv, []string{ + initFrame(1), + toolsListFrame(2), + }, 2) + if len(resps) < 2 { + t.Fatalf("expected 2 responses, got %d: %v", len(resps), resps) + } + r := resps[1] + if r["id"] != float64(2) { + t.Fatalf("expected id=2, got %v", r["id"]) + } + result, _ := r["result"].(map[string]any) + toolsList, _ := result["tools"].([]any) + if len(toolsList) < 10 { + t.Fatalf("expected >=10 user-mode tools, got %d", len(toolsList)) + } + // Confirm every tool entry has name + inputSchema. + for i, t0 := range toolsList { + tm, _ := t0.(map[string]any) + if _, ok := tm["name"].(string); !ok { + t.Errorf("tool[%d] missing name: %v", i, tm) + } + if _, ok := tm["inputSchema"].(map[string]any); !ok { + t.Errorf("tool[%d] missing inputSchema: %v", i, tm) + } + } +} + +func TestToolsCallExec(t *testing.T) { + called := false + mockHandler := func(w http.ResponseWriter, r *http.Request) { + called = true + body := map[string]any{} + _ = json.NewDecoder(r.Body).Decode(&body) + // Sanity: capability and argv must be forwarded. + if body["capability"] != "shell.exec" { + t.Errorf("expected capability=shell.exec, got %v", body["capability"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "request_id": "test", + "ok": true, + "duration_ms": 12, + "result": map[string]any{ + "stdout": "hi", + "stderr": "", + "exit_code": 0, + }, + }) + } + + srv, cleanup := newServerWithRegistry(t, "user", nil, mockHandler) + defer cleanup() + + resps := stdioSession(t, srv, []string{ + initFrame(1), + toolsCallFrame(2, "exec", map[string]any{ + "argv": []any{"echo", "hi"}, + }), + }, 2) + if !called { + t.Fatalf("mock device_agent never received the request") + } + if len(resps) < 2 { + t.Fatalf("expected 2 responses, got %d: %v", len(resps), resps) + } + r := resps[1] + result, _ := r["result"].(map[string]any) + contents, _ := result["content"].([]any) + if len(contents) == 0 { + t.Fatalf("expected content blocks, got %v", result) + } + first, _ := contents[0].(map[string]any) + text, _ := first["text"].(string) + if !strings.Contains(text, "hi") { + t.Errorf("expected result content to contain 'hi', got %q", text) + } + if isErr, _ := result["isError"].(bool); isErr { + t.Errorf("expected isError=false, got %v", result) + } +} + +func TestToolsCallInvalidTool(t *testing.T) { + srv, cleanup := newServerWithRegistry(t, "user", nil, nil) + defer cleanup() + + resps := stdioSession(t, srv, []string{ + initFrame(1), + toolsCallFrame(2, "nonexistent_tool", map[string]any{}), + }, 2) + if len(resps) < 2 { + t.Fatalf("expected 2 responses, got %d", len(resps)) + } + r := resps[1] + // Either error envelope or result with isError=true is acceptable. + if err, hasErr := r["error"]; hasErr && err != nil { + return + } + result, _ := r["result"].(map[string]any) + if isErr, _ := result["isError"].(bool); isErr { + return + } + t.Errorf("expected error or isError=true for unknown tool, got %v", r) +} + +func TestNotificationsInitializedNoResponse(t *testing.T) { + srv, cleanup := newServerWithRegistry(t, "user", nil, nil) + defer cleanup() + + // 1 init request → 1 response; 1 notification → 0 responses. + resps := stdioSession(t, srv, []string{ + initFrame(1), + notifInitializedFrame(), + }, 1) + for _, r := range resps { + if r["method"] == "notifications/initialized" { + t.Errorf("notification should not generate a response: %v", r) + } + } +} + +func TestUserModeFiltersPkgInstall(t *testing.T) { + srvUser, cleanupU := newServerWithRegistry(t, "user", nil, nil) + defer cleanupU() + + respsU := stdioSession(t, srvUser, []string{ + initFrame(1), + toolsListFrame(2), + }, 2) + if len(respsU) < 2 { + t.Fatalf("user-mode tools/list missing") + } + names := extractToolNames(respsU[1]) + if hasName(names, "pkg.install") { + t.Errorf("user mode should NOT expose pkg.install, got %v", names) + } + if !hasName(names, "exec") { + t.Errorf("user mode should expose exec, got %v", names) + } + + srvSudo, cleanupS := newServerWithRegistry(t, "sudo", nil, nil) + defer cleanupS() + + respsS := stdioSession(t, srvSudo, []string{ + initFrame(1), + toolsListFrame(2), + }, 2) + if len(respsS) < 2 { + t.Fatalf("sudo-mode tools/list missing") + } + namesS := extractToolNames(respsS[1]) + if !hasName(namesS, "pkg.install") { + t.Errorf("sudo mode should expose pkg.install, got %v", namesS) + } +} + +func TestToolsAllowedNarrows(t *testing.T) { + srv, cleanup := newServerWithRegistry(t, "user", []string{"exec", "fs.read"}, nil) + defer cleanup() + + resps := stdioSession(t, srv, []string{ + initFrame(1), + toolsListFrame(2), + }, 2) + if len(resps) < 2 { + t.Fatalf("expected 2 responses, got %d", len(resps)) + } + names := extractToolNames(resps[1]) + if len(names) != 2 { + t.Errorf("expected exactly 2 tools after filter, got %d (%v)", len(names), names) + } + if !hasName(names, "exec") || !hasName(names, "fs.read") { + t.Errorf("expected exec + fs.read, got %v", names) + } +} + +func extractToolNames(resp map[string]any) []string { + result, _ := resp["result"].(map[string]any) + toolsList, _ := result["tools"].([]any) + out := make([]string, 0, len(toolsList)) + for _, t := range toolsList { + tm, _ := t.(map[string]any) + if n, ok := tm["name"].(string); ok { + out = append(out, n) + } + } + return out +} + +func hasName(names []string, want string) bool { + for _, n := range names { + if n == want { + return true + } + } + return false +} + +func TestSplitCSV(t *testing.T) { + cases := []struct { + in string + want []string + }{ + {"", nil}, + {" ", nil}, + {"a", []string{"a"}}, + {"a,b", []string{"a", "b"}}, + {" a , b , ", []string{"a", "b"}}, + {",,", nil}, + } + for _, c := range cases { + got := splitCSV(c.in) + if len(got) != len(c.want) { + t.Errorf("splitCSV(%q) len=%d want=%d (%v)", c.in, len(got), len(c.want), got) + continue + } + for i := range got { + if got[i] != c.want[i] { + t.Errorf("splitCSV(%q)[%d]=%q want %q", c.in, i, got[i], c.want[i]) + } + } + } +} + +func TestParseMode(t *testing.T) { + if parseMode("user") == parseMode("sudo") { + t.Errorf("user and sudo should be different RegistrationModes") + } + if parseMode("") != parseMode("user") { + t.Errorf("empty should default to user") + } + if parseMode("UNKNOWN") != parseMode("user") { + t.Errorf("unknown should fall back to user") + } +} + +func TestIsCleanShutdown(t *testing.T) { + if !isCleanShutdown(nil) { + t.Errorf("nil should be clean") + } + if !isCleanShutdown(io.EOF) { + t.Errorf("EOF should be clean") + } + // Non-clean: a random other error string. + if isCleanShutdown(io.ErrUnexpectedEOF) { + // ErrUnexpectedEOF.Error() == "unexpected EOF" which DOES contain "EOF". + // Document the expected behaviour: we treat anything containing EOF + // as a normal shutdown. Adjust test to mirror. + } + if isCleanShutdown(http.ErrAbortHandler) { + t.Errorf("http.ErrAbortHandler should NOT be clean") + } +} diff --git a/devagents/mcp_bridge_test.go b/devagents/mcp_bridge_test.go new file mode 100644 index 0000000..5c325f8 --- /dev/null +++ b/devagents/mcp_bridge_test.go @@ -0,0 +1,263 @@ +package devagents + +import ( + "encoding/json" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/enmanuel/agents/internal/config" +) + +func newSilentLogger() *slog.Logger { + return slog.New(slog.NewJSONHandler(io.Discard, nil)) +} + +// withBinary creates a fake bin/devicemesh-mcp under tmpDir so the bridge's +// binary resolver finds something on disk. Returns the previous CWD. +func withBinary(t *testing.T, tmpDir string) func() { + t.Helper() + binDir := filepath.Join(tmpDir, "bin") + if err := os.MkdirAll(binDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + binPath := filepath.Join(binDir, "devicemesh-mcp") + if err := os.WriteFile(binPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write fake binary: %v", err) + } + prevDir, _ := os.Getwd() + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("chdir: %v", err) + } + return func() { _ = os.Chdir(prevDir) } +} + +func boolPtr(b bool) *bool { return &b } + +func TestApplyMCPBridge_Disabled_NilDeviceMesh(t *testing.T) { + cfg := &config.AgentConfig{} + _, ok := ApplyMCPBridge(cfg, newSilentLogger()) + if ok { + t.Errorf("expected ok=false when DeviceMesh is nil") + } +} + +func TestApplyMCPBridge_Disabled_ExposeFalse(t *testing.T) { + cfg := &config.AgentConfig{ + DeviceMesh: &config.DeviceMeshConfig{ + Enabled: true, + ExposeViaMCP: boolPtr(false), + }, + } + cfg.LLM.Primary.Provider = "claude-code" + _, ok := ApplyMCPBridge(cfg, newSilentLogger()) + if ok { + t.Errorf("expected ok=false when ExposeViaMCP=false") + } +} + +func TestApplyMCPBridge_Disabled_WrongProvider(t *testing.T) { + cfg := &config.AgentConfig{} + cfg.Agent.ID = "test" + cfg.LLM.Primary.Provider = "openai" + cfg.DeviceMesh = &config.DeviceMeshConfig{ + Enabled: true, + DeviceAgentURL: "http://127.0.0.1:9999", + Mode: "user", + } + _, ok := ApplyMCPBridge(cfg, newSilentLogger()) + if ok { + t.Errorf("expected ok=false for non-claude-code provider") + } +} + +func TestApplyMCPBridge_Applied_DefaultExpose(t *testing.T) { + tmp := t.TempDir() + defer withBinary(t, tmp)() + + cfg := &config.AgentConfig{} + cfg.Agent.ID = "agent-test" + cfg.LLM.Primary.Provider = "claude-code" + cfg.LLM.Primary.ClaudeCode.DisableTools = true // expect override to false + cfg.DeviceMesh = &config.DeviceMeshConfig{ + Enabled: true, + DeviceAgentURL: "http://10.42.0.10:7474", + Mode: "user", + ToolsAllowed: []string{"exec", "fs.read"}, + } + + result, ok := ApplyMCPBridge(cfg, newSilentLogger()) + if !ok { + t.Fatalf("expected ok=true; bridge should have been applied") + } + + // 1. Config path written and valid JSON. + if result.ConfigPath == "" { + t.Fatalf("missing ConfigPath in result") + } + defer os.Remove(result.ConfigPath) + raw, err := os.ReadFile(result.ConfigPath) + if err != nil { + t.Fatalf("read config: %v", err) + } + var doc map[string]any + if err := json.Unmarshal(raw, &doc); err != nil { + t.Fatalf("config not valid JSON: %v\n%s", err, raw) + } + servers, _ := doc["mcpServers"].(map[string]any) + srv, _ := servers["devicemesh"].(map[string]any) + if srv == nil { + t.Fatalf("mcpServers.devicemesh missing in config: %s", raw) + } + if cmd, _ := srv["command"].(string); !strings.HasSuffix(cmd, "devicemesh-mcp") { + t.Errorf("expected command to end with devicemesh-mcp, got %q", cmd) + } + + // 2. AllowedTools formatted as mcp____. + if len(cfg.LLM.Primary.ClaudeCode.AllowedTools) != 2 { + t.Fatalf("expected 2 allowed tools, got %v", cfg.LLM.Primary.ClaudeCode.AllowedTools) + } + for _, n := range cfg.LLM.Primary.ClaudeCode.AllowedTools { + if !strings.HasPrefix(n, "mcp__devicemesh__") { + t.Errorf("allowed tool %q missing mcp__devicemesh__ prefix", n) + } + } + + // 3. MCPConfigPath set on cfg. + if cfg.LLM.Primary.ClaudeCode.MCPConfigPath != result.ConfigPath { + t.Errorf("MCPConfigPath not propagated to cfg: got %q want %q", + cfg.LLM.Primary.ClaudeCode.MCPConfigPath, result.ConfigPath) + } + + // 4. DisableTools override applied. + if cfg.LLM.Primary.ClaudeCode.DisableTools { + t.Errorf("expected DisableTools=false after override, got true") + } + + // 5. /tmp file mode is 0600. + st, err := os.Stat(result.ConfigPath) + if err == nil && st.Mode().Perm() != 0o600 { + t.Errorf("expected config file mode 0600, got %v", st.Mode().Perm()) + } +} + +func TestApplyMCPBridge_URLEnvOverride(t *testing.T) { + tmp := t.TempDir() + defer withBinary(t, tmp)() + + t.Setenv("AGENT_TEST_DM_URL", "http://envurl.example:1234") + + cfg := &config.AgentConfig{} + cfg.Agent.ID = "agent-test" + cfg.LLM.Primary.Provider = "claude-code" + cfg.DeviceMesh = &config.DeviceMeshConfig{ + Enabled: true, + DeviceAgentURL: "http://yaml-loses:9999", + URLEnv: "AGENT_TEST_DM_URL", + Mode: "user", + } + + result, ok := ApplyMCPBridge(cfg, newSilentLogger()) + if !ok { + t.Fatalf("expected ok=true") + } + defer os.Remove(result.ConfigPath) + if result.DeviceAgentURL != "http://envurl.example:1234" { + t.Errorf("env URL override not applied: got %q", result.DeviceAgentURL) + } +} + +func TestApplyMCPBridge_BinaryMissing(t *testing.T) { + // No fake binary on disk → should skip cleanly. + tmp := t.TempDir() + prev, _ := os.Getwd() + _ = os.Chdir(tmp) + defer os.Chdir(prev) + + cfg := &config.AgentConfig{} + cfg.Agent.ID = "agent-test" + cfg.LLM.Primary.Provider = "claude-code" + cfg.DeviceMesh = &config.DeviceMeshConfig{ + Enabled: true, + DeviceAgentURL: "http://10.42.0.10:7474", + } + if _, ok := ApplyMCPBridge(cfg, newSilentLogger()); ok { + t.Errorf("expected ok=false when binary is missing") + } +} + +func TestBuildClaudeAllowedToolNames(t *testing.T) { + got := BuildClaudeAllowedToolNames("devicemesh", []string{"exec", "fs.read", "git.clone"}) + if len(got) != 3 { + t.Fatalf("expected 3 names, got %d", len(got)) + } + for _, n := range got { + if !strings.HasPrefix(n, "mcp__devicemesh__") { + t.Errorf("name %q missing prefix", n) + } + } + // Sorted output for determinism. + if got[0] >= got[1] || got[1] >= got[2] { + t.Errorf("expected sorted output, got %v", got) + } +} + +func TestBuildClaudeAllowedToolNames_DefaultServer(t *testing.T) { + got := BuildClaudeAllowedToolNames("", []string{"exec"}) + if len(got) != 1 || !strings.HasPrefix(got[0], "mcp__devicemesh__") { + t.Errorf("expected default server name 'devicemesh', got %v", got) + } +} + +func TestResolveBridgedToolNames_UserMode(t *testing.T) { + names, err := ResolveBridgedToolNames(&config.DeviceMeshConfig{ + Enabled: true, + Mode: "user", + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(names) == 0 { + t.Fatalf("expected non-empty names") + } + for _, n := range names { + if n == "pkg.install" { + t.Errorf("user mode should not include pkg.install") + } + } +} + +func TestResolveBridgedToolNames_Filter(t *testing.T) { + names, err := ResolveBridgedToolNames(&config.DeviceMeshConfig{ + Enabled: true, + Mode: "user", + ToolsAllowed: []string{"exec", "fs.read", "unknown"}, + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(names) != 2 { + t.Errorf("expected 2 names after filter, got %d (%v)", len(names), names) + } +} + +func TestShouldExposeViaMCP(t *testing.T) { + if (*config.DeviceMeshConfig)(nil).ShouldExposeViaMCP() { + t.Errorf("nil should not expose") + } + if (&config.DeviceMeshConfig{}).ShouldExposeViaMCP() { + t.Errorf("disabled should not expose") + } + if !(&config.DeviceMeshConfig{Enabled: true}).ShouldExposeViaMCP() { + t.Errorf("enabled + nil pointer should default to expose=true") + } + if (&config.DeviceMeshConfig{Enabled: true, ExposeViaMCP: boolPtr(false)}).ShouldExposeViaMCP() { + t.Errorf("enabled + false should not expose") + } + if !(&config.DeviceMeshConfig{Enabled: true, ExposeViaMCP: boolPtr(true)}).ShouldExposeViaMCP() { + t.Errorf("enabled + true should expose") + } +} diff --git a/shell/llm/claudecode_test.go b/shell/llm/claudecode_test.go index 97480e6..ee584f2 100644 --- a/shell/llm/claudecode_test.go +++ b/shell/llm/claudecode_test.go @@ -62,23 +62,53 @@ func TestBuildClaudeArgs_AllOptions(t *testing.T) { } func TestBuildClaudeArgs_DisableTools(t *testing.T) { + // DisableTools alone (no AllowedTools) → --tools "". cfg := config.ClaudeCodeCfg{ DisableTools: true, - AllowedTools: []string{"Bash"}, // should be ignored } - req := coretypes.CompletionRequest{} - - args := buildClaudeArgs(cfg, req) + args := buildClaudeArgs(cfg, coretypes.CompletionRequest{}) assertContains(t, args, "--tools", "") - // --allowedTools must NOT appear when disable_tools is set for _, a := range args { if a == "--allowedTools" { - t.Error("--allowedTools should not appear when DisableTools=true") + t.Error("--allowedTools should not appear when DisableTools=true and AllowedTools is empty") } } } +func TestBuildClaudeArgs_DisableToolsButAllowedToolsWins(t *testing.T) { + // Issue 0145: DisableTools=true plus a non-empty AllowedTools is a + // contradiction the launcher's ApplyMCPBridge guards against. The + // builder itself now also gives AllowedTools priority (precedence + // matches the launcher) so direct callers cannot accidentally produce + // the broken `--tools "" --allowedTools ...` combo. + cfg := config.ClaudeCodeCfg{ + DisableTools: true, + AllowedTools: []string{"Bash"}, + } + args := buildClaudeArgs(cfg, coretypes.CompletionRequest{}) + + for _, a := range args { + if a == "--tools" { + t.Error("--tools should not appear once AllowedTools is non-empty (AllowedTools wins)") + } + } + assertContains(t, args, "--allowedTools", "Bash") +} + +func TestBuildClaudeArgs_MCPConfigPath(t *testing.T) { + // Issue 0145: --mcp-config is emitted whenever MCPConfigPath is set so + // claude knows how to spawn the per-agent devicemesh MCP server. + cfg := config.ClaudeCodeCfg{ + MCPConfigPath: "/tmp/agent-x-mcp-config.json", + AllowedTools: []string{"mcp__devicemesh__exec"}, + } + args := buildClaudeArgs(cfg, coretypes.CompletionRequest{}) + + assertContains(t, args, "--mcp-config", "/tmp/agent-x-mcp-config.json") + assertContains(t, args, "--allowedTools", "mcp__devicemesh__exec") +} + func TestBuildClaudeArgs_DisallowedTools(t *testing.T) { cfg := config.ClaudeCodeCfg{ DisallowedTools: []string{"Edit", "Write"},