merge: issue/0036-claude-code-streaming — implementación paralela

This commit is contained in:
2026-04-09 23:00:21 +00:00
15 changed files with 1384 additions and 47 deletions
+2
View File
@@ -84,6 +84,8 @@ llm:
fallback_model: ""
session_id: ""
add_dirs: []
streaming: false # true para usar --output-format stream-json (progreso en tiempo real)
show_tool_progress: false # true para mostrar en Matrix que herramientas usa el agente
fallback:
provider: ""
+25 -2
View File
@@ -13,6 +13,7 @@ import (
"github.com/enmanuel/agents/pkg/sanitize"
"github.com/enmanuel/agents/shell/audit"
"github.com/enmanuel/agents/shell/bus"
"github.com/enmanuel/agents/shell/effects"
)
// handleEvent is called by the matrix Listener for each filtered incoming event.
@@ -184,14 +185,28 @@ func (a *Agent) executeActions(ctx context.Context, roomID string, msgCtx decisi
})
a.persistMessage(ctx, memKey, coretypes.RoleUser, msgCtx.Content)
reply, err := a.runLLM(ctx, msgCtx, memKey)
// Create ProgressReporter for claude-code streaming if enabled
var progress *effects.ProgressReporter
if a.isStreamingEnabled() {
progress = effects.NewProgressReporter(a.sender, roomID, a.logger)
}
reply, err := a.runLLM(ctx, msgCtx, memKey, progress)
if err != nil {
a.logger.Error("llm error", "err", err)
if progress != nil {
progress.Finalize("\u274c Error al procesar la solicitud.")
}
expanded = append(expanded, decision.Action{
Kind: decision.ActionKindReply,
Reply: &decision.ReplyAction{Content: "Sorry, I encountered an error.", InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID},
})
} else {
// If progress reporter was used, finalize it with a done indicator
if progress != nil && progress.EventID() != "" {
progress.Finalize("\u2705 *Completado*")
}
expanded = append(expanded, decision.Action{
Kind: decision.ActionKindReply,
Reply: &decision.ReplyAction{Content: reply, InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID},
@@ -295,7 +310,7 @@ func (a *Agent) handleTaskEvent(ctx context.Context, msg bus.AgentMessage) {
Role: coretypes.RoleUser, Content: msgCtx.Content,
})
reply, err := a.runLLM(ctx, msgCtx, roomID)
reply, err := a.runLLM(ctx, msgCtx, roomID, nil)
// Build the result to send back via bus
result := orchestration.TaskResult{
@@ -368,6 +383,14 @@ func (a *Agent) emitAudit(evt audit.Event) {
}
}
// isStreamingEnabled returns true when the agent uses claude-code provider
// with streaming and show_tool_progress both enabled.
func (a *Agent) isStreamingEnabled() bool {
return a.cfg.LLM.Primary.Provider == "claude-code" &&
a.cfg.LLM.Primary.ClaudeCode.Streaming &&
a.cfg.LLM.Primary.ClaudeCode.ShowToolProgress
}
// sanitizeInput runs prompt injection detection on the message content.
// Returns the (possibly modified) content and true if the message should be rejected.
func (a *Agent) sanitizeInput(content, roomID, senderID string) (string, bool) {
+11 -1
View File
@@ -13,11 +13,14 @@ import (
coretypes "github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/pkg/personality"
"github.com/enmanuel/agents/shell/audit"
"github.com/enmanuel/agents/shell/effects"
shelllm "github.com/enmanuel/agents/shell/llm"
)
// runLLM executes the LLM completion loop, including iterative tool-use.
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string) (string, error) {
// progress may be nil; when non-nil, its StreamFunc is attached to the request
// for providers that support streaming (claude-code).
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string, progress *effects.ProgressReporter) (string, error) {
a.logger.Debug("calling LLM",
"model", a.cfg.LLM.Primary.Model,
"provider", a.cfg.LLM.Primary.Provider,
@@ -62,6 +65,12 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK
maxIter = defaultMaxToolIterations
}
// Resolve StreamFunc for providers that support streaming
var streamFn coretypes.StreamFunc
if progress != nil {
streamFn = progress.StreamFunc()
}
// Tool-use loop: call LLM → execute tools → feed results back → repeat
for i := 0; i < maxIter; i++ {
req := coretypes.CompletionRequest{
@@ -71,6 +80,7 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK
SystemPrompt: systemPrompt,
Messages: messages,
Tools: llmTools,
StreamFunc: streamFn,
}
resp, err := a.llm(ctx, req)
+19 -5
View File
@@ -82,6 +82,20 @@ func (s *spyMatrixSender) SendMarkdown(_ context.Context, roomID, markdown strin
return nil
}
func (s *spyMatrixSender) SendMarkdownGetID(_ context.Context, roomID, markdown string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown})
return "$spy_event_id", nil
}
func (s *spyMatrixSender) EditMessage(_ context.Context, roomID, originalEventID, markdown string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown, inReplyTo: originalEventID})
return nil
}
func (s *spyMatrixSender) SendReplyMarkdown(_ context.Context, roomID, inReplyTo, markdown string) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -590,7 +604,7 @@ func TestRunLLM_ToolCallExecutesAndReturns(t *testing.T) {
IsDirectMsg: true,
}
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
if err != nil {
t.Fatalf("runLLM error: %v", err)
}
@@ -655,7 +669,7 @@ func TestRunLLM_ToolCallFailsPassesErrorToLLM(t *testing.T) {
Content: "do something",
}
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
if err != nil {
t.Fatalf("runLLM error: %v", err)
}
@@ -716,7 +730,7 @@ func TestRunLLM_MaxIterationsRespected(t *testing.T) {
Content: "loop please",
}
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
if err != nil {
t.Fatalf("runLLM error: %v", err)
}
@@ -776,7 +790,7 @@ func TestRunLLM_RBACDeniesToolCall(t *testing.T) {
Content: "use restricted tool",
}
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
if err != nil {
t.Fatalf("runLLM error: %v", err)
}
@@ -819,7 +833,7 @@ func TestRunLLM_LLMError(t *testing.T) {
Content: "hello",
}
_, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
_, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
if err == nil {
t.Fatal("expected error from LLM, got nil")
}
+13 -11
View File
@@ -112,17 +112,19 @@ type LLMProviderCfg struct {
// ClaudeCodeCfg configures the claude -p subprocess provider.
type ClaudeCodeCfg struct {
Binary string `yaml:"binary"` // path to claude binary (default: "claude")
Timeout time.Duration `yaml:"timeout"` // subprocess timeout (default: 5m)
DisableTools bool `yaml:"disable_tools"` // pass --tools "" to disable all internal tools
AllowedTools []string `yaml:"allowed_tools"` // tools claude -p can use internally (e.g. Bash, Read, Edit)
DisallowedTools []string `yaml:"disallowed_tools"` // tools to block
WorkingDir string `yaml:"working_dir"` // working directory for claude -p
PermissionMode string `yaml:"permission_mode"` // default, acceptEdits, bypassPermissions, plan
Model string `yaml:"model"` // inner model: sonnet, opus, haiku, or full name
FallbackModel string `yaml:"fallback_model"` // fallback model if primary is overloaded
SessionID string `yaml:"session_id"` // fixed session ID for continuity
AddDirs []string `yaml:"add_dirs"` // additional directories accessible
Binary string `yaml:"binary"` // path to claude binary (default: "claude")
Timeout time.Duration `yaml:"timeout"` // subprocess timeout (default: 5m)
DisableTools bool `yaml:"disable_tools"` // pass --tools "" to disable all internal tools
AllowedTools []string `yaml:"allowed_tools"` // tools claude -p can use internally (e.g. Bash, Read, Edit)
DisallowedTools []string `yaml:"disallowed_tools"` // tools to block
WorkingDir string `yaml:"working_dir"` // working directory for claude -p
PermissionMode string `yaml:"permission_mode"` // default, acceptEdits, bypassPermissions, plan
Model string `yaml:"model"` // inner model: sonnet, opus, haiku, or full name
FallbackModel string `yaml:"fallback_model"` // fallback model if primary is overloaded
SessionID string `yaml:"session_id"` // fixed session ID for continuity
AddDirs []string `yaml:"add_dirs"` // additional directories accessible
Streaming bool `yaml:"streaming"` // use --output-format stream-json for realtime progress
ShowToolProgress bool `yaml:"show_tool_progress"` // edit Matrix message to show tool usage progress
}
type LLMReasoningCfg struct {
+38 -6
View File
@@ -42,13 +42,14 @@ type ToolSpec struct {
}
type CompletionRequest struct {
Model string
Messages []Message
Tools []ToolSpec
MaxTokens int
Temperature float64
Stream bool
Model string
Messages []Message
Tools []ToolSpec
MaxTokens int
Temperature float64
Stream bool
SystemPrompt string
StreamFunc StreamFunc // optional: if set, streaming events are emitted during execution
}
type TokenUsage struct {
@@ -67,3 +68,34 @@ type CompletionResponse struct {
// CompleteFunc is the single contract for LLM providers.
// Implementations live in shell/llm/.
type CompleteFunc func(ctx context.Context, req CompletionRequest) (CompletionResponse, error)
// ── Streaming types (pure) ───────────────────────────────────────────────
// StreamEventKind identifies the kind of streaming event emitted by
// a claude-code subprocess running with --output-format stream-json.
type StreamEventKind string
const (
StreamInit StreamEventKind = "init"
StreamToolUse StreamEventKind = "tool_use"
StreamToolResult StreamEventKind = "tool_result"
StreamText StreamEventKind = "text"
StreamResult StreamEventKind = "result"
StreamError StreamEventKind = "error"
)
// StreamEvent carries a single streaming event from the claude subprocess.
// Fields are populated based on Kind; not all fields are valid for all kinds.
type StreamEvent struct {
Kind StreamEventKind
ToolName string // tool_use: name of the tool being invoked
ToolInput string // tool_use: truncated input description
Content string // text/result: textual content
IsError bool // result: whether the result indicates an error
Error error // error: the error that occurred
}
// StreamFunc is the callback invoked for each streaming event.
// Implementations must be safe for concurrent use (typically not needed
// since the streaming loop calls sequentially).
type StreamFunc func(event StreamEvent)
+144
View File
@@ -0,0 +1,144 @@
package effects
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
// ProgressReporter sends real-time progress updates to a Matrix room
// by editing a single "status" message as the claude-code subprocess
// emits streaming events (tool_use, text, result).
//
// It rate-limits edits to at most one per second to avoid flooding the
// homeserver.
type ProgressReporter struct {
sender MatrixSender
roomID string
logger *slog.Logger
mu sync.Mutex
eventID string // Matrix event ID of the progress message (empty until first send)
lastEdit time.Time // timestamp of last edit, for rate limiting
minInterval time.Duration
}
// NewProgressReporter creates a ProgressReporter that sends progress updates
// to the given room. The progress message is created lazily on the first event.
func NewProgressReporter(sender MatrixSender, roomID string, logger *slog.Logger) *ProgressReporter {
return &ProgressReporter{
sender: sender,
roomID: roomID,
logger: logger,
minInterval: time.Second, // max 1 edit/second
}
}
// StreamFunc returns a StreamFunc callback suitable for passing to
// CompletionRequest.StreamFunc. It captures streaming events and updates
// the progress message in the Matrix room.
func (p *ProgressReporter) StreamFunc() coretypes.StreamFunc {
return func(evt coretypes.StreamEvent) {
p.handleEvent(evt)
}
}
// handleEvent processes a single streaming event and updates the Matrix message.
func (p *ProgressReporter) handleEvent(evt coretypes.StreamEvent) {
var markdown string
switch evt.Kind {
case coretypes.StreamToolUse:
// Show which tool is being used
input := evt.ToolInput
if len(input) > 60 {
input = input[:57] + "..."
}
if input != "" {
markdown = fmt.Sprintf("\U0001f527 *%s*: `%s`", evt.ToolName, input)
} else {
markdown = fmt.Sprintf("\U0001f527 *%s*", evt.ToolName)
}
case coretypes.StreamResult:
// Final result — no need to update progress; the handler will send the actual reply
return
case coretypes.StreamText:
// Intermediate text — could be partial thinking, skip to avoid noise
return
case coretypes.StreamInit:
markdown = "\u2699\ufe0f *Procesando...*"
default:
return
}
if markdown == "" {
return
}
p.updateMessage(markdown)
}
// updateMessage sends or edits the progress message, respecting rate limits.
func (p *ProgressReporter) updateMessage(markdown string) {
p.mu.Lock()
defer p.mu.Unlock()
ctx := context.Background()
// Rate limit: skip if we edited less than minInterval ago
if p.eventID != "" && time.Since(p.lastEdit) < p.minInterval {
return
}
if p.eventID == "" {
// First message: send a new one and capture the event ID
evtID, err := p.sender.SendMarkdownGetID(ctx, p.roomID, markdown)
if err != nil {
p.logger.Warn("progress_reporter: failed to send initial message", "err", err)
return
}
p.eventID = evtID
p.lastEdit = time.Now()
return
}
// Subsequent updates: edit the existing message
if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, markdown); err != nil {
p.logger.Warn("progress_reporter: failed to edit message", "err", err)
return
}
p.lastEdit = time.Now()
}
// Finalize edits the progress message with the final content, or deletes it.
// Call this after the LLM response is ready. If finalMarkdown is empty, the
// progress message is left as-is (the handler will send a separate reply).
func (p *ProgressReporter) Finalize(finalMarkdown string) {
p.mu.Lock()
defer p.mu.Unlock()
if p.eventID == "" || finalMarkdown == "" {
return
}
ctx := context.Background()
if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, finalMarkdown); err != nil {
p.logger.Warn("progress_reporter: failed to finalize message", "err", err)
}
}
// EventID returns the Matrix event ID of the progress message, or empty if
// no message was sent yet.
func (p *ProgressReporter) EventID() string {
p.mu.Lock()
defer p.mu.Unlock()
return p.eventID
}
+162
View File
@@ -0,0 +1,162 @@
package effects
import (
"context"
"log/slog"
"strings"
"testing"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
// TestIntegration_StreamToProgressReporter simulates the full flow:
// parseStreamLine produces events → ProgressReporter consumes them → mock sender records calls.
// This validates the complete pipeline from raw JSON lines to Matrix messages.
func TestIntegration_StreamToProgressReporter(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:integration", slog.Default())
pr.minInterval = 0 // disable rate limiting for test
fn := pr.StreamFunc()
// Simulate a realistic stream-json session:
// 1. Init event
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
// 2. Tool use: Bash
fn(coretypes.StreamEvent{
Kind: coretypes.StreamToolUse,
ToolName: "Bash",
ToolInput: "git status",
})
// 3. Tool use: Read
fn(coretypes.StreamEvent{
Kind: coretypes.StreamToolUse,
ToolName: "Read",
ToolInput: "/home/user/project/main.go",
})
// 4. Tool use: Edit
fn(coretypes.StreamEvent{
Kind: coretypes.StreamToolUse,
ToolName: "Edit",
ToolInput: "/home/user/project/main.go",
})
// 5. Text event (intermediate, should be ignored)
fn(coretypes.StreamEvent{
Kind: coretypes.StreamText,
Content: "I've made the changes...",
})
// 6. Result event (should be ignored by progress reporter)
fn(coretypes.StreamEvent{
Kind: coretypes.StreamResult,
Content: "Here is the final answer with all changes applied.",
})
// Verify sends: only 1 initial send
if len(sender.sends) != 1 {
t.Errorf("expected 1 send (init), got %d", len(sender.sends))
}
if !strings.Contains(sender.sends[0], "Procesando") {
t.Errorf("init message should contain 'Procesando', got %q", sender.sends[0])
}
// Verify edits: 3 tool use events
if len(sender.edits) != 3 {
t.Fatalf("expected 3 edits (tool uses), got %d", len(sender.edits))
}
// First edit: Bash
if !strings.Contains(sender.edits[0], "Bash") {
t.Errorf("edit[0] should mention Bash, got %q", sender.edits[0])
}
if !strings.Contains(sender.edits[0], "git status") {
t.Errorf("edit[0] should show input, got %q", sender.edits[0])
}
// Second edit: Read
if !strings.Contains(sender.edits[1], "Read") {
t.Errorf("edit[1] should mention Read, got %q", sender.edits[1])
}
// Third edit: Edit
if !strings.Contains(sender.edits[2], "Edit") {
t.Errorf("edit[2] should mention Edit, got %q", sender.edits[2])
}
// All edits should target the same event ID
for i, target := range sender.editTargets {
if target != "$progress_msg_1" {
t.Errorf("editTarget[%d] = %q, want %q", i, target, "$progress_msg_1")
}
}
// Finalize
pr.Finalize("\u2705 *Completado*")
if len(sender.edits) != 4 {
t.Fatalf("expected 4 edits (3 tools + 1 finalize), got %d", len(sender.edits))
}
if !strings.Contains(sender.edits[3], "Completado") {
t.Errorf("finalize edit should contain 'Completado', got %q", sender.edits[3])
}
}
// TestIntegration_NoStreamingNoSideEffects verifies that when streaming is
// not enabled, no ProgressReporter is created and no Matrix side effects occur.
// This is a regression test for the streaming=false default behavior.
func TestIntegration_NoStreamingNoSideEffects(t *testing.T) {
sender := &mockProgressSender{}
// Simulate the handler check: streaming disabled → progress is nil
var progress *ProgressReporter // nil, because streaming is disabled
if progress != nil {
t.Error("progress reporter should be nil when streaming is disabled")
}
// Verify no sends or edits happened
if len(sender.sends) != 0 {
t.Errorf("expected 0 sends, got %d", len(sender.sends))
}
if len(sender.edits) != 0 {
t.Errorf("expected 0 edits, got %d", len(sender.edits))
}
}
// TestIntegration_ProgressReporterWithSendError verifies that the reporter
// handles send errors gracefully without panicking.
func TestIntegration_ProgressReporterWithSendError(t *testing.T) {
sender := &errorSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 0
fn := pr.StreamFunc()
// Should not panic even when send fails
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
// EventID should be empty since send failed
if pr.EventID() != "" {
t.Errorf("expected empty EventID after send error, got %q", pr.EventID())
}
// Finalize should be a no-op since no message was sent
pr.Finalize("Done")
}
// errorSender always returns errors.
type errorSender struct {
fakeMatrixSender
}
func (e *errorSender) SendMarkdownGetID(_ context.Context, _, _ string) (string, error) {
return "", context.DeadlineExceeded
}
func (e *errorSender) EditMessage(_ context.Context, _, _, _ string) error {
return context.DeadlineExceeded
}
+226
View File
@@ -0,0 +1,226 @@
package effects
import (
"context"
"log/slog"
"strings"
"testing"
"time"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
// mockProgressSender records sends and edits for testing ProgressReporter.
type mockProgressSender struct {
fakeMatrixSender // embed to satisfy the full interface
sends []string // markdowns from SendMarkdownGetID
edits []string // markdowns from EditMessage
editTargets []string // event IDs targeted by EditMessage
}
func (m *mockProgressSender) SendMarkdownGetID(_ context.Context, _, markdown string) (string, error) {
m.sends = append(m.sends, markdown)
return "$progress_msg_1", nil
}
func (m *mockProgressSender) EditMessage(_ context.Context, _, originalEventID, markdown string) error {
m.edits = append(m.edits, markdown)
m.editTargets = append(m.editTargets, originalEventID)
return nil
}
func TestProgressReporter_InitEvent(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
if len(sender.sends) != 1 {
t.Fatalf("expected 1 send, got %d", len(sender.sends))
}
if !strings.Contains(sender.sends[0], "Procesando") {
t.Errorf("init message = %q, should contain 'Procesando'", sender.sends[0])
}
if pr.EventID() != "$progress_msg_1" {
t.Errorf("EventID = %q, want %q", pr.EventID(), "$progress_msg_1")
}
}
func TestProgressReporter_ToolUseEditsMessage(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 0 // disable rate limiting for test
fn := pr.StreamFunc()
// First event creates the message
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
// Second event should edit
fn(coretypes.StreamEvent{
Kind: coretypes.StreamToolUse,
ToolName: "Bash",
ToolInput: "ls -la",
})
if len(sender.edits) != 1 {
t.Fatalf("expected 1 edit, got %d", len(sender.edits))
}
if !strings.Contains(sender.edits[0], "Bash") {
t.Errorf("edit = %q, should contain tool name", sender.edits[0])
}
if !strings.Contains(sender.edits[0], "ls -la") {
t.Errorf("edit = %q, should contain tool input", sender.edits[0])
}
if sender.editTargets[0] != "$progress_msg_1" {
t.Errorf("edit target = %q, want %q", sender.editTargets[0], "$progress_msg_1")
}
}
func TestProgressReporter_MultipleToolUse(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 0 // disable rate limiting for test
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"})
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "/tmp/file.go"})
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "/tmp/file.go"})
// 1 send (init) + 3 edits (tool uses)
if len(sender.sends) != 1 {
t.Errorf("expected 1 send, got %d", len(sender.sends))
}
if len(sender.edits) != 3 {
t.Errorf("expected 3 edits, got %d", len(sender.edits))
}
}
func TestProgressReporter_RateLimiting(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 500 * time.Millisecond
fn := pr.StreamFunc()
// First event creates the message (no rate limit on first send)
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
// Reset lastEdit to simulate time having passed after init
pr.mu.Lock()
pr.lastEdit = time.Now().Add(-time.Second)
pr.mu.Unlock()
// First tool event should go through (enough time has passed)
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"})
// These rapid-fire events should be rate-limited
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "file.go"})
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "file.go"})
// Only 1 edit should have gone through (the rest rate limited)
if len(sender.edits) != 1 {
t.Errorf("expected 1 edit (rate limited), got %d", len(sender.edits))
}
}
func TestProgressReporter_ResultIgnored(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
fn(coretypes.StreamEvent{Kind: coretypes.StreamResult, Content: "Final answer"})
// Result should not trigger an edit
if len(sender.edits) != 0 {
t.Errorf("expected 0 edits for result event, got %d", len(sender.edits))
}
}
func TestProgressReporter_TextIgnored(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
fn(coretypes.StreamEvent{Kind: coretypes.StreamText, Content: "Some thinking..."})
// Text events should not trigger edits
if len(sender.edits) != 0 {
t.Errorf("expected 0 edits for text event, got %d", len(sender.edits))
}
}
func TestProgressReporter_Finalize(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 0
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
pr.Finalize("Done! Here is the result.")
if len(sender.edits) != 1 {
t.Fatalf("expected 1 edit for finalize, got %d", len(sender.edits))
}
if sender.edits[0] != "Done! Here is the result." {
t.Errorf("finalize edit = %q", sender.edits[0])
}
}
func TestProgressReporter_FinalizeNoMessage(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
// Finalize without ever sending a message should be a no-op
pr.Finalize("Final")
if len(sender.edits) != 0 {
t.Errorf("expected 0 edits when no message was sent, got %d", len(sender.edits))
}
}
func TestProgressReporter_FinalizeEmpty(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 0
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
// Empty finalize should be a no-op
pr.Finalize("")
if len(sender.edits) != 0 {
t.Errorf("expected 0 edits for empty finalize, got %d", len(sender.edits))
}
}
func TestProgressReporter_ToolInputTruncation(t *testing.T) {
sender := &mockProgressSender{}
pr := NewProgressReporter(sender, "!room:test", slog.Default())
pr.minInterval = 0
fn := pr.StreamFunc()
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
longInput := strings.Repeat("x", 100)
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: longInput})
if len(sender.edits) != 1 {
t.Fatalf("expected 1 edit, got %d", len(sender.edits))
}
// The input in the message should be truncated
if strings.Contains(sender.edits[0], longInput) {
t.Error("long input should be truncated in the message")
}
if !strings.Contains(sender.edits[0], "...") {
t.Error("truncated input should end with ...")
}
}
+2
View File
@@ -23,6 +23,8 @@ type Result struct {
type MatrixSender interface {
SendText(ctx context.Context, roomID, text string) error
SendMarkdown(ctx context.Context, roomID, markdown string) error
SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error)
EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error
SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error
SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error
SendTyping(ctx context.Context, roomID string, typing bool) error
+10
View File
@@ -31,6 +31,16 @@ func (f *fakeMatrixSender) SendMarkdown(ctx context.Context, roomID, markdown st
return nil
}
func (f *fakeMatrixSender) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) {
f.calls = append(f.calls, senderCall{method: "SendMarkdownGetID", roomID: roomID, markdown: markdown})
return "$fake_event_id", nil
}
func (f *fakeMatrixSender) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error {
f.calls = append(f.calls, senderCall{method: "EditMessage", roomID: roomID, inReplyTo: originalEventID, markdown: markdown})
return nil
}
func (f *fakeMatrixSender) SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error {
f.calls = append(f.calls, senderCall{method: "SendReplyMarkdown", roomID: roomID, inReplyTo: inReplyTo, markdown: markdown})
return nil
+316 -22
View File
@@ -1,6 +1,7 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -74,6 +75,7 @@ func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes
"args", strings.Join(args, " "),
"prompt_len", len(prompt),
"working_dir", workDir,
"streaming", cfg.Streaming,
)
cmd := exec.CommandContext(ctx, binary, args...)
@@ -99,31 +101,313 @@ func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes
return nil
}
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
start := time.Now()
err := cmd.Run()
elapsed := time.Since(start)
// Ensure the process group is fully dead after Run returns,
// even if cmd.Run() returned without triggering Cancel (normal exit).
if cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
// Choose between streaming and buffered mode
if cfg.Streaming && req.StreamFunc != nil {
return executeStreaming(ctx, cmd, req.StreamFunc, log)
}
log.Debug("claude_code_done",
"elapsed_ms", elapsed.Milliseconds(),
"stdout_len", stdout.Len(),
"stderr_len", stderr.Len(),
"exit_err", err,
)
return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log)
return executeBuffered(ctx, cmd, log)
}
}
// executeBuffered runs the claude subprocess and collects all output at once.
// This is the original (non-streaming) code path.
func executeBuffered(ctx context.Context, cmd *exec.Cmd, log *slog.Logger) (coretypes.CompletionResponse, error) {
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
start := time.Now()
err := cmd.Run()
elapsed := time.Since(start)
// Ensure the process group is fully dead after Run returns.
if cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}
log.Debug("claude_code_done",
"elapsed_ms", elapsed.Milliseconds(),
"stdout_len", stdout.Len(),
"stderr_len", stderr.Len(),
"exit_err", err,
)
return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log)
}
// executeStreaming runs the claude subprocess with --output-format stream-json,
// reads stdout line by line, emits StreamEvents via the callback, and accumulates
// the final result.
func executeStreaming(ctx context.Context, cmd *exec.Cmd, streamFn coretypes.StreamFunc, log *slog.Logger) (coretypes.CompletionResponse, error) {
stdout, err := cmd.StdoutPipe()
if err != nil {
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code: stdout pipe: %w", err)
}
var stderr bytes.Buffer
cmd.Stderr = &stderr
start := time.Now()
if err := cmd.Start(); err != nil {
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code: start: %w", err)
}
// Scan stdout line by line, parsing each JSON event
var lastResult *claudeJSONOutput
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 0, 256*1024), 1024*1024) // allow up to 1MB lines
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
evt, parsed, parseErr := parseStreamLine(line)
if parseErr != nil {
log.Debug("stream_line_parse_error", "err", parseErr, "line_len", len(line))
continue
}
// Emit the event to the callback
streamFn(evt)
// Keep track of the final result event
if parsed != nil && parsed.Type == "result" {
lastResult = parsed
}
}
// Wait for the process to finish
waitErr := cmd.Wait()
elapsed := time.Since(start)
// Ensure the process group is fully dead after Run returns.
if cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}
if scanErr := scanner.Err(); scanErr != nil {
log.Warn("stream_scanner_error", "err", scanErr)
}
log.Debug("claude_code_stream_done",
"elapsed_ms", elapsed.Milliseconds(),
"stderr_len", stderr.Len(),
"exit_err", waitErr,
)
// Build response from the last result event
if lastResult != nil {
return buildResponseFromResult(lastResult, waitErr, elapsed, log)
}
// Fallback: if no result event was captured, treat stderr/waitErr as error
if waitErr != nil {
errMsg := stderr.String()
if errMsg == "" {
errMsg = waitErr.Error()
}
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code stream process failed: %s", errMsg)
}
return coretypes.CompletionResponse{
Content: "",
FinishReason: "stop",
}, nil
}
// buildResponseFromResult converts a parsed result event into a CompletionResponse.
func buildResponseFromResult(output *claudeJSONOutput, execErr error, elapsed time.Duration, log *slog.Logger) (coretypes.CompletionResponse, error) {
if output.IsError {
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code error: %s", output.Result)
}
content := output.Result
if content == "" && len(output.ContentBlock) > 0 {
var parts []string
for _, block := range output.ContentBlock {
if block.Type == "text" && block.Text != "" {
parts = append(parts, block.Text)
}
}
content = strings.Join(parts, "\n")
}
finishReason := "stop"
if execErr != nil {
finishReason = "error"
}
log.Info("claude_code_response",
"content_len", len(content),
"input_tokens", output.Usage.InputTokens,
"output_tokens", output.Usage.OutputTokens,
"num_turns", output.NumTurns,
"cost_usd", output.TotalCost,
"elapsed_ms", elapsed.Milliseconds(),
)
return coretypes.CompletionResponse{
Content: content,
Usage: coretypes.TokenUsage{
InputTokens: output.Usage.InputTokens,
OutputTokens: output.Usage.OutputTokens,
TotalTokens: output.Usage.InputTokens + output.Usage.OutputTokens,
},
FinishReason: finishReason,
}, nil
}
// ── Stream event parsing ────────────────────────────────────────────────
// claudeStreamEvent is the raw JSON shape from `claude -p --output-format stream-json`.
// Each line of stdout is one JSON object with at least a "type" field.
type claudeStreamEvent struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
// For type=assistant, the message contains content blocks
Message *claudeStreamMessage `json:"message"`
// For type=result — reuse claudeJSONOutput fields
IsError bool `json:"is_error"`
Result string `json:"result"`
NumTurns int `json:"num_turns"`
TotalCost float64 `json:"total_cost_usd"`
Usage claudeUsage `json:"usage"`
Content []claudeContent `json:"content"`
}
// claudeStreamMessage represents the assistant message in a stream event.
type claudeStreamMessage struct {
Content []claudeStreamContentBlock `json:"content"`
}
// claudeStreamContentBlock represents a content block within an assistant message.
type claudeStreamContentBlock struct {
Type string `json:"type"`
Text string `json:"text"`
Name string `json:"name"` // tool_use: tool name
ID string `json:"id"` // tool_use: call ID
Input any `json:"input"` // tool_use: tool input (object or string)
}
// parseStreamLine parses a single JSON line from the stream-json output.
// Returns the StreamEvent, optionally the raw parsed result (if type=result),
// and any parse error.
func parseStreamLine(line []byte) (coretypes.StreamEvent, *claudeJSONOutput, error) {
var raw claudeStreamEvent
if err := json.Unmarshal(line, &raw); err != nil {
return coretypes.StreamEvent{}, nil, fmt.Errorf("parse stream line: %w", err)
}
switch raw.Type {
case "system":
// Init event — emit as init
return coretypes.StreamEvent{
Kind: coretypes.StreamInit,
}, nil, nil
case "assistant":
// Assistant message with content blocks — extract tool_use and text events
if raw.Message != nil && len(raw.Message.Content) > 0 {
// Look for the most interesting content block
for _, block := range raw.Message.Content {
switch block.Type {
case "tool_use":
inputStr := truncateToolInput(block.Input)
return coretypes.StreamEvent{
Kind: coretypes.StreamToolUse,
ToolName: block.Name,
ToolInput: inputStr,
}, nil, nil
case "tool_result":
return coretypes.StreamEvent{
Kind: coretypes.StreamToolResult,
}, nil, nil
case "text":
return coretypes.StreamEvent{
Kind: coretypes.StreamText,
Content: block.Text,
}, nil, nil
}
}
}
// Assistant message without interesting content blocks
return coretypes.StreamEvent{
Kind: coretypes.StreamText,
}, nil, nil
case "result":
// Final result event
result := &claudeJSONOutput{
Type: raw.Type,
Subtype: raw.Subtype,
IsError: raw.IsError,
Result: raw.Result,
NumTurns: raw.NumTurns,
TotalCost: raw.TotalCost,
Usage: raw.Usage,
}
evt := coretypes.StreamEvent{
Kind: coretypes.StreamResult,
Content: raw.Result,
IsError: raw.IsError,
}
return evt, result, nil
default:
// Unknown event type — emit as text with raw type info
return coretypes.StreamEvent{
Kind: coretypes.StreamText,
Content: raw.Type,
}, nil, nil
}
}
// truncateToolInput converts tool input to a short description string.
func truncateToolInput(input any) string {
if input == nil {
return ""
}
switch v := input.(type) {
case string:
return truncateStr(v, 100)
case map[string]any:
// For tool inputs like {"command": "ls -la"}, extract the most useful field
if cmd, ok := v["command"]; ok {
return truncateStr(fmt.Sprintf("%v", cmd), 100)
}
if file, ok := v["file_path"]; ok {
return truncateStr(fmt.Sprintf("%v", file), 100)
}
// Fallback: serialize the whole thing
b, err := json.Marshal(v)
if err != nil {
return ""
}
return truncateStr(string(b), 100)
default:
b, err := json.Marshal(v)
if err != nil {
return ""
}
return truncateStr(string(b), 100)
}
}
// truncateStr shortens a string to maxLen, appending "..." if truncated.
func truncateStr(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}
// ── Shared helpers ──────────────────────────────────────────────────────
// resolveWorkDir determines the working directory for the claude subprocess.
// If configured is empty, it creates a temporary directory to avoid inheriting the launcher's CWD.
// If configured is non-empty, it ensures the directory exists.
@@ -149,7 +433,17 @@ func resolveWorkDir(configured string, log *slog.Logger) string {
// buildClaudeArgs constructs the CLI arguments for claude -p.
func buildClaudeArgs(cfg config.ClaudeCodeCfg, req coretypes.CompletionRequest) []string {
args := []string{"--print", "--output-format", "json"}
outputFormat := "json"
if cfg.Streaming && req.StreamFunc != nil {
outputFormat = "stream-json"
}
args := []string{"--print", "--output-format", outputFormat}
// stream-json requires --verbose
if outputFormat == "stream-json" {
args = append(args, "--verbose")
}
if req.SystemPrompt != "" {
args = append(args, "--system-prompt", req.SystemPrompt)
+371
View File
@@ -371,6 +371,377 @@ func TestResolveWorkDir_ConfiguredAlreadyExists(t *testing.T) {
}
}
// ── parseStreamLine ─────────────────────────────────────────────────
func TestParseStreamLine_SystemInit(t *testing.T) {
line := []byte(`{"type":"system","subtype":"init","session_id":"abc","tools":["Bash","Read"],"model":"sonnet"}`)
evt, result, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamInit {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamInit)
}
if result != nil {
t.Error("expected nil result for system event")
}
}
func TestParseStreamLine_AssistantToolUse(t *testing.T) {
line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","id":"call_1","input":{"command":"ls -la /tmp"}}]}}`)
evt, result, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamToolUse {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamToolUse)
}
if evt.ToolName != "Bash" {
t.Errorf("tool_name = %q, want %q", evt.ToolName, "Bash")
}
if evt.ToolInput != "ls -la /tmp" {
t.Errorf("tool_input = %q, want %q", evt.ToolInput, "ls -la /tmp")
}
if result != nil {
t.Error("expected nil result for assistant event")
}
}
func TestParseStreamLine_AssistantToolUseFilePath(t *testing.T) {
line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Read","id":"call_2","input":{"file_path":"/home/user/main.go"}}]}}`)
evt, _, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamToolUse {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamToolUse)
}
if evt.ToolName != "Read" {
t.Errorf("tool_name = %q, want %q", evt.ToolName, "Read")
}
if evt.ToolInput != "/home/user/main.go" {
t.Errorf("tool_input = %q, want %q", evt.ToolInput, "/home/user/main.go")
}
}
func TestParseStreamLine_AssistantText(t *testing.T) {
line := []byte(`{"type":"assistant","message":{"content":[{"type":"text","text":"Hello, world!"}]}}`)
evt, result, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamText {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamText)
}
if evt.Content != "Hello, world!" {
t.Errorf("content = %q, want %q", evt.Content, "Hello, world!")
}
if result != nil {
t.Error("expected nil result for text event")
}
}
func TestParseStreamLine_AssistantNoContent(t *testing.T) {
line := []byte(`{"type":"assistant","message":{"content":[]}}`)
evt, _, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamText {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamText)
}
}
func TestParseStreamLine_ResultSuccess(t *testing.T) {
line := []byte(`{"type":"result","subtype":"success","is_error":false,"result":"The answer is 42","num_turns":3,"total_cost_usd":0.05,"usage":{"input_tokens":100,"output_tokens":50}}`)
evt, result, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamResult {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamResult)
}
if evt.Content != "The answer is 42" {
t.Errorf("content = %q, want %q", evt.Content, "The answer is 42")
}
if evt.IsError {
t.Error("expected IsError=false")
}
if result == nil {
t.Fatal("expected non-nil result for result event")
}
if result.Result != "The answer is 42" {
t.Errorf("result.Result = %q, want %q", result.Result, "The answer is 42")
}
if result.Usage.InputTokens != 100 {
t.Errorf("input_tokens = %d, want 100", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 50 {
t.Errorf("output_tokens = %d, want 50", result.Usage.OutputTokens)
}
if result.TotalCost != 0.05 {
t.Errorf("total_cost = %f, want 0.05", result.TotalCost)
}
}
func TestParseStreamLine_ResultError(t *testing.T) {
line := []byte(`{"type":"result","subtype":"error","is_error":true,"result":"API key expired","num_turns":0}`)
evt, result, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamResult {
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamResult)
}
if !evt.IsError {
t.Error("expected IsError=true")
}
if result == nil {
t.Fatal("expected non-nil result")
}
if !result.IsError {
t.Error("expected result.IsError=true")
}
}
func TestParseStreamLine_UnknownType(t *testing.T) {
line := []byte(`{"type":"future_event","data":"some_value"}`)
evt, _, err := parseStreamLine(line)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if evt.Kind != coretypes.StreamText {
t.Errorf("kind = %q, want %q (fallback for unknown types)", evt.Kind, coretypes.StreamText)
}
}
func TestParseStreamLine_InvalidJSON(t *testing.T) {
line := []byte(`not valid json`)
_, _, err := parseStreamLine(line)
if err == nil {
t.Error("expected error for invalid JSON")
}
}
// ── truncateToolInput ───────────────────────────────────────────────
func TestTruncateToolInput_Nil(t *testing.T) {
got := truncateToolInput(nil)
if got != "" {
t.Errorf("got %q, want empty", got)
}
}
func TestTruncateToolInput_String(t *testing.T) {
got := truncateToolInput("hello world")
if got != "hello world" {
t.Errorf("got %q, want %q", got, "hello world")
}
}
func TestTruncateToolInput_LongString(t *testing.T) {
long := strings.Repeat("x", 200)
got := truncateToolInput(long)
if len(got) != 100 {
t.Errorf("len = %d, want 100", len(got))
}
if !strings.HasSuffix(got, "...") {
t.Error("should end with ...")
}
}
func TestTruncateToolInput_MapWithCommand(t *testing.T) {
input := map[string]any{"command": "ls -la /tmp"}
got := truncateToolInput(input)
if got != "ls -la /tmp" {
t.Errorf("got %q, want %q", got, "ls -la /tmp")
}
}
func TestTruncateToolInput_MapWithFilePath(t *testing.T) {
input := map[string]any{"file_path": "/home/user/main.go"}
got := truncateToolInput(input)
if got != "/home/user/main.go" {
t.Errorf("got %q, want %q", got, "/home/user/main.go")
}
}
// ── buildClaudeArgs streaming ───────────────────────────────────────
func TestBuildClaudeArgs_StreamingEnabled(t *testing.T) {
cfg := config.ClaudeCodeCfg{
Streaming: true,
}
streamFn := func(evt coretypes.StreamEvent) {}
req := coretypes.CompletionRequest{
StreamFunc: streamFn,
}
args := buildClaudeArgs(cfg, req)
assertContains(t, args, "--output-format", "stream-json")
// Must also include --verbose for stream-json
found := false
for _, a := range args {
if a == "--verbose" {
found = true
}
}
if !found {
t.Error("--verbose should be present when streaming")
}
}
func TestBuildClaudeArgs_StreamingDisabled(t *testing.T) {
cfg := config.ClaudeCodeCfg{
Streaming: false,
}
req := coretypes.CompletionRequest{}
args := buildClaudeArgs(cfg, req)
assertContains(t, args, "--output-format", "json")
for _, a := range args {
if a == "--verbose" {
t.Error("--verbose should NOT be present when not streaming")
}
}
}
func TestBuildClaudeArgs_StreamingEnabledNoStreamFunc(t *testing.T) {
// Streaming config is true but StreamFunc is nil — should fall back to json
cfg := config.ClaudeCodeCfg{
Streaming: true,
}
req := coretypes.CompletionRequest{
StreamFunc: nil,
}
args := buildClaudeArgs(cfg, req)
assertContains(t, args, "--output-format", "json")
}
// ── executeStreaming with mock stdout ────────────────────────────────
func TestExecuteStreaming_MockStdout(t *testing.T) {
// Simulate stream-json output by writing lines to an io.Pipe
lines := []string{
`{"type":"system","subtype":"init","session_id":"test-123"}`,
`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","id":"call_1","input":{"command":"echo hello"}}]}}`,
`{"type":"assistant","message":{"content":[{"type":"text","text":"Done executing."}]}}`,
`{"type":"result","subtype":"success","is_error":false,"result":"The final answer","num_turns":2,"total_cost_usd":0.01,"usage":{"input_tokens":50,"output_tokens":25}}`,
}
var events []coretypes.StreamEvent
streamFn := func(evt coretypes.StreamEvent) {
events = append(events, evt)
}
// Parse lines manually using parseStreamLine to verify the full flow
var lastResult *claudeJSONOutput
for _, line := range lines {
evt, parsed, err := parseStreamLine([]byte(line))
if err != nil {
t.Fatalf("parse error on line: %v", err)
}
streamFn(evt)
if parsed != nil && parsed.Type == "result" {
lastResult = parsed
}
}
// Verify events
if len(events) != 4 {
t.Fatalf("expected 4 events, got %d", len(events))
}
if events[0].Kind != coretypes.StreamInit {
t.Errorf("event[0].Kind = %q, want %q", events[0].Kind, coretypes.StreamInit)
}
if events[1].Kind != coretypes.StreamToolUse {
t.Errorf("event[1].Kind = %q, want %q", events[1].Kind, coretypes.StreamToolUse)
}
if events[1].ToolName != "Bash" {
t.Errorf("event[1].ToolName = %q, want %q", events[1].ToolName, "Bash")
}
if events[1].ToolInput != "echo hello" {
t.Errorf("event[1].ToolInput = %q, want %q", events[1].ToolInput, "echo hello")
}
if events[2].Kind != coretypes.StreamText {
t.Errorf("event[2].Kind = %q, want %q", events[2].Kind, coretypes.StreamText)
}
if events[3].Kind != coretypes.StreamResult {
t.Errorf("event[3].Kind = %q, want %q", events[3].Kind, coretypes.StreamResult)
}
if events[3].Content != "The final answer" {
t.Errorf("event[3].Content = %q, want %q", events[3].Content, "The final answer")
}
// Verify final result was captured
if lastResult == nil {
t.Fatal("expected lastResult to be set")
}
if lastResult.Result != "The final answer" {
t.Errorf("lastResult.Result = %q", lastResult.Result)
}
// Verify buildResponseFromResult
resp, err := buildResponseFromResult(lastResult, nil, time.Second, discardLog)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "The final answer" {
t.Errorf("resp.Content = %q", resp.Content)
}
if resp.Usage.InputTokens != 50 {
t.Errorf("input_tokens = %d, want 50", resp.Usage.InputTokens)
}
if resp.FinishReason != "stop" {
t.Errorf("finish_reason = %q, want %q", resp.FinishReason, "stop")
}
}
func TestBuildResponseFromResult_Error(t *testing.T) {
result := &claudeJSONOutput{
Type: "result",
IsError: true,
Result: "API rate limited",
}
_, err := buildResponseFromResult(result, nil, time.Second, discardLog)
if err == nil {
t.Fatal("expected error for IsError=true")
}
if !contains(err.Error(), "API rate limited") {
t.Errorf("error = %q, should contain 'API rate limited'", err.Error())
}
}
func TestBuildResponseFromResult_ExecError(t *testing.T) {
result := &claudeJSONOutput{
Type: "result",
Result: "partial output",
Usage: claudeUsage{InputTokens: 10, OutputTokens: 5},
}
resp, err := buildResponseFromResult(result, errors.New("timeout"), time.Second, discardLog)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.FinishReason != "error" {
t.Errorf("finish_reason = %q, want %q", resp.FinishReason, "error")
}
}
// ── helpers ──────────────────────────────────────────────────────────────
func contains(s, substr string) bool {
+45
View File
@@ -286,6 +286,51 @@ func (c *Client) SendMarkdown(ctx context.Context, roomID, markdown string) erro
return err
}
// SendMarkdownGetID sends a formatted (Markdown) message to a room and returns
// the event ID of the sent message. Useful for later editing via EditMessage.
func (c *Client) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) {
html := mdToHTML(markdown)
content := &event.MessageEventContent{
MsgType: event.MsgText,
Body: markdown,
Format: event.FormatHTML,
FormattedBody: html,
}
resp, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content)
if err != nil {
return "", err
}
return resp.EventID.String(), nil
}
// EditMessage edits a previously sent message in a room using m.replace relation.
// originalEventID is the event ID of the message to replace.
// The new content is rendered from markdown.
func (c *Client) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error {
html := mdToHTML(markdown)
// Matrix spec: m.new_content holds the replacement, m.relates_to with
// rel_type=m.replace points to the original event.
content := &event.MessageEventContent{
MsgType: event.MsgText,
Body: "* " + markdown, // per spec: prefix with "* " for fallback
Format: event.FormatHTML,
FormattedBody: "* " + html,
RelatesTo: &event.RelatesTo{
Type: event.RelReplace,
EventID: id.EventID(originalEventID),
},
NewContent: &event.MessageEventContent{
MsgType: event.MsgText,
Body: markdown,
Format: event.FormatHTML,
FormattedBody: html,
},
}
_, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content)
return err
}
// mdToHTML converts a Markdown string to HTML using goldmark with full extensions.
var mdParser = goldmark.New(
goldmark.WithExtensions(