feat: import agents_and_robots platform as unibots (Matrix-out, unibus transport)

Reemplaza el scaffold del echobot por la plataforma completa de bots traida
desde ~/DataProyects/Github/agents_and_robots tras la operacion Matrix-out:
los bots ya no hablan por Matrix sino por el bus unibus (modelo todo-rooms +
E2E via shell/transportunibus sobre github.com/enmanuel/unibus/pkg/client).

- go.mod: replace de unibus -> ../unibus y de fn-registry -> ../../../.. (paths
  relativos reajustados a la nueva ubicacion dentro de fn_registry).
- app.md: bump a 0.2.0, descripcion + arquitectura + comandos + gotchas reales.
- modulo Go conservado como github.com/enmanuel/agents (sin reescribir imports).

agents_and_robots queda archivado como museo de la era Matrix.
This commit is contained in:
agent
2026-06-07 11:50:13 +02:00
parent bb5b0e09b1
commit fc644ecd6e
308 changed files with 38829 additions and 474 deletions
+143
View File
@@ -0,0 +1,143 @@
// Package bus provides in-process agent-to-agent message passing.
package bus
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// Well-known message kinds used by the orchestrator.
const (
KindTask = "task" // orchestrator → bot: handle this question
KindTaskResult = "task_result" // bot → orchestrator: here is my answer
)
// AgentID identifies an agent.
type AgentID string
// AgentMessage is a message between agents.
type AgentMessage struct {
From AgentID
To AgentID
Kind string
Payload map[string]string
}
// Bus manages channels for inter-agent communication.
type Bus struct {
mu sync.RWMutex
channels map[AgentID]chan AgentMessage
replyMu sync.Mutex
replyChs map[string]chan AgentMessage // taskID → one-shot reply channel
logger *slog.Logger
}
// New creates a new Bus.
func New(logger *slog.Logger) *Bus {
return &Bus{
channels: make(map[AgentID]chan AgentMessage),
replyChs: make(map[string]chan AgentMessage),
logger: logger.With("component", "bus"),
}
}
// Subscribe registers an agent and returns its receive channel.
func (b *Bus) Subscribe(id AgentID) <-chan AgentMessage {
b.mu.Lock()
defer b.mu.Unlock()
ch := make(chan AgentMessage, 64)
b.channels[id] = ch
b.logger.Info("bus_subscribe", "agent", id)
return ch
}
// Send delivers a message to an agent's channel.
func (b *Bus) Send(msg AgentMessage) error {
b.mu.RLock()
ch, ok := b.channels[msg.To]
b.mu.RUnlock()
if !ok {
b.logger.Warn("bus_not_found", "to", msg.To, "from", msg.From, "kind", msg.Kind)
return fmt.Errorf("agent %q not registered on bus", msg.To)
}
select {
case ch <- msg:
b.logger.Debug("bus_send", "from", msg.From, "to", msg.To, "kind", msg.Kind)
return nil
default:
b.logger.Warn("bus_queue_full", "to", msg.To, "from", msg.From, "kind", msg.Kind)
return fmt.Errorf("agent %q message queue full", msg.To)
}
}
// SendAndWait sends a task message and blocks until a reply with the matching
// taskID arrives or the context expires. The caller must ensure the reply is
// routed via Reply().
func (b *Bus) SendAndWait(ctx context.Context, msg AgentMessage, taskID string, timeout time.Duration) (AgentMessage, error) {
ch := make(chan AgentMessage, 1)
b.replyMu.Lock()
b.replyChs[taskID] = ch
b.replyMu.Unlock()
defer func() {
b.replyMu.Lock()
delete(b.replyChs, taskID)
b.replyMu.Unlock()
}()
if err := b.Send(msg); err != nil {
return AgentMessage{}, err
}
b.logger.Debug("bus_send_and_wait", "task", taskID, "to", msg.To, "timeout", timeout)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case reply := <-ch:
return reply, nil
case <-timer.C:
b.logger.Warn("bus_timeout", "task", taskID, "to", msg.To, "timeout", timeout)
return AgentMessage{}, fmt.Errorf("task %s: delegation timeout after %s", taskID, timeout)
case <-ctx.Done():
return AgentMessage{}, ctx.Err()
}
}
// Reply routes a task_result message to the waiting SendAndWait caller.
// If no one is waiting for this taskID, it falls back to regular Send.
func (b *Bus) Reply(taskID string, msg AgentMessage) error {
b.replyMu.Lock()
ch, ok := b.replyChs[taskID]
b.replyMu.Unlock()
if ok {
select {
case ch <- msg:
b.logger.Debug("bus_reply", "task", taskID, "from", msg.From)
return nil
default:
b.logger.Warn("bus_reply_full", "task", taskID)
return fmt.Errorf("reply channel full for task %s", taskID)
}
}
// Fallback: deliver via regular channel
return b.Send(msg)
}
// Unsubscribe removes an agent from the bus.
func (b *Bus) Unsubscribe(id AgentID) {
b.mu.Lock()
defer b.mu.Unlock()
if ch, ok := b.channels[id]; ok {
close(ch)
delete(b.channels, id)
b.logger.Info("bus_unsubscribe", "agent", id)
}
}
+91
View File
@@ -0,0 +1,91 @@
package bus_test
import (
"log/slog"
"testing"
"github.com/enmanuel/agents/shell/bus"
)
func newBus() *bus.Bus {
return bus.New(slog.Default())
}
func TestSubscribeAndSend(t *testing.T) {
b := newBus()
ch := b.Subscribe("agent-a")
msg := bus.AgentMessage{From: "orch", To: "agent-a", Kind: bus.KindTask, Payload: map[string]string{"k": "v"}}
if err := b.Send(msg); err != nil {
t.Fatalf("Send: %v", err)
}
got := <-ch
if got.Kind != bus.KindTask || got.Payload["k"] != "v" {
t.Fatalf("unexpected message: %+v", got)
}
}
func TestUnsubscribeClosesChannel(t *testing.T) {
b := newBus()
ch := b.Subscribe("agent-b")
b.Unsubscribe("agent-b")
// Channel must be closed — reading from a closed channel returns zero value + ok=false.
_, ok := <-ch
if ok {
t.Fatal("expected channel to be closed after Unsubscribe")
}
}
func TestUnsubscribeRemovesFromBus(t *testing.T) {
b := newBus()
b.Subscribe("agent-c")
b.Unsubscribe("agent-c")
// Sending after unsubscribe must return an error, not panic.
err := b.Send(bus.AgentMessage{To: "agent-c", Kind: "ping"})
if err == nil {
t.Fatal("expected error when sending to unsubscribed agent")
}
}
func TestUnsubscribeIdempotent(t *testing.T) {
b := newBus()
b.Subscribe("agent-d")
// Double unsubscribe must not panic.
b.Unsubscribe("agent-d")
b.Unsubscribe("agent-d")
}
func TestUnsubscribeNonExistent(t *testing.T) {
b := newBus()
// Unsubscribing an ID that was never subscribed must not panic.
b.Unsubscribe("does-not-exist")
}
func TestSendToUnknownAgent(t *testing.T) {
b := newBus()
err := b.Send(bus.AgentMessage{To: "ghost", Kind: "hello"})
if err == nil {
t.Fatal("expected error when sending to unknown agent")
}
}
func TestResubscribeAfterUnsubscribe(t *testing.T) {
b := newBus()
b.Subscribe("agent-e")
b.Unsubscribe("agent-e")
// Re-subscribe must work and deliver messages.
ch2 := b.Subscribe("agent-e")
msg := bus.AgentMessage{To: "agent-e", Kind: "ping"}
if err := b.Send(msg); err != nil {
t.Fatalf("Send after re-subscribe: %v", err)
}
got := <-ch2
if got.Kind != "ping" {
t.Fatalf("unexpected kind: %q", got.Kind)
}
}
+116
View File
@@ -0,0 +1,116 @@
package cron
import (
"context"
"fmt"
"os"
"strings"
"github.com/enmanuel/agents/internal/config"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
const actionKindSendMessage = "send_message"
const actionKindLLMPrompt = "llm_prompt"
// handler is a function that fires when a schedule triggers.
type handler func(ctx context.Context, room string)
// buildHandler returns the handler for a schedule, or nil for unsupported kinds.
func (s *Scheduler) buildHandler(sc config.ScheduleCfg) handler {
switch sc.Action.Kind {
case actionKindSendMessage:
return s.sendMessageHandler(sc)
case actionKindLLMPrompt:
return s.llmPromptHandler(sc)
default:
return nil
}
}
// sendMessageHandler returns a handler that sends a static message to a Matrix room.
// The message content is resolved in priority order: Message > Template file.
func (s *Scheduler) sendMessageHandler(sc config.ScheduleCfg) handler {
return func(ctx context.Context, room string) {
content, err := resolveContent(sc.Action.Message, sc.Action.Template)
if err != nil {
s.logger.Error("send_message: failed to resolve content",
"name", sc.Name, "err", err)
return
}
if content == "" {
s.logger.Warn("send_message: empty content, skipping", "name", sc.Name)
return
}
s.logger.Info("cron_fire", "name", sc.Name, "kind", actionKindSendMessage, "room", room)
if err := s.sender.SendMarkdown(ctx, room, content); err != nil {
s.logger.Error("send_message: bus send failed",
"name", sc.Name, "room", room, "err", err)
}
}
}
// llmPromptHandler returns a handler that calls the LLM with a prompt and sends
// the response to a Matrix room.
func (s *Scheduler) llmPromptHandler(sc config.ScheduleCfg) handler {
return func(ctx context.Context, room string) {
if s.llm == nil {
s.logger.Warn("llm_prompt: no LLM configured, skipping", "name", sc.Name)
return
}
prompt, err := resolveContent(sc.Action.Prompt, sc.Action.Template)
if err != nil {
s.logger.Error("llm_prompt: failed to resolve prompt",
"name", sc.Name, "err", err)
return
}
if prompt == "" {
s.logger.Warn("llm_prompt: empty prompt, skipping", "name", sc.Name)
return
}
s.logger.Info("cron_fire", "name", sc.Name, "kind", actionKindLLMPrompt, "room", room)
req := coretypes.CompletionRequest{
Model: s.model,
Messages: []coretypes.Message{
{Role: coretypes.RoleUser, Content: prompt},
},
}
resp, err := s.llm(ctx, req)
if err != nil {
s.logger.Error("llm_prompt: LLM call failed",
"name", sc.Name, "err", err)
return
}
content := strings.TrimSpace(resp.Content)
if content == "" {
s.logger.Warn("llm_prompt: LLM returned empty response", "name", sc.Name)
return
}
if err := s.sender.SendMarkdown(ctx, room, content); err != nil {
s.logger.Error("llm_prompt: bus send failed",
"name", sc.Name, "room", room, "err", err)
}
}
}
// resolveContent returns the inline text if non-empty, otherwise reads the file at templatePath.
func resolveContent(inline, templatePath string) (string, error) {
if inline != "" {
return inline, nil
}
if templatePath == "" {
return "", nil
}
data, err := os.ReadFile(templatePath)
if err != nil {
return "", fmt.Errorf("reading template %q: %w", templatePath, err)
}
return strings.TrimSpace(string(data)), nil
}
+110
View File
@@ -0,0 +1,110 @@
// Package cron provides a scheduler for autonomous bot activity.
// It is part of the impure shell: it reads files, calls LLMs, and sends messages
// over the bot's transport (unibus).
package cron
import (
"context"
"log/slog"
"github.com/robfig/cron/v3"
"github.com/enmanuel/agents/internal/config"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
// Sender is the subset of the bot's transport sender needed by the scheduler.
type Sender interface {
SendMarkdown(ctx context.Context, roomID, markdown string) error
}
// Scheduler fires configured schedules and executes send_message or llm_prompt actions.
type Scheduler struct {
cfg []config.ScheduleCfg
sender Sender
llm coretypes.CompleteFunc // nil when agent has no LLM
model string
logger *slog.Logger
cron *cron.Cron
}
// New creates a Scheduler. llm and model are optional (nil/empty for agents without LLM).
func New(
cfg []config.ScheduleCfg,
sender Sender,
llm coretypes.CompleteFunc,
model string,
logger *slog.Logger,
) *Scheduler {
return &Scheduler{
cfg: cfg,
sender: sender,
llm: llm,
model: model,
logger: logger.With("component", "cron"),
cron: cron.New(),
}
}
// Fire immediately executes the action for the given schedule, bypassing the cron timer.
// Useful for tests and manual triggering from CLI.
func (s *Scheduler) Fire(ctx context.Context, sc config.ScheduleCfg) {
room := sc.OutputRoom
if room == "" {
s.logger.Warn("Fire: schedule has no output_room, skipping", "name", sc.Name)
return
}
handler := s.buildHandler(sc)
if handler == nil {
s.logger.Warn("Fire: unsupported action kind", "name", sc.Name, "kind", sc.Action.Kind)
return
}
handler(ctx, room)
}
// Start registers all schedules and starts the cron loop.
// It returns when ctx is cancelled, stopping the cron runner.
func (s *Scheduler) Start(ctx context.Context) {
for _, sc := range s.cfg {
sc := sc // capture range var
if sc.Cron == "" || sc.Action.Kind == "" {
s.logger.Warn("skipping invalid schedule", "name", sc.Name, "cron", sc.Cron, "kind", sc.Action.Kind)
continue
}
room := sc.OutputRoom
if room == "" {
s.logger.Warn("schedule has no output_room, skipping", "name", sc.Name)
continue
}
handler := s.buildHandler(sc)
if handler == nil {
s.logger.Warn("unsupported action kind, skipping", "name", sc.Name, "kind", sc.Action.Kind)
continue
}
_, err := s.cron.AddFunc(sc.Cron, func() {
handler(ctx, room)
})
if err != nil {
s.logger.Error("failed to register schedule",
"name", sc.Name,
"cron", sc.Cron,
"err", err,
)
continue
}
s.logger.Info("schedule registered", "name", sc.Name, "cron", sc.Cron, "kind", sc.Action.Kind, "room", room)
}
s.cron.Start()
s.logger.Info("cron scheduler started", "schedules", len(s.cfg))
<-ctx.Done()
s.logger.Info("cron scheduler stopping")
cronCtx := s.cron.Stop()
<-cronCtx.Done()
s.logger.Info("cron scheduler stopped")
}
+326
View File
@@ -0,0 +1,326 @@
package cron_test
import (
"context"
"errors"
"log/slog"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"github.com/enmanuel/agents/internal/config"
coretypes "github.com/enmanuel/agents/pkg/llm"
shellcron "github.com/enmanuel/agents/shell/cron"
)
// ── fakes ──────────────────────────────────────────────────────────────────
type fakeSender struct {
calls atomic.Int32
lastMD string
lastRM string
}
func (f *fakeSender) SendMarkdown(_ context.Context, room, md string) error {
f.calls.Add(1)
f.lastRM = room
f.lastMD = md
return nil
}
type errSender struct{}
func (e *errSender) SendMarkdown(_ context.Context, _, _ string) error {
return errors.New("matrix unavailable")
}
func fakeLLM(reply string) coretypes.CompleteFunc {
return func(_ context.Context, _ coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
return coretypes.CompletionResponse{Content: reply}, nil
}
}
func newTestLogger(t *testing.T) *slog.Logger {
t.Helper()
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
}
// ── helpers ────────────────────────────────────────────────────────────────
// waitCalls blocks until the sender has received at least n calls or the deadline passes.
func waitCalls(t *testing.T, f *fakeSender, n int32) {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
if f.calls.Load() >= n {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("expected %d call(s) to SendMarkdown, got %d", n, f.calls.Load())
}
// ── cron-based tests (require timer) ──────────────────────────────────────
func TestScheduler_SendMessage_Inline(t *testing.T) {
sender := &fakeSender{}
cfg := []config.ScheduleCfg{
{
Name: "test-inline",
Cron: "@every 100ms",
OutputRoom: "!room:server.com",
Action: config.ScheduledAction{
Kind: "send_message",
Message: "hola mundo",
},
},
}
s := shellcron.New(cfg, sender, nil, "", newTestLogger(t))
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer close(done)
s.Start(ctx)
}()
waitCalls(t, sender, 1)
cancel()
<-done
if sender.lastRM != "!room:server.com" {
t.Errorf("unexpected room: %s", sender.lastRM)
}
if sender.lastMD != "hola mundo" {
t.Errorf("unexpected message: %s", sender.lastMD)
}
}
func TestScheduler_SendMessage_Template(t *testing.T) {
// Write a temporary template file
dir := t.TempDir()
tmpl := filepath.Join(dir, "greeting.md")
if err := os.WriteFile(tmpl, []byte("buenos días"), 0o600); err != nil {
t.Fatal(err)
}
sender := &fakeSender{}
cfg := []config.ScheduleCfg{
{
Name: "test-template",
Cron: "@every 100ms",
OutputRoom: "!room2:server.com",
Action: config.ScheduledAction{
Kind: "send_message",
Template: tmpl,
},
},
}
s := shellcron.New(cfg, sender, nil, "", newTestLogger(t))
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer close(done)
s.Start(ctx)
}()
waitCalls(t, sender, 1)
cancel()
<-done
if sender.lastMD != "buenos días" {
t.Errorf("unexpected message: %q", sender.lastMD)
}
}
func TestScheduler_LLMPrompt(t *testing.T) {
sender := &fakeSender{}
cfg := []config.ScheduleCfg{
{
Name: "test-llm",
Cron: "@every 100ms",
OutputRoom: "!room3:server.com",
Action: config.ScheduledAction{
Kind: "llm_prompt",
Prompt: "resume el día",
},
},
}
llm := fakeLLM("resumen generado por LLM")
s := shellcron.New(cfg, sender, llm, "gpt-4o", newTestLogger(t))
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer close(done)
s.Start(ctx)
}()
waitCalls(t, sender, 1)
cancel()
<-done
if sender.lastMD != "resumen generado por LLM" {
t.Errorf("unexpected LLM reply: %q", sender.lastMD)
}
}
func TestScheduler_MatrixSendError(t *testing.T) {
// If matrix.SendMarkdown returns an error, the scheduler should log it and not panic.
cfg := []config.ScheduleCfg{
{
Name: "err-send",
Cron: "@every 100ms",
OutputRoom: "!room:server.com",
Action: config.ScheduledAction{
Kind: "send_message",
Message: "trigger error",
},
},
}
s := shellcron.New(cfg, &errSender{}, nil, "", newTestLogger(t))
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer close(done)
s.Start(ctx)
}()
// Let it fire at least once without panicking
time.Sleep(250 * time.Millisecond)
cancel()
<-done
}
// ── Fire() tests (deterministic, no timer) ─────────────────────────────────
func TestFire_SendMessage_Inline(t *testing.T) {
sender := &fakeSender{}
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
sc := config.ScheduleCfg{
Name: "fire-inline",
Cron: "0 9 * * *",
OutputRoom: "!fireroom:server.com",
Action: config.ScheduledAction{
Kind: "send_message",
Message: "buenos días via Fire",
},
}
s.Fire(context.Background(), sc)
if sender.calls.Load() != 1 {
t.Fatalf("expected 1 call, got %d", sender.calls.Load())
}
if sender.lastRM != "!fireroom:server.com" {
t.Errorf("unexpected room: %s", sender.lastRM)
}
if sender.lastMD != "buenos días via Fire" {
t.Errorf("unexpected message: %s", sender.lastMD)
}
}
func TestFire_LLMPrompt(t *testing.T) {
sender := &fakeSender{}
llm := fakeLLM("respuesta del LLM via Fire")
s := shellcron.New(nil, sender, llm, "gpt-4o", newTestLogger(t))
sc := config.ScheduleCfg{
Name: "fire-llm",
Cron: "0 18 * * *",
OutputRoom: "!llmroom:server.com",
Action: config.ScheduledAction{
Kind: "llm_prompt",
Prompt: "resume el día",
},
}
s.Fire(context.Background(), sc)
if sender.calls.Load() != 1 {
t.Fatalf("expected 1 call, got %d", sender.calls.Load())
}
if sender.lastMD != "respuesta del LLM via Fire" {
t.Errorf("unexpected LLM reply: %q", sender.lastMD)
}
}
func TestFire_NoOutputRoom_Skips(t *testing.T) {
sender := &fakeSender{}
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
sc := config.ScheduleCfg{
Name: "fire-no-room",
Cron: "0 9 * * *",
OutputRoom: "", // intentionally empty
Action: config.ScheduledAction{
Kind: "send_message",
Message: "should not send",
},
}
s.Fire(context.Background(), sc)
if sender.calls.Load() != 0 {
t.Errorf("expected 0 calls when output_room is empty, got %d", sender.calls.Load())
}
}
func TestFire_LLMPrompt_NoLLM_Skips(t *testing.T) {
// When no LLM is configured, Fire with llm_prompt should not send anything.
sender := &fakeSender{}
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
sc := config.ScheduleCfg{
Name: "fire-no-llm",
Cron: "0 9 * * *",
OutputRoom: "!room:server.com",
Action: config.ScheduledAction{
Kind: "llm_prompt",
Prompt: "hello",
},
}
s.Fire(context.Background(), sc)
if sender.calls.Load() != 0 {
t.Errorf("expected 0 calls without LLM, got %d", sender.calls.Load())
}
}
func TestScheduler_SkipsInvalidSchedule(t *testing.T) {
// Schedules without output_room or without action kind must be skipped during Start.
// We use Fire directly to test the skip logic without timer overhead.
sender := &fakeSender{}
s := shellcron.New(nil, sender, nil, "", newTestLogger(t))
ctx := context.Background()
// No output_room → skip
s.Fire(ctx, config.ScheduleCfg{
Name: "no-room",
Cron: "@every 100ms",
// missing OutputRoom
Action: config.ScheduledAction{Kind: "send_message", Message: "hi"},
})
// No kind → Fire calls buildHandler which returns nil → skip
s.Fire(ctx, config.ScheduleCfg{
Name: "no-kind",
Cron: "@every 100ms",
OutputRoom: "!room:server.com",
// missing Action.Kind
})
if sender.calls.Load() != 0 {
t.Errorf("expected 0 calls for invalid schedules, got %d", sender.calls.Load())
}
}
+96
View File
@@ -0,0 +1,96 @@
// Package effects interprets pure []decision.Action values into real side effects.
package effects
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/enmanuel/agents/pkg/decision"
"github.com/enmanuel/agents/shell/logger"
"github.com/enmanuel/agents/shell/ssh"
)
// Result holds the outcome of executing a single action.
type Result struct {
Action decision.Action
Output string
Err error
}
// Sender is the transport-neutral message-sending capability the runner depends
// on. It is satisfied by the unibus bus sender (and was satisfied by the Matrix
// client before the bus migration). SendTyping is a no-op on transports without
// typing indicators.
type Sender interface {
SendText(ctx context.Context, roomID, text string) error
SendMarkdown(ctx context.Context, roomID, markdown string) error
SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error
SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error
SendTyping(ctx context.Context, roomID string, typing bool) error
}
// Runner interprets actions and executes them.
type Runner struct {
sender Sender
ssh *ssh.Executor
logger *slog.Logger
}
// NewRunner creates a Runner with the provided dependencies.
func NewRunner(sender Sender, ssh *ssh.Executor, logger *slog.Logger) *Runner {
return &Runner{sender: sender, ssh: ssh, logger: logger}
}
// Execute runs each action sequentially and returns results.
func (r *Runner) Execute(ctx context.Context, roomID string, actions []decision.Action) []Result {
r.logger.Debug("effects_batch", "room", roomID, "count", len(actions))
results := make([]Result, 0, len(actions))
for _, a := range actions {
start := time.Now()
res := r.executeOne(ctx, roomID, a)
ms := time.Since(start).Milliseconds()
results = append(results, res)
if res.Err != nil {
r.logger.Error("action_failed", logger.FieldAction, a.Kind, logger.FieldDurationMS, ms, "err", res.Err)
} else {
r.logger.Info("action_done", logger.FieldAction, a.Kind, logger.FieldDurationMS, ms)
}
}
return results
}
func (r *Runner) executeOne(ctx context.Context, roomID string, a decision.Action) Result {
switch a.Kind {
case decision.ActionKindReply:
if a.Reply == nil {
return Result{Action: a, Err: fmt.Errorf("nil reply action")}
}
var err error
switch {
case a.Reply.ThreadID != "":
// Thread reply: send as part of the thread with fallback in_reply_to
err = r.sender.SendThreadMarkdown(ctx, roomID, a.Reply.ThreadID, a.Reply.InReplyTo, a.Reply.Content)
case a.Reply.InReplyTo != "":
err = r.sender.SendReplyMarkdown(ctx, roomID, a.Reply.InReplyTo, a.Reply.Content)
default:
err = r.sender.SendMarkdown(ctx, roomID, a.Reply.Content)
}
return Result{Action: a, Output: a.Reply.Content, Err: err}
case decision.ActionKindSSH:
if a.SSH == nil {
return Result{Action: a, Err: fmt.Errorf("nil ssh action")}
}
res := r.ssh.Execute(ctx, *a.SSH)
output := res.Stdout
if res.Stderr != "" {
output += "\nstderr: " + res.Stderr
}
return Result{Action: a, Output: output, Err: res.Err}
default:
return Result{Action: a, Err: fmt.Errorf("unhandled action kind: %s", a.Kind)}
}
}
+172
View File
@@ -0,0 +1,172 @@
package effects
import (
"context"
"log/slog"
"testing"
"github.com/enmanuel/agents/pkg/decision"
)
// fakeSender records calls for assertions.
type fakeSender struct {
calls []senderCall
}
type senderCall struct {
method string
roomID string
threadRootID string
inReplyTo string
markdown string
}
func (f *fakeSender) SendText(ctx context.Context, roomID, text string) error {
f.calls = append(f.calls, senderCall{method: "SendText", roomID: roomID, markdown: text})
return nil
}
func (f *fakeSender) SendMarkdown(ctx context.Context, roomID, markdown string) error {
f.calls = append(f.calls, senderCall{method: "SendMarkdown", roomID: roomID, markdown: markdown})
return nil
}
func (f *fakeSender) SendReplyMarkdown(ctx context.Context, roomID, inReplyTo, markdown string) error {
f.calls = append(f.calls, senderCall{method: "SendReplyMarkdown", roomID: roomID, inReplyTo: inReplyTo, markdown: markdown})
return nil
}
func (f *fakeSender) SendThreadMarkdown(ctx context.Context, roomID, threadRootID, inReplyTo, markdown string) error {
f.calls = append(f.calls, senderCall{method: "SendThreadMarkdown", roomID: roomID, threadRootID: threadRootID, inReplyTo: inReplyTo, markdown: markdown})
return nil
}
func (f *fakeSender) SendTyping(ctx context.Context, roomID string, typing bool) error {
return nil
}
func TestExecuteReply_PlainMarkdown(t *testing.T) {
sender := &fakeSender{}
runner := NewRunner(sender, nil, slog.Default())
actions := []decision.Action{{
Kind: decision.ActionKindReply,
Reply: &decision.ReplyAction{Content: "hello"},
}}
results := runner.Execute(context.Background(), "!room:test", actions)
if len(results) != 1 || results[0].Err != nil {
t.Fatalf("unexpected results: %+v", results)
}
if len(sender.calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(sender.calls))
}
c := sender.calls[0]
if c.method != "SendMarkdown" {
t.Errorf("expected SendMarkdown, got %s", c.method)
}
if c.roomID != "!room:test" {
t.Errorf("expected room !room:test, got %s", c.roomID)
}
}
func TestExecuteReply_WithInReplyTo(t *testing.T) {
sender := &fakeSender{}
runner := NewRunner(sender, nil, slog.Default())
actions := []decision.Action{{
Kind: decision.ActionKindReply,
Reply: &decision.ReplyAction{Content: "hello", InReplyTo: "$evt1"},
}}
runner.Execute(context.Background(), "!room:test", actions)
if len(sender.calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(sender.calls))
}
c := sender.calls[0]
if c.method != "SendReplyMarkdown" {
t.Errorf("expected SendReplyMarkdown, got %s", c.method)
}
if c.inReplyTo != "$evt1" {
t.Errorf("expected inReplyTo=$evt1, got %s", c.inReplyTo)
}
}
func TestExecuteReply_WithThread(t *testing.T) {
sender := &fakeSender{}
runner := NewRunner(sender, nil, slog.Default())
actions := []decision.Action{{
Kind: decision.ActionKindReply,
Reply: &decision.ReplyAction{
Content: "thread reply",
ThreadID: "$root",
InReplyTo: "$evt2",
},
}}
runner.Execute(context.Background(), "!room:test", actions)
if len(sender.calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(sender.calls))
}
c := sender.calls[0]
if c.method != "SendThreadMarkdown" {
t.Errorf("expected SendThreadMarkdown, got %s", c.method)
}
if c.threadRootID != "$root" {
t.Errorf("expected threadRootID=$root, got %s", c.threadRootID)
}
if c.inReplyTo != "$evt2" {
t.Errorf("expected inReplyTo=$evt2, got %s", c.inReplyTo)
}
if c.roomID != "!room:test" {
t.Errorf("expected room !room:test, got %s", c.roomID)
}
}
func TestExecuteReply_ThreadWithoutInReplyTo(t *testing.T) {
sender := &fakeSender{}
runner := NewRunner(sender, nil, slog.Default())
actions := []decision.Action{{
Kind: decision.ActionKindReply,
Reply: &decision.ReplyAction{
Content: "thread reply no fallback",
ThreadID: "$root",
},
}}
runner.Execute(context.Background(), "!room:test", actions)
if len(sender.calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(sender.calls))
}
c := sender.calls[0]
if c.method != "SendThreadMarkdown" {
t.Errorf("expected SendThreadMarkdown, got %s", c.method)
}
// inReplyTo should be empty; SendThreadMarkdown will default to threadRootID
if c.inReplyTo != "" {
t.Errorf("expected empty inReplyTo, got %s", c.inReplyTo)
}
}
func TestExecuteReply_NilReply(t *testing.T) {
sender := &fakeSender{}
runner := NewRunner(sender, nil, slog.Default())
actions := []decision.Action{{
Kind: decision.ActionKindReply,
Reply: nil,
}}
results := runner.Execute(context.Background(), "!room:test", actions)
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if results[0].Err == nil {
t.Error("expected error for nil reply")
}
}
+12
View File
@@ -0,0 +1,12 @@
package shellknowledge
import (
"database/sql"
moderncsqlite "modernc.org/sqlite"
)
func init() {
// Register pure-Go SQLite driver as "sqlite3" for tests.
sql.Register("sqlite3", &moderncsqlite.Driver{})
}
+298
View File
@@ -0,0 +1,298 @@
// Package shellknowledge implements the knowledge store using files + SQLite FTS5.
package shellknowledge
import (
"context"
"database/sql"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/enmanuel/agents/pkg/knowledge"
)
const ftsSchema = `
CREATE VIRTUAL TABLE IF NOT EXISTS documents USING fts5(
slug,
title,
content,
updated_at UNINDEXED
);
`
var slugRe = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{0,62}[a-z0-9]$`)
// ValidSlug returns true if s is a valid document slug.
func ValidSlug(s string) bool {
if len(s) < 2 || len(s) > 64 {
return false
}
return slugRe.MatchString(s)
}
// FileStore implements knowledge.Store using markdown files + SQLite FTS5 index.
type FileStore struct {
dir string // path to agents/<id>/knowledge/
dbPath string // path to agents/<id>/data/knowledge.db
db *sql.DB
logger *slog.Logger
}
// New creates a FileStore. It ensures the knowledge dir and DB dir exist,
// opens the SQLite database, and creates the FTS5 table if needed.
func New(dir, dbPath string, logger *slog.Logger) (*FileStore, error) {
log := logger.With("component", "knowledge", "dir", dir, "db_path", dbPath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create knowledge dir: %w", err)
}
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
return nil, fmt.Errorf("create knowledge db dir: %w", err)
}
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("open knowledge db: %w", err)
}
// Enable WAL mode for better concurrency (allows multiple readers + single writer)
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close()
return nil, fmt.Errorf("enable WAL mode: %w", err)
}
if _, err := db.Exec(ftsSchema); err != nil {
db.Close()
return nil, fmt.Errorf("create knowledge fts5 table: %w", err)
}
log.Info("knowledge_store_ready")
return &FileStore{dir: dir, dbPath: dbPath, db: db, logger: log}, nil
}
// Sync re-indexes all .md files from disk into the FTS5 table.
func (s *FileStore) Sync(ctx context.Context) error {
entries, err := os.ReadDir(s.dir)
if err != nil {
return fmt.Errorf("read knowledge dir: %w", err)
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin sync tx: %w", err)
}
defer tx.Rollback()
// Clear existing index
if _, err := tx.ExecContext(ctx, `DELETE FROM documents`); err != nil {
return fmt.Errorf("clear fts5 index: %w", err)
}
count := 0
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".md") {
continue
}
slug := strings.TrimSuffix(e.Name(), ".md")
if !ValidSlug(slug) {
s.logger.Warn("skipping invalid slug", "file", e.Name())
continue
}
content, err := os.ReadFile(filepath.Join(s.dir, e.Name()))
if err != nil {
s.logger.Warn("skipping unreadable file", "file", e.Name(), "err", err)
continue
}
info, err := e.Info()
if err != nil {
s.logger.Warn("skipping file without info", "file", e.Name(), "err", err)
continue
}
title := extractTitle(string(content), slug)
mtime := info.ModTime().UTC().Format(time.RFC3339)
if _, err := tx.ExecContext(ctx,
`INSERT INTO documents (slug, title, content, updated_at) VALUES (?, ?, ?, ?)`,
slug, title, string(content), mtime,
); err != nil {
s.logger.Warn("failed to index file", "slug", slug, "err", err)
continue
}
count++
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit sync tx: %w", err)
}
s.logger.Info("knowledge_sync", "count", count)
return nil
}
// Search performs full-text search on the FTS5 index.
func (s *FileStore) Search(ctx context.Context, query string, limit int) ([]knowledge.SearchResult, error) {
if limit <= 0 {
limit = 5
}
rows, err := s.db.QueryContext(ctx,
`SELECT slug, title, snippet(documents, 2, '**', '**', '…', 32), rank
FROM documents WHERE documents MATCH ?
ORDER BY rank LIMIT ?`,
query, limit,
)
if err != nil {
return nil, fmt.Errorf("knowledge search: %w", err)
}
defer rows.Close()
var results []knowledge.SearchResult
for rows.Next() {
var r knowledge.SearchResult
if err := rows.Scan(&r.Slug, &r.Title, &r.Snippet, &r.Rank); err != nil {
return nil, err
}
results = append(results, r)
}
return results, rows.Err()
}
// Get reads a document from disk by slug.
func (s *FileStore) Get(ctx context.Context, slug string) (*knowledge.Document, error) {
if !ValidSlug(slug) {
return nil, fmt.Errorf("invalid slug: %q", slug)
}
path := filepath.Join(s.dir, slug+".md")
content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("document not found: %q", slug)
}
return nil, fmt.Errorf("read document: %w", err)
}
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("stat document: %w", err)
}
return &knowledge.Document{
Slug: slug,
Title: extractTitle(string(content), slug),
Content: string(content),
UpdatedAt: info.ModTime().UTC(),
}, nil
}
// Put writes a document to disk and updates the FTS5 index.
func (s *FileStore) Put(ctx context.Context, doc knowledge.Document) error {
if !ValidSlug(doc.Slug) {
return fmt.Errorf("invalid slug: %q", doc.Slug)
}
if len(doc.Content) > 64*1024 {
return fmt.Errorf("document too large: %d bytes (max 65536)", len(doc.Content))
}
path := filepath.Join(s.dir, doc.Slug+".md")
if err := os.WriteFile(path, []byte(doc.Content), 0o644); err != nil {
return fmt.Errorf("write document: %w", err)
}
title := extractTitle(doc.Content, doc.Slug)
now := time.Now().UTC().Format(time.RFC3339)
// Upsert: delete old + insert new (FTS5 doesn't support UPDATE well)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin put tx: %w", err)
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `DELETE FROM documents WHERE slug = ?`, doc.Slug); err != nil {
return fmt.Errorf("delete old index: %w", err)
}
if _, err := tx.ExecContext(ctx,
`INSERT INTO documents (slug, title, content, updated_at) VALUES (?, ?, ?, ?)`,
doc.Slug, title, doc.Content, now,
); err != nil {
return fmt.Errorf("insert index: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit put tx: %w", err)
}
s.logger.Debug("knowledge_put", "slug", doc.Slug, "size", len(doc.Content))
return nil
}
// Delete removes a document from disk and the FTS5 index.
func (s *FileStore) Delete(ctx context.Context, slug string) error {
if !ValidSlug(slug) {
return fmt.Errorf("invalid slug: %q", slug)
}
path := filepath.Join(s.dir, slug+".md")
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove document: %w", err)
}
if _, err := s.db.ExecContext(ctx, `DELETE FROM documents WHERE slug = ?`, slug); err != nil {
return fmt.Errorf("delete from index: %w", err)
}
s.logger.Debug("knowledge_delete", "slug", slug)
return nil
}
// List returns all documents from the FTS5 index.
func (s *FileStore) List(ctx context.Context) ([]knowledge.Document, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT slug, title, updated_at FROM documents ORDER BY slug`)
if err != nil {
return nil, fmt.Errorf("knowledge list: %w", err)
}
defer rows.Close()
var docs []knowledge.Document
for rows.Next() {
var d knowledge.Document
var updatedAt string
if err := rows.Scan(&d.Slug, &d.Title, &updatedAt); err != nil {
return nil, err
}
d.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
docs = append(docs, d)
}
return docs, rows.Err()
}
// Close releases the SQLite database.
func (s *FileStore) Close() error {
s.logger.Info("knowledge_store_closed")
return s.db.Close()
}
// extractTitle returns the first H1 heading from markdown content, or a humanized slug.
func extractTitle(content, slug string) string {
for _, line := range strings.SplitN(content, "\n", 20) {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "# ") {
return strings.TrimPrefix(line, "# ")
}
}
// Humanize slug: "go-patterns" → "Go patterns"
humanized := strings.ReplaceAll(slug, "-", " ")
if len(humanized) > 0 {
humanized = strings.ToUpper(humanized[:1]) + humanized[1:]
}
return humanized
}
+208
View File
@@ -0,0 +1,208 @@
package shellknowledge
import (
"context"
"os"
"path/filepath"
"testing"
"log/slog"
"github.com/enmanuel/agents/pkg/knowledge"
)
func testStore(t *testing.T) (*FileStore, string) {
t.Helper()
dir := t.TempDir()
knowledgeDir := filepath.Join(dir, "knowledge")
dbPath := filepath.Join(dir, "data", "knowledge.db")
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
store, err := New(knowledgeDir, dbPath, logger)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { store.Close() })
return store, knowledgeDir
}
func TestValidSlug(t *testing.T) {
tests := []struct {
slug string
want bool
}{
{"go-patterns", true},
{"ab", true},
{"a-b", true},
{"abc123", true},
{"a", false}, // too short
{"A-B", false}, // uppercase
{"-bad", false}, // starts with hyphen
{"bad-", false}, // ends with hyphen
{"has space", false}, // space
{"has_underscore", false}, // underscore
{"", false},
}
for _, tt := range tests {
if got := ValidSlug(tt.slug); got != tt.want {
t.Errorf("ValidSlug(%q) = %v, want %v", tt.slug, got, tt.want)
}
}
}
func TestPutAndGet(t *testing.T) {
store, _ := testStore(t)
ctx := context.Background()
doc := knowledge.Document{
Slug: "test-doc",
Content: "# Test Document\n\nThis is a test.",
}
if err := store.Put(ctx, doc); err != nil {
t.Fatal(err)
}
got, err := store.Get(ctx, "test-doc")
if err != nil {
t.Fatal(err)
}
if got.Content != doc.Content {
t.Errorf("content mismatch: got %q, want %q", got.Content, doc.Content)
}
if got.Title != "Test Document" {
t.Errorf("title = %q, want %q", got.Title, "Test Document")
}
}
func TestPutInvalidSlug(t *testing.T) {
store, _ := testStore(t)
ctx := context.Background()
err := store.Put(ctx, knowledge.Document{Slug: "BAD", Content: "test"})
if err == nil {
t.Error("expected error for invalid slug")
}
}
func TestPutTooLarge(t *testing.T) {
store, _ := testStore(t)
ctx := context.Background()
bigContent := make([]byte, 65*1024)
for i := range bigContent {
bigContent[i] = 'x'
}
err := store.Put(ctx, knowledge.Document{Slug: "too-big", Content: string(bigContent)})
if err == nil {
t.Error("expected error for oversized document")
}
}
func TestSyncAndSearch(t *testing.T) {
store, knowledgeDir := testStore(t)
ctx := context.Background()
// Write files directly to disk
os.WriteFile(filepath.Join(knowledgeDir, "go-patterns.md"),
[]byte("# Go Patterns\n\nUse interfaces for dependency injection."), 0o644)
os.WriteFile(filepath.Join(knowledgeDir, "matrix-tips.md"),
[]byte("# Matrix Tips\n\nUse mautrix-go for Matrix bots."), 0o644)
if err := store.Sync(ctx); err != nil {
t.Fatal(err)
}
// Search for "interfaces"
results, err := store.Search(ctx, "interfaces", 5)
if err != nil {
t.Fatal(err)
}
if len(results) == 0 {
t.Fatal("expected at least 1 search result")
}
if results[0].Slug != "go-patterns" {
t.Errorf("expected slug go-patterns, got %q", results[0].Slug)
}
}
func TestList(t *testing.T) {
store, _ := testStore(t)
ctx := context.Background()
// Empty initially
docs, err := store.List(ctx)
if err != nil {
t.Fatal(err)
}
if len(docs) != 0 {
t.Errorf("expected 0 docs, got %d", len(docs))
}
// Add two docs
store.Put(ctx, knowledge.Document{Slug: "alpha", Content: "# Alpha\nContent A"})
store.Put(ctx, knowledge.Document{Slug: "beta", Content: "# Beta\nContent B"})
docs, err = store.List(ctx)
if err != nil {
t.Fatal(err)
}
if len(docs) != 2 {
t.Fatalf("expected 2 docs, got %d", len(docs))
}
}
func TestDelete(t *testing.T) {
store, knowledgeDir := testStore(t)
ctx := context.Background()
store.Put(ctx, knowledge.Document{Slug: "to-delete", Content: "# Delete Me\nGoodbye"})
// Verify file exists
if _, err := os.Stat(filepath.Join(knowledgeDir, "to-delete.md")); err != nil {
t.Fatal("file should exist after Put")
}
if err := store.Delete(ctx, "to-delete"); err != nil {
t.Fatal(err)
}
// File removed
if _, err := os.Stat(filepath.Join(knowledgeDir, "to-delete.md")); !os.IsNotExist(err) {
t.Error("file should be removed after Delete")
}
// Not in index
_, err := store.Get(ctx, "to-delete")
if err == nil {
t.Error("expected error for deleted document")
}
}
func TestGetNotFound(t *testing.T) {
store, _ := testStore(t)
ctx := context.Background()
_, err := store.Get(ctx, "nonexistent")
if err == nil {
t.Error("expected error for nonexistent document")
}
}
func TestExtractTitle(t *testing.T) {
tests := []struct {
content string
slug string
want string
}{
{"# My Title\nBody", "slug", "My Title"},
{"No heading here", "my-doc", "My doc"},
{"", "empty-doc", "Empty doc"},
{"\n\n# Late Title\n", "slug", "Late Title"},
}
for _, tt := range tests {
got := extractTitle(tt.content, tt.slug)
if got != tt.want {
t.Errorf("extractTitle(%q, %q) = %q, want %q", tt.content, tt.slug, got, tt.want)
}
}
}
+242
View File
@@ -0,0 +1,242 @@
// Package llm contains impure LLM provider implementations.
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"time"
coretypes "github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/shell/logger"
)
const anthropicAPIBase = "https://api.anthropic.com/v1"
const anthropicVersion = "2023-06-01"
// NewAnthropicComplete returns a CompleteFunc backed by the Anthropic API.
func NewAnthropicComplete(apiKeyEnv, baseURL string, log *slog.Logger) coretypes.CompleteFunc {
if baseURL == "" {
baseURL = anthropicAPIBase
}
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
apiKey := os.Getenv(apiKeyEnv)
if apiKey == "" {
return coretypes.CompletionResponse{}, fmt.Errorf("env var %s is not set", apiKeyEnv)
}
log.Info("llm_request",
"provider", "anthropic",
"model", req.Model,
"messages", len(req.Messages),
"tools", len(req.Tools),
)
body := toAnthropicRequest(req)
raw, err := json.Marshal(body)
if err != nil {
return coretypes.CompletionResponse{}, fmt.Errorf("marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/messages", bytes.NewReader(raw))
if err != nil {
return coretypes.CompletionResponse{}, err
}
httpReq.Header.Set("x-api-key", apiKey)
httpReq.Header.Set("anthropic-version", anthropicVersion)
httpReq.Header.Set("content-type", "application/json")
start := time.Now()
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
ms := time.Since(start).Milliseconds()
log.Error("llm_error", "provider", "anthropic", logger.FieldDurationMS, ms, "err", err)
return coretypes.CompletionResponse{}, fmt.Errorf("anthropic request: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return coretypes.CompletionResponse{}, fmt.Errorf("read response: %w", err)
}
ms := time.Since(start).Milliseconds()
if resp.StatusCode != http.StatusOK {
log.Error("llm_error", "provider", "anthropic", logger.FieldDurationMS, ms, "status", resp.StatusCode)
return coretypes.CompletionResponse{}, fmt.Errorf("anthropic error %d: %s", resp.StatusCode, respBytes)
}
result, err := fromAnthropicResponse(respBytes)
if err != nil {
log.Error("llm_error", "provider", "anthropic", logger.FieldDurationMS, ms, "err", err)
return result, err
}
log.Info("llm_response",
"provider", "anthropic",
"model", req.Model,
logger.FieldDurationMS, ms,
logger.FieldTokensUsed, result.Usage.TotalTokens,
"input_tokens", result.Usage.InputTokens,
"output_tokens", result.Usage.OutputTokens,
"tool_calls", len(result.ToolCalls),
"finish_reason", result.FinishReason,
)
return result, nil
}
}
// ── private conversion helpers ────────────────────────────────────────────
type anthropicRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System string `json:"system,omitempty"`
Messages []anthropicMessage `json:"messages"`
Tools []anthropicTool `json:"tools,omitempty"`
}
type anthropicMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
type anthropicTool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]any `json:"input_schema"`
}
// anthropicContentBlock represents a block in a content array.
type anthropicContentBlock struct {
Type string `json:"type"`
// text block
Text string `json:"text,omitempty"`
// tool_use block (in assistant responses)
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input map[string]any `json:"input,omitempty"`
// tool_result block (in user messages)
ToolUseID string `json:"tool_use_id,omitempty"`
Content string `json:"content,omitempty"`
}
type anthropicResponse struct {
Content []anthropicContentBlock `json:"content"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
StopReason string `json:"stop_reason"`
}
func toAnthropicRequest(req coretypes.CompletionRequest) anthropicRequest {
msgs := make([]anthropicMessage, 0, len(req.Messages))
for _, m := range req.Messages {
if m.Role == coretypes.RoleSystem {
continue
}
msgs = append(msgs, toAnthropicMessage(m))
}
tools := make([]anthropicTool, len(req.Tools))
for i, t := range req.Tools {
tools[i] = anthropicTool{
Name: t.Name,
Description: t.Description,
InputSchema: t.InputSchema,
}
}
return anthropicRequest{
Model: req.Model,
MaxTokens: req.MaxTokens,
System: req.SystemPrompt,
Messages: msgs,
Tools: tools,
}
}
// toAnthropicMessage converts a core Message to the Anthropic format.
// Handles plain text, assistant messages with tool calls, and tool result messages.
func toAnthropicMessage(m coretypes.Message) anthropicMessage {
// Assistant message with tool calls → content array with text + tool_use blocks
if m.Role == coretypes.RoleAssistant && len(m.ToolCalls) > 0 {
blocks := make([]anthropicContentBlock, 0, len(m.ToolCalls)+1)
if m.Content != "" {
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: m.Content})
}
for _, tc := range m.ToolCalls {
var input map[string]any
_ = json.Unmarshal([]byte(tc.Arguments), &input)
blocks = append(blocks, anthropicContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Name,
Input: input,
})
}
raw, _ := json.Marshal(blocks)
return anthropicMessage{Role: "assistant", Content: raw}
}
// Tool result message → user message with tool_result content array
if m.Role == coretypes.RoleTool {
blocks := []anthropicContentBlock{{
Type: "tool_result",
ToolUseID: m.ToolCallID,
Content: m.Content,
}}
raw, _ := json.Marshal(blocks)
return anthropicMessage{Role: "user", Content: raw}
}
// Plain text message
raw, _ := json.Marshal(m.Content)
return anthropicMessage{Role: string(m.Role), Content: raw}
}
func fromAnthropicResponse(raw []byte) (coretypes.CompletionResponse, error) {
var ar anthropicResponse
if err := json.Unmarshal(raw, &ar); err != nil {
return coretypes.CompletionResponse{}, fmt.Errorf("unmarshal response: %w", err)
}
var content string
var toolCalls []coretypes.ToolCall
for _, c := range ar.Content {
switch c.Type {
case "text":
content += c.Text
case "tool_use":
argsJSON, _ := json.Marshal(c.Input)
toolCalls = append(toolCalls, coretypes.ToolCall{
ID: c.ID,
Name: c.Name,
Arguments: string(argsJSON),
})
}
}
return coretypes.CompletionResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: ar.StopReason,
Usage: coretypes.TokenUsage{
InputTokens: ar.Usage.InputTokens,
OutputTokens: ar.Usage.OutputTokens,
TotalTokens: ar.Usage.InputTokens + ar.Usage.OutputTokens,
},
}, nil
}
+295
View File
@@ -0,0 +1,295 @@
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"os/exec"
"strings"
"syscall"
"time"
"github.com/enmanuel/agents/internal/config"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
const (
defaultClaudeBinary = "claude"
defaultClaudeTimeout = 5 * time.Minute
)
// claudeJSONOutput represents the JSON output from `claude -p --output-format json`.
type claudeJSONOutput struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
CostUSD float64 `json:"cost_usd"`
IsError bool `json:"is_error"`
Duration float64 `json:"duration_api_ms"`
NumTurns int `json:"num_turns"`
Result string `json:"result"`
SessionID string `json:"session_id"`
TotalCost float64 `json:"total_cost_usd"`
Usage claudeUsage `json:"usage"`
ContentBlock []claudeContent `json:"content"`
}
type claudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type claudeContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
// NewClaudeCodeComplete creates a CompleteFunc that executes `claude -p` as a subprocess.
func NewClaudeCodeComplete(cfg config.ClaudeCodeCfg, log *slog.Logger) coretypes.CompleteFunc {
binary := cfg.Binary
if binary == "" {
binary = defaultClaudeBinary
}
timeout := cfg.Timeout
if timeout <= 0 {
timeout = defaultClaudeTimeout
}
// Resolve working directory once at init time.
workDir := resolveWorkDir(cfg.WorkingDir, log)
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
args := buildClaudeArgs(cfg, req)
prompt := flattenMessages(req.Messages)
log.Debug("claude_code_exec",
"binary", binary,
"args", strings.Join(args, " "),
"prompt_len", len(prompt),
"working_dir", workDir,
)
cmd := exec.CommandContext(ctx, binary, args...)
if workDir != "" {
cmd.Dir = workDir
}
// Build clean env: inherit parent but remove ANTHROPIC_API_KEY
// so claude uses its own OAuth auth instead of a potentially invalid key.
cmd.Env = filterEnv(os.Environ(), "ANTHROPIC_API_KEY")
cmd.Stdin = strings.NewReader(prompt)
// Create a new process group so we can kill claude + all its children.
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// Override the default cancel behavior: kill the entire process group
// instead of just the main process, preventing orphaned child processes.
cmd.Cancel = func() error {
if cmd.Process != nil {
pgid := cmd.Process.Pid
log.Info("killing claude-code process group", "pgid", pgid)
// Negative PID = kill entire process group
return syscall.Kill(-pgid, syscall.SIGKILL)
}
return nil
}
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
start := time.Now()
err := cmd.Run()
elapsed := time.Since(start)
// Ensure the process group is fully dead after Run returns,
// even if cmd.Run() returned without triggering Cancel (normal exit).
if cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}
log.Debug("claude_code_done",
"elapsed_ms", elapsed.Milliseconds(),
"stdout_len", stdout.Len(),
"stderr_len", stderr.Len(),
"exit_err", err,
)
return parseClaudeOutput(stdout.Bytes(), stderr.Bytes(), err, elapsed, log)
}
}
// resolveWorkDir determines the working directory for the claude subprocess.
// If configured is empty, it creates a temporary directory to avoid inheriting the launcher's CWD.
// If configured is non-empty, it ensures the directory exists.
func resolveWorkDir(configured string, log *slog.Logger) string {
if configured == "" {
tmp, err := os.MkdirTemp("", "claude-agent-*")
if err != nil {
log.Error("claude-code: failed to create temp working dir", "err", err)
return "" // Fall through — cmd.Dir will remain empty (inherits CWD).
}
log.Warn("claude-code working_dir is empty, using temporary directory",
"dir", tmp,
)
return tmp
}
// Ensure configured directory exists.
if err := os.MkdirAll(configured, 0o755); err != nil {
log.Error("claude-code: failed to create working dir", "dir", configured, "err", err)
}
return configured
}
// buildClaudeArgs constructs the CLI arguments for claude -p.
func buildClaudeArgs(cfg config.ClaudeCodeCfg, req coretypes.CompletionRequest) []string {
args := []string{"--print", "--output-format", "json"}
if req.SystemPrompt != "" {
args = append(args, "--system-prompt", req.SystemPrompt)
}
if cfg.DisableTools {
args = append(args, "--tools", "")
} else {
if len(cfg.AllowedTools) > 0 {
args = append(args, "--allowedTools")
args = append(args, cfg.AllowedTools...)
}
if len(cfg.DisallowedTools) > 0 {
args = append(args, "--disallowedTools")
args = append(args, cfg.DisallowedTools...)
}
}
if cfg.PermissionMode != "" {
args = append(args, "--permission-mode", cfg.PermissionMode)
}
if cfg.Model != "" {
args = append(args, "--model", cfg.Model)
}
if cfg.FallbackModel != "" {
args = append(args, "--fallback-model", cfg.FallbackModel)
}
if cfg.SessionID != "" {
args = append(args, "--session-id", cfg.SessionID)
}
for _, dir := range cfg.AddDirs {
args = append(args, "--add-dir", dir)
}
return args
}
// flattenMessages converts a conversation history into a single text prompt for stdin.
func flattenMessages(msgs []coretypes.Message) string {
var b strings.Builder
for _, m := range msgs {
switch m.Role {
case coretypes.RoleUser:
fmt.Fprintf(&b, "User: %s\n\n", m.Content)
case coretypes.RoleAssistant:
fmt.Fprintf(&b, "Assistant: %s\n\n", m.Content)
case coretypes.RoleTool:
fmt.Fprintf(&b, "Tool result: %s\n\n", m.Content)
}
}
return b.String()
}
// parseClaudeOutput parses the JSON output from `claude -p --output-format json`.
func parseClaudeOutput(
stdout, stderr []byte,
execErr error,
elapsed time.Duration,
log *slog.Logger,
) (coretypes.CompletionResponse, error) {
// If the process failed and there's no stdout, report the error
if execErr != nil && len(stdout) == 0 {
errMsg := string(stderr)
if errMsg == "" {
errMsg = execErr.Error()
}
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code process failed: %s", errMsg)
}
// Parse JSON output
var output claudeJSONOutput
if err := json.Unmarshal(stdout, &output); err != nil {
// Fall back to treating stdout as plain text
log.Warn("claude_code_json_parse_failed", "err", err, "stdout_len", len(stdout))
return coretypes.CompletionResponse{
Content: strings.TrimSpace(string(stdout)),
FinishReason: "stop",
}, nil
}
if output.IsError {
return coretypes.CompletionResponse{}, fmt.Errorf("claude-code error: %s", output.Result)
}
// Extract text from result field or content blocks
content := output.Result
if content == "" && len(output.ContentBlock) > 0 {
var parts []string
for _, block := range output.ContentBlock {
if block.Type == "text" && block.Text != "" {
parts = append(parts, block.Text)
}
}
content = strings.Join(parts, "\n")
}
finishReason := "stop"
if execErr != nil {
finishReason = "error"
}
log.Info("claude_code_response",
"content_len", len(content),
"input_tokens", output.Usage.InputTokens,
"output_tokens", output.Usage.OutputTokens,
"num_turns", output.NumTurns,
"cost_usd", output.TotalCost,
"elapsed_ms", elapsed.Milliseconds(),
)
return coretypes.CompletionResponse{
Content: content,
Usage: coretypes.TokenUsage{
InputTokens: output.Usage.InputTokens,
OutputTokens: output.Usage.OutputTokens,
TotalTokens: output.Usage.InputTokens + output.Usage.OutputTokens,
},
FinishReason: finishReason,
}, nil
}
// filterEnv returns a copy of environ with the named keys removed.
func filterEnv(environ []string, keys ...string) []string {
out := make([]string, 0, len(environ))
for _, e := range environ {
skip := false
for _, k := range keys {
if strings.HasPrefix(e, k+"=") {
skip = true
break
}
}
if !skip {
out = append(out, e)
}
}
return out
}
+402
View File
@@ -0,0 +1,402 @@
package llm
import (
"encoding/json"
"errors"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/enmanuel/agents/internal/config"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
var discardLog = slog.New(slog.NewTextHandler(io.Discard, nil))
// ── buildClaudeArgs ──────────────────────────────────────────────────────
func TestBuildClaudeArgs_Minimal(t *testing.T) {
cfg := config.ClaudeCodeCfg{}
req := coretypes.CompletionRequest{}
args := buildClaudeArgs(cfg, req)
// Must always start with --print --output-format json
want := []string{"--print", "--output-format", "json"}
if len(args) != len(want) {
t.Fatalf("got %v, want %v", args, want)
}
for i := range want {
if args[i] != want[i] {
t.Errorf("args[%d] = %q, want %q", i, args[i], want[i])
}
}
}
func TestBuildClaudeArgs_AllOptions(t *testing.T) {
cfg := config.ClaudeCodeCfg{
Model: "sonnet",
FallbackModel: "haiku",
PermissionMode: "bypassPermissions",
AllowedTools: []string{"Bash(git:*)", "Read"},
SessionID: "abc-123",
AddDirs: []string{"/tmp/extra"},
}
req := coretypes.CompletionRequest{
SystemPrompt: "You are a helpful bot",
}
args := buildClaudeArgs(cfg, req)
assertContains(t, args, "--system-prompt", "You are a helpful bot")
assertContains(t, args, "--model", "sonnet")
assertContains(t, args, "--fallback-model", "haiku")
assertContains(t, args, "--permission-mode", "bypassPermissions")
assertContains(t, args, "--session-id", "abc-123")
assertContains(t, args, "--add-dir", "/tmp/extra")
assertContains(t, args, "--allowedTools", "Bash(git:*)")
}
func TestBuildClaudeArgs_DisableTools(t *testing.T) {
cfg := config.ClaudeCodeCfg{
DisableTools: true,
AllowedTools: []string{"Bash"}, // should be ignored
}
req := coretypes.CompletionRequest{}
args := buildClaudeArgs(cfg, req)
assertContains(t, args, "--tools", "")
// --allowedTools must NOT appear when disable_tools is set
for _, a := range args {
if a == "--allowedTools" {
t.Error("--allowedTools should not appear when DisableTools=true")
}
}
}
func TestBuildClaudeArgs_DisallowedTools(t *testing.T) {
cfg := config.ClaudeCodeCfg{
DisallowedTools: []string{"Edit", "Write"},
}
req := coretypes.CompletionRequest{}
args := buildClaudeArgs(cfg, req)
assertContains(t, args, "--disallowedTools", "Edit")
}
// ── flattenMessages ──────────────────────────────────────────────────────
func TestFlattenMessages_Empty(t *testing.T) {
got := flattenMessages(nil)
if got != "" {
t.Errorf("expected empty, got %q", got)
}
}
func TestFlattenMessages_MultiRole(t *testing.T) {
msgs := []coretypes.Message{
{Role: coretypes.RoleUser, Content: "hello"},
{Role: coretypes.RoleAssistant, Content: "hi there"},
{Role: coretypes.RoleTool, Content: `{"time":"12:00"}`},
{Role: coretypes.RoleUser, Content: "thanks"},
}
got := flattenMessages(msgs)
expects := []string{
"User: hello",
"Assistant: hi there",
`Tool result: {"time":"12:00"}`,
"User: thanks",
}
for _, e := range expects {
if !contains(got, e) {
t.Errorf("missing %q in:\n%s", e, got)
}
}
}
func TestFlattenMessages_SkipsSystem(t *testing.T) {
msgs := []coretypes.Message{
{Role: coretypes.RoleSystem, Content: "system prompt"},
{Role: coretypes.RoleUser, Content: "hello"},
}
got := flattenMessages(msgs)
if contains(got, "system prompt") {
t.Error("system messages should not appear in flattened output")
}
if !contains(got, "User: hello") {
t.Error("user message missing")
}
}
// ── parseClaudeOutput ────────────────────────────────────────────────────
func TestParseClaudeOutput_Success(t *testing.T) {
output := claudeJSONOutput{
Type: "result",
Subtype: "success",
IsError: false,
NumTurns: 1,
Result: "Hello! I'm Claude.",
TotalCost: 0.025,
Usage: claudeUsage{InputTokens: 10, OutputTokens: 50},
}
stdout, _ := json.Marshal(output)
resp, err := parseClaudeOutput(stdout, nil, nil, 2*time.Second, discardLog)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "Hello! I'm Claude." {
t.Errorf("content = %q, want %q", resp.Content, "Hello! I'm Claude.")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("input tokens = %d, want 10", resp.Usage.InputTokens)
}
if resp.Usage.OutputTokens != 50 {
t.Errorf("output tokens = %d, want 50", resp.Usage.OutputTokens)
}
if resp.Usage.TotalTokens != 60 {
t.Errorf("total tokens = %d, want 60", resp.Usage.TotalTokens)
}
if resp.FinishReason != "stop" {
t.Errorf("finish reason = %q, want %q", resp.FinishReason, "stop")
}
}
func TestParseClaudeOutput_ErrorResponse(t *testing.T) {
output := claudeJSONOutput{
IsError: true,
Result: "Invalid API key",
}
stdout, _ := json.Marshal(output)
_, err := parseClaudeOutput(stdout, nil, nil, time.Second, discardLog)
if err == nil {
t.Fatal("expected error for IsError=true")
}
if !contains(err.Error(), "Invalid API key") {
t.Errorf("error = %q, should contain 'Invalid API key'", err.Error())
}
}
func TestParseClaudeOutput_ProcessFailedNoStdout(t *testing.T) {
_, err := parseClaudeOutput(nil, []byte("unknown option\n"), errors.New("exit 1"), time.Second, discardLog)
if err == nil {
t.Fatal("expected error when process fails with no stdout")
}
if !contains(err.Error(), "unknown option") {
t.Errorf("error = %q, should contain stderr message", err.Error())
}
}
func TestParseClaudeOutput_ProcessFailedNoStderr(t *testing.T) {
_, err := parseClaudeOutput(nil, nil, errors.New("exit 1"), time.Second, discardLog)
if err == nil {
t.Fatal("expected error")
}
if !contains(err.Error(), "exit 1") {
t.Errorf("error = %q, should contain exec error", err.Error())
}
}
func TestParseClaudeOutput_FallbackPlainText(t *testing.T) {
// Non-JSON stdout should be treated as plain text
resp, err := parseClaudeOutput([]byte("just plain text\n"), nil, nil, time.Second, discardLog)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "just plain text" {
t.Errorf("content = %q, want %q", resp.Content, "just plain text")
}
}
func TestParseClaudeOutput_ContentBlocks(t *testing.T) {
output := claudeJSONOutput{
Result: "", // empty result, content in blocks
ContentBlock: []claudeContent{
{Type: "text", Text: "First part."},
{Type: "text", Text: "Second part."},
},
Usage: claudeUsage{InputTokens: 5, OutputTokens: 20},
}
stdout, _ := json.Marshal(output)
resp, err := parseClaudeOutput(stdout, nil, nil, time.Second, discardLog)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "First part.\nSecond part." {
t.Errorf("content = %q, want joined blocks", resp.Content)
}
}
func TestParseClaudeOutput_ExecErrWithStdout(t *testing.T) {
// Process failed but produced valid JSON output — should parse and set finish_reason=error
output := claudeJSONOutput{
Result: "partial answer",
Usage: claudeUsage{InputTokens: 3, OutputTokens: 10},
}
stdout, _ := json.Marshal(output)
resp, err := parseClaudeOutput(stdout, nil, errors.New("timeout"), time.Second, discardLog)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.FinishReason != "error" {
t.Errorf("finish reason = %q, want %q", resp.FinishReason, "error")
}
if resp.Content != "partial answer" {
t.Errorf("content = %q", resp.Content)
}
}
// ── filterEnv ────────────────────────────────────────────────────────────
func TestFilterEnv_RemovesSingleKey(t *testing.T) {
env := []string{
"HOME=/home/user",
"ANTHROPIC_API_KEY=sk-secret",
"PATH=/usr/bin",
}
got := filterEnv(env, "ANTHROPIC_API_KEY")
if len(got) != 2 {
t.Fatalf("expected 2 entries, got %d: %v", len(got), got)
}
for _, e := range got {
if contains(e, "ANTHROPIC_API_KEY") {
t.Errorf("ANTHROPIC_API_KEY should have been removed: %v", got)
}
}
}
func TestFilterEnv_RemovesMultipleKeys(t *testing.T) {
env := []string{
"HOME=/home/user",
"ANTHROPIC_API_KEY=sk-secret",
"OPENAI_API_KEY=sk-openai",
"PATH=/usr/bin",
}
got := filterEnv(env, "ANTHROPIC_API_KEY", "OPENAI_API_KEY")
if len(got) != 2 {
t.Fatalf("expected 2 entries, got %d: %v", len(got), got)
}
}
func TestFilterEnv_NoMatchKeepsAll(t *testing.T) {
env := []string{"HOME=/home/user", "PATH=/usr/bin"}
got := filterEnv(env, "NONEXISTENT")
if len(got) != 2 {
t.Fatalf("expected 2, got %d", len(got))
}
}
func TestFilterEnv_PrefixSafety(t *testing.T) {
// ANTHROPIC_API_KEY_V2 should NOT be removed when filtering ANTHROPIC_API_KEY
env := []string{
"ANTHROPIC_API_KEY=secret",
"ANTHROPIC_API_KEY_V2=other",
}
got := filterEnv(env, "ANTHROPIC_API_KEY")
if len(got) != 1 {
t.Fatalf("expected 1, got %d: %v", len(got), got)
}
if got[0] != "ANTHROPIC_API_KEY_V2=other" {
t.Errorf("wrong entry kept: %q", got[0])
}
}
// ── resolveWorkDir ──────────────────────────────────────────────────────
func TestResolveWorkDir_EmptyCreatesTempDir(t *testing.T) {
dir := resolveWorkDir("", discardLog)
if dir == "" {
t.Fatal("expected a temp directory, got empty string")
}
defer os.RemoveAll(dir)
if !strings.Contains(dir, "claude-agent-") {
t.Errorf("temp dir %q should contain 'claude-agent-' prefix", dir)
}
info, err := os.Stat(dir)
if err != nil {
t.Fatalf("temp dir should exist: %v", err)
}
if !info.IsDir() {
t.Error("temp dir should be a directory")
}
}
func TestResolveWorkDir_ConfiguredValueUsed(t *testing.T) {
want := filepath.Join(t.TempDir(), "custom-workdir")
got := resolveWorkDir(want, discardLog)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
info, err := os.Stat(got)
if err != nil {
t.Fatalf("configured dir should be created: %v", err)
}
if !info.IsDir() {
t.Error("configured dir should be a directory")
}
}
func TestResolveWorkDir_ConfiguredAlreadyExists(t *testing.T) {
want := t.TempDir() // already exists
got := resolveWorkDir(want, discardLog)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
// ── helpers ──────────────────────────────────────────────────────────────
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && stringContains(s, substr)))
}
func stringContains(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
func assertContains(t *testing.T, args []string, flag, value string) {
t.Helper()
for i, a := range args {
if a == flag && i+1 < len(args) && args[i+1] == value {
return
}
// For --tools "" where value is empty string
if a == flag && value == "" && i+1 < len(args) && args[i+1] == "" {
return
}
}
t.Errorf("args %v missing %s %q", args, flag, value)
}
+51
View File
@@ -0,0 +1,51 @@
package llm
import (
"context"
"fmt"
"log/slog"
"github.com/enmanuel/agents/internal/config"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
// FromConfig builds a CompleteFunc from an LLMProviderCfg.
func FromConfig(cfg config.LLMProviderCfg, log *slog.Logger) (coretypes.CompleteFunc, error) {
log.Info("llm_provider_init", "provider", cfg.Provider, "model", cfg.Model)
switch cfg.Provider {
case "anthropic":
return NewAnthropicComplete(cfg.APIKeyEnv, cfg.BaseURL, log), nil
case "openai":
return NewOpenAIComplete(cfg.APIKeyEnv, cfg.BaseURL, log), nil
case "ollama":
base := cfg.BaseURL
if base == "" {
base = "http://localhost:11434/v1"
}
return NewOpenAIComplete("OLLAMA_API_KEY", base, log), nil
case "claude-code":
return NewClaudeCodeComplete(cfg.ClaudeCode, log), nil
default:
return nil, fmt.Errorf("unknown LLM provider: %s", cfg.Provider)
}
}
// WithFallback wraps primary with a fallback CompleteFunc.
// If primary returns an error, fallback is tried with the fallback config's model.
func WithFallback(primary, fallback coretypes.CompleteFunc, fallbackCfg config.LLMProviderCfg, log *slog.Logger) coretypes.CompleteFunc {
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
resp, err := primary(ctx, req)
if err != nil {
log.Warn("llm_fallback_triggered", "primary_err", err)
// Override request fields with fallback config values
if fallbackCfg.Model != "" {
req.Model = fallbackCfg.Model
}
if fallbackCfg.MaxTokens > 0 {
req.MaxTokens = fallbackCfg.MaxTokens
}
return fallback(ctx, req)
}
return resp, nil
}
}
+169
View File
@@ -0,0 +1,169 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"time"
openai "github.com/sashabaranov/go-openai"
coretypes "github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/shell/logger"
)
// NewOpenAIComplete returns a CompleteFunc backed by the OpenAI-compatible API.
// Works with OpenAI, Ollama, vLLM, LMStudio — just change baseURL.
func NewOpenAIComplete(apiKeyEnv, baseURL string, log *slog.Logger) coretypes.CompleteFunc {
return func(ctx context.Context, req coretypes.CompletionRequest) (coretypes.CompletionResponse, error) {
apiKey := os.Getenv(apiKeyEnv)
if apiKey == "" {
apiKey = "ollama" // Ollama doesn't require a real key
}
cfg := openai.DefaultConfig(apiKey)
if baseURL != "" {
cfg.BaseURL = baseURL
}
client := openai.NewClientWithConfig(cfg)
msgs := make([]openai.ChatCompletionMessage, 0, len(req.Messages)+1)
if req.SystemPrompt != "" {
msgs = append(msgs, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: req.SystemPrompt,
})
}
for _, m := range req.Messages {
msgs = append(msgs, toOpenAIMessage(m))
}
openReq := openai.ChatCompletionRequest{
Model: req.Model,
Messages: msgs,
MaxTokens: req.MaxTokens,
Temperature: float32(req.Temperature),
}
// Add tools if present
if len(req.Tools) > 0 {
openReq.Tools = toOpenAITools(req.Tools)
}
log.Info("llm_request",
"provider", "openai",
"model", req.Model,
"messages", len(req.Messages),
"tools", len(req.Tools),
)
start := time.Now()
resp, err := client.CreateChatCompletion(ctx, openReq)
if err != nil {
ms := time.Since(start).Milliseconds()
log.Error("llm_error", "provider", "openai", logger.FieldDurationMS, ms, "err", err)
return coretypes.CompletionResponse{}, fmt.Errorf("openai completion: %w", err)
}
ms := time.Since(start).Milliseconds()
if len(resp.Choices) == 0 {
log.Error("llm_error", "provider", "openai", logger.FieldDurationMS, ms, "err", "empty choices")
return coretypes.CompletionResponse{}, fmt.Errorf("openai: empty choices")
}
choice := resp.Choices[0]
var toolCalls []coretypes.ToolCall
for _, tc := range choice.Message.ToolCalls {
toolCalls = append(toolCalls, coretypes.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
})
}
log.Info("llm_response",
"provider", "openai",
"model", req.Model,
logger.FieldDurationMS, ms,
logger.FieldTokensUsed, resp.Usage.TotalTokens,
"input_tokens", resp.Usage.PromptTokens,
"output_tokens", resp.Usage.CompletionTokens,
"tool_calls", len(toolCalls),
"finish_reason", string(choice.FinishReason),
)
return coretypes.CompletionResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: string(choice.FinishReason),
Usage: coretypes.TokenUsage{
InputTokens: resp.Usage.PromptTokens,
OutputTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
},
}, nil
}
}
// toOpenAIMessage converts a core Message to an OpenAI ChatCompletionMessage.
func toOpenAIMessage(m coretypes.Message) openai.ChatCompletionMessage {
role := openai.ChatMessageRoleUser
switch m.Role {
case coretypes.RoleAssistant:
role = openai.ChatMessageRoleAssistant
case coretypes.RoleSystem:
role = openai.ChatMessageRoleSystem
case coretypes.RoleTool:
role = openai.ChatMessageRoleTool
}
msg := openai.ChatCompletionMessage{
Role: role,
Content: m.Content,
ToolCallID: m.ToolCallID,
}
// Assistant messages with tool calls
if m.Role == coretypes.RoleAssistant && len(m.ToolCalls) > 0 {
msg.ToolCalls = make([]openai.ToolCall, len(m.ToolCalls))
for i, tc := range m.ToolCalls {
msg.ToolCalls[i] = openai.ToolCall{
ID: tc.ID,
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{
Name: tc.Name,
Arguments: tc.Arguments,
},
}
}
}
return msg
}
// toOpenAITools converts core ToolSpecs to OpenAI Tool format.
func toOpenAITools(specs []coretypes.ToolSpec) []openai.Tool {
tools := make([]openai.Tool, len(specs))
for i, s := range specs {
tools[i] = openai.Tool{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: s.Name,
Description: s.Description,
Parameters: json.RawMessage(marshalSchema(s.InputSchema)),
},
}
}
return tools
}
// marshalSchema marshals a JSON schema map to bytes. Falls back to empty object.
func marshalSchema(schema map[string]any) []byte {
b, err := json.Marshal(schema)
if err != nil {
return []byte("{}")
}
return b
}
+84
View File
@@ -0,0 +1,84 @@
package logger
import (
"context"
"os"
"path/filepath"
"strings"
"time"
)
// runCleanup periodically removes log files older than maxAgeDays for the
// given agent. It runs until ctx is cancelled.
func runCleanup(ctx context.Context, baseDir, agentID string, maxAgeDays int, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
// Run once immediately at startup.
cleanOldLogs(baseDir, agentID, maxAgeDays)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cleanOldLogs(baseDir, agentID, maxAgeDays)
}
}
}
// cleanOldLogs removes .jsonl and .jsonl.gz files older than maxAgeDays.
func cleanOldLogs(baseDir, agentID string, maxAgeDays int) {
dir := filepath.Join(baseDir, agentID)
entries, err := os.ReadDir(dir)
if err != nil {
return
}
cutoff := time.Now().UTC().AddDate(0, 0, -maxAgeDays)
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !isLogFile(name) {
continue
}
date := parseDateFromFilename(name)
if date.IsZero() {
continue
}
if date.Before(cutoff) {
os.Remove(filepath.Join(dir, name))
}
}
}
// isLogFile returns true for .jsonl and .jsonl.gz files.
func isLogFile(name string) bool {
return strings.HasSuffix(name, ".jsonl") || strings.HasSuffix(name, ".jsonl.gz")
}
// parseDateFromFilename extracts YYYY-MM-DD from filenames like:
//
// 2026-03-06.jsonl
// 2026-03-06.1.jsonl
// 2026-03-06.jsonl.gz
func parseDateFromFilename(name string) time.Time {
// Strip extensions.
base := strings.TrimSuffix(name, ".gz")
base = strings.TrimSuffix(base, ".jsonl")
// Remove numeric suffix (e.g., ".1" from "2026-03-06.1").
if idx := strings.LastIndex(base, "."); idx >= 0 {
candidate := base[:idx]
if t, err := time.Parse("2006-01-02", candidate); err == nil {
return t
}
}
t, _ := time.Parse("2006-01-02", base)
return t
}
+110
View File
@@ -0,0 +1,110 @@
package logger
import (
"context"
"os"
"path/filepath"
"testing"
"time"
)
func TestCleanOldLogs(t *testing.T) {
dir := t.TempDir()
agentDir := filepath.Join(dir, "bot1")
os.MkdirAll(agentDir, 0o755)
// Create files: 10 days ago, 5 days ago, today.
files := []string{
"2026-02-24.jsonl",
"2026-02-24.jsonl.gz",
"2026-03-01.jsonl",
"2026-03-06.jsonl",
}
for _, f := range files {
os.WriteFile(filepath.Join(agentDir, f), []byte("{}"), 0o644)
}
// Retain 7 days → should remove 2026-02-24 files.
cleanOldLogs(dir, "bot1", 7)
remaining, _ := os.ReadDir(agentDir)
names := make(map[string]bool)
for _, e := range remaining {
names[e.Name()] = true
}
if names["2026-02-24.jsonl"] {
t.Error("2026-02-24.jsonl should have been removed")
}
if names["2026-02-24.jsonl.gz"] {
t.Error("2026-02-24.jsonl.gz should have been removed")
}
if !names["2026-03-01.jsonl"] {
t.Error("2026-03-01.jsonl should still exist")
}
if !names["2026-03-06.jsonl"] {
t.Error("2026-03-06.jsonl should still exist")
}
}
func TestParseDateFromFilename(t *testing.T) {
tests := []struct {
name string
want string
}{
{"2026-03-06.jsonl", "2026-03-06"},
{"2026-03-06.1.jsonl", "2026-03-06"},
{"2026-03-06.jsonl.gz", "2026-03-06"},
{"2026-03-06.2.jsonl.gz", "2026-03-06"},
{"invalid.jsonl", ""},
}
for _, tt := range tests {
d := parseDateFromFilename(tt.name)
got := ""
if !d.IsZero() {
got = d.Format("2006-01-02")
}
if got != tt.want {
t.Errorf("parseDateFromFilename(%q) = %q, want %q", tt.name, got, tt.want)
}
}
}
func TestCleanOldLogs_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Should not panic on non-existent agent dir.
cleanOldLogs(dir, "nonexistent", 7)
}
func TestIsLogFile(t *testing.T) {
if !isLogFile("2026-03-06.jsonl") {
t.Error("should match .jsonl")
}
if !isLogFile("2026-03-06.jsonl.gz") {
t.Error("should match .jsonl.gz")
}
if isLogFile("readme.txt") {
t.Error("should not match .txt")
}
}
func TestRunCleanup_Cancellation(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "bot1"), 0o755)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
runCleanup(ctx, dir, "bot1", 7, 50*time.Millisecond)
close(done)
}()
time.Sleep(150 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Error("cleanup goroutine did not exit after cancel")
}
}
+101
View File
@@ -0,0 +1,101 @@
// Package logger provides structured JSONL logging for agents with daily
// file rotation, size-based splitting, automatic cleanup, and query helpers.
package logger
import (
"context"
"log/slog"
"os"
"time"
)
// Standard field names for structured logging across all agents.
const (
FieldAgentID = "agent_id"
FieldTraceID = "trace_id"
FieldAction = "action"
FieldReason = "reason"
FieldDurationMS = "duration_ms"
FieldTokensUsed = "tokens_used"
FieldResult = "result"
FieldErrorType = "error_type"
FieldComponent = "component"
)
// traceKey is the context key for trace IDs.
type traceKey struct{}
// WithTraceID returns a new context carrying the given trace ID.
func WithTraceID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, traceKey{}, id)
}
// TraceIDFromCtx extracts the trace ID from ctx, or "" if absent.
func TraceIDFromCtx(ctx context.Context) string {
if v, ok := ctx.Value(traceKey{}).(string); ok {
return v
}
return ""
}
// LoggerConfig configures a per-agent logger.
type LoggerConfig struct {
BaseDir string // root log directory (default: "logs"); empty → stdout only
AgentID string // agent identifier (required)
MaxSizeMB int64 // max file size before rotation (default: 50)
MaxAgeDays int // retention in days (default: 7)
Compress bool // gzip rotated files (default: true)
CleanupInterval time.Duration // cleanup ticker interval (default: 24h)
Level slog.Level // minimum log level (default: INFO)
}
func (c *LoggerConfig) defaults() {
if c.BaseDir == "" {
c.BaseDir = "logs"
}
if c.MaxSizeMB <= 0 {
c.MaxSizeMB = 50
}
if c.MaxAgeDays <= 0 {
c.MaxAgeDays = 7
}
if c.CleanupInterval <= 0 {
c.CleanupInterval = 24 * time.Hour
}
}
// NewAgentLogger creates a structured JSON logger that writes to daily-rotated
// JSONL files under BaseDir/<AgentID>/. It returns:
// - a *slog.Logger pre-enriched with agent_id
// - a cleanup func to call on shutdown (closes files, stops cleanup goroutine)
// - an error if the log directory cannot be created
//
// If BaseDir is literally "stdout", the logger writes to os.Stdout with no
// file rotation or cleanup.
func NewAgentLogger(cfg LoggerConfig) (*slog.Logger, func(), error) {
if cfg.BaseDir == "stdout" {
h := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: cfg.Level})
l := slog.New(h).With(FieldAgentID, cfg.AgentID)
return l, func() {}, nil
}
cfg.defaults()
w, err := NewDailyRotatingWriter(cfg.BaseDir, cfg.AgentID, cfg.MaxSizeMB, cfg.Compress)
if err != nil {
return nil, nil, err
}
h := slog.NewJSONHandler(w, &slog.HandlerOptions{Level: cfg.Level})
l := slog.New(h).With(FieldAgentID, cfg.AgentID)
ctx, cancel := context.WithCancel(context.Background())
go runCleanup(ctx, cfg.BaseDir, cfg.AgentID, cfg.MaxAgeDays, cfg.CleanupInterval)
cleanup := func() {
cancel()
w.Close()
}
return l, cleanup, nil
}
+77
View File
@@ -0,0 +1,77 @@
package logger
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
)
func TestNewAgentLogger_WritesJSONL(t *testing.T) {
dir := t.TempDir()
l, cleanup, err := NewAgentLogger(LoggerConfig{
BaseDir: dir,
AgentID: "test-bot",
Level: slog.LevelDebug,
})
if err != nil {
t.Fatal(err)
}
defer cleanup()
l.Info("hello world", FieldAction, "greet", FieldReason, "testing")
// Force flush by closing.
cleanup()
files, _ := os.ReadDir(filepath.Join(dir, "test-bot"))
if len(files) == 0 {
t.Fatal("expected at least one log file")
}
data, err := os.ReadFile(filepath.Join(dir, "test-bot", files[0].Name()))
if err != nil {
t.Fatal(err)
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
t.Fatalf("log line is not valid JSON: %s", data)
}
if m["msg"] != "hello world" {
t.Errorf("msg = %v, want hello world", m["msg"])
}
if m[FieldAgentID] != "test-bot" {
t.Errorf("agent_id = %v, want test-bot", m[FieldAgentID])
}
if m[FieldAction] != "greet" {
t.Errorf("action = %v, want greet", m[FieldAction])
}
}
func TestNewAgentLogger_Stdout(t *testing.T) {
l, cleanup, err := NewAgentLogger(LoggerConfig{
BaseDir: "stdout",
AgentID: "dev-bot",
})
if err != nil {
t.Fatal(err)
}
defer cleanup()
// Just verify it doesn't panic.
l.Info("stdout test")
}
func TestTraceIDContext(t *testing.T) {
ctx := context.Background()
if got := TraceIDFromCtx(ctx); got != "" {
t.Errorf("empty ctx should return empty trace, got %q", got)
}
ctx = WithTraceID(ctx, "abc-123")
if got := TraceIDFromCtx(ctx); got != "abc-123" {
t.Errorf("trace = %q, want abc-123", got)
}
}
+149
View File
@@ -0,0 +1,149 @@
package logger
import (
"bufio"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
// ReadLogs returns all log entries for agentID between from and to (inclusive).
func ReadLogs(baseDir, agentID string, from, to time.Time) ([]json.RawMessage, error) {
var result []json.RawMessage
for d := from; !d.After(to); d = d.AddDate(0, 0, 1) {
entries, err := ReadDayLogs(baseDir, agentID, d)
if err != nil {
continue // skip missing days
}
result = append(result, entries...)
}
return result, nil
}
// ReadDayLogs returns all log entries for a specific day.
func ReadDayLogs(baseDir, agentID string, date time.Time) ([]json.RawMessage, error) {
dir := filepath.Join(baseDir, agentID)
prefix := date.Format("2006-01-02")
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("read dir %s: %w", dir, err)
}
var result []json.RawMessage
for _, e := range entries {
name := e.Name()
if !strings.HasPrefix(name, prefix) || !isLogFile(name) {
continue
}
lines, err := readLogFile(filepath.Join(dir, name))
if err != nil {
continue
}
result = append(result, lines...)
}
return result, nil
}
// SearchLogs returns log entries where field equals value, within the date range.
func SearchLogs(baseDir, agentID string, field, value string, from, to time.Time) ([]json.RawMessage, error) {
all, err := ReadLogs(baseDir, agentID, from, to)
if err != nil {
return nil, err
}
var matched []json.RawMessage
for _, raw := range all {
var m map[string]any
if json.Unmarshal(raw, &m) != nil {
continue
}
if v, ok := m[field]; ok && fmt.Sprint(v) == value {
matched = append(matched, raw)
}
}
return matched, nil
}
// ListAgents returns the agent IDs that have log directories under baseDir.
func ListAgents(baseDir string) ([]string, error) {
entries, err := os.ReadDir(baseDir)
if err != nil {
return nil, fmt.Errorf("read base dir %s: %w", baseDir, err)
}
var ids []string
for _, e := range entries {
if e.IsDir() {
ids = append(ids, e.Name())
}
}
sort.Strings(ids)
return ids, nil
}
// ListDates returns the dates for which logs exist for the given agent.
func ListDates(baseDir, agentID string) ([]time.Time, error) {
dir := filepath.Join(baseDir, agentID)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("read dir %s: %w", dir, err)
}
seen := make(map[string]bool)
var dates []time.Time
for _, e := range entries {
if e.IsDir() || !isLogFile(e.Name()) {
continue
}
d := parseDateFromFilename(e.Name())
if d.IsZero() {
continue
}
key := d.Format("2006-01-02")
if !seen[key] {
seen[key] = true
dates = append(dates, d)
}
}
sort.Slice(dates, func(i, j int) bool { return dates[i].Before(dates[j]) })
return dates, nil
}
// readLogFile reads all JSONL lines from a file (.jsonl or .jsonl.gz).
func readLogFile(path string) ([]json.RawMessage, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var r io.Reader = f
if strings.HasSuffix(path, ".gz") {
gz, err := gzip.NewReader(f)
if err != nil {
return nil, err
}
defer gz.Close()
r = gz
}
var lines []json.RawMessage
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
cp := make([]byte, len(line))
copy(cp, line)
lines = append(lines, json.RawMessage(cp))
}
return lines, scanner.Err()
}
+130
View File
@@ -0,0 +1,130 @@
package logger
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
)
func setupQueryDir(t *testing.T) string {
t.Helper()
dir := t.TempDir()
bot1 := filepath.Join(dir, "bot1")
bot2 := filepath.Join(dir, "bot2")
os.MkdirAll(bot1, 0o755)
os.MkdirAll(bot2, 0o755)
lines := []string{
`{"time":"2026-03-05T10:00:00Z","level":"INFO","msg":"hello","action":"greet"}`,
`{"time":"2026-03-05T11:00:00Z","level":"ERROR","msg":"oops","action":"fail"}`,
}
os.WriteFile(filepath.Join(bot1, "2026-03-05.jsonl"),
[]byte(lines[0]+"\n"+lines[1]+"\n"), 0o644)
os.WriteFile(filepath.Join(bot1, "2026-03-06.jsonl"),
[]byte(`{"time":"2026-03-06T09:00:00Z","level":"INFO","msg":"day2"}`+"\n"), 0o644)
os.WriteFile(filepath.Join(bot2, "2026-03-06.jsonl"),
[]byte(`{"time":"2026-03-06T08:00:00Z","level":"DEBUG","msg":"bot2 log"}`+"\n"), 0o644)
return dir
}
func TestListAgents(t *testing.T) {
dir := setupQueryDir(t)
agents, err := ListAgents(dir)
if err != nil {
t.Fatal(err)
}
if len(agents) != 2 {
t.Fatalf("expected 2 agents, got %d", len(agents))
}
if agents[0] != "bot1" || agents[1] != "bot2" {
t.Errorf("agents = %v, want [bot1 bot2]", agents)
}
}
func TestListDates(t *testing.T) {
dir := setupQueryDir(t)
dates, err := ListDates(dir, "bot1")
if err != nil {
t.Fatal(err)
}
if len(dates) != 2 {
t.Fatalf("expected 2 dates, got %d", len(dates))
}
if dates[0].Format("2006-01-02") != "2026-03-05" {
t.Errorf("first date = %v, want 2026-03-05", dates[0])
}
}
func TestReadDayLogs(t *testing.T) {
dir := setupQueryDir(t)
day := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
entries, err := ReadDayLogs(dir, "bot1", day)
if err != nil {
t.Fatal(err)
}
if len(entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(entries))
}
}
func TestReadLogs(t *testing.T) {
dir := setupQueryDir(t)
from := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
to := time.Date(2026, 3, 6, 0, 0, 0, 0, time.UTC)
entries, err := ReadLogs(dir, "bot1", from, to)
if err != nil {
t.Fatal(err)
}
if len(entries) != 3 {
t.Fatalf("expected 3 entries across 2 days, got %d", len(entries))
}
}
func TestSearchLogs(t *testing.T) {
dir := setupQueryDir(t)
from := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
to := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
results, err := SearchLogs(dir, "bot1", "action", "fail", from, to)
if err != nil {
t.Fatal(err)
}
if len(results) != 1 {
t.Fatalf("expected 1 match, got %d", len(results))
}
var m map[string]any
json.Unmarshal(results[0], &m)
if m["msg"] != "oops" {
t.Errorf("msg = %v, want oops", m["msg"])
}
}
func TestSearchLogs_NoMatch(t *testing.T) {
dir := setupQueryDir(t)
from := time.Date(2026, 3, 5, 0, 0, 0, 0, time.UTC)
to := time.Date(2026, 3, 6, 0, 0, 0, 0, time.UTC)
results, err := SearchLogs(dir, "bot1", "action", "nonexistent", from, to)
if err != nil {
t.Fatal(err)
}
if len(results) != 0 {
t.Errorf("expected 0 matches, got %d", len(results))
}
}
func TestListAgents_EmptyDir(t *testing.T) {
dir := t.TempDir()
agents, err := ListAgents(dir)
if err != nil {
t.Fatal(err)
}
if len(agents) != 0 {
t.Errorf("expected 0 agents, got %d", len(agents))
}
}
+168
View File
@@ -0,0 +1,168 @@
package logger
import (
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"time"
)
// DailyRotatingWriter is an io.Writer that rotates log files daily and by
// size. Files are named <baseDir>/<agentID>/YYYY-MM-DD.jsonl with optional
// numeric suffixes for size-based splits within the same day.
type DailyRotatingWriter struct {
baseDir string
agentID string
maxSize int64 // bytes
compress bool
nowFunc func() time.Time // for testing; defaults to time.Now().UTC
dir string // resolved agent log directory
mu sync.Mutex
current *os.File
written int64
currentDay string
suffix int
}
// NewDailyRotatingWriter creates a writer that stores logs under
// baseDir/agentID/. It creates the directory if needed and opens the first
// log file for today.
func NewDailyRotatingWriter(baseDir, agentID string, maxSizeMB int64, compress bool) (*DailyRotatingWriter, error) {
dir := filepath.Join(baseDir, agentID)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create log dir %s: %w", dir, err)
}
w := &DailyRotatingWriter{
baseDir: baseDir,
agentID: agentID,
maxSize: maxSizeMB * 1024 * 1024,
compress: compress,
nowFunc: func() time.Time { return time.Now().UTC() },
dir: dir,
}
if err := w.openFile(); err != nil {
return nil, err
}
return w, nil
}
// Write implements io.Writer with daily and size-based rotation.
func (w *DailyRotatingWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
today := w.nowFunc().Format("2006-01-02")
// Day changed → rotate to new day file.
if today != w.currentDay {
if err := w.rotate(today, 0); err != nil {
return 0, err
}
}
// Size exceeded → split within same day.
if w.written+int64(len(p)) > w.maxSize && w.written > 0 {
w.suffix++
if err := w.rotate(today, w.suffix); err != nil {
return 0, err
}
}
n, err := w.current.Write(p)
w.written += int64(n)
return n, err
}
// Close closes the current log file.
func (w *DailyRotatingWriter) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.current != nil {
return w.current.Close()
}
return nil
}
// rotate closes the current file (optionally compressing it) and opens a new one.
func (w *DailyRotatingWriter) rotate(day string, suffix int) error {
prev := w.current
prevPath := ""
if prev != nil {
prevPath = prev.Name()
prev.Close()
}
// Compress the previous file in the background if enabled and it's from a
// different day (we don't compress intra-day splits until day rotates).
if w.compress && prevPath != "" && day != w.currentDay {
go compressFile(prevPath)
}
w.currentDay = day
w.suffix = suffix
w.written = 0
return w.openFile()
}
// openFile opens (or creates) the log file for the current day/suffix.
func (w *DailyRotatingWriter) openFile() error {
w.currentDay = w.nowFunc().Format("2006-01-02")
name := w.filename(w.currentDay, w.suffix)
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return fmt.Errorf("open log file %s: %w", name, err)
}
// Track how much has already been written (append mode).
info, err := f.Stat()
if err != nil {
f.Close()
return err
}
w.current = f
w.written = info.Size()
return nil
}
// filename returns the full path for a given day and suffix.
func (w *DailyRotatingWriter) filename(day string, suffix int) string {
if suffix == 0 {
return filepath.Join(w.dir, day+".jsonl")
}
return filepath.Join(w.dir, fmt.Sprintf("%s.%d.jsonl", day, suffix))
}
// compressFile gzips src to src.gz and removes the original.
func compressFile(src string) {
in, err := os.Open(src)
if err != nil {
return
}
defer in.Close()
out, err := os.Create(src + ".gz")
if err != nil {
return
}
gz := gzip.NewWriter(out)
if _, err := io.Copy(gz, in); err != nil {
gz.Close()
out.Close()
os.Remove(src + ".gz")
return
}
gz.Close()
out.Close()
in.Close()
os.Remove(src)
}
+98
View File
@@ -0,0 +1,98 @@
package logger
import (
"os"
"path/filepath"
"sync"
"testing"
"time"
)
func TestDailyRotatingWriter_DayRotation(t *testing.T) {
dir := t.TempDir()
w, err := NewDailyRotatingWriter(dir, "bot1", 50, false)
if err != nil {
t.Fatal(err)
}
day1 := time.Date(2026, 3, 5, 12, 0, 0, 0, time.UTC)
day2 := time.Date(2026, 3, 6, 12, 0, 0, 0, time.UTC)
w.nowFunc = func() time.Time { return day1 }
// Force re-open with correct day.
w.current.Close()
w.currentDay = ""
w.openFile()
w.Write([]byte(`{"msg":"day1"}`))
w.nowFunc = func() time.Time { return day2 }
w.Write([]byte(`{"msg":"day2"}`))
w.Close()
agentDir := filepath.Join(dir, "bot1")
entries, _ := os.ReadDir(agentDir)
names := make(map[string]bool)
for _, e := range entries {
names[e.Name()] = true
}
if !names["2026-03-05.jsonl"] && !names["2026-03-05.jsonl.gz"] {
t.Error("expected 2026-03-05.jsonl or .gz")
}
if !names["2026-03-06.jsonl"] {
t.Error("expected 2026-03-06.jsonl")
}
}
func TestDailyRotatingWriter_SizeRotation(t *testing.T) {
dir := t.TempDir()
// 1 byte max to force rotation on every write.
w, err := NewDailyRotatingWriter(dir, "bot2", 0, false)
if err != nil {
t.Fatal(err)
}
// Override maxSize to a tiny value (can't use 0 MB).
w.maxSize = 10
now := time.Date(2026, 3, 6, 10, 0, 0, 0, time.UTC)
w.nowFunc = func() time.Time { return now }
w.current.Close()
w.currentDay = ""
w.openFile()
w.Write([]byte(`{"line":1}` + "\n"))
w.Write([]byte(`{"line":2}` + "\n"))
w.Write([]byte(`{"line":3}` + "\n"))
w.Close()
entries, _ := os.ReadDir(filepath.Join(dir, "bot2"))
if len(entries) < 2 {
t.Errorf("expected multiple files from size rotation, got %d", len(entries))
}
}
func TestDailyRotatingWriter_Concurrent(t *testing.T) {
dir := t.TempDir()
w, err := NewDailyRotatingWriter(dir, "bot3", 50, false)
if err != nil {
t.Fatal(err)
}
defer w.Close()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
w.Write([]byte(`{"concurrent":true}` + "\n"))
}()
}
wg.Wait()
entries, _ := os.ReadDir(filepath.Join(dir, "bot3"))
if len(entries) == 0 {
t.Error("expected at least one log file")
}
}
+159
View File
@@ -0,0 +1,159 @@
// Package mcp provides MCP client and server implementations.
package mcp
import (
"context"
"fmt"
"log/slog"
"os"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
)
// Client wraps an MCP client (stdio or SSE) and exposes discovered tools.
type Client struct {
name string
transport string // "stdio" | "sse"
mcpClient *client.Client
tools []mcp.Tool
logger *slog.Logger
}
// NewStdioClient creates an MCP client that connects to a stdio-based MCP server.
func NewStdioClient(ctx context.Context, name, command string, args []string, env map[string]string, logger *slog.Logger) (*Client, error) {
logger.Info("creating stdio MCP client", "name", name, "command", command, "args", args)
// Prepare environment
envSlice := os.Environ()
for k, v := range env {
envSlice = append(envSlice, fmt.Sprintf("%s=%s", k, v))
}
// Create stdio client
mcpClient, err := client.NewStdioMCPClient(command, envSlice, args...)
if err != nil {
return nil, fmt.Errorf("failed to create stdio client: %w", err)
}
// Initialize
initReq := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "agents-mcp-client",
Version: "1.0.0",
},
Capabilities: mcp.ClientCapabilities{},
},
}
_, err = mcpClient.Initialize(ctx, initReq)
if err != nil {
mcpClient.Close()
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
}
// Discover tools
toolsResp, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
mcpClient.Close()
return nil, fmt.Errorf("failed to list tools: %w", err)
}
logger.Info("discovered MCP tools", "name", name, "count", len(toolsResp.Tools))
return &Client{
name: name,
transport: "stdio",
mcpClient: mcpClient,
tools: toolsResp.Tools,
logger: logger,
}, nil
}
// NewSSEClient creates an MCP client that connects to an SSE/HTTP-based MCP server.
func NewSSEClient(ctx context.Context, name, url string, headers map[string]string, logger *slog.Logger) (*Client, error) {
logger.Info("creating SSE MCP client", "name", name, "url", url)
// Create SSE client (no custom headers support in basic API, would need transport options)
mcpClient, err := client.NewSSEMCPClient(url)
if err != nil {
return nil, fmt.Errorf("failed to create SSE client: %w", err)
}
// Initialize
initReq := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "agents-mcp-client",
Version: "1.0.0",
},
Capabilities: mcp.ClientCapabilities{},
},
}
_, err = mcpClient.Initialize(ctx, initReq)
if err != nil {
mcpClient.Close()
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
}
// Discover tools
toolsResp, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
mcpClient.Close()
return nil, fmt.Errorf("failed to list tools: %w", err)
}
logger.Info("discovered MCP tools", "name", name, "count", len(toolsResp.Tools))
return &Client{
name: name,
transport: "sse",
mcpClient: mcpClient,
tools: toolsResp.Tools,
logger: logger,
}, nil
}
// Tools returns the discovered MCP tools.
func (c *Client) Tools() []mcp.Tool {
return c.tools
}
// Name returns the client name.
func (c *Client) Name() string {
return c.name
}
// CallTool invokes an MCP tool by name with the given arguments.
func (c *Client) CallTool(ctx context.Context, name string, args map[string]any, timeout time.Duration) (*mcp.CallToolResult, error) {
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: name,
Arguments: args,
},
}
result, err := c.mcpClient.CallTool(ctx, req)
if err != nil {
return nil, fmt.Errorf("tool call failed: %w", err)
}
return result, nil
}
// Close closes the MCP client connection.
func (c *Client) Close() error {
c.logger.Info("closing MCP client", "name", c.name, "transport", c.transport)
return c.mcpClient.Close()
}
+130
View File
@@ -0,0 +1,130 @@
package mcp
import (
"context"
"fmt"
"log/slog"
"os"
"strings"
"github.com/enmanuel/agents/internal/config"
)
// Manager manages multiple MCP client connections.
type Manager struct {
clients map[string]*Client // server name → client
logger *slog.Logger
}
// NewManager creates a new MCP manager and initializes all configured servers.
func NewManager(ctx context.Context, servers []config.MCPServerCfg, logger *slog.Logger) (*Manager, error) {
if logger == nil {
logger = slog.Default()
}
m := &Manager{
clients: make(map[string]*Client),
logger: logger,
}
for _, serverCfg := range servers {
if err := m.addServer(ctx, serverCfg); err != nil {
// Close any already-created clients before returning error
m.Close()
return nil, fmt.Errorf("failed to initialize MCP server %q: %w", serverCfg.Name, err)
}
}
logger.Info("MCP manager initialized", "servers", len(m.clients))
return m, nil
}
// addServer creates and adds a single MCP client to the manager.
func (m *Manager) addServer(ctx context.Context, cfg config.MCPServerCfg) error {
if cfg.Name == "" {
return fmt.Errorf("MCP server must have a name")
}
// Auto-detect transport if not specified
transport := cfg.Transport
if transport == "" {
if cfg.Command != "" {
transport = "stdio"
} else if cfg.URL != "" {
transport = "sse"
} else {
return fmt.Errorf("MCP server %q must specify either command (stdio) or url (sse)", cfg.Name)
}
}
var client *Client
var err error
switch transport {
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("MCP server %q with stdio transport must have a command", cfg.Name)
}
// Expand environment variables in command and args
command := os.ExpandEnv(cfg.Command)
args := make([]string, len(cfg.Args))
for i, arg := range cfg.Args {
args[i] = os.ExpandEnv(arg)
}
// Expand env vars in environment map
env := make(map[string]string, len(cfg.Env))
for k, v := range cfg.Env {
env[k] = os.ExpandEnv(v)
}
client, err = NewStdioClient(ctx, cfg.Name, command, args, env, m.logger)
case "sse":
if cfg.URL == "" {
return fmt.Errorf("MCP server %q with sse transport must have a url", cfg.Name)
}
url := os.ExpandEnv(cfg.URL)
headers := make(map[string]string, len(cfg.Headers))
for k, v := range cfg.Headers {
headers[k] = os.ExpandEnv(v)
}
client, err = NewSSEClient(ctx, cfg.Name, url, headers, m.logger)
default:
return fmt.Errorf("unknown transport %q for MCP server %q (must be stdio or sse)", transport, cfg.Name)
}
if err != nil {
return err
}
m.clients[cfg.Name] = client
m.logger.Info("MCP server connected", "name", cfg.Name, "transport", transport, "tools", len(client.Tools()))
return nil
}
// GetClient returns an MCP client by name, or nil if not found.
func (m *Manager) GetClient(name string) *Client {
return m.clients[name]
}
// AllClients returns all MCP clients managed by this manager.
func (m *Manager) AllClients() map[string]*Client {
return m.clients
}
// Close closes all MCP client connections.
func (m *Manager) Close() error {
var errs []string
for name, client := range m.clients {
if err := client.Close(); err != nil {
errs = append(errs, fmt.Sprintf("%s: %v", name, err))
}
}
if len(errs) > 0 {
return fmt.Errorf("errors closing MCP clients: %s", strings.Join(errs, "; "))
}
m.logger.Info("MCP manager closed", "servers", len(m.clients))
return nil
}
+43
View File
@@ -0,0 +1,43 @@
// Package mcp provides MCP client and server implementations.
package mcp
import (
"context"
"fmt"
"log/slog"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/enmanuel/agents/pkg/tools"
)
// MCPServer exposes agent tools as an MCP server.
type MCPServer struct {
srv *server.MCPServer
logger *slog.Logger
}
// NewMCPServer creates an MCP server exposing the given tool specs.
func NewMCPServer(name, version string, specs []tools.ToolSpec, logger *slog.Logger) *MCPServer {
srv := server.NewMCPServer(name, version)
for _, spec := range specs {
spec := spec // capture
tool := mcp.NewTool(spec.Name,
mcp.WithDescription(spec.Description),
)
srv.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Placeholder handler — wire real execution here
return mcp.NewToolResultText(fmt.Sprintf("tool %s called", spec.Name)), nil
})
}
return &MCPServer{srv: srv, logger: logger}
}
// ServeStdio runs the MCP server over stdin/stdout (for Claude Desktop / CLI integration).
func (m *MCPServer) ServeStdio(ctx context.Context) error {
m.logger.Info("mcp server starting on stdio")
return server.ServeStdio(m.srv)
}
+191
View File
@@ -0,0 +1,191 @@
// Package shellmem implements persistent memory storage using SQLite.
package shellmem
import (
"context"
"database/sql"
"fmt"
"log/slog"
"os"
"path/filepath"
"time"
"github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/pkg/memory"
)
const schema = `
CREATE TABLE IF NOT EXISTS facts (
agent_id TEXT NOT NULL,
subject TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (agent_id, subject, key)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_id TEXT NOT NULL,
room_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_messages_room ON messages(agent_id, room_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(agent_id, subject);
`
// SQLiteStore implements memory.Store using SQLite.
type SQLiteStore struct {
db *sql.DB
logger *slog.Logger
}
// New opens (or creates) a SQLite database at dbPath and runs migrations.
func New(dbPath string, logger *slog.Logger) (*SQLiteStore, error) {
log := logger.With("component", "memory", "db_path", dbPath)
log.Info("memory_open")
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
return nil, fmt.Errorf("create memory db dir: %w", err)
}
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("open memory db: %w", err)
}
if _, err := db.Exec(schema); err != nil {
db.Close()
return nil, fmt.Errorf("migrate memory db: %w", err)
}
log.Info("memory_ready")
return &SQLiteStore{db: db, logger: log}, nil
}
func (s *SQLiteStore) SaveFact(ctx context.Context, f memory.Fact) error {
s.logger.Debug("memory_save_fact", "subject", f.Subject, "key", f.Key)
_, err := s.db.ExecContext(ctx,
`INSERT OR REPLACE INTO facts (agent_id, subject, key, value, updated_at)
VALUES (?, ?, ?, ?, ?)`,
f.AgentID, f.Subject, f.Key, f.Value, time.Now().UTC(),
)
if err != nil {
s.logger.Error("memory_save_fact_error", "subject", f.Subject, "key", f.Key, "err", err)
}
return err
}
func (s *SQLiteStore) RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]memory.Fact, error) {
var rows *sql.Rows
var err error
if key != nil {
rows, err = s.db.QueryContext(ctx,
`SELECT agent_id, subject, key, value, updated_at FROM facts
WHERE agent_id = ? AND subject = ? AND key = ?`,
agentID, subject, *key,
)
} else {
rows, err = s.db.QueryContext(ctx,
`SELECT agent_id, subject, key, value, updated_at FROM facts
WHERE agent_id = ? AND subject = ?`,
agentID, subject,
)
}
if err != nil {
return nil, err
}
defer rows.Close()
var facts []memory.Fact
for rows.Next() {
var f memory.Fact
if err := rows.Scan(&f.AgentID, &f.Subject, &f.Key, &f.Value, &f.UpdatedAt); err != nil {
return nil, err
}
facts = append(facts, f)
}
s.logger.Debug("memory_recall", "subject", subject, "count", len(facts))
return facts, rows.Err()
}
func (s *SQLiteStore) DeleteFacts(ctx context.Context, agentID, subject string, key *string) error {
if key != nil {
_, err := s.db.ExecContext(ctx,
`DELETE FROM facts WHERE agent_id = ? AND subject = ? AND key = ?`,
agentID, subject, *key,
)
return err
}
_, err := s.db.ExecContext(ctx,
`DELETE FROM facts WHERE agent_id = ? AND subject = ?`,
agentID, subject,
)
return err
}
func (s *SQLiteStore) SaveMessage(ctx context.Context, m memory.HistoryMessage) error {
s.logger.Debug("memory_save_msg", "room", m.RoomID, "role", m.Role)
_, err := s.db.ExecContext(ctx,
`INSERT INTO messages (agent_id, room_id, role, content, created_at)
VALUES (?, ?, ?, ?, ?)`,
m.AgentID, m.RoomID, string(m.Role), m.Content, time.Now().UTC(),
)
if err != nil {
s.logger.Error("memory_save_msg_error", "room", m.RoomID, "err", err)
}
return err
}
func (s *SQLiteStore) LoadMessages(ctx context.Context, agentID, roomID string, limit int) ([]memory.HistoryMessage, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT agent_id, room_id, role, content, created_at FROM messages
WHERE agent_id = ? AND room_id = ?
ORDER BY created_at DESC LIMIT ?`,
agentID, roomID, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var msgs []memory.HistoryMessage
for rows.Next() {
var m memory.HistoryMessage
var role string
if err := rows.Scan(&m.AgentID, &m.RoomID, &role, &m.Content, &m.CreatedAt); err != nil {
return nil, err
}
m.Role = llm.Role(role)
msgs = append(msgs, m)
}
if err := rows.Err(); err != nil {
return nil, err
}
// Reverse to chronological order
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
s.logger.Debug("memory_load_msgs", "room", roomID, "count", len(msgs))
return msgs, nil
}
func (s *SQLiteStore) DeleteMessages(ctx context.Context, agentID string, roomID *string) error {
if roomID != nil {
_, err := s.db.ExecContext(ctx,
`DELETE FROM messages WHERE agent_id = ? AND room_id = ?`,
agentID, *roomID,
)
return err
}
_, err := s.db.ExecContext(ctx,
`DELETE FROM messages WHERE agent_id = ?`,
agentID,
)
return err
}
func (s *SQLiteStore) Close() error {
s.logger.Info("memory_closed")
return s.db.Close()
}
+48
View File
@@ -0,0 +1,48 @@
package orchestration
import (
"context"
"encoding/json"
"fmt"
"strings"
coretypes "github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/pkg/orchestration"
)
// evaluate asks the LLM to score the quality of a bot's response.
func (o *Orchestrator) evaluate(ctx context.Context, question string, response orchestration.BotResponse) orchestration.QualityScore {
userContent := fmt.Sprintf("Question: %s\n\nResponse from %s:\n%s", question, response.BotID, response.Text)
resp, err := o.llm(ctx, coretypes.CompletionRequest{
Model: o.cfg.LLM.Primary.Model,
MaxTokens: o.cfg.LLM.Primary.MaxTokens,
Temperature: o.cfg.LLM.Primary.Temperature,
SystemPrompt: o.qualityPrompt,
Messages: []coretypes.Message{
{Role: coretypes.RoleUser, Content: userContent},
},
})
if err != nil {
o.logger.Error("quality evaluation LLM call failed", "err", err)
// On LLM failure, assume quality is good enough to stop the pipeline
return orchestration.QualityScore{
Score: 1.0,
Continue: false,
Reason: fmt.Sprintf("evaluation failed: %s, assuming good quality", err),
}
}
var qs orchestration.QualityScore
if err := json.Unmarshal([]byte(strings.TrimSpace(resp.Content)), &qs); err != nil {
o.logger.Warn("failed to parse quality score", "content", resp.Content, "err", err)
// On parse failure, assume good quality
return orchestration.QualityScore{
Score: 1.0,
Continue: false,
Reason: fmt.Sprintf("parse failed: %s", err),
}
}
return qs
}
+455
View File
@@ -0,0 +1,455 @@
// Package orchestration implements the multi-bot orchestrator runtime.
// The orchestrator intercepts events in managed rooms and coordinates which bot
// responds via the in-process bus.
//
// PARKED (Matrix-out, issue matrix-out): the orchestrator is no longer wired
// into the launcher. The room-discovery side (RoomScanner / SetScanner /
// ScanExistingRooms / evaluateRoom / NotifyMembership) was intrinsic to Matrix —
// it enumerated Matrix rooms via mautrix — and has been removed. What remains is
// the transport-neutral routing/quality pipeline, which compiles without any
// messaging fabric. Re-introducing auto-discovery over unibus
// (GET /rooms/{id}/members) is a later step.
package orchestration
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"github.com/enmanuel/agents/internal/config"
"github.com/enmanuel/agents/pkg/decision"
coretypes "github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/pkg/orchestration"
"github.com/enmanuel/agents/shell/bus"
shelllm "github.com/enmanuel/agents/shell/llm"
)
// Orchestrator coordinates multi-bot rooms. It has no transport identity —
// it intercepts events before they reach bots and delegates via the bus.
type Orchestrator struct {
cfg *config.SpecialConfig
llm coretypes.CompleteFunc
bus *bus.Bus
logger *slog.Logger
// mu protects managedRooms, participants, and knownBotIDs.
mu sync.RWMutex
managedRooms map[string][]string // roomID → []botID
participants map[string]orchestration.ParticipantInfo // botID → info
knownBotIDs map[string]string // senderID → botID
// Prompts loaded from files
routingPrompt string
qualityPrompt string
refinementPrompt string
// Dedup: multiple bots in the same room will each trigger Intercept().
// We use a set of "room:sender:content" keys to ensure only one fires.
seenMu sync.Mutex
seen map[string]bool
}
// New creates an Orchestrator from its config.
func New(cfg *config.SpecialConfig, agentBus *bus.Bus, logger *slog.Logger) (*Orchestrator, error) {
llmFunc, err := shelllm.FromConfig(cfg.LLM.Primary, logger.With("component", "llm"))
if err != nil {
return nil, fmt.Errorf("orchestrator LLM: %w", err)
}
managed := make(map[string][]string)
for _, room := range cfg.Orchestration.Rooms {
if room.RoomID == "" {
continue // skip empty room IDs (unset env vars)
}
managed[room.RoomID] = room.Participants
}
o := &Orchestrator{
cfg: cfg,
llm: llmFunc,
bus: agentBus,
managedRooms: managed,
participants: make(map[string]orchestration.ParticipantInfo),
knownBotIDs: make(map[string]string),
logger: logger,
seen: make(map[string]bool),
}
if err := o.loadPrompts(); err != nil {
return nil, fmt.Errorf("load prompts: %w", err)
}
return o, nil
}
// RegisterParticipant adds bot metadata used for LLM routing decisions.
func (o *Orchestrator) RegisterParticipant(info orchestration.ParticipantInfo) {
o.mu.Lock()
o.participants[info.ID] = info
if info.MatrixUserID != "" {
o.knownBotIDs[info.MatrixUserID] = info.ID
}
o.mu.Unlock()
o.logger.Debug("registered participant", "bot", info.ID, "sender_id", info.MatrixUserID)
}
// ShouldIntercept returns true if the room is managed by this orchestrator.
func (o *Orchestrator) ShouldIntercept(roomID string) bool {
o.mu.RLock()
_, ok := o.managedRooms[roomID]
o.mu.RUnlock()
return ok
}
// Intercept is the InterceptFunc used by bot listeners. It checks if the
// room is managed and, if so, starts the orchestration pipeline asynchronously.
// Returns true if the event was intercepted (all bots in the room should return true,
// but only the first one triggers actual routing — the rest are deduped).
func (o *Orchestrator) Intercept(ctx context.Context, msgCtx decision.MessageContext) bool {
if !o.ShouldIntercept(msgCtx.RoomID) {
return false
}
// Ignore messages from known bots to prevent feedback loops.
o.mu.RLock()
_, senderIsBot := o.knownBotIDs[msgCtx.SenderID]
o.mu.RUnlock()
if senderIsBot {
return true // suppress but don't route — bot's own message
}
// Dedup: multiple bots receive the same event. Only route once.
key := msgCtx.RoomID + ":" + msgCtx.SenderID + ":" + msgCtx.Content
o.seenMu.Lock()
if o.seen[key] {
o.seenMu.Unlock()
return true // still intercept (don't let the bot handle it) but don't route again
}
o.seen[key] = true
o.seenMu.Unlock()
// Route asynchronously so the listener isn't blocked.
// Clean up the dedup key after routing completes.
go func() {
defer func() {
o.seenMu.Lock()
delete(o.seen, key)
o.seenMu.Unlock()
}()
if err := o.Route(ctx, msgCtx); err != nil {
o.logger.Error("orchestration failed", "room", msgCtx.RoomID, "err", err)
}
}()
return true
}
// Route is the main entry point. Called when a human posts in a managed room.
// It decides which bot(s) should respond and dispatches tasks via the bus.
func (o *Orchestrator) Route(ctx context.Context, msgCtx decision.MessageContext) error {
o.mu.RLock()
participants, ok := o.managedRooms[msgCtx.RoomID]
participantsCopy := append([]string(nil), participants...)
o.mu.RUnlock()
if !ok {
return fmt.Errorf("room %s is not managed", msgCtx.RoomID)
}
o.logger.Info("orchestrating message",
"room", msgCtx.RoomID,
"sender", msgCtx.SenderID,
"participants", participantsCopy,
"content_preview", truncate(msgCtx.Content, 80),
)
// Optimization: single bot → dispatch directly without LLM
if len(participantsCopy) == 1 {
o.logger.Debug("single participant, dispatching directly", "bot", participantsCopy[0])
_, err := o.dispatchAndWait(ctx, participantsCopy[0], msgCtx, 0, nil)
return err
}
var responses []orchestration.BotResponse
var lastBot string
maxIter := o.cfg.Orchestration.MaxIterations
if maxIter <= 0 {
maxIter = 3
}
for i := 0; i < maxIter; i++ {
// Route: decide which bot responds
var target string
var err error
if i == 0 {
rd, routeErr := o.routeInitial(ctx, msgCtx.Content, participantsCopy)
if routeErr != nil {
o.logger.Error("routing failed, falling back to first participant", "err", routeErr)
target = participantsCopy[0]
} else {
target = rd.TargetBotID
o.logger.Info("routed to bot",
"bot", target,
"confidence", rd.Confidence,
"reason", rd.Reason,
"iteration", i,
)
}
} else {
rd, routeErr := o.routeRefinement(ctx, msgCtx.Content, responses, participantsCopy, lastBot)
if routeErr != nil {
o.logger.Warn("refinement routing failed, stopping pipeline", "err", routeErr)
break
}
target = rd.TargetBotID
o.logger.Info("refinement routed to bot",
"bot", target,
"reason", rd.Reason,
"iteration", i,
)
}
// Dispatch: send TaskEvent to bot via bus and wait for response
response, err := o.dispatchAndWait(ctx, target, msgCtx, i, responses)
if err != nil {
o.logger.Error("dispatch failed", "bot", target, "err", err)
break
}
responses = append(responses, response)
lastBot = target
o.logger.Info("bot responded",
"bot", target,
"response_len", len(response.Text),
"iteration", i,
)
// Fallback: detect circular conversations before quality evaluation
if o.detectRepetition(responses) {
o.logger.Warn("repetition detected, stopping pipeline to prevent circular conversation",
"iteration", i+1,
"total_responses", len(responses),
)
break
}
// Evaluate quality (Fase 3)
score := o.evaluate(ctx, msgCtx.Content, response)
o.logger.Info("quality evaluated",
"score", score.Score,
"continue", score.Continue,
"reason", score.Reason,
"iteration", i,
)
if score.Score >= o.cfg.Orchestration.QualityThreshold || !score.Continue {
o.logger.Info("pipeline complete",
"iterations", i+1,
"final_score", score.Score,
)
break
}
}
return nil
}
// dispatchAndWait sends a TaskEvent to a bot and waits for its response.
func (o *Orchestrator) dispatchAndWait(
ctx context.Context,
botID string,
msgCtx decision.MessageContext,
iteration int,
previousResponses []orchestration.BotResponse,
) (orchestration.BotResponse, error) {
taskID := fmt.Sprintf("orch-%s-%s-%d", msgCtx.RoomID, botID, iteration)
task := orchestration.TaskEvent{
TaskID: taskID,
TargetBotID: botID,
TargetRoomID: msgCtx.RoomID,
OriginalSender: msgCtx.SenderID,
OriginalQuestion: msgCtx.Content,
Iteration: iteration,
PreviousResponses: previousResponses,
}
taskJSON, err := orchestration.MarshalTaskEvent(task)
if err != nil {
return orchestration.BotResponse{}, fmt.Errorf("marshal task: %w", err)
}
msg := bus.AgentMessage{
From: bus.AgentID(o.cfg.Special.ID),
To: bus.AgentID(botID),
Kind: bus.KindTask,
Payload: map[string]string{"task_json": taskJSON},
}
timeout := o.cfg.Orchestration.DelegationTimeout
if timeout <= 0 {
timeout = 30_000_000_000 // 30s default
}
reply, err := o.bus.SendAndWait(ctx, msg, taskID, timeout)
if err != nil {
return orchestration.BotResponse{}, err
}
resultJSON, ok := reply.Payload["result_json"]
if !ok {
return orchestration.BotResponse{}, fmt.Errorf("reply missing result_json")
}
result, err := orchestration.UnmarshalTaskResult(resultJSON)
if err != nil {
return orchestration.BotResponse{}, fmt.Errorf("unmarshal result: %w", err)
}
if result.Error != "" {
return orchestration.BotResponse{}, fmt.Errorf("bot %s error: %s", botID, result.Error)
}
return orchestration.BotResponse{
BotID: botID,
Text: result.Text,
}, nil
}
// loadPrompts reads the orchestrator's prompt files.
func (o *Orchestrator) loadPrompts() error {
base := filepath.Join("agents", "specials", "orchestrator", "prompts")
routing, err := os.ReadFile(filepath.Join(base, "routing.md"))
if err != nil {
return fmt.Errorf("routing prompt: %w", err)
}
o.routingPrompt = string(routing)
quality, err := os.ReadFile(filepath.Join(base, "quality.md"))
if err != nil {
return fmt.Errorf("quality prompt: %w", err)
}
o.qualityPrompt = string(quality)
refinement, err := os.ReadFile(filepath.Join(base, "refinement.md"))
if err != nil {
return fmt.Errorf("refinement prompt: %w", err)
}
o.refinementPrompt = string(refinement)
return nil
}
// buildParticipantsList formats participant info for LLM prompts.
func (o *Orchestrator) buildParticipantsList(botIDs []string, exclude string) string {
o.mu.RLock()
defer o.mu.RUnlock()
var sb strings.Builder
for _, id := range botIDs {
if id == exclude {
continue
}
info, ok := o.participants[id]
if !ok {
sb.WriteString(fmt.Sprintf("- %s: (no description available)\n", id))
continue
}
caps := ""
if len(info.Capabilities) > 0 {
caps = fmt.Sprintf(" (capabilities: %s)", strings.Join(info.Capabilities, ", "))
}
sb.WriteString(fmt.Sprintf("- %s: %s%s\n", info.ID, info.Description, caps))
}
return sb.String()
}
// detectRepetition checks if a new response is too similar to previous responses,
// indicating a circular conversation that should be stopped.
// Returns true if the conversation should be terminated.
func (o *Orchestrator) detectRepetition(responses []orchestration.BotResponse) bool {
if len(responses) < 2 {
return false
}
threshold := o.cfg.Orchestration.RepetitionThreshold
if threshold <= 0 {
threshold = 0.6 // default
}
latest := responses[len(responses)-1].Text
for i := 0; i < len(responses)-1; i++ {
if similarity(latest, responses[i].Text) >= threshold {
return true
}
}
return false
}
// similarity computes a simple bigram-based similarity ratio between two strings.
// Returns a value between 0.0 (completely different) and 1.0 (identical).
func similarity(a, b string) float64 {
if a == b {
return 1.0
}
a = strings.ToLower(strings.TrimSpace(a))
b = strings.ToLower(strings.TrimSpace(b))
if a == b {
return 1.0
}
if len(a) < 2 || len(b) < 2 {
return 0.0
}
bigramsA := makeBigrams(a)
bigramsB := makeBigrams(b)
// Count intersection
intersection := 0
for bg, countA := range bigramsA {
if countB, ok := bigramsB[bg]; ok {
if countA < countB {
intersection += countA
} else {
intersection += countB
}
}
}
totalA := 0
for _, c := range bigramsA {
totalA += c
}
totalB := 0
for _, c := range bigramsB {
totalB += c
}
if totalA+totalB == 0 {
return 0.0
}
return float64(2*intersection) / float64(totalA+totalB)
}
func makeBigrams(s string) map[string]int {
runes := []rune(s)
bgs := make(map[string]int, len(runes))
for i := 0; i < len(runes)-1; i++ {
bg := string(runes[i : i+2])
bgs[bg]++
}
return bgs
}
func truncate(s string, n int) string {
runes := []rune(s)
if len(runes) <= n {
return s
}
return string(runes[:n]) + "..."
}
+107
View File
@@ -0,0 +1,107 @@
package orchestration
import (
"context"
"encoding/json"
"fmt"
"strings"
coretypes "github.com/enmanuel/agents/pkg/llm"
"github.com/enmanuel/agents/pkg/orchestration"
)
// routeInitial asks the LLM which bot should handle the question first.
func (o *Orchestrator) routeInitial(ctx context.Context, question string, participants []string) (orchestration.RoutingDecision, error) {
systemPrompt := strings.ReplaceAll(o.routingPrompt, "{{PARTICIPANTS}}", o.buildParticipantsList(participants, ""))
resp, err := o.llm(ctx, coretypes.CompletionRequest{
Model: o.cfg.LLM.Primary.Model,
MaxTokens: o.cfg.LLM.Primary.MaxTokens,
Temperature: o.cfg.LLM.Primary.Temperature,
SystemPrompt: systemPrompt,
Messages: []coretypes.Message{
{Role: coretypes.RoleUser, Content: question},
},
})
if err != nil {
return orchestration.RoutingDecision{}, fmt.Errorf("LLM routing call: %w", err)
}
var rd orchestration.RoutingDecision
if err := json.Unmarshal([]byte(strings.TrimSpace(resp.Content)), &rd); err != nil {
o.logger.Warn("failed to parse routing response, raw", "content", resp.Content, "err", err)
return orchestration.RoutingDecision{}, fmt.Errorf("parse routing decision: %w", err)
}
// Validate the chosen bot is actually a participant
if !contains(participants, rd.TargetBotID) {
o.logger.Warn("LLM chose unknown bot, falling back to first", "chosen", rd.TargetBotID)
rd.TargetBotID = participants[0]
rd.Confidence = 0.5
rd.Reason = "fallback: LLM chose unknown bot"
}
return rd, nil
}
// routeRefinement asks the LLM which bot should improve the response,
// excluding the last respondent.
func (o *Orchestrator) routeRefinement(
ctx context.Context,
question string,
responses []orchestration.BotResponse,
participants []string,
excludeBot string,
) (orchestration.RoutingDecision, error) {
lastResponse := ""
if len(responses) > 0 {
lastResponse = responses[len(responses)-1].Text
}
systemPrompt := strings.ReplaceAll(o.refinementPrompt, "{{PARTICIPANTS}}", o.buildParticipantsList(participants, excludeBot))
systemPrompt = strings.ReplaceAll(systemPrompt, "{{LAST_RESPONSE}}", lastResponse)
userContent := fmt.Sprintf("Original question: %s\n\nCurrent response that needs improvement:\n%s", question, lastResponse)
resp, err := o.llm(ctx, coretypes.CompletionRequest{
Model: o.cfg.LLM.Primary.Model,
MaxTokens: o.cfg.LLM.Primary.MaxTokens,
Temperature: o.cfg.LLM.Primary.Temperature,
SystemPrompt: systemPrompt,
Messages: []coretypes.Message{
{Role: coretypes.RoleUser, Content: userContent},
},
})
if err != nil {
return orchestration.RoutingDecision{}, fmt.Errorf("LLM refinement call: %w", err)
}
var rd orchestration.RoutingDecision
if err := json.Unmarshal([]byte(strings.TrimSpace(resp.Content)), &rd); err != nil {
o.logger.Warn("failed to parse refinement response", "content", resp.Content, "err", err)
return orchestration.RoutingDecision{}, fmt.Errorf("parse refinement decision: %w", err)
}
// Validate: must be a participant and not the excluded bot
if rd.TargetBotID == excludeBot || !contains(participants, rd.TargetBotID) {
// Pick first available that isn't excluded
for _, p := range participants {
if p != excludeBot {
rd.TargetBotID = p
rd.Reason = "fallback: LLM chose excluded or unknown bot"
break
}
}
}
return rd, nil
}
func contains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}
+692
View File
@@ -0,0 +1,692 @@
// Package process manages agent processes: discovery, start, stop, kill, stats.
// This is the impure shell layer — all I/O happens here.
package process
import (
"bufio"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/enmanuel/agents/internal/config"
)
// AgentInfo holds metadata about an agent parsed from its config.
type AgentInfo struct {
ID string
Name string
Version string
Desc string
ConfigPath string
Enabled bool
}
// AgentStatus combines agent metadata with runtime state.
type AgentStatus struct {
AgentInfo
Running bool
PID int
Instances int
}
// ProcessStats holds resource usage for a running process.
type ProcessStats struct {
PID int
UptimeSecs int64
MemRSSKB int64
CPUPct float64
LogBytes int64
}
// processProber abstracts process detection for testing.
type processProber interface {
// pgrepPIDs runs pgrep -f with the given pattern and returns matching PIDs.
pgrepPIDs(pattern string) []int
// processComm returns the comm name for a PID (e.g. "launcher", "go").
processComm(pid int) string
// isAlive checks if a PID is running.
isAlive(pid int) bool
}
// osProber is the real implementation using OS calls.
type osProber struct{}
func (osProber) pgrepPIDs(pattern string) []int {
out, err := exec.Command("pgrep", "-f", pattern).Output()
if err != nil {
return nil
}
var pids []int
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
if p, err := strconv.Atoi(strings.TrimSpace(line)); err == nil && p > 0 {
pids = append(pids, p)
}
}
return pids
}
func (osProber) processComm(pid int) string {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
if err != nil {
return ""
}
return strings.TrimSpace(string(data))
}
func (osProber) isAlive(pid int) bool {
return syscall.Kill(pid, 0) == nil
}
const unifiedID = "launcher" // PID/log file ID for the unified launcher
// Manager handles agent process lifecycle.
type Manager struct {
runDir string
agentsGlob string
binPath string
envFile string // path to .env file for child processes
prober processProber
}
// NewManager creates a Manager. binPath can be empty for auto-detection.
func NewManager(runDir, agentsGlob, binPath string) *Manager {
return &Manager{runDir: runDir, agentsGlob: agentsGlob, binPath: binPath, envFile: ".env", prober: osProber{}}
}
// Scan discovers all agents from config files.
func (m *Manager) Scan() ([]AgentInfo, error) {
matches, err := filepath.Glob(m.agentsGlob)
if err != nil {
return nil, err
}
var agents []AgentInfo
for _, path := range matches {
cfg, err := config.LoadMeta(path)
if err != nil {
continue
}
agents = append(agents, AgentInfo{
ID: cfg.Agent.ID,
Name: cfg.Agent.Name,
Version: cfg.Agent.Version,
Desc: cfg.Agent.Description,
ConfigPath: path,
Enabled: cfg.Agent.Enabled,
})
}
return agents, nil
}
// Status returns the runtime status for a single agent.
func (m *Manager) Status(info AgentInfo) AgentStatus {
pids := m.findProcessPIDs(info.ID)
primary := 0
if len(pids) > 0 {
primary = pids[0]
}
return AgentStatus{
AgentInfo: info,
Running: len(pids) > 0,
PID: primary,
Instances: len(pids),
}
}
// StatusAll returns status for every discovered agent.
func (m *Manager) StatusAll() ([]AgentStatus, error) {
agents, err := m.Scan()
if err != nil {
return nil, err
}
statuses := make([]AgentStatus, len(agents))
for i, a := range agents {
statuses[i] = m.Status(a)
}
return statuses, nil
}
// Start launches an agent process in the background.
// Returns an error if the agent is already running.
func (m *Manager) Start(info AgentInfo) error {
if pids := m.findProcessPIDs(info.ID); len(pids) > 0 {
return fmt.Errorf("agent %q is already running (PID %d)", info.ID, pids[0])
}
if err := os.MkdirAll(m.runDir, 0o755); err != nil {
return fmt.Errorf("create run dir: %w", err)
}
logFile, err := os.OpenFile(m.logPath(info.ID), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open log: %w", err)
}
bin := m.resolvedBin()
var cmd *exec.Cmd
if strings.HasPrefix(bin, "go run") {
cmd = exec.Command("go", "run", "-tags", "goolm", "./cmd/launcher", "-c", info.ConfigPath)
} else {
cmd = exec.Command(bin, "-c", info.ConfigPath)
}
cmd.Env = m.BuildEnv()
cmd.Stdout = logFile
cmd.Stderr = logFile
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
if err := cmd.Start(); err != nil {
logFile.Close()
return fmt.Errorf("exec: %w", err)
}
if err := os.WriteFile(m.pidPath(info.ID), []byte(strconv.Itoa(cmd.Process.Pid)), 0o644); err != nil {
return fmt.Errorf("write PID: %w", err)
}
go func() { _ = cmd.Wait() }()
return nil
}
// Stop sends SIGTERM to all instances, waits up to 5s, then SIGKILL if needed.
func (m *Manager) Stop(id string) error {
pids := m.findProcessPIDs(id)
// Also include PID file PID if alive and not already in the list
filePID := m.readPID(id)
if filePID > 0 && m.isAlive(filePID) {
found := false
for _, p := range pids {
if p == filePID {
found = true
break
}
}
if !found {
pids = append(pids, filePID)
}
}
if len(pids) == 0 {
return fmt.Errorf("agent %q is not running", id)
}
// SIGTERM all instances
for _, pid := range pids {
_ = syscall.Kill(pid, syscall.SIGTERM)
}
// Wait up to 5 seconds for graceful shutdown.
for i := 0; i < 10; i++ {
allDead := true
for _, pid := range pids {
if m.isAlive(pid) {
allDead = false
break
}
}
if allDead {
m.removePID(id)
return nil
}
time.Sleep(500 * time.Millisecond)
}
// Force kill survivors.
for _, pid := range pids {
if m.isAlive(pid) {
_ = syscall.Kill(pid, syscall.SIGKILL)
}
}
m.removePID(id)
return nil
}
// Kill sends SIGKILL to all instances immediately.
func (m *Manager) Kill(id string) error {
pids := m.findProcessPIDs(id)
filePID := m.readPID(id)
if filePID > 0 && m.isAlive(filePID) {
found := false
for _, p := range pids {
if p == filePID {
found = true
break
}
}
if !found {
pids = append(pids, filePID)
}
}
if len(pids) == 0 {
return fmt.Errorf("agent %q is not running", id)
}
var lastErr error
for _, pid := range pids {
if err := syscall.Kill(pid, syscall.SIGKILL); err != nil {
lastErr = err
}
}
m.removePID(id)
return lastErr
}
// Stats gathers resource usage for a running agent from /proc.
func (m *Manager) Stats(id string) (ProcessStats, error) {
pid := m.resolveRunningPID(id)
if pid == 0 {
return ProcessStats{}, fmt.Errorf("agent %q is not running", id)
}
return m.statsForPID(pid, id), nil
}
// statsForPID gathers resource usage for a specific PID.
func (m *Manager) statsForPID(pid int, id string) ProcessStats {
s := ProcessStats{PID: pid}
// Uptime from /proc/<pid>/stat
if data, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)); err == nil {
fields := strings.Fields(string(data))
if len(fields) > 21 {
startTicks, _ := strconv.ParseInt(fields[21], 10, 64)
clkTck := int64(100) // sysconf(_SC_CLK_TCK) is 100 on Linux
if raw, err := os.ReadFile("/proc/stat"); err == nil {
for _, line := range strings.Split(string(raw), "\n") {
if strings.HasPrefix(line, "btime ") {
btime, _ := strconv.ParseInt(strings.Fields(line)[1], 10, 64)
procStart := btime + startTicks/clkTck
s.UptimeSecs = time.Now().Unix() - procStart
break
}
}
}
}
}
// RSS from /proc/<pid>/status
if data, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid)); err == nil {
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "VmRSS:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
s.MemRSSKB, _ = strconv.ParseInt(fields[1], 10, 64)
}
break
}
}
}
// CPU% from ps (simpler than calculating from /proc/stat deltas)
if out, err := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "pcpu=").Output(); err == nil {
s.CPUPct, _ = strconv.ParseFloat(strings.TrimSpace(string(out)), 64)
}
// Log file size
if info, err := os.Stat(m.logPath(id)); err == nil {
s.LogBytes = info.Size()
}
return s
}
// LogTail returns the last N lines of an agent's log.
func (m *Manager) LogTail(id string, lines int) ([]string, error) {
f, err := os.Open(m.logPath(id))
if err != nil {
return nil, fmt.Errorf("open log: %w", err)
}
defer f.Close()
// Read all lines and keep last N. For large files a reverse scanner
// would be better, but agent logs are typically small.
var all []string
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
all = append(all, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, err
}
if len(all) > lines {
all = all[len(all)-lines:]
}
return all, nil
}
// IsRunning checks if an agent process is alive.
func (m *Manager) IsRunning(id string) bool {
return m.resolveRunningPID(id) > 0
}
// InstanceCount returns how many launcher processes are running for an agent.
func (m *Manager) InstanceCount(id string) int {
return len(m.findProcessPIDs(id))
}
// ReadPID returns the PID from the PID file, or 0.
func (m *Manager) ReadPID(id string) int {
return m.readPID(id)
}
// PidPath returns the path to the PID file for an agent.
func (m *Manager) PidPath(id string) string { return m.pidPath(id) }
// LogPath returns the path to the log file for an agent.
func (m *Manager) LogPath(id string) string { return m.logPath(id) }
// Build compiles all project binaries by running build.sh.
// Returns the combined output and any error.
func (m *Manager) Build() (string, error) {
cmd := exec.Command("bash", "build.sh")
cmd.Env = m.BuildEnv()
out, err := cmd.CombinedOutput()
return string(out), err
}
// ── Unified launcher ─────────────────────────────────────────────────────
// The unified launcher runs ALL enabled agents + orchestrator in a single
// process. PID → run/launcher.pid, log → run/launcher.log.
// StartUnified launches the unified launcher (no -c flag → discovers all agents).
func (m *Manager) StartUnified() error {
if m.IsUnifiedRunning() {
return fmt.Errorf("unified launcher is already running (PID %d)", m.readPID(unifiedID))
}
if err := os.MkdirAll(m.runDir, 0o755); err != nil {
return fmt.Errorf("create run dir: %w", err)
}
logFile, err := os.OpenFile(m.logPath(unifiedID), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open log: %w", err)
}
bin := m.resolvedBin()
var cmd *exec.Cmd
if strings.HasPrefix(bin, "go run") {
cmd = exec.Command("go", "run", "-tags", "goolm", "./cmd/launcher", "--log-level", "info")
} else {
cmd = exec.Command(bin, "--log-level", "info")
}
cmd.Env = m.BuildEnv()
cmd.Stdout = logFile
cmd.Stderr = logFile
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
if err := cmd.Start(); err != nil {
logFile.Close()
return fmt.Errorf("exec: %w", err)
}
if err := os.WriteFile(m.pidPath(unifiedID), []byte(strconv.Itoa(cmd.Process.Pid)), 0o644); err != nil {
return fmt.Errorf("write PID: %w", err)
}
go func() { _ = cmd.Wait() }()
return nil
}
// StopUnified stops the unified launcher process.
func (m *Manager) StopUnified() error {
return m.Stop(unifiedID)
}
// KillUnified sends SIGKILL to the unified launcher.
func (m *Manager) KillUnified() error {
return m.Kill(unifiedID)
}
// IsUnifiedRunning checks if the unified launcher is alive.
func (m *Manager) IsUnifiedRunning() bool {
pid := m.readPID(unifiedID)
if pid > 0 && m.isAlive(pid) {
return true
}
// Fallback: search for launcher running without -c flag
pids := m.findUnifiedPIDs()
return len(pids) > 0
}
// UnifiedPID returns the PID of the running unified launcher, or 0.
func (m *Manager) UnifiedPID() int {
pid := m.readPID(unifiedID)
if pid > 0 && m.isAlive(pid) {
return pid
}
pids := m.findUnifiedPIDs()
if len(pids) > 0 {
// Repair PID file
_ = os.WriteFile(m.pidPath(unifiedID), []byte(strconv.Itoa(pids[0])), 0o644)
return pids[0]
}
return 0
}
// UnifiedStats returns resource usage for the unified launcher process.
func (m *Manager) UnifiedStats() (ProcessStats, error) {
pid := m.UnifiedPID()
if pid == 0 {
return ProcessStats{}, fmt.Errorf("unified launcher is not running")
}
return m.statsForPID(pid, unifiedID), nil
}
// UnifiedLogTail returns the last N lines of the unified launcher log.
func (m *Manager) UnifiedLogTail(lines int) ([]string, error) {
return m.LogTail(unifiedID, lines)
}
// StatusAllUnified returns status for all agents, deriving "running" from
// whether the unified launcher is running + the agent is enabled.
func (m *Manager) StatusAllUnified() ([]AgentStatus, error) {
agents, err := m.Scan()
if err != nil {
return nil, err
}
launcherRunning := m.IsUnifiedRunning()
launcherPID := m.UnifiedPID()
statuses := make([]AgentStatus, len(agents))
for i, a := range agents {
running := launcherRunning && a.Enabled
pid := 0
instances := 0
if running {
pid = launcherPID
instances = 1
}
statuses[i] = AgentStatus{
AgentInfo: a,
Running: running,
PID: pid,
Instances: instances,
}
}
return statuses, nil
}
// ToggleEnabled sets the enabled field in an agent's config.yaml.
func (m *Manager) ToggleEnabled(id string, enabled bool) error {
agents, err := m.Scan()
if err != nil {
return err
}
for _, a := range agents {
if a.ID == id {
return m.setEnabledInConfig(a.ConfigPath, enabled)
}
}
return fmt.Errorf("agent %q not found", id)
}
// setEnabledInConfig rewrites the enabled field in a config.yaml.
func (m *Manager) setEnabledInConfig(path string, enabled bool) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
val := "false"
if enabled {
val = "true"
}
lines := strings.Split(string(data), "\n")
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "enabled:") {
// Preserve indentation
indent := line[:len(line)-len(strings.TrimLeft(line, " \t"))]
lines[i] = indent + "enabled: " + val
break
}
}
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0o644)
}
// findUnifiedPIDs finds launcher processes running without -c flag.
func (m *Manager) findUnifiedPIDs() []int {
// Search for launcher processes that do NOT have -c flag
raw := m.prober.pgrepPIDs("launcher.*--log-level")
var pids []int
for _, p := range raw {
comm := m.prober.processComm(p)
if comm == "go" {
continue
}
pids = append(pids, p)
}
return pids
}
// ── internal helpers ─────────────────────────────────────────────────────
func (m *Manager) pidPath(id string) string { return filepath.Join(m.runDir, id+".pid") }
func (m *Manager) logPath(id string) string { return filepath.Join(m.runDir, id+".log") }
func (m *Manager) readPID(id string) int {
raw, err := os.ReadFile(m.pidPath(id))
if err != nil {
return 0
}
pid, _ := strconv.Atoi(strings.TrimSpace(string(raw)))
return pid
}
// findProcessPIDs searches for running launcher processes for a given agent ID
// using pgrep. Filters out "go run" wrapper PIDs to avoid double-counting.
func (m *Manager) findProcessPIDs(id string) []int {
configPath := m.configPathFor(id)
if configPath == "" {
return nil
}
pattern := fmt.Sprintf("launcher.*-c.*%s", configPath)
raw := m.prober.pgrepPIDs(pattern)
// Filter out the "go" wrapper process that appears when using "go run".
var pids []int
for _, p := range raw {
comm := m.prober.processComm(p)
if comm == "go" {
continue
}
pids = append(pids, p)
}
return pids
}
// configPathFor returns the config file path for the given agent ID.
func (m *Manager) configPathFor(id string) string {
matches, err := filepath.Glob(m.agentsGlob)
if err != nil {
return ""
}
for _, path := range matches {
cfg, err := config.LoadMeta(path)
if err != nil {
continue
}
if cfg.Agent.ID == id {
return path
}
}
return ""
}
// resolveRunningPID returns the PID of the running agent, checking the PID file
// first and falling back to process discovery. It also repairs stale PID files.
func (m *Manager) resolveRunningPID(id string) int {
// Check PID file first
pid := m.readPID(id)
if pid > 0 && m.isAlive(pid) {
return pid
}
// PID file is stale or missing — search for actual processes
pids := m.findProcessPIDs(id)
if len(pids) > 0 {
// Repair the PID file with the first found process
_ = os.WriteFile(m.pidPath(id), []byte(strconv.Itoa(pids[0])), 0o644)
return pids[0]
}
// Clean up stale PID file
if pid > 0 {
m.removePID(id)
}
return 0
}
func (m *Manager) isAlive(pid int) bool {
return m.prober.isAlive(pid)
}
func (m *Manager) removePID(id string) {
_ = os.Remove(m.pidPath(id))
}
// BuildEnv returns the environment for child processes: current env + .env file vars.
func (m *Manager) BuildEnv() []string {
env := os.Environ()
if m.envFile == "" {
return env
}
data, err := os.ReadFile(m.envFile)
if err != nil {
return env
}
// Parse KEY=VALUE lines, skip comments and blanks.
seen := make(map[string]bool)
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if idx := strings.Index(line, "="); idx > 0 {
key := line[:idx]
seen[key] = true
env = append(env, line)
}
}
_ = seen // .env values appended last, so they override earlier entries
return env
}
func (m *Manager) resolvedBin() string {
if m.binPath != "" {
return m.binPath
}
if _, err := os.Stat("bin/launcher"); err == nil {
return "bin/launcher"
}
return "go run ./cmd/launcher"
}
+190
View File
@@ -0,0 +1,190 @@
package process
import (
"os"
"path/filepath"
"strconv"
"testing"
)
// fakeProber is a test double for processProber.
type fakeProber struct {
pids map[string][]int // pattern → PIDs
comms map[int]string // PID → comm name
alive map[int]bool // PID → is alive
}
func newFakeProber() *fakeProber {
return &fakeProber{
pids: make(map[string][]int),
comms: make(map[int]string),
alive: make(map[int]bool),
}
}
func (f *fakeProber) pgrepPIDs(pattern string) []int { return f.pids[pattern] }
func (f *fakeProber) processComm(pid int) string { return f.comms[pid] }
func (f *fakeProber) isAlive(pid int) bool { return f.alive[pid] }
// testManager creates a Manager with a temp dir, fake prober, and a config file.
func testManager(t *testing.T, fp *fakeProber) (*Manager, string) {
t.Helper()
dir := t.TempDir()
runDir := filepath.Join(dir, "run")
agentsDir := filepath.Join(dir, "agents", "test-bot")
_ = os.MkdirAll(runDir, 0o755)
_ = os.MkdirAll(agentsDir, 0o755)
// Minimal config.yaml so Scan() and configPathFor() work.
cfgPath := filepath.Join(agentsDir, "config.yaml")
_ = os.WriteFile(cfgPath, []byte(`agent:
id: test-bot
name: Test Bot
version: "0.1"
enabled: true
`), 0o644)
glob := filepath.Join(dir, "agents", "*", "config.yaml")
m := &Manager{
runDir: runDir,
agentsGlob: glob,
binPath: "/bin/true", // won't actually run
envFile: "",
prober: fp,
}
return m, cfgPath
}
func TestFindProcessPIDs_FiltersGoWrapper(t *testing.T) {
fp := newFakeProber()
m, cfgPath := testManager(t, fp)
// Simulate pgrep returning 2 PIDs: go wrapper (100) + real launcher (200).
pattern := "launcher.*-c.*" + cfgPath
fp.pids[pattern] = []int{100, 200}
fp.comms[100] = "go"
fp.comms[200] = "launcher"
pids := m.findProcessPIDs("test-bot")
if len(pids) != 1 {
t.Fatalf("expected 1 PID, got %d: %v", len(pids), pids)
}
if pids[0] != 200 {
t.Errorf("expected PID 200, got %d", pids[0])
}
}
func TestFindProcessPIDs_NoPIDs(t *testing.T) {
fp := newFakeProber()
m, _ := testManager(t, fp)
pids := m.findProcessPIDs("test-bot")
if len(pids) != 0 {
t.Fatalf("expected 0 PIDs, got %d", len(pids))
}
}
func TestStatus_SingleInstance(t *testing.T) {
fp := newFakeProber()
m, cfgPath := testManager(t, fp)
pattern := "launcher.*-c.*" + cfgPath
fp.pids[pattern] = []int{42}
fp.comms[42] = "launcher"
info := AgentInfo{ID: "test-bot", Name: "Test", ConfigPath: cfgPath, Enabled: true}
st := m.Status(info)
if !st.Running {
t.Error("expected Running=true")
}
if st.PID != 42 {
t.Errorf("expected PID=42, got %d", st.PID)
}
if st.Instances != 1 {
t.Errorf("expected Instances=1, got %d", st.Instances)
}
}
func TestStatus_NoInstances(t *testing.T) {
fp := newFakeProber()
m, cfgPath := testManager(t, fp)
info := AgentInfo{ID: "test-bot", Name: "Test", ConfigPath: cfgPath, Enabled: true}
st := m.Status(info)
if st.Running {
t.Error("expected Running=false")
}
if st.Instances != 0 {
t.Errorf("expected Instances=0, got %d", st.Instances)
}
}
func TestStart_RejectsWhenAlreadyRunning(t *testing.T) {
fp := newFakeProber()
m, cfgPath := testManager(t, fp)
pattern := "launcher.*-c.*" + cfgPath
fp.pids[pattern] = []int{99}
fp.comms[99] = "launcher"
info := AgentInfo{ID: "test-bot", Name: "Test", ConfigPath: cfgPath, Enabled: true}
err := m.Start(info)
if err == nil {
t.Fatal("expected error when agent already running")
}
if got := err.Error(); got != `agent "test-bot" is already running (PID 99)` {
t.Errorf("unexpected error: %s", got)
}
}
func TestResolveRunningPID_RepairsStale(t *testing.T) {
fp := newFakeProber()
m, cfgPath := testManager(t, fp)
// Write a stale PID file (PID 999 is dead).
_ = os.MkdirAll(m.runDir, 0o755)
_ = os.WriteFile(m.pidPath("test-bot"), []byte("999"), 0o644)
fp.alive[999] = false
// But the real process is at PID 42.
pattern := "launcher.*-c.*" + cfgPath
fp.pids[pattern] = []int{42}
fp.comms[42] = "launcher"
pid := m.resolveRunningPID("test-bot")
if pid != 42 {
t.Errorf("expected repaired PID=42, got %d", pid)
}
// Verify PID file was repaired.
data, err := os.ReadFile(m.pidPath("test-bot"))
if err != nil {
t.Fatalf("read pid file: %v", err)
}
if got, _ := strconv.Atoi(string(data)); got != 42 {
t.Errorf("expected PID file to contain 42, got %d", got)
}
}
func TestResolveRunningPID_CleansUpStalePIDFile(t *testing.T) {
fp := newFakeProber()
m, _ := testManager(t, fp)
// Write a stale PID file, no real process running.
_ = os.MkdirAll(m.runDir, 0o755)
_ = os.WriteFile(m.pidPath("test-bot"), []byte("999"), 0o644)
fp.alive[999] = false
pid := m.resolveRunningPID("test-bot")
if pid != 0 {
t.Errorf("expected 0 for dead process, got %d", pid)
}
// PID file should be removed.
if _, err := os.Stat(m.pidPath("test-bot")); !os.IsNotExist(err) {
t.Error("expected stale PID file to be removed")
}
}
+148
View File
@@ -0,0 +1,148 @@
// Package security provides the impure loader for security policy YAML files.
// It reads security/ directory files and returns a pure security.SecurityPolicy.
package security
import (
"errors"
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
"github.com/enmanuel/agents/pkg/security"
)
// --- YAML intermediate types (private, only for parsing) ---
type yamlUserGroups struct {
Groups map[string]struct {
Members []string `yaml:"members"`
} `yaml:"groups"`
}
type yamlAgentGroups struct {
Groups map[string]struct {
Agents []string `yaml:"agents"`
} `yaml:"groups"`
}
type yamlPermissions struct {
Policies []struct {
AgentGroup string `yaml:"agent_group"`
Permissions []struct {
UserGroup string `yaml:"user_group"`
Actions []string `yaml:"actions"`
} `yaml:"permissions"`
} `yaml:"policies"`
}
// Load reads the security YAML files from dir and returns a SecurityPolicy.
// If dir does not exist or is empty, returns an empty policy without error.
// If an individual file is missing, that section is left empty.
// If a YAML file is malformed, returns an error naming the file.
func Load(dir string) (security.SecurityPolicy, error) {
if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) {
return security.SecurityPolicy{}, nil
}
userGroups, err := loadUserGroups(filepath.Join(dir, "user-groups.yaml"))
if err != nil {
return security.SecurityPolicy{}, err
}
agentGroups, err := loadAgentGroups(filepath.Join(dir, "agent-groups.yaml"))
if err != nil {
return security.SecurityPolicy{}, err
}
policies, err := loadPermissions(filepath.Join(dir, "permissions.yaml"))
if err != nil {
return security.SecurityPolicy{}, err
}
return security.SecurityPolicy{
UserGroups: userGroups,
AgentGroups: agentGroups,
Policies: policies,
}, nil
}
func loadUserGroups(path string) ([]security.UserGroup, error) {
data, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("security: reading %s: %w", path, err)
}
var raw yamlUserGroups
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("security: parsing %s: %w", path, err)
}
groups := make([]security.UserGroup, 0, len(raw.Groups))
for name, g := range raw.Groups {
groups = append(groups, security.UserGroup{
Name: name,
Members: g.Members,
})
}
return groups, nil
}
func loadAgentGroups(path string) ([]security.AgentGroup, error) {
data, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("security: reading %s: %w", path, err)
}
var raw yamlAgentGroups
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("security: parsing %s: %w", path, err)
}
groups := make([]security.AgentGroup, 0, len(raw.Groups))
for name, g := range raw.Groups {
groups = append(groups, security.AgentGroup{
Name: name,
Agents: g.Agents,
})
}
return groups, nil
}
func loadPermissions(path string) ([]security.AgentPolicy, error) {
data, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("security: reading %s: %w", path, err)
}
var raw yamlPermissions
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("security: parsing %s: %w", path, err)
}
policies := make([]security.AgentPolicy, 0, len(raw.Policies))
for _, p := range raw.Policies {
perms := make([]security.Permission, 0, len(p.Permissions))
for _, perm := range p.Permissions {
perms = append(perms, security.Permission{
UserGroup: perm.UserGroup,
Actions: perm.Actions,
})
}
policies = append(policies, security.AgentPolicy{
AgentGroup: p.AgentGroup,
Permissions: perms,
})
}
return policies, nil
}
+189
View File
@@ -0,0 +1,189 @@
package security_test
import (
"os"
"path/filepath"
"testing"
shellsecurity "github.com/enmanuel/agents/shell/security"
)
// writeFile is a helper that creates a file in dir with the given content.
func writeFile(t *testing.T, dir, name, content string) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644); err != nil {
t.Fatalf("writeFile %s: %v", name, err)
}
}
// --- Test 3.1: directorio inexistente → policy vacía, sin error ---
func TestLoad_NonExistentDir(t *testing.T) {
policy, err := shellsecurity.Load("/tmp/does-not-exist-security-xyz")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if len(policy.UserGroups) != 0 || len(policy.AgentGroups) != 0 || len(policy.Policies) != 0 {
t.Errorf("expected empty policy, got: %+v", policy)
}
}
// --- Test 3.2: directorio vacío (sin YAML) → policy vacía, sin error ---
func TestLoad_EmptyDir(t *testing.T) {
dir := t.TempDir()
policy, err := shellsecurity.Load(dir)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if len(policy.UserGroups) != 0 || len(policy.AgentGroups) != 0 || len(policy.Policies) != 0 {
t.Errorf("expected empty policy, got: %+v", policy)
}
}
// --- Test 3.3: los 3 YAML válidos → policy con todos los campos ---
func TestLoad_AllFiles(t *testing.T) {
dir := t.TempDir()
writeFile(t, dir, "user-groups.yaml", `
groups:
admins:
members: ["@admin:example.com"]
everyone:
members: ["*"]
`)
writeFile(t, dir, "agent-groups.yaml", `
groups:
assistants:
agents:
- assistant-bot
all:
agents: ["*"]
`)
writeFile(t, dir, "permissions.yaml", `
policies:
- agent_group: all
permissions:
- user_group: admins
actions: ["*"]
- user_group: everyone
actions: ["ask"]
`)
policy, err := shellsecurity.Load(dir)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(policy.UserGroups) != 2 {
t.Errorf("expected 2 user groups, got %d", len(policy.UserGroups))
}
if len(policy.AgentGroups) != 2 {
t.Errorf("expected 2 agent groups, got %d", len(policy.AgentGroups))
}
if len(policy.Policies) != 1 {
t.Errorf("expected 1 policy, got %d", len(policy.Policies))
}
if len(policy.Policies[0].Permissions) != 2 {
t.Errorf("expected 2 permissions, got %d", len(policy.Policies[0].Permissions))
}
}
// --- Test 3.4: solo user-groups.yaml → user groups poblados, resto vacío ---
func TestLoad_OnlyUserGroups(t *testing.T) {
dir := t.TempDir()
writeFile(t, dir, "user-groups.yaml", `
groups:
admins:
members: ["@admin:example.com"]
`)
policy, err := shellsecurity.Load(dir)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(policy.UserGroups) != 1 {
t.Errorf("expected 1 user group, got %d", len(policy.UserGroups))
}
if policy.UserGroups[0].Name != "admins" {
t.Errorf("expected group name 'admins', got %q", policy.UserGroups[0].Name)
}
if len(policy.AgentGroups) != 0 {
t.Errorf("expected no agent groups, got %d", len(policy.AgentGroups))
}
if len(policy.Policies) != 0 {
t.Errorf("expected no policies, got %d", len(policy.Policies))
}
}
// --- Test 3.5: YAML malformado → error con nombre de archivo en el mensaje ---
func TestLoad_MalformedYAML(t *testing.T) {
dir := t.TempDir()
writeFile(t, dir, "user-groups.yaml", `this: is: not: valid: yaml: [`)
_, err := shellsecurity.Load(dir)
if err == nil {
t.Fatal("expected error for malformed YAML, got nil")
}
if got := err.Error(); len(got) == 0 {
t.Fatal("error message is empty")
}
// Must mention the filename
if !containsString(err.Error(), "user-groups.yaml") {
t.Errorf("error message should contain filename, got: %s", err.Error())
}
}
// --- Test 3.6: "*" como string literal en members y agents ---
func TestLoad_WildcardStrings(t *testing.T) {
dir := t.TempDir()
writeFile(t, dir, "user-groups.yaml", `
groups:
everyone:
members: ["*"]
`)
writeFile(t, dir, "agent-groups.yaml", `
groups:
all:
agents: ["*"]
`)
policy, err := shellsecurity.Load(dir)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(policy.UserGroups) != 1 {
t.Fatalf("expected 1 user group, got %d", len(policy.UserGroups))
}
if len(policy.UserGroups[0].Members) != 1 || policy.UserGroups[0].Members[0] != "*" {
t.Errorf("expected members=[\"*\"], got %v", policy.UserGroups[0].Members)
}
if len(policy.AgentGroups) != 1 {
t.Fatalf("expected 1 agent group, got %d", len(policy.AgentGroups))
}
if len(policy.AgentGroups[0].Agents) != 1 || policy.AgentGroups[0].Agents[0] != "*" {
t.Errorf("expected agents=[\"*\"], got %v", policy.AgentGroups[0].Agents)
}
}
func containsString(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub))
}
func containsSubstr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
+110
View File
@@ -0,0 +1,110 @@
package skills
import (
"bytes"
"context"
"fmt"
"os/exec"
"path/filepath"
"strings"
"time"
)
// Executor ejecuta scripts de skills de forma segura con allowlist de interpreters.
type Executor struct {
allowedInterpreters []string
timeout time.Duration
}
// NewExecutor crea un nuevo Executor con la configuracion dada.
// Si allowedInterpreters esta vacio, se usa un default de ["bash", "sh"].
func NewExecutor(allowedInterpreters []string, timeout time.Duration) *Executor {
if len(allowedInterpreters) == 0 {
allowedInterpreters = []string{"bash", "sh"}
}
if timeout == 0 {
timeout = 60 * time.Second
}
return &Executor{
allowedInterpreters: allowedInterpreters,
timeout: timeout,
}
}
// Run ejecuta un script de skill con los argumentos dados.
// scriptPath es la ruta absoluta al script.
// args son los argumentos pasados al script.
//
// El script debe tener una extension reconocida (.sh, .bash, .py, etc.) o
// un shebang que indique el interprete.
//
// Retorna stdout+stderr combinados y error si falla.
func (e *Executor) Run(ctx context.Context, scriptPath string, args []string) (string, error) {
// Inferir interprete desde extension
interpreter, err := e.inferInterpreter(scriptPath)
if err != nil {
return "", err
}
// Validar que el interprete esta en la allowlist
if !e.isAllowed(interpreter) {
return "", fmt.Errorf("interpreter not allowed: %s (allowed: %v)", interpreter, e.allowedInterpreters)
}
// Construir comando
cmdArgs := append([]string{scriptPath}, args...)
cmd := exec.CommandContext(ctx, interpreter, cmdArgs...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// Aplicar timeout
timeoutCtx, cancel := context.WithTimeout(ctx, e.timeout)
defer cancel()
cmd = exec.CommandContext(timeoutCtx, interpreter, cmdArgs...)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
output := stdout.String() + stderr.String()
if timeoutCtx.Err() == context.DeadlineExceeded {
return output, fmt.Errorf("script timeout exceeded (%s)", e.timeout)
}
if err != nil {
return output, fmt.Errorf("script failed: %w", err)
}
return output, nil
}
// inferInterpreter detecta el interprete a usar desde la extension del archivo.
func (e *Executor) inferInterpreter(path string) (string, error) {
ext := strings.ToLower(filepath.Ext(path))
switch ext {
case ".sh", ".bash":
return "bash", nil
case ".py":
return "python3", nil
case ".rb":
return "ruby", nil
case ".js":
return "node", nil
default:
return "", fmt.Errorf("unsupported script extension: %s", ext)
}
}
// isAllowed verifica si un interprete esta en la allowlist.
func (e *Executor) isAllowed(interpreter string) bool {
for _, allowed := range e.allowedInterpreters {
if allowed == interpreter {
return true
}
}
return false
}
+127
View File
@@ -0,0 +1,127 @@
package skills
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestExecutor(t *testing.T) {
tmpDir := t.TempDir()
// Create a simple bash script
scriptPath := filepath.Join(tmpDir, "test.sh")
scriptContent := `#!/bin/bash
echo "Hello from script"
echo "Args: $@"
`
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
t.Fatal(err)
}
// Create a script that times out
timeoutScriptPath := filepath.Join(tmpDir, "timeout.sh")
timeoutContent := `#!/bin/bash
sleep 10
`
if err := os.WriteFile(timeoutScriptPath, []byte(timeoutContent), 0755); err != nil {
t.Fatal(err)
}
// Create a failing script
failScriptPath := filepath.Join(tmpDir, "fail.sh")
failContent := `#!/bin/bash
exit 1
`
if err := os.WriteFile(failScriptPath, []byte(failContent), 0755); err != nil {
t.Fatal(err)
}
executor := NewExecutor([]string{"bash", "sh"}, 2*time.Second)
// Test successful execution
t.Run("successful_execution", func(t *testing.T) {
ctx := context.Background()
output, err := executor.Run(ctx, scriptPath, []string{"arg1", "arg2"})
if err != nil {
t.Fatalf("Run failed: %v", err)
}
if !strings.Contains(output, "Hello from script") {
t.Errorf("expected 'Hello from script' in output, got: %q", output)
}
if !strings.Contains(output, "Args: arg1 arg2") {
t.Errorf("expected 'Args: arg1 arg2' in output, got: %q", output)
}
})
// Test timeout
t.Run("timeout", func(t *testing.T) {
ctx := context.Background()
_, err := executor.Run(ctx, timeoutScriptPath, nil)
if err == nil {
t.Error("expected timeout error")
}
if !strings.Contains(err.Error(), "timeout") {
t.Errorf("expected timeout error, got: %v", err)
}
})
// Test script failure
t.Run("script_failure", func(t *testing.T) {
ctx := context.Background()
_, err := executor.Run(ctx, failScriptPath, nil)
if err == nil {
t.Error("expected script failure error")
}
})
// Test disallowed interpreter
t.Run("disallowed_interpreter", func(t *testing.T) {
pyScriptPath := filepath.Join(tmpDir, "test.py")
pyContent := `#!/usr/bin/env python3
print("hello")
`
if err := os.WriteFile(pyScriptPath, []byte(pyContent), 0755); err != nil {
t.Fatal(err)
}
ctx := context.Background()
_, err := executor.Run(ctx, pyScriptPath, nil)
if err == nil {
t.Error("expected error for disallowed interpreter")
}
if !strings.Contains(err.Error(), "not allowed") {
t.Errorf("expected 'not allowed' error, got: %v", err)
}
})
// Test allowed python interpreter
t.Run("allowed_python", func(t *testing.T) {
pyExecutor := NewExecutor([]string{"python3"}, 2*time.Second)
pyScriptPath := filepath.Join(tmpDir, "hello.py")
pyContent := `#!/usr/bin/env python3
print("Hello from Python")
`
if err := os.WriteFile(pyScriptPath, []byte(pyContent), 0755); err != nil {
t.Fatal(err)
}
ctx := context.Background()
output, err := pyExecutor.Run(ctx, pyScriptPath, nil)
if err != nil {
t.Fatalf("Run failed: %v", err)
}
if !strings.Contains(output, "Hello from Python") {
t.Errorf("expected 'Hello from Python' in output, got: %q", output)
}
})
}
+223
View File
@@ -0,0 +1,223 @@
package skills
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/enmanuel/agents/pkg/skills"
"gopkg.in/yaml.v3"
)
// Loader descubre y carga skills desde un directorio base.
type Loader struct {
basePath string
}
// NewLoader crea un nuevo Loader apuntando al directorio de skills.
func NewLoader(basePath string) *Loader {
return &Loader{basePath: basePath}
}
// LoadMeta carga solo la metadata (nivel 1) de todas las skills.
// Recorre el directorio base buscando SKILL.md y extrae el frontmatter YAML.
func (l *Loader) LoadMeta() ([]skills.SkillMeta, error) {
var metas []skills.SkillMeta
// Recorre categorias (devops/, analysis/, etc.)
categories, err := os.ReadDir(l.basePath)
if err != nil {
return nil, fmt.Errorf("read skills dir: %w", err)
}
for _, catEntry := range categories {
if !catEntry.IsDir() {
continue
}
category := catEntry.Name()
catPath := filepath.Join(l.basePath, category)
// Recorre skills dentro de la categoria
skillDirs, err := os.ReadDir(catPath)
if err != nil {
continue
}
for _, skillEntry := range skillDirs {
if !skillEntry.IsDir() {
continue
}
skillName := skillEntry.Name()
skillPath := filepath.Join(catPath, skillName)
skillMdPath := filepath.Join(skillPath, "SKILL.md")
// Verificar que existe SKILL.md
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
continue
}
// Parsear metadata
meta, _, err := parseSkillMD(skillMdPath)
if err != nil {
continue // skip invalid skills
}
meta.Category = category
metas = append(metas, meta)
}
}
return metas, nil
}
// LoadSkill carga una skill completa (nivel 2) por nombre.
// Retorna el struct Skill con metadata, instrucciones y listado de recursos.
func (l *Loader) LoadSkill(name string) (*skills.Skill, error) {
// Buscar en todas las categorias
categories, err := os.ReadDir(l.basePath)
if err != nil {
return nil, fmt.Errorf("read skills dir: %w", err)
}
for _, catEntry := range categories {
if !catEntry.IsDir() {
continue
}
category := catEntry.Name()
skillPath := filepath.Join(l.basePath, category, name)
skillMdPath := filepath.Join(skillPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
continue
}
// Parsear skill completa
meta, instructions, err := parseSkillMD(skillMdPath)
if err != nil {
return nil, fmt.Errorf("parse %s: %w", skillMdPath, err)
}
meta.Category = category
skill := &skills.Skill{
Meta: meta,
Instructions: instructions,
BasePath: skillPath,
Scripts: listFiles(filepath.Join(skillPath, "scripts")),
References: listFiles(filepath.Join(skillPath, "references")),
Templates: listFiles(filepath.Join(skillPath, "templates")),
Assets: listFiles(filepath.Join(skillPath, "assets")),
}
return skill, nil
}
return nil, fmt.Errorf("skill not found: %s", name)
}
// ReadResource lee un recurso especifico (nivel 3) de una skill.
// path es relativo a la skill (ej: "scripts/deploy.sh", "references/api.md").
func (l *Loader) ReadResource(skillName, resourcePath string) (string, error) {
skill, err := l.LoadSkill(skillName)
if err != nil {
return "", err
}
fullPath := filepath.Join(skill.BasePath, resourcePath)
// Validar que el path esta dentro de la skill (evitar path traversal)
absBasePath, err := filepath.Abs(skill.BasePath)
if err != nil {
return "", fmt.Errorf("abs base path: %w", err)
}
absFullPath, err := filepath.Abs(fullPath)
if err != nil {
return "", fmt.Errorf("abs resource path: %w", err)
}
if !strings.HasPrefix(absFullPath, absBasePath) {
return "", fmt.Errorf("path traversal detected: %s", resourcePath)
}
content, err := os.ReadFile(absFullPath)
if err != nil {
return "", fmt.Errorf("read resource: %w", err)
}
return string(content), nil
}
// parseSkillMD extrae el frontmatter YAML y el cuerpo markdown de un SKILL.md.
func parseSkillMD(path string) (skills.SkillMeta, string, error) {
f, err := os.Open(path)
if err != nil {
return skills.SkillMeta{}, "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
var yamlLines []string
var bodyLines []string
inYAML := false
yamlClosed := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "---" {
if !inYAML {
inYAML = true
continue
} else {
inYAML = false
yamlClosed = true
continue
}
}
if inYAML {
yamlLines = append(yamlLines, line)
} else if yamlClosed {
bodyLines = append(bodyLines, line)
}
}
if err := scanner.Err(); err != nil {
return skills.SkillMeta{}, "", err
}
// Parse YAML frontmatter
var meta skills.SkillMeta
yamlStr := strings.Join(yamlLines, "\n")
if err := yaml.Unmarshal([]byte(yamlStr), &meta); err != nil {
return skills.SkillMeta{}, "", fmt.Errorf("parse yaml: %w", err)
}
// Cuerpo markdown
body := strings.Join(bodyLines, "\n")
return meta, body, nil
}
// listFiles retorna una lista de archivos (rutas relativas) dentro de un directorio.
// Si el directorio no existe, retorna una lista vacia.
func listFiles(dir string) []string {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var files []string
for _, entry := range entries {
if !entry.IsDir() {
files = append(files, entry.Name())
}
}
return files
}
+131
View File
@@ -0,0 +1,131 @@
package skills
import (
"os"
"path/filepath"
"testing"
)
func TestLoader(t *testing.T) {
// Create temporary skills directory structure
tmpDir := t.TempDir()
// Create a test skill
skillDir := filepath.Join(tmpDir, "devops", "test-skill")
if err := os.MkdirAll(skillDir, 0755); err != nil {
t.Fatal(err)
}
// Write SKILL.md
skillMD := `---
name: test-skill
description: A test skill for unit testing
---
# Test Skill
This is the instructions body.
It has multiple lines.
`
skillMDPath := filepath.Join(skillDir, "SKILL.md")
if err := os.WriteFile(skillMDPath, []byte(skillMD), 0644); err != nil {
t.Fatal(err)
}
// Create scripts/ directory with a test script
scriptsDir := filepath.Join(skillDir, "scripts")
if err := os.MkdirAll(scriptsDir, 0755); err != nil {
t.Fatal(err)
}
scriptPath := filepath.Join(scriptsDir, "test.sh")
if err := os.WriteFile(scriptPath, []byte("#!/bin/bash\necho test"), 0755); err != nil {
t.Fatal(err)
}
// Create references/ directory with a test reference
refsDir := filepath.Join(skillDir, "references")
if err := os.MkdirAll(refsDir, 0755); err != nil {
t.Fatal(err)
}
refPath := filepath.Join(refsDir, "api.md")
if err := os.WriteFile(refPath, []byte("# API Reference"), 0644); err != nil {
t.Fatal(err)
}
loader := NewLoader(tmpDir)
// Test LoadMeta
t.Run("LoadMeta", func(t *testing.T) {
metas, err := loader.LoadMeta()
if err != nil {
t.Fatalf("LoadMeta failed: %v", err)
}
if len(metas) != 1 {
t.Fatalf("expected 1 skill, got %d", len(metas))
}
meta := metas[0]
if meta.Name != "test-skill" {
t.Errorf("expected name 'test-skill', got %q", meta.Name)
}
if meta.Category != "devops" {
t.Errorf("expected category 'devops', got %q", meta.Category)
}
if meta.Description != "A test skill for unit testing" {
t.Errorf("expected description 'A test skill for unit testing', got %q", meta.Description)
}
})
// Test LoadSkill
t.Run("LoadSkill", func(t *testing.T) {
skill, err := loader.LoadSkill("test-skill")
if err != nil {
t.Fatalf("LoadSkill failed: %v", err)
}
if skill.Meta.Name != "test-skill" {
t.Errorf("expected name 'test-skill', got %q", skill.Meta.Name)
}
if skill.Instructions == "" {
t.Error("instructions should not be empty")
}
if len(skill.Scripts) != 1 || skill.Scripts[0] != "test.sh" {
t.Errorf("expected Scripts=['test.sh'], got %v", skill.Scripts)
}
if len(skill.References) != 1 || skill.References[0] != "api.md" {
t.Errorf("expected References=['api.md'], got %v", skill.References)
}
})
// Test LoadSkill nonexistent
t.Run("LoadSkill_nonexistent", func(t *testing.T) {
_, err := loader.LoadSkill("nonexistent")
if err == nil {
t.Error("expected error for nonexistent skill")
}
})
// Test ReadResource
t.Run("ReadResource", func(t *testing.T) {
content, err := loader.ReadResource("test-skill", "scripts/test.sh")
if err != nil {
t.Fatalf("ReadResource failed: %v", err)
}
if content != "#!/bin/bash\necho test" {
t.Errorf("unexpected content: %q", content)
}
})
// Test ReadResource path traversal protection
t.Run("ReadResource_path_traversal", func(t *testing.T) {
_, err := loader.ReadResource("test-skill", "../../../etc/passwd")
if err == nil {
t.Error("expected error for path traversal attempt")
}
})
}
+169
View File
@@ -0,0 +1,169 @@
// Package ssh provides impure SSH command execution.
package ssh
import (
"bytes"
"context"
"fmt"
"log/slog"
"net"
"os"
"time"
gossh "golang.org/x/crypto/ssh"
"github.com/enmanuel/agents/internal/config"
"github.com/enmanuel/agents/pkg/tools"
"github.com/enmanuel/agents/shell/logger"
)
// Result holds the output of an SSH command execution.
type Result struct {
Stdout string
Stderr string
ExitCode int
Err error
}
// Executor runs SSH commands against configured targets.
type Executor struct {
cfg config.SSHCfg
logger *slog.Logger
}
// NewExecutor creates an Executor from the SSH config section.
func NewExecutor(cfg config.SSHCfg, log *slog.Logger) *Executor {
return &Executor{cfg: cfg, logger: log.With(logger.FieldComponent, "ssh")}
}
// Execute runs the SSH command described by spec. Impure.
func (e *Executor) Execute(ctx context.Context, spec tools.SSHCommandSpec) Result {
cmdPreview := spec.Command
if len(cmdPreview) > 80 {
cmdPreview = cmdPreview[:80] + "..."
}
e.logger.Info("ssh_exec_start", "target", spec.Target, "command", cmdPreview)
start := time.Now()
target, ok := e.cfg.Targets[spec.Target]
if !ok {
e.logger.Error("ssh_exec_error", "target", spec.Target, "err", "unknown target")
return Result{Err: fmt.Errorf("unknown SSH target: %s", spec.Target)}
}
if len(target.Hosts) == 0 {
e.logger.Error("ssh_exec_error", "target", spec.Target, "err", "no hosts")
return Result{Err: fmt.Errorf("no hosts for target: %s", spec.Target)}
}
// Use first host (round-robin or load balancing can be added later)
host := target.Hosts[0]
user := target.User
if user == "" {
user = e.cfg.Defaults.User
}
port := target.Port
if port == 0 {
port = e.cfg.Defaults.Port
}
if port == 0 {
port = 22
}
keyEnv := target.KeyFileEnv
if keyEnv == "" {
keyEnv = e.cfg.Defaults.KeyFileEnv
}
signer, err := loadSigner(keyEnv)
if err != nil {
ms := time.Since(start).Milliseconds()
e.logger.Error("ssh_exec_error", "target", spec.Target, logger.FieldDurationMS, ms, "err", err)
return Result{Err: fmt.Errorf("load SSH key: %w", err)}
}
sshCfg := &gossh.ClientConfig{
User: user,
Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)},
HostKeyCallback: gossh.InsecureIgnoreHostKey(), // TODO: use known_hosts
Timeout: e.cfg.Defaults.Timeout,
}
if sshCfg.Timeout == 0 {
sshCfg.Timeout = 10 * time.Second
}
addr := fmt.Sprintf("%s:%d", host, port)
conn, err := gossh.Dial("tcp", addr, sshCfg)
if err != nil {
ms := time.Since(start).Milliseconds()
e.logger.Error("ssh_exec_error", "target", spec.Target, "host", addr, logger.FieldDurationMS, ms, "err", err)
return Result{Err: fmt.Errorf("ssh dial %s: %w", addr, err)}
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
ms := time.Since(start).Milliseconds()
e.logger.Error("ssh_exec_error", "target", spec.Target, logger.FieldDurationMS, ms, "err", err)
return Result{Err: fmt.Errorf("ssh session: %w", err)}
}
defer session.Close()
var stdout, stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
// Respect context cancellation via a goroutine
done := make(chan error, 1)
go func() { done <- session.Run(spec.Command) }()
select {
case <-ctx.Done():
session.Signal(gossh.SIGTERM)
ms := time.Since(start).Milliseconds()
e.logger.Warn("ssh_exec_cancelled", "target", spec.Target, logger.FieldDurationMS, ms)
return Result{Err: ctx.Err()}
case err := <-done:
ms := time.Since(start).Milliseconds()
code := 0
if err != nil {
var exitErr *gossh.ExitError
if ok := asExitError(err, &exitErr); ok {
code = exitErr.ExitStatus()
} else {
e.logger.Error("ssh_exec_error", "target", spec.Target, logger.FieldDurationMS, ms, "err", err)
return Result{Err: err}
}
}
e.logger.Info("ssh_exec_end", "target", spec.Target, "exit_code", code, logger.FieldDurationMS, ms)
return Result{
Stdout: stdout.String(),
Stderr: stderr.String(),
ExitCode: code,
}
}
}
func loadSigner(keyFileEnv string) (gossh.Signer, error) {
keyPath := os.Getenv(keyFileEnv)
if keyPath == "" {
return nil, fmt.Errorf("env var %s not set", keyFileEnv)
}
raw, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}
return gossh.ParsePrivateKey(raw)
}
// asExitError is a helper for type-asserting ssh.ExitError.
func asExitError(err error, target **gossh.ExitError) bool {
e, ok := err.(*gossh.ExitError)
if ok {
*target = e
}
return ok
}
// Ensure net is used (for future jump host support)
var _ = net.Dial
+35
View File
@@ -0,0 +1,35 @@
package transportunibus
import (
"context"
"log/slog"
"strings"
"github.com/enmanuel/agents/pkg/transport"
)
// DemoEchoHandler returns a minimal bot handler that proves the unibus transport
// end to end: it receives a neutral InboundMessage and answers in the same room.
// It echoes the message body back as a reply, with one built-in command
// (!ping → pong) to show command routing works over the bus. It is intentionally
// tiny — the point is the transport, not the bot.
func DemoEchoHandler(t transport.Transport, logger *slog.Logger) transport.Handler {
if logger == nil {
logger = slog.Default()
}
return func(ctx context.Context, in transport.InboundMessage) {
reply := "echo: " + in.Body
if strings.TrimSpace(in.Body) == "!ping" {
reply = "pong"
}
out := transport.OutboundReply{
RoomID: in.RoomID,
ReplyTo: in.MsgID,
ThreadID: in.ThreadID,
Markdown: reply,
}
if err := t.Reply(ctx, out); err != nil {
logger.Error("demo echo reply failed", "err", err, "sender", in.SenderID)
}
}
}
+314
View File
@@ -0,0 +1,314 @@
// Package transportunibus implements transport.Transport over the unibus message
// bus (github.com/enmanuel/unibus). A bot built on the neutral
// transport.Transport speaks unibus instead of Matrix: it discovers the rooms it
// has been invited to, joins them, and replies in the room a message arrived on.
//
// Room-based model ("everything is a room"):
//
// - There is no inbox/outbox subject convention. A conversation is a unibus
// room; a 1:1 DM is just a room with two members. A human peer creates an
// encrypted room (room.ModeMatrix), invites the bot by its endpoint id, and
// publishes a message. The bot finds the room by polling ListMyRooms,
// Joins (fetching the sealed room key), Subscribes, and answers in place.
// - The control plane is pull-based: there is no server push of invitations,
// so the bot polls ListMyRooms on a ticker and reacts to rooms it has not
// seen before.
//
// This adapter carries no Matrix (mautrix) types, so the agent core driving it
// stays transport-neutral.
package transportunibus
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/enmanuel/agents/internal/config"
"github.com/enmanuel/agents/pkg/message"
"github.com/enmanuel/agents/pkg/transport"
"github.com/enmanuel/unibus/pkg/client"
"github.com/enmanuel/unibus/pkg/frame"
)
// defaultCommandPrefix marks a command message (e.g. "!ping") when the bot's
// config does not override it.
const defaultCommandPrefix = "!"
// discoveryInterval is how often the bot polls the control plane for rooms it
// has been invited to. The control plane has no push, so this is the latency a
// human waits between inviting the bot and the bot joining.
const discoveryInterval = 2 * time.Second
// Transport is a unibus-backed transport.Transport for one bot. It discovers
// rooms, subscribes to them, and replies in the room each message came from.
type Transport struct {
handle string
commandPrefix string
client *client.Client
endpoint string // this bot's own endpoint id, to skip its own messages
ctrlURL string
http *http.Client
logger *slog.Logger
mu sync.Mutex
subscribed map[string]*client.Sub // roomID -> active subscription
memberCount map[string]int // roomID -> cached member count (for IsDirectMsg)
}
// compile-time assertion that Transport satisfies the neutral interface.
var _ transport.Transport = (*Transport)(nil)
// New connects to a unibus deployment using the bot's BusCfg. It loads (or
// creates) the bot's long-term identity, connects to the NATS data plane and
// the membershipd control plane, and records the handle used for mention
// detection. It does not create or join any room: rooms are discovered at Run
// time as the bot is invited to them.
func New(busCfg config.BusCfg, logger *slog.Logger) (*Transport, error) {
id, err := client.LoadOrCreateIdentity(busCfg.IdentityPath)
if err != nil {
return nil, fmt.Errorf("transportunibus: identity: %w", err)
}
c, err := client.New(busCfg.NatsURL, busCfg.CtrlURL, id)
if err != nil {
return nil, fmt.Errorf("transportunibus: connect: %w", err)
}
if logger == nil {
logger = slog.Default()
}
prefix := busCfg.CommandPrefix
if prefix == "" {
prefix = defaultCommandPrefix
}
return &Transport{
handle: busCfg.Handle,
commandPrefix: prefix,
client: c,
endpoint: c.Endpoint().ID,
ctrlURL: busCfg.CtrlURL,
http: &http.Client{Timeout: 10 * time.Second},
logger: logger,
subscribed: map[string]*client.Sub{},
memberCount: map[string]int{},
}, nil
}
// Endpoint returns this bot's public endpoint id. A human peer needs it to
// invite the bot to a room (the bot logs it at startup; a directory is a later
// step).
func (t *Transport) Endpoint() string { return t.endpoint }
// BusEndpoint returns this bot's full public endpoint (id + signing/key-exchange
// public keys). A peer inviting the bot to an encrypted room needs the public
// keys to seal the room key for it.
func (t *Transport) BusEndpoint() client.Endpoint { return t.client.Endpoint() }
// Run polls the control plane for rooms the bot has been invited to, joins and
// subscribes to each new one, and delivers every decrypted frame to handler as
// a neutral InboundMessage. It blocks until ctx is cancelled.
func (t *Transport) Run(ctx context.Context, handler transport.Handler) error {
t.logger.Info("unibus transport running", "handle", t.handle, "endpoint", t.endpoint)
ticker := time.NewTicker(discoveryInterval)
defer ticker.Stop()
defer t.unsubscribeAll()
// Discover immediately so we don't wait a full interval on startup.
t.discover(ctx, handler)
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
t.discover(ctx, handler)
}
}
}
// discover lists the bot's rooms and joins+subscribes to any it has not seen.
func (t *Transport) discover(ctx context.Context, handler transport.Handler) {
rooms, err := t.client.ListMyRooms()
if err != nil {
t.logger.Warn("unibus discover: list rooms failed", "err", err)
return
}
for _, r := range rooms {
t.mu.Lock()
_, already := t.subscribed[r.RoomID]
t.mu.Unlock()
if already {
continue
}
if err := t.client.Join(r.RoomID); err != nil {
t.logger.Warn("unibus discover: join failed", "room", r.RoomID, "err", err)
continue
}
roomID := r.RoomID
sub, err := t.client.Subscribe(roomID, func(f frame.Frame, plaintext []byte) {
t.onFrame(ctx, handler, roomID, f, plaintext)
})
if err != nil {
t.logger.Warn("unibus discover: subscribe failed", "room", roomID, "err", err)
continue
}
t.mu.Lock()
t.subscribed[roomID] = sub
t.mu.Unlock()
t.logger.Info("joined and subscribed to room", "room", roomID, "subject", r.Subject)
}
}
// onFrame maps a decrypted frame to a neutral InboundMessage and delivers it.
// It skips the bot's own messages (to avoid replying to itself), parses any
// command, and computes IsDirectMsg (2-member room) and IsMention (handle in
// body) so the agent core's command/LLM flow behaves exactly as it did on
// Matrix.
func (t *Transport) onFrame(ctx context.Context, handler transport.Handler, roomID string, f frame.Frame, plaintext []byte) {
if f.Sender == t.endpoint {
return // never react to our own messages
}
body := string(plaintext)
isDM := t.roomMemberCount(roomID) == 2
isMention := t.handle != "" && strings.Contains(strings.ToLower(body), strings.ToLower(t.handle))
// Reuse the pure command parser so "!cmd args" is split the same way the
// Matrix listener split it.
parsed := message.Parse(body, f.Sender, roomID, 0, isDM, message.ParseOptions{
CommandPrefix: t.commandPrefix,
})
handler(ctx, transport.InboundMessage{
RoomID: roomID,
Subject: f.Subject,
SenderID: f.Sender,
MsgID: f.MsgID,
ThreadID: f.ThreadID,
ReplyTo: f.ReplyTo,
Body: body,
Command: parsed.Command,
Args: parsed.Args,
IsDirectMsg: isDM,
IsMention: isMention,
})
}
// Reply publishes a reply into the room the message came from. When the reply
// carries a ReplyTo / ThreadID anchor it is published as a threaded reply so
// receivers can render the conversation tree.
func (t *Transport) Reply(_ context.Context, out transport.OutboundReply) error {
if out.ReplyTo != "" || out.ThreadID != "" {
return t.client.PublishReply(out.RoomID, []byte(out.Markdown), out.ReplyTo, out.ThreadID)
}
return t.client.Publish(out.RoomID, []byte(out.Markdown))
}
// Send posts a standalone message into a room.
func (t *Transport) Send(_ context.Context, roomID, markdown string) error {
return t.client.Publish(roomID, []byte(markdown))
}
// Close unsubscribes from every room and releases the unibus client connection.
func (t *Transport) Close() error {
t.unsubscribeAll()
return t.client.Close()
}
// Sender returns an adapter that satisfies the effects/cron/tools Sender
// interface, letting the agent's effects runner, scheduler, and bus_send tool
// publish into rooms over this transport.
func (t *Transport) Sender() *busSender { return &busSender{t: t} }
// unsubscribeAll cancels every active room subscription.
func (t *Transport) unsubscribeAll() {
t.mu.Lock()
subs := t.subscribed
t.subscribed = map[string]*client.Sub{}
t.mu.Unlock()
for roomID, sub := range subs {
if err := sub.Unsubscribe(); err != nil {
t.logger.Warn("unibus: unsubscribe failed", "room", roomID, "err", err)
}
}
}
// roomMemberCount returns the number of members in a room, used to decide
// IsDirectMsg. The control plane exposes GET /rooms/{id}/members; the result is
// cached per room since membership rarely changes during a conversation.
func (t *Transport) roomMemberCount(roomID string) int {
t.mu.Lock()
if n, ok := t.memberCount[roomID]; ok {
t.mu.Unlock()
return n
}
t.mu.Unlock()
n, err := t.fetchMemberCount(roomID)
if err != nil {
t.logger.Warn("unibus: member count fetch failed", "room", roomID, "err", err)
return 0 // unknown → treat as not-a-DM (mention still drives the LLM)
}
t.mu.Lock()
t.memberCount[roomID] = n
t.mu.Unlock()
return n
}
// memberJSON mirrors the membership server's GET /rooms/{id}/members element.
// Only the count matters here, so the body fields are ignored.
type memberJSON struct {
Endpoint string `json:"endpoint"`
}
// fetchMemberCount calls the membershipd control plane directly to count the
// members of a room. unibus's client does not expose this, and the task forbids
// modifying unibus, so the minimal HTTP GET lives here.
func (t *Transport) fetchMemberCount(roomID string) (int, error) {
url := strings.TrimRight(t.ctrlURL, "/") + "/rooms/" + roomID + "/members"
resp, err := t.http.Get(url)
if err != nil {
return 0, fmt.Errorf("get members: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return 0, fmt.Errorf("get members: status %d", resp.StatusCode)
}
var members []memberJSON
if err := json.NewDecoder(resp.Body).Decode(&members); err != nil {
return 0, fmt.Errorf("decode members: %w", err)
}
return len(members), nil
}
// busSender adapts a *Transport to the effects.Sender / cron.Sender / tools
// Sender interface (SendText/SendMarkdown/SendReplyMarkdown/SendThreadMarkdown/
// SendTyping). All sends publish into the given room; SendTyping is a no-op
// because unibus has no typing-indicator concept.
type busSender struct{ t *Transport }
func (s *busSender) SendText(_ context.Context, roomID, text string) error {
return s.t.client.Publish(roomID, []byte(text))
}
func (s *busSender) SendMarkdown(_ context.Context, roomID, markdown string) error {
return s.t.client.Publish(roomID, []byte(markdown))
}
func (s *busSender) SendReplyMarkdown(_ context.Context, roomID, inReplyTo, markdown string) error {
return s.t.client.PublishReply(roomID, []byte(markdown), inReplyTo, "")
}
func (s *busSender) SendThreadMarkdown(_ context.Context, roomID, threadRootID, inReplyTo, markdown string) error {
return s.t.client.PublishReply(roomID, []byte(markdown), inReplyTo, threadRootID)
}
// SendTyping is a no-op: unibus has no typing indicator.
func (s *busSender) SendTyping(_ context.Context, _ string, _ bool) error { return nil }
+243
View File
@@ -0,0 +1,243 @@
package transportunibus_test
import (
"context"
"net"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"testing"
"time"
"github.com/enmanuel/agents/internal/config"
"github.com/enmanuel/agents/pkg/transport"
"github.com/enmanuel/agents/shell/transportunibus"
cs "fn-registry/functions/cybersecurity"
"github.com/enmanuel/unibus/pkg/blobstore"
"github.com/enmanuel/unibus/pkg/client"
"github.com/enmanuel/unibus/pkg/embeddednats"
"github.com/enmanuel/unibus/pkg/frame"
"github.com/enmanuel/unibus/pkg/membership"
"github.com/enmanuel/unibus/pkg/room"
server "github.com/nats-io/nats-server/v2/server"
)
// harness boots an embedded NATS + an in-process membershipd, mirroring the
// unibus test harness so this adapter can be exercised without any external
// service.
type harness struct {
natsURL string
ctrlURL string
ns *server.Server
httpts *httptest.Server
}
func freePort(t *testing.T) int {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("free port: %v", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}
func newHarness(t *testing.T) *harness {
t.Helper()
dir := t.TempDir()
ns, err := embeddednats.StartHost(filepath.Join(dir, "js"), "127.0.0.1", freePort(t))
if err != nil {
t.Fatalf("embedded nats: %v", err)
}
store, err := membership.Open(filepath.Join(dir, "unibus.db"))
if err != nil {
ns.Shutdown()
t.Fatalf("membership store: %v", err)
}
blobs, err := blobstore.New(filepath.Join(dir, "blobs"))
if err != nil {
ns.Shutdown()
t.Fatalf("blob store: %v", err)
}
httpts := httptest.NewServer(membership.NewServer(store, blobs))
h := &harness{natsURL: embeddednats.ClientURL(ns), ctrlURL: httpts.URL, ns: ns, httpts: httpts}
t.Cleanup(func() {
httpts.Close()
store.Close()
ns.Shutdown()
ns.WaitForShutdown()
})
return h
}
func waitHealth(t *testing.T, ctrlURL string) {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
resp, err := http.Get(ctrlURL + "/healthz")
if err == nil && resp.StatusCode == 200 {
resp.Body.Close()
return
}
if resp != nil {
resp.Body.Close()
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("membershipd never became healthy")
}
// botCfg builds a BusCfg pointing the bot at the harness, with a fresh identity
// file under the test's temp dir.
func botCfg(t *testing.T, h *harness, handle string) config.BusCfg {
t.Helper()
return config.BusCfg{
NatsURL: h.natsURL,
CtrlURL: h.ctrlURL,
IdentityPath: filepath.Join(t.TempDir(), handle+".id"),
Handle: handle,
}
}
// TestBotEchoesInEncryptedRoom is the headline room-based test: a human peer
// creates an encrypted (room.ModeMatrix) room, invites the bot by its endpoint,
// and publishes a mention. The bot — driven by Transport.Run + a tiny echo
// handler that replies via Reply — answers IN THE SAME room, and the human
// receives the reply decrypted. No Matrix is involved end to end.
func TestBotEchoesInEncryptedRoom(t *testing.T) {
h := newHarness(t)
waitHealth(t, h.ctrlURL)
bot, err := transportunibus.New(botCfg(t, h, "demo"), nil)
if err != nil {
t.Fatalf("bot transport: %v", err)
}
defer bot.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = bot.Run(ctx, transportunibus.DemoEchoHandler(bot, nil)) }()
// Human peer.
userID, err := cs.GenerateIdentity()
if err != nil {
t.Fatalf("user identity: %v", err)
}
user, err := client.New(h.natsURL, h.ctrlURL, userID)
if err != nil {
t.Fatalf("user client: %v", err)
}
defer user.Close()
// Human creates an encrypted room and invites the bot by its endpoint id.
roomID, err := user.CreateRoom("conv.demo", room.ModeMatrix)
if err != nil {
t.Fatalf("create room: %v", err)
}
// Invite the bot by its full endpoint (id + public keys), so the human can
// seal the encrypted room key for it.
if err := user.Invite(roomID, bot.BusEndpoint()); err != nil {
t.Fatalf("invite bot: %v", err)
}
// Human subscribes to the same room to receive the bot's reply.
var mu sync.Mutex
var bodies []string
var sawAnchored bool
if err := user.Join(roomID); err != nil {
t.Fatalf("user join: %v", err)
}
sub, err := user.Subscribe(roomID, func(f frame.Frame, plaintext []byte) {
mu.Lock()
bodies = append(bodies, string(plaintext))
if f.ReplyTo != "" {
sawAnchored = true
}
mu.Unlock()
})
if err != nil {
t.Fatalf("user subscribe: %v", err)
}
defer sub.Unsubscribe()
// Give the bot's discovery ticker time to find, join and subscribe to the room.
time.Sleep(300 * time.Millisecond)
// Human posts a message mentioning the bot's handle.
if err := user.Publish(roomID, []byte("hola demo")); err != nil {
t.Fatalf("user publish: %v", err)
}
if _, ok := waitBody(&mu, &bodies, "echo: hola demo", 5*time.Second); !ok {
t.Fatalf("never received echo reply; got %v", snapshot(&mu, &bodies))
}
mu.Lock()
anchored := sawAnchored
mu.Unlock()
if !anchored {
t.Fatalf("reply did not carry a ReplyTo anchor")
}
// Command over the bus → pong, in the same room.
if err := user.Publish(roomID, []byte("!ping")); err != nil {
t.Fatalf("user publish ping: %v", err)
}
if _, ok := waitBody(&mu, &bodies, "pong", 5*time.Second); !ok {
t.Fatalf("never received pong; got %v", snapshot(&mu, &bodies))
}
}
// TestRunStopsOnContextCancel is an error/lifecycle path: Run must return when
// its context is cancelled rather than blocking forever.
func TestRunStopsOnContextCancel(t *testing.T) {
h := newHarness(t)
waitHealth(t, h.ctrlURL)
bot, err := transportunibus.New(botCfg(t, h, "lifecycle"), nil)
if err != nil {
t.Fatalf("bot transport: %v", err)
}
defer bot.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() { done <- bot.Run(ctx, func(context.Context, transport.InboundMessage) {}) }()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case err := <-done:
if err != context.Canceled {
t.Fatalf("Run returned %v, want context.Canceled", err)
}
case <-time.After(3 * time.Second):
t.Fatalf("Run did not return after context cancel")
}
}
// ---- helpers ----
func waitBody(mu *sync.Mutex, slice *[]string, want string, timeout time.Duration) (string, bool) {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
mu.Lock()
for _, s := range *slice {
if s == want {
mu.Unlock()
return s, true
}
}
mu.Unlock()
time.Sleep(25 * time.Millisecond)
}
return "", false
}
func snapshot(mu *sync.Mutex, slice *[]string) []string {
mu.Lock()
defer mu.Unlock()
return append([]string(nil), (*slice)...)
}
+419
View File
@@ -0,0 +1,419 @@
// Package tui is the impure shell layer for the TUI.
// It converts pure Intent values into real I/O via tea.Cmd.
package tui
import (
"fmt"
"os"
"os/exec"
"strings"
"syscall"
"time"
tea "github.com/charmbracelet/bubbletea"
puretui "github.com/enmanuel/agents/pkg/tui"
"github.com/enmanuel/agents/shell/process"
)
// Adapter bridges pure Intents with the process Manager.
type Adapter struct {
mgr *process.Manager
}
// NewAdapter creates an Adapter with the given Manager.
func NewAdapter(mgr *process.Manager) *Adapter {
return &Adapter{mgr: mgr}
}
// RunIntent converts a pure Intent into a bubbletea Cmd that performs I/O.
func (a *Adapter) RunIntent(intent puretui.Intent) tea.Cmd {
switch intent.Kind {
case puretui.IntentLoadAgents:
return a.loadAgents()
case puretui.IntentEnableAgent:
return a.enableAgent(intent.AgentID)
case puretui.IntentDisableAgent:
return a.disableAgent(intent.AgentID)
case puretui.IntentReloadAgent:
return a.reloadAgent(intent.AgentID)
case puretui.IntentReloadAll:
return a.reloadAll()
case puretui.IntentRestartAgent:
return a.restartAgent(intent.AgentID)
case puretui.IntentLoadLogs:
return a.loadLogs(intent.AgentID)
case puretui.IntentStartLauncher:
return a.startLauncher()
case puretui.IntentStopLauncher:
return a.stopLauncher()
case puretui.IntentRestartLauncher:
return a.restartLauncher()
case puretui.IntentKillLauncher:
return a.killLauncher()
case puretui.IntentRebuildRestart:
return a.rebuildRestart()
case puretui.IntentRunTests:
return a.runGoTests()
case puretui.IntentRunGoTests:
return a.runGoTests()
case puretui.IntentRunE2ETests:
return a.runE2ETests(false)
case puretui.IntentRunE2EHeadTests:
return a.runE2ETests(true)
case puretui.IntentRunAllTests:
return a.runAllTests()
case puretui.IntentTick:
return a.tick()
case puretui.IntentQuit:
return tea.Quit
default:
return nil
}
}
func (a *Adapter) loadAgents() tea.Cmd {
return func() tea.Msg {
statuses, err := a.mgr.StatusAllUnified()
if err != nil {
return puretui.MsgAgentsLoaded{}
}
views := make([]puretui.AgentView, len(statuses))
for i, s := range statuses {
views[i] = puretui.AgentView{
ID: s.ID,
Name: s.Name,
Version: s.Version,
Desc: s.Desc,
Enabled: s.Enabled,
Running: s.Running,
PID: s.PID,
}
}
msg := puretui.MsgAgentsLoaded{
Agents: views,
LauncherRunning: a.mgr.IsUnifiedRunning(),
LauncherPID: a.mgr.UnifiedPID(),
}
// Launcher stats
if msg.LauncherRunning {
if stats, err := a.mgr.UnifiedStats(); err == nil {
msg.LauncherUptime = formatUptime(stats.UptimeSecs)
msg.LauncherMemory = formatBytes(stats.MemRSSKB * 1024)
msg.LauncherCPU = fmt.Sprintf("%.1f%%", stats.CPUPct)
msg.LauncherLogSize = formatBytes(stats.LogBytes)
}
}
return msg
}
}
func (a *Adapter) enableAgent(id string) tea.Cmd {
return func() tea.Msg {
err := a.mgr.ToggleEnabled(id, true)
return puretui.MsgActionDone{AgentID: id, Action: "Enable", Err: err}
}
}
func (a *Adapter) disableAgent(id string) tea.Cmd {
return func() tea.Msg {
err := a.mgr.ToggleEnabled(id, false)
return puretui.MsgActionDone{AgentID: id, Action: "Disable", Err: err}
}
}
// reloadAgent hot-reloads a single agent via SIGHUP without stopping the launcher.
func (a *Adapter) reloadAgent(id string) tea.Cmd {
return func() tea.Msg {
pid := a.mgr.UnifiedPID()
if pid <= 0 {
return puretui.MsgActionDone{AgentID: id, Action: "Reload",
Err: fmt.Errorf("el launcher no está corriendo")}
}
if id != "" {
if err := os.WriteFile("run/reload.txt", []byte(id), 0o644); err != nil {
return puretui.MsgActionDone{AgentID: id, Action: "Reload", Err: err}
}
}
err := syscall.Kill(pid, syscall.SIGHUP)
if err != nil {
return puretui.MsgActionDone{AgentID: id, Action: "Reload", Err: err}
}
time.Sleep(1 * time.Second)
return puretui.MsgActionDone{AgentID: id, Action: "Reload", Err: nil}
}
}
// reloadAll hot-reloads all agents via SIGHUP (no reload.txt → reload all).
func (a *Adapter) reloadAll() tea.Cmd {
return func() tea.Msg {
pid := a.mgr.UnifiedPID()
if pid <= 0 {
return puretui.MsgServerActionDone{Action: "Reload All",
Err: fmt.Errorf("el launcher no está corriendo")}
}
// Remove stale reload.txt so the launcher reloads all agents.
_ = os.Remove("run/reload.txt")
err := syscall.Kill(pid, syscall.SIGHUP)
if err != nil {
return puretui.MsgServerActionDone{Action: "Reload All", Err: err}
}
time.Sleep(1 * time.Second)
return puretui.MsgServerActionDone{Action: "Reload All", Err: nil}
}
}
// restartAgent stops and restarts the whole launcher (full restart, all agents).
func (a *Adapter) restartAgent(id string) tea.Cmd {
return func() tea.Msg {
_ = a.mgr.StopUnified()
time.Sleep(500 * time.Millisecond)
err := a.mgr.StartUnified()
if err == nil {
time.Sleep(500 * time.Millisecond)
}
return puretui.MsgActionDone{AgentID: id, Action: "Restart", Err: err}
}
}
func (a *Adapter) startLauncher() tea.Cmd {
return func() tea.Msg {
err := a.mgr.StartUnified()
if err == nil {
time.Sleep(500 * time.Millisecond)
}
return puretui.MsgServerActionDone{Action: "Start", Err: err}
}
}
func (a *Adapter) stopLauncher() tea.Cmd {
return func() tea.Msg {
err := a.mgr.StopUnified()
return puretui.MsgServerActionDone{Action: "Stop", Err: err}
}
}
func (a *Adapter) restartLauncher() tea.Cmd {
return func() tea.Msg {
_ = a.mgr.StopUnified()
time.Sleep(500 * time.Millisecond)
err := a.mgr.StartUnified()
if err == nil {
time.Sleep(500 * time.Millisecond)
}
return puretui.MsgServerActionDone{Action: "Restart", Err: err}
}
}
func (a *Adapter) killLauncher() tea.Cmd {
return func() tea.Msg {
err := a.mgr.KillUnified()
return puretui.MsgServerActionDone{Action: "Kill", Err: err}
}
}
func (a *Adapter) rebuildRestart() tea.Cmd {
return func() tea.Msg {
wasRunning := a.mgr.IsUnifiedRunning()
// Stop if running
if wasRunning {
_ = a.mgr.StopUnified()
time.Sleep(500 * time.Millisecond)
}
// Build
buildOut, buildErr := a.mgr.Build()
if buildErr != nil {
// Build failed — try to restart if was running
if wasRunning {
_ = a.mgr.StartUnified()
}
lines := strings.Split(strings.TrimSpace(buildOut), "\n")
tail := buildOut
if len(lines) > 5 {
tail = strings.Join(lines[len(lines)-5:], "\n")
}
return puretui.MsgRebuildDone{BuildOK: false, BuildLog: tail}
}
// Restart launcher
started := false
var startErr error
if wasRunning {
startErr = a.mgr.StartUnified()
if startErr == nil {
started = true
time.Sleep(500 * time.Millisecond)
}
}
return puretui.MsgRebuildDone{
BuildOK: true,
Started: started,
Err: startErr,
}
}
}
func (a *Adapter) loadLogs(id string) tea.Cmd {
return func() tea.Msg {
var lines []string
var err error
if id == "" {
// Launcher logs
lines, err = a.mgr.UnifiedLogTail(100)
} else {
// Agent logs — in unified mode, all go to launcher log
lines, err = a.mgr.UnifiedLogTail(100)
}
if err != nil {
return puretui.MsgLogsLoaded{Lines: []string{"Error: " + err.Error()}}
}
return puretui.MsgLogsLoaded{Lines: lines}
}
}
func (a *Adapter) runGoTests() tea.Cmd {
return func() tea.Msg {
goBin, err := exec.LookPath("go")
if err != nil {
goBin = "/usr/local/go/bin/go"
}
cmd := exec.Command(goBin, "test", "-tags", "goolm", "-count=1", "./...")
cmd.Env = a.mgr.BuildEnv()
out, err := cmd.CombinedOutput()
output := strings.TrimSpace(string(out))
if output == "" && err != nil {
output = "Error: " + err.Error()
}
lines := strings.Split(output, "\n")
return puretui.MsgTestsDone{Kind: puretui.TestKindGo, Passed: err == nil, Output: lines}
}
}
func (a *Adapter) runE2ETests(headed bool) tea.Cmd {
return func() tea.Msg {
args := []string{"./dev-scripts/e2e/run.sh"}
if headed {
args = append(args, "--headed")
}
cmd := exec.Command("bash", args...)
cmd.Env = a.mgr.BuildEnv()
out, err := cmd.CombinedOutput()
output := strings.TrimSpace(string(out))
if output == "" && err != nil {
output = "Error: " + err.Error()
}
lines := strings.Split(output, "\n")
kind := puretui.TestKindE2E
if headed {
kind = puretui.TestKindE2EHead
}
return puretui.MsgTestsDone{Kind: kind, Passed: err == nil, Output: lines}
}
}
func (a *Adapter) runAllTests() tea.Cmd {
return func() tea.Msg {
var allLines []string
// Go tests first
goBin, err := exec.LookPath("go")
if err != nil {
goBin = "/usr/local/go/bin/go"
}
goCmd := exec.Command(goBin, "test", "-tags", "goolm", "-count=1", "./...")
goCmd.Env = a.mgr.BuildEnv()
goOut, goErr := goCmd.CombinedOutput()
allLines = append(allLines, "═══ Go Tests ═══")
goOutput := strings.TrimSpace(string(goOut))
if goOutput == "" && goErr != nil {
goOutput = "Error: " + goErr.Error()
}
allLines = append(allLines, strings.Split(goOutput, "\n")...)
if goErr != nil {
allLines = append(allLines, "", "Go tests FAILED — skipping E2E")
return puretui.MsgTestsDone{Kind: puretui.TestKindAll, Passed: false, Output: allLines}
}
// E2E tests
allLines = append(allLines, "", "═══ E2E Tests ═══")
e2eCmd := exec.Command("bash", "./dev-scripts/e2e/run.sh")
e2eCmd.Env = a.mgr.BuildEnv()
e2eOut, e2eErr := e2eCmd.CombinedOutput()
e2eOutput := strings.TrimSpace(string(e2eOut))
if e2eOutput == "" && e2eErr != nil {
e2eOutput = "Error: " + e2eErr.Error()
}
allLines = append(allLines, strings.Split(e2eOutput, "\n")...)
return puretui.MsgTestsDone{Kind: puretui.TestKindAll, Passed: e2eErr == nil, Output: allLines}
}
}
func (a *Adapter) tick() tea.Cmd {
return tea.Tick(3*time.Second, func(time.Time) tea.Msg {
return puretui.MsgTick{}
})
}
// ── formatting helpers ───────────────────────────────────────────────────
func formatUptime(secs int64) string {
if secs < 0 {
return "n/a"
}
d := secs / 86400
h := (secs % 86400) / 3600
m := (secs % 3600) / 60
if d > 0 {
return fmt.Sprintf("%dd %dh", d, h)
}
if h > 0 {
return fmt.Sprintf("%dh %dm", h, m)
}
return fmt.Sprintf("%dm", m)
}
func formatBytes(bytes int64) string {
switch {
case bytes >= 1<<30:
return fmt.Sprintf("%.1f GB", float64(bytes)/float64(1<<30))
case bytes >= 1<<20:
return fmt.Sprintf("%.1f MB", float64(bytes)/float64(1<<20))
case bytes >= 1<<10:
return fmt.Sprintf("%.1f KB", float64(bytes)/float64(1<<10))
default:
return fmt.Sprintf("%d B", bytes)
}
}