merge: issue/0036-claude-code-streaming — implementación paralela
This commit is contained in:
@@ -84,6 +84,8 @@ llm:
|
||||
fallback_model: ""
|
||||
session_id: ""
|
||||
add_dirs: []
|
||||
streaming: false # true para usar --output-format stream-json (progreso en tiempo real)
|
||||
show_tool_progress: false # true para mostrar en Matrix que herramientas usa el agente
|
||||
|
||||
fallback:
|
||||
provider: ""
|
||||
|
||||
+25
-2
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/enmanuel/agents/pkg/sanitize"
|
||||
"github.com/enmanuel/agents/shell/audit"
|
||||
"github.com/enmanuel/agents/shell/bus"
|
||||
"github.com/enmanuel/agents/shell/effects"
|
||||
)
|
||||
|
||||
// handleEvent is called by the matrix Listener for each filtered incoming event.
|
||||
@@ -184,14 +185,28 @@ func (a *Agent) executeActions(ctx context.Context, roomID string, msgCtx decisi
|
||||
})
|
||||
a.persistMessage(ctx, memKey, coretypes.RoleUser, msgCtx.Content)
|
||||
|
||||
reply, err := a.runLLM(ctx, msgCtx, memKey)
|
||||
// Create ProgressReporter for claude-code streaming if enabled
|
||||
var progress *effects.ProgressReporter
|
||||
if a.isStreamingEnabled() {
|
||||
progress = effects.NewProgressReporter(a.sender, roomID, a.logger)
|
||||
}
|
||||
|
||||
reply, err := a.runLLM(ctx, msgCtx, memKey, progress)
|
||||
if err != nil {
|
||||
a.logger.Error("llm error", "err", err)
|
||||
if progress != nil {
|
||||
progress.Finalize("\u274c Error al procesar la solicitud.")
|
||||
}
|
||||
expanded = append(expanded, decision.Action{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{Content: "Sorry, I encountered an error.", InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID},
|
||||
})
|
||||
} else {
|
||||
// If progress reporter was used, finalize it with a done indicator
|
||||
if progress != nil && progress.EventID() != "" {
|
||||
progress.Finalize("\u2705 *Completado*")
|
||||
}
|
||||
|
||||
expanded = append(expanded, decision.Action{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{Content: reply, InReplyTo: msgCtx.EventID, ThreadID: msgCtx.ThreadID},
|
||||
@@ -295,7 +310,7 @@ func (a *Agent) handleTaskEvent(ctx context.Context, msg bus.AgentMessage) {
|
||||
Role: coretypes.RoleUser, Content: msgCtx.Content,
|
||||
})
|
||||
|
||||
reply, err := a.runLLM(ctx, msgCtx, roomID)
|
||||
reply, err := a.runLLM(ctx, msgCtx, roomID, nil)
|
||||
|
||||
// Build the result to send back via bus
|
||||
result := orchestration.TaskResult{
|
||||
@@ -368,6 +383,14 @@ func (a *Agent) emitAudit(evt audit.Event) {
|
||||
}
|
||||
}
|
||||
|
||||
// isStreamingEnabled returns true when the agent uses claude-code provider
|
||||
// with streaming and show_tool_progress both enabled.
|
||||
func (a *Agent) isStreamingEnabled() bool {
|
||||
return a.cfg.LLM.Primary.Provider == "claude-code" &&
|
||||
a.cfg.LLM.Primary.ClaudeCode.Streaming &&
|
||||
a.cfg.LLM.Primary.ClaudeCode.ShowToolProgress
|
||||
}
|
||||
|
||||
// sanitizeInput runs prompt injection detection on the message content.
|
||||
// Returns the (possibly modified) content and true if the message should be rejected.
|
||||
func (a *Agent) sanitizeInput(content, roomID, senderID string) (string, bool) {
|
||||
|
||||
+11
-1
@@ -13,11 +13,14 @@ import (
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/personality"
|
||||
"github.com/enmanuel/agents/shell/audit"
|
||||
"github.com/enmanuel/agents/shell/effects"
|
||||
shelllm "github.com/enmanuel/agents/shell/llm"
|
||||
)
|
||||
|
||||
// runLLM executes the LLM completion loop, including iterative tool-use.
|
||||
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string) (string, error) {
|
||||
// progress may be nil; when non-nil, its StreamFunc is attached to the request
|
||||
// for providers that support streaming (claude-code).
|
||||
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memKey string, progress *effects.ProgressReporter) (string, error) {
|
||||
a.logger.Debug("calling LLM",
|
||||
"model", a.cfg.LLM.Primary.Model,
|
||||
"provider", a.cfg.LLM.Primary.Provider,
|
||||
@@ -62,6 +65,12 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK
|
||||
maxIter = defaultMaxToolIterations
|
||||
}
|
||||
|
||||
// Resolve StreamFunc for providers that support streaming
|
||||
var streamFn coretypes.StreamFunc
|
||||
if progress != nil {
|
||||
streamFn = progress.StreamFunc()
|
||||
}
|
||||
|
||||
// Tool-use loop: call LLM → execute tools → feed results back → repeat
|
||||
for i := 0; i < maxIter; i++ {
|
||||
req := coretypes.CompletionRequest{
|
||||
@@ -71,6 +80,7 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext, memK
|
||||
SystemPrompt: systemPrompt,
|
||||
Messages: messages,
|
||||
Tools: llmTools,
|
||||
StreamFunc: streamFn,
|
||||
}
|
||||
|
||||
resp, err := a.llm(ctx, req)
|
||||
|
||||
@@ -82,6 +82,20 @@ func (s *spyMatrixSender) SendMarkdown(_ context.Context, roomID, markdown strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *spyMatrixSender) SendMarkdownGetID(_ context.Context, roomID, markdown string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown})
|
||||
return "$spy_event_id", nil
|
||||
}
|
||||
|
||||
func (s *spyMatrixSender) EditMessage(_ context.Context, roomID, originalEventID, markdown string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.messages = append(s.messages, sentMessage{roomID: roomID, text: markdown, inReplyTo: originalEventID})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *spyMatrixSender) SendReplyMarkdown(_ context.Context, roomID, inReplyTo, markdown string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -590,7 +604,7 @@ func TestRunLLM_ToolCallExecutesAndReturns(t *testing.T) {
|
||||
IsDirectMsg: true,
|
||||
}
|
||||
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("runLLM error: %v", err)
|
||||
}
|
||||
@@ -655,7 +669,7 @@ func TestRunLLM_ToolCallFailsPassesErrorToLLM(t *testing.T) {
|
||||
Content: "do something",
|
||||
}
|
||||
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("runLLM error: %v", err)
|
||||
}
|
||||
@@ -716,7 +730,7 @@ func TestRunLLM_MaxIterationsRespected(t *testing.T) {
|
||||
Content: "loop please",
|
||||
}
|
||||
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("runLLM error: %v", err)
|
||||
}
|
||||
@@ -776,7 +790,7 @@ func TestRunLLM_RBACDeniesToolCall(t *testing.T) {
|
||||
Content: "use restricted tool",
|
||||
}
|
||||
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
|
||||
reply, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("runLLM error: %v", err)
|
||||
}
|
||||
@@ -819,7 +833,7 @@ func TestRunLLM_LLMError(t *testing.T) {
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
_, err := a.runLLM(context.Background(), msgCtx, "!room:example.com")
|
||||
_, err := a.runLLM(context.Background(), msgCtx, "!room:example.com", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from LLM, got nil")
|
||||
}
|
||||
|
||||
+13
-11
@@ -112,17 +112,19 @@ type LLMProviderCfg struct {
|
||||
|
||||
// ClaudeCodeCfg configures the claude -p subprocess provider.
|
||||
type ClaudeCodeCfg struct {
|
||||
Binary string `yaml:"binary"` // path to claude binary (default: "claude")
|
||||
Timeout time.Duration `yaml:"timeout"` // subprocess timeout (default: 5m)
|
||||
DisableTools bool `yaml:"disable_tools"` // pass --tools "" to disable all internal tools
|
||||
AllowedTools []string `yaml:"allowed_tools"` // tools claude -p can use internally (e.g. Bash, Read, Edit)
|
||||
DisallowedTools []string `yaml:"disallowed_tools"` // tools to block
|
||||
WorkingDir string `yaml:"working_dir"` // working directory for claude -p
|
||||
PermissionMode string `yaml:"permission_mode"` // default, acceptEdits, bypassPermissions, plan
|
||||
Model string `yaml:"model"` // inner model: sonnet, opus, haiku, or full name
|
||||
FallbackModel string `yaml:"fallback_model"` // fallback model if primary is overloaded
|
||||
SessionID string `yaml:"session_id"` // fixed session ID for continuity
|
||||
AddDirs []string `yaml:"add_dirs"` // additional directories accessible
|
||||
Binary string `yaml:"binary"` // path to claude binary (default: "claude")
|
||||
Timeout time.Duration `yaml:"timeout"` // subprocess timeout (default: 5m)
|
||||
DisableTools bool `yaml:"disable_tools"` // pass --tools "" to disable all internal tools
|
||||
AllowedTools []string `yaml:"allowed_tools"` // tools claude -p can use internally (e.g. Bash, Read, Edit)
|
||||
DisallowedTools []string `yaml:"disallowed_tools"` // tools to block
|
||||
WorkingDir string `yaml:"working_dir"` // working directory for claude -p
|
||||
PermissionMode string `yaml:"permission_mode"` // default, acceptEdits, bypassPermissions, plan
|
||||
Model string `yaml:"model"` // inner model: sonnet, opus, haiku, or full name
|
||||
FallbackModel string `yaml:"fallback_model"` // fallback model if primary is overloaded
|
||||
SessionID string `yaml:"session_id"` // fixed session ID for continuity
|
||||
AddDirs []string `yaml:"add_dirs"` // additional directories accessible
|
||||
Streaming bool `yaml:"streaming"` // use --output-format stream-json for realtime progress
|
||||
ShowToolProgress bool `yaml:"show_tool_progress"` // edit Matrix message to show tool usage progress
|
||||
}
|
||||
|
||||
type LLMReasoningCfg struct {
|
||||
|
||||
+38
-6
@@ -42,13 +42,14 @@ type ToolSpec struct {
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Model string
|
||||
Messages []Message
|
||||
Tools []ToolSpec
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
Stream bool
|
||||
Model string
|
||||
Messages []Message
|
||||
Tools []ToolSpec
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
Stream bool
|
||||
SystemPrompt string
|
||||
StreamFunc StreamFunc // optional: if set, streaming events are emitted during execution
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -67,3 +68,34 @@ type CompletionResponse struct {
|
||||
// CompleteFunc is the single contract for LLM providers.
|
||||
// Implementations live in shell/llm/.
|
||||
type CompleteFunc func(ctx context.Context, req CompletionRequest) (CompletionResponse, error)
|
||||
|
||||
// ── Streaming types (pure) ───────────────────────────────────────────────
|
||||
|
||||
// StreamEventKind identifies the kind of streaming event emitted by
|
||||
// a claude-code subprocess running with --output-format stream-json.
|
||||
type StreamEventKind string
|
||||
|
||||
const (
|
||||
StreamInit StreamEventKind = "init"
|
||||
StreamToolUse StreamEventKind = "tool_use"
|
||||
StreamToolResult StreamEventKind = "tool_result"
|
||||
StreamText StreamEventKind = "text"
|
||||
StreamResult StreamEventKind = "result"
|
||||
StreamError StreamEventKind = "error"
|
||||
)
|
||||
|
||||
// StreamEvent carries a single streaming event from the claude subprocess.
|
||||
// Fields are populated based on Kind; not all fields are valid for all kinds.
|
||||
type StreamEvent struct {
|
||||
Kind StreamEventKind
|
||||
ToolName string // tool_use: name of the tool being invoked
|
||||
ToolInput string // tool_use: truncated input description
|
||||
Content string // text/result: textual content
|
||||
IsError bool // result: whether the result indicates an error
|
||||
Error error // error: the error that occurred
|
||||
}
|
||||
|
||||
// StreamFunc is the callback invoked for each streaming event.
|
||||
// Implementations must be safe for concurrent use (typically not needed
|
||||
// since the streaming loop calls sequentially).
|
||||
type StreamFunc func(event StreamEvent)
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
package effects
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// ProgressReporter sends real-time progress updates to a Matrix room
|
||||
// by editing a single "status" message as the claude-code subprocess
|
||||
// emits streaming events (tool_use, text, result).
|
||||
//
|
||||
// It rate-limits edits to at most one per second to avoid flooding the
|
||||
// homeserver.
|
||||
type ProgressReporter struct {
|
||||
sender MatrixSender
|
||||
roomID string
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.Mutex
|
||||
eventID string // Matrix event ID of the progress message (empty until first send)
|
||||
lastEdit time.Time // timestamp of last edit, for rate limiting
|
||||
minInterval time.Duration
|
||||
}
|
||||
|
||||
// NewProgressReporter creates a ProgressReporter that sends progress updates
|
||||
// to the given room. The progress message is created lazily on the first event.
|
||||
func NewProgressReporter(sender MatrixSender, roomID string, logger *slog.Logger) *ProgressReporter {
|
||||
return &ProgressReporter{
|
||||
sender: sender,
|
||||
roomID: roomID,
|
||||
logger: logger,
|
||||
minInterval: time.Second, // max 1 edit/second
|
||||
}
|
||||
}
|
||||
|
||||
// StreamFunc returns a StreamFunc callback suitable for passing to
|
||||
// CompletionRequest.StreamFunc. It captures streaming events and updates
|
||||
// the progress message in the Matrix room.
|
||||
func (p *ProgressReporter) StreamFunc() coretypes.StreamFunc {
|
||||
return func(evt coretypes.StreamEvent) {
|
||||
p.handleEvent(evt)
|
||||
}
|
||||
}
|
||||
|
||||
// handleEvent processes a single streaming event and updates the Matrix message.
|
||||
func (p *ProgressReporter) handleEvent(evt coretypes.StreamEvent) {
|
||||
var markdown string
|
||||
|
||||
switch evt.Kind {
|
||||
case coretypes.StreamToolUse:
|
||||
// Show which tool is being used
|
||||
input := evt.ToolInput
|
||||
if len(input) > 60 {
|
||||
input = input[:57] + "..."
|
||||
}
|
||||
if input != "" {
|
||||
markdown = fmt.Sprintf("\U0001f527 *%s*: `%s`", evt.ToolName, input)
|
||||
} else {
|
||||
markdown = fmt.Sprintf("\U0001f527 *%s*", evt.ToolName)
|
||||
}
|
||||
|
||||
case coretypes.StreamResult:
|
||||
// Final result — no need to update progress; the handler will send the actual reply
|
||||
return
|
||||
|
||||
case coretypes.StreamText:
|
||||
// Intermediate text — could be partial thinking, skip to avoid noise
|
||||
return
|
||||
|
||||
case coretypes.StreamInit:
|
||||
markdown = "\u2699\ufe0f *Procesando...*"
|
||||
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
if markdown == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.updateMessage(markdown)
|
||||
}
|
||||
|
||||
// updateMessage sends or edits the progress message, respecting rate limits.
|
||||
func (p *ProgressReporter) updateMessage(markdown string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Rate limit: skip if we edited less than minInterval ago
|
||||
if p.eventID != "" && time.Since(p.lastEdit) < p.minInterval {
|
||||
return
|
||||
}
|
||||
|
||||
if p.eventID == "" {
|
||||
// First message: send a new one and capture the event ID
|
||||
evtID, err := p.sender.SendMarkdownGetID(ctx, p.roomID, markdown)
|
||||
if err != nil {
|
||||
p.logger.Warn("progress_reporter: failed to send initial message", "err", err)
|
||||
return
|
||||
}
|
||||
p.eventID = evtID
|
||||
p.lastEdit = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
// Subsequent updates: edit the existing message
|
||||
if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, markdown); err != nil {
|
||||
p.logger.Warn("progress_reporter: failed to edit message", "err", err)
|
||||
return
|
||||
}
|
||||
p.lastEdit = time.Now()
|
||||
}
|
||||
|
||||
// Finalize edits the progress message with the final content, or deletes it.
|
||||
// Call this after the LLM response is ready. If finalMarkdown is empty, the
|
||||
// progress message is left as-is (the handler will send a separate reply).
|
||||
func (p *ProgressReporter) Finalize(finalMarkdown string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.eventID == "" || finalMarkdown == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := p.sender.EditMessage(ctx, p.roomID, p.eventID, finalMarkdown); err != nil {
|
||||
p.logger.Warn("progress_reporter: failed to finalize message", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// EventID returns the Matrix event ID of the progress message, or empty if
|
||||
// no message was sent yet.
|
||||
func (p *ProgressReporter) EventID() string {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.eventID
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package effects
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// TestIntegration_StreamToProgressReporter simulates the full flow:
|
||||
// parseStreamLine produces events → ProgressReporter consumes them → mock sender records calls.
|
||||
// This validates the complete pipeline from raw JSON lines to Matrix messages.
|
||||
func TestIntegration_StreamToProgressReporter(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:integration", slog.Default())
|
||||
pr.minInterval = 0 // disable rate limiting for test
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
// Simulate a realistic stream-json session:
|
||||
// 1. Init event
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
// 2. Tool use: Bash
|
||||
fn(coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamToolUse,
|
||||
ToolName: "Bash",
|
||||
ToolInput: "git status",
|
||||
})
|
||||
|
||||
// 3. Tool use: Read
|
||||
fn(coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamToolUse,
|
||||
ToolName: "Read",
|
||||
ToolInput: "/home/user/project/main.go",
|
||||
})
|
||||
|
||||
// 4. Tool use: Edit
|
||||
fn(coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamToolUse,
|
||||
ToolName: "Edit",
|
||||
ToolInput: "/home/user/project/main.go",
|
||||
})
|
||||
|
||||
// 5. Text event (intermediate, should be ignored)
|
||||
fn(coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamText,
|
||||
Content: "I've made the changes...",
|
||||
})
|
||||
|
||||
// 6. Result event (should be ignored by progress reporter)
|
||||
fn(coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamResult,
|
||||
Content: "Here is the final answer with all changes applied.",
|
||||
})
|
||||
|
||||
// Verify sends: only 1 initial send
|
||||
if len(sender.sends) != 1 {
|
||||
t.Errorf("expected 1 send (init), got %d", len(sender.sends))
|
||||
}
|
||||
if !strings.Contains(sender.sends[0], "Procesando") {
|
||||
t.Errorf("init message should contain 'Procesando', got %q", sender.sends[0])
|
||||
}
|
||||
|
||||
// Verify edits: 3 tool use events
|
||||
if len(sender.edits) != 3 {
|
||||
t.Fatalf("expected 3 edits (tool uses), got %d", len(sender.edits))
|
||||
}
|
||||
|
||||
// First edit: Bash
|
||||
if !strings.Contains(sender.edits[0], "Bash") {
|
||||
t.Errorf("edit[0] should mention Bash, got %q", sender.edits[0])
|
||||
}
|
||||
if !strings.Contains(sender.edits[0], "git status") {
|
||||
t.Errorf("edit[0] should show input, got %q", sender.edits[0])
|
||||
}
|
||||
|
||||
// Second edit: Read
|
||||
if !strings.Contains(sender.edits[1], "Read") {
|
||||
t.Errorf("edit[1] should mention Read, got %q", sender.edits[1])
|
||||
}
|
||||
|
||||
// Third edit: Edit
|
||||
if !strings.Contains(sender.edits[2], "Edit") {
|
||||
t.Errorf("edit[2] should mention Edit, got %q", sender.edits[2])
|
||||
}
|
||||
|
||||
// All edits should target the same event ID
|
||||
for i, target := range sender.editTargets {
|
||||
if target != "$progress_msg_1" {
|
||||
t.Errorf("editTarget[%d] = %q, want %q", i, target, "$progress_msg_1")
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize
|
||||
pr.Finalize("\u2705 *Completado*")
|
||||
|
||||
if len(sender.edits) != 4 {
|
||||
t.Fatalf("expected 4 edits (3 tools + 1 finalize), got %d", len(sender.edits))
|
||||
}
|
||||
if !strings.Contains(sender.edits[3], "Completado") {
|
||||
t.Errorf("finalize edit should contain 'Completado', got %q", sender.edits[3])
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_NoStreamingNoSideEffects verifies that when streaming is
|
||||
// not enabled, no ProgressReporter is created and no Matrix side effects occur.
|
||||
// This is a regression test for the streaming=false default behavior.
|
||||
func TestIntegration_NoStreamingNoSideEffects(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
|
||||
// Simulate the handler check: streaming disabled → progress is nil
|
||||
var progress *ProgressReporter // nil, because streaming is disabled
|
||||
|
||||
if progress != nil {
|
||||
t.Error("progress reporter should be nil when streaming is disabled")
|
||||
}
|
||||
|
||||
// Verify no sends or edits happened
|
||||
if len(sender.sends) != 0 {
|
||||
t.Errorf("expected 0 sends, got %d", len(sender.sends))
|
||||
}
|
||||
if len(sender.edits) != 0 {
|
||||
t.Errorf("expected 0 edits, got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ProgressReporterWithSendError verifies that the reporter
|
||||
// handles send errors gracefully without panicking.
|
||||
func TestIntegration_ProgressReporterWithSendError(t *testing.T) {
|
||||
sender := &errorSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 0
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
// Should not panic even when send fails
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
// EventID should be empty since send failed
|
||||
if pr.EventID() != "" {
|
||||
t.Errorf("expected empty EventID after send error, got %q", pr.EventID())
|
||||
}
|
||||
|
||||
// Finalize should be a no-op since no message was sent
|
||||
pr.Finalize("Done")
|
||||
}
|
||||
|
||||
// errorSender always returns errors.
|
||||
type errorSender struct {
|
||||
fakeMatrixSender
|
||||
}
|
||||
|
||||
func (e *errorSender) SendMarkdownGetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func (e *errorSender) EditMessage(_ context.Context, _, _, _ string) error {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package effects
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// mockProgressSender records sends and edits for testing ProgressReporter.
|
||||
type mockProgressSender struct {
|
||||
fakeMatrixSender // embed to satisfy the full interface
|
||||
sends []string // markdowns from SendMarkdownGetID
|
||||
edits []string // markdowns from EditMessage
|
||||
editTargets []string // event IDs targeted by EditMessage
|
||||
}
|
||||
|
||||
func (m *mockProgressSender) SendMarkdownGetID(_ context.Context, _, markdown string) (string, error) {
|
||||
m.sends = append(m.sends, markdown)
|
||||
return "$progress_msg_1", nil
|
||||
}
|
||||
|
||||
func (m *mockProgressSender) EditMessage(_ context.Context, _, originalEventID, markdown string) error {
|
||||
m.edits = append(m.edits, markdown)
|
||||
m.editTargets = append(m.editTargets, originalEventID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestProgressReporter_InitEvent(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
if len(sender.sends) != 1 {
|
||||
t.Fatalf("expected 1 send, got %d", len(sender.sends))
|
||||
}
|
||||
if !strings.Contains(sender.sends[0], "Procesando") {
|
||||
t.Errorf("init message = %q, should contain 'Procesando'", sender.sends[0])
|
||||
}
|
||||
if pr.EventID() != "$progress_msg_1" {
|
||||
t.Errorf("EventID = %q, want %q", pr.EventID(), "$progress_msg_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_ToolUseEditsMessage(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 0 // disable rate limiting for test
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
// First event creates the message
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
// Second event should edit
|
||||
fn(coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamToolUse,
|
||||
ToolName: "Bash",
|
||||
ToolInput: "ls -la",
|
||||
})
|
||||
|
||||
if len(sender.edits) != 1 {
|
||||
t.Fatalf("expected 1 edit, got %d", len(sender.edits))
|
||||
}
|
||||
if !strings.Contains(sender.edits[0], "Bash") {
|
||||
t.Errorf("edit = %q, should contain tool name", sender.edits[0])
|
||||
}
|
||||
if !strings.Contains(sender.edits[0], "ls -la") {
|
||||
t.Errorf("edit = %q, should contain tool input", sender.edits[0])
|
||||
}
|
||||
if sender.editTargets[0] != "$progress_msg_1" {
|
||||
t.Errorf("edit target = %q, want %q", sender.editTargets[0], "$progress_msg_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_MultipleToolUse(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 0 // disable rate limiting for test
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"})
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "/tmp/file.go"})
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "/tmp/file.go"})
|
||||
|
||||
// 1 send (init) + 3 edits (tool uses)
|
||||
if len(sender.sends) != 1 {
|
||||
t.Errorf("expected 1 send, got %d", len(sender.sends))
|
||||
}
|
||||
if len(sender.edits) != 3 {
|
||||
t.Errorf("expected 3 edits, got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_RateLimiting(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 500 * time.Millisecond
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
// First event creates the message (no rate limit on first send)
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
// Reset lastEdit to simulate time having passed after init
|
||||
pr.mu.Lock()
|
||||
pr.lastEdit = time.Now().Add(-time.Second)
|
||||
pr.mu.Unlock()
|
||||
|
||||
// First tool event should go through (enough time has passed)
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: "echo 1"})
|
||||
// These rapid-fire events should be rate-limited
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Read", ToolInput: "file.go"})
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Edit", ToolInput: "file.go"})
|
||||
|
||||
// Only 1 edit should have gone through (the rest rate limited)
|
||||
if len(sender.edits) != 1 {
|
||||
t.Errorf("expected 1 edit (rate limited), got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_ResultIgnored(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamResult, Content: "Final answer"})
|
||||
|
||||
// Result should not trigger an edit
|
||||
if len(sender.edits) != 0 {
|
||||
t.Errorf("expected 0 edits for result event, got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_TextIgnored(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamText, Content: "Some thinking..."})
|
||||
|
||||
// Text events should not trigger edits
|
||||
if len(sender.edits) != 0 {
|
||||
t.Errorf("expected 0 edits for text event, got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_Finalize(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 0
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
pr.Finalize("Done! Here is the result.")
|
||||
|
||||
if len(sender.edits) != 1 {
|
||||
t.Fatalf("expected 1 edit for finalize, got %d", len(sender.edits))
|
||||
}
|
||||
if sender.edits[0] != "Done! Here is the result." {
|
||||
t.Errorf("finalize edit = %q", sender.edits[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_FinalizeNoMessage(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
|
||||
// Finalize without ever sending a message should be a no-op
|
||||
pr.Finalize("Final")
|
||||
|
||||
if len(sender.edits) != 0 {
|
||||
t.Errorf("expected 0 edits when no message was sent, got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_FinalizeEmpty(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 0
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
// Empty finalize should be a no-op
|
||||
pr.Finalize("")
|
||||
|
||||
if len(sender.edits) != 0 {
|
||||
t.Errorf("expected 0 edits for empty finalize, got %d", len(sender.edits))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressReporter_ToolInputTruncation(t *testing.T) {
|
||||
sender := &mockProgressSender{}
|
||||
pr := NewProgressReporter(sender, "!room:test", slog.Default())
|
||||
pr.minInterval = 0
|
||||
|
||||
fn := pr.StreamFunc()
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamInit})
|
||||
|
||||
longInput := strings.Repeat("x", 100)
|
||||
fn(coretypes.StreamEvent{Kind: coretypes.StreamToolUse, ToolName: "Bash", ToolInput: longInput})
|
||||
|
||||
if len(sender.edits) != 1 {
|
||||
t.Fatalf("expected 1 edit, got %d", len(sender.edits))
|
||||
}
|
||||
// The input in the message should be truncated
|
||||
if strings.Contains(sender.edits[0], longInput) {
|
||||
t.Error("long input should be truncated in the message")
|
||||
}
|
||||
if !strings.Contains(sender.edits[0], "...") {
|
||||
t.Error("truncated input should end with ...")
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,8 @@ type Result struct {
|
||||
type MatrixSender interface {
|
||||
SendText(ctx context.Context, roomID, text string) error
|
||||
SendMarkdown(ctx context.Context, roomID, markdown string) error
|
||||
SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error)
|
||||
EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error
|
||||
SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error
|
||||
SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error
|
||||
SendTyping(ctx context.Context, roomID string, typing bool) error
|
||||
|
||||
@@ -31,6 +31,16 @@ func (f *fakeMatrixSender) SendMarkdown(ctx context.Context, roomID, markdown st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeMatrixSender) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) {
|
||||
f.calls = append(f.calls, senderCall{method: "SendMarkdownGetID", roomID: roomID, markdown: markdown})
|
||||
return "$fake_event_id", nil
|
||||
}
|
||||
|
||||
func (f *fakeMatrixSender) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error {
|
||||
f.calls = append(f.calls, senderCall{method: "EditMessage", roomID: roomID, inReplyTo: originalEventID, markdown: markdown})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeMatrixSender) SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error {
|
||||
f.calls = append(f.calls, senderCall{method: "SendReplyMarkdown", roomID: roomID, inReplyTo: inReplyTo, markdown: markdown})
|
||||
return nil
|
||||
|
||||
+316
-22
@@ -1,6 +1,7 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
@@ -74,6 +75,7 @@ func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes
|
||||
"args", strings.Join(args, " "),
|
||||
"prompt_len", len(prompt),
|
||||
"working_dir", workDir,
|
||||
"streaming", cfg.Streaming,
|
||||
)
|
||||
|
||||
cmd := exec.CommandContext(ctx, binary, args...)
|
||||
@@ -99,31 +101,313 @@ func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes
|
||||
return nil
|
||||
}
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
start := time.Now()
|
||||
err := cmd.Run()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Ensure the process group is fully dead after Run returns,
|
||||
// even if cmd.Run() returned without triggering Cancel (normal exit).
|
||||
if cmd.Process != nil {
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||
// Choose between streaming and buffered mode
|
||||
if cfg.Streaming && req.StreamFunc != nil {
|
||||
return executeStreaming(ctx, cmd, req.StreamFunc, log)
|
||||
}
|
||||
|
||||
log.Debug("claude_code_done",
|
||||
"elapsed_ms", elapsed.Milliseconds(),
|
||||
"stdout_len", stdout.Len(),
|
||||
"stderr_len", stderr.Len(),
|
||||
"exit_err", err,
|
||||
)
|
||||
|
||||
return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log)
|
||||
return executeBuffered(ctx, cmd, log)
|
||||
}
|
||||
}
|
||||
|
||||
// executeBuffered runs the claude subprocess and collects all output at once.
|
||||
// This is the original (non-streaming) code path.
|
||||
func executeBuffered(ctx context.Context, cmd *exec.Cmd, log *slog.Logger) (coretypes.CompletionResponse, error) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
start := time.Now()
|
||||
err := cmd.Run()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Ensure the process group is fully dead after Run returns.
|
||||
if cmd.Process != nil {
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
log.Debug("claude_code_done",
|
||||
"elapsed_ms", elapsed.Milliseconds(),
|
||||
"stdout_len", stdout.Len(),
|
||||
"stderr_len", stderr.Len(),
|
||||
"exit_err", err,
|
||||
)
|
||||
|
||||
return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log)
|
||||
}
|
||||
|
||||
// executeStreaming runs the claude subprocess with --output-format stream-json,
|
||||
// reads stdout line by line, emits StreamEvents via the callback, and accumulates
|
||||
// the final result.
|
||||
func executeStreaming(ctx context.Context, cmd *exec.Cmd, streamFn coretypes.StreamFunc, log *slog.Logger) (coretypes.CompletionResponse, error) {
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code: stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code: start: %w", err)
|
||||
}
|
||||
|
||||
// Scan stdout line by line, parsing each JSON event
|
||||
var lastResult *claudeJSONOutput
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 0, 256*1024), 1024*1024) // allow up to 1MB lines
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
evt, parsed, parseErr := parseStreamLine(line)
|
||||
if parseErr != nil {
|
||||
log.Debug("stream_line_parse_error", "err", parseErr, "line_len", len(line))
|
||||
continue
|
||||
}
|
||||
|
||||
// Emit the event to the callback
|
||||
streamFn(evt)
|
||||
|
||||
// Keep track of the final result event
|
||||
if parsed != nil && parsed.Type == "result" {
|
||||
lastResult = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the process to finish
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Ensure the process group is fully dead after Run returns.
|
||||
if cmd.Process != nil {
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
if scanErr := scanner.Err(); scanErr != nil {
|
||||
log.Warn("stream_scanner_error", "err", scanErr)
|
||||
}
|
||||
|
||||
log.Debug("claude_code_stream_done",
|
||||
"elapsed_ms", elapsed.Milliseconds(),
|
||||
"stderr_len", stderr.Len(),
|
||||
"exit_err", waitErr,
|
||||
)
|
||||
|
||||
// Build response from the last result event
|
||||
if lastResult != nil {
|
||||
return buildResponseFromResult(lastResult, waitErr, elapsed, log)
|
||||
}
|
||||
|
||||
// Fallback: if no result event was captured, treat stderr/waitErr as error
|
||||
if waitErr != nil {
|
||||
errMsg := stderr.String()
|
||||
if errMsg == "" {
|
||||
errMsg = waitErr.Error()
|
||||
}
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code stream process failed: %s", errMsg)
|
||||
}
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildResponseFromResult converts a parsed result event into a CompletionResponse.
|
||||
func buildResponseFromResult(output *claudeJSONOutput, execErr error, elapsed time.Duration, log *slog.Logger) (coretypes.CompletionResponse, error) {
|
||||
if output.IsError {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code error: %s", output.Result)
|
||||
}
|
||||
|
||||
content := output.Result
|
||||
if content == "" && len(output.ContentBlock) > 0 {
|
||||
var parts []string
|
||||
for _, block := range output.ContentBlock {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
content = strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if execErr != nil {
|
||||
finishReason = "error"
|
||||
}
|
||||
|
||||
log.Info("claude_code_response",
|
||||
"content_len", len(content),
|
||||
"input_tokens", output.Usage.InputTokens,
|
||||
"output_tokens", output.Usage.OutputTokens,
|
||||
"num_turns", output.NumTurns,
|
||||
"cost_usd", output.TotalCost,
|
||||
"elapsed_ms", elapsed.Milliseconds(),
|
||||
)
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: content,
|
||||
Usage: coretypes.TokenUsage{
|
||||
InputTokens: output.Usage.InputTokens,
|
||||
OutputTokens: output.Usage.OutputTokens,
|
||||
TotalTokens: output.Usage.InputTokens + output.Usage.OutputTokens,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ── Stream event parsing ────────────────────────────────────────────────
|
||||
|
||||
// claudeStreamEvent is the raw JSON shape from `claude -p --output-format stream-json`.
|
||||
// Each line of stdout is one JSON object with at least a "type" field.
|
||||
type claudeStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
|
||||
// For type=assistant, the message contains content blocks
|
||||
Message *claudeStreamMessage `json:"message"`
|
||||
|
||||
// For type=result — reuse claudeJSONOutput fields
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
TotalCost float64 `json:"total_cost_usd"`
|
||||
Usage claudeUsage `json:"usage"`
|
||||
Content []claudeContent `json:"content"`
|
||||
}
|
||||
|
||||
// claudeStreamMessage represents the assistant message in a stream event.
|
||||
type claudeStreamMessage struct {
|
||||
Content []claudeStreamContentBlock `json:"content"`
|
||||
}
|
||||
|
||||
// claudeStreamContentBlock represents a content block within an assistant message.
|
||||
type claudeStreamContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Name string `json:"name"` // tool_use: tool name
|
||||
ID string `json:"id"` // tool_use: call ID
|
||||
Input any `json:"input"` // tool_use: tool input (object or string)
|
||||
}
|
||||
|
||||
// parseStreamLine parses a single JSON line from the stream-json output.
|
||||
// Returns the StreamEvent, optionally the raw parsed result (if type=result),
|
||||
// and any parse error.
|
||||
func parseStreamLine(line []byte) (coretypes.StreamEvent, *claudeJSONOutput, error) {
|
||||
var raw claudeStreamEvent
|
||||
if err := json.Unmarshal(line, &raw); err != nil {
|
||||
return coretypes.StreamEvent{}, nil, fmt.Errorf("parse stream line: %w", err)
|
||||
}
|
||||
|
||||
switch raw.Type {
|
||||
case "system":
|
||||
// Init event — emit as init
|
||||
return coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamInit,
|
||||
}, nil, nil
|
||||
|
||||
case "assistant":
|
||||
// Assistant message with content blocks — extract tool_use and text events
|
||||
if raw.Message != nil && len(raw.Message.Content) > 0 {
|
||||
// Look for the most interesting content block
|
||||
for _, block := range raw.Message.Content {
|
||||
switch block.Type {
|
||||
case "tool_use":
|
||||
inputStr := truncateToolInput(block.Input)
|
||||
return coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamToolUse,
|
||||
ToolName: block.Name,
|
||||
ToolInput: inputStr,
|
||||
}, nil, nil
|
||||
case "tool_result":
|
||||
return coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamToolResult,
|
||||
}, nil, nil
|
||||
case "text":
|
||||
return coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamText,
|
||||
Content: block.Text,
|
||||
}, nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assistant message without interesting content blocks
|
||||
return coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamText,
|
||||
}, nil, nil
|
||||
|
||||
case "result":
|
||||
// Final result event
|
||||
result := &claudeJSONOutput{
|
||||
Type: raw.Type,
|
||||
Subtype: raw.Subtype,
|
||||
IsError: raw.IsError,
|
||||
Result: raw.Result,
|
||||
NumTurns: raw.NumTurns,
|
||||
TotalCost: raw.TotalCost,
|
||||
Usage: raw.Usage,
|
||||
}
|
||||
evt := coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamResult,
|
||||
Content: raw.Result,
|
||||
IsError: raw.IsError,
|
||||
}
|
||||
return evt, result, nil
|
||||
|
||||
default:
|
||||
// Unknown event type — emit as text with raw type info
|
||||
return coretypes.StreamEvent{
|
||||
Kind: coretypes.StreamText,
|
||||
Content: raw.Type,
|
||||
}, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// truncateToolInput converts tool input to a short description string.
|
||||
func truncateToolInput(input any) string {
|
||||
if input == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return truncateStr(v, 100)
|
||||
case map[string]any:
|
||||
// For tool inputs like {"command": "ls -la"}, extract the most useful field
|
||||
if cmd, ok := v["command"]; ok {
|
||||
return truncateStr(fmt.Sprintf("%v", cmd), 100)
|
||||
}
|
||||
if file, ok := v["file_path"]; ok {
|
||||
return truncateStr(fmt.Sprintf("%v", file), 100)
|
||||
}
|
||||
// Fallback: serialize the whole thing
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return truncateStr(string(b), 100)
|
||||
default:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return truncateStr(string(b), 100)
|
||||
}
|
||||
}
|
||||
|
||||
// truncateStr shortens a string to maxLen, appending "..." if truncated.
|
||||
func truncateStr(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// ── Shared helpers ──────────────────────────────────────────────────────
|
||||
|
||||
// resolveWorkDir determines the working directory for the claude subprocess.
|
||||
// If configured is empty, it creates a temporary directory to avoid inheriting the launcher's CWD.
|
||||
// If configured is non-empty, it ensures the directory exists.
|
||||
@@ -149,7 +433,17 @@ func resolveWorkDir(configured string, log *slog.Logger) string {
|
||||
|
||||
// buildClaudeArgs constructs the CLI arguments for claude -p.
|
||||
func buildClaudeArgs(cfg config.ClaudeCodeCfg, req coretypes.CompletionRequest) []string {
|
||||
args := []string{"--print", "--output-format", "json"}
|
||||
outputFormat := "json"
|
||||
if cfg.Streaming && req.StreamFunc != nil {
|
||||
outputFormat = "stream-json"
|
||||
}
|
||||
|
||||
args := []string{"--print", "--output-format", outputFormat}
|
||||
|
||||
// stream-json requires --verbose
|
||||
if outputFormat == "stream-json" {
|
||||
args = append(args, "--verbose")
|
||||
}
|
||||
|
||||
if req.SystemPrompt != "" {
|
||||
args = append(args, "--system-prompt", req.SystemPrompt)
|
||||
|
||||
@@ -371,6 +371,377 @@ func TestResolveWorkDir_ConfiguredAlreadyExists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ── parseStreamLine ─────────────────────────────────────────────────
|
||||
|
||||
func TestParseStreamLine_SystemInit(t *testing.T) {
|
||||
line := []byte(`{"type":"system","subtype":"init","session_id":"abc","tools":["Bash","Read"],"model":"sonnet"}`)
|
||||
|
||||
evt, result, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamInit {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamInit)
|
||||
}
|
||||
if result != nil {
|
||||
t.Error("expected nil result for system event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_AssistantToolUse(t *testing.T) {
|
||||
line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","id":"call_1","input":{"command":"ls -la /tmp"}}]}}`)
|
||||
|
||||
evt, result, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamToolUse {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamToolUse)
|
||||
}
|
||||
if evt.ToolName != "Bash" {
|
||||
t.Errorf("tool_name = %q, want %q", evt.ToolName, "Bash")
|
||||
}
|
||||
if evt.ToolInput != "ls -la /tmp" {
|
||||
t.Errorf("tool_input = %q, want %q", evt.ToolInput, "ls -la /tmp")
|
||||
}
|
||||
if result != nil {
|
||||
t.Error("expected nil result for assistant event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_AssistantToolUseFilePath(t *testing.T) {
|
||||
line := []byte(`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Read","id":"call_2","input":{"file_path":"/home/user/main.go"}}]}}`)
|
||||
|
||||
evt, _, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamToolUse {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamToolUse)
|
||||
}
|
||||
if evt.ToolName != "Read" {
|
||||
t.Errorf("tool_name = %q, want %q", evt.ToolName, "Read")
|
||||
}
|
||||
if evt.ToolInput != "/home/user/main.go" {
|
||||
t.Errorf("tool_input = %q, want %q", evt.ToolInput, "/home/user/main.go")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_AssistantText(t *testing.T) {
|
||||
line := []byte(`{"type":"assistant","message":{"content":[{"type":"text","text":"Hello, world!"}]}}`)
|
||||
|
||||
evt, result, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamText {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamText)
|
||||
}
|
||||
if evt.Content != "Hello, world!" {
|
||||
t.Errorf("content = %q, want %q", evt.Content, "Hello, world!")
|
||||
}
|
||||
if result != nil {
|
||||
t.Error("expected nil result for text event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_AssistantNoContent(t *testing.T) {
|
||||
line := []byte(`{"type":"assistant","message":{"content":[]}}`)
|
||||
|
||||
evt, _, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamText {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_ResultSuccess(t *testing.T) {
|
||||
line := []byte(`{"type":"result","subtype":"success","is_error":false,"result":"The answer is 42","num_turns":3,"total_cost_usd":0.05,"usage":{"input_tokens":100,"output_tokens":50}}`)
|
||||
|
||||
evt, result, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamResult {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamResult)
|
||||
}
|
||||
if evt.Content != "The answer is 42" {
|
||||
t.Errorf("content = %q, want %q", evt.Content, "The answer is 42")
|
||||
}
|
||||
if evt.IsError {
|
||||
t.Error("expected IsError=false")
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result for result event")
|
||||
}
|
||||
if result.Result != "The answer is 42" {
|
||||
t.Errorf("result.Result = %q, want %q", result.Result, "The answer is 42")
|
||||
}
|
||||
if result.Usage.InputTokens != 100 {
|
||||
t.Errorf("input_tokens = %d, want 100", result.Usage.InputTokens)
|
||||
}
|
||||
if result.Usage.OutputTokens != 50 {
|
||||
t.Errorf("output_tokens = %d, want 50", result.Usage.OutputTokens)
|
||||
}
|
||||
if result.TotalCost != 0.05 {
|
||||
t.Errorf("total_cost = %f, want 0.05", result.TotalCost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_ResultError(t *testing.T) {
|
||||
line := []byte(`{"type":"result","subtype":"error","is_error":true,"result":"API key expired","num_turns":0}`)
|
||||
|
||||
evt, result, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamResult {
|
||||
t.Errorf("kind = %q, want %q", evt.Kind, coretypes.StreamResult)
|
||||
}
|
||||
if !evt.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("expected result.IsError=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_UnknownType(t *testing.T) {
|
||||
line := []byte(`{"type":"future_event","data":"some_value"}`)
|
||||
|
||||
evt, _, err := parseStreamLine(line)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if evt.Kind != coretypes.StreamText {
|
||||
t.Errorf("kind = %q, want %q (fallback for unknown types)", evt.Kind, coretypes.StreamText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamLine_InvalidJSON(t *testing.T) {
|
||||
line := []byte(`not valid json`)
|
||||
|
||||
_, _, err := parseStreamLine(line)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// ── truncateToolInput ───────────────────────────────────────────────
|
||||
|
||||
func TestTruncateToolInput_Nil(t *testing.T) {
|
||||
got := truncateToolInput(nil)
|
||||
if got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolInput_String(t *testing.T) {
|
||||
got := truncateToolInput("hello world")
|
||||
if got != "hello world" {
|
||||
t.Errorf("got %q, want %q", got, "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolInput_LongString(t *testing.T) {
|
||||
long := strings.Repeat("x", 200)
|
||||
got := truncateToolInput(long)
|
||||
if len(got) != 100 {
|
||||
t.Errorf("len = %d, want 100", len(got))
|
||||
}
|
||||
if !strings.HasSuffix(got, "...") {
|
||||
t.Error("should end with ...")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolInput_MapWithCommand(t *testing.T) {
|
||||
input := map[string]any{"command": "ls -la /tmp"}
|
||||
got := truncateToolInput(input)
|
||||
if got != "ls -la /tmp" {
|
||||
t.Errorf("got %q, want %q", got, "ls -la /tmp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolInput_MapWithFilePath(t *testing.T) {
|
||||
input := map[string]any{"file_path": "/home/user/main.go"}
|
||||
got := truncateToolInput(input)
|
||||
if got != "/home/user/main.go" {
|
||||
t.Errorf("got %q, want %q", got, "/home/user/main.go")
|
||||
}
|
||||
}
|
||||
|
||||
// ── buildClaudeArgs streaming ───────────────────────────────────────
|
||||
|
||||
func TestBuildClaudeArgs_StreamingEnabled(t *testing.T) {
|
||||
cfg := config.ClaudeCodeCfg{
|
||||
Streaming: true,
|
||||
}
|
||||
streamFn := func(evt coretypes.StreamEvent) {}
|
||||
req := coretypes.CompletionRequest{
|
||||
StreamFunc: streamFn,
|
||||
}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
assertContains(t, args, "--output-format", "stream-json")
|
||||
// Must also include --verbose for stream-json
|
||||
found := false
|
||||
for _, a := range args {
|
||||
if a == "--verbose" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("--verbose should be present when streaming")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeArgs_StreamingDisabled(t *testing.T) {
|
||||
cfg := config.ClaudeCodeCfg{
|
||||
Streaming: false,
|
||||
}
|
||||
req := coretypes.CompletionRequest{}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
assertContains(t, args, "--output-format", "json")
|
||||
for _, a := range args {
|
||||
if a == "--verbose" {
|
||||
t.Error("--verbose should NOT be present when not streaming")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeArgs_StreamingEnabledNoStreamFunc(t *testing.T) {
|
||||
// Streaming config is true but StreamFunc is nil — should fall back to json
|
||||
cfg := config.ClaudeCodeCfg{
|
||||
Streaming: true,
|
||||
}
|
||||
req := coretypes.CompletionRequest{
|
||||
StreamFunc: nil,
|
||||
}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
assertContains(t, args, "--output-format", "json")
|
||||
}
|
||||
|
||||
// ── executeStreaming with mock stdout ────────────────────────────────
|
||||
|
||||
func TestExecuteStreaming_MockStdout(t *testing.T) {
|
||||
// Simulate stream-json output by writing lines to an io.Pipe
|
||||
lines := []string{
|
||||
`{"type":"system","subtype":"init","session_id":"test-123"}`,
|
||||
`{"type":"assistant","message":{"content":[{"type":"tool_use","name":"Bash","id":"call_1","input":{"command":"echo hello"}}]}}`,
|
||||
`{"type":"assistant","message":{"content":[{"type":"text","text":"Done executing."}]}}`,
|
||||
`{"type":"result","subtype":"success","is_error":false,"result":"The final answer","num_turns":2,"total_cost_usd":0.01,"usage":{"input_tokens":50,"output_tokens":25}}`,
|
||||
}
|
||||
|
||||
var events []coretypes.StreamEvent
|
||||
streamFn := func(evt coretypes.StreamEvent) {
|
||||
events = append(events, evt)
|
||||
}
|
||||
|
||||
// Parse lines manually using parseStreamLine to verify the full flow
|
||||
var lastResult *claudeJSONOutput
|
||||
for _, line := range lines {
|
||||
evt, parsed, err := parseStreamLine([]byte(line))
|
||||
if err != nil {
|
||||
t.Fatalf("parse error on line: %v", err)
|
||||
}
|
||||
streamFn(evt)
|
||||
if parsed != nil && parsed.Type == "result" {
|
||||
lastResult = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Verify events
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("expected 4 events, got %d", len(events))
|
||||
}
|
||||
if events[0].Kind != coretypes.StreamInit {
|
||||
t.Errorf("event[0].Kind = %q, want %q", events[0].Kind, coretypes.StreamInit)
|
||||
}
|
||||
if events[1].Kind != coretypes.StreamToolUse {
|
||||
t.Errorf("event[1].Kind = %q, want %q", events[1].Kind, coretypes.StreamToolUse)
|
||||
}
|
||||
if events[1].ToolName != "Bash" {
|
||||
t.Errorf("event[1].ToolName = %q, want %q", events[1].ToolName, "Bash")
|
||||
}
|
||||
if events[1].ToolInput != "echo hello" {
|
||||
t.Errorf("event[1].ToolInput = %q, want %q", events[1].ToolInput, "echo hello")
|
||||
}
|
||||
if events[2].Kind != coretypes.StreamText {
|
||||
t.Errorf("event[2].Kind = %q, want %q", events[2].Kind, coretypes.StreamText)
|
||||
}
|
||||
if events[3].Kind != coretypes.StreamResult {
|
||||
t.Errorf("event[3].Kind = %q, want %q", events[3].Kind, coretypes.StreamResult)
|
||||
}
|
||||
if events[3].Content != "The final answer" {
|
||||
t.Errorf("event[3].Content = %q, want %q", events[3].Content, "The final answer")
|
||||
}
|
||||
|
||||
// Verify final result was captured
|
||||
if lastResult == nil {
|
||||
t.Fatal("expected lastResult to be set")
|
||||
}
|
||||
if lastResult.Result != "The final answer" {
|
||||
t.Errorf("lastResult.Result = %q", lastResult.Result)
|
||||
}
|
||||
|
||||
// Verify buildResponseFromResult
|
||||
resp, err := buildResponseFromResult(lastResult, nil, time.Second, discardLog)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "The final answer" {
|
||||
t.Errorf("resp.Content = %q", resp.Content)
|
||||
}
|
||||
if resp.Usage.InputTokens != 50 {
|
||||
t.Errorf("input_tokens = %d, want 50", resp.Usage.InputTokens)
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("finish_reason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseFromResult_Error(t *testing.T) {
|
||||
result := &claudeJSONOutput{
|
||||
Type: "result",
|
||||
IsError: true,
|
||||
Result: "API rate limited",
|
||||
}
|
||||
|
||||
_, err := buildResponseFromResult(result, nil, time.Second, discardLog)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for IsError=true")
|
||||
}
|
||||
if !contains(err.Error(), "API rate limited") {
|
||||
t.Errorf("error = %q, should contain 'API rate limited'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseFromResult_ExecError(t *testing.T) {
|
||||
result := &claudeJSONOutput{
|
||||
Type: "result",
|
||||
Result: "partial output",
|
||||
Usage: claudeUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
resp, err := buildResponseFromResult(result, errors.New("timeout"), time.Second, discardLog)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.FinishReason != "error" {
|
||||
t.Errorf("finish_reason = %q, want %q", resp.FinishReason, "error")
|
||||
}
|
||||
}
|
||||
|
||||
// ── helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
|
||||
@@ -286,6 +286,51 @@ func (c *Client) SendMarkdown(ctx context.Context, roomID, markdown string) erro
|
||||
return err
|
||||
}
|
||||
|
||||
// SendMarkdownGetID sends a formatted (Markdown) message to a room and returns
|
||||
// the event ID of the sent message. Useful for later editing via EditMessage.
|
||||
func (c *Client) SendMarkdownGetID(ctx context.Context, roomID, markdown string) (string, error) {
|
||||
html := mdToHTML(markdown)
|
||||
content := &event.MessageEventContent{
|
||||
MsgType: event.MsgText,
|
||||
Body: markdown,
|
||||
Format: event.FormatHTML,
|
||||
FormattedBody: html,
|
||||
}
|
||||
resp, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.EventID.String(), nil
|
||||
}
|
||||
|
||||
// EditMessage edits a previously sent message in a room using m.replace relation.
|
||||
// originalEventID is the event ID of the message to replace.
|
||||
// The new content is rendered from markdown.
|
||||
func (c *Client) EditMessage(ctx context.Context, roomID, originalEventID, markdown string) error {
|
||||
html := mdToHTML(markdown)
|
||||
|
||||
// Matrix spec: m.new_content holds the replacement, m.relates_to with
|
||||
// rel_type=m.replace points to the original event.
|
||||
content := &event.MessageEventContent{
|
||||
MsgType: event.MsgText,
|
||||
Body: "* " + markdown, // per spec: prefix with "* " for fallback
|
||||
Format: event.FormatHTML,
|
||||
FormattedBody: "* " + html,
|
||||
RelatesTo: &event.RelatesTo{
|
||||
Type: event.RelReplace,
|
||||
EventID: id.EventID(originalEventID),
|
||||
},
|
||||
NewContent: &event.MessageEventContent{
|
||||
MsgType: event.MsgText,
|
||||
Body: markdown,
|
||||
Format: event.FormatHTML,
|
||||
FormattedBody: html,
|
||||
},
|
||||
}
|
||||
_, err := c.raw.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content)
|
||||
return err
|
||||
}
|
||||
|
||||
// mdToHTML converts a Markdown string to HTML using goldmark with full extensions.
|
||||
var mdParser = goldmark.New(
|
||||
goldmark.WithExtensions(
|
||||
|
||||
Reference in New Issue
Block a user