feat: import agents_and_robots platform as unibots (Matrix-out, unibus transport)
Reemplaza el scaffold del echobot por la plataforma completa de bots traida desde ~/DataProyects/Github/agents_and_robots tras la operacion Matrix-out: los bots ya no hablan por Matrix sino por el bus unibus (modelo todo-rooms + E2E via shell/transportunibus sobre github.com/enmanuel/unibus/pkg/client). - go.mod: replace de unibus -> ../unibus y de fn-registry -> ../../../.. (paths relativos reajustados a la nueva ubicacion dentro de fn_registry). - app.md: bump a 0.2.0, descripcion + arquitectura + comandos + gotchas reales. - modulo Go conservado como github.com/enmanuel/agents (sin reescribir imports). agents_and_robots queda archivado como museo de la era Matrix.
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
// Package bus provides in-process agent-to-agent message passing.
|
||||
package bus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Well-known message kinds used by the orchestrator.
|
||||
const (
|
||||
KindTask = "task" // orchestrator → bot: handle this question
|
||||
KindTaskResult = "task_result" // bot → orchestrator: here is my answer
|
||||
)
|
||||
|
||||
// AgentID identifies an agent.
|
||||
type AgentID string
|
||||
|
||||
// AgentMessage is a message between agents.
|
||||
type AgentMessage struct {
|
||||
From AgentID
|
||||
To AgentID
|
||||
Kind string
|
||||
Payload map[string]string
|
||||
}
|
||||
|
||||
// Bus manages channels for inter-agent communication.
|
||||
type Bus struct {
|
||||
mu sync.RWMutex
|
||||
channels map[AgentID]chan AgentMessage
|
||||
|
||||
replyMu sync.Mutex
|
||||
replyChs map[string]chan AgentMessage // taskID → one-shot reply channel
|
||||
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a new Bus.
|
||||
func New(logger *slog.Logger) *Bus {
|
||||
return &Bus{
|
||||
channels: make(map[AgentID]chan AgentMessage),
|
||||
replyChs: make(map[string]chan AgentMessage),
|
||||
logger: logger.With("component", "bus"),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers an agent and returns its receive channel.
|
||||
func (b *Bus) Subscribe(id AgentID) <-chan AgentMessage {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
ch := make(chan AgentMessage, 64)
|
||||
b.channels[id] = ch
|
||||
b.logger.Info("bus_subscribe", "agent", id)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Send delivers a message to an agent's channel.
|
||||
func (b *Bus) Send(msg AgentMessage) error {
|
||||
b.mu.RLock()
|
||||
ch, ok := b.channels[msg.To]
|
||||
b.mu.RUnlock()
|
||||
if !ok {
|
||||
b.logger.Warn("bus_not_found", "to", msg.To, "from", msg.From, "kind", msg.Kind)
|
||||
return fmt.Errorf("agent %q not registered on bus", msg.To)
|
||||
}
|
||||
select {
|
||||
case ch <- msg:
|
||||
b.logger.Debug("bus_send", "from", msg.From, "to", msg.To, "kind", msg.Kind)
|
||||
return nil
|
||||
default:
|
||||
b.logger.Warn("bus_queue_full", "to", msg.To, "from", msg.From, "kind", msg.Kind)
|
||||
return fmt.Errorf("agent %q message queue full", msg.To)
|
||||
}
|
||||
}
|
||||
|
||||
// SendAndWait sends a task message and blocks until a reply with the matching
|
||||
// taskID arrives or the context expires. The caller must ensure the reply is
|
||||
// routed via Reply().
|
||||
func (b *Bus) SendAndWait(ctx context.Context, msg AgentMessage, taskID string, timeout time.Duration) (AgentMessage, error) {
|
||||
ch := make(chan AgentMessage, 1)
|
||||
b.replyMu.Lock()
|
||||
b.replyChs[taskID] = ch
|
||||
b.replyMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
b.replyMu.Lock()
|
||||
delete(b.replyChs, taskID)
|
||||
b.replyMu.Unlock()
|
||||
}()
|
||||
|
||||
if err := b.Send(msg); err != nil {
|
||||
return AgentMessage{}, err
|
||||
}
|
||||
|
||||
b.logger.Debug("bus_send_and_wait", "task", taskID, "to", msg.To, "timeout", timeout)
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case reply := <-ch:
|
||||
return reply, nil
|
||||
case <-timer.C:
|
||||
b.logger.Warn("bus_timeout", "task", taskID, "to", msg.To, "timeout", timeout)
|
||||
return AgentMessage{}, fmt.Errorf("task %s: delegation timeout after %s", taskID, timeout)
|
||||
case <-ctx.Done():
|
||||
return AgentMessage{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Reply routes a task_result message to the waiting SendAndWait caller.
|
||||
// If no one is waiting for this taskID, it falls back to regular Send.
|
||||
func (b *Bus) Reply(taskID string, msg AgentMessage) error {
|
||||
b.replyMu.Lock()
|
||||
ch, ok := b.replyChs[taskID]
|
||||
b.replyMu.Unlock()
|
||||
|
||||
if ok {
|
||||
select {
|
||||
case ch <- msg:
|
||||
b.logger.Debug("bus_reply", "task", taskID, "from", msg.From)
|
||||
return nil
|
||||
default:
|
||||
b.logger.Warn("bus_reply_full", "task", taskID)
|
||||
return fmt.Errorf("reply channel full for task %s", taskID)
|
||||
}
|
||||
}
|
||||
// Fallback: deliver via regular channel
|
||||
return b.Send(msg)
|
||||
}
|
||||
|
||||
// Unsubscribe removes an agent from the bus.
|
||||
func (b *Bus) Unsubscribe(id AgentID) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if ch, ok := b.channels[id]; ok {
|
||||
close(ch)
|
||||
delete(b.channels, id)
|
||||
b.logger.Info("bus_unsubscribe", "agent", id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package bus_test
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/enmanuel/agents/shell/bus"
|
||||
)
|
||||
|
||||
func newBus() *bus.Bus {
|
||||
return bus.New(slog.Default())
|
||||
}
|
||||
|
||||
func TestSubscribeAndSend(t *testing.T) {
|
||||
b := newBus()
|
||||
ch := b.Subscribe("agent-a")
|
||||
|
||||
msg := bus.AgentMessage{From: "orch", To: "agent-a", Kind: bus.KindTask, Payload: map[string]string{"k": "v"}}
|
||||
if err := b.Send(msg); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
|
||||
got := <-ch
|
||||
if got.Kind != bus.KindTask || got.Payload["k"] != "v" {
|
||||
t.Fatalf("unexpected message: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribeClosesChannel(t *testing.T) {
|
||||
b := newBus()
|
||||
ch := b.Subscribe("agent-b")
|
||||
|
||||
b.Unsubscribe("agent-b")
|
||||
|
||||
// Channel must be closed — reading from a closed channel returns zero value + ok=false.
|
||||
_, ok := <-ch
|
||||
if ok {
|
||||
t.Fatal("expected channel to be closed after Unsubscribe")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribeRemovesFromBus(t *testing.T) {
|
||||
b := newBus()
|
||||
b.Subscribe("agent-c")
|
||||
b.Unsubscribe("agent-c")
|
||||
|
||||
// Sending after unsubscribe must return an error, not panic.
|
||||
err := b.Send(bus.AgentMessage{To: "agent-c", Kind: "ping"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when sending to unsubscribed agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribeIdempotent(t *testing.T) {
|
||||
b := newBus()
|
||||
b.Subscribe("agent-d")
|
||||
// Double unsubscribe must not panic.
|
||||
b.Unsubscribe("agent-d")
|
||||
b.Unsubscribe("agent-d")
|
||||
}
|
||||
|
||||
func TestUnsubscribeNonExistent(t *testing.T) {
|
||||
b := newBus()
|
||||
// Unsubscribing an ID that was never subscribed must not panic.
|
||||
b.Unsubscribe("does-not-exist")
|
||||
}
|
||||
|
||||
func TestSendToUnknownAgent(t *testing.T) {
|
||||
b := newBus()
|
||||
err := b.Send(bus.AgentMessage{To: "ghost", Kind: "hello"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when sending to unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResubscribeAfterUnsubscribe(t *testing.T) {
|
||||
b := newBus()
|
||||
b.Subscribe("agent-e")
|
||||
b.Unsubscribe("agent-e")
|
||||
|
||||
// Re-subscribe must work and deliver messages.
|
||||
ch2 := b.Subscribe("agent-e")
|
||||
msg := bus.AgentMessage{To: "agent-e", Kind: "ping"}
|
||||
if err := b.Send(msg); err != nil {
|
||||
t.Fatalf("Send after re-subscribe: %v", err)
|
||||
}
|
||||
got := <-ch2
|
||||
if got.Kind != "ping" {
|
||||
t.Fatalf("unexpected kind: %q", got.Kind)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
const actionKindSendMessage = "send_message"
|
||||
const actionKindLLMPrompt = "llm_prompt"
|
||||
|
||||
// handler is a function that fires when a schedule triggers.
|
||||
type handler func(ctx context.Context, room string)
|
||||
|
||||
// buildHandler returns the handler for a schedule, or nil for unsupported kinds.
|
||||
func (s *Scheduler) buildHandler(sc config.ScheduleCfg) handler {
|
||||
switch sc.Action.Kind {
|
||||
case actionKindSendMessage:
|
||||
return s.sendMessageHandler(sc)
|
||||
case actionKindLLMPrompt:
|
||||
return s.llmPromptHandler(sc)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sendMessageHandler returns a handler that sends a static message to a Matrix room.
|
||||
// The message content is resolved in priority order: Message > Template file.
|
||||
func (s *Scheduler) sendMessageHandler(sc config.ScheduleCfg) handler {
|
||||
return func(ctx context.Context, room string) {
|
||||
content, err := resolveContent(sc.Action.Message, sc.Action.Template)
|
||||
if err != nil {
|
||||
s.logger.Error("send_message: failed to resolve content",
|
||||
"name", sc.Name, "err", err)
|
||||
return
|
||||
}
|
||||
if content == "" {
|
||||
s.logger.Warn("send_message: empty content, skipping", "name", sc.Name)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("cron_fire", "name", sc.Name, "kind", actionKindSendMessage, "room", room)
|
||||
if err := s.sender.SendMarkdown(ctx, room, content); err != nil {
|
||||
s.logger.Error("send_message: bus send failed",
|
||||
"name", sc.Name, "room", room, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// llmPromptHandler returns a handler that calls the LLM with a prompt and sends
|
||||
// the response to a Matrix room.
|
||||
func (s *Scheduler) llmPromptHandler(sc config.ScheduleCfg) handler {
|
||||
return func(ctx context.Context, room string) {
|
||||
if s.llm == nil {
|
||||
s.logger.Warn("llm_prompt: no LLM configured, skipping", "name", sc.Name)
|
||||
return
|
||||
}
|
||||
|
||||
prompt, err := resolveContent(sc.Action.Prompt, sc.Action.Template)
|
||||
if err != nil {
|
||||
s.logger.Error("llm_prompt: failed to resolve prompt",
|
||||
"name", sc.Name, "err", err)
|
||||
return
|
||||
}
|
||||
if prompt == "" {
|
||||
s.logger.Warn("llm_prompt: empty prompt, skipping", "name", sc.Name)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("cron_fire", "name", sc.Name, "kind", actionKindLLMPrompt, "room", room)
|
||||
|
||||
req := coretypes.CompletionRequest{
|
||||
Model: s.model,
|
||||
Messages: []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: prompt},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := s.llm(ctx, req)
|
||||
if err != nil {
|
||||
s.logger.Error("llm_prompt: LLM call failed",
|
||||
"name", sc.Name, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if content == "" {
|
||||
s.logger.Warn("llm_prompt: LLM returned empty response", "name", sc.Name)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.sender.SendMarkdown(ctx, room, content); err != nil {
|
||||
s.logger.Error("llm_prompt: bus send failed",
|
||||
"name", sc.Name, "room", room, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveContent returns the inline text if non-empty, otherwise reads the file at templatePath.
|
||||
func resolveContent(inline, templatePath string) (string, error) {
|
||||
if inline != "" {
|
||||
return inline, nil
|
||||
}
|
||||
if templatePath == "" {
|
||||
return "", nil
|
||||
}
|
||||
data, err := os.ReadFile(templatePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading template %q: %w", templatePath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
// Package cron provides a scheduler for autonomous bot activity.
|
||||
// It is part of the impure shell: it reads files, calls LLMs, and sends messages
|
||||
// over the bot's transport (unibus).
|
||||
package cron
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// Sender is the subset of the bot's transport sender needed by the scheduler.
|
||||
type Sender interface {
|
||||
SendMarkdown(ctx context.Context, roomID, markdown string) error
|
||||
}
|
||||
|
||||
// Scheduler fires configured schedules and executes send_message or llm_prompt actions.
|
||||
type Scheduler struct {
|
||||
cfg []config.ScheduleCfg
|
||||
sender Sender
|
||||
llm coretypes.CompleteFunc // nil when agent has no LLM
|
||||
model string
|
||||
logger *slog.Logger
|
||||
cron *cron.Cron
|
||||
}
|
||||
|
||||
// New creates a Scheduler. llm and model are optional (nil/empty for agents without LLM).
|
||||
func New(
|
||||
cfg []config.ScheduleCfg,
|
||||
sender Sender,
|
||||
llm coretypes.CompleteFunc,
|
||||
model string,
|
||||
logger *slog.Logger,
|
||||
) *Scheduler {
|
||||
return &Scheduler{
|
||||
cfg: cfg,
|
||||
sender: sender,
|
||||
llm: llm,
|
||||
model: model,
|
||||
logger: logger.With("component", "cron"),
|
||||
cron: cron.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Fire immediately executes the action for the given schedule, bypassing the cron timer.
|
||||
// Useful for tests and manual triggering from CLI.
|
||||
func (s *Scheduler) Fire(ctx context.Context, sc config.ScheduleCfg) {
|
||||
room := sc.OutputRoom
|
||||
if room == "" {
|
||||
s.logger.Warn("Fire: schedule has no output_room, skipping", "name", sc.Name)
|
||||
return
|
||||
}
|
||||
handler := s.buildHandler(sc)
|
||||
if handler == nil {
|
||||
s.logger.Warn("Fire: unsupported action kind", "name", sc.Name, "kind", sc.Action.Kind)
|
||||
return
|
||||
}
|
||||
handler(ctx, room)
|
||||
}
|
||||
|
||||
// Start registers all schedules and starts the cron loop.
|
||||
// It returns when ctx is cancelled, stopping the cron runner.
|
||||
func (s *Scheduler) Start(ctx context.Context) {
|
||||
for _, sc := range s.cfg {
|
||||
sc := sc // capture range var
|
||||
if sc.Cron == "" || sc.Action.Kind == "" {
|
||||
s.logger.Warn("skipping invalid schedule", "name", sc.Name, "cron", sc.Cron, "kind", sc.Action.Kind)
|
||||
continue
|
||||
}
|
||||
|
||||
room := sc.OutputRoom
|
||||
if room == "" {
|
||||
s.logger.Warn("schedule has no output_room, skipping", "name", sc.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
handler := s.buildHandler(sc)
|
||||
if handler == nil {
|
||||
s.logger.Warn("unsupported action kind, skipping", "name", sc.Name, "kind", sc.Action.Kind)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err := s.cron.AddFunc(sc.Cron, func() {
|
||||
handler(ctx, room)
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Error("failed to register schedule",
|
||||
"name", sc.Name,
|
||||
"cron", sc.Cron,
|
||||
"err", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info("schedule registered", "name", sc.Name, "cron", sc.Cron, "kind", sc.Action.Kind, "room", room)
|
||||
}
|
||||
|
||||
s.cron.Start()
|
||||
s.logger.Info("cron scheduler started", "schedules", len(s.cfg))
|
||||
|
||||
<-ctx.Done()
|
||||
s.logger.Info("cron scheduler stopping")
|
||||
cronCtx := s.cron.Stop()
|
||||
<-cronCtx.Done()
|
||||
s.logger.Info("cron scheduler stopped")
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
package cron_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
shellcron "github.com/enmanuel/agents/shell/cron"
|
||||
)
|
||||
|
||||
// ── fakes ──────────────────────────────────────────────────────────────────
|
||||
|
||||
type fakeSender struct {
|
||||
calls atomic.Int32
|
||||
lastMD string
|
||||
lastRM string
|
||||
}
|
||||
|
||||
func (f *fakeSender) SendMarkdown(_ context.Context, room, md string) error {
|
||||
f.calls.Add(1)
|
||||
f.lastRM = room
|
||||
f.lastMD = md
|
||||
return nil
|
||||
}
|
||||
|
||||
type errSender struct{}
|
||||
|
||||
func (e *errSender) SendMarkdown(_ context.Context, _, _ string) error {
|
||||
return errors.New("matrix unavailable")
|
||||
}
|
||||
|
||||
func fakeLLM(reply string) coretypes.CompleteFunc {
|
||||
return func(_ context.Context, _ coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
|
||||
return coretypes.CompletionResponse{Content: reply}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func newTestLogger(t *testing.T) *slog.Logger {
|
||||
t.Helper()
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
}
|
||||
|
||||
// ── helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
// waitCalls blocks until the sender has received at least n calls or the deadline passes.
|
||||
func waitCalls(t *testing.T, f *fakeSender, n int32) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if f.calls.Load() >= n {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("expected %d call(s) to SendMarkdown, got %d", n, f.calls.Load())
|
||||
}
|
||||
|
||||
// ── cron-based tests (require timer) ──────────────────────────────────────
|
||||
|
||||
func TestScheduler_SendMessage_Inline(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
cfg := []config.ScheduleCfg{
|
||||
{
|
||||
Name: "test-inline",
|
||||
Cron: "@every 100ms",
|
||||
OutputRoom: "!room:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "send_message",
|
||||
Message: "hola mundo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := shellcron.New(cfg, sender, nil, "", newTestLogger(t))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.Start(ctx)
|
||||
}()
|
||||
|
||||
waitCalls(t, sender, 1)
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if sender.lastRM != "!room:server.com" {
|
||||
t.Errorf("unexpected room: %s", sender.lastRM)
|
||||
}
|
||||
if sender.lastMD != "hola mundo" {
|
||||
t.Errorf("unexpected message: %s", sender.lastMD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduler_SendMessage_Template(t *testing.T) {
|
||||
// Write a temporary template file
|
||||
dir := t.TempDir()
|
||||
tmpl := filepath.Join(dir, "greeting.md")
|
||||
if err := os.WriteFile(tmpl, []byte("buenos días"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sender := &fakeSender{}
|
||||
cfg := []config.ScheduleCfg{
|
||||
{
|
||||
Name: "test-template",
|
||||
Cron: "@every 100ms",
|
||||
OutputRoom: "!room2:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "send_message",
|
||||
Template: tmpl,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := shellcron.New(cfg, sender, nil, "", newTestLogger(t))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.Start(ctx)
|
||||
}()
|
||||
|
||||
waitCalls(t, sender, 1)
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if sender.lastMD != "buenos días" {
|
||||
t.Errorf("unexpected message: %q", sender.lastMD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduler_LLMPrompt(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
cfg := []config.ScheduleCfg{
|
||||
{
|
||||
Name: "test-llm",
|
||||
Cron: "@every 100ms",
|
||||
OutputRoom: "!room3:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "llm_prompt",
|
||||
Prompt: "resume el día",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
llm := fakeLLM("resumen generado por LLM")
|
||||
s := shellcron.New(cfg, sender, llm, "gpt-4o", newTestLogger(t))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.Start(ctx)
|
||||
}()
|
||||
|
||||
waitCalls(t, sender, 1)
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if sender.lastMD != "resumen generado por LLM" {
|
||||
t.Errorf("unexpected LLM reply: %q", sender.lastMD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduler_MatrixSendError(t *testing.T) {
|
||||
// If matrix.SendMarkdown returns an error, the scheduler should log it and not panic.
|
||||
cfg := []config.ScheduleCfg{
|
||||
{
|
||||
Name: "err-send",
|
||||
Cron: "@every 100ms",
|
||||
OutputRoom: "!room:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "send_message",
|
||||
Message: "trigger error",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := shellcron.New(cfg, &errSender{}, nil, "", newTestLogger(t))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.Start(ctx)
|
||||
}()
|
||||
|
||||
// Let it fire at least once without panicking
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
cancel()
|
||||
<-done
|
||||
}
|
||||
|
||||
// ── Fire() tests (deterministic, no timer) ─────────────────────────────────
|
||||
|
||||
func TestFire_SendMessage_Inline(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
|
||||
|
||||
sc := config.ScheduleCfg{
|
||||
Name: "fire-inline",
|
||||
Cron: "0 9 * * *",
|
||||
OutputRoom: "!fireroom:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "send_message",
|
||||
Message: "buenos días via Fire",
|
||||
},
|
||||
}
|
||||
|
||||
s.Fire(context.Background(), sc)
|
||||
|
||||
if sender.calls.Load() != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", sender.calls.Load())
|
||||
}
|
||||
if sender.lastRM != "!fireroom:server.com" {
|
||||
t.Errorf("unexpected room: %s", sender.lastRM)
|
||||
}
|
||||
if sender.lastMD != "buenos días via Fire" {
|
||||
t.Errorf("unexpected message: %s", sender.lastMD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFire_LLMPrompt(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
llm := fakeLLM("respuesta del LLM via Fire")
|
||||
s := shellcron.New(nil, sender, llm, "gpt-4o", newTestLogger(t))
|
||||
|
||||
sc := config.ScheduleCfg{
|
||||
Name: "fire-llm",
|
||||
Cron: "0 18 * * *",
|
||||
OutputRoom: "!llmroom:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "llm_prompt",
|
||||
Prompt: "resume el día",
|
||||
},
|
||||
}
|
||||
|
||||
s.Fire(context.Background(), sc)
|
||||
|
||||
if sender.calls.Load() != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", sender.calls.Load())
|
||||
}
|
||||
if sender.lastMD != "respuesta del LLM via Fire" {
|
||||
t.Errorf("unexpected LLM reply: %q", sender.lastMD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFire_NoOutputRoom_Skips(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
|
||||
|
||||
sc := config.ScheduleCfg{
|
||||
Name: "fire-no-room",
|
||||
Cron: "0 9 * * *",
|
||||
OutputRoom: "", // intentionally empty
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "send_message",
|
||||
Message: "should not send",
|
||||
},
|
||||
}
|
||||
|
||||
s.Fire(context.Background(), sc)
|
||||
|
||||
if sender.calls.Load() != 0 {
|
||||
t.Errorf("expected 0 calls when output_room is empty, got %d", sender.calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFire_LLMPrompt_NoLLM_Skips(t *testing.T) {
|
||||
// When no LLM is configured, Fire with llm_prompt should not send anything.
|
||||
sender := &fakeSender{}
|
||||
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
|
||||
|
||||
sc := config.ScheduleCfg{
|
||||
Name: "fire-no-llm",
|
||||
Cron: "0 9 * * *",
|
||||
OutputRoom: "!room:server.com",
|
||||
Action: config.ScheduledAction{
|
||||
Kind: "llm_prompt",
|
||||
Prompt: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
s.Fire(context.Background(), sc)
|
||||
|
||||
if sender.calls.Load() != 0 {
|
||||
t.Errorf("expected 0 calls without LLM, got %d", sender.calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduler_SkipsInvalidSchedule(t *testing.T) {
|
||||
// Schedules without output_room or without action kind must be skipped during Start.
|
||||
// We use Fire directly to test the skip logic without timer overhead.
|
||||
sender := &fakeSender{}
|
||||
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
|
||||
ctx := context.Background()
|
||||
|
||||
// No output_room → skip
|
||||
s.Fire(ctx, config.ScheduleCfg{
|
||||
Name: "no-room",
|
||||
Cron: "@every 100ms",
|
||||
// missing OutputRoom
|
||||
Action: config.ScheduledAction{Kind: "send_message", Message: "hi"},
|
||||
})
|
||||
|
||||
// No kind → Fire calls buildHandler which returns nil → skip
|
||||
s.Fire(ctx, config.ScheduleCfg{
|
||||
Name: "no-kind",
|
||||
Cron: "@every 100ms",
|
||||
OutputRoom: "!room:server.com",
|
||||
// missing Action.Kind
|
||||
})
|
||||
|
||||
if sender.calls.Load() != 0 {
|
||||
t.Errorf("expected 0 calls for invalid schedules, got %d", sender.calls.Load())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
// Package effects interprets pure []decision.Action values into real side effects.
|
||||
package effects
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/decision"
|
||||
"github.com/enmanuel/agents/shell/logger"
|
||||
"github.com/enmanuel/agents/shell/ssh"
|
||||
)
|
||||
|
||||
// Result holds the outcome of executing a single action.
|
||||
type Result struct {
|
||||
Action decision.Action
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Sender is the transport-neutral message-sending capability the runner depends
|
||||
// on. It is satisfied by the unibus bus sender (and was satisfied by the Matrix
|
||||
// client before the bus migration). SendTyping is a no-op on transports without
|
||||
// typing indicators.
|
||||
type Sender interface {
|
||||
SendText(ctx context.Context, roomID, text string) error
|
||||
SendMarkdown(ctx context.Context, roomID, markdown string) error
|
||||
SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error
|
||||
SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error
|
||||
SendTyping(ctx context.Context, roomID string, typing bool) error
|
||||
}
|
||||
|
||||
// Runner interprets actions and executes them.
|
||||
type Runner struct {
|
||||
sender Sender
|
||||
ssh *ssh.Executor
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRunner creates a Runner with the provided dependencies.
|
||||
func NewRunner(sender Sender, ssh *ssh.Executor, logger *slog.Logger) *Runner {
|
||||
return &Runner{sender: sender, ssh: ssh, logger: logger}
|
||||
}
|
||||
|
||||
// Execute runs each action sequentially and returns results.
|
||||
func (r *Runner) Execute(ctx context.Context, roomID string, actions []decision.Action) []Result {
|
||||
r.logger.Debug("effects_batch", "room", roomID, "count", len(actions))
|
||||
results := make([]Result, 0, len(actions))
|
||||
for _, a := range actions {
|
||||
start := time.Now()
|
||||
res := r.executeOne(ctx, roomID, a)
|
||||
ms := time.Since(start).Milliseconds()
|
||||
results = append(results, res)
|
||||
if res.Err != nil {
|
||||
r.logger.Error("action_failed", logger.FieldAction, a.Kind, logger.FieldDurationMS, ms, "err", res.Err)
|
||||
} else {
|
||||
r.logger.Info("action_done", logger.FieldAction, a.Kind, logger.FieldDurationMS, ms)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (r *Runner) executeOne(ctx context.Context, roomID string, a decision.Action) Result {
|
||||
switch a.Kind {
|
||||
case decision.ActionKindReply:
|
||||
if a.Reply == nil {
|
||||
return Result{Action: a, Err: fmt.Errorf("nil reply action")}
|
||||
}
|
||||
var err error
|
||||
switch {
|
||||
case a.Reply.ThreadID != "":
|
||||
// Thread reply: send as part of the thread with fallback in_reply_to
|
||||
err = r.sender.SendThreadMarkdown(ctx, roomID, a.Reply.ThreadID, a.Reply.InReplyTo, a.Reply.Content)
|
||||
case a.Reply.InReplyTo != "":
|
||||
err = r.sender.SendReplyMarkdown(ctx, roomID, a.Reply.InReplyTo, a.Reply.Content)
|
||||
default:
|
||||
err = r.sender.SendMarkdown(ctx, roomID, a.Reply.Content)
|
||||
}
|
||||
return Result{Action: a, Output: a.Reply.Content, Err: err}
|
||||
|
||||
case decision.ActionKindSSH:
|
||||
if a.SSH == nil {
|
||||
return Result{Action: a, Err: fmt.Errorf("nil ssh action")}
|
||||
}
|
||||
res := r.ssh.Execute(ctx, *a.SSH)
|
||||
output := res.Stdout
|
||||
if res.Stderr != "" {
|
||||
output += "\nstderr: " + res.Stderr
|
||||
}
|
||||
return Result{Action: a, Output: output, Err: res.Err}
|
||||
|
||||
default:
|
||||
return Result{Action: a, Err: fmt.Errorf("unhandled action kind: %s", a.Kind)}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package effects
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/decision"
|
||||
)
|
||||
|
||||
// fakeSender records calls for assertions.
|
||||
type fakeSender struct {
|
||||
calls []senderCall
|
||||
}
|
||||
|
||||
type senderCall struct {
|
||||
method string
|
||||
roomID string
|
||||
threadRootID string
|
||||
inReplyTo string
|
||||
markdown string
|
||||
}
|
||||
|
||||
func (f *fakeSender) SendText(ctx context.Context, roomID, text string) error {
|
||||
f.calls = append(f.calls, senderCall{method: "SendText", roomID: roomID, markdown: text})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSender) SendMarkdown(ctx context.Context, roomID, markdown string) error {
|
||||
f.calls = append(f.calls, senderCall{method: "SendMarkdown", roomID: roomID, markdown: markdown})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSender) SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error {
|
||||
f.calls = append(f.calls, senderCall{method: "SendReplyMarkdown", roomID: roomID, inReplyTo: inReplyTo, markdown: markdown})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSender) SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error {
|
||||
f.calls = append(f.calls, senderCall{method: "SendThreadMarkdown", roomID: roomID, threadRootID: threadRootID, inReplyTo: inReplyTo, markdown: markdown})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSender) SendTyping(ctx context.Context, roomID string, typing bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExecuteReply_PlainMarkdown(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
runner := NewRunner(sender, nil, slog.Default())
|
||||
|
||||
actions := []decision.Action{{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{Content: "hello"},
|
||||
}}
|
||||
|
||||
results := runner.Execute(context.Background(), "!room:test", actions)
|
||||
if len(results) != 1 || results[0].Err != nil {
|
||||
t.Fatalf("unexpected results: %+v", results)
|
||||
}
|
||||
if len(sender.calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(sender.calls))
|
||||
}
|
||||
c := sender.calls[0]
|
||||
if c.method != "SendMarkdown" {
|
||||
t.Errorf("expected SendMarkdown, got %s", c.method)
|
||||
}
|
||||
if c.roomID != "!room:test" {
|
||||
t.Errorf("expected room !room:test, got %s", c.roomID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReply_WithInReplyTo(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
runner := NewRunner(sender, nil, slog.Default())
|
||||
|
||||
actions := []decision.Action{{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{Content: "hello", InReplyTo: "$evt1"},
|
||||
}}
|
||||
|
||||
runner.Execute(context.Background(), "!room:test", actions)
|
||||
|
||||
if len(sender.calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(sender.calls))
|
||||
}
|
||||
c := sender.calls[0]
|
||||
if c.method != "SendReplyMarkdown" {
|
||||
t.Errorf("expected SendReplyMarkdown, got %s", c.method)
|
||||
}
|
||||
if c.inReplyTo != "$evt1" {
|
||||
t.Errorf("expected inReplyTo=$evt1, got %s", c.inReplyTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReply_WithThread(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
runner := NewRunner(sender, nil, slog.Default())
|
||||
|
||||
actions := []decision.Action{{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{
|
||||
Content: "thread reply",
|
||||
ThreadID: "$root",
|
||||
InReplyTo: "$evt2",
|
||||
},
|
||||
}}
|
||||
|
||||
runner.Execute(context.Background(), "!room:test", actions)
|
||||
|
||||
if len(sender.calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(sender.calls))
|
||||
}
|
||||
c := sender.calls[0]
|
||||
if c.method != "SendThreadMarkdown" {
|
||||
t.Errorf("expected SendThreadMarkdown, got %s", c.method)
|
||||
}
|
||||
if c.threadRootID != "$root" {
|
||||
t.Errorf("expected threadRootID=$root, got %s", c.threadRootID)
|
||||
}
|
||||
if c.inReplyTo != "$evt2" {
|
||||
t.Errorf("expected inReplyTo=$evt2, got %s", c.inReplyTo)
|
||||
}
|
||||
if c.roomID != "!room:test" {
|
||||
t.Errorf("expected room !room:test, got %s", c.roomID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReply_ThreadWithoutInReplyTo(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
runner := NewRunner(sender, nil, slog.Default())
|
||||
|
||||
actions := []decision.Action{{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{
|
||||
Content: "thread reply no fallback",
|
||||
ThreadID: "$root",
|
||||
},
|
||||
}}
|
||||
|
||||
runner.Execute(context.Background(), "!room:test", actions)
|
||||
|
||||
if len(sender.calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(sender.calls))
|
||||
}
|
||||
c := sender.calls[0]
|
||||
if c.method != "SendThreadMarkdown" {
|
||||
t.Errorf("expected SendThreadMarkdown, got %s", c.method)
|
||||
}
|
||||
// inReplyTo should be empty; SendThreadMarkdown will default to threadRootID
|
||||
if c.inReplyTo != "" {
|
||||
t.Errorf("expected empty inReplyTo, got %s", c.inReplyTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReply_NilReply(t *testing.T) {
|
||||
sender := &fakeSender{}
|
||||
runner := NewRunner(sender, nil, slog.Default())
|
||||
|
||||
actions := []decision.Action{{
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: nil,
|
||||
}}
|
||||
|
||||
results := runner.Execute(context.Background(), "!room:test", actions)
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
if results[0].Err == nil {
|
||||
t.Error("expected error for nil reply")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package shellknowledge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
moderncsqlite "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register pure-Go SQLite driver as "sqlite3" for tests.
|
||||
sql.Register("sqlite3", &moderncsqlite.Driver{})
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
// Package shellknowledge implements the knowledge store using files + SQLite FTS5.
|
||||
package shellknowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/knowledge"
|
||||
)
|
||||
|
||||
const ftsSchema = `
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS documents USING fts5(
|
||||
slug,
|
||||
title,
|
||||
content,
|
||||
updated_at UNINDEXED
|
||||
);
|
||||
`
|
||||
|
||||
var slugRe = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{0,62}[a-z0-9]$`)
|
||||
|
||||
// ValidSlug returns true if s is a valid document slug.
|
||||
func ValidSlug(s string) bool {
|
||||
if len(s) < 2 || len(s) > 64 {
|
||||
return false
|
||||
}
|
||||
return slugRe.MatchString(s)
|
||||
}
|
||||
|
||||
// FileStore implements knowledge.Store using markdown files + SQLite FTS5 index.
|
||||
type FileStore struct {
|
||||
dir string // path to agents/<id>/knowledge/
|
||||
dbPath string // path to agents/<id>/data/knowledge.db
|
||||
db *sql.DB
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a FileStore. It ensures the knowledge dir and DB dir exist,
|
||||
// opens the SQLite database, and creates the FTS5 table if needed.
|
||||
func New(dir, dbPath string, logger *slog.Logger) (*FileStore, error) {
|
||||
log := logger.With("component", "knowledge", "dir", dir, "db_path", dbPath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create knowledge dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create knowledge db dir: %w", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open knowledge db: %w", err)
|
||||
}
|
||||
|
||||
// Enable WAL mode for better concurrency (allows multiple readers + single writer)
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("enable WAL mode: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ftsSchema); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("create knowledge fts5 table: %w", err)
|
||||
}
|
||||
|
||||
log.Info("knowledge_store_ready")
|
||||
return &FileStore{dir: dir, dbPath: dbPath, db: db, logger: log}, nil
|
||||
}
|
||||
|
||||
// Sync re-indexes all .md files from disk into the FTS5 table.
|
||||
func (s *FileStore) Sync(ctx context.Context) error {
|
||||
entries, err := os.ReadDir(s.dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read knowledge dir: %w", err)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin sync tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Clear existing index
|
||||
if _, err := tx.ExecContext(ctx, `DELETE FROM documents`); err != nil {
|
||||
return fmt.Errorf("clear fts5 index: %w", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
slug := strings.TrimSuffix(e.Name(), ".md")
|
||||
if !ValidSlug(slug) {
|
||||
s.logger.Warn("skipping invalid slug", "file", e.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filepath.Join(s.dir, e.Name()))
|
||||
if err != nil {
|
||||
s.logger.Warn("skipping unreadable file", "file", e.Name(), "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := e.Info()
|
||||
if err != nil {
|
||||
s.logger.Warn("skipping file without info", "file", e.Name(), "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
title := extractTitle(string(content), slug)
|
||||
mtime := info.ModTime().UTC().Format(time.RFC3339)
|
||||
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`INSERT INTO documents (slug, title, content, updated_at) VALUES (?, ?, ?, ?)`,
|
||||
slug, title, string(content), mtime,
|
||||
); err != nil {
|
||||
s.logger.Warn("failed to index file", "slug", slug, "err", err)
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit sync tx: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("knowledge_sync", "count", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search performs full-text search on the FTS5 index.
|
||||
func (s *FileStore) Search(ctx context.Context, query string, limit int) ([]knowledge.SearchResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT slug, title, snippet(documents, 2, '**', '**', '…', 32), rank
|
||||
FROM documents WHERE documents MATCH ?
|
||||
ORDER BY rank LIMIT ?`,
|
||||
query, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("knowledge search: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []knowledge.SearchResult
|
||||
for rows.Next() {
|
||||
var r knowledge.SearchResult
|
||||
if err := rows.Scan(&r.Slug, &r.Title, &r.Snippet, &r.Rank); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// Get reads a document from disk by slug.
|
||||
func (s *FileStore) Get(ctx context.Context, slug string) (*knowledge.Document, error) {
|
||||
if !ValidSlug(slug) {
|
||||
return nil, fmt.Errorf("invalid slug: %q", slug)
|
||||
}
|
||||
|
||||
path := filepath.Join(s.dir, slug+".md")
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("document not found: %q", slug)
|
||||
}
|
||||
return nil, fmt.Errorf("read document: %w", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat document: %w", err)
|
||||
}
|
||||
|
||||
return &knowledge.Document{
|
||||
Slug: slug,
|
||||
Title: extractTitle(string(content), slug),
|
||||
Content: string(content),
|
||||
UpdatedAt: info.ModTime().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Put writes a document to disk and updates the FTS5 index.
|
||||
func (s *FileStore) Put(ctx context.Context, doc knowledge.Document) error {
|
||||
if !ValidSlug(doc.Slug) {
|
||||
return fmt.Errorf("invalid slug: %q", doc.Slug)
|
||||
}
|
||||
if len(doc.Content) > 64*1024 {
|
||||
return fmt.Errorf("document too large: %d bytes (max 65536)", len(doc.Content))
|
||||
}
|
||||
|
||||
path := filepath.Join(s.dir, doc.Slug+".md")
|
||||
if err := os.WriteFile(path, []byte(doc.Content), 0o644); err != nil {
|
||||
return fmt.Errorf("write document: %w", err)
|
||||
}
|
||||
|
||||
title := extractTitle(doc.Content, doc.Slug)
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
// Upsert: delete old + insert new (FTS5 doesn't support UPDATE well)
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin put tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `DELETE FROM documents WHERE slug = ?`, doc.Slug); err != nil {
|
||||
return fmt.Errorf("delete old index: %w", err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`INSERT INTO documents (slug, title, content, updated_at) VALUES (?, ?, ?, ?)`,
|
||||
doc.Slug, title, doc.Content, now,
|
||||
); err != nil {
|
||||
return fmt.Errorf("insert index: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit put tx: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug("knowledge_put", "slug", doc.Slug, "size", len(doc.Content))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a document from disk and the FTS5 index.
|
||||
func (s *FileStore) Delete(ctx context.Context, slug string) error {
|
||||
if !ValidSlug(slug) {
|
||||
return fmt.Errorf("invalid slug: %q", slug)
|
||||
}
|
||||
|
||||
path := filepath.Join(s.dir, slug+".md")
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove document: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, `DELETE FROM documents WHERE slug = ?`, slug); err != nil {
|
||||
return fmt.Errorf("delete from index: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug("knowledge_delete", "slug", slug)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns all documents from the FTS5 index.
|
||||
func (s *FileStore) List(ctx context.Context) ([]knowledge.Document, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT slug, title, updated_at FROM documents ORDER BY slug`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("knowledge list: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []knowledge.Document
|
||||
for rows.Next() {
|
||||
var d knowledge.Document
|
||||
var updatedAt string
|
||||
if err := rows.Scan(&d.Slug, &d.Title, &updatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
docs = append(docs, d)
|
||||
}
|
||||
return docs, rows.Err()
|
||||
}
|
||||
|
||||
// Close releases the SQLite database.
|
||||
func (s *FileStore) Close() error {
|
||||
s.logger.Info("knowledge_store_closed")
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// extractTitle returns the first H1 heading from markdown content, or a humanized slug.
|
||||
func extractTitle(content, slug string) string {
|
||||
for _, line := range strings.SplitN(content, "\n", 20) {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "# ") {
|
||||
return strings.TrimPrefix(line, "# ")
|
||||
}
|
||||
}
|
||||
// Humanize slug: "go-patterns" → "Go patterns"
|
||||
humanized := strings.ReplaceAll(slug, "-", " ")
|
||||
if len(humanized) > 0 {
|
||||
humanized = strings.ToUpper(humanized[:1]) + humanized[1:]
|
||||
}
|
||||
return humanized
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
package shellknowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/knowledge"
|
||||
)
|
||||
|
||||
func testStore(t *testing.T) (*FileStore, string) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
knowledgeDir := filepath.Join(dir, "knowledge")
|
||||
dbPath := filepath.Join(dir, "data", "knowledge.db")
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
store, err := New(knowledgeDir, dbPath, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return store, knowledgeDir
|
||||
}
|
||||
|
||||
func TestValidSlug(t *testing.T) {
|
||||
tests := []struct {
|
||||
slug string
|
||||
want bool
|
||||
}{
|
||||
{"go-patterns", true},
|
||||
{"ab", true},
|
||||
{"a-b", true},
|
||||
{"abc123", true},
|
||||
{"a", false}, // too short
|
||||
{"A-B", false}, // uppercase
|
||||
{"-bad", false}, // starts with hyphen
|
||||
{"bad-", false}, // ends with hyphen
|
||||
{"has space", false}, // space
|
||||
{"has_underscore", false}, // underscore
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := ValidSlug(tt.slug); got != tt.want {
|
||||
t.Errorf("ValidSlug(%q) = %v, want %v", tt.slug, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutAndGet(t *testing.T) {
|
||||
store, _ := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
doc := knowledge.Document{
|
||||
Slug: "test-doc",
|
||||
Content: "# Test Document\n\nThis is a test.",
|
||||
}
|
||||
|
||||
if err := store.Put(ctx, doc); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := store.Get(ctx, "test-doc")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Content != doc.Content {
|
||||
t.Errorf("content mismatch: got %q, want %q", got.Content, doc.Content)
|
||||
}
|
||||
if got.Title != "Test Document" {
|
||||
t.Errorf("title = %q, want %q", got.Title, "Test Document")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutInvalidSlug(t *testing.T) {
|
||||
store, _ := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.Put(ctx, knowledge.Document{Slug: "BAD", Content: "test"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid slug")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutTooLarge(t *testing.T) {
|
||||
store, _ := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
bigContent := make([]byte, 65*1024)
|
||||
for i := range bigContent {
|
||||
bigContent[i] = 'x'
|
||||
}
|
||||
err := store.Put(ctx, knowledge.Document{Slug: "too-big", Content: string(bigContent)})
|
||||
if err == nil {
|
||||
t.Error("expected error for oversized document")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncAndSearch(t *testing.T) {
|
||||
store, knowledgeDir := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Write files directly to disk
|
||||
os.WriteFile(filepath.Join(knowledgeDir, "go-patterns.md"),
|
||||
[]byte("# Go Patterns\n\nUse interfaces for dependency injection."), 0o644)
|
||||
os.WriteFile(filepath.Join(knowledgeDir, "matrix-tips.md"),
|
||||
[]byte("# Matrix Tips\n\nUse mautrix-go for Matrix bots."), 0o644)
|
||||
|
||||
if err := store.Sync(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Search for "interfaces"
|
||||
results, err := store.Search(ctx, "interfaces", 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected at least 1 search result")
|
||||
}
|
||||
if results[0].Slug != "go-patterns" {
|
||||
t.Errorf("expected slug go-patterns, got %q", results[0].Slug)
|
||||
}
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
store, _ := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty initially
|
||||
docs, err := store.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(docs) != 0 {
|
||||
t.Errorf("expected 0 docs, got %d", len(docs))
|
||||
}
|
||||
|
||||
// Add two docs
|
||||
store.Put(ctx, knowledge.Document{Slug: "alpha", Content: "# Alpha\nContent A"})
|
||||
store.Put(ctx, knowledge.Document{Slug: "beta", Content: "# Beta\nContent B"})
|
||||
|
||||
docs, err = store.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(docs) != 2 {
|
||||
t.Fatalf("expected 2 docs, got %d", len(docs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
store, knowledgeDir := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
store.Put(ctx, knowledge.Document{Slug: "to-delete", Content: "# Delete Me\nGoodbye"})
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(filepath.Join(knowledgeDir, "to-delete.md")); err != nil {
|
||||
t.Fatal("file should exist after Put")
|
||||
}
|
||||
|
||||
if err := store.Delete(ctx, "to-delete"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// File removed
|
||||
if _, err := os.Stat(filepath.Join(knowledgeDir, "to-delete.md")); !os.IsNotExist(err) {
|
||||
t.Error("file should be removed after Delete")
|
||||
}
|
||||
|
||||
// Not in index
|
||||
_, err := store.Get(ctx, "to-delete")
|
||||
if err == nil {
|
||||
t.Error("expected error for deleted document")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNotFound(t *testing.T) {
|
||||
store, _ := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := store.Get(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent document")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTitle(t *testing.T) {
|
||||
tests := []struct {
|
||||
content string
|
||||
slug string
|
||||
want string
|
||||
}{
|
||||
{"# My Title\nBody", "slug", "My Title"},
|
||||
{"No heading here", "my-doc", "My doc"},
|
||||
{"", "empty-doc", "Empty doc"},
|
||||
{"\n\n# Late Title\n", "slug", "Late Title"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := extractTitle(tt.content, tt.slug)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractTitle(%q, %q) = %q, want %q", tt.content, tt.slug, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
// Package llm contains impure LLM provider implementations.
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/shell/logger"
|
||||
)
|
||||
|
||||
const anthropicAPIBase = "https://api.anthropic.com/v1"
|
||||
const anthropicVersion = "2023-06-01"
|
||||
|
||||
// NewAnthropicComplete returns a CompleteFunc backed by the Anthropic API.
|
||||
func NewAnthropicComplete(apiKeyEnv, baseURL string, log *slog.Logger) coretypes.CompleteFunc {
|
||||
if baseURL == "" {
|
||||
baseURL = anthropicAPIBase
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
|
||||
apiKey := os.Getenv(apiKeyEnv)
|
||||
if apiKey == "" {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("env var %s is not set", apiKeyEnv)
|
||||
}
|
||||
|
||||
log.Info("llm_request",
|
||||
"provider", "anthropic",
|
||||
"model", req.Model,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
)
|
||||
|
||||
body := toAnthropicRequest(req)
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/messages", bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return coretypes.CompletionResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", apiKey)
|
||||
httpReq.Header.Set("anthropic-version", anthropicVersion)
|
||||
httpReq.Header.Set("content-type", "application/json")
|
||||
|
||||
start := time.Now()
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
ms := time.Since(start).Milliseconds()
|
||||
log.Error("llm_error", "provider", "anthropic", logger.FieldDurationMS, ms, "err", err)
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("anthropic request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
ms := time.Since(start).Milliseconds()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Error("llm_error", "provider", "anthropic", logger.FieldDurationMS, ms, "status", resp.StatusCode)
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("anthropic error %d: %s", resp.StatusCode, respBytes)
|
||||
}
|
||||
|
||||
result, err := fromAnthropicResponse(respBytes)
|
||||
if err != nil {
|
||||
log.Error("llm_error", "provider", "anthropic", logger.FieldDurationMS, ms, "err", err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
log.Info("llm_response",
|
||||
"provider", "anthropic",
|
||||
"model", req.Model,
|
||||
logger.FieldDurationMS, ms,
|
||||
logger.FieldTokensUsed, result.Usage.TotalTokens,
|
||||
"input_tokens", result.Usage.InputTokens,
|
||||
"output_tokens", result.Usage.OutputTokens,
|
||||
"tool_calls", len(result.ToolCalls),
|
||||
"finish_reason", result.FinishReason,
|
||||
)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ── private conversion helpers ────────────────────────────────────────────
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock represents a block in a content array.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// text block
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// tool_use block (in assistant responses)
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input map[string]any `json:"input,omitempty"`
|
||||
|
||||
// tool_result block (in user messages)
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
func toAnthropicRequest(req coretypes.CompletionRequest) anthropicRequest {
|
||||
msgs := make([]anthropicMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == coretypes.RoleSystem {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, toAnthropicMessage(m))
|
||||
}
|
||||
|
||||
tools := make([]anthropicTool, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tools[i] = anthropicTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: t.InputSchema,
|
||||
}
|
||||
}
|
||||
|
||||
return anthropicRequest{
|
||||
Model: req.Model,
|
||||
MaxTokens: req.MaxTokens,
|
||||
System: req.SystemPrompt,
|
||||
Messages: msgs,
|
||||
Tools: tools,
|
||||
}
|
||||
}
|
||||
|
||||
// toAnthropicMessage converts a core Message to the Anthropic format.
|
||||
// Handles plain text, assistant messages with tool calls, and tool result messages.
|
||||
func toAnthropicMessage(m coretypes.Message) anthropicMessage {
|
||||
// Assistant message with tool calls → content array with text + tool_use blocks
|
||||
if m.Role == coretypes.RoleAssistant && len(m.ToolCalls) > 0 {
|
||||
blocks := make([]anthropicContentBlock, 0, len(m.ToolCalls)+1)
|
||||
if m.Content != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: m.Content})
|
||||
}
|
||||
for _, tc := range m.ToolCalls {
|
||||
var input map[string]any
|
||||
_ = json.Unmarshal([]byte(tc.Arguments), &input)
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Input: input,
|
||||
})
|
||||
}
|
||||
raw, _ := json.Marshal(blocks)
|
||||
return anthropicMessage{Role: "assistant", Content: raw}
|
||||
}
|
||||
|
||||
// Tool result message → user message with tool_result content array
|
||||
if m.Role == coretypes.RoleTool {
|
||||
blocks := []anthropicContentBlock{{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.ToolCallID,
|
||||
Content: m.Content,
|
||||
}}
|
||||
raw, _ := json.Marshal(blocks)
|
||||
return anthropicMessage{Role: "user", Content: raw}
|
||||
}
|
||||
|
||||
// Plain text message
|
||||
raw, _ := json.Marshal(m.Content)
|
||||
return anthropicMessage{Role: string(m.Role), Content: raw}
|
||||
}
|
||||
|
||||
func fromAnthropicResponse(raw []byte) (coretypes.CompletionResponse, error) {
|
||||
var ar anthropicResponse
|
||||
if err := json.Unmarshal(raw, &ar); err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
var content string
|
||||
var toolCalls []coretypes.ToolCall
|
||||
|
||||
for _, c := range ar.Content {
|
||||
switch c.Type {
|
||||
case "text":
|
||||
content += c.Text
|
||||
case "tool_use":
|
||||
argsJSON, _ := json.Marshal(c.Input)
|
||||
toolCalls = append(toolCalls, coretypes.ToolCall{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Arguments: string(argsJSON),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: ar.StopReason,
|
||||
Usage: coretypes.TokenUsage{
|
||||
InputTokens: ar.Usage.InputTokens,
|
||||
OutputTokens: ar.Usage.OutputTokens,
|
||||
TotalTokens: ar.Usage.InputTokens + ar.Usage.OutputTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,295 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClaudeBinary = "claude"
|
||||
defaultClaudeTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// claudeJSONOutput represents the JSON output from `claude -p --output-format json`.
|
||||
type claudeJSONOutput struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
CostUSD float64 `json:"cost_usd"`
|
||||
IsError bool `json:"is_error"`
|
||||
Duration float64 `json:"duration_api_ms"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
Result string `json:"result"`
|
||||
SessionID string `json:"session_id"`
|
||||
TotalCost float64 `json:"total_cost_usd"`
|
||||
Usage claudeUsage `json:"usage"`
|
||||
ContentBlock []claudeContent `json:"content"`
|
||||
}
|
||||
|
||||
type claudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type claudeContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// NewClaudeCodeComplete creates a CompleteFunc that executes `claude -p` as a subprocess.
|
||||
func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes.CompleteFunc {
|
||||
binary := cfg.Binary
|
||||
if binary == "" {
|
||||
binary = defaultClaudeBinary
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultClaudeTimeout
|
||||
}
|
||||
|
||||
// Resolve working directory once at init time.
|
||||
workDir := resolveWorkDir(cfg.WorkingDir, log)
|
||||
|
||||
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
prompt := flattenMessages(req.Messages)
|
||||
|
||||
log.Debug("claude_code_exec",
|
||||
"binary", binary,
|
||||
"args", strings.Join(args, " "),
|
||||
"prompt_len", len(prompt),
|
||||
"working_dir", workDir,
|
||||
)
|
||||
|
||||
cmd := exec.CommandContext(ctx, binary, args...)
|
||||
if workDir != "" {
|
||||
cmd.Dir = workDir
|
||||
}
|
||||
// Build clean env: inherit parent but remove ANTHROPIC_API_KEY
|
||||
// so claude uses its own OAuth auth instead of a potentially invalid key.
|
||||
cmd.Env = filterEnv(os.Environ(), "ANTHROPIC_API_KEY")
|
||||
cmd.Stdin = strings.NewReader(prompt)
|
||||
|
||||
// Create a new process group so we can kill claude + all its children.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
// Override the default cancel behavior: kill the entire process group
|
||||
// instead of just the main process, preventing orphaned child processes.
|
||||
cmd.Cancel = func() error {
|
||||
if cmd.Process != nil {
|
||||
pgid := cmd.Process.Pid
|
||||
log.Info("killing claude-code process group", "pgid", pgid)
|
||||
// Negative PID = kill entire process group
|
||||
return syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
start := time.Now()
|
||||
err := cmd.Run()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Ensure the process group is fully dead after Run returns,
|
||||
// even if cmd.Run() returned without triggering Cancel (normal exit).
|
||||
if cmd.Process != nil {
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
log.Debug("claude_code_done",
|
||||
"elapsed_ms", elapsed.Milliseconds(),
|
||||
"stdout_len", stdout.Len(),
|
||||
"stderr_len", stderr.Len(),
|
||||
"exit_err", err,
|
||||
)
|
||||
|
||||
return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveWorkDir determines the working directory for the claude subprocess.
|
||||
// If configured is empty, it creates a temporary directory to avoid inheriting the launcher's CWD.
|
||||
// If configured is non-empty, it ensures the directory exists.
|
||||
func resolveWorkDir(configured string, log *slog.Logger) string {
|
||||
if configured == "" {
|
||||
tmp, err := os.MkdirTemp("", "claude-agent-*")
|
||||
if err != nil {
|
||||
log.Error("claude-code: failed to create temp working dir", "err", err)
|
||||
return "" // Fall through — cmd.Dir will remain empty (inherits CWD).
|
||||
}
|
||||
log.Warn("claude-code working_dir is empty, using temporary directory",
|
||||
"dir", tmp,
|
||||
)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// Ensure configured directory exists.
|
||||
if err := os.MkdirAll(configured, 0o755); err != nil {
|
||||
log.Error("claude-code: failed to create working dir", "dir", configured, "err", err)
|
||||
}
|
||||
return configured
|
||||
}
|
||||
|
||||
// buildClaudeArgs constructs the CLI arguments for claude -p.
|
||||
func buildClaudeArgs(cfg config.ClaudeCodeCfg, req coretypes.CompletionRequest) []string {
|
||||
args := []string{"--print", "--output-format", "json"}
|
||||
|
||||
if req.SystemPrompt != "" {
|
||||
args = append(args, "--system-prompt", req.SystemPrompt)
|
||||
}
|
||||
|
||||
if cfg.DisableTools {
|
||||
args = append(args, "--tools", "")
|
||||
} else {
|
||||
if len(cfg.AllowedTools) > 0 {
|
||||
args = append(args, "--allowedTools")
|
||||
args = append(args, cfg.AllowedTools...)
|
||||
}
|
||||
|
||||
if len(cfg.DisallowedTools) > 0 {
|
||||
args = append(args, "--disallowedTools")
|
||||
args = append(args, cfg.DisallowedTools...)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PermissionMode != "" {
|
||||
args = append(args, "--permission-mode", cfg.PermissionMode)
|
||||
}
|
||||
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
if cfg.FallbackModel != "" {
|
||||
args = append(args, "--fallback-model", cfg.FallbackModel)
|
||||
}
|
||||
|
||||
if cfg.SessionID != "" {
|
||||
args = append(args, "--session-id", cfg.SessionID)
|
||||
}
|
||||
|
||||
for _, dir := range cfg.AddDirs {
|
||||
args = append(args, "--add-dir", dir)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// flattenMessages converts a conversation history into a single text prompt for stdin.
|
||||
func flattenMessages(msgs []coretypes.Message) string {
|
||||
var b strings.Builder
|
||||
for _, m := range msgs {
|
||||
switch m.Role {
|
||||
case coretypes.RoleUser:
|
||||
fmt.Fprintf(&b, "User: %s\n\n", m.Content)
|
||||
case coretypes.RoleAssistant:
|
||||
fmt.Fprintf(&b, "Assistant: %s\n\n", m.Content)
|
||||
case coretypes.RoleTool:
|
||||
fmt.Fprintf(&b, "Tool result: %s\n\n", m.Content)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// parseClaudeOutput parses the JSON output from `claude -p --output-format json`.
|
||||
func parseClaudeOutput(
|
||||
stdout, stderr []byte,
|
||||
execErr error,
|
||||
elapsed time.Duration,
|
||||
log *slog.Logger,
|
||||
) (coretypes.CompletionResponse, error) {
|
||||
// If the process failed and there's no stdout, report the error
|
||||
if execErr != nil && len(stdout) == 0 {
|
||||
errMsg := string(stderr)
|
||||
if errMsg == "" {
|
||||
errMsg = execErr.Error()
|
||||
}
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code process failed: %s", errMsg)
|
||||
}
|
||||
|
||||
// Parse JSON output
|
||||
var output claudeJSONOutput
|
||||
if err := json.Unmarshal(stdout, &output); err != nil {
|
||||
// Fall back to treating stdout as plain text
|
||||
log.Warn("claude_code_json_parse_failed", "err", err, "stdout_len", len(stdout))
|
||||
return coretypes.CompletionResponse{
|
||||
Content: strings.TrimSpace(string(stdout)),
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if output.IsError {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code error: %s", output.Result)
|
||||
}
|
||||
|
||||
// Extract text from result field or content blocks
|
||||
content := output.Result
|
||||
if content == "" && len(output.ContentBlock) > 0 {
|
||||
var parts []string
|
||||
for _, block := range output.ContentBlock {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
content = strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if execErr != nil {
|
||||
finishReason = "error"
|
||||
}
|
||||
|
||||
log.Info("claude_code_response",
|
||||
"content_len", len(content),
|
||||
"input_tokens", output.Usage.InputTokens,
|
||||
"output_tokens", output.Usage.OutputTokens,
|
||||
"num_turns", output.NumTurns,
|
||||
"cost_usd", output.TotalCost,
|
||||
"elapsed_ms", elapsed.Milliseconds(),
|
||||
)
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: content,
|
||||
Usage: coretypes.TokenUsage{
|
||||
InputTokens: output.Usage.InputTokens,
|
||||
OutputTokens: output.Usage.OutputTokens,
|
||||
TotalTokens: output.Usage.InputTokens + output.Usage.OutputTokens,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// filterEnv returns a copy of environ with the named keys removed.
|
||||
func filterEnv(environ []string, keys ...string) []string {
|
||||
out := make([]string, 0, len(environ))
|
||||
for _, e := range environ {
|
||||
skip := false
|
||||
for _, k := range keys {
|
||||
if strings.HasPrefix(e, k+"=") {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip {
|
||||
out = append(out, e)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,402 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
var discardLog = slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
|
||||
// ── buildClaudeArgs ──────────────────────────────────────────────────────
|
||||
|
||||
func TestBuildClaudeArgs_Minimal(t *testing.T) {
|
||||
cfg := config.ClaudeCodeCfg{}
|
||||
req := coretypes.CompletionRequest{}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
// Must always start with --print --output-format json
|
||||
want := []string{"--print", "--output-format", "json"}
|
||||
if len(args) != len(want) {
|
||||
t.Fatalf("got %v, want %v", args, want)
|
||||
}
|
||||
for i := range want {
|
||||
if args[i] != want[i] {
|
||||
t.Errorf("args[%d] = %q, want %q", i, args[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeArgs_AllOptions(t *testing.T) {
|
||||
cfg := config.ClaudeCodeCfg{
|
||||
Model: "sonnet",
|
||||
FallbackModel: "haiku",
|
||||
PermissionMode: "bypassPermissions",
|
||||
AllowedTools: []string{"Bash(git:*)", "Read"},
|
||||
SessionID: "abc-123",
|
||||
AddDirs: []string{"/tmp/extra"},
|
||||
}
|
||||
req := coretypes.CompletionRequest{
|
||||
SystemPrompt: "You are a helpful bot",
|
||||
}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
assertContains(t, args, "--system-prompt", "You are a helpful bot")
|
||||
assertContains(t, args, "--model", "sonnet")
|
||||
assertContains(t, args, "--fallback-model", "haiku")
|
||||
assertContains(t, args, "--permission-mode", "bypassPermissions")
|
||||
assertContains(t, args, "--session-id", "abc-123")
|
||||
assertContains(t, args, "--add-dir", "/tmp/extra")
|
||||
assertContains(t, args, "--allowedTools", "Bash(git:*)")
|
||||
}
|
||||
|
||||
func TestBuildClaudeArgs_DisableTools(t *testing.T) {
|
||||
cfg := config.ClaudeCodeCfg{
|
||||
DisableTools: true,
|
||||
AllowedTools: []string{"Bash"}, // should be ignored
|
||||
}
|
||||
req := coretypes.CompletionRequest{}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
|
||||
assertContains(t, args, "--tools", "")
|
||||
// --allowedTools must NOT appear when disable_tools is set
|
||||
for _, a := range args {
|
||||
if a == "--allowedTools" {
|
||||
t.Error("--allowedTools should not appear when DisableTools=true")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeArgs_DisallowedTools(t *testing.T) {
|
||||
cfg := config.ClaudeCodeCfg{
|
||||
DisallowedTools: []string{"Edit", "Write"},
|
||||
}
|
||||
req := coretypes.CompletionRequest{}
|
||||
|
||||
args := buildClaudeArgs(cfg, req)
|
||||
assertContains(t, args, "--disallowedTools", "Edit")
|
||||
}
|
||||
|
||||
// ── flattenMessages ──────────────────────────────────────────────────────
|
||||
|
||||
func TestFlattenMessages_Empty(t *testing.T) {
|
||||
got := flattenMessages(nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlattenMessages_MultiRole(t *testing.T) {
|
||||
msgs := []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: "hello"},
|
||||
{Role: coretypes.RoleAssistant, Content: "hi there"},
|
||||
{Role: coretypes.RoleTool, Content: `{"time":"12:00"}`},
|
||||
{Role: coretypes.RoleUser, Content: "thanks"},
|
||||
}
|
||||
|
||||
got := flattenMessages(msgs)
|
||||
|
||||
expects := []string{
|
||||
"User: hello",
|
||||
"Assistant: hi there",
|
||||
`Tool result: {"time":"12:00"}`,
|
||||
"User: thanks",
|
||||
}
|
||||
for _, e := range expects {
|
||||
if !contains(got, e) {
|
||||
t.Errorf("missing %q in:\n%s", e, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlattenMessages_SkipsSystem(t *testing.T) {
|
||||
msgs := []coretypes.Message{
|
||||
{Role: coretypes.RoleSystem, Content: "system prompt"},
|
||||
{Role: coretypes.RoleUser, Content: "hello"},
|
||||
}
|
||||
|
||||
got := flattenMessages(msgs)
|
||||
if contains(got, "system prompt") {
|
||||
t.Error("system messages should not appear in flattened output")
|
||||
}
|
||||
if !contains(got, "User: hello") {
|
||||
t.Error("user message missing")
|
||||
}
|
||||
}
|
||||
|
||||
// ── parseClaudeOutput ────────────────────────────────────────────────────
|
||||
|
||||
func TestParseClaudeOutput_Success(t *testing.T) {
|
||||
output := claudeJSONOutput{
|
||||
Type: "result",
|
||||
Subtype: "success",
|
||||
IsError: false,
|
||||
NumTurns: 1,
|
||||
Result: "Hello! I'm Claude.",
|
||||
TotalCost: 0.025,
|
||||
Usage: claudeUsage{InputTokens: 10, OutputTokens: 50},
|
||||
}
|
||||
stdout, _ := json.Marshal(output)
|
||||
|
||||
resp, err := parseClaudeOutput(stdout, nil, nil, 2*time.Second, discardLog)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello! I'm Claude." {
|
||||
t.Errorf("content = %q, want %q", resp.Content, "Hello! I'm Claude.")
|
||||
}
|
||||
if resp.Usage.InputTokens != 10 {
|
||||
t.Errorf("input tokens = %d, want 10", resp.Usage.InputTokens)
|
||||
}
|
||||
if resp.Usage.OutputTokens != 50 {
|
||||
t.Errorf("output tokens = %d, want 50", resp.Usage.OutputTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 60 {
|
||||
t.Errorf("total tokens = %d, want 60", resp.Usage.TotalTokens)
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("finish reason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeOutput_ErrorResponse(t *testing.T) {
|
||||
output := claudeJSONOutput{
|
||||
IsError: true,
|
||||
Result: "Invalid API key",
|
||||
}
|
||||
stdout, _ := json.Marshal(output)
|
||||
|
||||
_, err := parseClaudeOutput(stdout, nil, nil, time.Second, discardLog)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for IsError=true")
|
||||
}
|
||||
if !contains(err.Error(), "Invalid API key") {
|
||||
t.Errorf("error = %q, should contain 'Invalid API key'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeOutput_ProcessFailedNoStdout(t *testing.T) {
|
||||
_, err := parseClaudeOutput(nil, []byte("unknown option\n"), errors.New("exit 1"), time.Second, discardLog)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when process fails with no stdout")
|
||||
}
|
||||
if !contains(err.Error(), "unknown option") {
|
||||
t.Errorf("error = %q, should contain stderr message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeOutput_ProcessFailedNoStderr(t *testing.T) {
|
||||
_, err := parseClaudeOutput(nil, nil, errors.New("exit 1"), time.Second, discardLog)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !contains(err.Error(), "exit 1") {
|
||||
t.Errorf("error = %q, should contain exec error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeOutput_FallbackPlainText(t *testing.T) {
|
||||
// Non-JSON stdout should be treated as plain text
|
||||
resp, err := parseClaudeOutput([]byte("just plain text\n"), nil, nil, time.Second, discardLog)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "just plain text" {
|
||||
t.Errorf("content = %q, want %q", resp.Content, "just plain text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeOutput_ContentBlocks(t *testing.T) {
|
||||
output := claudeJSONOutput{
|
||||
Result: "", // empty result, content in blocks
|
||||
ContentBlock: []claudeContent{
|
||||
{Type: "text", Text: "First part."},
|
||||
{Type: "text", Text: "Second part."},
|
||||
},
|
||||
Usage: claudeUsage{InputTokens: 5, OutputTokens: 20},
|
||||
}
|
||||
stdout, _ := json.Marshal(output)
|
||||
|
||||
resp, err := parseClaudeOutput(stdout, nil, nil, time.Second, discardLog)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "First part.\nSecond part." {
|
||||
t.Errorf("content = %q, want joined blocks", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeOutput_ExecErrWithStdout(t *testing.T) {
|
||||
// Process failed but produced valid JSON output — should parse and set finish_reason=error
|
||||
output := claudeJSONOutput{
|
||||
Result: "partial answer",
|
||||
Usage: claudeUsage{InputTokens: 3, OutputTokens: 10},
|
||||
}
|
||||
stdout, _ := json.Marshal(output)
|
||||
|
||||
resp, err := parseClaudeOutput(stdout, nil, errors.New("timeout"), time.Second, discardLog)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.FinishReason != "error" {
|
||||
t.Errorf("finish reason = %q, want %q", resp.FinishReason, "error")
|
||||
}
|
||||
if resp.Content != "partial answer" {
|
||||
t.Errorf("content = %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// ── filterEnv ────────────────────────────────────────────────────────────
|
||||
|
||||
func TestFilterEnv_RemovesSingleKey(t *testing.T) {
|
||||
env := []string{
|
||||
"HOME=/home/user",
|
||||
"ANTHROPIC_API_KEY=sk-secret",
|
||||
"PATH=/usr/bin",
|
||||
}
|
||||
|
||||
got := filterEnv(env, "ANTHROPIC_API_KEY")
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d: %v", len(got), got)
|
||||
}
|
||||
for _, e := range got {
|
||||
if contains(e, "ANTHROPIC_API_KEY") {
|
||||
t.Errorf("ANTHROPIC_API_KEY should have been removed: %v", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterEnv_RemovesMultipleKeys(t *testing.T) {
|
||||
env := []string{
|
||||
"HOME=/home/user",
|
||||
"ANTHROPIC_API_KEY=sk-secret",
|
||||
"OPENAI_API_KEY=sk-openai",
|
||||
"PATH=/usr/bin",
|
||||
}
|
||||
|
||||
got := filterEnv(env, "ANTHROPIC_API_KEY", "OPENAI_API_KEY")
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d: %v", len(got), got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterEnv_NoMatchKeepsAll(t *testing.T) {
|
||||
env := []string{"HOME=/home/user", "PATH=/usr/bin"}
|
||||
|
||||
got := filterEnv(env, "NONEXISTENT")
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterEnv_PrefixSafety(t *testing.T) {
|
||||
// ANTHROPIC_API_KEY_V2 should NOT be removed when filtering ANTHROPIC_API_KEY
|
||||
env := []string{
|
||||
"ANTHROPIC_API_KEY=secret",
|
||||
"ANTHROPIC_API_KEY_V2=other",
|
||||
}
|
||||
|
||||
got := filterEnv(env, "ANTHROPIC_API_KEY")
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1, got %d: %v", len(got), got)
|
||||
}
|
||||
if got[0] != "ANTHROPIC_API_KEY_V2=other" {
|
||||
t.Errorf("wrong entry kept: %q", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ── resolveWorkDir ──────────────────────────────────────────────────────
|
||||
|
||||
func TestResolveWorkDir_EmptyCreatesTempDir(t *testing.T) {
|
||||
dir := resolveWorkDir("", discardLog)
|
||||
if dir == "" {
|
||||
t.Fatal("expected a temp directory, got empty string")
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
if !strings.Contains(dir, "claude-agent-") {
|
||||
t.Errorf("temp dir %q should contain 'claude-agent-' prefix", dir)
|
||||
}
|
||||
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("temp dir should exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Error("temp dir should be a directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWorkDir_ConfiguredValueUsed(t *testing.T) {
|
||||
want := filepath.Join(t.TempDir(), "custom-workdir")
|
||||
|
||||
got := resolveWorkDir(want, discardLog)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
info, err := os.Stat(got)
|
||||
if err != nil {
|
||||
t.Fatalf("configured dir should be created: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Error("configured dir should be a directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWorkDir_ConfiguredAlreadyExists(t *testing.T) {
|
||||
want := t.TempDir() // already exists
|
||||
|
||||
got := resolveWorkDir(want, discardLog)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// ── helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
(len(s) > 0 && stringContains(s, substr)))
|
||||
}
|
||||
|
||||
func stringContains(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func assertContains(t *testing.T, args []string, flag, value string) {
|
||||
t.Helper()
|
||||
for i, a := range args {
|
||||
if a == flag && i+1 < len(args) && args[i+1] == value {
|
||||
return
|
||||
}
|
||||
// For --tools "" where value is empty string
|
||||
if a == flag && value == "" && i+1 < len(args) && args[i+1] == "" {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("args %v missing %s %q", args, flag, value)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// FromConfig builds a CompleteFunc from an LLMProviderCfg.
|
||||
func FromConfig(cfg config.LLMProviderCfg, log *slog.Logger) (coretypes.CompleteFunc, error) {
|
||||
log.Info("llm_provider_init", "provider", cfg.Provider, "model", cfg.Model)
|
||||
switch cfg.Provider {
|
||||
case "anthropic":
|
||||
return NewAnthropicComplete(cfg.APIKeyEnv, cfg.BaseURL, log), nil
|
||||
case "openai":
|
||||
return NewOpenAIComplete(cfg.APIKeyEnv, cfg.BaseURL, log), nil
|
||||
case "ollama":
|
||||
base := cfg.BaseURL
|
||||
if base == "" {
|
||||
base = "http://localhost:11434/v1"
|
||||
}
|
||||
return NewOpenAIComplete("OLLAMA_API_KEY", base, log), nil
|
||||
case "claude-code":
|
||||
return NewClaudeCodeComplete(cfg.ClaudeCode, log), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown LLM provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// WithFallback wraps primary with a fallback CompleteFunc.
|
||||
// If primary returns an error, fallback is tried with the fallback config's model.
|
||||
func WithFallback(primary, fallback coretypes.CompleteFunc, fallbackCfg config.LLMProviderCfg, log *slog.Logger) coretypes.CompleteFunc {
|
||||
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
|
||||
resp, err := primary(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("llm_fallback_triggered", "primary_err", err)
|
||||
// Override request fields with fallback config values
|
||||
if fallbackCfg.Model != "" {
|
||||
req.Model = fallbackCfg.Model
|
||||
}
|
||||
if fallbackCfg.MaxTokens > 0 {
|
||||
req.MaxTokens = fallbackCfg.MaxTokens
|
||||
}
|
||||
return fallback(ctx, req)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/shell/logger"
|
||||
)
|
||||
|
||||
// NewOpenAIComplete returns a CompleteFunc backed by the OpenAI-compatible API.
|
||||
// Works with OpenAI, Ollama, vLLM, LMStudio — just change baseURL.
|
||||
func NewOpenAIComplete(apiKeyEnv, baseURL string, log *slog.Logger) coretypes.CompleteFunc {
|
||||
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
|
||||
apiKey := os.Getenv(apiKeyEnv)
|
||||
if apiKey == "" {
|
||||
apiKey = "ollama" // Ollama doesn't require a real key
|
||||
}
|
||||
|
||||
cfg := openai.DefaultConfig(apiKey)
|
||||
if baseURL != "" {
|
||||
cfg.BaseURL = baseURL
|
||||
}
|
||||
client := openai.NewClientWithConfig(cfg)
|
||||
|
||||
msgs := make([]openai.ChatCompletionMessage, 0, len(req.Messages)+1)
|
||||
if req.SystemPrompt != "" {
|
||||
msgs = append(msgs, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: req.SystemPrompt,
|
||||
})
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
msgs = append(msgs, toOpenAIMessage(m))
|
||||
}
|
||||
|
||||
openReq := openai.ChatCompletionRequest{
|
||||
Model: req.Model,
|
||||
Messages: msgs,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Temperature: float32(req.Temperature),
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if len(req.Tools) > 0 {
|
||||
openReq.Tools = toOpenAITools(req.Tools)
|
||||
}
|
||||
|
||||
log.Info("llm_request",
|
||||
"provider", "openai",
|
||||
"model", req.Model,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
)
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.CreateChatCompletion(ctx, openReq)
|
||||
if err != nil {
|
||||
ms := time.Since(start).Milliseconds()
|
||||
log.Error("llm_error", "provider", "openai", logger.FieldDurationMS, ms, "err", err)
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("openai completion: %w", err)
|
||||
}
|
||||
ms := time.Since(start).Milliseconds()
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
log.Error("llm_error", "provider", "openai", logger.FieldDurationMS, ms, "err", "empty choices")
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("openai: empty choices")
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
var toolCalls []coretypes.ToolCall
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
toolCalls = append(toolCalls, coretypes.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
log.Info("llm_response",
|
||||
"provider", "openai",
|
||||
"model", req.Model,
|
||||
logger.FieldDurationMS, ms,
|
||||
logger.FieldTokensUsed, resp.Usage.TotalTokens,
|
||||
"input_tokens", resp.Usage.PromptTokens,
|
||||
"output_tokens", resp.Usage.CompletionTokens,
|
||||
"tool_calls", len(toolCalls),
|
||||
"finish_reason", string(choice.FinishReason),
|
||||
)
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: string(choice.FinishReason),
|
||||
Usage: coretypes.TokenUsage{
|
||||
InputTokens: resp.Usage.PromptTokens,
|
||||
OutputTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// toOpenAIMessage converts a core Message to an OpenAI ChatCompletionMessage.
|
||||
func toOpenAIMessage(m coretypes.Message) openai.ChatCompletionMessage {
|
||||
role := openai.ChatMessageRoleUser
|
||||
switch m.Role {
|
||||
case coretypes.RoleAssistant:
|
||||
role = openai.ChatMessageRoleAssistant
|
||||
case coretypes.RoleSystem:
|
||||
role = openai.ChatMessageRoleSystem
|
||||
case coretypes.RoleTool:
|
||||
role = openai.ChatMessageRoleTool
|
||||
}
|
||||
|
||||
msg := openai.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
ToolCallID: m.ToolCallID,
|
||||
}
|
||||
|
||||
// Assistant messages with tool calls
|
||||
if m.Role == coretypes.RoleAssistant && len(m.ToolCalls) > 0 {
|
||||
msg.ToolCalls = make([]openai.ToolCall, len(m.ToolCalls))
|
||||
for i, tc := range m.ToolCalls {
|
||||
msg.ToolCalls[i] = openai.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: openai.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// toOpenAITools converts core ToolSpecs to OpenAI Tool format.
|
||||
func toOpenAITools(specs []coretypes.ToolSpec) []openai.Tool {
|
||||
tools := make([]openai.Tool, len(specs))
|
||||
for i, s := range specs {
|
||||
tools[i] = openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Parameters: json.RawMessage(marshalSchema(s.InputSchema)),
|
||||
},
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// marshalSchema marshals a JSON schema map to bytes. Falls back to empty object.
|
||||
func marshalSchema(schema map[string]any) []byte {
|
||||
b, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
return []byte("{}")
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// runCleanup periodically removes log files older than maxAgeDays for the
|
||||
// given agent. It runs until ctx is cancelled.
|
||||
func runCleanup(ctx context.Context, baseDir, agentID string, maxAgeDays int, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run once immediately at startup.
|
||||
cleanOldLogs(baseDir, agentID, maxAgeDays)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanOldLogs(baseDir, agentID, maxAgeDays)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanOldLogs removes .jsonl and .jsonl.gz files older than maxAgeDays.
|
||||
func cleanOldLogs(baseDir, agentID string, maxAgeDays int) {
|
||||
dir := filepath.Join(baseDir, agentID)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := time.Now().UTC().AddDate(0, 0, -maxAgeDays)
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !isLogFile(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
date := parseDateFromFilename(name)
|
||||
if date.IsZero() {
|
||||
continue
|
||||
}
|
||||
if date.Before(cutoff) {
|
||||
os.Remove(filepath.Join(dir, name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isLogFile returns true for .jsonl and .jsonl.gz files.
|
||||
func isLogFile(name string) bool {
|
||||
return strings.HasSuffix(name, ".jsonl") || strings.HasSuffix(name, ".jsonl.gz")
|
||||
}
|
||||
|
||||
// parseDateFromFilename extracts YYYY-MM-DD from filenames like:
|
||||
//
|
||||
// 2026-03-06.jsonl
|
||||
// 2026-03-06.1.jsonl
|
||||
// 2026-03-06.jsonl.gz
|
||||
func parseDateFromFilename(name string) time.Time {
|
||||
// Strip extensions.
|
||||
base := strings.TrimSuffix(name, ".gz")
|
||||
base = strings.TrimSuffix(base, ".jsonl")
|
||||
|
||||
// Remove numeric suffix (e.g., ".1" from "2026-03-06.1").
|
||||
if idx := strings.LastIndex(base, "."); idx >= 0 {
|
||||
candidate := base[:idx]
|
||||
if t, err := time.Parse("2006-01-02", candidate); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
t, _ := time.Parse("2006-01-02", base)
|
||||
return t
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCleanOldLogs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
agentDir := filepath.Join(dir, "bot1")
|
||||
os.MkdirAll(agentDir, 0o755)
|
||||
|
||||
// Create files: 10 days ago, 5 days ago, today.
|
||||
files := []string{
|
||||
"2026-02-24.jsonl",
|
||||
"2026-02-24.jsonl.gz",
|
||||
"2026-03-01.jsonl",
|
||||
"2026-03-06.jsonl",
|
||||
}
|
||||
for _, f := range files {
|
||||
os.WriteFile(filepath.Join(agentDir, f), []byte("{}"), 0o644)
|
||||
}
|
||||
|
||||
// Retain 7 days → should remove 2026-02-24 files.
|
||||
cleanOldLogs(dir, "bot1", 7)
|
||||
|
||||
remaining, _ := os.ReadDir(agentDir)
|
||||
names := make(map[string]bool)
|
||||
for _, e := range remaining {
|
||||
names[e.Name()] = true
|
||||
}
|
||||
|
||||
if names["2026-02-24.jsonl"] {
|
||||
t.Error("2026-02-24.jsonl should have been removed")
|
||||
}
|
||||
if names["2026-02-24.jsonl.gz"] {
|
||||
t.Error("2026-02-24.jsonl.gz should have been removed")
|
||||
}
|
||||
if !names["2026-03-01.jsonl"] {
|
||||
t.Error("2026-03-01.jsonl should still exist")
|
||||
}
|
||||
if !names["2026-03-06.jsonl"] {
|
||||
t.Error("2026-03-06.jsonl should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDateFromFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want string
|
||||
}{
|
||||
{"2026-03-06.jsonl", "2026-03-06"},
|
||||
{"2026-03-06.1.jsonl", "2026-03-06"},
|
||||
{"2026-03-06.jsonl.gz", "2026-03-06"},
|
||||
{"2026-03-06.2.jsonl.gz", "2026-03-06"},
|
||||
{"invalid.jsonl", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
d := parseDateFromFilename(tt.name)
|
||||
got := ""
|
||||
if !d.IsZero() {
|
||||
got = d.Format("2006-01-02")
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("parseDateFromFilename(%q) = %q, want %q", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanOldLogs_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Should not panic on non-existent agent dir.
|
||||
cleanOldLogs(dir, "nonexistent", 7)
|
||||
}
|
||||
|
||||
func TestIsLogFile(t *testing.T) {
|
||||
if !isLogFile("2026-03-06.jsonl") {
|
||||
t.Error("should match .jsonl")
|
||||
}
|
||||
if !isLogFile("2026-03-06.jsonl.gz") {
|
||||
t.Error("should match .jsonl.gz")
|
||||
}
|
||||
if isLogFile("readme.txt") {
|
||||
t.Error("should not match .txt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCleanup_Cancellation(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.MkdirAll(filepath.Join(dir, "bot1"), 0o755)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
runCleanup(ctx, dir, "bot1", 7, 50*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("cleanup goroutine did not exit after cancel")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
// Package logger provides structured JSONL logging for agents with daily
|
||||
// file rotation, size-based splitting, automatic cleanup, and query helpers.
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Standard field names for structured logging across all agents.
|
||||
const (
|
||||
FieldAgentID = "agent_id"
|
||||
FieldTraceID = "trace_id"
|
||||
FieldAction = "action"
|
||||
FieldReason = "reason"
|
||||
FieldDurationMS = "duration_ms"
|
||||
FieldTokensUsed = "tokens_used"
|
||||
FieldResult = "result"
|
||||
FieldErrorType = "error_type"
|
||||
FieldComponent = "component"
|
||||
)
|
||||
|
||||
// traceKey is the context key for trace IDs.
|
||||
type traceKey struct{}
|
||||
|
||||
// WithTraceID returns a new context carrying the given trace ID.
|
||||
func WithTraceID(ctx context.Context, id string) context.Context {
|
||||
return context.WithValue(ctx, traceKey{}, id)
|
||||
}
|
||||
|
||||
// TraceIDFromCtx extracts the trace ID from ctx, or "" if absent.
|
||||
func TraceIDFromCtx(ctx context.Context) string {
|
||||
if v, ok := ctx.Value(traceKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// LoggerConfig configures a per-agent logger.
|
||||
type LoggerConfig struct {
|
||||
BaseDir string // root log directory (default: "logs"); empty → stdout only
|
||||
AgentID string // agent identifier (required)
|
||||
MaxSizeMB int64 // max file size before rotation (default: 50)
|
||||
MaxAgeDays int // retention in days (default: 7)
|
||||
Compress bool // gzip rotated files (default: true)
|
||||
CleanupInterval time.Duration // cleanup ticker interval (default: 24h)
|
||||
Level slog.Level // minimum log level (default: INFO)
|
||||
}
|
||||
|
||||
func (c *LoggerConfig) defaults() {
|
||||
if c.BaseDir == "" {
|
||||
c.BaseDir = "logs"
|
||||
}
|
||||
if c.MaxSizeMB <= 0 {
|
||||
c.MaxSizeMB = 50
|
||||
}
|
||||
if c.MaxAgeDays <= 0 {
|
||||
c.MaxAgeDays = 7
|
||||
}
|
||||
if c.CleanupInterval <= 0 {
|
||||
c.CleanupInterval = 24 * time.Hour
|
||||
}
|
||||
}
|
||||
|
||||
// NewAgentLogger creates a structured JSON logger that writes to daily-rotated
|
||||
// JSONL files under BaseDir/<AgentID>/. It returns:
|
||||
// - a *slog.Logger pre-enriched with agent_id
|
||||
// - a cleanup func to call on shutdown (closes files, stops cleanup goroutine)
|
||||
// - an error if the log directory cannot be created
|
||||
//
|
||||
// If BaseDir is literally "stdout", the logger writes to os.Stdout with no
|
||||
// file rotation or cleanup.
|
||||
func NewAgentLogger(cfg LoggerConfig) (*slog.Logger, func(), error) {
|
||||
if cfg.BaseDir == "stdout" {
|
||||
h := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: cfg.Level})
|
||||
l := slog.New(h).With(FieldAgentID, cfg.AgentID)
|
||||
return l, func() {}, nil
|
||||
}
|
||||
|
||||
cfg.defaults()
|
||||
|
||||
w, err := NewDailyRotatingWriter(cfg.BaseDir, cfg.AgentID, cfg.MaxSizeMB, cfg.Compress)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
h := slog.NewJSONHandler(w, &slog.HandlerOptions{Level: cfg.Level})
|
||||
l := slog.New(h).With(FieldAgentID, cfg.AgentID)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go runCleanup(ctx, cfg.BaseDir, cfg.AgentID, cfg.MaxAgeDays, cfg.CleanupInterval)
|
||||
|
||||
cleanup := func() {
|
||||
cancel()
|
||||
w.Close()
|
||||
}
|
||||
|
||||
return l, cleanup, nil
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewAgentLogger_WritesJSONL(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
l, cleanup, err := NewAgentLogger(LoggerConfig{
|
||||
BaseDir: dir,
|
||||
AgentID: "test-bot",
|
||||
Level: slog.LevelDebug,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
l.Info("hello world", FieldAction, "greet", FieldReason, "testing")
|
||||
|
||||
// Force flush by closing.
|
||||
cleanup()
|
||||
|
||||
files, _ := os.ReadDir(filepath.Join(dir, "test-bot"))
|
||||
if len(files) == 0 {
|
||||
t.Fatal("expected at least one log file")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, "test-bot", files[0].Name()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
t.Fatalf("log line is not valid JSON: %s", data)
|
||||
}
|
||||
if m["msg"] != "hello world" {
|
||||
t.Errorf("msg = %v, want hello world", m["msg"])
|
||||
}
|
||||
if m[FieldAgentID] != "test-bot" {
|
||||
t.Errorf("agent_id = %v, want test-bot", m[FieldAgentID])
|
||||
}
|
||||
if m[FieldAction] != "greet" {
|
||||
t.Errorf("action = %v, want greet", m[FieldAction])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentLogger_Stdout(t *testing.T) {
|
||||
l, cleanup, err := NewAgentLogger(LoggerConfig{
|
||||
BaseDir: "stdout",
|
||||
AgentID: "dev-bot",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
// Just verify it doesn't panic.
|
||||
l.Info("stdout test")
|
||||
}
|
||||
|
||||
func TestTraceIDContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if got := TraceIDFromCtx(ctx); got != "" {
|
||||
t.Errorf("empty ctx should return empty trace, got %q", got)
|
||||
}
|
||||
|
||||
ctx = WithTraceID(ctx, "abc-123")
|
||||
if got := TraceIDFromCtx(ctx); got != "abc-123" {
|
||||
t.Errorf("trace = %q, want abc-123", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReadLogs returns all log entries for agentID between from and to (inclusive).
|
||||
func ReadLogs(baseDir, agentID string, from, to time.Time) ([]json.RawMessage, error) {
|
||||
var result []json.RawMessage
|
||||
for d := from; !d.After(to); d = d.AddDate(0, 0, 1) {
|
||||
entries, err := ReadDayLogs(baseDir, agentID, d)
|
||||
if err != nil {
|
||||
continue // skip missing days
|
||||
}
|
||||
result = append(result, entries...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ReadDayLogs returns all log entries for a specific day.
|
||||
func ReadDayLogs(baseDir, agentID string, date time.Time) ([]json.RawMessage, error) {
|
||||
dir := filepath.Join(baseDir, agentID)
|
||||
prefix := date.Format("2006-01-02")
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
var result []json.RawMessage
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if !strings.HasPrefix(name, prefix) || !isLogFile(name) {
|
||||
continue
|
||||
}
|
||||
lines, err := readLogFile(filepath.Join(dir, name))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, lines...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SearchLogs returns log entries where field equals value, within the date range.
|
||||
func SearchLogs(baseDir, agentID string, field, value string, from, to time.Time) ([]json.RawMessage, error) {
|
||||
all, err := ReadLogs(baseDir, agentID, from, to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var matched []json.RawMessage
|
||||
for _, raw := range all {
|
||||
var m map[string]any
|
||||
if json.Unmarshal(raw, &m) != nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := m[field]; ok && fmt.Sprint(v) == value {
|
||||
matched = append(matched, raw)
|
||||
}
|
||||
}
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// ListAgents returns the agent IDs that have log directories under baseDir.
|
||||
func ListAgents(baseDir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(baseDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read base dir %s: %w", baseDir, err)
|
||||
}
|
||||
var ids []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
ids = append(ids, e.Name())
|
||||
}
|
||||
}
|
||||
sort.Strings(ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// ListDates returns the dates for which logs exist for the given agent.
|
||||
func ListDates(baseDir, agentID string) ([]time.Time, error) {
|
||||
dir := filepath.Join(baseDir, agentID)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
var dates []time.Time
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !isLogFile(e.Name()) {
|
||||
continue
|
||||
}
|
||||
d := parseDateFromFilename(e.Name())
|
||||
if d.IsZero() {
|
||||
continue
|
||||
}
|
||||
key := d.Format("2006-01-02")
|
||||
if !seen[key] {
|
||||
seen[key] = true
|
||||
dates = append(dates, d)
|
||||
}
|
||||
}
|
||||
sort.Slice(dates, func(i, j int) bool { return dates[i].Before(dates[j]) })
|
||||
return dates, nil
|
||||
}
|
||||
|
||||
// readLogFile reads all JSONL lines from a file (.jsonl or .jsonl.gz).
|
||||
func readLogFile(path string) ([]json.RawMessage, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var r io.Reader = f
|
||||
if strings.HasSuffix(path, ".gz") {
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gz.Close()
|
||||
r = gz
|
||||
}
|
||||
|
||||
var lines []json.RawMessage
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
cp := make([]byte, len(line))
|
||||
copy(cp, line)
|
||||
lines = append(lines, json.RawMessage(cp))
|
||||
}
|
||||
return lines, scanner.Err()
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setupQueryDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
|
||||
bot1 := filepath.Join(dir, "bot1")
|
||||
bot2 := filepath.Join(dir, "bot2")
|
||||
os.MkdirAll(bot1, 0o755)
|
||||
os.MkdirAll(bot2, 0o755)
|
||||
|
||||
lines := []string{
|
||||
`{"time":"2026-03-05T10:00:00Z","level":"INFO","msg":"hello","action":"greet"}`,
|
||||
`{"time":"2026-03-05T11:00:00Z","level":"ERROR","msg":"oops","action":"fail"}`,
|
||||
}
|
||||
os.WriteFile(filepath.Join(bot1, "2026-03-05.jsonl"),
|
||||
[]byte(lines[0]+"\n"+lines[1]+"\n"), 0o644)
|
||||
os.WriteFile(filepath.Join(bot1, "2026-03-06.jsonl"),
|
||||
[]byte(`{"time":"2026-03-06T09:00:00Z","level":"INFO","msg":"day2"}`+"\n"), 0o644)
|
||||
|
||||
os.WriteFile(filepath.Join(bot2, "2026-03-06.jsonl"),
|
||||
[]byte(`{"time":"2026-03-06T08:00:00Z","level":"DEBUG","msg":"bot2 log"}`+"\n"), 0o644)
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestListAgents(t *testing.T) {
|
||||
dir := setupQueryDir(t)
|
||||
agents, err := ListAgents(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(agents) != 2 {
|
||||
t.Fatalf("expected 2 agents, got %d", len(agents))
|
||||
}
|
||||
if agents[0] != "bot1" || agents[1] != "bot2" {
|
||||
t.Errorf("agents = %v, want [bot1 bot2]", agents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDates(t *testing.T) {
|
||||
dir := setupQueryDir(t)
|
||||
dates, err := ListDates(dir, "bot1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(dates) != 2 {
|
||||
t.Fatalf("expected 2 dates, got %d", len(dates))
|
||||
}
|
||||
if dates[0].Format("2006-01-02") != "2026-03-05" {
|
||||
t.Errorf("first date = %v, want 2026-03-05", dates[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDayLogs(t *testing.T) {
|
||||
dir := setupQueryDir(t)
|
||||
day := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
|
||||
entries, err := ReadDayLogs(dir, "bot1", day)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLogs(t *testing.T) {
|
||||
dir := setupQueryDir(t)
|
||||
from := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
|
||||
to := time.Date(2026, 3, 6, 0, 0, 0, 0, time.UTC)
|
||||
entries, err := ReadLogs(dir, "bot1", from, to)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entries) != 3 {
|
||||
t.Fatalf("expected 3 entries across 2 days, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchLogs(t *testing.T) {
|
||||
dir := setupQueryDir(t)
|
||||
from := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
|
||||
to := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
results, err := SearchLogs(dir, "bot1", "action", "fail", from, to)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 match, got %d", len(results))
|
||||
}
|
||||
var m map[string]any
|
||||
json.Unmarshal(results[0], &m)
|
||||
if m["msg"] != "oops" {
|
||||
t.Errorf("msg = %v, want oops", m["msg"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchLogs_NoMatch(t *testing.T) {
|
||||
dir := setupQueryDir(t)
|
||||
from := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
|
||||
to := time.Date(2026, 3, 6, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
results, err := SearchLogs(dir, "bot1", "action", "nonexistent", from, to)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 matches, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgents_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
agents, err := ListAgents(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(agents) != 0 {
|
||||
t.Errorf("expected 0 agents, got %d", len(agents))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DailyRotatingWriter is an io.Writer that rotates log files daily and by
|
||||
// size. Files are named <baseDir>/<agentID>/YYYY-MM-DD.jsonl with optional
|
||||
// numeric suffixes for size-based splits within the same day.
|
||||
type DailyRotatingWriter struct {
|
||||
baseDir string
|
||||
agentID string
|
||||
maxSize int64 // bytes
|
||||
compress bool
|
||||
nowFunc func() time.Time // for testing; defaults to time.Now().UTC
|
||||
dir string // resolved agent log directory
|
||||
|
||||
mu sync.Mutex
|
||||
current *os.File
|
||||
written int64
|
||||
currentDay string
|
||||
suffix int
|
||||
}
|
||||
|
||||
// NewDailyRotatingWriter creates a writer that stores logs under
|
||||
// baseDir/agentID/. It creates the directory if needed and opens the first
|
||||
// log file for today.
|
||||
func NewDailyRotatingWriter(baseDir, agentID string, maxSizeMB int64, compress bool) (*DailyRotatingWriter, error) {
|
||||
dir := filepath.Join(baseDir, agentID)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create log dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
w := &DailyRotatingWriter{
|
||||
baseDir: baseDir,
|
||||
agentID: agentID,
|
||||
maxSize: maxSizeMB * 1024 * 1024,
|
||||
compress: compress,
|
||||
nowFunc: func() time.Time { return time.Now().UTC() },
|
||||
dir: dir,
|
||||
}
|
||||
|
||||
if err := w.openFile(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// Write implements io.Writer with daily and size-based rotation.
|
||||
func (w *DailyRotatingWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
today := w.nowFunc().Format("2006-01-02")
|
||||
|
||||
// Day changed → rotate to new day file.
|
||||
if today != w.currentDay {
|
||||
if err := w.rotate(today, 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Size exceeded → split within same day.
|
||||
if w.written+int64(len(p)) > w.maxSize && w.written > 0 {
|
||||
w.suffix++
|
||||
if err := w.rotate(today, w.suffix); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
n, err := w.current.Write(p)
|
||||
w.written += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close closes the current log file.
|
||||
func (w *DailyRotatingWriter) Close() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.current != nil {
|
||||
return w.current.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rotate closes the current file (optionally compressing it) and opens a new one.
|
||||
func (w *DailyRotatingWriter) rotate(day string, suffix int) error {
|
||||
prev := w.current
|
||||
prevPath := ""
|
||||
if prev != nil {
|
||||
prevPath = prev.Name()
|
||||
prev.Close()
|
||||
}
|
||||
|
||||
// Compress the previous file in the background if enabled and it's from a
|
||||
// different day (we don't compress intra-day splits until day rotates).
|
||||
if w.compress && prevPath != "" && day != w.currentDay {
|
||||
go compressFile(prevPath)
|
||||
}
|
||||
|
||||
w.currentDay = day
|
||||
w.suffix = suffix
|
||||
w.written = 0
|
||||
|
||||
return w.openFile()
|
||||
}
|
||||
|
||||
// openFile opens (or creates) the log file for the current day/suffix.
|
||||
func (w *DailyRotatingWriter) openFile() error {
|
||||
w.currentDay = w.nowFunc().Format("2006-01-02")
|
||||
name := w.filename(w.currentDay, w.suffix)
|
||||
|
||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Track how much has already been written (append mode).
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
w.current = f
|
||||
w.written = info.Size()
|
||||
return nil
|
||||
}
|
||||
|
||||
// filename returns the full path for a given day and suffix.
|
||||
func (w *DailyRotatingWriter) filename(day string, suffix int) string {
|
||||
if suffix == 0 {
|
||||
return filepath.Join(w.dir, day+".jsonl")
|
||||
}
|
||||
return filepath.Join(w.dir, fmt.Sprintf("%s.%d.jsonl", day, suffix))
|
||||
}
|
||||
|
||||
// compressFile gzips src to src.gz and removes the original.
|
||||
func compressFile(src string) {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := os.Create(src + ".gz")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
gz := gzip.NewWriter(out)
|
||||
if _, err := io.Copy(gz, in); err != nil {
|
||||
gz.Close()
|
||||
out.Close()
|
||||
os.Remove(src + ".gz")
|
||||
return
|
||||
}
|
||||
gz.Close()
|
||||
out.Close()
|
||||
in.Close()
|
||||
os.Remove(src)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDailyRotatingWriter_DayRotation(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
w, err := NewDailyRotatingWriter(dir, "bot1", 50, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
day1 := time.Date(2026, 3, 5, 12, 0, 0, 0, time.UTC)
|
||||
day2 := time.Date(2026, 3, 6, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
w.nowFunc = func() time.Time { return day1 }
|
||||
// Force re-open with correct day.
|
||||
w.current.Close()
|
||||
w.currentDay = ""
|
||||
w.openFile()
|
||||
|
||||
w.Write([]byte(`{"msg":"day1"}`))
|
||||
|
||||
w.nowFunc = func() time.Time { return day2 }
|
||||
w.Write([]byte(`{"msg":"day2"}`))
|
||||
w.Close()
|
||||
|
||||
agentDir := filepath.Join(dir, "bot1")
|
||||
entries, _ := os.ReadDir(agentDir)
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, e := range entries {
|
||||
names[e.Name()] = true
|
||||
}
|
||||
|
||||
if !names["2026-03-05.jsonl"] && !names["2026-03-05.jsonl.gz"] {
|
||||
t.Error("expected 2026-03-05.jsonl or .gz")
|
||||
}
|
||||
if !names["2026-03-06.jsonl"] {
|
||||
t.Error("expected 2026-03-06.jsonl")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDailyRotatingWriter_SizeRotation(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// 1 byte max to force rotation on every write.
|
||||
w, err := NewDailyRotatingWriter(dir, "bot2", 0, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Override maxSize to a tiny value (can't use 0 MB).
|
||||
w.maxSize = 10
|
||||
|
||||
now := time.Date(2026, 3, 6, 10, 0, 0, 0, time.UTC)
|
||||
w.nowFunc = func() time.Time { return now }
|
||||
w.current.Close()
|
||||
w.currentDay = ""
|
||||
w.openFile()
|
||||
|
||||
w.Write([]byte(`{"line":1}` + "\n"))
|
||||
w.Write([]byte(`{"line":2}` + "\n"))
|
||||
w.Write([]byte(`{"line":3}` + "\n"))
|
||||
w.Close()
|
||||
|
||||
entries, _ := os.ReadDir(filepath.Join(dir, "bot2"))
|
||||
if len(entries) < 2 {
|
||||
t.Errorf("expected multiple files from size rotation, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDailyRotatingWriter_Concurrent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
w, err := NewDailyRotatingWriter(dir, "bot3", 50, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w.Write([]byte(`{"concurrent":true}` + "\n"))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
entries, _ := os.ReadDir(filepath.Join(dir, "bot3"))
|
||||
if len(entries) == 0 {
|
||||
t.Error("expected at least one log file")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
// Package mcp provides MCP client and server implementations.
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// Client wraps an MCP client (stdio or SSE) and exposes discovered tools.
|
||||
type Client struct {
|
||||
name string
|
||||
transport string // "stdio" | "sse"
|
||||
mcpClient *client.Client
|
||||
tools []mcp.Tool
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewStdioClient creates an MCP client that connects to a stdio-based MCP server.
|
||||
func NewStdioClient(ctx context.Context, name, command string, args []string, env map[string]string, logger *slog.Logger) (*Client, error) {
|
||||
logger.Info("creating stdio MCP client", "name", name, "command", command, "args", args)
|
||||
|
||||
// Prepare environment
|
||||
envSlice := os.Environ()
|
||||
for k, v := range env {
|
||||
envSlice = append(envSlice, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
|
||||
// Create stdio client
|
||||
mcpClient, err := client.NewStdioMCPClient(command, envSlice, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdio client: %w", err)
|
||||
}
|
||||
|
||||
// Initialize
|
||||
initReq := mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "agents-mcp-client",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
Capabilities: mcp.ClientCapabilities{},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = mcpClient.Initialize(ctx, initReq)
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
|
||||
}
|
||||
|
||||
// Discover tools
|
||||
toolsResp, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return nil, fmt.Errorf("failed to list tools: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("discovered MCP tools", "name", name, "count", len(toolsResp.Tools))
|
||||
|
||||
return &Client{
|
||||
name: name,
|
||||
transport: "stdio",
|
||||
mcpClient: mcpClient,
|
||||
tools: toolsResp.Tools,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewSSEClient creates an MCP client that connects to an SSE/HTTP-based MCP server.
|
||||
func NewSSEClient(ctx context.Context, name, url string, headers map[string]string, logger *slog.Logger) (*Client, error) {
|
||||
logger.Info("creating SSE MCP client", "name", name, "url", url)
|
||||
|
||||
// Create SSE client (no custom headers support in basic API, would need transport options)
|
||||
mcpClient, err := client.NewSSEMCPClient(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSE client: %w", err)
|
||||
}
|
||||
|
||||
// Initialize
|
||||
initReq := mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "agents-mcp-client",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
Capabilities: mcp.ClientCapabilities{},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = mcpClient.Initialize(ctx, initReq)
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
|
||||
}
|
||||
|
||||
// Discover tools
|
||||
toolsResp, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return nil, fmt.Errorf("failed to list tools: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("discovered MCP tools", "name", name, "count", len(toolsResp.Tools))
|
||||
|
||||
return &Client{
|
||||
name: name,
|
||||
transport: "sse",
|
||||
mcpClient: mcpClient,
|
||||
tools: toolsResp.Tools,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tools returns the discovered MCP tools.
|
||||
func (c *Client) Tools() []mcp.Tool {
|
||||
return c.tools
|
||||
}
|
||||
|
||||
// Name returns the client name.
|
||||
func (c *Client) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// CallTool invokes an MCP tool by name with the given arguments.
|
||||
func (c *Client) CallTool(ctx context.Context, name string, args map[string]any, timeout time.Duration) (*mcp.CallToolResult, error) {
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := c.mcpClient.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tool call failed: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close closes the MCP client connection.
|
||||
func (c *Client) Close() error {
|
||||
c.logger.Info("closing MCP client", "name", c.name, "transport", c.transport)
|
||||
return c.mcpClient.Close()
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
)
|
||||
|
||||
// Manager manages multiple MCP client connections.
|
||||
type Manager struct {
|
||||
clients map[string]*Client // server name → client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager and initializes all configured servers.
|
||||
func NewManager(ctx context.Context, servers []config.MCPServerCfg, logger *slog.Logger) (*Manager, error) {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
clients: make(map[string]*Client),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
for _, serverCfg := range servers {
|
||||
if err := m.addServer(ctx, serverCfg); err != nil {
|
||||
// Close any already-created clients before returning error
|
||||
m.Close()
|
||||
return nil, fmt.Errorf("failed to initialize MCP server %q: %w", serverCfg.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("MCP manager initialized", "servers", len(m.clients))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// addServer creates and adds a single MCP client to the manager.
|
||||
func (m *Manager) addServer(ctx context.Context, cfg config.MCPServerCfg) error {
|
||||
if cfg.Name == "" {
|
||||
return fmt.Errorf("MCP server must have a name")
|
||||
}
|
||||
|
||||
// Auto-detect transport if not specified
|
||||
transport := cfg.Transport
|
||||
if transport == "" {
|
||||
if cfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if cfg.URL != "" {
|
||||
transport = "sse"
|
||||
} else {
|
||||
return fmt.Errorf("MCP server %q must specify either command (stdio) or url (sse)", cfg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
var client *Client
|
||||
var err error
|
||||
|
||||
switch transport {
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("MCP server %q with stdio transport must have a command", cfg.Name)
|
||||
}
|
||||
// Expand environment variables in command and args
|
||||
command := os.ExpandEnv(cfg.Command)
|
||||
args := make([]string, len(cfg.Args))
|
||||
for i, arg := range cfg.Args {
|
||||
args[i] = os.ExpandEnv(arg)
|
||||
}
|
||||
// Expand env vars in environment map
|
||||
env := make(map[string]string, len(cfg.Env))
|
||||
for k, v := range cfg.Env {
|
||||
env[k] = os.ExpandEnv(v)
|
||||
}
|
||||
client, err = NewStdioClient(ctx, cfg.Name, command, args, env, m.logger)
|
||||
|
||||
case "sse":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("MCP server %q with sse transport must have a url", cfg.Name)
|
||||
}
|
||||
url := os.ExpandEnv(cfg.URL)
|
||||
headers := make(map[string]string, len(cfg.Headers))
|
||||
for k, v := range cfg.Headers {
|
||||
headers[k] = os.ExpandEnv(v)
|
||||
}
|
||||
client, err = NewSSEClient(ctx, cfg.Name, url, headers, m.logger)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown transport %q for MCP server %q (must be stdio or sse)", transport, cfg.Name)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.clients[cfg.Name] = client
|
||||
m.logger.Info("MCP server connected", "name", cfg.Name, "transport", transport, "tools", len(client.Tools()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClient returns an MCP client by name, or nil if not found.
|
||||
func (m *Manager) GetClient(name string) *Client {
|
||||
return m.clients[name]
|
||||
}
|
||||
|
||||
// AllClients returns all MCP clients managed by this manager.
|
||||
func (m *Manager) AllClients() map[string]*Client {
|
||||
return m.clients
|
||||
}
|
||||
|
||||
// Close closes all MCP client connections.
|
||||
func (m *Manager) Close() error {
|
||||
var errs []string
|
||||
for name, client := range m.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("%s: %v", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("errors closing MCP clients: %s", strings.Join(errs, "; "))
|
||||
}
|
||||
|
||||
m.logger.Info("MCP manager closed", "servers", len(m.clients))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Package mcp provides MCP client and server implementations.
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/tools"
|
||||
)
|
||||
|
||||
// MCPServer exposes agent tools as an MCP server.
|
||||
type MCPServer struct {
|
||||
srv *server.MCPServer
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewMCPServer creates an MCP server exposing the given tool specs.
|
||||
func NewMCPServer(name, version string, specs []tools.ToolSpec, logger *slog.Logger) *MCPServer {
|
||||
srv := server.NewMCPServer(name, version)
|
||||
|
||||
for _, spec := range specs {
|
||||
spec := spec // capture
|
||||
tool := mcp.NewTool(spec.Name,
|
||||
mcp.WithDescription(spec.Description),
|
||||
)
|
||||
srv.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Placeholder handler — wire real execution here
|
||||
return mcp.NewToolResultText(fmt.Sprintf("tool %s called", spec.Name)), nil
|
||||
})
|
||||
}
|
||||
|
||||
return &MCPServer{srv: srv, logger: logger}
|
||||
}
|
||||
|
||||
// ServeStdio runs the MCP server over stdin/stdout (for Claude Desktop / CLI integration).
|
||||
func (m *MCPServer) ServeStdio(ctx context.Context) error {
|
||||
m.logger.Info("mcp server starting on stdio")
|
||||
return server.ServeStdio(m.srv)
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
// Package shellmem implements persistent memory storage using SQLite.
|
||||
package shellmem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/memory"
|
||||
)
|
||||
|
||||
const schema = `
|
||||
CREATE TABLE IF NOT EXISTS facts (
|
||||
agent_id TEXT NOT NULL,
|
||||
subject TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (agent_id, subject, key)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_room ON messages(agent_id, room_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(agent_id, subject);
|
||||
`
|
||||
|
||||
// SQLiteStore implements memory.Store using SQLite.
|
||||
type SQLiteStore struct {
|
||||
db *sql.DB
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New opens (or creates) a SQLite database at dbPath and runs migrations.
|
||||
func New(dbPath string, logger *slog.Logger) (*SQLiteStore, error) {
|
||||
log := logger.With("component", "memory", "db_path", dbPath)
|
||||
log.Info("memory_open")
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create memory db dir: %w", err)
|
||||
}
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open memory db: %w", err)
|
||||
}
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("migrate memory db: %w", err)
|
||||
}
|
||||
log.Info("memory_ready")
|
||||
return &SQLiteStore{db: db, logger: log}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SaveFact(ctx context.Context, f memory.Fact) error {
|
||||
s.logger.Debug("memory_save_fact", "subject", f.Subject, "key", f.Key)
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT OR REPLACE INTO facts (agent_id, subject, key, value, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
f.AgentID, f.Subject, f.Key, f.Value, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("memory_save_fact_error", "subject", f.Subject, "key", f.Key, "err", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]memory.Fact, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if key != nil {
|
||||
rows, err = s.db.QueryContext(ctx,
|
||||
`SELECT agent_id, subject, key, value, updated_at FROM facts
|
||||
WHERE agent_id = ? AND subject = ? AND key = ?`,
|
||||
agentID, subject, *key,
|
||||
)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx,
|
||||
`SELECT agent_id, subject, key, value, updated_at FROM facts
|
||||
WHERE agent_id = ? AND subject = ?`,
|
||||
agentID, subject,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var facts []memory.Fact
|
||||
for rows.Next() {
|
||||
var f memory.Fact
|
||||
if err := rows.Scan(&f.AgentID, &f.Subject, &f.Key, &f.Value, &f.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
facts = append(facts, f)
|
||||
}
|
||||
s.logger.Debug("memory_recall", "subject", subject, "count", len(facts))
|
||||
return facts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteFacts(ctx context.Context, agentID, subject string, key *string) error {
|
||||
if key != nil {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM facts WHERE agent_id = ? AND subject = ? AND key = ?`,
|
||||
agentID, subject, *key,
|
||||
)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM facts WHERE agent_id = ? AND subject = ?`,
|
||||
agentID, subject,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SaveMessage(ctx context.Context, m memory.HistoryMessage) error {
|
||||
s.logger.Debug("memory_save_msg", "room", m.RoomID, "role", m.Role)
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO messages (agent_id, room_id, role, content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
m.AgentID, m.RoomID, string(m.Role), m.Content, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("memory_save_msg_error", "room", m.RoomID, "err", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) LoadMessages(ctx context.Context, agentID, roomID string, limit int) ([]memory.HistoryMessage, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT agent_id, room_id, role, content, created_at FROM messages
|
||||
WHERE agent_id = ? AND room_id = ?
|
||||
ORDER BY created_at DESC LIMIT ?`,
|
||||
agentID, roomID, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var msgs []memory.HistoryMessage
|
||||
for rows.Next() {
|
||||
var m memory.HistoryMessage
|
||||
var role string
|
||||
if err := rows.Scan(&m.AgentID, &m.RoomID, &role, &m.Content, &m.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Role = llm.Role(role)
|
||||
msgs = append(msgs, m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reverse to chronological order
|
||||
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||||
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||||
}
|
||||
s.logger.Debug("memory_load_msgs", "room", roomID, "count", len(msgs))
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteMessages(ctx context.Context, agentID string, roomID *string) error {
|
||||
if roomID != nil {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM messages WHERE agent_id = ? AND room_id = ?`,
|
||||
agentID, *roomID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM messages WHERE agent_id = ?`,
|
||||
agentID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) Close() error {
|
||||
s.logger.Info("memory_closed")
|
||||
return s.db.Close()
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package orchestration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/orchestration"
|
||||
)
|
||||
|
||||
// evaluate asks the LLM to score the quality of a bot's response.
|
||||
func (o *Orchestrator) evaluate(ctx context.Context, question string, response orchestration.BotResponse) orchestration.QualityScore {
|
||||
userContent := fmt.Sprintf("Question: %s\n\nResponse from %s:\n%s", question, response.BotID, response.Text)
|
||||
|
||||
resp, err := o.llm(ctx, coretypes.CompletionRequest{
|
||||
Model: o.cfg.LLM.Primary.Model,
|
||||
MaxTokens: o.cfg.LLM.Primary.MaxTokens,
|
||||
Temperature: o.cfg.LLM.Primary.Temperature,
|
||||
SystemPrompt: o.qualityPrompt,
|
||||
Messages: []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: userContent},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
o.logger.Error("quality evaluation LLM call failed", "err", err)
|
||||
// On LLM failure, assume quality is good enough to stop the pipeline
|
||||
return orchestration.QualityScore{
|
||||
Score: 1.0,
|
||||
Continue: false,
|
||||
Reason: fmt.Sprintf("evaluation failed: %s, assuming good quality", err),
|
||||
}
|
||||
}
|
||||
|
||||
var qs orchestration.QualityScore
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(resp.Content)), &qs); err != nil {
|
||||
o.logger.Warn("failed to parse quality score", "content", resp.Content, "err", err)
|
||||
// On parse failure, assume good quality
|
||||
return orchestration.QualityScore{
|
||||
Score: 1.0,
|
||||
Continue: false,
|
||||
Reason: fmt.Sprintf("parse failed: %s", err),
|
||||
}
|
||||
}
|
||||
|
||||
return qs
|
||||
}
|
||||
@@ -0,0 +1,455 @@
|
||||
// Package orchestration implements the multi-bot orchestrator runtime.
|
||||
// The orchestrator intercepts events in managed rooms and coordinates which bot
|
||||
// responds via the in-process bus.
|
||||
//
|
||||
// PARKED (Matrix-out, issue matrix-out): the orchestrator is no longer wired
|
||||
// into the launcher. The room-discovery side (RoomScanner / SetScanner /
|
||||
// ScanExistingRooms / evaluateRoom / NotifyMembership) was intrinsic to Matrix —
|
||||
// it enumerated Matrix rooms via mautrix — and has been removed. What remains is
|
||||
// the transport-neutral routing/quality pipeline, which compiles without any
|
||||
// messaging fabric. Re-introducing auto-discovery over unibus
|
||||
// (GET /rooms/{id}/members) is a later step.
|
||||
package orchestration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/decision"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/orchestration"
|
||||
"github.com/enmanuel/agents/shell/bus"
|
||||
shelllm "github.com/enmanuel/agents/shell/llm"
|
||||
)
|
||||
|
||||
// Orchestrator coordinates multi-bot rooms. It has no transport identity —
|
||||
// it intercepts events before they reach bots and delegates via the bus.
|
||||
type Orchestrator struct {
|
||||
cfg *config.SpecialConfig
|
||||
llm coretypes.CompleteFunc
|
||||
bus *bus.Bus
|
||||
logger *slog.Logger
|
||||
|
||||
// mu protects managedRooms, participants, and knownBotIDs.
|
||||
mu sync.RWMutex
|
||||
managedRooms map[string][]string // roomID → []botID
|
||||
participants map[string]orchestration.ParticipantInfo // botID → info
|
||||
knownBotIDs map[string]string // senderID → botID
|
||||
|
||||
// Prompts loaded from files
|
||||
routingPrompt string
|
||||
qualityPrompt string
|
||||
refinementPrompt string
|
||||
|
||||
// Dedup: multiple bots in the same room will each trigger Intercept().
|
||||
// We use a set of "room:sender:content" keys to ensure only one fires.
|
||||
seenMu sync.Mutex
|
||||
seen map[string]bool
|
||||
}
|
||||
|
||||
// New creates an Orchestrator from its config.
|
||||
func New(cfg *config.SpecialConfig, agentBus *bus.Bus, logger *slog.Logger) (*Orchestrator, error) {
|
||||
llmFunc, err := shelllm.FromConfig(cfg.LLM.Primary, logger.With("component", "llm"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("orchestrator LLM: %w", err)
|
||||
}
|
||||
|
||||
managed := make(map[string][]string)
|
||||
for _, room := range cfg.Orchestration.Rooms {
|
||||
if room.RoomID == "" {
|
||||
continue // skip empty room IDs (unset env vars)
|
||||
}
|
||||
managed[room.RoomID] = room.Participants
|
||||
}
|
||||
|
||||
o := &Orchestrator{
|
||||
cfg: cfg,
|
||||
llm: llmFunc,
|
||||
bus: agentBus,
|
||||
managedRooms: managed,
|
||||
participants: make(map[string]orchestration.ParticipantInfo),
|
||||
knownBotIDs: make(map[string]string),
|
||||
logger: logger,
|
||||
seen: make(map[string]bool),
|
||||
}
|
||||
|
||||
if err := o.loadPrompts(); err != nil {
|
||||
return nil, fmt.Errorf("load prompts: %w", err)
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// RegisterParticipant adds bot metadata used for LLM routing decisions.
|
||||
func (o *Orchestrator) RegisterParticipant(info orchestration.ParticipantInfo) {
|
||||
o.mu.Lock()
|
||||
o.participants[info.ID] = info
|
||||
if info.MatrixUserID != "" {
|
||||
o.knownBotIDs[info.MatrixUserID] = info.ID
|
||||
}
|
||||
o.mu.Unlock()
|
||||
o.logger.Debug("registered participant", "bot", info.ID, "sender_id", info.MatrixUserID)
|
||||
}
|
||||
|
||||
// ShouldIntercept returns true if the room is managed by this orchestrator.
|
||||
func (o *Orchestrator) ShouldIntercept(roomID string) bool {
|
||||
o.mu.RLock()
|
||||
_, ok := o.managedRooms[roomID]
|
||||
o.mu.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// Intercept is the InterceptFunc used by bot listeners. It checks if the
|
||||
// room is managed and, if so, starts the orchestration pipeline asynchronously.
|
||||
// Returns true if the event was intercepted (all bots in the room should return true,
|
||||
// but only the first one triggers actual routing — the rest are deduped).
|
||||
func (o *Orchestrator) Intercept(ctx context.Context, msgCtx decision.MessageContext) bool {
|
||||
if !o.ShouldIntercept(msgCtx.RoomID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Ignore messages from known bots to prevent feedback loops.
|
||||
o.mu.RLock()
|
||||
_, senderIsBot := o.knownBotIDs[msgCtx.SenderID]
|
||||
o.mu.RUnlock()
|
||||
if senderIsBot {
|
||||
return true // suppress but don't route — bot's own message
|
||||
}
|
||||
|
||||
// Dedup: multiple bots receive the same event. Only route once.
|
||||
key := msgCtx.RoomID + ":" + msgCtx.SenderID + ":" + msgCtx.Content
|
||||
o.seenMu.Lock()
|
||||
if o.seen[key] {
|
||||
o.seenMu.Unlock()
|
||||
return true // still intercept (don't let the bot handle it) but don't route again
|
||||
}
|
||||
o.seen[key] = true
|
||||
o.seenMu.Unlock()
|
||||
|
||||
// Route asynchronously so the listener isn't blocked.
|
||||
// Clean up the dedup key after routing completes.
|
||||
go func() {
|
||||
defer func() {
|
||||
o.seenMu.Lock()
|
||||
delete(o.seen, key)
|
||||
o.seenMu.Unlock()
|
||||
}()
|
||||
if err := o.Route(ctx, msgCtx); err != nil {
|
||||
o.logger.Error("orchestration failed", "room", msgCtx.RoomID, "err", err)
|
||||
}
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// Route is the main entry point. Called when a human posts in a managed room.
|
||||
// It decides which bot(s) should respond and dispatches tasks via the bus.
|
||||
func (o *Orchestrator) Route(ctx context.Context, msgCtx decision.MessageContext) error {
|
||||
o.mu.RLock()
|
||||
participants, ok := o.managedRooms[msgCtx.RoomID]
|
||||
participantsCopy := append([]string(nil), participants...)
|
||||
o.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("room %s is not managed", msgCtx.RoomID)
|
||||
}
|
||||
|
||||
o.logger.Info("orchestrating message",
|
||||
"room", msgCtx.RoomID,
|
||||
"sender", msgCtx.SenderID,
|
||||
"participants", participantsCopy,
|
||||
"content_preview", truncate(msgCtx.Content, 80),
|
||||
)
|
||||
|
||||
// Optimization: single bot → dispatch directly without LLM
|
||||
if len(participantsCopy) == 1 {
|
||||
o.logger.Debug("single participant, dispatching directly", "bot", participantsCopy[0])
|
||||
_, err := o.dispatchAndWait(ctx, participantsCopy[0], msgCtx, 0, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
var responses []orchestration.BotResponse
|
||||
var lastBot string
|
||||
maxIter := o.cfg.Orchestration.MaxIterations
|
||||
if maxIter <= 0 {
|
||||
maxIter = 3
|
||||
}
|
||||
|
||||
for i := 0; i < maxIter; i++ {
|
||||
// Route: decide which bot responds
|
||||
var target string
|
||||
var err error
|
||||
|
||||
if i == 0 {
|
||||
rd, routeErr := o.routeInitial(ctx, msgCtx.Content, participantsCopy)
|
||||
if routeErr != nil {
|
||||
o.logger.Error("routing failed, falling back to first participant", "err", routeErr)
|
||||
target = participantsCopy[0]
|
||||
} else {
|
||||
target = rd.TargetBotID
|
||||
o.logger.Info("routed to bot",
|
||||
"bot", target,
|
||||
"confidence", rd.Confidence,
|
||||
"reason", rd.Reason,
|
||||
"iteration", i,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
rd, routeErr := o.routeRefinement(ctx, msgCtx.Content, responses, participantsCopy, lastBot)
|
||||
if routeErr != nil {
|
||||
o.logger.Warn("refinement routing failed, stopping pipeline", "err", routeErr)
|
||||
break
|
||||
}
|
||||
target = rd.TargetBotID
|
||||
o.logger.Info("refinement routed to bot",
|
||||
"bot", target,
|
||||
"reason", rd.Reason,
|
||||
"iteration", i,
|
||||
)
|
||||
}
|
||||
|
||||
// Dispatch: send TaskEvent to bot via bus and wait for response
|
||||
response, err := o.dispatchAndWait(ctx, target, msgCtx, i, responses)
|
||||
if err != nil {
|
||||
o.logger.Error("dispatch failed", "bot", target, "err", err)
|
||||
break
|
||||
}
|
||||
|
||||
responses = append(responses, response)
|
||||
lastBot = target
|
||||
|
||||
o.logger.Info("bot responded",
|
||||
"bot", target,
|
||||
"response_len", len(response.Text),
|
||||
"iteration", i,
|
||||
)
|
||||
|
||||
// Fallback: detect circular conversations before quality evaluation
|
||||
if o.detectRepetition(responses) {
|
||||
o.logger.Warn("repetition detected, stopping pipeline to prevent circular conversation",
|
||||
"iteration", i+1,
|
||||
"total_responses", len(responses),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
// Evaluate quality (Fase 3)
|
||||
score := o.evaluate(ctx, msgCtx.Content, response)
|
||||
o.logger.Info("quality evaluated",
|
||||
"score", score.Score,
|
||||
"continue", score.Continue,
|
||||
"reason", score.Reason,
|
||||
"iteration", i,
|
||||
)
|
||||
|
||||
if score.Score >= o.cfg.Orchestration.QualityThreshold || !score.Continue {
|
||||
o.logger.Info("pipeline complete",
|
||||
"iterations", i+1,
|
||||
"final_score", score.Score,
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dispatchAndWait sends a TaskEvent to a bot and waits for its response.
|
||||
func (o *Orchestrator) dispatchAndWait(
|
||||
ctx context.Context,
|
||||
botID string,
|
||||
msgCtx decision.MessageContext,
|
||||
iteration int,
|
||||
previousResponses []orchestration.BotResponse,
|
||||
) (orchestration.BotResponse, error) {
|
||||
taskID := fmt.Sprintf("orch-%s-%s-%d", msgCtx.RoomID, botID, iteration)
|
||||
|
||||
task := orchestration.TaskEvent{
|
||||
TaskID: taskID,
|
||||
TargetBotID: botID,
|
||||
TargetRoomID: msgCtx.RoomID,
|
||||
OriginalSender: msgCtx.SenderID,
|
||||
OriginalQuestion: msgCtx.Content,
|
||||
Iteration: iteration,
|
||||
PreviousResponses: previousResponses,
|
||||
}
|
||||
|
||||
taskJSON, err := orchestration.MarshalTaskEvent(task)
|
||||
if err != nil {
|
||||
return orchestration.BotResponse{}, fmt.Errorf("marshal task: %w", err)
|
||||
}
|
||||
|
||||
msg := bus.AgentMessage{
|
||||
From: bus.AgentID(o.cfg.Special.ID),
|
||||
To: bus.AgentID(botID),
|
||||
Kind: bus.KindTask,
|
||||
Payload: map[string]string{"task_json": taskJSON},
|
||||
}
|
||||
|
||||
timeout := o.cfg.Orchestration.DelegationTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 30_000_000_000 // 30s default
|
||||
}
|
||||
|
||||
reply, err := o.bus.SendAndWait(ctx, msg, taskID, timeout)
|
||||
if err != nil {
|
||||
return orchestration.BotResponse{}, err
|
||||
}
|
||||
|
||||
resultJSON, ok := reply.Payload["result_json"]
|
||||
if !ok {
|
||||
return orchestration.BotResponse{}, fmt.Errorf("reply missing result_json")
|
||||
}
|
||||
|
||||
result, err := orchestration.UnmarshalTaskResult(resultJSON)
|
||||
if err != nil {
|
||||
return orchestration.BotResponse{}, fmt.Errorf("unmarshal result: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return orchestration.BotResponse{}, fmt.Errorf("bot %s error: %s", botID, result.Error)
|
||||
}
|
||||
|
||||
return orchestration.BotResponse{
|
||||
BotID: botID,
|
||||
Text: result.Text,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadPrompts reads the orchestrator's prompt files.
|
||||
func (o *Orchestrator) loadPrompts() error {
|
||||
base := filepath.Join("agents", "specials", "orchestrator", "prompts")
|
||||
|
||||
routing, err := os.ReadFile(filepath.Join(base, "routing.md"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("routing prompt: %w", err)
|
||||
}
|
||||
o.routingPrompt = string(routing)
|
||||
|
||||
quality, err := os.ReadFile(filepath.Join(base, "quality.md"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("quality prompt: %w", err)
|
||||
}
|
||||
o.qualityPrompt = string(quality)
|
||||
|
||||
refinement, err := os.ReadFile(filepath.Join(base, "refinement.md"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("refinement prompt: %w", err)
|
||||
}
|
||||
o.refinementPrompt = string(refinement)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildParticipantsList formats participant info for LLM prompts.
|
||||
func (o *Orchestrator) buildParticipantsList(botIDs []string, exclude string) string {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
var sb strings.Builder
|
||||
for _, id := range botIDs {
|
||||
if id == exclude {
|
||||
continue
|
||||
}
|
||||
info, ok := o.participants[id]
|
||||
if !ok {
|
||||
sb.WriteString(fmt.Sprintf("- %s: (no description available)\n", id))
|
||||
continue
|
||||
}
|
||||
caps := ""
|
||||
if len(info.Capabilities) > 0 {
|
||||
caps = fmt.Sprintf(" (capabilities: %s)", strings.Join(info.Capabilities, ", "))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("- %s: %s%s\n", info.ID, info.Description, caps))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// detectRepetition checks if a new response is too similar to previous responses,
|
||||
// indicating a circular conversation that should be stopped.
|
||||
// Returns true if the conversation should be terminated.
|
||||
func (o *Orchestrator) detectRepetition(responses []orchestration.BotResponse) bool {
|
||||
if len(responses) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
threshold := o.cfg.Orchestration.RepetitionThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 0.6 // default
|
||||
}
|
||||
|
||||
latest := responses[len(responses)-1].Text
|
||||
for i := 0; i < len(responses)-1; i++ {
|
||||
if similarity(latest, responses[i].Text) >= threshold {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// similarity computes a simple bigram-based similarity ratio between two strings.
|
||||
// Returns a value between 0.0 (completely different) and 1.0 (identical).
|
||||
func similarity(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
a = strings.ToLower(strings.TrimSpace(a))
|
||||
b = strings.ToLower(strings.TrimSpace(b))
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
if len(a) < 2 || len(b) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
bigramsA := makeBigrams(a)
|
||||
bigramsB := makeBigrams(b)
|
||||
|
||||
// Count intersection
|
||||
intersection := 0
|
||||
for bg, countA := range bigramsA {
|
||||
if countB, ok := bigramsB[bg]; ok {
|
||||
if countA < countB {
|
||||
intersection += countA
|
||||
} else {
|
||||
intersection += countB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
totalA := 0
|
||||
for _, c := range bigramsA {
|
||||
totalA += c
|
||||
}
|
||||
totalB := 0
|
||||
for _, c := range bigramsB {
|
||||
totalB += c
|
||||
}
|
||||
|
||||
if totalA+totalB == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(2*intersection) / float64(totalA+totalB)
|
||||
}
|
||||
|
||||
func makeBigrams(s string) map[string]int {
|
||||
runes := []rune(s)
|
||||
bgs := make(map[string]int, len(runes))
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
bg := string(runes[i : i+2])
|
||||
bgs[bg]++
|
||||
}
|
||||
return bgs
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= n {
|
||||
return s
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package orchestration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/orchestration"
|
||||
)
|
||||
|
||||
// routeInitial asks the LLM which bot should handle the question first.
|
||||
func (o *Orchestrator) routeInitial(ctx context.Context, question string, participants []string) (orchestration.RoutingDecision, error) {
|
||||
systemPrompt := strings.ReplaceAll(o.routingPrompt, "{{PARTICIPANTS}}", o.buildParticipantsList(participants, ""))
|
||||
|
||||
resp, err := o.llm(ctx, coretypes.CompletionRequest{
|
||||
Model: o.cfg.LLM.Primary.Model,
|
||||
MaxTokens: o.cfg.LLM.Primary.MaxTokens,
|
||||
Temperature: o.cfg.LLM.Primary.Temperature,
|
||||
SystemPrompt: systemPrompt,
|
||||
Messages: []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: question},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return orchestration.RoutingDecision{}, fmt.Errorf("LLM routing call: %w", err)
|
||||
}
|
||||
|
||||
var rd orchestration.RoutingDecision
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(resp.Content)), &rd); err != nil {
|
||||
o.logger.Warn("failed to parse routing response, raw", "content", resp.Content, "err", err)
|
||||
return orchestration.RoutingDecision{}, fmt.Errorf("parse routing decision: %w", err)
|
||||
}
|
||||
|
||||
// Validate the chosen bot is actually a participant
|
||||
if !contains(participants, rd.TargetBotID) {
|
||||
o.logger.Warn("LLM chose unknown bot, falling back to first", "chosen", rd.TargetBotID)
|
||||
rd.TargetBotID = participants[0]
|
||||
rd.Confidence = 0.5
|
||||
rd.Reason = "fallback: LLM chose unknown bot"
|
||||
}
|
||||
|
||||
return rd, nil
|
||||
}
|
||||
|
||||
// routeRefinement asks the LLM which bot should improve the response,
|
||||
// excluding the last respondent.
|
||||
func (o *Orchestrator) routeRefinement(
|
||||
ctx context.Context,
|
||||
question string,
|
||||
responses []orchestration.BotResponse,
|
||||
participants []string,
|
||||
excludeBot string,
|
||||
) (orchestration.RoutingDecision, error) {
|
||||
lastResponse := ""
|
||||
if len(responses) > 0 {
|
||||
lastResponse = responses[len(responses)-1].Text
|
||||
}
|
||||
|
||||
systemPrompt := strings.ReplaceAll(o.refinementPrompt, "{{PARTICIPANTS}}", o.buildParticipantsList(participants, excludeBot))
|
||||
systemPrompt = strings.ReplaceAll(systemPrompt, "{{LAST_RESPONSE}}", lastResponse)
|
||||
|
||||
userContent := fmt.Sprintf("Original question: %s\n\nCurrent response that needs improvement:\n%s", question, lastResponse)
|
||||
|
||||
resp, err := o.llm(ctx, coretypes.CompletionRequest{
|
||||
Model: o.cfg.LLM.Primary.Model,
|
||||
MaxTokens: o.cfg.LLM.Primary.MaxTokens,
|
||||
Temperature: o.cfg.LLM.Primary.Temperature,
|
||||
SystemPrompt: systemPrompt,
|
||||
Messages: []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: userContent},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return orchestration.RoutingDecision{}, fmt.Errorf("LLM refinement call: %w", err)
|
||||
}
|
||||
|
||||
var rd orchestration.RoutingDecision
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(resp.Content)), &rd); err != nil {
|
||||
o.logger.Warn("failed to parse refinement response", "content", resp.Content, "err", err)
|
||||
return orchestration.RoutingDecision{}, fmt.Errorf("parse refinement decision: %w", err)
|
||||
}
|
||||
|
||||
// Validate: must be a participant and not the excluded bot
|
||||
if rd.TargetBotID == excludeBot || !contains(participants, rd.TargetBotID) {
|
||||
// Pick first available that isn't excluded
|
||||
for _, p := range participants {
|
||||
if p != excludeBot {
|
||||
rd.TargetBotID = p
|
||||
rd.Reason = "fallback: LLM chose excluded or unknown bot"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rd, nil
|
||||
}
|
||||
|
||||
func contains(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,692 @@
|
||||
// Package process manages agent processes: discovery, start, stop, kill, stats.
|
||||
// This is the impure shell layer — all I/O happens here.
|
||||
package process
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
)
|
||||
|
||||
// AgentInfo holds metadata about an agent parsed from its config.
|
||||
type AgentInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Version string
|
||||
Desc string
|
||||
ConfigPath string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// AgentStatus combines agent metadata with runtime state.
|
||||
type AgentStatus struct {
|
||||
AgentInfo
|
||||
Running bool
|
||||
PID int
|
||||
Instances int
|
||||
}
|
||||
|
||||
// ProcessStats holds resource usage for a running process.
|
||||
type ProcessStats struct {
|
||||
PID int
|
||||
UptimeSecs int64
|
||||
MemRSSKB int64
|
||||
CPUPct float64
|
||||
LogBytes int64
|
||||
}
|
||||
|
||||
// processProber abstracts process detection for testing.
|
||||
type processProber interface {
|
||||
// pgrepPIDs runs pgrep -f with the given pattern and returns matching PIDs.
|
||||
pgrepPIDs(pattern string) []int
|
||||
// processComm returns the comm name for a PID (e.g. "launcher", "go").
|
||||
processComm(pid int) string
|
||||
// isAlive checks if a PID is running.
|
||||
isAlive(pid int) bool
|
||||
}
|
||||
|
||||
// osProber is the real implementation using OS calls.
|
||||
type osProber struct{}
|
||||
|
||||
func (osProber) pgrepPIDs(pattern string) []int {
|
||||
out, err := exec.Command("pgrep", "-f", pattern).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var pids []int
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
if p, err := strconv.Atoi(strings.TrimSpace(line)); err == nil && p > 0 {
|
||||
pids = append(pids, p)
|
||||
}
|
||||
}
|
||||
return pids
|
||||
}
|
||||
|
||||
func (osProber) processComm(pid int) string {
|
||||
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
func (osProber) isAlive(pid int) bool {
|
||||
return syscall.Kill(pid, 0) == nil
|
||||
}
|
||||
|
||||
const unifiedID = "launcher" // PID/log file ID for the unified launcher
|
||||
|
||||
// Manager handles agent process lifecycle.
|
||||
type Manager struct {
|
||||
runDir string
|
||||
agentsGlob string
|
||||
binPath string
|
||||
envFile string // path to .env file for child processes
|
||||
prober processProber
|
||||
}
|
||||
|
||||
// NewManager creates a Manager. binPath can be empty for auto-detection.
|
||||
func NewManager(runDir, agentsGlob, binPath string) *Manager {
|
||||
return &Manager{runDir: runDir, agentsGlob: agentsGlob, binPath: binPath, envFile: ".env", prober: osProber{}}
|
||||
}
|
||||
|
||||
// Scan discovers all agents from config files.
|
||||
func (m *Manager) Scan() ([]AgentInfo, error) {
|
||||
matches, err := filepath.Glob(m.agentsGlob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var agents []AgentInfo
|
||||
for _, path := range matches {
|
||||
cfg, err := config.LoadMeta(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
agents = append(agents, AgentInfo{
|
||||
ID: cfg.Agent.ID,
|
||||
Name: cfg.Agent.Name,
|
||||
Version: cfg.Agent.Version,
|
||||
Desc: cfg.Agent.Description,
|
||||
ConfigPath: path,
|
||||
Enabled: cfg.Agent.Enabled,
|
||||
})
|
||||
}
|
||||
return agents, nil
|
||||
}
|
||||
|
||||
// Status returns the runtime status for a single agent.
|
||||
func (m *Manager) Status(info AgentInfo) AgentStatus {
|
||||
pids := m.findProcessPIDs(info.ID)
|
||||
primary := 0
|
||||
if len(pids) > 0 {
|
||||
primary = pids[0]
|
||||
}
|
||||
return AgentStatus{
|
||||
AgentInfo: info,
|
||||
Running: len(pids) > 0,
|
||||
PID: primary,
|
||||
Instances: len(pids),
|
||||
}
|
||||
}
|
||||
|
||||
// StatusAll returns status for every discovered agent.
|
||||
func (m *Manager) StatusAll() ([]AgentStatus, error) {
|
||||
agents, err := m.Scan()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
statuses := make([]AgentStatus, len(agents))
|
||||
for i, a := range agents {
|
||||
statuses[i] = m.Status(a)
|
||||
}
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
// Start launches an agent process in the background.
|
||||
// Returns an error if the agent is already running.
|
||||
func (m *Manager) Start(info AgentInfo) error {
|
||||
if pids := m.findProcessPIDs(info.ID); len(pids) > 0 {
|
||||
return fmt.Errorf("agent %q is already running (PID %d)", info.ID, pids[0])
|
||||
}
|
||||
if err := os.MkdirAll(m.runDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create run dir: %w", err)
|
||||
}
|
||||
|
||||
logFile, err := os.OpenFile(m.logPath(info.ID), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log: %w", err)
|
||||
}
|
||||
|
||||
bin := m.resolvedBin()
|
||||
var cmd *exec.Cmd
|
||||
if strings.HasPrefix(bin, "go run") {
|
||||
cmd = exec.Command("go", "run", "-tags", "goolm", "./cmd/launcher", "-c", info.ConfigPath)
|
||||
} else {
|
||||
cmd = exec.Command(bin, "-c", info.ConfigPath)
|
||||
}
|
||||
|
||||
cmd.Env = m.BuildEnv()
|
||||
cmd.Stdout = logFile
|
||||
cmd.Stderr = logFile
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
logFile.Close()
|
||||
return fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.pidPath(info.ID), []byte(strconv.Itoa(cmd.Process.Pid)), 0o644); err != nil {
|
||||
return fmt.Errorf("write PID: %w", err)
|
||||
}
|
||||
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop sends SIGTERM to all instances, waits up to 5s, then SIGKILL if needed.
|
||||
func (m *Manager) Stop(id string) error {
|
||||
pids := m.findProcessPIDs(id)
|
||||
// Also include PID file PID if alive and not already in the list
|
||||
filePID := m.readPID(id)
|
||||
if filePID > 0 && m.isAlive(filePID) {
|
||||
found := false
|
||||
for _, p := range pids {
|
||||
if p == filePID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
pids = append(pids, filePID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pids) == 0 {
|
||||
return fmt.Errorf("agent %q is not running", id)
|
||||
}
|
||||
|
||||
// SIGTERM all instances
|
||||
for _, pid := range pids {
|
||||
_ = syscall.Kill(pid, syscall.SIGTERM)
|
||||
}
|
||||
|
||||
// Wait up to 5 seconds for graceful shutdown.
|
||||
for i := 0; i < 10; i++ {
|
||||
allDead := true
|
||||
for _, pid := range pids {
|
||||
if m.isAlive(pid) {
|
||||
allDead = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allDead {
|
||||
m.removePID(id)
|
||||
return nil
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Force kill survivors.
|
||||
for _, pid := range pids {
|
||||
if m.isAlive(pid) {
|
||||
_ = syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
m.removePID(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kill sends SIGKILL to all instances immediately.
|
||||
func (m *Manager) Kill(id string) error {
|
||||
pids := m.findProcessPIDs(id)
|
||||
filePID := m.readPID(id)
|
||||
if filePID > 0 && m.isAlive(filePID) {
|
||||
found := false
|
||||
for _, p := range pids {
|
||||
if p == filePID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
pids = append(pids, filePID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pids) == 0 {
|
||||
return fmt.Errorf("agent %q is not running", id)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, pid := range pids {
|
||||
if err := syscall.Kill(pid, syscall.SIGKILL); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
m.removePID(id)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Stats gathers resource usage for a running agent from /proc.
|
||||
func (m *Manager) Stats(id string) (ProcessStats, error) {
|
||||
pid := m.resolveRunningPID(id)
|
||||
if pid == 0 {
|
||||
return ProcessStats{}, fmt.Errorf("agent %q is not running", id)
|
||||
}
|
||||
return m.statsForPID(pid, id), nil
|
||||
}
|
||||
|
||||
// statsForPID gathers resource usage for a specific PID.
|
||||
func (m *Manager) statsForPID(pid int, id string) ProcessStats {
|
||||
s := ProcessStats{PID: pid}
|
||||
|
||||
// Uptime from /proc/<pid>/stat
|
||||
if data, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)); err == nil {
|
||||
fields := strings.Fields(string(data))
|
||||
if len(fields) > 21 {
|
||||
startTicks, _ := strconv.ParseInt(fields[21], 10, 64)
|
||||
clkTck := int64(100) // sysconf(_SC_CLK_TCK) is 100 on Linux
|
||||
if raw, err := os.ReadFile("/proc/stat"); err == nil {
|
||||
for _, line := range strings.Split(string(raw), "\n") {
|
||||
if strings.HasPrefix(line, "btime ") {
|
||||
btime, _ := strconv.ParseInt(strings.Fields(line)[1], 10, 64)
|
||||
procStart := btime + startTicks/clkTck
|
||||
s.UptimeSecs = time.Now().Unix() - procStart
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RSS from /proc/<pid>/status
|
||||
if data, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid)); err == nil {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "VmRSS:") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
s.MemRSSKB, _ = strconv.ParseInt(fields[1], 10, 64)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CPU% from ps (simpler than calculating from /proc/stat deltas)
|
||||
if out, err := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "pcpu=").Output(); err == nil {
|
||||
s.CPUPct, _ = strconv.ParseFloat(strings.TrimSpace(string(out)), 64)
|
||||
}
|
||||
|
||||
// Log file size
|
||||
if info, err := os.Stat(m.logPath(id)); err == nil {
|
||||
s.LogBytes = info.Size()
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// LogTail returns the last N lines of an agent's log.
|
||||
func (m *Manager) LogTail(id string, lines int) ([]string, error) {
|
||||
f, err := os.Open(m.logPath(id))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open log: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Read all lines and keep last N. For large files a reverse scanner
|
||||
// would be better, but agent logs are typically small.
|
||||
var all []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
all = append(all, scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(all) > lines {
|
||||
all = all[len(all)-lines:]
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsRunning checks if an agent process is alive.
|
||||
func (m *Manager) IsRunning(id string) bool {
|
||||
return m.resolveRunningPID(id) > 0
|
||||
}
|
||||
|
||||
// InstanceCount returns how many launcher processes are running for an agent.
|
||||
func (m *Manager) InstanceCount(id string) int {
|
||||
return len(m.findProcessPIDs(id))
|
||||
}
|
||||
|
||||
// ReadPID returns the PID from the PID file, or 0.
|
||||
func (m *Manager) ReadPID(id string) int {
|
||||
return m.readPID(id)
|
||||
}
|
||||
|
||||
// PidPath returns the path to the PID file for an agent.
|
||||
func (m *Manager) PidPath(id string) string { return m.pidPath(id) }
|
||||
|
||||
// LogPath returns the path to the log file for an agent.
|
||||
func (m *Manager) LogPath(id string) string { return m.logPath(id) }
|
||||
|
||||
// Build compiles all project binaries by running build.sh.
|
||||
// Returns the combined output and any error.
|
||||
func (m *Manager) Build() (string, error) {
|
||||
cmd := exec.Command("bash", "build.sh")
|
||||
cmd.Env = m.BuildEnv()
|
||||
out, err := cmd.CombinedOutput()
|
||||
return string(out), err
|
||||
}
|
||||
|
||||
// ── Unified launcher ─────────────────────────────────────────────────────
|
||||
// The unified launcher runs ALL enabled agents + orchestrator in a single
|
||||
// process. PID → run/launcher.pid, log → run/launcher.log.
|
||||
|
||||
// StartUnified launches the unified launcher (no -c flag → discovers all agents).
|
||||
func (m *Manager) StartUnified() error {
|
||||
if m.IsUnifiedRunning() {
|
||||
return fmt.Errorf("unified launcher is already running (PID %d)", m.readPID(unifiedID))
|
||||
}
|
||||
if err := os.MkdirAll(m.runDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create run dir: %w", err)
|
||||
}
|
||||
|
||||
logFile, err := os.OpenFile(m.logPath(unifiedID), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log: %w", err)
|
||||
}
|
||||
|
||||
bin := m.resolvedBin()
|
||||
var cmd *exec.Cmd
|
||||
if strings.HasPrefix(bin, "go run") {
|
||||
cmd = exec.Command("go", "run", "-tags", "goolm", "./cmd/launcher", "--log-level", "info")
|
||||
} else {
|
||||
cmd = exec.Command(bin, "--log-level", "info")
|
||||
}
|
||||
|
||||
cmd.Env = m.BuildEnv()
|
||||
cmd.Stdout = logFile
|
||||
cmd.Stderr = logFile
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
logFile.Close()
|
||||
return fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.pidPath(unifiedID), []byte(strconv.Itoa(cmd.Process.Pid)), 0o644); err != nil {
|
||||
return fmt.Errorf("write PID: %w", err)
|
||||
}
|
||||
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopUnified stops the unified launcher process.
|
||||
func (m *Manager) StopUnified() error {
|
||||
return m.Stop(unifiedID)
|
||||
}
|
||||
|
||||
// KillUnified sends SIGKILL to the unified launcher.
|
||||
func (m *Manager) KillUnified() error {
|
||||
return m.Kill(unifiedID)
|
||||
}
|
||||
|
||||
// IsUnifiedRunning checks if the unified launcher is alive.
|
||||
func (m *Manager) IsUnifiedRunning() bool {
|
||||
pid := m.readPID(unifiedID)
|
||||
if pid > 0 && m.isAlive(pid) {
|
||||
return true
|
||||
}
|
||||
// Fallback: search for launcher running without -c flag
|
||||
pids := m.findUnifiedPIDs()
|
||||
return len(pids) > 0
|
||||
}
|
||||
|
||||
// UnifiedPID returns the PID of the running unified launcher, or 0.
|
||||
func (m *Manager) UnifiedPID() int {
|
||||
pid := m.readPID(unifiedID)
|
||||
if pid > 0 && m.isAlive(pid) {
|
||||
return pid
|
||||
}
|
||||
pids := m.findUnifiedPIDs()
|
||||
if len(pids) > 0 {
|
||||
// Repair PID file
|
||||
_ = os.WriteFile(m.pidPath(unifiedID), []byte(strconv.Itoa(pids[0])), 0o644)
|
||||
return pids[0]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// UnifiedStats returns resource usage for the unified launcher process.
|
||||
func (m *Manager) UnifiedStats() (ProcessStats, error) {
|
||||
pid := m.UnifiedPID()
|
||||
if pid == 0 {
|
||||
return ProcessStats{}, fmt.Errorf("unified launcher is not running")
|
||||
}
|
||||
return m.statsForPID(pid, unifiedID), nil
|
||||
}
|
||||
|
||||
// UnifiedLogTail returns the last N lines of the unified launcher log.
|
||||
func (m *Manager) UnifiedLogTail(lines int) ([]string, error) {
|
||||
return m.LogTail(unifiedID, lines)
|
||||
}
|
||||
|
||||
// StatusAllUnified returns status for all agents, deriving "running" from
|
||||
// whether the unified launcher is running + the agent is enabled.
|
||||
func (m *Manager) StatusAllUnified() ([]AgentStatus, error) {
|
||||
agents, err := m.Scan()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
launcherRunning := m.IsUnifiedRunning()
|
||||
launcherPID := m.UnifiedPID()
|
||||
|
||||
statuses := make([]AgentStatus, len(agents))
|
||||
for i, a := range agents {
|
||||
running := launcherRunning && a.Enabled
|
||||
pid := 0
|
||||
instances := 0
|
||||
if running {
|
||||
pid = launcherPID
|
||||
instances = 1
|
||||
}
|
||||
statuses[i] = AgentStatus{
|
||||
AgentInfo: a,
|
||||
Running: running,
|
||||
PID: pid,
|
||||
Instances: instances,
|
||||
}
|
||||
}
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
// ToggleEnabled sets the enabled field in an agent's config.yaml.
|
||||
func (m *Manager) ToggleEnabled(id string, enabled bool) error {
|
||||
agents, err := m.Scan()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, a := range agents {
|
||||
if a.ID == id {
|
||||
return m.setEnabledInConfig(a.ConfigPath, enabled)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("agent %q not found", id)
|
||||
}
|
||||
|
||||
// setEnabledInConfig rewrites the enabled field in a config.yaml.
|
||||
func (m *Manager) setEnabledInConfig(path string, enabled bool) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val := "false"
|
||||
if enabled {
|
||||
val = "true"
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "enabled:") {
|
||||
// Preserve indentation
|
||||
indent := line[:len(line)-len(strings.TrimLeft(line, " \t"))]
|
||||
lines[i] = indent + "enabled: " + val
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0o644)
|
||||
}
|
||||
|
||||
// findUnifiedPIDs finds launcher processes running without -c flag.
|
||||
func (m *Manager) findUnifiedPIDs() []int {
|
||||
// Search for launcher processes that do NOT have -c flag
|
||||
raw := m.prober.pgrepPIDs("launcher.*--log-level")
|
||||
var pids []int
|
||||
for _, p := range raw {
|
||||
comm := m.prober.processComm(p)
|
||||
if comm == "go" {
|
||||
continue
|
||||
}
|
||||
pids = append(pids, p)
|
||||
}
|
||||
return pids
|
||||
}
|
||||
|
||||
// ── internal helpers ─────────────────────────────────────────────────────
|
||||
|
||||
func (m *Manager) pidPath(id string) string { return filepath.Join(m.runDir, id+".pid") }
|
||||
func (m *Manager) logPath(id string) string { return filepath.Join(m.runDir, id+".log") }
|
||||
|
||||
func (m *Manager) readPID(id string) int {
|
||||
raw, err := os.ReadFile(m.pidPath(id))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
pid, _ := strconv.Atoi(strings.TrimSpace(string(raw)))
|
||||
return pid
|
||||
}
|
||||
|
||||
// findProcessPIDs searches for running launcher processes for a given agent ID
|
||||
// using pgrep. Filters out "go run" wrapper PIDs to avoid double-counting.
|
||||
func (m *Manager) findProcessPIDs(id string) []int {
|
||||
configPath := m.configPathFor(id)
|
||||
if configPath == "" {
|
||||
return nil
|
||||
}
|
||||
pattern := fmt.Sprintf("launcher.*-c.*%s", configPath)
|
||||
raw := m.prober.pgrepPIDs(pattern)
|
||||
|
||||
// Filter out the "go" wrapper process that appears when using "go run".
|
||||
var pids []int
|
||||
for _, p := range raw {
|
||||
comm := m.prober.processComm(p)
|
||||
if comm == "go" {
|
||||
continue
|
||||
}
|
||||
pids = append(pids, p)
|
||||
}
|
||||
return pids
|
||||
}
|
||||
|
||||
// configPathFor returns the config file path for the given agent ID.
|
||||
func (m *Manager) configPathFor(id string) string {
|
||||
matches, err := filepath.Glob(m.agentsGlob)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, path := range matches {
|
||||
cfg, err := config.LoadMeta(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cfg.Agent.ID == id {
|
||||
return path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveRunningPID returns the PID of the running agent, checking the PID file
|
||||
// first and falling back to process discovery. It also repairs stale PID files.
|
||||
func (m *Manager) resolveRunningPID(id string) int {
|
||||
// Check PID file first
|
||||
pid := m.readPID(id)
|
||||
if pid > 0 && m.isAlive(pid) {
|
||||
return pid
|
||||
}
|
||||
|
||||
// PID file is stale or missing — search for actual processes
|
||||
pids := m.findProcessPIDs(id)
|
||||
if len(pids) > 0 {
|
||||
// Repair the PID file with the first found process
|
||||
_ = os.WriteFile(m.pidPath(id), []byte(strconv.Itoa(pids[0])), 0o644)
|
||||
return pids[0]
|
||||
}
|
||||
|
||||
// Clean up stale PID file
|
||||
if pid > 0 {
|
||||
m.removePID(id)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Manager) isAlive(pid int) bool {
|
||||
return m.prober.isAlive(pid)
|
||||
}
|
||||
|
||||
func (m *Manager) removePID(id string) {
|
||||
_ = os.Remove(m.pidPath(id))
|
||||
}
|
||||
|
||||
// BuildEnv returns the environment for child processes: current env + .env file vars.
|
||||
func (m *Manager) BuildEnv() []string {
|
||||
env := os.Environ()
|
||||
if m.envFile == "" {
|
||||
return env
|
||||
}
|
||||
data, err := os.ReadFile(m.envFile)
|
||||
if err != nil {
|
||||
return env
|
||||
}
|
||||
// Parse KEY=VALUE lines, skip comments and blanks.
|
||||
seen := make(map[string]bool)
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(line, "="); idx > 0 {
|
||||
key := line[:idx]
|
||||
seen[key] = true
|
||||
env = append(env, line)
|
||||
}
|
||||
}
|
||||
_ = seen // .env values appended last, so they override earlier entries
|
||||
return env
|
||||
}
|
||||
|
||||
func (m *Manager) resolvedBin() string {
|
||||
if m.binPath != "" {
|
||||
return m.binPath
|
||||
}
|
||||
if _, err := os.Stat("bin/launcher"); err == nil {
|
||||
return "bin/launcher"
|
||||
}
|
||||
return "go run ./cmd/launcher"
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fakeProber is a test double for processProber.
|
||||
type fakeProber struct {
|
||||
pids map[string][]int // pattern → PIDs
|
||||
comms map[int]string // PID → comm name
|
||||
alive map[int]bool // PID → is alive
|
||||
}
|
||||
|
||||
func newFakeProber() *fakeProber {
|
||||
return &fakeProber{
|
||||
pids: make(map[string][]int),
|
||||
comms: make(map[int]string),
|
||||
alive: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeProber) pgrepPIDs(pattern string) []int { return f.pids[pattern] }
|
||||
func (f *fakeProber) processComm(pid int) string { return f.comms[pid] }
|
||||
func (f *fakeProber) isAlive(pid int) bool { return f.alive[pid] }
|
||||
|
||||
// testManager creates a Manager with a temp dir, fake prober, and a config file.
|
||||
func testManager(t *testing.T, fp *fakeProber) (*Manager, string) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
runDir := filepath.Join(dir, "run")
|
||||
agentsDir := filepath.Join(dir, "agents", "test-bot")
|
||||
_ = os.MkdirAll(runDir, 0o755)
|
||||
_ = os.MkdirAll(agentsDir, 0o755)
|
||||
|
||||
// Minimal config.yaml so Scan() and configPathFor() work.
|
||||
cfgPath := filepath.Join(agentsDir, "config.yaml")
|
||||
_ = os.WriteFile(cfgPath, []byte(`agent:
|
||||
id: test-bot
|
||||
name: Test Bot
|
||||
version: "0.1"
|
||||
enabled: true
|
||||
`), 0o644)
|
||||
|
||||
glob := filepath.Join(dir, "agents", "*", "config.yaml")
|
||||
m := &Manager{
|
||||
runDir: runDir,
|
||||
agentsGlob: glob,
|
||||
binPath: "/bin/true", // won't actually run
|
||||
envFile: "",
|
||||
prober: fp,
|
||||
}
|
||||
return m, cfgPath
|
||||
}
|
||||
|
||||
func TestFindProcessPIDs_FiltersGoWrapper(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, cfgPath := testManager(t, fp)
|
||||
|
||||
// Simulate pgrep returning 2 PIDs: go wrapper (100) + real launcher (200).
|
||||
pattern := "launcher.*-c.*" + cfgPath
|
||||
fp.pids[pattern] = []int{100, 200}
|
||||
fp.comms[100] = "go"
|
||||
fp.comms[200] = "launcher"
|
||||
|
||||
pids := m.findProcessPIDs("test-bot")
|
||||
|
||||
if len(pids) != 1 {
|
||||
t.Fatalf("expected 1 PID, got %d: %v", len(pids), pids)
|
||||
}
|
||||
if pids[0] != 200 {
|
||||
t.Errorf("expected PID 200, got %d", pids[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindProcessPIDs_NoPIDs(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, _ := testManager(t, fp)
|
||||
|
||||
pids := m.findProcessPIDs("test-bot")
|
||||
if len(pids) != 0 {
|
||||
t.Fatalf("expected 0 PIDs, got %d", len(pids))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus_SingleInstance(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, cfgPath := testManager(t, fp)
|
||||
|
||||
pattern := "launcher.*-c.*" + cfgPath
|
||||
fp.pids[pattern] = []int{42}
|
||||
fp.comms[42] = "launcher"
|
||||
|
||||
info := AgentInfo{ID: "test-bot", Name: "Test", ConfigPath: cfgPath, Enabled: true}
|
||||
st := m.Status(info)
|
||||
|
||||
if !st.Running {
|
||||
t.Error("expected Running=true")
|
||||
}
|
||||
if st.PID != 42 {
|
||||
t.Errorf("expected PID=42, got %d", st.PID)
|
||||
}
|
||||
if st.Instances != 1 {
|
||||
t.Errorf("expected Instances=1, got %d", st.Instances)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus_NoInstances(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, cfgPath := testManager(t, fp)
|
||||
|
||||
info := AgentInfo{ID: "test-bot", Name: "Test", ConfigPath: cfgPath, Enabled: true}
|
||||
st := m.Status(info)
|
||||
|
||||
if st.Running {
|
||||
t.Error("expected Running=false")
|
||||
}
|
||||
if st.Instances != 0 {
|
||||
t.Errorf("expected Instances=0, got %d", st.Instances)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart_RejectsWhenAlreadyRunning(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, cfgPath := testManager(t, fp)
|
||||
|
||||
pattern := "launcher.*-c.*" + cfgPath
|
||||
fp.pids[pattern] = []int{99}
|
||||
fp.comms[99] = "launcher"
|
||||
|
||||
info := AgentInfo{ID: "test-bot", Name: "Test", ConfigPath: cfgPath, Enabled: true}
|
||||
err := m.Start(info)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when agent already running")
|
||||
}
|
||||
if got := err.Error(); got != `agent "test-bot" is already running (PID 99)` {
|
||||
t.Errorf("unexpected error: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRunningPID_RepairsStale(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, cfgPath := testManager(t, fp)
|
||||
|
||||
// Write a stale PID file (PID 999 is dead).
|
||||
_ = os.MkdirAll(m.runDir, 0o755)
|
||||
_ = os.WriteFile(m.pidPath("test-bot"), []byte("999"), 0o644)
|
||||
fp.alive[999] = false
|
||||
|
||||
// But the real process is at PID 42.
|
||||
pattern := "launcher.*-c.*" + cfgPath
|
||||
fp.pids[pattern] = []int{42}
|
||||
fp.comms[42] = "launcher"
|
||||
|
||||
pid := m.resolveRunningPID("test-bot")
|
||||
if pid != 42 {
|
||||
t.Errorf("expected repaired PID=42, got %d", pid)
|
||||
}
|
||||
|
||||
// Verify PID file was repaired.
|
||||
data, err := os.ReadFile(m.pidPath("test-bot"))
|
||||
if err != nil {
|
||||
t.Fatalf("read pid file: %v", err)
|
||||
}
|
||||
if got, _ := strconv.Atoi(string(data)); got != 42 {
|
||||
t.Errorf("expected PID file to contain 42, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRunningPID_CleansUpStalePIDFile(t *testing.T) {
|
||||
fp := newFakeProber()
|
||||
m, _ := testManager(t, fp)
|
||||
|
||||
// Write a stale PID file, no real process running.
|
||||
_ = os.MkdirAll(m.runDir, 0o755)
|
||||
_ = os.WriteFile(m.pidPath("test-bot"), []byte("999"), 0o644)
|
||||
fp.alive[999] = false
|
||||
|
||||
pid := m.resolveRunningPID("test-bot")
|
||||
if pid != 0 {
|
||||
t.Errorf("expected 0 for dead process, got %d", pid)
|
||||
}
|
||||
|
||||
// PID file should be removed.
|
||||
if _, err := os.Stat(m.pidPath("test-bot")); !os.IsNotExist(err) {
|
||||
t.Error("expected stale PID file to be removed")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
// Package security provides the impure loader for security policy YAML files.
|
||||
// It reads security/ directory files and returns a pure security.SecurityPolicy.
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/security"
|
||||
)
|
||||
|
||||
// --- YAML intermediate types (private, only for parsing) ---
|
||||
|
||||
type yamlUserGroups struct {
|
||||
Groups map[string]struct {
|
||||
Members []string `yaml:"members"`
|
||||
} `yaml:"groups"`
|
||||
}
|
||||
|
||||
type yamlAgentGroups struct {
|
||||
Groups map[string]struct {
|
||||
Agents []string `yaml:"agents"`
|
||||
} `yaml:"groups"`
|
||||
}
|
||||
|
||||
type yamlPermissions struct {
|
||||
Policies []struct {
|
||||
AgentGroup string `yaml:"agent_group"`
|
||||
Permissions []struct {
|
||||
UserGroup string `yaml:"user_group"`
|
||||
Actions []string `yaml:"actions"`
|
||||
} `yaml:"permissions"`
|
||||
} `yaml:"policies"`
|
||||
}
|
||||
|
||||
// Load reads the security YAML files from dir and returns a SecurityPolicy.
|
||||
// If dir does not exist or is empty, returns an empty policy without error.
|
||||
// If an individual file is missing, that section is left empty.
|
||||
// If a YAML file is malformed, returns an error naming the file.
|
||||
func Load(dir string) (security.SecurityPolicy, error) {
|
||||
if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) {
|
||||
return security.SecurityPolicy{}, nil
|
||||
}
|
||||
|
||||
userGroups, err := loadUserGroups(filepath.Join(dir, "user-groups.yaml"))
|
||||
if err != nil {
|
||||
return security.SecurityPolicy{}, err
|
||||
}
|
||||
|
||||
agentGroups, err := loadAgentGroups(filepath.Join(dir, "agent-groups.yaml"))
|
||||
if err != nil {
|
||||
return security.SecurityPolicy{}, err
|
||||
}
|
||||
|
||||
policies, err := loadPermissions(filepath.Join(dir, "permissions.yaml"))
|
||||
if err != nil {
|
||||
return security.SecurityPolicy{}, err
|
||||
}
|
||||
|
||||
return security.SecurityPolicy{
|
||||
UserGroups: userGroups,
|
||||
AgentGroups: agentGroups,
|
||||
Policies: policies,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadUserGroups(path string) ([]security.UserGroup, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("security: reading %s: %w", path, err)
|
||||
}
|
||||
|
||||
var raw yamlUserGroups
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("security: parsing %s: %w", path, err)
|
||||
}
|
||||
|
||||
groups := make([]security.UserGroup, 0, len(raw.Groups))
|
||||
for name, g := range raw.Groups {
|
||||
groups = append(groups, security.UserGroup{
|
||||
Name: name,
|
||||
Members: g.Members,
|
||||
})
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func loadAgentGroups(path string) ([]security.AgentGroup, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("security: reading %s: %w", path, err)
|
||||
}
|
||||
|
||||
var raw yamlAgentGroups
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("security: parsing %s: %w", path, err)
|
||||
}
|
||||
|
||||
groups := make([]security.AgentGroup, 0, len(raw.Groups))
|
||||
for name, g := range raw.Groups {
|
||||
groups = append(groups, security.AgentGroup{
|
||||
Name: name,
|
||||
Agents: g.Agents,
|
||||
})
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func loadPermissions(path string) ([]security.AgentPolicy, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("security: reading %s: %w", path, err)
|
||||
}
|
||||
|
||||
var raw yamlPermissions
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("security: parsing %s: %w", path, err)
|
||||
}
|
||||
|
||||
policies := make([]security.AgentPolicy, 0, len(raw.Policies))
|
||||
for _, p := range raw.Policies {
|
||||
perms := make([]security.Permission, 0, len(p.Permissions))
|
||||
for _, perm := range p.Permissions {
|
||||
perms = append(perms, security.Permission{
|
||||
UserGroup: perm.UserGroup,
|
||||
Actions: perm.Actions,
|
||||
})
|
||||
}
|
||||
policies = append(policies, security.AgentPolicy{
|
||||
AgentGroup: p.AgentGroup,
|
||||
Permissions: perms,
|
||||
})
|
||||
}
|
||||
return policies, nil
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
shellsecurity "github.com/enmanuel/agents/shell/security"
|
||||
)
|
||||
|
||||
// writeFile is a helper that creates a file in dir with the given content.
|
||||
func writeFile(t *testing.T, dir, name, content string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("writeFile %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3.1: directorio inexistente → policy vacía, sin error ---
|
||||
|
||||
func TestLoad_NonExistentDir(t *testing.T) {
|
||||
policy, err := shellsecurity.Load("/tmp/does-not-exist-security-xyz")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if len(policy.UserGroups) != 0 || len(policy.AgentGroups) != 0 || len(policy.Policies) != 0 {
|
||||
t.Errorf("expected empty policy, got: %+v", policy)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3.2: directorio vacío (sin YAML) → policy vacía, sin error ---
|
||||
|
||||
func TestLoad_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
policy, err := shellsecurity.Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if len(policy.UserGroups) != 0 || len(policy.AgentGroups) != 0 || len(policy.Policies) != 0 {
|
||||
t.Errorf("expected empty policy, got: %+v", policy)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3.3: los 3 YAML válidos → policy con todos los campos ---
|
||||
|
||||
func TestLoad_AllFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeFile(t, dir, "user-groups.yaml", `
|
||||
groups:
|
||||
admins:
|
||||
members: ["@admin:example.com"]
|
||||
everyone:
|
||||
members: ["*"]
|
||||
`)
|
||||
writeFile(t, dir, "agent-groups.yaml", `
|
||||
groups:
|
||||
assistants:
|
||||
agents:
|
||||
- assistant-bot
|
||||
all:
|
||||
agents: ["*"]
|
||||
`)
|
||||
writeFile(t, dir, "permissions.yaml", `
|
||||
policies:
|
||||
- agent_group: all
|
||||
permissions:
|
||||
- user_group: admins
|
||||
actions: ["*"]
|
||||
- user_group: everyone
|
||||
actions: ["ask"]
|
||||
`)
|
||||
|
||||
policy, err := shellsecurity.Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(policy.UserGroups) != 2 {
|
||||
t.Errorf("expected 2 user groups, got %d", len(policy.UserGroups))
|
||||
}
|
||||
if len(policy.AgentGroups) != 2 {
|
||||
t.Errorf("expected 2 agent groups, got %d", len(policy.AgentGroups))
|
||||
}
|
||||
if len(policy.Policies) != 1 {
|
||||
t.Errorf("expected 1 policy, got %d", len(policy.Policies))
|
||||
}
|
||||
if len(policy.Policies[0].Permissions) != 2 {
|
||||
t.Errorf("expected 2 permissions, got %d", len(policy.Policies[0].Permissions))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3.4: solo user-groups.yaml → user groups poblados, resto vacío ---
|
||||
|
||||
func TestLoad_OnlyUserGroups(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeFile(t, dir, "user-groups.yaml", `
|
||||
groups:
|
||||
admins:
|
||||
members: ["@admin:example.com"]
|
||||
`)
|
||||
|
||||
policy, err := shellsecurity.Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(policy.UserGroups) != 1 {
|
||||
t.Errorf("expected 1 user group, got %d", len(policy.UserGroups))
|
||||
}
|
||||
if policy.UserGroups[0].Name != "admins" {
|
||||
t.Errorf("expected group name 'admins', got %q", policy.UserGroups[0].Name)
|
||||
}
|
||||
if len(policy.AgentGroups) != 0 {
|
||||
t.Errorf("expected no agent groups, got %d", len(policy.AgentGroups))
|
||||
}
|
||||
if len(policy.Policies) != 0 {
|
||||
t.Errorf("expected no policies, got %d", len(policy.Policies))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3.5: YAML malformado → error con nombre de archivo en el mensaje ---
|
||||
|
||||
func TestLoad_MalformedYAML(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeFile(t, dir, "user-groups.yaml", `this: is: not: valid: yaml: [`)
|
||||
|
||||
_, err := shellsecurity.Load(dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for malformed YAML, got nil")
|
||||
}
|
||||
if got := err.Error(); len(got) == 0 {
|
||||
t.Fatal("error message is empty")
|
||||
}
|
||||
// Must mention the filename
|
||||
if !containsString(err.Error(), "user-groups.yaml") {
|
||||
t.Errorf("error message should contain filename, got: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test 3.6: "*" como string literal en members y agents ---
|
||||
|
||||
func TestLoad_WildcardStrings(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeFile(t, dir, "user-groups.yaml", `
|
||||
groups:
|
||||
everyone:
|
||||
members: ["*"]
|
||||
`)
|
||||
writeFile(t, dir, "agent-groups.yaml", `
|
||||
groups:
|
||||
all:
|
||||
agents: ["*"]
|
||||
`)
|
||||
|
||||
policy, err := shellsecurity.Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(policy.UserGroups) != 1 {
|
||||
t.Fatalf("expected 1 user group, got %d", len(policy.UserGroups))
|
||||
}
|
||||
if len(policy.UserGroups[0].Members) != 1 || policy.UserGroups[0].Members[0] != "*" {
|
||||
t.Errorf("expected members=[\"*\"], got %v", policy.UserGroups[0].Members)
|
||||
}
|
||||
|
||||
if len(policy.AgentGroups) != 1 {
|
||||
t.Fatalf("expected 1 agent group, got %d", len(policy.AgentGroups))
|
||||
}
|
||||
if len(policy.AgentGroups[0].Agents) != 1 || policy.AgentGroups[0].Agents[0] != "*" {
|
||||
t.Errorf("expected agents=[\"*\"], got %v", policy.AgentGroups[0].Agents)
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub))
|
||||
}
|
||||
|
||||
func containsSubstr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Executor ejecuta scripts de skills de forma segura con allowlist de interpreters.
|
||||
type Executor struct {
|
||||
allowedInterpreters []string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewExecutor crea un nuevo Executor con la configuracion dada.
|
||||
// Si allowedInterpreters esta vacio, se usa un default de ["bash", "sh"].
|
||||
func NewExecutor(allowedInterpreters []string, timeout time.Duration) *Executor {
|
||||
if len(allowedInterpreters) == 0 {
|
||||
allowedInterpreters = []string{"bash", "sh"}
|
||||
}
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
return &Executor{
|
||||
allowedInterpreters: allowedInterpreters,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Run ejecuta un script de skill con los argumentos dados.
|
||||
// scriptPath es la ruta absoluta al script.
|
||||
// args son los argumentos pasados al script.
|
||||
//
|
||||
// El script debe tener una extension reconocida (.sh, .bash, .py, etc.) o
|
||||
// un shebang que indique el interprete.
|
||||
//
|
||||
// Retorna stdout+stderr combinados y error si falla.
|
||||
func (e *Executor) Run(ctx context.Context, scriptPath string, args []string) (string, error) {
|
||||
// Inferir interprete desde extension
|
||||
interpreter, err := e.inferInterpreter(scriptPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validar que el interprete esta en la allowlist
|
||||
if !e.isAllowed(interpreter) {
|
||||
return "", fmt.Errorf("interpreter not allowed: %s (allowed: %v)", interpreter, e.allowedInterpreters)
|
||||
}
|
||||
|
||||
// Construir comando
|
||||
cmdArgs := append([]string{scriptPath}, args...)
|
||||
cmd := exec.CommandContext(ctx, interpreter, cmdArgs...)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Aplicar timeout
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, e.timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd = exec.CommandContext(timeoutCtx, interpreter, cmdArgs...)
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err = cmd.Run()
|
||||
output := stdout.String() + stderr.String()
|
||||
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
return output, fmt.Errorf("script timeout exceeded (%s)", e.timeout)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return output, fmt.Errorf("script failed: %w", err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// inferInterpreter detecta el interprete a usar desde la extension del archivo.
|
||||
func (e *Executor) inferInterpreter(path string) (string, error) {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
|
||||
switch ext {
|
||||
case ".sh", ".bash":
|
||||
return "bash", nil
|
||||
case ".py":
|
||||
return "python3", nil
|
||||
case ".rb":
|
||||
return "ruby", nil
|
||||
case ".js":
|
||||
return "node", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported script extension: %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// isAllowed verifica si un interprete esta en la allowlist.
|
||||
func (e *Executor) isAllowed(interpreter string) bool {
|
||||
for _, allowed := range e.allowedInterpreters {
|
||||
if allowed == interpreter {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExecutor(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a simple bash script
|
||||
scriptPath := filepath.Join(tmpDir, "test.sh")
|
||||
scriptContent := `#!/bin/bash
|
||||
echo "Hello from script"
|
||||
echo "Args: $@"
|
||||
`
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a script that times out
|
||||
timeoutScriptPath := filepath.Join(tmpDir, "timeout.sh")
|
||||
timeoutContent := `#!/bin/bash
|
||||
sleep 10
|
||||
`
|
||||
if err := os.WriteFile(timeoutScriptPath, []byte(timeoutContent), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a failing script
|
||||
failScriptPath := filepath.Join(tmpDir, "fail.sh")
|
||||
failContent := `#!/bin/bash
|
||||
exit 1
|
||||
`
|
||||
if err := os.WriteFile(failScriptPath, []byte(failContent), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
executor := NewExecutor([]string{"bash", "sh"}, 2*time.Second)
|
||||
|
||||
// Test successful execution
|
||||
t.Run("successful_execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
output, err := executor.Run(ctx, scriptPath, []string{"arg1", "arg2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Run failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Hello from script") {
|
||||
t.Errorf("expected 'Hello from script' in output, got: %q", output)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Args: arg1 arg2") {
|
||||
t.Errorf("expected 'Args: arg1 arg2' in output, got: %q", output)
|
||||
}
|
||||
})
|
||||
|
||||
// Test timeout
|
||||
t.Run("timeout", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := executor.Run(ctx, timeoutScriptPath, nil)
|
||||
if err == nil {
|
||||
t.Error("expected timeout error")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "timeout") {
|
||||
t.Errorf("expected timeout error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test script failure
|
||||
t.Run("script_failure", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := executor.Run(ctx, failScriptPath, nil)
|
||||
if err == nil {
|
||||
t.Error("expected script failure error")
|
||||
}
|
||||
})
|
||||
|
||||
// Test disallowed interpreter
|
||||
t.Run("disallowed_interpreter", func(t *testing.T) {
|
||||
pyScriptPath := filepath.Join(tmpDir, "test.py")
|
||||
pyContent := `#!/usr/bin/env python3
|
||||
print("hello")
|
||||
`
|
||||
if err := os.WriteFile(pyScriptPath, []byte(pyContent), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := executor.Run(ctx, pyScriptPath, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for disallowed interpreter")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "not allowed") {
|
||||
t.Errorf("expected 'not allowed' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test allowed python interpreter
|
||||
t.Run("allowed_python", func(t *testing.T) {
|
||||
pyExecutor := NewExecutor([]string{"python3"}, 2*time.Second)
|
||||
|
||||
pyScriptPath := filepath.Join(tmpDir, "hello.py")
|
||||
pyContent := `#!/usr/bin/env python3
|
||||
print("Hello from Python")
|
||||
`
|
||||
if err := os.WriteFile(pyScriptPath, []byte(pyContent), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
output, err := pyExecutor.Run(ctx, pyScriptPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Run failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Hello from Python") {
|
||||
t.Errorf("expected 'Hello from Python' in output, got: %q", output)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/skills"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Loader descubre y carga skills desde un directorio base.
|
||||
type Loader struct {
|
||||
basePath string
|
||||
}
|
||||
|
||||
// NewLoader crea un nuevo Loader apuntando al directorio de skills.
|
||||
func NewLoader(basePath string) *Loader {
|
||||
return &Loader{basePath: basePath}
|
||||
}
|
||||
|
||||
// LoadMeta carga solo la metadata (nivel 1) de todas las skills.
|
||||
// Recorre el directorio base buscando SKILL.md y extrae el frontmatter YAML.
|
||||
func (l *Loader) LoadMeta() ([]skills.SkillMeta, error) {
|
||||
var metas []skills.SkillMeta
|
||||
|
||||
// Recorre categorias (devops/, analysis/, etc.)
|
||||
categories, err := os.ReadDir(l.basePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read skills dir: %w", err)
|
||||
}
|
||||
|
||||
for _, catEntry := range categories {
|
||||
if !catEntry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
category := catEntry.Name()
|
||||
catPath := filepath.Join(l.basePath, category)
|
||||
|
||||
// Recorre skills dentro de la categoria
|
||||
skillDirs, err := os.ReadDir(catPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, skillEntry := range skillDirs {
|
||||
if !skillEntry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
skillName := skillEntry.Name()
|
||||
skillPath := filepath.Join(catPath, skillName)
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
|
||||
// Verificar que existe SKILL.md
|
||||
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parsear metadata
|
||||
meta, _, err := parseSkillMD(skillMdPath)
|
||||
if err != nil {
|
||||
continue // skip invalid skills
|
||||
}
|
||||
|
||||
meta.Category = category
|
||||
metas = append(metas, meta)
|
||||
}
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
// LoadSkill carga una skill completa (nivel 2) por nombre.
|
||||
// Retorna el struct Skill con metadata, instrucciones y listado de recursos.
|
||||
func (l *Loader) LoadSkill(name string) (*skills.Skill, error) {
|
||||
// Buscar en todas las categorias
|
||||
categories, err := os.ReadDir(l.basePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read skills dir: %w", err)
|
||||
}
|
||||
|
||||
for _, catEntry := range categories {
|
||||
if !catEntry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
category := catEntry.Name()
|
||||
skillPath := filepath.Join(l.basePath, category, name)
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
|
||||
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parsear skill completa
|
||||
meta, instructions, err := parseSkillMD(skillMdPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse %s: %w", skillMdPath, err)
|
||||
}
|
||||
|
||||
meta.Category = category
|
||||
|
||||
skill := &skills.Skill{
|
||||
Meta: meta,
|
||||
Instructions: instructions,
|
||||
BasePath: skillPath,
|
||||
Scripts: listFiles(filepath.Join(skillPath, "scripts")),
|
||||
References: listFiles(filepath.Join(skillPath, "references")),
|
||||
Templates: listFiles(filepath.Join(skillPath, "templates")),
|
||||
Assets: listFiles(filepath.Join(skillPath, "assets")),
|
||||
}
|
||||
|
||||
return skill, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("skill not found: %s", name)
|
||||
}
|
||||
|
||||
// ReadResource lee un recurso especifico (nivel 3) de una skill.
|
||||
// path es relativo a la skill (ej: "scripts/deploy.sh", "references/api.md").
|
||||
func (l *Loader) ReadResource(skillName, resourcePath string) (string, error) {
|
||||
skill, err := l.LoadSkill(skillName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(skill.BasePath, resourcePath)
|
||||
|
||||
// Validar que el path esta dentro de la skill (evitar path traversal)
|
||||
absBasePath, err := filepath.Abs(skill.BasePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("abs base path: %w", err)
|
||||
}
|
||||
|
||||
absFullPath, err := filepath.Abs(fullPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("abs resource path: %w", err)
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(absFullPath, absBasePath) {
|
||||
return "", fmt.Errorf("path traversal detected: %s", resourcePath)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absFullPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read resource: %w", err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// parseSkillMD extrae el frontmatter YAML y el cuerpo markdown de un SKILL.md.
|
||||
func parseSkillMD(path string) (skills.SkillMeta, string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return skills.SkillMeta{}, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
var yamlLines []string
|
||||
var bodyLines []string
|
||||
inYAML := false
|
||||
yamlClosed := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if strings.TrimSpace(line) == "---" {
|
||||
if !inYAML {
|
||||
inYAML = true
|
||||
continue
|
||||
} else {
|
||||
inYAML = false
|
||||
yamlClosed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if inYAML {
|
||||
yamlLines = append(yamlLines, line)
|
||||
} else if yamlClosed {
|
||||
bodyLines = append(bodyLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return skills.SkillMeta{}, "", err
|
||||
}
|
||||
|
||||
// Parse YAML frontmatter
|
||||
var meta skills.SkillMeta
|
||||
yamlStr := strings.Join(yamlLines, "\n")
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &meta); err != nil {
|
||||
return skills.SkillMeta{}, "", fmt.Errorf("parse yaml: %w", err)
|
||||
}
|
||||
|
||||
// Cuerpo markdown
|
||||
body := strings.Join(bodyLines, "\n")
|
||||
|
||||
return meta, body, nil
|
||||
}
|
||||
|
||||
// listFiles retorna una lista de archivos (rutas relativas) dentro de un directorio.
|
||||
// Si el directorio no existe, retorna una lista vacia.
|
||||
func listFiles(dir string) []string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
files = append(files, entry.Name())
|
||||
}
|
||||
}
|
||||
return files
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoader(t *testing.T) {
|
||||
// Create temporary skills directory structure
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test skill
|
||||
skillDir := filepath.Join(tmpDir, "devops", "test-skill")
|
||||
if err := os.MkdirAll(skillDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Write SKILL.md
|
||||
skillMD := `---
|
||||
name: test-skill
|
||||
description: A test skill for unit testing
|
||||
---
|
||||
|
||||
# Test Skill
|
||||
|
||||
This is the instructions body.
|
||||
It has multiple lines.
|
||||
`
|
||||
skillMDPath := filepath.Join(skillDir, "SKILL.md")
|
||||
if err := os.WriteFile(skillMDPath, []byte(skillMD), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create scripts/ directory with a test script
|
||||
scriptsDir := filepath.Join(skillDir, "scripts")
|
||||
if err := os.MkdirAll(scriptsDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
scriptPath := filepath.Join(scriptsDir, "test.sh")
|
||||
if err := os.WriteFile(scriptPath, []byte("#!/bin/bash\necho test"), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create references/ directory with a test reference
|
||||
refsDir := filepath.Join(skillDir, "references")
|
||||
if err := os.MkdirAll(refsDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
refPath := filepath.Join(refsDir, "api.md")
|
||||
if err := os.WriteFile(refPath, []byte("# API Reference"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loader := NewLoader(tmpDir)
|
||||
|
||||
// Test LoadMeta
|
||||
t.Run("LoadMeta", func(t *testing.T) {
|
||||
metas, err := loader.LoadMeta()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadMeta failed: %v", err)
|
||||
}
|
||||
|
||||
if len(metas) != 1 {
|
||||
t.Fatalf("expected 1 skill, got %d", len(metas))
|
||||
}
|
||||
|
||||
meta := metas[0]
|
||||
if meta.Name != "test-skill" {
|
||||
t.Errorf("expected name 'test-skill', got %q", meta.Name)
|
||||
}
|
||||
if meta.Category != "devops" {
|
||||
t.Errorf("expected category 'devops', got %q", meta.Category)
|
||||
}
|
||||
if meta.Description != "A test skill for unit testing" {
|
||||
t.Errorf("expected description 'A test skill for unit testing', got %q", meta.Description)
|
||||
}
|
||||
})
|
||||
|
||||
// Test LoadSkill
|
||||
t.Run("LoadSkill", func(t *testing.T) {
|
||||
skill, err := loader.LoadSkill("test-skill")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSkill failed: %v", err)
|
||||
}
|
||||
|
||||
if skill.Meta.Name != "test-skill" {
|
||||
t.Errorf("expected name 'test-skill', got %q", skill.Meta.Name)
|
||||
}
|
||||
|
||||
if skill.Instructions == "" {
|
||||
t.Error("instructions should not be empty")
|
||||
}
|
||||
|
||||
if len(skill.Scripts) != 1 || skill.Scripts[0] != "test.sh" {
|
||||
t.Errorf("expected Scripts=['test.sh'], got %v", skill.Scripts)
|
||||
}
|
||||
|
||||
if len(skill.References) != 1 || skill.References[0] != "api.md" {
|
||||
t.Errorf("expected References=['api.md'], got %v", skill.References)
|
||||
}
|
||||
})
|
||||
|
||||
// Test LoadSkill nonexistent
|
||||
t.Run("LoadSkill_nonexistent", func(t *testing.T) {
|
||||
_, err := loader.LoadSkill("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent skill")
|
||||
}
|
||||
})
|
||||
|
||||
// Test ReadResource
|
||||
t.Run("ReadResource", func(t *testing.T) {
|
||||
content, err := loader.ReadResource("test-skill", "scripts/test.sh")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadResource failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "#!/bin/bash\necho test" {
|
||||
t.Errorf("unexpected content: %q", content)
|
||||
}
|
||||
})
|
||||
|
||||
// Test ReadResource path traversal protection
|
||||
t.Run("ReadResource_path_traversal", func(t *testing.T) {
|
||||
_, err := loader.ReadResource("test-skill", "../../../etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for path traversal attempt")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Package ssh provides impure SSH command execution.
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/tools"
|
||||
"github.com/enmanuel/agents/shell/logger"
|
||||
)
|
||||
|
||||
// Result holds the output of an SSH command execution.
|
||||
type Result struct {
|
||||
Stdout string
|
||||
Stderr string
|
||||
ExitCode int
|
||||
Err error
|
||||
}
|
||||
|
||||
// Executor runs SSH commands against configured targets.
|
||||
type Executor struct {
|
||||
cfg config.SSHCfg
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewExecutor creates an Executor from the SSH config section.
|
||||
func NewExecutor(cfg config.SSHCfg, log *slog.Logger) *Executor {
|
||||
return &Executor{cfg: cfg, logger: log.With(logger.FieldComponent, "ssh")}
|
||||
}
|
||||
|
||||
// Execute runs the SSH command described by spec. Impure.
|
||||
func (e *Executor) Execute(ctx context.Context, spec tools.SSHCommandSpec) Result {
|
||||
cmdPreview := spec.Command
|
||||
if len(cmdPreview) > 80 {
|
||||
cmdPreview = cmdPreview[:80] + "..."
|
||||
}
|
||||
e.logger.Info("ssh_exec_start", "target", spec.Target, "command", cmdPreview)
|
||||
start := time.Now()
|
||||
|
||||
target, ok := e.cfg.Targets[spec.Target]
|
||||
if !ok {
|
||||
e.logger.Error("ssh_exec_error", "target", spec.Target, "err", "unknown target")
|
||||
return Result{Err: fmt.Errorf("unknown SSH target: %s", spec.Target)}
|
||||
}
|
||||
|
||||
if len(target.Hosts) == 0 {
|
||||
e.logger.Error("ssh_exec_error", "target", spec.Target, "err", "no hosts")
|
||||
return Result{Err: fmt.Errorf("no hosts for target: %s", spec.Target)}
|
||||
}
|
||||
|
||||
// Use first host (round-robin or load balancing can be added later)
|
||||
host := target.Hosts[0]
|
||||
user := target.User
|
||||
if user == "" {
|
||||
user = e.cfg.Defaults.User
|
||||
}
|
||||
port := target.Port
|
||||
if port == 0 {
|
||||
port = e.cfg.Defaults.Port
|
||||
}
|
||||
if port == 0 {
|
||||
port = 22
|
||||
}
|
||||
|
||||
keyEnv := target.KeyFileEnv
|
||||
if keyEnv == "" {
|
||||
keyEnv = e.cfg.Defaults.KeyFileEnv
|
||||
}
|
||||
|
||||
signer, err := loadSigner(keyEnv)
|
||||
if err != nil {
|
||||
ms := time.Since(start).Milliseconds()
|
||||
e.logger.Error("ssh_exec_error", "target", spec.Target, logger.FieldDurationMS, ms, "err", err)
|
||||
return Result{Err: fmt.Errorf("load SSH key: %w", err)}
|
||||
}
|
||||
|
||||
sshCfg := &gossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)},
|
||||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), // TODO: use known_hosts
|
||||
Timeout: e.cfg.Defaults.Timeout,
|
||||
}
|
||||
if sshCfg.Timeout == 0 {
|
||||
sshCfg.Timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
conn, err := gossh.Dial("tcp", addr, sshCfg)
|
||||
if err != nil {
|
||||
ms := time.Since(start).Milliseconds()
|
||||
e.logger.Error("ssh_exec_error", "target", spec.Target, "host", addr, logger.FieldDurationMS, ms, "err", err)
|
||||
return Result{Err: fmt.Errorf("ssh dial %s: %w", addr, err)}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
ms := time.Since(start).Milliseconds()
|
||||
e.logger.Error("ssh_exec_error", "target", spec.Target, logger.FieldDurationMS, ms, "err", err)
|
||||
return Result{Err: fmt.Errorf("ssh session: %w", err)}
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
session.Stderr = &stderr
|
||||
|
||||
// Respect context cancellation via a goroutine
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- session.Run(spec.Command) }()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
session.Signal(gossh.SIGTERM)
|
||||
ms := time.Since(start).Milliseconds()
|
||||
e.logger.Warn("ssh_exec_cancelled", "target", spec.Target, logger.FieldDurationMS, ms)
|
||||
return Result{Err: ctx.Err()}
|
||||
case err := <-done:
|
||||
ms := time.Since(start).Milliseconds()
|
||||
code := 0
|
||||
if err != nil {
|
||||
var exitErr *gossh.ExitError
|
||||
if ok := asExitError(err, &exitErr); ok {
|
||||
code = exitErr.ExitStatus()
|
||||
} else {
|
||||
e.logger.Error("ssh_exec_error", "target", spec.Target, logger.FieldDurationMS, ms, "err", err)
|
||||
return Result{Err: err}
|
||||
}
|
||||
}
|
||||
e.logger.Info("ssh_exec_end", "target", spec.Target, "exit_code", code, logger.FieldDurationMS, ms)
|
||||
return Result{
|
||||
Stdout: stdout.String(),
|
||||
Stderr: stderr.String(),
|
||||
ExitCode: code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadSigner(keyFileEnv string) (gossh.Signer, error) {
|
||||
keyPath := os.Getenv(keyFileEnv)
|
||||
if keyPath == "" {
|
||||
return nil, fmt.Errorf("env var %s not set", keyFileEnv)
|
||||
}
|
||||
raw, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gossh.ParsePrivateKey(raw)
|
||||
}
|
||||
|
||||
// asExitError is a helper for type-asserting ssh.ExitError.
|
||||
func asExitError(err error, target **gossh.ExitError) bool {
|
||||
e, ok := err.(*gossh.ExitError)
|
||||
if ok {
|
||||
*target = e
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Ensure net is used (for future jump host support)
|
||||
var _ = net.Dial
|
||||
@@ -0,0 +1,35 @@
|
||||
package transportunibus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/transport"
|
||||
)
|
||||
|
||||
// DemoEchoHandler returns a minimal bot handler that proves the unibus transport
|
||||
// end to end: it receives a neutral InboundMessage and answers in the same room.
|
||||
// It echoes the message body back as a reply, with one built-in command
|
||||
// (!ping → pong) to show command routing works over the bus. It is intentionally
|
||||
// tiny — the point is the transport, not the bot.
|
||||
func DemoEchoHandler(t transport.Transport, logger *slog.Logger) transport.Handler {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return func(ctx context.Context, in transport.InboundMessage) {
|
||||
reply := "echo: " + in.Body
|
||||
if strings.TrimSpace(in.Body) == "!ping" {
|
||||
reply = "pong"
|
||||
}
|
||||
out := transport.OutboundReply{
|
||||
RoomID: in.RoomID,
|
||||
ReplyTo: in.MsgID,
|
||||
ThreadID: in.ThreadID,
|
||||
Markdown: reply,
|
||||
}
|
||||
if err := t.Reply(ctx, out); err != nil {
|
||||
logger.Error("demo echo reply failed", "err", err, "sender", in.SenderID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
// Package transportunibus implements transport.Transport over the unibus message
|
||||
// bus (github.com/enmanuel/unibus). A bot built on the neutral
|
||||
// transport.Transport speaks unibus instead of Matrix: it discovers the rooms it
|
||||
// has been invited to, joins them, and replies in the room a message arrived on.
|
||||
//
|
||||
// Room-based model ("everything is a room"):
|
||||
//
|
||||
// - There is no inbox/outbox subject convention. A conversation is a unibus
|
||||
// room; a 1:1 DM is just a room with two members. A human peer creates an
|
||||
// encrypted room (room.ModeMatrix), invites the bot by its endpoint id, and
|
||||
// publishes a message. The bot finds the room by polling ListMyRooms,
|
||||
// Joins (fetching the sealed room key), Subscribes, and answers in place.
|
||||
// - The control plane is pull-based: there is no server push of invitations,
|
||||
// so the bot polls ListMyRooms on a ticker and reacts to rooms it has not
|
||||
// seen before.
|
||||
//
|
||||
// This adapter carries no Matrix (mautrix) types, so the agent core driving it
|
||||
// stays transport-neutral.
|
||||
package transportunibus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/message"
|
||||
"github.com/enmanuel/agents/pkg/transport"
|
||||
"github.com/enmanuel/unibus/pkg/client"
|
||||
"github.com/enmanuel/unibus/pkg/frame"
|
||||
)
|
||||
|
||||
// defaultCommandPrefix marks a command message (e.g. "!ping") when the bot's
|
||||
// config does not override it.
|
||||
const defaultCommandPrefix = "!"
|
||||
|
||||
// discoveryInterval is how often the bot polls the control plane for rooms it
|
||||
// has been invited to. The control plane has no push, so this is the latency a
|
||||
// human waits between inviting the bot and the bot joining.
|
||||
const discoveryInterval = 2 * time.Second
|
||||
|
||||
// Transport is a unibus-backed transport.Transport for one bot. It discovers
|
||||
// rooms, subscribes to them, and replies in the room each message came from.
|
||||
type Transport struct {
|
||||
handle string
|
||||
commandPrefix string
|
||||
client *client.Client
|
||||
endpoint string // this bot's own endpoint id, to skip its own messages
|
||||
ctrlURL string
|
||||
http *http.Client
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.Mutex
|
||||
subscribed map[string]*client.Sub // roomID -> active subscription
|
||||
memberCount map[string]int // roomID -> cached member count (for IsDirectMsg)
|
||||
}
|
||||
|
||||
// compile-time assertion that Transport satisfies the neutral interface.
|
||||
var _ transport.Transport = (*Transport)(nil)
|
||||
|
||||
// New connects to a unibus deployment using the bot's BusCfg. It loads (or
|
||||
// creates) the bot's long-term identity, connects to the NATS data plane and
|
||||
// the membershipd control plane, and records the handle used for mention
|
||||
// detection. It does not create or join any room: rooms are discovered at Run
|
||||
// time as the bot is invited to them.
|
||||
func New(busCfg config.BusCfg, logger *slog.Logger) (*Transport, error) {
|
||||
id, err := client.LoadOrCreateIdentity(busCfg.IdentityPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transportunibus: identity: %w", err)
|
||||
}
|
||||
c, err := client.New(busCfg.NatsURL, busCfg.CtrlURL, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transportunibus: connect: %w", err)
|
||||
}
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
prefix := busCfg.CommandPrefix
|
||||
if prefix == "" {
|
||||
prefix = defaultCommandPrefix
|
||||
}
|
||||
return &Transport{
|
||||
handle: busCfg.Handle,
|
||||
commandPrefix: prefix,
|
||||
client: c,
|
||||
endpoint: c.Endpoint().ID,
|
||||
ctrlURL: busCfg.CtrlURL,
|
||||
http: &http.Client{Timeout: 10 * time.Second},
|
||||
logger: logger,
|
||||
subscribed: map[string]*client.Sub{},
|
||||
memberCount: map[string]int{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Endpoint returns this bot's public endpoint id. A human peer needs it to
|
||||
// invite the bot to a room (the bot logs it at startup; a directory is a later
|
||||
// step).
|
||||
func (t *Transport) Endpoint() string { return t.endpoint }
|
||||
|
||||
// BusEndpoint returns this bot's full public endpoint (id + signing/key-exchange
|
||||
// public keys). A peer inviting the bot to an encrypted room needs the public
|
||||
// keys to seal the room key for it.
|
||||
func (t *Transport) BusEndpoint() client.Endpoint { return t.client.Endpoint() }
|
||||
|
||||
// Run polls the control plane for rooms the bot has been invited to, joins and
|
||||
// subscribes to each new one, and delivers every decrypted frame to handler as
|
||||
// a neutral InboundMessage. It blocks until ctx is cancelled.
|
||||
func (t *Transport) Run(ctx context.Context, handler transport.Handler) error {
|
||||
t.logger.Info("unibus transport running", "handle", t.handle, "endpoint", t.endpoint)
|
||||
|
||||
ticker := time.NewTicker(discoveryInterval)
|
||||
defer ticker.Stop()
|
||||
defer t.unsubscribeAll()
|
||||
|
||||
// Discover immediately so we don't wait a full interval on startup.
|
||||
t.discover(ctx, handler)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
t.discover(ctx, handler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// discover lists the bot's rooms and joins+subscribes to any it has not seen.
|
||||
func (t *Transport) discover(ctx context.Context, handler transport.Handler) {
|
||||
rooms, err := t.client.ListMyRooms()
|
||||
if err != nil {
|
||||
t.logger.Warn("unibus discover: list rooms failed", "err", err)
|
||||
return
|
||||
}
|
||||
for _, r := range rooms {
|
||||
t.mu.Lock()
|
||||
_, already := t.subscribed[r.RoomID]
|
||||
t.mu.Unlock()
|
||||
if already {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := t.client.Join(r.RoomID); err != nil {
|
||||
t.logger.Warn("unibus discover: join failed", "room", r.RoomID, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
roomID := r.RoomID
|
||||
sub, err := t.client.Subscribe(roomID, func(f frame.Frame, plaintext []byte) {
|
||||
t.onFrame(ctx, handler, roomID, f, plaintext)
|
||||
})
|
||||
if err != nil {
|
||||
t.logger.Warn("unibus discover: subscribe failed", "room", roomID, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.subscribed[roomID] = sub
|
||||
t.mu.Unlock()
|
||||
t.logger.Info("joined and subscribed to room", "room", roomID, "subject", r.Subject)
|
||||
}
|
||||
}
|
||||
|
||||
// onFrame maps a decrypted frame to a neutral InboundMessage and delivers it.
|
||||
// It skips the bot's own messages (to avoid replying to itself), parses any
|
||||
// command, and computes IsDirectMsg (2-member room) and IsMention (handle in
|
||||
// body) so the agent core's command/LLM flow behaves exactly as it did on
|
||||
// Matrix.
|
||||
func (t *Transport) onFrame(ctx context.Context, handler transport.Handler, roomID string, f frame.Frame, plaintext []byte) {
|
||||
if f.Sender == t.endpoint {
|
||||
return // never react to our own messages
|
||||
}
|
||||
|
||||
body := string(plaintext)
|
||||
isDM := t.roomMemberCount(roomID) == 2
|
||||
isMention := t.handle != "" && strings.Contains(strings.ToLower(body), strings.ToLower(t.handle))
|
||||
|
||||
// Reuse the pure command parser so "!cmd args" is split the same way the
|
||||
// Matrix listener split it.
|
||||
parsed := message.Parse(body, f.Sender, roomID, 0, isDM, message.ParseOptions{
|
||||
CommandPrefix: t.commandPrefix,
|
||||
})
|
||||
|
||||
handler(ctx, transport.InboundMessage{
|
||||
RoomID: roomID,
|
||||
Subject: f.Subject,
|
||||
SenderID: f.Sender,
|
||||
MsgID: f.MsgID,
|
||||
ThreadID: f.ThreadID,
|
||||
ReplyTo: f.ReplyTo,
|
||||
Body: body,
|
||||
Command: parsed.Command,
|
||||
Args: parsed.Args,
|
||||
IsDirectMsg: isDM,
|
||||
IsMention: isMention,
|
||||
})
|
||||
}
|
||||
|
||||
// Reply publishes a reply into the room the message came from. When the reply
|
||||
// carries a ReplyTo / ThreadID anchor it is published as a threaded reply so
|
||||
// receivers can render the conversation tree.
|
||||
func (t *Transport) Reply(_ context.Context, out transport.OutboundReply) error {
|
||||
if out.ReplyTo != "" || out.ThreadID != "" {
|
||||
return t.client.PublishReply(out.RoomID, []byte(out.Markdown), out.ReplyTo, out.ThreadID)
|
||||
}
|
||||
return t.client.Publish(out.RoomID, []byte(out.Markdown))
|
||||
}
|
||||
|
||||
// Send posts a standalone message into a room.
|
||||
func (t *Transport) Send(_ context.Context, roomID, markdown string) error {
|
||||
return t.client.Publish(roomID, []byte(markdown))
|
||||
}
|
||||
|
||||
// Close unsubscribes from every room and releases the unibus client connection.
|
||||
func (t *Transport) Close() error {
|
||||
t.unsubscribeAll()
|
||||
return t.client.Close()
|
||||
}
|
||||
|
||||
// Sender returns an adapter that satisfies the effects/cron/tools Sender
|
||||
// interface, letting the agent's effects runner, scheduler, and bus_send tool
|
||||
// publish into rooms over this transport.
|
||||
func (t *Transport) Sender() *busSender { return &busSender{t: t} }
|
||||
|
||||
// unsubscribeAll cancels every active room subscription.
|
||||
func (t *Transport) unsubscribeAll() {
|
||||
t.mu.Lock()
|
||||
subs := t.subscribed
|
||||
t.subscribed = map[string]*client.Sub{}
|
||||
t.mu.Unlock()
|
||||
for roomID, sub := range subs {
|
||||
if err := sub.Unsubscribe(); err != nil {
|
||||
t.logger.Warn("unibus: unsubscribe failed", "room", roomID, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// roomMemberCount returns the number of members in a room, used to decide
|
||||
// IsDirectMsg. The control plane exposes GET /rooms/{id}/members; the result is
|
||||
// cached per room since membership rarely changes during a conversation.
|
||||
func (t *Transport) roomMemberCount(roomID string) int {
|
||||
t.mu.Lock()
|
||||
if n, ok := t.memberCount[roomID]; ok {
|
||||
t.mu.Unlock()
|
||||
return n
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
n, err := t.fetchMemberCount(roomID)
|
||||
if err != nil {
|
||||
t.logger.Warn("unibus: member count fetch failed", "room", roomID, "err", err)
|
||||
return 0 // unknown → treat as not-a-DM (mention still drives the LLM)
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.memberCount[roomID] = n
|
||||
t.mu.Unlock()
|
||||
return n
|
||||
}
|
||||
|
||||
// memberJSON mirrors the membership server's GET /rooms/{id}/members element.
|
||||
// Only the count matters here, so the body fields are ignored.
|
||||
type memberJSON struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// fetchMemberCount calls the membershipd control plane directly to count the
|
||||
// members of a room. unibus's client does not expose this, and the task forbids
|
||||
// modifying unibus, so the minimal HTTP GET lives here.
|
||||
func (t *Transport) fetchMemberCount(roomID string) (int, error) {
|
||||
url := strings.TrimRight(t.ctrlURL, "/") + "/rooms/" + roomID + "/members"
|
||||
resp, err := t.http.Get(url)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get members: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 300 {
|
||||
return 0, fmt.Errorf("get members: status %d", resp.StatusCode)
|
||||
}
|
||||
var members []memberJSON
|
||||
if err := json.NewDecoder(resp.Body).Decode(&members); err != nil {
|
||||
return 0, fmt.Errorf("decode members: %w", err)
|
||||
}
|
||||
return len(members), nil
|
||||
}
|
||||
|
||||
// busSender adapts a *Transport to the effects.Sender / cron.Sender / tools
|
||||
// Sender interface (SendText/SendMarkdown/SendReplyMarkdown/SendThreadMarkdown/
|
||||
// SendTyping). All sends publish into the given room; SendTyping is a no-op
|
||||
// because unibus has no typing-indicator concept.
|
||||
type busSender struct{ t *Transport }
|
||||
|
||||
func (s *busSender) SendText(_ context.Context, roomID, text string) error {
|
||||
return s.t.client.Publish(roomID, []byte(text))
|
||||
}
|
||||
|
||||
func (s *busSender) SendMarkdown(_ context.Context, roomID, markdown string) error {
|
||||
return s.t.client.Publish(roomID, []byte(markdown))
|
||||
}
|
||||
|
||||
func (s *busSender) SendReplyMarkdown(_ context.Context, roomID, inReplyTo, markdown string) error {
|
||||
return s.t.client.PublishReply(roomID, []byte(markdown), inReplyTo, "")
|
||||
}
|
||||
|
||||
func (s *busSender) SendThreadMarkdown(_ context.Context, roomID, threadRootID, inReplyTo, markdown string) error {
|
||||
return s.t.client.PublishReply(roomID, []byte(markdown), inReplyTo, threadRootID)
|
||||
}
|
||||
|
||||
// SendTyping is a no-op: unibus has no typing indicator.
|
||||
func (s *busSender) SendTyping(_ context.Context, _ string, _ bool) error { return nil }
|
||||
@@ -0,0 +1,243 @@
|
||||
package transportunibus_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/transport"
|
||||
"github.com/enmanuel/agents/shell/transportunibus"
|
||||
|
||||
cs "fn-registry/functions/cybersecurity"
|
||||
|
||||
"github.com/enmanuel/unibus/pkg/blobstore"
|
||||
"github.com/enmanuel/unibus/pkg/client"
|
||||
"github.com/enmanuel/unibus/pkg/embeddednats"
|
||||
"github.com/enmanuel/unibus/pkg/frame"
|
||||
"github.com/enmanuel/unibus/pkg/membership"
|
||||
"github.com/enmanuel/unibus/pkg/room"
|
||||
server "github.com/nats-io/nats-server/v2/server"
|
||||
)
|
||||
|
||||
// harness boots an embedded NATS + an in-process membershipd, mirroring the
|
||||
// unibus test harness so this adapter can be exercised without any external
|
||||
// service.
|
||||
type harness struct {
|
||||
natsURL string
|
||||
ctrlURL string
|
||||
ns *server.Server
|
||||
httpts *httptest.Server
|
||||
}
|
||||
|
||||
func freePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("free port: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func newHarness(t *testing.T) *harness {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
ns, err := embeddednats.StartHost(filepath.Join(dir, "js"), "127.0.0.1", freePort(t))
|
||||
if err != nil {
|
||||
t.Fatalf("embedded nats: %v", err)
|
||||
}
|
||||
store, err := membership.Open(filepath.Join(dir, "unibus.db"))
|
||||
if err != nil {
|
||||
ns.Shutdown()
|
||||
t.Fatalf("membership store: %v", err)
|
||||
}
|
||||
blobs, err := blobstore.New(filepath.Join(dir, "blobs"))
|
||||
if err != nil {
|
||||
ns.Shutdown()
|
||||
t.Fatalf("blob store: %v", err)
|
||||
}
|
||||
httpts := httptest.NewServer(membership.NewServer(store, blobs))
|
||||
h := &harness{natsURL: embeddednats.ClientURL(ns), ctrlURL: httpts.URL, ns: ns, httpts: httpts}
|
||||
t.Cleanup(func() {
|
||||
httpts.Close()
|
||||
store.Close()
|
||||
ns.Shutdown()
|
||||
ns.WaitForShutdown()
|
||||
})
|
||||
return h
|
||||
}
|
||||
|
||||
func waitHealth(t *testing.T, ctrlURL string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err := http.Get(ctrlURL + "/healthz")
|
||||
if err == nil && resp.StatusCode == 200 {
|
||||
resp.Body.Close()
|
||||
return
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("membershipd never became healthy")
|
||||
}
|
||||
|
||||
// botCfg builds a BusCfg pointing the bot at the harness, with a fresh identity
|
||||
// file under the test's temp dir.
|
||||
func botCfg(t *testing.T, h *harness, handle string) config.BusCfg {
|
||||
t.Helper()
|
||||
return config.BusCfg{
|
||||
NatsURL: h.natsURL,
|
||||
CtrlURL: h.ctrlURL,
|
||||
IdentityPath: filepath.Join(t.TempDir(), handle+".id"),
|
||||
Handle: handle,
|
||||
}
|
||||
}
|
||||
|
||||
// TestBotEchoesInEncryptedRoom is the headline room-based test: a human peer
|
||||
// creates an encrypted (room.ModeMatrix) room, invites the bot by its endpoint,
|
||||
// and publishes a mention. The bot — driven by Transport.Run + a tiny echo
|
||||
// handler that replies via Reply — answers IN THE SAME room, and the human
|
||||
// receives the reply decrypted. No Matrix is involved end to end.
|
||||
func TestBotEchoesInEncryptedRoom(t *testing.T) {
|
||||
h := newHarness(t)
|
||||
waitHealth(t, h.ctrlURL)
|
||||
|
||||
bot, err := transportunibus.New(botCfg(t, h, "demo"), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("bot transport: %v", err)
|
||||
}
|
||||
defer bot.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = bot.Run(ctx, transportunibus.DemoEchoHandler(bot, nil)) }()
|
||||
|
||||
// Human peer.
|
||||
userID, err := cs.GenerateIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("user identity: %v", err)
|
||||
}
|
||||
user, err := client.New(h.natsURL, h.ctrlURL, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("user client: %v", err)
|
||||
}
|
||||
defer user.Close()
|
||||
|
||||
// Human creates an encrypted room and invites the bot by its endpoint id.
|
||||
roomID, err := user.CreateRoom("conv.demo", room.ModeMatrix)
|
||||
if err != nil {
|
||||
t.Fatalf("create room: %v", err)
|
||||
}
|
||||
// Invite the bot by its full endpoint (id + public keys), so the human can
|
||||
// seal the encrypted room key for it.
|
||||
if err := user.Invite(roomID, bot.BusEndpoint()); err != nil {
|
||||
t.Fatalf("invite bot: %v", err)
|
||||
}
|
||||
|
||||
// Human subscribes to the same room to receive the bot's reply.
|
||||
var mu sync.Mutex
|
||||
var bodies []string
|
||||
var sawAnchored bool
|
||||
if err := user.Join(roomID); err != nil {
|
||||
t.Fatalf("user join: %v", err)
|
||||
}
|
||||
sub, err := user.Subscribe(roomID, func(f frame.Frame, plaintext []byte) {
|
||||
mu.Lock()
|
||||
bodies = append(bodies, string(plaintext))
|
||||
if f.ReplyTo != "" {
|
||||
sawAnchored = true
|
||||
}
|
||||
mu.Unlock()
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("user subscribe: %v", err)
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Give the bot's discovery ticker time to find, join and subscribe to the room.
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Human posts a message mentioning the bot's handle.
|
||||
if err := user.Publish(roomID, []byte("hola demo")); err != nil {
|
||||
t.Fatalf("user publish: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := waitBody(&mu, &bodies, "echo: hola demo", 5*time.Second); !ok {
|
||||
t.Fatalf("never received echo reply; got %v", snapshot(&mu, &bodies))
|
||||
}
|
||||
mu.Lock()
|
||||
anchored := sawAnchored
|
||||
mu.Unlock()
|
||||
if !anchored {
|
||||
t.Fatalf("reply did not carry a ReplyTo anchor")
|
||||
}
|
||||
|
||||
// Command over the bus → pong, in the same room.
|
||||
if err := user.Publish(roomID, []byte("!ping")); err != nil {
|
||||
t.Fatalf("user publish ping: %v", err)
|
||||
}
|
||||
if _, ok := waitBody(&mu, &bodies, "pong", 5*time.Second); !ok {
|
||||
t.Fatalf("never received pong; got %v", snapshot(&mu, &bodies))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunStopsOnContextCancel is an error/lifecycle path: Run must return when
|
||||
// its context is cancelled rather than blocking forever.
|
||||
func TestRunStopsOnContextCancel(t *testing.T) {
|
||||
h := newHarness(t)
|
||||
waitHealth(t, h.ctrlURL)
|
||||
|
||||
bot, err := transportunibus.New(botCfg(t, h, "lifecycle"), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("bot transport: %v", err)
|
||||
}
|
||||
defer bot.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- bot.Run(ctx, func(context.Context, transport.InboundMessage) {}) }()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("Run returned %v, want context.Canceled", err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("Run did not return after context cancel")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func waitBody(mu *sync.Mutex, slice *[]string, want string, timeout time.Duration) (string, bool) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
mu.Lock()
|
||||
for _, s := range *slice {
|
||||
if s == want {
|
||||
mu.Unlock()
|
||||
return s, true
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func snapshot(mu *sync.Mutex, slice *[]string) []string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return append([]string(nil), (*slice)...)
|
||||
}
|
||||
@@ -0,0 +1,419 @@
|
||||
// Package tui is the impure shell layer for the TUI.
|
||||
// It converts pure Intent values into real I/O via tea.Cmd.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
puretui "github.com/enmanuel/agents/pkg/tui"
|
||||
"github.com/enmanuel/agents/shell/process"
|
||||
)
|
||||
|
||||
// Adapter bridges pure Intents with the process Manager.
|
||||
type Adapter struct {
|
||||
mgr *process.Manager
|
||||
}
|
||||
|
||||
// NewAdapter creates an Adapter with the given Manager.
|
||||
func NewAdapter(mgr *process.Manager) *Adapter {
|
||||
return &Adapter{mgr: mgr}
|
||||
}
|
||||
|
||||
// RunIntent converts a pure Intent into a bubbletea Cmd that performs I/O.
|
||||
func (a *Adapter) RunIntent(intent puretui.Intent) tea.Cmd {
|
||||
switch intent.Kind {
|
||||
|
||||
case puretui.IntentLoadAgents:
|
||||
return a.loadAgents()
|
||||
|
||||
case puretui.IntentEnableAgent:
|
||||
return a.enableAgent(intent.AgentID)
|
||||
|
||||
case puretui.IntentDisableAgent:
|
||||
return a.disableAgent(intent.AgentID)
|
||||
|
||||
case puretui.IntentReloadAgent:
|
||||
return a.reloadAgent(intent.AgentID)
|
||||
|
||||
case puretui.IntentReloadAll:
|
||||
return a.reloadAll()
|
||||
|
||||
case puretui.IntentRestartAgent:
|
||||
return a.restartAgent(intent.AgentID)
|
||||
|
||||
case puretui.IntentLoadLogs:
|
||||
return a.loadLogs(intent.AgentID)
|
||||
|
||||
case puretui.IntentStartLauncher:
|
||||
return a.startLauncher()
|
||||
|
||||
case puretui.IntentStopLauncher:
|
||||
return a.stopLauncher()
|
||||
|
||||
case puretui.IntentRestartLauncher:
|
||||
return a.restartLauncher()
|
||||
|
||||
case puretui.IntentKillLauncher:
|
||||
return a.killLauncher()
|
||||
|
||||
case puretui.IntentRebuildRestart:
|
||||
return a.rebuildRestart()
|
||||
|
||||
case puretui.IntentRunTests:
|
||||
return a.runGoTests()
|
||||
|
||||
case puretui.IntentRunGoTests:
|
||||
return a.runGoTests()
|
||||
|
||||
case puretui.IntentRunE2ETests:
|
||||
return a.runE2ETests(false)
|
||||
|
||||
case puretui.IntentRunE2EHeadTests:
|
||||
return a.runE2ETests(true)
|
||||
|
||||
case puretui.IntentRunAllTests:
|
||||
return a.runAllTests()
|
||||
|
||||
case puretui.IntentTick:
|
||||
return a.tick()
|
||||
|
||||
case puretui.IntentQuit:
|
||||
return tea.Quit
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) loadAgents() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
statuses, err := a.mgr.StatusAllUnified()
|
||||
if err != nil {
|
||||
return puretui.MsgAgentsLoaded{}
|
||||
}
|
||||
|
||||
views := make([]puretui.AgentView, len(statuses))
|
||||
for i, s := range statuses {
|
||||
views[i] = puretui.AgentView{
|
||||
ID: s.ID,
|
||||
Name: s.Name,
|
||||
Version: s.Version,
|
||||
Desc: s.Desc,
|
||||
Enabled: s.Enabled,
|
||||
Running: s.Running,
|
||||
PID: s.PID,
|
||||
}
|
||||
}
|
||||
|
||||
msg := puretui.MsgAgentsLoaded{
|
||||
Agents: views,
|
||||
LauncherRunning: a.mgr.IsUnifiedRunning(),
|
||||
LauncherPID: a.mgr.UnifiedPID(),
|
||||
}
|
||||
|
||||
// Launcher stats
|
||||
if msg.LauncherRunning {
|
||||
if stats, err := a.mgr.UnifiedStats(); err == nil {
|
||||
msg.LauncherUptime = formatUptime(stats.UptimeSecs)
|
||||
msg.LauncherMemory = formatBytes(stats.MemRSSKB * 1024)
|
||||
msg.LauncherCPU = fmt.Sprintf("%.1f%%", stats.CPUPct)
|
||||
msg.LauncherLogSize = formatBytes(stats.LogBytes)
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) enableAgent(id string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := a.mgr.ToggleEnabled(id, true)
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Enable", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) disableAgent(id string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := a.mgr.ToggleEnabled(id, false)
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Disable", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// reloadAgent hot-reloads a single agent via SIGHUP without stopping the launcher.
|
||||
func (a *Adapter) reloadAgent(id string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
pid := a.mgr.UnifiedPID()
|
||||
if pid <= 0 {
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Reload",
|
||||
Err: fmt.Errorf("el launcher no está corriendo")}
|
||||
}
|
||||
if id != "" {
|
||||
if err := os.WriteFile("run/reload.txt", []byte(id), 0o644); err != nil {
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Reload", Err: err}
|
||||
}
|
||||
}
|
||||
err := syscall.Kill(pid, syscall.SIGHUP)
|
||||
if err != nil {
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Reload", Err: err}
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Reload", Err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
// reloadAll hot-reloads all agents via SIGHUP (no reload.txt → reload all).
|
||||
func (a *Adapter) reloadAll() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
pid := a.mgr.UnifiedPID()
|
||||
if pid <= 0 {
|
||||
return puretui.MsgServerActionDone{Action: "Reload All",
|
||||
Err: fmt.Errorf("el launcher no está corriendo")}
|
||||
}
|
||||
// Remove stale reload.txt so the launcher reloads all agents.
|
||||
_ = os.Remove("run/reload.txt")
|
||||
err := syscall.Kill(pid, syscall.SIGHUP)
|
||||
if err != nil {
|
||||
return puretui.MsgServerActionDone{Action: "Reload All", Err: err}
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
return puretui.MsgServerActionDone{Action: "Reload All", Err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
// restartAgent stops and restarts the whole launcher (full restart, all agents).
|
||||
func (a *Adapter) restartAgent(id string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
_ = a.mgr.StopUnified()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
err := a.mgr.StartUnified()
|
||||
if err == nil {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
return puretui.MsgActionDone{AgentID: id, Action: "Restart", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) startLauncher() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := a.mgr.StartUnified()
|
||||
if err == nil {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
return puretui.MsgServerActionDone{Action: "Start", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) stopLauncher() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := a.mgr.StopUnified()
|
||||
return puretui.MsgServerActionDone{Action: "Stop", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) restartLauncher() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
_ = a.mgr.StopUnified()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
err := a.mgr.StartUnified()
|
||||
if err == nil {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
return puretui.MsgServerActionDone{Action: "Restart", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) killLauncher() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := a.mgr.KillUnified()
|
||||
return puretui.MsgServerActionDone{Action: "Kill", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) rebuildRestart() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
wasRunning := a.mgr.IsUnifiedRunning()
|
||||
|
||||
// Stop if running
|
||||
if wasRunning {
|
||||
_ = a.mgr.StopUnified()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Build
|
||||
buildOut, buildErr := a.mgr.Build()
|
||||
if buildErr != nil {
|
||||
// Build failed — try to restart if was running
|
||||
if wasRunning {
|
||||
_ = a.mgr.StartUnified()
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(buildOut), "\n")
|
||||
tail := buildOut
|
||||
if len(lines) > 5 {
|
||||
tail = strings.Join(lines[len(lines)-5:], "\n")
|
||||
}
|
||||
return puretui.MsgRebuildDone{BuildOK: false, BuildLog: tail}
|
||||
}
|
||||
|
||||
// Restart launcher
|
||||
started := false
|
||||
var startErr error
|
||||
if wasRunning {
|
||||
startErr = a.mgr.StartUnified()
|
||||
if startErr == nil {
|
||||
started = true
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return puretui.MsgRebuildDone{
|
||||
BuildOK: true,
|
||||
Started: started,
|
||||
Err: startErr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) loadLogs(id string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
var lines []string
|
||||
var err error
|
||||
if id == "" {
|
||||
// Launcher logs
|
||||
lines, err = a.mgr.UnifiedLogTail(100)
|
||||
} else {
|
||||
// Agent logs — in unified mode, all go to launcher log
|
||||
lines, err = a.mgr.UnifiedLogTail(100)
|
||||
}
|
||||
if err != nil {
|
||||
return puretui.MsgLogsLoaded{Lines: []string{"Error: " + err.Error()}}
|
||||
}
|
||||
return puretui.MsgLogsLoaded{Lines: lines}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) runGoTests() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
goBin, err := exec.LookPath("go")
|
||||
if err != nil {
|
||||
goBin = "/usr/local/go/bin/go"
|
||||
}
|
||||
cmd := exec.Command(goBin, "test", "-tags", "goolm", "-count=1", "./...")
|
||||
cmd.Env = a.mgr.BuildEnv()
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
output := strings.TrimSpace(string(out))
|
||||
if output == "" && err != nil {
|
||||
output = "Error: " + err.Error()
|
||||
}
|
||||
lines := strings.Split(output, "\n")
|
||||
return puretui.MsgTestsDone{Kind: puretui.TestKindGo, Passed: err == nil, Output: lines}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) runE2ETests(headed bool) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
args := []string{"./dev-scripts/e2e/run.sh"}
|
||||
if headed {
|
||||
args = append(args, "--headed")
|
||||
}
|
||||
cmd := exec.Command("bash", args...)
|
||||
cmd.Env = a.mgr.BuildEnv()
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
output := strings.TrimSpace(string(out))
|
||||
if output == "" && err != nil {
|
||||
output = "Error: " + err.Error()
|
||||
}
|
||||
lines := strings.Split(output, "\n")
|
||||
kind := puretui.TestKindE2E
|
||||
if headed {
|
||||
kind = puretui.TestKindE2EHead
|
||||
}
|
||||
return puretui.MsgTestsDone{Kind: kind, Passed: err == nil, Output: lines}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) runAllTests() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
var allLines []string
|
||||
|
||||
// Go tests first
|
||||
goBin, err := exec.LookPath("go")
|
||||
if err != nil {
|
||||
goBin = "/usr/local/go/bin/go"
|
||||
}
|
||||
goCmd := exec.Command(goBin, "test", "-tags", "goolm", "-count=1", "./...")
|
||||
goCmd.Env = a.mgr.BuildEnv()
|
||||
goOut, goErr := goCmd.CombinedOutput()
|
||||
|
||||
allLines = append(allLines, "═══ Go Tests ═══")
|
||||
goOutput := strings.TrimSpace(string(goOut))
|
||||
if goOutput == "" && goErr != nil {
|
||||
goOutput = "Error: " + goErr.Error()
|
||||
}
|
||||
allLines = append(allLines, strings.Split(goOutput, "\n")...)
|
||||
|
||||
if goErr != nil {
|
||||
allLines = append(allLines, "", "Go tests FAILED — skipping E2E")
|
||||
return puretui.MsgTestsDone{Kind: puretui.TestKindAll, Passed: false, Output: allLines}
|
||||
}
|
||||
|
||||
// E2E tests
|
||||
allLines = append(allLines, "", "═══ E2E Tests ═══")
|
||||
e2eCmd := exec.Command("bash", "./dev-scripts/e2e/run.sh")
|
||||
e2eCmd.Env = a.mgr.BuildEnv()
|
||||
e2eOut, e2eErr := e2eCmd.CombinedOutput()
|
||||
|
||||
e2eOutput := strings.TrimSpace(string(e2eOut))
|
||||
if e2eOutput == "" && e2eErr != nil {
|
||||
e2eOutput = "Error: " + e2eErr.Error()
|
||||
}
|
||||
allLines = append(allLines, strings.Split(e2eOutput, "\n")...)
|
||||
|
||||
return puretui.MsgTestsDone{Kind: puretui.TestKindAll, Passed: e2eErr == nil, Output: allLines}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) tick() tea.Cmd {
|
||||
return tea.Tick(3*time.Second, func(time.Time) tea.Msg {
|
||||
return puretui.MsgTick{}
|
||||
})
|
||||
}
|
||||
|
||||
// ── formatting helpers ───────────────────────────────────────────────────
|
||||
|
||||
func formatUptime(secs int64) string {
|
||||
if secs < 0 {
|
||||
return "n/a"
|
||||
}
|
||||
d := secs / 86400
|
||||
h := (secs % 86400) / 3600
|
||||
m := (secs % 3600) / 60
|
||||
if d > 0 {
|
||||
return fmt.Sprintf("%dd %dh", d, h)
|
||||
}
|
||||
if h > 0 {
|
||||
return fmt.Sprintf("%dh %dm", h, m)
|
||||
}
|
||||
return fmt.Sprintf("%dm", m)
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
||||
switch {
|
||||
case bytes >= 1<<30:
|
||||
return fmt.Sprintf("%.1f GB", float64(bytes)/float64(1<<30))
|
||||
case bytes >= 1<<20:
|
||||
return fmt.Sprintf("%.1f MB", float64(bytes)/float64(1<<20))
|
||||
case bytes >= 1<<10:
|
||||
return fmt.Sprintf("%.1f KB", float64(bytes)/float64(1<<10))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user