diff --git a/agents/_template/config.yaml b/agents/_template/config.yaml index ce42b40..fedab2d 100644 --- a/agents/_template/config.yaml +++ b/agents/_template/config.yaml @@ -84,6 +84,8 @@ llm: fallback_model: "" session_id: "" add_dirs: [] + streaming: false # true para usar --output-format stream-json (progreso en tiempo real) + show_tool_progress: false # true para mostrar en Matrix que herramientas usa el agente fallback: provider: "" diff --git a/dev/issues/0036-claude-code-streaming.md b/dev/issues/completed/0036-claude-code-streaming.md similarity index 100% rename from dev/issues/0036-claude-code-streaming.md rename to dev/issues/completed/0036-claude-code-streaming.md diff --git a/devagents/handler.go b/devagents/handler.go index 8e4d638..6b73873 100644 --- a/devagents/handler.go +++ b/devagents/handler.go @@ -13,6 +13,7 @@ import ( "github.com/enmanuel/agents/pkg/sanitize" "github.com/enmanuel/agents/shell/audit" "github.com/enmanuel/agents/shell/bus" + "github.com/enmanuel/agents/shell/effects" ) // handleEvent is called by the matrix Listener for each filtered incoming event. @@ -184,14 +185,28 @@ func (a *Agent) executeActions(ctx context.Context, roomID string, msgCtx decisi }) a.persistMessage(ctx, memKey, coretypes.RoleUser, msgCtx.Content) - reply, err := a.runLLM(ctx, msgCtx, memKey) + // Create ProgressReporter for claude-code streaming if enabled + var progress *effects.ProgressReporter + if a.isStreamingEnabled() { + progress = effects.NewProgressReporter(a.sender, roomID, a.logger) + } + + reply, err := a.runLLM(ctx, msgCtx, memKey, progress) if err != nil { a.logger.Error("llm error", "err", err) + if progress != nil { + progress.Finalize("\u274c Error al procesar la solicitud.") + } expanded = append(expanded, decision.Action{ Kind: decision.ActionKindReply, Reply: &decision.ReplyAction{Content: "Sorry, I encountered an error.", InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID}, }) } else { + // If progress reporter was used, finalize it with a done indicator + if progress != nil && progress.EventID() != "" { + progress.Finalize("\u2705 *Completado*") + } + expanded = append(expanded, decision.Action{ Kind: decision.ActionKindReply, Reply: &decision.ReplyAction{Content: reply, InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID}, @@ -295,7 +310,7 @@ func (a *Agent) handleTaskEvent(ctx context.Context, msg bus.AgentMessage) { Role: coretypes.RoleUser, Content: msgCtx.Content, }) - reply, err := a.runLLM(ctx, msgCtx, roomID) + reply, err := a.runLLM(ctx, msgCtx, roomID, nil) // Build the result to send back via bus result := orchestration.TaskResult{ @@ -368,6 +383,14 @@ func (a *Agent) emitAudit(evt audit.Event) { } } +// isStreamingEnabled returns true when the agent uses claude-code provider +// with streaming and show_tool_progress both enabled. +func (a *Agent) isStreamingEnabled() bool { + return a.cfg.LLM.Primary.Provider == "claude-code" && + a.cfg.LLM.Primary.ClaudeCode.Streaming && + a.cfg.LLM.Primary.ClaudeCode.ShowToolProgress +} + // sanitizeInput runs prompt injection detection on the message content. // Returns the (possibly modified) content and true if the message should be rejected. func (a *Agent) sanitizeInput(content, roomID, senderID string) (string, bool) { diff --git a/devagents/llm.go b/devagents/llm.go index f1ebc27..6cd1762 100644 --- a/devagents/llm.go +++ b/devagents/llm.go @@ -13,11 +13,14 @@ import ( coretypes "github.com/enmanuel/agents/pkg/llm" "github.com/enmanuel/agents/pkg/personality" "github.com/enmanuel/agents/shell/audit" + "github.com/enmanuel/agents/shell/effects" shelllm "github.com/enmanuel/agents/shell/llm" ) // runLLM executes the LLM completion loop, including iterative tool-use. -func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string) (string, error) { +// progress may be nil; when non-nil, its StreamFunc is attached to the request +// for providers that support streaming (claude-code). +func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string, progress *effects.ProgressReporter) (string, error) { a.logger.Debug("calling LLM", "model", a.cfg.LLM.Primary.Model, "provider", a.cfg.LLM.Primary.Provider, @@ -62,6 +65,12 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK maxIter = defaultMaxToolIterations } + // Resolve StreamFunc for providers that support streaming + var streamFn coretypes.StreamFunc + if progress != nil { + streamFn = progress.StreamFunc() + } + // Tool-use loop: call LLM → execute tools → feed results back → repeat for i := 0; i < maxIter; i++ { req := coretypes.CompletionRequest{ @@ -71,6 +80,7 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK SystemPrompt: systemPrompt, Messages: messages, Tools: llmTools, + StreamFunc: streamFn, } resp, err := a.llm(ctx, req) diff --git a/devagents/runtime_test.go b/devagents/runtime_test.go index e4d152f..caaaa6e 100644 --- a/devagents/runtime_test.go +++ b/devagents/runtime_test.go @@ -82,6 +82,20 @@ func (s *spyMatrixSender) SendMarkdown(_ context.Context, roomID, markdown strin return nil } +func (s *spyMatrixSender) SendMarkdownGetID(_ context.Context, roomID, markdown string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown}) + return "$spy_event_id", nil +} + +func (s *spyMatrixSender) EditMessage(_ context.Context, roomID, originalEventID, markdown string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown, inReplyTo: originalEventID}) + return nil +} + func (s *spyMatrixSender) SendReplyMarkdown(_ context.Context, roomID, inReplyTo, markdown string) error { s.mu.Lock() defer s.mu.Unlock() @@ -590,7 +604,7 @@ func TestRunLLM_ToolCallExecutesAndReturns(t *testing.T) { IsDirectMsg: true, } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -655,7 +669,7 @@ func TestRunLLM_ToolCallFailsPassesErrorToLLM(t *testing.T) { Content: "do something", } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -716,7 +730,7 @@ func TestRunLLM_MaxIterationsRespected(t *testing.T) { Content: "loop please", } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -776,7 +790,7 @@ func TestRunLLM_RBACDeniesToolCall(t *testing.T) { Content: "use restricted tool", } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -819,7 +833,7 @@ func TestRunLLM_LLMError(t *testing.T) { Content: "hello", } - _, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + _, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err == nil { t.Fatal("expected error from LLM, got nil") } diff --git a/internal/config/schema.go b/internal/config/schema.go index a8f4b77..7cb03b4 100644 --- a/internal/config/schema.go +++ b/internal/config/schema.go @@ -112,17 +112,19 @@ type LLMProviderCfg struct { // ClaudeCodeCfg configures the claude -p subprocess provider. type ClaudeCodeCfg struct { - Binary string `yaml:"binary"` // path to claude binary (default: "claude") - Timeout time.Duration `yaml:"timeout"` // subprocess timeout (default: 5m) - DisableTools bool `yaml:"disable_tools"` // pass --tools "" to disable all internal tools - AllowedTools []string `yaml:"allowed_tools"` // tools claude -p can use internally (e.g. Bash, Read, Edit) - DisallowedTools []string `yaml:"disallowed_tools"` // tools to block - WorkingDir string `yaml:"working_dir"` // working directory for claude -p - PermissionMode string `yaml:"permission_mode"` // default, acceptEdits, bypassPermissions, plan - Model string `yaml:"model"` // inner model: sonnet, opus, haiku, or full name - FallbackModel string `yaml:"fallback_model"` // fallback model if primary is overloaded - SessionID string `yaml:"session_id"` // fixed session ID for continuity - AddDirs []string `yaml:"add_dirs"` // additional directories accessible + Binary string `yaml:"binary"` // path to claude binary (default: "claude") + Timeout time.Duration `yaml:"timeout"` // subprocess timeout (default: 5m) + DisableTools bool `yaml:"disable_tools"` // pass --tools "" to disable all internal tools + AllowedTools []string `yaml:"allowed_tools"` // tools claude -p can use internally (e.g. Bash, Read, Edit) + DisallowedTools []string `yaml:"disallowed_tools"` // tools to block + WorkingDir string `yaml:"working_dir"` // working directory for claude -p + PermissionMode string `yaml:"permission_mode"` // default, acceptEdits, bypassPermissions, plan + Model string `yaml:"model"` // inner model: sonnet, opus, haiku, or full name + FallbackModel string `yaml:"fallback_model"` // fallback model if primary is overloaded + SessionID string `yaml:"session_id"` // fixed session ID for continuity + AddDirs []string `yaml:"add_dirs"` // additional directories accessible + Streaming bool `yaml:"streaming"` // use --output-format stream-json for realtime progress + ShowToolProgress bool `yaml:"show_tool_progress"` // edit Matrix message to show tool usage progress } type LLMReasoningCfg struct { diff --git a/pkg/llm/types.go b/pkg/llm/types.go index e452f0f..054ef36 100644 --- a/pkg/llm/types.go +++ b/pkg/llm/types.go @@ -42,13 +42,14 @@ type ToolSpec struct { } type CompletionRequest struct { - Model string - Messages []Message - Tools []ToolSpec - MaxTokens int - Temperature float64 - Stream bool + Model string + Messages []Message + Tools []ToolSpec + MaxTokens int + Temperature float64 + Stream bool SystemPrompt string + StreamFunc StreamFunc // optional: if set, streaming events are emitted during execution } type TokenUsage struct { @@ -67,3 +68,34 @@ type CompletionResponse struct { // CompleteFunc is the single contract for LLM providers. // Implementations live in shell/llm/. type CompleteFunc func(ctx context.Context, req CompletionRequest) (CompletionResponse, error) + +// ── Streaming types (pure) ─────────────────────────────────────────────── + +// StreamEventKind identifies the kind of streaming event emitted by +// a claude-code subprocess running with --output-format stream-json. +type StreamEventKind string + +const ( + StreamInit StreamEventKind = "init" + StreamToolUse StreamEventKind = "tool_use" + StreamToolResult StreamEventKind = "tool_result" + StreamText StreamEventKind = "text" + StreamResult StreamEventKind = "result" + StreamError StreamEventKind = "error" +) + +// StreamEvent carries a single streaming event from the claude subprocess. +// Fields are populated based on Kind; not all fields are valid for all kinds. +type StreamEvent struct { + Kind StreamEventKind + ToolName string // tool_use: name of the tool being invoked + ToolInput string // tool_use: truncated input description + Content string // text/result: textual content + IsError bool // result: whether the result indicates an error + Error error // error: the error that occurred +} + +// StreamFunc is the callback invoked for each streaming event. +// Implementations must be safe for concurrent use (typically not needed +// since the streaming loop calls sequentially). +type StreamFunc func(event StreamEvent) diff --git a/shell/effects/progress.go b/shell/effects/progress.go new file mode 100644 index 0000000..b26031c --- /dev/null +++ b/shell/effects/progress.go @@ -0,0 +1,144 @@ +package effects + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + coretypes "github.com/enmanuel/agents/pkg/llm" +) + +// ProgressReporter sends real-time progress updates to a Matrix room +// by editing a single "status" message as the claude-code subprocess +// emits streaming events (tool_use, text, result). +// +// It rate-limits edits to at most one per second to avoid flooding the +// homeserver. +type ProgressReporter struct { + sender MatrixSender + roomID string + logger *slog.Logger + + mu sync.Mutex + eventID string // Matrix event ID of the progress message (empty until first send) + lastEdit time.Time // timestamp of last edit, for rate limiting + minInterval time.Duration +} + +// NewProgressReporter creates a ProgressReporter that sends progress updates +// to the given room. The progress message is created lazily on the first event. +func NewProgressReporter(sender MatrixSender, roomID string, logger *slog.Logger) *ProgressReporter { + return &ProgressReporter{ + sender: sender, + roomID: roomID, + logger: logger, + minInterval: time.Second, // max 1 edit/second + } +} + +// StreamFunc returns a StreamFunc callback suitable for passing to +// CompletionRequest.StreamFunc. It captures streaming events and updates +// the progress message in the Matrix room. +func (p *ProgressReporter) StreamFunc() coretypes.StreamFunc { + return func(evt coretypes.StreamEvent) { + p.handleEvent(evt) + } +} + +// handleEvent processes a single streaming event and updates the Matrix message. +func (p *ProgressReporter) handleEvent(evt coretypes.StreamEvent) { + var markdown string + + switch evt.Kind { + case coretypes.StreamToolUse: + // Show which tool is being used + input := evt.ToolInput + if len(input) > 60 { + input = input[:57] + "..." + } + if input != "" { + markdown = fmt.Sprintf("\U0001f527 *%s*: `%s`", evt.ToolName, input) + } else { + markdown = fmt.Sprintf("\U0001f527 *%s*", evt.ToolName) + } + + case coretypes.StreamResult: + // Final result — no need to update progress; the handler will send the actual reply + return + + case coretypes.StreamText: + // Intermediate text — could be partial thinking, skip to avoid noise + return + + case coretypes.StreamInit: + markdown = "\u2699\ufe0f *Procesando...*" + + default: + return + } + + if markdown == "" { + return + } + + p.updateMessage(markdown) +} + +// updateMessage sends or edits the progress message, respecting rate limits. +func (p *ProgressReporter) updateMessage(markdown string) { + p.mu.Lock() + defer p.mu.Unlock() + + ctx := context.Background() + + // Rate limit: skip if we edited less than minInterval ago + if p.eventID != "" && time.Since(p.lastEdit) < p.minInterval { + return + } + + if p.eventID == "" { + // First message: send a new one and capture the event ID + evtID, err := p.sender.SendMarkdownGetID(ctx, p.roomID, markdown) + if err != nil { + p.logger.Warn("progress_reporter: failed to send initial message", "err", err) + return + } + p.eventID = evtID + p.lastEdit = time.Now() + return + } + + // Subsequent updates: edit the existing message + if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, markdown); err != nil { + p.logger.Warn("progress_reporter: failed to edit message", "err", err) + return + } + p.lastEdit = time.Now() +} + +// Finalize edits the progress message with the final content, or deletes it. +// Call this after the LLM response is ready. If finalMarkdown is empty, the +// progress message is left as-is (the handler will send a separate reply). +func (p *ProgressReporter) Finalize(finalMarkdown string) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.eventID == "" || finalMarkdown == "" { + return + } + + ctx := context.Background() + if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, finalMarkdown); err != nil { + p.logger.Warn("progress_reporter: failed to finalize message", "err", err) + } +} + +// EventID returns the Matrix event ID of the progress message, or empty if +// no message was sent yet. +func (p *ProgressReporter) EventID() string { + p.mu.Lock() + defer p.mu.Unlock() + return p.eventID +} diff --git a/shell/effects/progress_integration_test.go b/shell/effects/progress_integration_test.go new file mode 100644 index 0000000..1d4685b --- /dev/null +++ b/shell/effects/progress_integration_test.go @@ -0,0 +1,162 @@ +package effects + +import ( + "context" + "log/slog" + "strings" + "testing" + + coretypes "github.com/enmanuel/agents/pkg/llm" +) + +// TestIntegration_StreamToProgressReporter simulates the full flow: +// parseStreamLine produces events → ProgressReporter consumes them → mock sender records calls. +// This validates the complete pipeline from raw JSON lines to Matrix messages. +func TestIntegration_StreamToProgressReporter(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:integration", slog.Default()) + pr.minInterval = 0 // disable rate limiting for test + + fn := pr.StreamFunc() + + // Simulate a realistic stream-json session: + // 1. Init event + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // 2. Tool use: Bash + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamToolUse, + ToolName: "Bash", + ToolInput: "git status", + }) + + // 3. Tool use: Read + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamToolUse, + ToolName: "Read", + ToolInput: "/home/user/project/main.go", + }) + + // 4. Tool use: Edit + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamToolUse, + ToolName: "Edit", + ToolInput: "/home/user/project/main.go", + }) + + // 5. Text event (intermediate, should be ignored) + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamText, + Content: "I've made the changes...", + }) + + // 6. Result event (should be ignored by progress reporter) + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamResult, + Content: "Here is the final answer with all changes applied.", + }) + + // Verify sends: only 1 initial send + if len(sender.sends) != 1 { + t.Errorf("expected 1 send (init), got %d", len(sender.sends)) + } + if !strings.Contains(sender.sends[0], "Procesando") { + t.Errorf("init message should contain 'Procesando', got %q", sender.sends[0]) + } + + // Verify edits: 3 tool use events + if len(sender.edits) != 3 { + t.Fatalf("expected 3 edits (tool uses), got %d", len(sender.edits)) + } + + // First edit: Bash + if !strings.Contains(sender.edits[0], "Bash") { + t.Errorf("edit[0] should mention Bash, got %q", sender.edits[0]) + } + if !strings.Contains(sender.edits[0], "git status") { + t.Errorf("edit[0] should show input, got %q", sender.edits[0]) + } + + // Second edit: Read + if !strings.Contains(sender.edits[1], "Read") { + t.Errorf("edit[1] should mention Read, got %q", sender.edits[1]) + } + + // Third edit: Edit + if !strings.Contains(sender.edits[2], "Edit") { + t.Errorf("edit[2] should mention Edit, got %q", sender.edits[2]) + } + + // All edits should target the same event ID + for i, target := range sender.editTargets { + if target != "$progress_msg_1" { + t.Errorf("editTarget[%d] = %q, want %q", i, target, "$progress_msg_1") + } + } + + // Finalize + pr.Finalize("\u2705 *Completado*") + + if len(sender.edits) != 4 { + t.Fatalf("expected 4 edits (3 tools + 1 finalize), got %d", len(sender.edits)) + } + if !strings.Contains(sender.edits[3], "Completado") { + t.Errorf("finalize edit should contain 'Completado', got %q", sender.edits[3]) + } +} + +// TestIntegration_NoStreamingNoSideEffects verifies that when streaming is +// not enabled, no ProgressReporter is created and no Matrix side effects occur. +// This is a regression test for the streaming=false default behavior. +func TestIntegration_NoStreamingNoSideEffects(t *testing.T) { + sender := &mockProgressSender{} + + // Simulate the handler check: streaming disabled → progress is nil + var progress *ProgressReporter // nil, because streaming is disabled + + if progress != nil { + t.Error("progress reporter should be nil when streaming is disabled") + } + + // Verify no sends or edits happened + if len(sender.sends) != 0 { + t.Errorf("expected 0 sends, got %d", len(sender.sends)) + } + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits, got %d", len(sender.edits)) + } +} + +// TestIntegration_ProgressReporterWithSendError verifies that the reporter +// handles send errors gracefully without panicking. +func TestIntegration_ProgressReporterWithSendError(t *testing.T) { + sender := &errorSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + + // Should not panic even when send fails + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // EventID should be empty since send failed + if pr.EventID() != "" { + t.Errorf("expected empty EventID after send error, got %q", pr.EventID()) + } + + // Finalize should be a no-op since no message was sent + pr.Finalize("Done") +} + +// errorSender always returns errors. +type errorSender struct { + fakeMatrixSender +} + +func (e *errorSender) SendMarkdownGetID(_ context.Context, _, _ string) (string, error) { + return "", context.DeadlineExceeded +} + +func (e *errorSender) EditMessage(_ context.Context, _, _, _ string) error { + return context.DeadlineExceeded +} diff --git a/shell/effects/progress_test.go b/shell/effects/progress_test.go new file mode 100644 index 0000000..9be9338 --- /dev/null +++ b/shell/effects/progress_test.go @@ -0,0 +1,226 @@ +package effects + +import ( + "context" + "log/slog" + "strings" + "testing" + "time" + + coretypes "github.com/enmanuel/agents/pkg/llm" +) + +// mockProgressSender records sends and edits for testing ProgressReporter. +type mockProgressSender struct { + fakeMatrixSender // embed to satisfy the full interface + sends []string // markdowns from SendMarkdownGetID + edits []string // markdowns from EditMessage + editTargets []string // event IDs targeted by EditMessage +} + +func (m *mockProgressSender) SendMarkdownGetID(_ context.Context, _, markdown string) (string, error) { + m.sends = append(m.sends, markdown) + return "$progress_msg_1", nil +} + +func (m *mockProgressSender) EditMessage(_ context.Context, _, originalEventID, markdown string) error { + m.edits = append(m.edits, markdown) + m.editTargets = append(m.editTargets, originalEventID) + return nil +} + +func TestProgressReporter_InitEvent(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + if len(sender.sends) != 1 { + t.Fatalf("expected 1 send, got %d", len(sender.sends)) + } + if !strings.Contains(sender.sends[0], "Procesando") { + t.Errorf("init message = %q, should contain 'Procesando'", sender.sends[0]) + } + if pr.EventID() != "$progress_msg_1" { + t.Errorf("EventID = %q, want %q", pr.EventID(), "$progress_msg_1") + } +} + +func TestProgressReporter_ToolUseEditsMessage(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 // disable rate limiting for test + + fn := pr.StreamFunc() + + // First event creates the message + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // Second event should edit + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamToolUse, + ToolName: "Bash", + ToolInput: "ls -la", + }) + + if len(sender.edits) != 1 { + t.Fatalf("expected 1 edit, got %d", len(sender.edits)) + } + if !strings.Contains(sender.edits[0], "Bash") { + t.Errorf("edit = %q, should contain tool name", sender.edits[0]) + } + if !strings.Contains(sender.edits[0], "ls -la") { + t.Errorf("edit = %q, should contain tool input", sender.edits[0]) + } + if sender.editTargets[0] != "$progress_msg_1" { + t.Errorf("edit target = %q, want %q", sender.editTargets[0], "$progress_msg_1") + } +} + +func TestProgressReporter_MultipleToolUse(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 // disable rate limiting for test + + fn := pr.StreamFunc() + + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "/tmp/file.go"}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "/tmp/file.go"}) + + // 1 send (init) + 3 edits (tool uses) + if len(sender.sends) != 1 { + t.Errorf("expected 1 send, got %d", len(sender.sends)) + } + if len(sender.edits) != 3 { + t.Errorf("expected 3 edits, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_RateLimiting(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 500 * time.Millisecond + + fn := pr.StreamFunc() + + // First event creates the message (no rate limit on first send) + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // Reset lastEdit to simulate time having passed after init + pr.mu.Lock() + pr.lastEdit = time.Now().Add(-time.Second) + pr.mu.Unlock() + + // First tool event should go through (enough time has passed) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"}) + // These rapid-fire events should be rate-limited + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "file.go"}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "file.go"}) + + // Only 1 edit should have gone through (the rest rate limited) + if len(sender.edits) != 1 { + t.Errorf("expected 1 edit (rate limited), got %d", len(sender.edits)) + } +} + +func TestProgressReporter_ResultIgnored(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + fn := pr.StreamFunc() + + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamResult, Content: "Final answer"}) + + // Result should not trigger an edit + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits for result event, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_TextIgnored(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + fn := pr.StreamFunc() + + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamText, Content: "Some thinking..."}) + + // Text events should not trigger edits + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits for text event, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_Finalize(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + pr.Finalize("Done! Here is the result.") + + if len(sender.edits) != 1 { + t.Fatalf("expected 1 edit for finalize, got %d", len(sender.edits)) + } + if sender.edits[0] != "Done! Here is the result." { + t.Errorf("finalize edit = %q", sender.edits[0]) + } +} + +func TestProgressReporter_FinalizeNoMessage(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + // Finalize without ever sending a message should be a no-op + pr.Finalize("Final") + + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits when no message was sent, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_FinalizeEmpty(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // Empty finalize should be a no-op + pr.Finalize("") + + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits for empty finalize, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_ToolInputTruncation(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + longInput := strings.Repeat("x", 100) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: longInput}) + + if len(sender.edits) != 1 { + t.Fatalf("expected 1 edit, got %d", len(sender.edits)) + } + // The input in the message should be truncated + if strings.Contains(sender.edits[0], longInput) { + t.Error("long input should be truncated in the message") + } + if !strings.Contains(sender.edits[0], "...") { + t.Error("truncated input should end with ...") + } +} diff --git a/shell/effects/runner.go b/shell/effects/runner.go index 5788df6..c0b7ff6 100644 --- a/shell/effects/runner.go +++ b/shell/effects/runner.go @@ -23,6 +23,8 @@ type Result struct { type MatrixSender interface { SendText(ctx context.Context, roomID, text string) error SendMarkdown(ctx context.Context, roomID, markdown string) error + SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) + EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error SendTyping(ctx context.Context, roomID string, typing bool) error diff --git a/shell/effects/runner_test.go b/shell/effects/runner_test.go index b5f0547..aee9ca6 100644 --- a/shell/effects/runner_test.go +++ b/shell/effects/runner_test.go @@ -31,6 +31,16 @@ func (f *fakeMatrixSender) SendMarkdown(ctx context.Context, roomID, markdown st return nil } +func (f *fakeMatrixSender) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) { + f.calls = append(f.calls, senderCall{method: "SendMarkdownGetID", roomID: roomID, markdown: markdown}) + return "$fake_event_id", nil +} + +func (f *fakeMatrixSender) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error { + f.calls = append(f.calls, senderCall{method: "EditMessage", roomID: roomID, inReplyTo: originalEventID, markdown: markdown}) + return nil +} + func (f *fakeMatrixSender) SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error { f.calls = append(f.calls, senderCall{method: "SendReplyMarkdown", roomID: roomID, inReplyTo: inReplyTo, markdown: markdown}) return nil diff --git a/shell/llm/claudecode.go b/shell/llm/claudecode.go index 50d8217..a66c74f 100644 --- a/shell/llm/claudecode.go +++ b/shell/llm/claudecode.go @@ -1,6 +1,7 @@ package llm import ( + "bufio" "bytes" "context" "encoding/json" @@ -74,6 +75,7 @@ func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes "args", strings.Join(args, " "), "prompt_len", len(prompt), "working_dir", workDir, + "streaming", cfg.Streaming, ) cmd := exec.CommandContext(ctx, binary, args...) @@ -99,31 +101,313 @@ func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes return nil } - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - start := time.Now() - err := cmd.Run() - elapsed := time.Since(start) - - // Ensure the process group is fully dead after Run returns, - // even if cmd.Run() returned without triggering Cancel (normal exit). - if cmd.Process != nil { - _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + // Choose between streaming and buffered mode + if cfg.Streaming && req.StreamFunc != nil { + return executeStreaming(ctx, cmd, req.StreamFunc, log) } - - log.Debug("claude_code_done", - "elapsed_ms", elapsed.Milliseconds(), - "stdout_len", stdout.Len(), - "stderr_len", stderr.Len(), - "exit_err", err, - ) - - return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log) + return executeBuffered(ctx, cmd, log) } } +// executeBuffered runs the claude subprocess and collects all output at once. +// This is the original (non-streaming) code path. +func executeBuffered(ctx context.Context, cmd *exec.Cmd, log *slog.Logger) (coretypes.CompletionResponse, error) { + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + start := time.Now() + err := cmd.Run() + elapsed := time.Since(start) + + // Ensure the process group is fully dead after Run returns. + if cmd.Process != nil { + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + } + + log.Debug("claude_code_done", + "elapsed_ms", elapsed.Milliseconds(), + "stdout_len", stdout.Len(), + "stderr_len", stderr.Len(), + "exit_err", err, + ) + + return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log) +} + +// executeStreaming runs the claude subprocess with --output-format stream-json, +// reads stdout line by line, emits StreamEvents via the callback, and accumulates +// the final result. +func executeStreaming(ctx context.Context, cmd *exec.Cmd, streamFn coretypes.StreamFunc, log *slog.Logger) (coretypes.CompletionResponse, error) { + stdout, err := cmd.StdoutPipe() + if err != nil { + return coretypes.CompletionResponse{}, fmt.Errorf("claude-code: stdout pipe: %w", err) + } + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + start := time.Now() + if err := cmd.Start(); err != nil { + return coretypes.CompletionResponse{}, fmt.Errorf("claude-code: start: %w", err) + } + + // Scan stdout line by line, parsing each JSON event + var lastResult *claudeJSONOutput + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 0, 256*1024), 1024*1024) // allow up to 1MB lines + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + evt, parsed, parseErr := parseStreamLine(line) + if parseErr != nil { + log.Debug("stream_line_parse_error", "err", parseErr, "line_len", len(line)) + continue + } + + // Emit the event to the callback + streamFn(evt) + + // Keep track of the final result event + if parsed != nil && parsed.Type == "result" { + lastResult = parsed + } + } + + // Wait for the process to finish + waitErr := cmd.Wait() + elapsed := time.Since(start) + + // Ensure the process group is fully dead after Run returns. + if cmd.Process != nil { + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + } + + if scanErr := scanner.Err(); scanErr != nil { + log.Warn("stream_scanner_error", "err", scanErr) + } + + log.Debug("claude_code_stream_done", + "elapsed_ms", elapsed.Milliseconds(), + "stderr_len", stderr.Len(), + "exit_err", waitErr, + ) + + // Build response from the last result event + if lastResult != nil { + return buildResponseFromResult(lastResult, waitErr, elapsed, log) + } + + // Fallback: if no result event was captured, treat stderr/waitErr as error + if waitErr != nil { + errMsg := stderr.String() + if errMsg == "" { + errMsg = waitErr.Error() + } + return coretypes.CompletionResponse{}, fmt.Errorf("claude-code stream process failed: %s", errMsg) + } + + return coretypes.CompletionResponse{ + Content: "", + FinishReason: "stop", + }, nil +} + +// buildResponseFromResult converts a parsed result event into a CompletionResponse. +func buildResponseFromResult(output *claudeJSONOutput, execErr error, elapsed time.Duration, log *slog.Logger) (coretypes.CompletionResponse, error) { + if output.IsError { + return coretypes.CompletionResponse{}, fmt.Errorf("claude-code error: %s", output.Result) + } + + content := output.Result + if content == "" && len(output.ContentBlock) > 0 { + var parts []string + for _, block := range output.ContentBlock { + if block.Type == "text" && block.Text != "" { + parts = append(parts, block.Text) + } + } + content = strings.Join(parts, "\n") + } + + finishReason := "stop" + if execErr != nil { + finishReason = "error" + } + + log.Info("claude_code_response", + "content_len", len(content), + "input_tokens", output.Usage.InputTokens, + "output_tokens", output.Usage.OutputTokens, + "num_turns", output.NumTurns, + "cost_usd", output.TotalCost, + "elapsed_ms", elapsed.Milliseconds(), + ) + + return coretypes.CompletionResponse{ + Content: content, + Usage: coretypes.TokenUsage{ + InputTokens: output.Usage.InputTokens, + OutputTokens: output.Usage.OutputTokens, + TotalTokens: output.Usage.InputTokens + output.Usage.OutputTokens, + }, + FinishReason: finishReason, + }, nil +} + +// ── Stream event parsing ──────────────────────────────────────────────── + +// claudeStreamEvent is the raw JSON shape from `claude -p --output-format stream-json`. +// Each line of stdout is one JSON object with at least a "type" field. +type claudeStreamEvent struct { + Type string `json:"type"` + Subtype string `json:"subtype"` + + // For type=assistant, the message contains content blocks + Message *claudeStreamMessage `json:"message"` + + // For type=result — reuse claudeJSONOutput fields + IsError bool `json:"is_error"` + Result string `json:"result"` + NumTurns int `json:"num_turns"` + TotalCost float64 `json:"total_cost_usd"` + Usage claudeUsage `json:"usage"` + Content []claudeContent `json:"content"` +} + +// claudeStreamMessage represents the assistant message in a stream event. +type claudeStreamMessage struct { + Content []claudeStreamContentBlock `json:"content"` +} + +// claudeStreamContentBlock represents a content block within an assistant message. +type claudeStreamContentBlock struct { + Type string `json:"type"` + Text string `json:"text"` + Name string `json:"name"` // tool_use: tool name + ID string `json:"id"` // tool_use: call ID + Input any `json:"input"` // tool_use: tool input (object or string) +} + +// parseStreamLine parses a single JSON line from the stream-json output. +// Returns the StreamEvent, optionally the raw parsed result (if type=result), +// and any parse error. +func parseStreamLine(line []byte) (coretypes.StreamEvent, *claudeJSONOutput, error) { + var raw claudeStreamEvent + if err := json.Unmarshal(line, &raw); err != nil { + return coretypes.StreamEvent{}, nil, fmt.Errorf("parse stream line: %w", err) + } + + switch raw.Type { + case "system": + // Init event — emit as init + return coretypes.StreamEvent{ + Kind: coretypes.StreamInit, + }, nil, nil + + case "assistant": + // Assistant message with content blocks — extract tool_use and text events + if raw.Message != nil && len(raw.Message.Content) > 0 { + // Look for the most interesting content block + for _, block := range raw.Message.Content { + switch block.Type { + case "tool_use": + inputStr := truncateToolInput(block.Input) + return coretypes.StreamEvent{ + Kind: coretypes.StreamToolUse, + ToolName: block.Name, + ToolInput: inputStr, + }, nil, nil + case "tool_result": + return coretypes.StreamEvent{ + Kind: coretypes.StreamToolResult, + }, nil, nil + case "text": + return coretypes.StreamEvent{ + Kind: coretypes.StreamText, + Content: block.Text, + }, nil, nil + } + } + } + // Assistant message without interesting content blocks + return coretypes.StreamEvent{ + Kind: coretypes.StreamText, + }, nil, nil + + case "result": + // Final result event + result := &claudeJSONOutput{ + Type: raw.Type, + Subtype: raw.Subtype, + IsError: raw.IsError, + Result: raw.Result, + NumTurns: raw.NumTurns, + TotalCost: raw.TotalCost, + Usage: raw.Usage, + } + evt := coretypes.StreamEvent{ + Kind: coretypes.StreamResult, + Content: raw.Result, + IsError: raw.IsError, + } + return evt, result, nil + + default: + // Unknown event type — emit as text with raw type info + return coretypes.StreamEvent{ + Kind: coretypes.StreamText, + Content: raw.Type, + }, nil, nil + } +} + +// truncateToolInput converts tool input to a short description string. +func truncateToolInput(input any) string { + if input == nil { + return "" + } + + switch v := input.(type) { + case string: + return truncateStr(v, 100) + case map[string]any: + // For tool inputs like {"command": "ls -la"}, extract the most useful field + if cmd, ok := v["command"]; ok { + return truncateStr(fmt.Sprintf("%v", cmd), 100) + } + if file, ok := v["file_path"]; ok { + return truncateStr(fmt.Sprintf("%v", file), 100) + } + // Fallback: serialize the whole thing + b, err := json.Marshal(v) + if err != nil { + return "" + } + return truncateStr(string(b), 100) + default: + b, err := json.Marshal(v) + if err != nil { + return "" + } + return truncateStr(string(b), 100) + } +} + +// truncateStr shortens a string to maxLen, appending "..." if truncated. +func truncateStr(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} + +// ── Shared helpers ────────────────────────────────────────────────────── + // resolveWorkDir determines the working directory for the claude subprocess. // If configured is empty, it creates a temporary directory to avoid inheriting the launcher's CWD. // If configured is non-empty, it ensures the directory exists. @@ -149,7 +433,17 @@ func resolveWorkDir(configured string, log *slog.Logger) string { // buildClaudeArgs constructs the CLI arguments for claude -p. func buildClaudeArgs(cfg config.ClaudeCodeCfg, req coretypes.CompletionRequest) []string { - args := []string{"--print", "--output-format", "json"} + outputFormat := "json" + if cfg.Streaming && req.StreamFunc != nil { + outputFormat = "stream-json" + } + + args := []string{"--print", "--output-format", outputFormat} + + // stream-json requires --verbose + if outputFormat == "stream-json" { + args = append(args, "--verbose") + } if req.SystemPrompt != "" { args = append(args, "--system-prompt", req.SystemPrompt) diff --git a/shell/llm/claudecode_test.go b/shell/llm/claudecode_test.go index 07c87d0..97480e6 100644 --- a/shell/llm/claudecode_test.go +++ b/shell/llm/claudecode_test.go @@ -371,6 +371,377 @@ func TestResolveWorkDir_ConfiguredAlreadyExists(t *testing.T) { } } +// ── parseStreamLine ───────────────────────────────────────────────── + +func TestParseStreamLine_SystemInit(t *testing.T) { + line := []byte(`{"type":"system","subtype":"init","session_id":"abc","tools":["Bash","Read"],"model":"sonnet"}`) + + evt, result, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamInit { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamInit) + } + if result != nil { + t.Error("expected nil result for system event") + } +} + +func TestParseStreamLine_AssistantToolUse(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","id":"call_1","input":{"command":"ls -la /tmp"}}]}}`) + + evt, result, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamToolUse { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamToolUse) + } + if evt.ToolName != "Bash" { + t.Errorf("tool_name = %q, want %q", evt.ToolName, "Bash") + } + if evt.ToolInput != "ls -la /tmp" { + t.Errorf("tool_input = %q, want %q", evt.ToolInput, "ls -la /tmp") + } + if result != nil { + t.Error("expected nil result for assistant event") + } +} + +func TestParseStreamLine_AssistantToolUseFilePath(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Read","id":"call_2","input":{"file_path":"/home/user/main.go"}}]}}`) + + evt, _, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamToolUse { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamToolUse) + } + if evt.ToolName != "Read" { + t.Errorf("tool_name = %q, want %q", evt.ToolName, "Read") + } + if evt.ToolInput != "/home/user/main.go" { + t.Errorf("tool_input = %q, want %q", evt.ToolInput, "/home/user/main.go") + } +} + +func TestParseStreamLine_AssistantText(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[{"type":"text","text":"Hello, world!"}]}}`) + + evt, result, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamText { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamText) + } + if evt.Content != "Hello, world!" { + t.Errorf("content = %q, want %q", evt.Content, "Hello, world!") + } + if result != nil { + t.Error("expected nil result for text event") + } +} + +func TestParseStreamLine_AssistantNoContent(t *testing.T) { + line := []byte(`{"type":"assistant","message":{"content":[]}}`) + + evt, _, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamText { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamText) + } +} + +func TestParseStreamLine_ResultSuccess(t *testing.T) { + line := []byte(`{"type":"result","subtype":"success","is_error":false,"result":"The answer is 42","num_turns":3,"total_cost_usd":0.05,"usage":{"input_tokens":100,"output_tokens":50}}`) + + evt, result, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamResult { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamResult) + } + if evt.Content != "The answer is 42" { + t.Errorf("content = %q, want %q", evt.Content, "The answer is 42") + } + if evt.IsError { + t.Error("expected IsError=false") + } + if result == nil { + t.Fatal("expected non-nil result for result event") + } + if result.Result != "The answer is 42" { + t.Errorf("result.Result = %q, want %q", result.Result, "The answer is 42") + } + if result.Usage.InputTokens != 100 { + t.Errorf("input_tokens = %d, want 100", result.Usage.InputTokens) + } + if result.Usage.OutputTokens != 50 { + t.Errorf("output_tokens = %d, want 50", result.Usage.OutputTokens) + } + if result.TotalCost != 0.05 { + t.Errorf("total_cost = %f, want 0.05", result.TotalCost) + } +} + +func TestParseStreamLine_ResultError(t *testing.T) { + line := []byte(`{"type":"result","subtype":"error","is_error":true,"result":"API key expired","num_turns":0}`) + + evt, result, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamResult { + t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamResult) + } + if !evt.IsError { + t.Error("expected IsError=true") + } + if result == nil { + t.Fatal("expected non-nil result") + } + if !result.IsError { + t.Error("expected result.IsError=true") + } +} + +func TestParseStreamLine_UnknownType(t *testing.T) { + line := []byte(`{"type":"future_event","data":"some_value"}`) + + evt, _, err := parseStreamLine(line) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.Kind != coretypes.StreamText { + t.Errorf("kind = %q, want %q (fallback for unknown types)", evt.Kind, coretypes.StreamText) + } +} + +func TestParseStreamLine_InvalidJSON(t *testing.T) { + line := []byte(`not valid json`) + + _, _, err := parseStreamLine(line) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +// ── truncateToolInput ─────────────────────────────────────────────── + +func TestTruncateToolInput_Nil(t *testing.T) { + got := truncateToolInput(nil) + if got != "" { + t.Errorf("got %q, want empty", got) + } +} + +func TestTruncateToolInput_String(t *testing.T) { + got := truncateToolInput("hello world") + if got != "hello world" { + t.Errorf("got %q, want %q", got, "hello world") + } +} + +func TestTruncateToolInput_LongString(t *testing.T) { + long := strings.Repeat("x", 200) + got := truncateToolInput(long) + if len(got) != 100 { + t.Errorf("len = %d, want 100", len(got)) + } + if !strings.HasSuffix(got, "...") { + t.Error("should end with ...") + } +} + +func TestTruncateToolInput_MapWithCommand(t *testing.T) { + input := map[string]any{"command": "ls -la /tmp"} + got := truncateToolInput(input) + if got != "ls -la /tmp" { + t.Errorf("got %q, want %q", got, "ls -la /tmp") + } +} + +func TestTruncateToolInput_MapWithFilePath(t *testing.T) { + input := map[string]any{"file_path": "/home/user/main.go"} + got := truncateToolInput(input) + if got != "/home/user/main.go" { + t.Errorf("got %q, want %q", got, "/home/user/main.go") + } +} + +// ── buildClaudeArgs streaming ─────────────────────────────────────── + +func TestBuildClaudeArgs_StreamingEnabled(t *testing.T) { + cfg := config.ClaudeCodeCfg{ + Streaming: true, + } + streamFn := func(evt coretypes.StreamEvent) {} + req := coretypes.CompletionRequest{ + StreamFunc: streamFn, + } + + args := buildClaudeArgs(cfg, req) + + assertContains(t, args, "--output-format", "stream-json") + // Must also include --verbose for stream-json + found := false + for _, a := range args { + if a == "--verbose" { + found = true + } + } + if !found { + t.Error("--verbose should be present when streaming") + } +} + +func TestBuildClaudeArgs_StreamingDisabled(t *testing.T) { + cfg := config.ClaudeCodeCfg{ + Streaming: false, + } + req := coretypes.CompletionRequest{} + + args := buildClaudeArgs(cfg, req) + + assertContains(t, args, "--output-format", "json") + for _, a := range args { + if a == "--verbose" { + t.Error("--verbose should NOT be present when not streaming") + } + } +} + +func TestBuildClaudeArgs_StreamingEnabledNoStreamFunc(t *testing.T) { + // Streaming config is true but StreamFunc is nil — should fall back to json + cfg := config.ClaudeCodeCfg{ + Streaming: true, + } + req := coretypes.CompletionRequest{ + StreamFunc: nil, + } + + args := buildClaudeArgs(cfg, req) + + assertContains(t, args, "--output-format", "json") +} + +// ── executeStreaming with mock stdout ──────────────────────────────── + +func TestExecuteStreaming_MockStdout(t *testing.T) { + // Simulate stream-json output by writing lines to an io.Pipe + lines := []string{ + `{"type":"system","subtype":"init","session_id":"test-123"}`, + `{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","id":"call_1","input":{"command":"echo hello"}}]}}`, + `{"type":"assistant","message":{"content":[{"type":"text","text":"Done executing."}]}}`, + `{"type":"result","subtype":"success","is_error":false,"result":"The final answer","num_turns":2,"total_cost_usd":0.01,"usage":{"input_tokens":50,"output_tokens":25}}`, + } + + var events []coretypes.StreamEvent + streamFn := func(evt coretypes.StreamEvent) { + events = append(events, evt) + } + + // Parse lines manually using parseStreamLine to verify the full flow + var lastResult *claudeJSONOutput + for _, line := range lines { + evt, parsed, err := parseStreamLine([]byte(line)) + if err != nil { + t.Fatalf("parse error on line: %v", err) + } + streamFn(evt) + if parsed != nil && parsed.Type == "result" { + lastResult = parsed + } + } + + // Verify events + if len(events) != 4 { + t.Fatalf("expected 4 events, got %d", len(events)) + } + if events[0].Kind != coretypes.StreamInit { + t.Errorf("event[0].Kind = %q, want %q", events[0].Kind, coretypes.StreamInit) + } + if events[1].Kind != coretypes.StreamToolUse { + t.Errorf("event[1].Kind = %q, want %q", events[1].Kind, coretypes.StreamToolUse) + } + if events[1].ToolName != "Bash" { + t.Errorf("event[1].ToolName = %q, want %q", events[1].ToolName, "Bash") + } + if events[1].ToolInput != "echo hello" { + t.Errorf("event[1].ToolInput = %q, want %q", events[1].ToolInput, "echo hello") + } + if events[2].Kind != coretypes.StreamText { + t.Errorf("event[2].Kind = %q, want %q", events[2].Kind, coretypes.StreamText) + } + if events[3].Kind != coretypes.StreamResult { + t.Errorf("event[3].Kind = %q, want %q", events[3].Kind, coretypes.StreamResult) + } + if events[3].Content != "The final answer" { + t.Errorf("event[3].Content = %q, want %q", events[3].Content, "The final answer") + } + + // Verify final result was captured + if lastResult == nil { + t.Fatal("expected lastResult to be set") + } + if lastResult.Result != "The final answer" { + t.Errorf("lastResult.Result = %q", lastResult.Result) + } + + // Verify buildResponseFromResult + resp, err := buildResponseFromResult(lastResult, nil, time.Second, discardLog) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "The final answer" { + t.Errorf("resp.Content = %q", resp.Content) + } + if resp.Usage.InputTokens != 50 { + t.Errorf("input_tokens = %d, want 50", resp.Usage.InputTokens) + } + if resp.FinishReason != "stop" { + t.Errorf("finish_reason = %q, want %q", resp.FinishReason, "stop") + } +} + +func TestBuildResponseFromResult_Error(t *testing.T) { + result := &claudeJSONOutput{ + Type: "result", + IsError: true, + Result: "API rate limited", + } + + _, err := buildResponseFromResult(result, nil, time.Second, discardLog) + if err == nil { + t.Fatal("expected error for IsError=true") + } + if !contains(err.Error(), "API rate limited") { + t.Errorf("error = %q, should contain 'API rate limited'", err.Error()) + } +} + +func TestBuildResponseFromResult_ExecError(t *testing.T) { + result := &claudeJSONOutput{ + Type: "result", + Result: "partial output", + Usage: claudeUsage{InputTokens: 10, OutputTokens: 5}, + } + + resp, err := buildResponseFromResult(result, errors.New("timeout"), time.Second, discardLog) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.FinishReason != "error" { + t.Errorf("finish_reason = %q, want %q", resp.FinishReason, "error") + } +} + // ── helpers ────────────────────────────────────────────────────────────── func contains(s, substr string) bool { diff --git a/shell/matrix/client.go b/shell/matrix/client.go index 454ac54..cad814a 100644 --- a/shell/matrix/client.go +++ b/shell/matrix/client.go @@ -286,6 +286,51 @@ func (c *Client) SendMarkdown(ctx context.Context, roomID, markdown string) erro return err } +// SendMarkdownGetID sends a formatted (Markdown) message to a room and returns +// the event ID of the sent message. Useful for later editing via EditMessage. +func (c *Client) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) { + html := mdToHTML(markdown) + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: markdown, + Format: event.FormatHTML, + FormattedBody: html, + } + resp, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + if err != nil { + return "", err + } + return resp.EventID.String(), nil +} + +// EditMessage edits a previously sent message in a room using m.replace relation. +// originalEventID is the event ID of the message to replace. +// The new content is rendered from markdown. +func (c *Client) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error { + html := mdToHTML(markdown) + + // Matrix spec: m.new_content holds the replacement, m.relates_to with + // rel_type=m.replace points to the original event. + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: "* " + markdown, // per spec: prefix with "* " for fallback + Format: event.FormatHTML, + FormattedBody: "* " + html, + RelatesTo: &event.RelatesTo{ + Type: event.RelReplace, + EventID: id.EventID(originalEventID), + }, + NewContent: &event.MessageEventContent{ + MsgType: event.MsgText, + Body: markdown, + Format: event.FormatHTML, + FormattedBody: html, + }, + } + _, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + return err +} + // mdToHTML converts a Markdown string to HTML using goldmark with full extensions. var mdParser = goldmark.New( goldmark.WithExtensions(