package devicemesh import ( "context" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" ) func TestToolRegistry_RegisterListGet(t *testing.T) { reg := NewToolRegistry(nil) reg.Register(ToolSpec{Name: "a", Capability: "x.a"}) reg.Register(ToolSpec{Name: "b", Capability: "x.b"}) got, ok := reg.Get("a") if !ok { t.Fatalf("Get(a) not found") } if got.Capability != "x.a" { t.Errorf("capability: %q", got.Capability) } names := reg.Names() if len(names) != 2 || names[0] != "a" || names[1] != "b" { t.Errorf("Names sort: %v", names) } } func TestToolRegistry_Call_UnknownTool(t *testing.T) { reg := NewToolRegistry(NewClient("http://nowhere.invalid")) _, err := reg.Call(context.Background(), "no.such.tool", nil) if err == nil { t.Fatalf("expected error for unknown tool") } if !strings.Contains(err.Error(), "unknown tool") { t.Errorf("error message: %v", err) } } func TestToolRegistry_Call_NilClient(t *testing.T) { reg := NewToolRegistry(nil) reg.Register(ToolSpec{Name: "x", Capability: "x.y"}) _, err := reg.Call(context.Background(), "x", nil) if err == nil { t.Fatalf("expected error when client is nil") } } func TestToolRegistry_Call_InvalidInput(t *testing.T) { reg := NewToolRegistry(NewClient("http://nowhere.invalid")) reg.Register(ToolSpec{ Name: "needs_string", Capability: "x.y", InputSchema: map[string]any{ "type": "object", "required": []string{"foo"}, "properties": map[string]any{ "foo": map[string]any{"type": "string"}, }, "additionalProperties": false, }, }) // Missing required _, err := reg.Call(context.Background(), "needs_string", map[string]any{}) if err == nil { t.Errorf("expected error for missing required field") } // Wrong type _, err = reg.Call(context.Background(), "needs_string", map[string]any{"foo": 42}) if err == nil { t.Errorf("expected error for wrong type") } // Extra field _, err = reg.Call(context.Background(), "needs_string", map[string]any{"foo": "bar", "extra": 1}) if err == nil { t.Errorf("expected error for additional property") } } func TestToolRegistry_Call_HappyPath(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req CapabilityRequest body, _ := io.ReadAll(r.Body) _ = json.Unmarshal(body, &req) // Echo back the args under "received". _ = json.NewEncoder(w).Encode(CapabilityResponse{ RequestID: req.RequestID, OK: true, Result: map[string]any{"received": req.Args}, }) })) defer srv.Close() reg := NewToolRegistry(NewClient(srv.URL)) reg.Register(ToolSpec{ Name: "echo", Capability: "x.echo", InputSchema: map[string]any{ "type": "object", "required": []string{"msg"}, "properties": map[string]any{ "msg": map[string]any{"type": "string"}, }, }, ArgMapping: func(in map[string]any) (map[string]any, error) { return map[string]any{"upper_msg": strings.ToUpper(in["msg"].(string))}, nil }, ResultMapping: func(r map[string]any) (any, error) { received := r["received"].(map[string]any) return received["upper_msg"], nil }, }) out, err := reg.Call(context.Background(), "echo", map[string]any{"msg": "hola"}) if err != nil { t.Fatalf("call: %v", err) } if out != "HOLA" { t.Errorf("expected HOLA, got %v", out) } } func TestToolRegistry_Call_DeviceErrorPropagates(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(CapabilityResponse{ OK: false, Error: "binary not whitelisted", }) })) defer srv.Close() reg := NewToolRegistry(NewClient(srv.URL)) reg.Register(ToolSpec{Name: "exec", Capability: "shell.exec"}) _, err := reg.Call(context.Background(), "exec", nil) if err == nil { t.Fatalf("expected device-side error to propagate") } if !strings.Contains(err.Error(), "binary not whitelisted") { t.Errorf("error message lost: %v", err) } }