diff --git a/devagents/handler.go b/devagents/handler.go index 8e4d638..6b73873 100644 --- a/devagents/handler.go +++ b/devagents/handler.go @@ -13,6 +13,7 @@ import ( "github.com/enmanuel/agents/pkg/sanitize" "github.com/enmanuel/agents/shell/audit" "github.com/enmanuel/agents/shell/bus" + "github.com/enmanuel/agents/shell/effects" ) // handleEvent is called by the matrix Listener for each filtered incoming event. @@ -184,14 +185,28 @@ func (a *Agent) executeActions(ctx context.Context, roomID string, msgCtx decisi }) a.persistMessage(ctx, memKey, coretypes.RoleUser, msgCtx.Content) - reply, err := a.runLLM(ctx, msgCtx, memKey) + // Create ProgressReporter for claude-code streaming if enabled + var progress *effects.ProgressReporter + if a.isStreamingEnabled() { + progress = effects.NewProgressReporter(a.sender, roomID, a.logger) + } + + reply, err := a.runLLM(ctx, msgCtx, memKey, progress) if err != nil { a.logger.Error("llm error", "err", err) + if progress != nil { + progress.Finalize("\u274c Error al procesar la solicitud.") + } expanded = append(expanded, decision.Action{ Kind: decision.ActionKindReply, Reply: &decision.ReplyAction{Content: "Sorry, I encountered an error.", InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID}, }) } else { + // If progress reporter was used, finalize it with a done indicator + if progress != nil && progress.EventID() != "" { + progress.Finalize("\u2705 *Completado*") + } + expanded = append(expanded, decision.Action{ Kind: decision.ActionKindReply, Reply: &decision.ReplyAction{Content: reply, InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID}, @@ -295,7 +310,7 @@ func (a *Agent) handleTaskEvent(ctx context.Context, msg bus.AgentMessage) { Role: coretypes.RoleUser, Content: msgCtx.Content, }) - reply, err := a.runLLM(ctx, msgCtx, roomID) + reply, err := a.runLLM(ctx, msgCtx, roomID, nil) // Build the result to send back via bus result := orchestration.TaskResult{ @@ -368,6 +383,14 @@ func (a *Agent) emitAudit(evt audit.Event) { } } +// isStreamingEnabled returns true when the agent uses claude-code provider +// with streaming and show_tool_progress both enabled. +func (a *Agent) isStreamingEnabled() bool { + return a.cfg.LLM.Primary.Provider == "claude-code" && + a.cfg.LLM.Primary.ClaudeCode.Streaming && + a.cfg.LLM.Primary.ClaudeCode.ShowToolProgress +} + // sanitizeInput runs prompt injection detection on the message content. // Returns the (possibly modified) content and true if the message should be rejected. func (a *Agent) sanitizeInput(content, roomID, senderID string) (string, bool) { diff --git a/devagents/llm.go b/devagents/llm.go index f1ebc27..6cd1762 100644 --- a/devagents/llm.go +++ b/devagents/llm.go @@ -13,11 +13,14 @@ import ( coretypes "github.com/enmanuel/agents/pkg/llm" "github.com/enmanuel/agents/pkg/personality" "github.com/enmanuel/agents/shell/audit" + "github.com/enmanuel/agents/shell/effects" shelllm "github.com/enmanuel/agents/shell/llm" ) // runLLM executes the LLM completion loop, including iterative tool-use. -func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string) (string, error) { +// progress may be nil; when non-nil, its StreamFunc is attached to the request +// for providers that support streaming (claude-code). +func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string, progress *effects.ProgressReporter) (string, error) { a.logger.Debug("calling LLM", "model", a.cfg.LLM.Primary.Model, "provider", a.cfg.LLM.Primary.Provider, @@ -62,6 +65,12 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK maxIter = defaultMaxToolIterations } + // Resolve StreamFunc for providers that support streaming + var streamFn coretypes.StreamFunc + if progress != nil { + streamFn = progress.StreamFunc() + } + // Tool-use loop: call LLM → execute tools → feed results back → repeat for i := 0; i < maxIter; i++ { req := coretypes.CompletionRequest{ @@ -71,6 +80,7 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK SystemPrompt: systemPrompt, Messages: messages, Tools: llmTools, + StreamFunc: streamFn, } resp, err := a.llm(ctx, req) diff --git a/devagents/runtime_test.go b/devagents/runtime_test.go index e4d152f..caaaa6e 100644 --- a/devagents/runtime_test.go +++ b/devagents/runtime_test.go @@ -82,6 +82,20 @@ func (s *spyMatrixSender) SendMarkdown(_ context.Context, roomID, markdown strin return nil } +func (s *spyMatrixSender) SendMarkdownGetID(_ context.Context, roomID, markdown string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown}) + return "$spy_event_id", nil +} + +func (s *spyMatrixSender) EditMessage(_ context.Context, roomID, originalEventID, markdown string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown, inReplyTo: originalEventID}) + return nil +} + func (s *spyMatrixSender) SendReplyMarkdown(_ context.Context, roomID, inReplyTo, markdown string) error { s.mu.Lock() defer s.mu.Unlock() @@ -590,7 +604,7 @@ func TestRunLLM_ToolCallExecutesAndReturns(t *testing.T) { IsDirectMsg: true, } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -655,7 +669,7 @@ func TestRunLLM_ToolCallFailsPassesErrorToLLM(t *testing.T) { Content: "do something", } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -716,7 +730,7 @@ func TestRunLLM_MaxIterationsRespected(t *testing.T) { Content: "loop please", } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -776,7 +790,7 @@ func TestRunLLM_RBACDeniesToolCall(t *testing.T) { Content: "use restricted tool", } - reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err != nil { t.Fatalf("runLLM error: %v", err) } @@ -819,7 +833,7 @@ func TestRunLLM_LLMError(t *testing.T) { Content: "hello", } - _, err := a.runLLM(context.Background(), msgCtx, "!room:example.com") + _, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil) if err == nil { t.Fatal("expected error from LLM, got nil") } diff --git a/shell/effects/progress.go b/shell/effects/progress.go new file mode 100644 index 0000000..b26031c --- /dev/null +++ b/shell/effects/progress.go @@ -0,0 +1,144 @@ +package effects + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + coretypes "github.com/enmanuel/agents/pkg/llm" +) + +// ProgressReporter sends real-time progress updates to a Matrix room +// by editing a single "status" message as the claude-code subprocess +// emits streaming events (tool_use, text, result). +// +// It rate-limits edits to at most one per second to avoid flooding the +// homeserver. +type ProgressReporter struct { + sender MatrixSender + roomID string + logger *slog.Logger + + mu sync.Mutex + eventID string // Matrix event ID of the progress message (empty until first send) + lastEdit time.Time // timestamp of last edit, for rate limiting + minInterval time.Duration +} + +// NewProgressReporter creates a ProgressReporter that sends progress updates +// to the given room. The progress message is created lazily on the first event. +func NewProgressReporter(sender MatrixSender, roomID string, logger *slog.Logger) *ProgressReporter { + return &ProgressReporter{ + sender: sender, + roomID: roomID, + logger: logger, + minInterval: time.Second, // max 1 edit/second + } +} + +// StreamFunc returns a StreamFunc callback suitable for passing to +// CompletionRequest.StreamFunc. It captures streaming events and updates +// the progress message in the Matrix room. +func (p *ProgressReporter) StreamFunc() coretypes.StreamFunc { + return func(evt coretypes.StreamEvent) { + p.handleEvent(evt) + } +} + +// handleEvent processes a single streaming event and updates the Matrix message. +func (p *ProgressReporter) handleEvent(evt coretypes.StreamEvent) { + var markdown string + + switch evt.Kind { + case coretypes.StreamToolUse: + // Show which tool is being used + input := evt.ToolInput + if len(input) > 60 { + input = input[:57] + "..." + } + if input != "" { + markdown = fmt.Sprintf("\U0001f527 *%s*: `%s`", evt.ToolName, input) + } else { + markdown = fmt.Sprintf("\U0001f527 *%s*", evt.ToolName) + } + + case coretypes.StreamResult: + // Final result — no need to update progress; the handler will send the actual reply + return + + case coretypes.StreamText: + // Intermediate text — could be partial thinking, skip to avoid noise + return + + case coretypes.StreamInit: + markdown = "\u2699\ufe0f *Procesando...*" + + default: + return + } + + if markdown == "" { + return + } + + p.updateMessage(markdown) +} + +// updateMessage sends or edits the progress message, respecting rate limits. +func (p *ProgressReporter) updateMessage(markdown string) { + p.mu.Lock() + defer p.mu.Unlock() + + ctx := context.Background() + + // Rate limit: skip if we edited less than minInterval ago + if p.eventID != "" && time.Since(p.lastEdit) < p.minInterval { + return + } + + if p.eventID == "" { + // First message: send a new one and capture the event ID + evtID, err := p.sender.SendMarkdownGetID(ctx, p.roomID, markdown) + if err != nil { + p.logger.Warn("progress_reporter: failed to send initial message", "err", err) + return + } + p.eventID = evtID + p.lastEdit = time.Now() + return + } + + // Subsequent updates: edit the existing message + if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, markdown); err != nil { + p.logger.Warn("progress_reporter: failed to edit message", "err", err) + return + } + p.lastEdit = time.Now() +} + +// Finalize edits the progress message with the final content, or deletes it. +// Call this after the LLM response is ready. If finalMarkdown is empty, the +// progress message is left as-is (the handler will send a separate reply). +func (p *ProgressReporter) Finalize(finalMarkdown string) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.eventID == "" || finalMarkdown == "" { + return + } + + ctx := context.Background() + if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, finalMarkdown); err != nil { + p.logger.Warn("progress_reporter: failed to finalize message", "err", err) + } +} + +// EventID returns the Matrix event ID of the progress message, or empty if +// no message was sent yet. +func (p *ProgressReporter) EventID() string { + p.mu.Lock() + defer p.mu.Unlock() + return p.eventID +} diff --git a/shell/effects/progress_test.go b/shell/effects/progress_test.go new file mode 100644 index 0000000..9be9338 --- /dev/null +++ b/shell/effects/progress_test.go @@ -0,0 +1,226 @@ +package effects + +import ( + "context" + "log/slog" + "strings" + "testing" + "time" + + coretypes "github.com/enmanuel/agents/pkg/llm" +) + +// mockProgressSender records sends and edits for testing ProgressReporter. +type mockProgressSender struct { + fakeMatrixSender // embed to satisfy the full interface + sends []string // markdowns from SendMarkdownGetID + edits []string // markdowns from EditMessage + editTargets []string // event IDs targeted by EditMessage +} + +func (m *mockProgressSender) SendMarkdownGetID(_ context.Context, _, markdown string) (string, error) { + m.sends = append(m.sends, markdown) + return "$progress_msg_1", nil +} + +func (m *mockProgressSender) EditMessage(_ context.Context, _, originalEventID, markdown string) error { + m.edits = append(m.edits, markdown) + m.editTargets = append(m.editTargets, originalEventID) + return nil +} + +func TestProgressReporter_InitEvent(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + if len(sender.sends) != 1 { + t.Fatalf("expected 1 send, got %d", len(sender.sends)) + } + if !strings.Contains(sender.sends[0], "Procesando") { + t.Errorf("init message = %q, should contain 'Procesando'", sender.sends[0]) + } + if pr.EventID() != "$progress_msg_1" { + t.Errorf("EventID = %q, want %q", pr.EventID(), "$progress_msg_1") + } +} + +func TestProgressReporter_ToolUseEditsMessage(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 // disable rate limiting for test + + fn := pr.StreamFunc() + + // First event creates the message + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // Second event should edit + fn(coretypes.StreamEvent{ + Kind: coretypes.StreamToolUse, + ToolName: "Bash", + ToolInput: "ls -la", + }) + + if len(sender.edits) != 1 { + t.Fatalf("expected 1 edit, got %d", len(sender.edits)) + } + if !strings.Contains(sender.edits[0], "Bash") { + t.Errorf("edit = %q, should contain tool name", sender.edits[0]) + } + if !strings.Contains(sender.edits[0], "ls -la") { + t.Errorf("edit = %q, should contain tool input", sender.edits[0]) + } + if sender.editTargets[0] != "$progress_msg_1" { + t.Errorf("edit target = %q, want %q", sender.editTargets[0], "$progress_msg_1") + } +} + +func TestProgressReporter_MultipleToolUse(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 // disable rate limiting for test + + fn := pr.StreamFunc() + + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "/tmp/file.go"}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "/tmp/file.go"}) + + // 1 send (init) + 3 edits (tool uses) + if len(sender.sends) != 1 { + t.Errorf("expected 1 send, got %d", len(sender.sends)) + } + if len(sender.edits) != 3 { + t.Errorf("expected 3 edits, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_RateLimiting(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 500 * time.Millisecond + + fn := pr.StreamFunc() + + // First event creates the message (no rate limit on first send) + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // Reset lastEdit to simulate time having passed after init + pr.mu.Lock() + pr.lastEdit = time.Now().Add(-time.Second) + pr.mu.Unlock() + + // First tool event should go through (enough time has passed) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"}) + // These rapid-fire events should be rate-limited + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "file.go"}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "file.go"}) + + // Only 1 edit should have gone through (the rest rate limited) + if len(sender.edits) != 1 { + t.Errorf("expected 1 edit (rate limited), got %d", len(sender.edits)) + } +} + +func TestProgressReporter_ResultIgnored(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + fn := pr.StreamFunc() + + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamResult, Content: "Final answer"}) + + // Result should not trigger an edit + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits for result event, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_TextIgnored(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + fn := pr.StreamFunc() + + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + fn(coretypes.StreamEvent{Kind: coretypes.StreamText, Content: "Some thinking..."}) + + // Text events should not trigger edits + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits for text event, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_Finalize(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + pr.Finalize("Done! Here is the result.") + + if len(sender.edits) != 1 { + t.Fatalf("expected 1 edit for finalize, got %d", len(sender.edits)) + } + if sender.edits[0] != "Done! Here is the result." { + t.Errorf("finalize edit = %q", sender.edits[0]) + } +} + +func TestProgressReporter_FinalizeNoMessage(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + + // Finalize without ever sending a message should be a no-op + pr.Finalize("Final") + + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits when no message was sent, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_FinalizeEmpty(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + // Empty finalize should be a no-op + pr.Finalize("") + + if len(sender.edits) != 0 { + t.Errorf("expected 0 edits for empty finalize, got %d", len(sender.edits)) + } +} + +func TestProgressReporter_ToolInputTruncation(t *testing.T) { + sender := &mockProgressSender{} + pr := NewProgressReporter(sender, "!room:test", slog.Default()) + pr.minInterval = 0 + + fn := pr.StreamFunc() + fn(coretypes.StreamEvent{Kind: coretypes.StreamInit}) + + longInput := strings.Repeat("x", 100) + fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: longInput}) + + if len(sender.edits) != 1 { + t.Fatalf("expected 1 edit, got %d", len(sender.edits)) + } + // The input in the message should be truncated + if strings.Contains(sender.edits[0], longInput) { + t.Error("long input should be truncated in the message") + } + if !strings.Contains(sender.edits[0], "...") { + t.Error("truncated input should end with ...") + } +} diff --git a/shell/effects/runner.go b/shell/effects/runner.go index 5788df6..c0b7ff6 100644 --- a/shell/effects/runner.go +++ b/shell/effects/runner.go @@ -23,6 +23,8 @@ type Result struct { type MatrixSender interface { SendText(ctx context.Context, roomID, text string) error SendMarkdown(ctx context.Context, roomID, markdown string) error + SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) + EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error SendTyping(ctx context.Context, roomID string, typing bool) error diff --git a/shell/effects/runner_test.go b/shell/effects/runner_test.go index b5f0547..aee9ca6 100644 --- a/shell/effects/runner_test.go +++ b/shell/effects/runner_test.go @@ -31,6 +31,16 @@ func (f *fakeMatrixSender) SendMarkdown(ctx context.Context, roomID, markdown st return nil } +func (f *fakeMatrixSender) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) { + f.calls = append(f.calls, senderCall{method: "SendMarkdownGetID", roomID: roomID, markdown: markdown}) + return "$fake_event_id", nil +} + +func (f *fakeMatrixSender) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error { + f.calls = append(f.calls, senderCall{method: "EditMessage", roomID: roomID, inReplyTo: originalEventID, markdown: markdown}) + return nil +} + func (f *fakeMatrixSender) SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error { f.calls = append(f.calls, senderCall{method: "SendReplyMarkdown", roomID: roomID, inReplyTo: inReplyTo, markdown: markdown}) return nil diff --git a/shell/matrix/client.go b/shell/matrix/client.go index 454ac54..cad814a 100644 --- a/shell/matrix/client.go +++ b/shell/matrix/client.go @@ -286,6 +286,51 @@ func (c *Client) SendMarkdown(ctx context.Context, roomID, markdown string) erro return err } +// SendMarkdownGetID sends a formatted (Markdown) message to a room and returns +// the event ID of the sent message. Useful for later editing via EditMessage. +func (c *Client) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) { + html := mdToHTML(markdown) + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: markdown, + Format: event.FormatHTML, + FormattedBody: html, + } + resp, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + if err != nil { + return "", err + } + return resp.EventID.String(), nil +} + +// EditMessage edits a previously sent message in a room using m.replace relation. +// originalEventID is the event ID of the message to replace. +// The new content is rendered from markdown. +func (c *Client) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error { + html := mdToHTML(markdown) + + // Matrix spec: m.new_content holds the replacement, m.relates_to with + // rel_type=m.replace points to the original event. + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: "* " + markdown, // per spec: prefix with "* " for fallback + Format: event.FormatHTML, + FormattedBody: "* " + html, + RelatesTo: &event.RelatesTo{ + Type: event.RelReplace, + EventID: id.EventID(originalEventID), + }, + NewContent: &event.MessageEventContent{ + MsgType: event.MsgText, + Body: markdown, + Format: event.FormatHTML, + FormattedBody: html, + }, + } + _, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + return err +} + // mdToHTML converts a Markdown string to HTML using goldmark with full extensions. var mdParser = goldmark.New( goldmark.WithExtensions(