feat: implement memory management system with SQLite persistence, including conversation windows and episodic facts
This commit is contained in:
+165
-8
@@ -8,21 +8,27 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/decision"
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/memory"
|
||||
"github.com/enmanuel/agents/pkg/personality"
|
||||
"github.com/enmanuel/agents/shell/effects"
|
||||
shelllm "github.com/enmanuel/agents/shell/llm"
|
||||
"github.com/enmanuel/agents/shell/matrix"
|
||||
shellmem "github.com/enmanuel/agents/shell/memory"
|
||||
"github.com/enmanuel/agents/shell/ssh"
|
||||
"github.com/enmanuel/agents/tools"
|
||||
)
|
||||
|
||||
const defaultMaxToolIterations = 5
|
||||
const (
|
||||
defaultMaxToolIterations = 5
|
||||
defaultWindowSize = 20
|
||||
)
|
||||
|
||||
// Agent is the assembled runtime: pure core + impure shell.
|
||||
type Agent struct {
|
||||
@@ -36,6 +42,20 @@ type Agent struct {
|
||||
toolReg *tools.Registry
|
||||
logger *slog.Logger
|
||||
cryptoStore io.Closer // non-nil when E2EE is enabled; closed on shutdown
|
||||
|
||||
// Memory
|
||||
windows map[string]memory.Window
|
||||
windowsMu sync.RWMutex
|
||||
memStore memory.Store // nil when memory is disabled
|
||||
windowSize int
|
||||
roomCtx *tools.RoomContext
|
||||
}
|
||||
|
||||
// ClearWindow resets the conversation window for a room. Implements tools.WindowClearer.
|
||||
func (a *Agent) ClearWindow(roomID string) {
|
||||
a.windowsMu.Lock()
|
||||
defer a.windowsMu.Unlock()
|
||||
a.windows[roomID] = memory.NewWindow(a.windowSize)
|
||||
}
|
||||
|
||||
// New assembles an Agent from its config, rules, and logger.
|
||||
@@ -100,8 +120,31 @@ func New(cfg *config.AgentConfig, rules []decision.Rule, logger *slog.Logger) (*
|
||||
// Effects runner
|
||||
runner := effects.NewRunner(matrixClient, sshExec, logger)
|
||||
|
||||
// Memory subsystem
|
||||
var memStore memory.Store
|
||||
windowSize := defaultWindowSize
|
||||
roomCtx := &tools.RoomContext{}
|
||||
|
||||
if cfg.Memory.Enabled {
|
||||
windowSize = cfg.Memory.WindowSize
|
||||
if windowSize <= 0 {
|
||||
windowSize = defaultWindowSize
|
||||
}
|
||||
|
||||
dbPath := cfg.Memory.DBPath
|
||||
if dbPath == "" {
|
||||
dbPath = filepath.Join("agents", cfg.Agent.ID, "data", "memory.db")
|
||||
}
|
||||
store, err := shellmem.New(dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("memory store: %w", err)
|
||||
}
|
||||
memStore = store
|
||||
logger.Info("memory enabled", "window_size", windowSize, "db", dbPath)
|
||||
}
|
||||
|
||||
// Tool registry — register tools enabled in config
|
||||
toolReg := buildToolRegistry(cfg, sshExec, matrixClient, logger)
|
||||
toolReg := buildToolRegistry(cfg, sshExec, matrixClient, memStore, roomCtx, logger)
|
||||
|
||||
a := &Agent{
|
||||
cfg: cfg,
|
||||
@@ -112,6 +155,15 @@ func New(cfg *config.AgentConfig, rules []decision.Rule, logger *slog.Logger) (*
|
||||
toolReg: toolReg,
|
||||
logger: logger,
|
||||
cryptoStore: cryptoStore,
|
||||
windows: make(map[string]memory.Window),
|
||||
memStore: memStore,
|
||||
windowSize: windowSize,
|
||||
roomCtx: roomCtx,
|
||||
}
|
||||
|
||||
// Register memory_clear_context with self as WindowClearer (after a is created)
|
||||
if cfg.Tools.Memory.Enabled && memStore != nil {
|
||||
toolReg.Register(tools.NewMemoryClearContext(a, roomCtx))
|
||||
}
|
||||
|
||||
// Matrix event listener
|
||||
@@ -125,6 +177,9 @@ func (a *Agent) Run(ctx context.Context) error {
|
||||
if a.cryptoStore != nil {
|
||||
defer a.cryptoStore.Close()
|
||||
}
|
||||
if a.memStore != nil {
|
||||
defer a.memStore.Close()
|
||||
}
|
||||
a.logger.Info("agent starting",
|
||||
"id", a.cfg.Agent.ID,
|
||||
"name", a.cfg.Agent.Name,
|
||||
@@ -142,9 +197,14 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
||||
"command", msgCtx.Command,
|
||||
)
|
||||
|
||||
roomID := evt.RoomID.String()
|
||||
|
||||
// Update room context for memory tools
|
||||
a.roomCtx.Set(roomID)
|
||||
|
||||
if a.cfg.Personality.Behavior.TypingIndicator {
|
||||
_ = a.matrix.SendTyping(ctx, evt.RoomID.String(), true)
|
||||
defer a.matrix.SendTyping(ctx, evt.RoomID.String(), false)
|
||||
_ = a.matrix.SendTyping(ctx, roomID, true)
|
||||
defer a.matrix.SendTyping(ctx, roomID, false)
|
||||
}
|
||||
|
||||
actions := decision.Evaluate(msgCtx, a.rules)
|
||||
@@ -171,6 +231,13 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
||||
expanded := make([]decision.Action, 0, len(actions))
|
||||
for _, act := range actions {
|
||||
if act.Kind == decision.ActionKindLLM {
|
||||
// Memory: load window + append user message before LLM call
|
||||
a.ensureWindowLoaded(ctx, roomID)
|
||||
a.appendToWindow(roomID, coretypes.Message{
|
||||
Role: coretypes.RoleUser, Content: msgCtx.Content,
|
||||
})
|
||||
a.persistMessage(ctx, roomID, coretypes.RoleUser, msgCtx.Content)
|
||||
|
||||
reply, err := a.runLLM(ctx, msgCtx)
|
||||
if err != nil {
|
||||
a.logger.Error("llm error", "err", err)
|
||||
@@ -183,13 +250,19 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
||||
Kind: decision.ActionKindReply,
|
||||
Reply: &decision.ReplyAction{Content: reply},
|
||||
})
|
||||
|
||||
// Memory: append assistant reply after LLM call
|
||||
a.appendToWindow(roomID, coretypes.Message{
|
||||
Role: coretypes.RoleAssistant, Content: reply,
|
||||
})
|
||||
a.persistMessage(ctx, roomID, coretypes.RoleAssistant, reply)
|
||||
}
|
||||
} else {
|
||||
expanded = append(expanded, act)
|
||||
}
|
||||
}
|
||||
|
||||
a.runner.Execute(ctx, evt.RoomID.String(), expanded)
|
||||
a.runner.Execute(ctx, roomID, expanded)
|
||||
}
|
||||
|
||||
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (string, error) {
|
||||
@@ -201,8 +274,13 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (str
|
||||
// Load system prompt from file if configured, else use description
|
||||
systemPrompt := a.cfg.Agent.Description
|
||||
|
||||
messages := []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: msgCtx.Content},
|
||||
// Build messages: conversation history from window (includes current user msg)
|
||||
messages := a.getWindowMessages(msgCtx.RoomID)
|
||||
if len(messages) == 0 {
|
||||
// Fallback if memory is disabled: just the current message
|
||||
messages = []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: msgCtx.Content},
|
||||
}
|
||||
}
|
||||
|
||||
// Build tool specs for the LLM if tool_use is enabled
|
||||
@@ -294,8 +372,78 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (str
|
||||
return "I've reached the maximum number of tool iterations. Here's what I found so far.", nil
|
||||
}
|
||||
|
||||
// ── Memory helpers ───────────────────────────────────────────────────────
|
||||
|
||||
// ensureWindowLoaded loads the conversation window from SQLite on first access for a room.
|
||||
func (a *Agent) ensureWindowLoaded(ctx context.Context, roomID string) {
|
||||
a.windowsMu.Lock()
|
||||
defer a.windowsMu.Unlock()
|
||||
if _, ok := a.windows[roomID]; ok {
|
||||
return
|
||||
}
|
||||
w := memory.NewWindow(a.windowSize)
|
||||
if a.memStore != nil {
|
||||
msgs, err := a.memStore.LoadMessages(ctx, a.cfg.Agent.ID, roomID, a.windowSize)
|
||||
if err != nil {
|
||||
a.logger.Warn("failed to load message history", "room", roomID, "err", err)
|
||||
} else {
|
||||
for _, m := range msgs {
|
||||
w = w.Append(coretypes.Message{Role: m.Role, Content: m.Content})
|
||||
}
|
||||
if len(msgs) > 0 {
|
||||
a.logger.Debug("loaded message history", "room", roomID, "count", len(msgs))
|
||||
}
|
||||
}
|
||||
}
|
||||
a.windows[roomID] = w
|
||||
}
|
||||
|
||||
// appendToWindow adds a message to the in-memory conversation window.
|
||||
func (a *Agent) appendToWindow(roomID string, msg coretypes.Message) {
|
||||
a.windowsMu.Lock()
|
||||
defer a.windowsMu.Unlock()
|
||||
w, ok := a.windows[roomID]
|
||||
if !ok {
|
||||
w = memory.NewWindow(a.windowSize)
|
||||
}
|
||||
a.windows[roomID] = w.Append(msg)
|
||||
}
|
||||
|
||||
// getWindowMessages returns a copy of the conversation window for a room.
|
||||
func (a *Agent) getWindowMessages(roomID string) []coretypes.Message {
|
||||
a.windowsMu.RLock()
|
||||
defer a.windowsMu.RUnlock()
|
||||
w, ok := a.windows[roomID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return w.ToLLMMessages()
|
||||
}
|
||||
|
||||
// persistMessage saves a message to the SQLite store (no-op if store is nil).
|
||||
func (a *Agent) persistMessage(ctx context.Context, roomID string, role coretypes.Role, content string) {
|
||||
if a.memStore == nil {
|
||||
return
|
||||
}
|
||||
if err := a.memStore.SaveMessage(ctx, memory.HistoryMessage{
|
||||
AgentID: a.cfg.Agent.ID,
|
||||
RoomID: roomID,
|
||||
Role: role,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
a.logger.Warn("failed to persist message", "room", roomID, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// buildToolRegistry creates a Registry with tools enabled in the agent's config.
|
||||
func buildToolRegistry(cfg *config.AgentConfig, sshExec *ssh.Executor, matrixClient *matrix.Client, logger *slog.Logger) *tools.Registry {
|
||||
func buildToolRegistry(
|
||||
cfg *config.AgentConfig,
|
||||
sshExec *ssh.Executor,
|
||||
matrixClient *matrix.Client,
|
||||
memStore memory.Store,
|
||||
roomCtx *tools.RoomContext,
|
||||
logger *slog.Logger,
|
||||
) *tools.Registry {
|
||||
reg := tools.NewRegistry()
|
||||
|
||||
if cfg.Tools.HTTP.Enabled {
|
||||
@@ -322,5 +470,14 @@ func buildToolRegistry(cfg *config.AgentConfig, sshExec *ssh.Executor, matrixCli
|
||||
reg.Register(tools.NewMatrixSend(matrixClient))
|
||||
logger.Debug("registered matrix tool")
|
||||
|
||||
// Memory tools (memory_clear_context registered later since it needs the Agent)
|
||||
if cfg.Tools.Memory.Enabled && memStore != nil {
|
||||
reg.Register(tools.NewMemorySave(cfg.Agent.ID, memStore))
|
||||
reg.Register(tools.NewMemoryRecall(cfg.Agent.ID, memStore))
|
||||
reg.Register(tools.NewMemoryForget(cfg.Agent.ID, memStore))
|
||||
reg.Register(tools.NewMemorySummary(cfg.Agent.ID, memStore))
|
||||
logger.Debug("registered memory tools")
|
||||
}
|
||||
|
||||
return reg
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ type AgentConfig struct {
|
||||
Observability ObservabilityCfg `yaml:"observability"`
|
||||
Resilience ResilienceCfg `yaml:"resilience"`
|
||||
Storage StorageCfg `yaml:"storage"`
|
||||
Memory MemoryCfg `yaml:"memory"`
|
||||
}
|
||||
|
||||
// ── Identity ──────────────────────────────────────────────────────────────
|
||||
@@ -107,6 +108,7 @@ type ToolsCfg struct {
|
||||
Scripts ScriptsCfg `yaml:"scripts"`
|
||||
FileOps FileOpsCfg `yaml:"file_ops"`
|
||||
MCP MCPToolCfg `yaml:"mcp"`
|
||||
Memory MemoryToolCfg `yaml:"memory"`
|
||||
}
|
||||
|
||||
type SSHToolCfg struct {
|
||||
@@ -384,3 +386,15 @@ type HistoryStorageCfg struct {
|
||||
Path string `yaml:"path"`
|
||||
Retention time.Duration `yaml:"retention"`
|
||||
}
|
||||
|
||||
// ── Memory ────────────────────────────────────────────────────────────────
|
||||
|
||||
type MemoryCfg struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
WindowSize int `yaml:"window_size"` // sliding window size per room (default 20)
|
||||
DBPath string `yaml:"db_path"` // SQLite path (default agents/<id>/data/memory.db)
|
||||
}
|
||||
|
||||
type MemoryToolCfg struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package memory
|
||||
|
||||
import "context"
|
||||
|
||||
// Store is the interface for persistent memory operations.
|
||||
// Defined in the pure package; implemented by shell/memory.
|
||||
type Store interface {
|
||||
// Facts
|
||||
SaveFact(ctx context.Context, fact Fact) error
|
||||
RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]Fact, error)
|
||||
DeleteFacts(ctx context.Context, agentID, subject string, key *string) error
|
||||
|
||||
// Message history
|
||||
SaveMessage(ctx context.Context, msg HistoryMessage) error
|
||||
LoadMessages(ctx context.Context, agentID, roomID string, limit int) ([]HistoryMessage, error)
|
||||
DeleteMessages(ctx context.Context, agentID string, roomID *string) error
|
||||
|
||||
// Lifecycle
|
||||
Close() error
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// Package memory provides pure types for agent memory: conversation windows and episodic facts.
|
||||
package memory
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// Fact is a single episodic fact: a key-value pair scoped to a subject.
|
||||
type Fact struct {
|
||||
AgentID string
|
||||
Subject string
|
||||
Key string
|
||||
Value string
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// HistoryMessage is a persisted conversation message.
|
||||
type HistoryMessage struct {
|
||||
AgentID string
|
||||
RoomID string
|
||||
Role llm.Role
|
||||
Content string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package memory
|
||||
|
||||
import "github.com/enmanuel/agents/pkg/llm"
|
||||
|
||||
// Window is an immutable sliding window of conversation messages for a single room.
|
||||
type Window struct {
|
||||
messages []llm.Message
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// NewWindow creates an empty window with the given capacity.
|
||||
func NewWindow(maxSize int) Window {
|
||||
return Window{maxSize: maxSize}
|
||||
}
|
||||
|
||||
// Append returns a new Window with the message added, dropping the oldest
|
||||
// messages if capacity is exceeded.
|
||||
func (w Window) Append(msg llm.Message) Window {
|
||||
msgs := make([]llm.Message, len(w.messages), len(w.messages)+1)
|
||||
copy(msgs, w.messages)
|
||||
msgs = append(msgs, msg)
|
||||
if len(msgs) > w.maxSize {
|
||||
msgs = msgs[len(msgs)-w.maxSize:]
|
||||
}
|
||||
return Window{messages: msgs, maxSize: w.maxSize}
|
||||
}
|
||||
|
||||
// ToLLMMessages returns a copy of the window contents as []llm.Message.
|
||||
func (w Window) ToLLMMessages() []llm.Message {
|
||||
out := make([]llm.Message, len(w.messages))
|
||||
copy(out, w.messages)
|
||||
return out
|
||||
}
|
||||
|
||||
// Len returns the number of messages in the window.
|
||||
func (w Window) Len() int {
|
||||
return len(w.messages)
|
||||
}
|
||||
|
||||
// Clear returns an empty window with the same capacity.
|
||||
func (w Window) Clear() Window {
|
||||
return NewWindow(w.maxSize)
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
// Package shellmem implements persistent memory storage using SQLite.
|
||||
package shellmem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/llm"
|
||||
"github.com/enmanuel/agents/pkg/memory"
|
||||
)
|
||||
|
||||
const schema = `
|
||||
CREATE TABLE IF NOT EXISTS facts (
|
||||
agent_id TEXT NOT NULL,
|
||||
subject TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (agent_id, subject, key)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_room ON messages(agent_id, room_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(agent_id, subject);
|
||||
`
|
||||
|
||||
// SQLiteStore implements memory.Store using SQLite.
|
||||
type SQLiteStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// New opens (or creates) a SQLite database at dbPath and runs migrations.
|
||||
func New(dbPath string) (*SQLiteStore, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create memory db dir: %w", err)
|
||||
}
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open memory db: %w", err)
|
||||
}
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("migrate memory db: %w", err)
|
||||
}
|
||||
return &SQLiteStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SaveFact(ctx context.Context, f memory.Fact) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT OR REPLACE INTO facts (agent_id, subject, key, value, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
f.AgentID, f.Subject, f.Key, f.Value, time.Now().UTC(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]memory.Fact, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if key != nil {
|
||||
rows, err = s.db.QueryContext(ctx,
|
||||
`SELECT agent_id, subject, key, value, updated_at FROM facts
|
||||
WHERE agent_id = ? AND subject = ? AND key = ?`,
|
||||
agentID, subject, *key,
|
||||
)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx,
|
||||
`SELECT agent_id, subject, key, value, updated_at FROM facts
|
||||
WHERE agent_id = ? AND subject = ?`,
|
||||
agentID, subject,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var facts []memory.Fact
|
||||
for rows.Next() {
|
||||
var f memory.Fact
|
||||
if err := rows.Scan(&f.AgentID, &f.Subject, &f.Key, &f.Value, &f.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
facts = append(facts, f)
|
||||
}
|
||||
return facts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteFacts(ctx context.Context, agentID, subject string, key *string) error {
|
||||
if key != nil {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM facts WHERE agent_id = ? AND subject = ? AND key = ?`,
|
||||
agentID, subject, *key,
|
||||
)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM facts WHERE agent_id = ? AND subject = ?`,
|
||||
agentID, subject,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SaveMessage(ctx context.Context, m memory.HistoryMessage) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO messages (agent_id, room_id, role, content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
m.AgentID, m.RoomID, string(m.Role), m.Content, time.Now().UTC(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) LoadMessages(ctx context.Context, agentID, roomID string, limit int) ([]memory.HistoryMessage, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT agent_id, room_id, role, content, created_at FROM messages
|
||||
WHERE agent_id = ? AND room_id = ?
|
||||
ORDER BY created_at DESC LIMIT ?`,
|
||||
agentID, roomID, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var msgs []memory.HistoryMessage
|
||||
for rows.Next() {
|
||||
var m memory.HistoryMessage
|
||||
var role string
|
||||
if err := rows.Scan(&m.AgentID, &m.RoomID, &role, &m.Content, &m.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Role = llm.Role(role)
|
||||
msgs = append(msgs, m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reverse to chronological order
|
||||
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||||
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteMessages(ctx context.Context, agentID string, roomID *string) error {
|
||||
if roomID != nil {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM messages WHERE agent_id = ? AND room_id = ?`,
|
||||
agentID, *roomID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM messages WHERE agent_id = ?`,
|
||||
agentID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
+199
@@ -0,0 +1,199 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/pkg/memory"
|
||||
)
|
||||
|
||||
// MemoryStore is the subset of memory.Store needed by memory tools.
|
||||
type MemoryStore interface {
|
||||
SaveFact(ctx context.Context, fact memory.Fact) error
|
||||
RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]memory.Fact, error)
|
||||
DeleteFacts(ctx context.Context, agentID, subject string, key *string) error
|
||||
}
|
||||
|
||||
// WindowClearer allows tools to clear the conversation window for a room.
|
||||
type WindowClearer interface {
|
||||
ClearWindow(roomID string)
|
||||
}
|
||||
|
||||
// RoomContext is a thread-safe holder for the current room ID.
|
||||
// Set by the runtime before each event handling; read by memory_clear_context.
|
||||
type RoomContext struct {
|
||||
mu sync.RWMutex
|
||||
roomID string
|
||||
}
|
||||
|
||||
// Set updates the current room ID.
|
||||
func (rc *RoomContext) Set(roomID string) {
|
||||
rc.mu.Lock()
|
||||
rc.roomID = roomID
|
||||
rc.mu.Unlock()
|
||||
}
|
||||
|
||||
// Get returns the current room ID.
|
||||
func (rc *RoomContext) Get() string {
|
||||
rc.mu.RLock()
|
||||
defer rc.mu.RUnlock()
|
||||
return rc.roomID
|
||||
}
|
||||
|
||||
// NewMemorySave creates a tool that saves a fact to long-term memory.
|
||||
func NewMemorySave(agentID string, store MemoryStore) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "memory_save",
|
||||
Description: "Save a fact to long-term memory. Use this to remember important information about users, topics, or preferences.",
|
||||
Parameters: []Param{
|
||||
{Name: "subject", Type: "string", Description: "The subject this fact is about (e.g. a username, a topic)", Required: true},
|
||||
{Name: "key", Type: "string", Description: "The fact key (e.g. 'favorite_language', 'timezone')", Required: true},
|
||||
{Name: "value", Type: "string", Description: "The fact value to store", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
subject := getString(args, "subject")
|
||||
key := getString(args, "key")
|
||||
value := getString(args, "value")
|
||||
if subject == "" || key == "" || value == "" {
|
||||
return Result{Err: fmt.Errorf("memory_save: subject, key, and value are required")}
|
||||
}
|
||||
err := store.SaveFact(ctx, memory.Fact{
|
||||
AgentID: agentID,
|
||||
Subject: subject,
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("memory_save: %w", err)}
|
||||
}
|
||||
return Result{Output: fmt.Sprintf("saved: %s.%s = %s", subject, key, value)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemoryRecall creates a tool that retrieves facts from long-term memory.
|
||||
func NewMemoryRecall(agentID string, store MemoryStore) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "memory_recall",
|
||||
Description: "Recall facts from long-term memory about a subject. Omit key to get all facts for the subject.",
|
||||
Parameters: []Param{
|
||||
{Name: "subject", Type: "string", Description: "The subject to recall facts about", Required: true},
|
||||
{Name: "key", Type: "string", Description: "Optional specific fact key to recall", Required: false},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
subject := getString(args, "subject")
|
||||
if subject == "" {
|
||||
return Result{Err: fmt.Errorf("memory_recall: subject is required")}
|
||||
}
|
||||
var keyPtr *string
|
||||
if k := getString(args, "key"); k != "" {
|
||||
keyPtr = &k
|
||||
}
|
||||
facts, err := store.RecallFacts(ctx, agentID, subject, keyPtr)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("memory_recall: %w", err)}
|
||||
}
|
||||
if len(facts) == 0 {
|
||||
return Result{Output: fmt.Sprintf("no facts found for subject %q", subject)}
|
||||
}
|
||||
var sb strings.Builder
|
||||
for _, f := range facts {
|
||||
fmt.Fprintf(&sb, "%s.%s = %s\n", f.Subject, f.Key, f.Value)
|
||||
}
|
||||
return Result{Output: sb.String()}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemoryForget creates a tool that deletes facts from long-term memory.
|
||||
func NewMemoryForget(agentID string, store MemoryStore) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "memory_forget",
|
||||
Description: "Delete facts from long-term memory. Omit key to delete all facts for the subject.",
|
||||
Parameters: []Param{
|
||||
{Name: "subject", Type: "string", Description: "The subject whose facts to delete", Required: true},
|
||||
{Name: "key", Type: "string", Description: "Optional specific fact key to delete; omit to delete all", Required: false},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
subject := getString(args, "subject")
|
||||
if subject == "" {
|
||||
return Result{Err: fmt.Errorf("memory_forget: subject is required")}
|
||||
}
|
||||
var keyPtr *string
|
||||
if k := getString(args, "key"); k != "" {
|
||||
keyPtr = &k
|
||||
}
|
||||
err := store.DeleteFacts(ctx, agentID, subject, keyPtr)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("memory_forget: %w", err)}
|
||||
}
|
||||
if keyPtr != nil {
|
||||
return Result{Output: fmt.Sprintf("forgot %s.%s", subject, *keyPtr)}
|
||||
}
|
||||
return Result{Output: fmt.Sprintf("forgot all facts about %s", subject)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemoryClearContext creates a tool that clears the conversation window.
|
||||
func NewMemoryClearContext(clearer WindowClearer, roomCtx *RoomContext) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "memory_clear_context",
|
||||
Description: "Clear the conversation context window. Useful to start fresh. Omit room_id to clear the current room.",
|
||||
Parameters: []Param{
|
||||
{Name: "room_id", Type: "string", Description: "Optional room ID to clear; defaults to current room", Required: false},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
roomID := getString(args, "room_id")
|
||||
if roomID == "" {
|
||||
roomID = roomCtx.Get()
|
||||
}
|
||||
if roomID == "" {
|
||||
return Result{Err: fmt.Errorf("memory_clear_context: no room_id provided and no current room")}
|
||||
}
|
||||
clearer.ClearWindow(roomID)
|
||||
return Result{Output: fmt.Sprintf("conversation context cleared for room %s", roomID)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemorySummary creates a tool that saves an important summary to long-term memory.
|
||||
func NewMemorySummary(agentID string, store MemoryStore) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "memory_summary",
|
||||
Description: "Save an important summary or takeaway from the current conversation to long-term memory.",
|
||||
Parameters: []Param{
|
||||
{Name: "text", Type: "string", Description: "The summary text to save", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
text := getString(args, "text")
|
||||
if text == "" {
|
||||
return Result{Err: fmt.Errorf("memory_summary: text is required")}
|
||||
}
|
||||
key := time.Now().UTC().Format("2006-01-02T15:04:05")
|
||||
err := store.SaveFact(ctx, memory.Fact{
|
||||
AgentID: agentID,
|
||||
Subject: "_summary",
|
||||
Key: key,
|
||||
Value: text,
|
||||
})
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("memory_summary: %w", err)}
|
||||
}
|
||||
return Result{Output: "summary saved"}
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user