package devicemesh import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "time" ) func TestClient_Call_RoundTrip(t *testing.T) { var received CapabilityRequest srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("expected POST, got %s", r.Method) } if r.URL.Path != "/capability" { t.Errorf("expected /capability path, got %s", r.URL.Path) } body, _ := io.ReadAll(r.Body) if err := json.Unmarshal(body, &received); err != nil { t.Fatalf("decode body: %v", err) } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(CapabilityResponse{ RequestID: received.RequestID, OK: true, Result: map[string]any{"echo": "ok"}, DurationMs: 5, AuditHash: "abc123", }) })) defer srv.Close() c := NewClient(srv.URL) resp, err := c.Call(context.Background(), CapabilityRequest{ Capability: "shell.exec", Args: map[string]any{"argv": []string{"ls"}}, }) if err != nil { t.Fatalf("call: %v", err) } if !resp.OK { t.Fatalf("expected ok=true, got %+v", resp) } if resp.AuditHash != "abc123" { t.Errorf("audit hash mismatch: %q", resp.AuditHash) } if received.RequestID == "" { t.Errorf("expected client to populate request_id") } if !strings.HasPrefix(received.RequestID, "req_") { t.Errorf("request_id should have req_ prefix, got %q", received.RequestID) } if received.Nonce == "" { t.Errorf("expected client to populate nonce") } if received.Timestamp == 0 { t.Errorf("expected client to populate ts") } if received.Capability != "shell.exec" { t.Errorf("capability mismatch: %q", received.Capability) } } func TestClient_Call_PreservesProvidedIDs(t *testing.T) { var received CapabilityRequest srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) _ = json.Unmarshal(body, &received) _ = json.NewEncoder(w).Encode(CapabilityResponse{RequestID: received.RequestID, OK: true}) })) defer srv.Close() c := NewClient(srv.URL) _, err := c.Call(context.Background(), CapabilityRequest{ RequestID: "req_custom_123", Capability: "fs.read", Args: map[string]any{"path": "/tmp/x"}, Nonce: "fixed_nonce", Timestamp: 1234567890, }) if err != nil { t.Fatalf("call: %v", err) } if received.RequestID != "req_custom_123" { t.Errorf("request_id overwritten: %q", received.RequestID) } if received.Nonce != "fixed_nonce" { t.Errorf("nonce overwritten: %q", received.Nonce) } if received.Timestamp != 1234567890 { t.Errorf("ts overwritten: %d", received.Timestamp) } } func TestClient_Call_OKFalseSurfacedNotError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Device returns 500 with body; mimics device_agent capability handler. w.WriteHeader(http.StatusInternalServerError) _ = json.NewEncoder(w).Encode(CapabilityResponse{ RequestID: "req_x", OK: false, Error: "binary not whitelisted", }) })) defer srv.Close() c := NewClient(srv.URL) resp, err := c.Call(context.Background(), CapabilityRequest{Capability: "shell.exec"}) if err != nil { t.Fatalf("expected nil error (body parseable), got: %v", err) } if resp.OK { t.Errorf("expected ok=false") } if resp.Error == "" { t.Errorf("expected error message populated") } } func TestClient_Call_HTTPErrorWithUnparseableBody(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadGateway) _, _ = w.Write([]byte("nginx html garbage")) })) defer srv.Close() c := NewClient(srv.URL) _, err := c.Call(context.Background(), CapabilityRequest{Capability: "shell.exec"}) if err == nil { t.Fatalf("expected error for unparseable 502 body") } } func TestClient_Call_ContextCancel(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(500 * time.Millisecond) })) defer srv.Close() c := NewClient(srv.URL) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() _, err := c.Call(ctx, CapabilityRequest{Capability: "shell.exec"}) if err == nil { t.Fatalf("expected timeout error, got nil") } if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "deadline") && !strings.Contains(err.Error(), "context") { t.Errorf("expected context-related error, got: %v", err) } } func TestClient_Call_RejectsEmptyCapability(t *testing.T) { c := NewClient("http://nowhere.invalid") _, err := c.Call(context.Background(), CapabilityRequest{}) if err == nil { t.Fatalf("expected error for empty capability") } if !strings.Contains(err.Error(), "capability") { t.Errorf("expected capability-related error, got: %v", err) } } func TestClient_Call_RejectsEmptyBaseURL(t *testing.T) { c := &Client{} _, err := c.Call(context.Background(), CapabilityRequest{Capability: "shell.exec"}) if err == nil { t.Fatalf("expected error for empty BaseURL") } } func TestClient_Health(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/health" { t.Errorf("expected /health, got %s", r.URL.Path) } _ = json.NewEncoder(w).Encode(map[string]string{ "device_id": "home-wsl", "version": "0.2.0", }) })) defer srv.Close() c := NewClient(srv.URL) id, v, err := c.Health(context.Background()) if err != nil { t.Fatalf("health: %v", err) } if id != "home-wsl" { t.Errorf("device_id mismatch: %q", id) } if v != "0.2.0" { t.Errorf("version mismatch: %q", v) } } func TestClient_Call_NoRetry(t *testing.T) { // Confirm that a single failure does NOT trigger a retry — POC behavior // per the README. The handler counts hits. hits := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits++ w.WriteHeader(http.StatusBadGateway) _, _ = w.Write([]byte("oops")) })) defer srv.Close() c := NewClient(srv.URL) _, _ = c.Call(context.Background(), CapabilityRequest{Capability: "shell.exec"}) if hits != 1 { t.Errorf("expected exactly 1 hit (no retry), got %d", hits) } } func TestRandomRequestID_UniqueAndPrefixed(t *testing.T) { a, err := randomRequestID() if err != nil { t.Fatalf("randomRequestID: %v", err) } b, err := randomRequestID() if err != nil { t.Fatalf("randomRequestID: %v", err) } if a == b { t.Errorf("collision: %q == %q", a, b) } if !strings.HasPrefix(a, "req_") { t.Errorf("missing req_ prefix: %q", a) } }