feat: implement memory management system with SQLite persistence, including conversation windows and episodic facts
This commit is contained in:
+165
-8
@@ -8,21 +8,27 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"maunium.net/go/mautrix/event"
|
"maunium.net/go/mautrix/event"
|
||||||
|
|
||||||
"github.com/enmanuel/agents/internal/config"
|
"github.com/enmanuel/agents/internal/config"
|
||||||
"github.com/enmanuel/agents/pkg/decision"
|
"github.com/enmanuel/agents/pkg/decision"
|
||||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||||
|
"github.com/enmanuel/agents/pkg/memory"
|
||||||
"github.com/enmanuel/agents/pkg/personality"
|
"github.com/enmanuel/agents/pkg/personality"
|
||||||
"github.com/enmanuel/agents/shell/effects"
|
"github.com/enmanuel/agents/shell/effects"
|
||||||
shelllm "github.com/enmanuel/agents/shell/llm"
|
shelllm "github.com/enmanuel/agents/shell/llm"
|
||||||
"github.com/enmanuel/agents/shell/matrix"
|
"github.com/enmanuel/agents/shell/matrix"
|
||||||
|
shellmem "github.com/enmanuel/agents/shell/memory"
|
||||||
"github.com/enmanuel/agents/shell/ssh"
|
"github.com/enmanuel/agents/shell/ssh"
|
||||||
"github.com/enmanuel/agents/tools"
|
"github.com/enmanuel/agents/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultMaxToolIterations = 5
|
const (
|
||||||
|
defaultMaxToolIterations = 5
|
||||||
|
defaultWindowSize = 20
|
||||||
|
)
|
||||||
|
|
||||||
// Agent is the assembled runtime: pure core + impure shell.
|
// Agent is the assembled runtime: pure core + impure shell.
|
||||||
type Agent struct {
|
type Agent struct {
|
||||||
@@ -36,6 +42,20 @@ type Agent struct {
|
|||||||
toolReg *tools.Registry
|
toolReg *tools.Registry
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
cryptoStore io.Closer // non-nil when E2EE is enabled; closed on shutdown
|
cryptoStore io.Closer // non-nil when E2EE is enabled; closed on shutdown
|
||||||
|
|
||||||
|
// Memory
|
||||||
|
windows map[string]memory.Window
|
||||||
|
windowsMu sync.RWMutex
|
||||||
|
memStore memory.Store // nil when memory is disabled
|
||||||
|
windowSize int
|
||||||
|
roomCtx *tools.RoomContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearWindow resets the conversation window for a room. Implements tools.WindowClearer.
|
||||||
|
func (a *Agent) ClearWindow(roomID string) {
|
||||||
|
a.windowsMu.Lock()
|
||||||
|
defer a.windowsMu.Unlock()
|
||||||
|
a.windows[roomID] = memory.NewWindow(a.windowSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New assembles an Agent from its config, rules, and logger.
|
// New assembles an Agent from its config, rules, and logger.
|
||||||
@@ -100,8 +120,31 @@ func New(cfg *config.AgentConfig, rules []decision.Rule, logger *slog.Logger) (*
|
|||||||
// Effects runner
|
// Effects runner
|
||||||
runner := effects.NewRunner(matrixClient, sshExec, logger)
|
runner := effects.NewRunner(matrixClient, sshExec, logger)
|
||||||
|
|
||||||
|
// Memory subsystem
|
||||||
|
var memStore memory.Store
|
||||||
|
windowSize := defaultWindowSize
|
||||||
|
roomCtx := &tools.RoomContext{}
|
||||||
|
|
||||||
|
if cfg.Memory.Enabled {
|
||||||
|
windowSize = cfg.Memory.WindowSize
|
||||||
|
if windowSize <= 0 {
|
||||||
|
windowSize = defaultWindowSize
|
||||||
|
}
|
||||||
|
|
||||||
|
dbPath := cfg.Memory.DBPath
|
||||||
|
if dbPath == "" {
|
||||||
|
dbPath = filepath.Join("agents", cfg.Agent.ID, "data", "memory.db")
|
||||||
|
}
|
||||||
|
store, err := shellmem.New(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("memory store: %w", err)
|
||||||
|
}
|
||||||
|
memStore = store
|
||||||
|
logger.Info("memory enabled", "window_size", windowSize, "db", dbPath)
|
||||||
|
}
|
||||||
|
|
||||||
// Tool registry — register tools enabled in config
|
// Tool registry — register tools enabled in config
|
||||||
toolReg := buildToolRegistry(cfg, sshExec, matrixClient, logger)
|
toolReg := buildToolRegistry(cfg, sshExec, matrixClient, memStore, roomCtx, logger)
|
||||||
|
|
||||||
a := &Agent{
|
a := &Agent{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -112,6 +155,15 @@ func New(cfg *config.AgentConfig, rules []decision.Rule, logger *slog.Logger) (*
|
|||||||
toolReg: toolReg,
|
toolReg: toolReg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
cryptoStore: cryptoStore,
|
cryptoStore: cryptoStore,
|
||||||
|
windows: make(map[string]memory.Window),
|
||||||
|
memStore: memStore,
|
||||||
|
windowSize: windowSize,
|
||||||
|
roomCtx: roomCtx,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register memory_clear_context with self as WindowClearer (after a is created)
|
||||||
|
if cfg.Tools.Memory.Enabled && memStore != nil {
|
||||||
|
toolReg.Register(tools.NewMemoryClearContext(a, roomCtx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix event listener
|
// Matrix event listener
|
||||||
@@ -125,6 +177,9 @@ func (a *Agent) Run(ctx context.Context) error {
|
|||||||
if a.cryptoStore != nil {
|
if a.cryptoStore != nil {
|
||||||
defer a.cryptoStore.Close()
|
defer a.cryptoStore.Close()
|
||||||
}
|
}
|
||||||
|
if a.memStore != nil {
|
||||||
|
defer a.memStore.Close()
|
||||||
|
}
|
||||||
a.logger.Info("agent starting",
|
a.logger.Info("agent starting",
|
||||||
"id", a.cfg.Agent.ID,
|
"id", a.cfg.Agent.ID,
|
||||||
"name", a.cfg.Agent.Name,
|
"name", a.cfg.Agent.Name,
|
||||||
@@ -142,9 +197,14 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
|||||||
"command", msgCtx.Command,
|
"command", msgCtx.Command,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
roomID := evt.RoomID.String()
|
||||||
|
|
||||||
|
// Update room context for memory tools
|
||||||
|
a.roomCtx.Set(roomID)
|
||||||
|
|
||||||
if a.cfg.Personality.Behavior.TypingIndicator {
|
if a.cfg.Personality.Behavior.TypingIndicator {
|
||||||
_ = a.matrix.SendTyping(ctx, evt.RoomID.String(), true)
|
_ = a.matrix.SendTyping(ctx, roomID, true)
|
||||||
defer a.matrix.SendTyping(ctx, evt.RoomID.String(), false)
|
defer a.matrix.SendTyping(ctx, roomID, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
actions := decision.Evaluate(msgCtx, a.rules)
|
actions := decision.Evaluate(msgCtx, a.rules)
|
||||||
@@ -171,6 +231,13 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
|||||||
expanded := make([]decision.Action, 0, len(actions))
|
expanded := make([]decision.Action, 0, len(actions))
|
||||||
for _, act := range actions {
|
for _, act := range actions {
|
||||||
if act.Kind == decision.ActionKindLLM {
|
if act.Kind == decision.ActionKindLLM {
|
||||||
|
// Memory: load window + append user message before LLM call
|
||||||
|
a.ensureWindowLoaded(ctx, roomID)
|
||||||
|
a.appendToWindow(roomID, coretypes.Message{
|
||||||
|
Role: coretypes.RoleUser, Content: msgCtx.Content,
|
||||||
|
})
|
||||||
|
a.persistMessage(ctx, roomID, coretypes.RoleUser, msgCtx.Content)
|
||||||
|
|
||||||
reply, err := a.runLLM(ctx, msgCtx)
|
reply, err := a.runLLM(ctx, msgCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Error("llm error", "err", err)
|
a.logger.Error("llm error", "err", err)
|
||||||
@@ -183,13 +250,19 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
|||||||
Kind: decision.ActionKindReply,
|
Kind: decision.ActionKindReply,
|
||||||
Reply: &decision.ReplyAction{Content: reply},
|
Reply: &decision.ReplyAction{Content: reply},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Memory: append assistant reply after LLM call
|
||||||
|
a.appendToWindow(roomID, coretypes.Message{
|
||||||
|
Role: coretypes.RoleAssistant, Content: reply,
|
||||||
|
})
|
||||||
|
a.persistMessage(ctx, roomID, coretypes.RoleAssistant, reply)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
expanded = append(expanded, act)
|
expanded = append(expanded, act)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.runner.Execute(ctx, evt.RoomID.String(), expanded)
|
a.runner.Execute(ctx, roomID, expanded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (string, error) {
|
func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (string, error) {
|
||||||
@@ -201,8 +274,13 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (str
|
|||||||
// Load system prompt from file if configured, else use description
|
// Load system prompt from file if configured, else use description
|
||||||
systemPrompt := a.cfg.Agent.Description
|
systemPrompt := a.cfg.Agent.Description
|
||||||
|
|
||||||
messages := []coretypes.Message{
|
// Build messages: conversation history from window (includes current user msg)
|
||||||
{Role: coretypes.RoleUser, Content: msgCtx.Content},
|
messages := a.getWindowMessages(msgCtx.RoomID)
|
||||||
|
if len(messages) == 0 {
|
||||||
|
// Fallback if memory is disabled: just the current message
|
||||||
|
messages = []coretypes.Message{
|
||||||
|
{Role: coretypes.RoleUser, Content: msgCtx.Content},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build tool specs for the LLM if tool_use is enabled
|
// Build tool specs for the LLM if tool_use is enabled
|
||||||
@@ -294,8 +372,78 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (str
|
|||||||
return "I've reached the maximum number of tool iterations. Here's what I found so far.", nil
|
return "I've reached the maximum number of tool iterations. Here's what I found so far.", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Memory helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// ensureWindowLoaded loads the conversation window from SQLite on first access for a room.
|
||||||
|
func (a *Agent) ensureWindowLoaded(ctx context.Context, roomID string) {
|
||||||
|
a.windowsMu.Lock()
|
||||||
|
defer a.windowsMu.Unlock()
|
||||||
|
if _, ok := a.windows[roomID]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w := memory.NewWindow(a.windowSize)
|
||||||
|
if a.memStore != nil {
|
||||||
|
msgs, err := a.memStore.LoadMessages(ctx, a.cfg.Agent.ID, roomID, a.windowSize)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Warn("failed to load message history", "room", roomID, "err", err)
|
||||||
|
} else {
|
||||||
|
for _, m := range msgs {
|
||||||
|
w = w.Append(coretypes.Message{Role: m.Role, Content: m.Content})
|
||||||
|
}
|
||||||
|
if len(msgs) > 0 {
|
||||||
|
a.logger.Debug("loaded message history", "room", roomID, "count", len(msgs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.windows[roomID] = w
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendToWindow adds a message to the in-memory conversation window.
|
||||||
|
func (a *Agent) appendToWindow(roomID string, msg coretypes.Message) {
|
||||||
|
a.windowsMu.Lock()
|
||||||
|
defer a.windowsMu.Unlock()
|
||||||
|
w, ok := a.windows[roomID]
|
||||||
|
if !ok {
|
||||||
|
w = memory.NewWindow(a.windowSize)
|
||||||
|
}
|
||||||
|
a.windows[roomID] = w.Append(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWindowMessages returns a copy of the conversation window for a room.
|
||||||
|
func (a *Agent) getWindowMessages(roomID string) []coretypes.Message {
|
||||||
|
a.windowsMu.RLock()
|
||||||
|
defer a.windowsMu.RUnlock()
|
||||||
|
w, ok := a.windows[roomID]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ToLLMMessages()
|
||||||
|
}
|
||||||
|
|
||||||
|
// persistMessage saves a message to the SQLite store (no-op if store is nil).
|
||||||
|
func (a *Agent) persistMessage(ctx context.Context, roomID string, role coretypes.Role, content string) {
|
||||||
|
if a.memStore == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := a.memStore.SaveMessage(ctx, memory.HistoryMessage{
|
||||||
|
AgentID: a.cfg.Agent.ID,
|
||||||
|
RoomID: roomID,
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
}); err != nil {
|
||||||
|
a.logger.Warn("failed to persist message", "room", roomID, "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// buildToolRegistry creates a Registry with tools enabled in the agent's config.
|
// buildToolRegistry creates a Registry with tools enabled in the agent's config.
|
||||||
func buildToolRegistry(cfg *config.AgentConfig, sshExec *ssh.Executor, matrixClient *matrix.Client, logger *slog.Logger) *tools.Registry {
|
func buildToolRegistry(
|
||||||
|
cfg *config.AgentConfig,
|
||||||
|
sshExec *ssh.Executor,
|
||||||
|
matrixClient *matrix.Client,
|
||||||
|
memStore memory.Store,
|
||||||
|
roomCtx *tools.RoomContext,
|
||||||
|
logger *slog.Logger,
|
||||||
|
) *tools.Registry {
|
||||||
reg := tools.NewRegistry()
|
reg := tools.NewRegistry()
|
||||||
|
|
||||||
if cfg.Tools.HTTP.Enabled {
|
if cfg.Tools.HTTP.Enabled {
|
||||||
@@ -322,5 +470,14 @@ func buildToolRegistry(cfg *config.AgentConfig, sshExec *ssh.Executor, matrixCli
|
|||||||
reg.Register(tools.NewMatrixSend(matrixClient))
|
reg.Register(tools.NewMatrixSend(matrixClient))
|
||||||
logger.Debug("registered matrix tool")
|
logger.Debug("registered matrix tool")
|
||||||
|
|
||||||
|
// Memory tools (memory_clear_context registered later since it needs the Agent)
|
||||||
|
if cfg.Tools.Memory.Enabled && memStore != nil {
|
||||||
|
reg.Register(tools.NewMemorySave(cfg.Agent.ID, memStore))
|
||||||
|
reg.Register(tools.NewMemoryRecall(cfg.Agent.ID, memStore))
|
||||||
|
reg.Register(tools.NewMemoryForget(cfg.Agent.ID, memStore))
|
||||||
|
reg.Register(tools.NewMemorySummary(cfg.Agent.ID, memStore))
|
||||||
|
logger.Debug("registered memory tools")
|
||||||
|
}
|
||||||
|
|
||||||
return reg
|
return reg
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type AgentConfig struct {
|
|||||||
Observability ObservabilityCfg `yaml:"observability"`
|
Observability ObservabilityCfg `yaml:"observability"`
|
||||||
Resilience ResilienceCfg `yaml:"resilience"`
|
Resilience ResilienceCfg `yaml:"resilience"`
|
||||||
Storage StorageCfg `yaml:"storage"`
|
Storage StorageCfg `yaml:"storage"`
|
||||||
|
Memory MemoryCfg `yaml:"memory"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Identity ──────────────────────────────────────────────────────────────
|
// ── Identity ──────────────────────────────────────────────────────────────
|
||||||
@@ -107,6 +108,7 @@ type ToolsCfg struct {
|
|||||||
Scripts ScriptsCfg `yaml:"scripts"`
|
Scripts ScriptsCfg `yaml:"scripts"`
|
||||||
FileOps FileOpsCfg `yaml:"file_ops"`
|
FileOps FileOpsCfg `yaml:"file_ops"`
|
||||||
MCP MCPToolCfg `yaml:"mcp"`
|
MCP MCPToolCfg `yaml:"mcp"`
|
||||||
|
Memory MemoryToolCfg `yaml:"memory"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SSHToolCfg struct {
|
type SSHToolCfg struct {
|
||||||
@@ -384,3 +386,15 @@ type HistoryStorageCfg struct {
|
|||||||
Path string `yaml:"path"`
|
Path string `yaml:"path"`
|
||||||
Retention time.Duration `yaml:"retention"`
|
Retention time.Duration `yaml:"retention"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Memory ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type MemoryCfg struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
WindowSize int `yaml:"window_size"` // sliding window size per room (default 20)
|
||||||
|
DBPath string `yaml:"db_path"` // SQLite path (default agents/<id>/data/memory.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MemoryToolCfg struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package memory
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Store is the interface for persistent memory operations.
|
||||||
|
// Defined in the pure package; implemented by shell/memory.
|
||||||
|
type Store interface {
|
||||||
|
// Facts
|
||||||
|
SaveFact(ctx context.Context, fact Fact) error
|
||||||
|
RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]Fact, error)
|
||||||
|
DeleteFacts(ctx context.Context, agentID, subject string, key *string) error
|
||||||
|
|
||||||
|
// Message history
|
||||||
|
SaveMessage(ctx context.Context, msg HistoryMessage) error
|
||||||
|
LoadMessages(ctx context.Context, agentID, roomID string, limit int) ([]HistoryMessage, error)
|
||||||
|
DeleteMessages(ctx context.Context, agentID string, roomID *string) error
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
// Package memory provides pure types for agent memory: conversation windows and episodic facts.
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/enmanuel/agents/pkg/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Fact is a single episodic fact: a key-value pair scoped to a subject.
|
||||||
|
type Fact struct {
|
||||||
|
AgentID string
|
||||||
|
Subject string
|
||||||
|
Key string
|
||||||
|
Value string
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// HistoryMessage is a persisted conversation message.
|
||||||
|
type HistoryMessage struct {
|
||||||
|
AgentID string
|
||||||
|
RoomID string
|
||||||
|
Role llm.Role
|
||||||
|
Content string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package memory
|
||||||
|
|
||||||
|
import "github.com/enmanuel/agents/pkg/llm"
|
||||||
|
|
||||||
|
// Window is an immutable sliding window of conversation messages for a single room.
|
||||||
|
type Window struct {
|
||||||
|
messages []llm.Message
|
||||||
|
maxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWindow creates an empty window with the given capacity.
|
||||||
|
func NewWindow(maxSize int) Window {
|
||||||
|
return Window{maxSize: maxSize}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append returns a new Window with the message added, dropping the oldest
|
||||||
|
// messages if capacity is exceeded.
|
||||||
|
func (w Window) Append(msg llm.Message) Window {
|
||||||
|
msgs := make([]llm.Message, len(w.messages), len(w.messages)+1)
|
||||||
|
copy(msgs, w.messages)
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
if len(msgs) > w.maxSize {
|
||||||
|
msgs = msgs[len(msgs)-w.maxSize:]
|
||||||
|
}
|
||||||
|
return Window{messages: msgs, maxSize: w.maxSize}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLLMMessages returns a copy of the window contents as []llm.Message.
|
||||||
|
func (w Window) ToLLMMessages() []llm.Message {
|
||||||
|
out := make([]llm.Message, len(w.messages))
|
||||||
|
copy(out, w.messages)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of messages in the window.
|
||||||
|
func (w Window) Len() int {
|
||||||
|
return len(w.messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear returns an empty window with the same capacity.
|
||||||
|
func (w Window) Clear() Window {
|
||||||
|
return NewWindow(w.maxSize)
|
||||||
|
}
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
// Package shellmem implements persistent memory storage using SQLite.
|
||||||
|
package shellmem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/enmanuel/agents/pkg/llm"
|
||||||
|
"github.com/enmanuel/agents/pkg/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
const schema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS facts (
|
||||||
|
agent_id TEXT NOT NULL,
|
||||||
|
subject TEXT NOT NULL,
|
||||||
|
key TEXT NOT NULL,
|
||||||
|
value TEXT NOT NULL,
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (agent_id, subject, key)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
agent_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_room ON messages(agent_id, room_id, created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_facts_subject ON facts(agent_id, subject);
|
||||||
|
`
|
||||||
|
|
||||||
|
// SQLiteStore implements memory.Store using SQLite.
|
||||||
|
type SQLiteStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// New opens (or creates) a SQLite database at dbPath and runs migrations.
|
||||||
|
func New(dbPath string) (*SQLiteStore, error) {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||||
|
return nil, fmt.Errorf("create memory db dir: %w", err)
|
||||||
|
}
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open memory db: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(schema); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("migrate memory db: %w", err)
|
||||||
|
}
|
||||||
|
return &SQLiteStore{db: db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) SaveFact(ctx context.Context, f memory.Fact) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT OR REPLACE INTO facts (agent_id, subject, key, value, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
f.AgentID, f.Subject, f.Key, f.Value, time.Now().UTC(),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]memory.Fact, error) {
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if key != nil {
|
||||||
|
rows, err = s.db.QueryContext(ctx,
|
||||||
|
`SELECT agent_id, subject, key, value, updated_at FROM facts
|
||||||
|
WHERE agent_id = ? AND subject = ? AND key = ?`,
|
||||||
|
agentID, subject, *key,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx,
|
||||||
|
`SELECT agent_id, subject, key, value, updated_at FROM facts
|
||||||
|
WHERE agent_id = ? AND subject = ?`,
|
||||||
|
agentID, subject,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var facts []memory.Fact
|
||||||
|
for rows.Next() {
|
||||||
|
var f memory.Fact
|
||||||
|
if err := rows.Scan(&f.AgentID, &f.Subject, &f.Key, &f.Value, &f.UpdatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
facts = append(facts, f)
|
||||||
|
}
|
||||||
|
return facts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) DeleteFacts(ctx context.Context, agentID, subject string, key *string) error {
|
||||||
|
if key != nil {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM facts WHERE agent_id = ? AND subject = ? AND key = ?`,
|
||||||
|
agentID, subject, *key,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM facts WHERE agent_id = ? AND subject = ?`,
|
||||||
|
agentID, subject,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) SaveMessage(ctx context.Context, m memory.HistoryMessage) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT INTO messages (agent_id, room_id, role, content, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
m.AgentID, m.RoomID, string(m.Role), m.Content, time.Now().UTC(),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) LoadMessages(ctx context.Context, agentID, roomID string, limit int) ([]memory.HistoryMessage, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx,
|
||||||
|
`SELECT agent_id, room_id, role, content, created_at FROM messages
|
||||||
|
WHERE agent_id = ? AND room_id = ?
|
||||||
|
ORDER BY created_at DESC LIMIT ?`,
|
||||||
|
agentID, roomID, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var msgs []memory.HistoryMessage
|
||||||
|
for rows.Next() {
|
||||||
|
var m memory.HistoryMessage
|
||||||
|
var role string
|
||||||
|
if err := rows.Scan(&m.AgentID, &m.RoomID, &role, &m.Content, &m.CreatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.Role = llm.Role(role)
|
||||||
|
msgs = append(msgs, m)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse to chronological order
|
||||||
|
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||||||
|
}
|
||||||
|
return msgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) DeleteMessages(ctx context.Context, agentID string, roomID *string) error {
|
||||||
|
if roomID != nil {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM messages WHERE agent_id = ? AND room_id = ?`,
|
||||||
|
agentID, *roomID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM messages WHERE agent_id = ?`,
|
||||||
|
agentID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) Close() error {
|
||||||
|
return s.db.Close()
|
||||||
|
}
|
||||||
+199
@@ -0,0 +1,199 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/enmanuel/agents/pkg/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryStore is the subset of memory.Store needed by memory tools.
|
||||||
|
type MemoryStore interface {
|
||||||
|
SaveFact(ctx context.Context, fact memory.Fact) error
|
||||||
|
RecallFacts(ctx context.Context, agentID, subject string, key *string) ([]memory.Fact, error)
|
||||||
|
DeleteFacts(ctx context.Context, agentID, subject string, key *string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// WindowClearer allows tools to clear the conversation window for a room.
|
||||||
|
type WindowClearer interface {
|
||||||
|
ClearWindow(roomID string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomContext is a thread-safe holder for the current room ID.
|
||||||
|
// Set by the runtime before each event handling; read by memory_clear_context.
|
||||||
|
type RoomContext struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
roomID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set updates the current room ID.
|
||||||
|
func (rc *RoomContext) Set(roomID string) {
|
||||||
|
rc.mu.Lock()
|
||||||
|
rc.roomID = roomID
|
||||||
|
rc.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the current room ID.
|
||||||
|
func (rc *RoomContext) Get() string {
|
||||||
|
rc.mu.RLock()
|
||||||
|
defer rc.mu.RUnlock()
|
||||||
|
return rc.roomID
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemorySave creates a tool that saves a fact to long-term memory.
|
||||||
|
func NewMemorySave(agentID string, store MemoryStore) Tool {
|
||||||
|
return Tool{
|
||||||
|
Def: Def{
|
||||||
|
Name: "memory_save",
|
||||||
|
Description: "Save a fact to long-term memory. Use this to remember important information about users, topics, or preferences.",
|
||||||
|
Parameters: []Param{
|
||||||
|
{Name: "subject", Type: "string", Description: "The subject this fact is about (e.g. a username, a topic)", Required: true},
|
||||||
|
{Name: "key", Type: "string", Description: "The fact key (e.g. 'favorite_language', 'timezone')", Required: true},
|
||||||
|
{Name: "value", Type: "string", Description: "The fact value to store", Required: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||||
|
subject := getString(args, "subject")
|
||||||
|
key := getString(args, "key")
|
||||||
|
value := getString(args, "value")
|
||||||
|
if subject == "" || key == "" || value == "" {
|
||||||
|
return Result{Err: fmt.Errorf("memory_save: subject, key, and value are required")}
|
||||||
|
}
|
||||||
|
err := store.SaveFact(ctx, memory.Fact{
|
||||||
|
AgentID: agentID,
|
||||||
|
Subject: subject,
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return Result{Err: fmt.Errorf("memory_save: %w", err)}
|
||||||
|
}
|
||||||
|
return Result{Output: fmt.Sprintf("saved: %s.%s = %s", subject, key, value)}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryRecall creates a tool that retrieves facts from long-term memory.
|
||||||
|
func NewMemoryRecall(agentID string, store MemoryStore) Tool {
|
||||||
|
return Tool{
|
||||||
|
Def: Def{
|
||||||
|
Name: "memory_recall",
|
||||||
|
Description: "Recall facts from long-term memory about a subject. Omit key to get all facts for the subject.",
|
||||||
|
Parameters: []Param{
|
||||||
|
{Name: "subject", Type: "string", Description: "The subject to recall facts about", Required: true},
|
||||||
|
{Name: "key", Type: "string", Description: "Optional specific fact key to recall", Required: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||||
|
subject := getString(args, "subject")
|
||||||
|
if subject == "" {
|
||||||
|
return Result{Err: fmt.Errorf("memory_recall: subject is required")}
|
||||||
|
}
|
||||||
|
var keyPtr *string
|
||||||
|
if k := getString(args, "key"); k != "" {
|
||||||
|
keyPtr = &k
|
||||||
|
}
|
||||||
|
facts, err := store.RecallFacts(ctx, agentID, subject, keyPtr)
|
||||||
|
if err != nil {
|
||||||
|
return Result{Err: fmt.Errorf("memory_recall: %w", err)}
|
||||||
|
}
|
||||||
|
if len(facts) == 0 {
|
||||||
|
return Result{Output: fmt.Sprintf("no facts found for subject %q", subject)}
|
||||||
|
}
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, f := range facts {
|
||||||
|
fmt.Fprintf(&sb, "%s.%s = %s\n", f.Subject, f.Key, f.Value)
|
||||||
|
}
|
||||||
|
return Result{Output: sb.String()}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryForget creates a tool that deletes facts from long-term memory.
|
||||||
|
func NewMemoryForget(agentID string, store MemoryStore) Tool {
|
||||||
|
return Tool{
|
||||||
|
Def: Def{
|
||||||
|
Name: "memory_forget",
|
||||||
|
Description: "Delete facts from long-term memory. Omit key to delete all facts for the subject.",
|
||||||
|
Parameters: []Param{
|
||||||
|
{Name: "subject", Type: "string", Description: "The subject whose facts to delete", Required: true},
|
||||||
|
{Name: "key", Type: "string", Description: "Optional specific fact key to delete; omit to delete all", Required: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||||
|
subject := getString(args, "subject")
|
||||||
|
if subject == "" {
|
||||||
|
return Result{Err: fmt.Errorf("memory_forget: subject is required")}
|
||||||
|
}
|
||||||
|
var keyPtr *string
|
||||||
|
if k := getString(args, "key"); k != "" {
|
||||||
|
keyPtr = &k
|
||||||
|
}
|
||||||
|
err := store.DeleteFacts(ctx, agentID, subject, keyPtr)
|
||||||
|
if err != nil {
|
||||||
|
return Result{Err: fmt.Errorf("memory_forget: %w", err)}
|
||||||
|
}
|
||||||
|
if keyPtr != nil {
|
||||||
|
return Result{Output: fmt.Sprintf("forgot %s.%s", subject, *keyPtr)}
|
||||||
|
}
|
||||||
|
return Result{Output: fmt.Sprintf("forgot all facts about %s", subject)}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryClearContext creates a tool that clears the conversation window.
|
||||||
|
func NewMemoryClearContext(clearer WindowClearer, roomCtx *RoomContext) Tool {
|
||||||
|
return Tool{
|
||||||
|
Def: Def{
|
||||||
|
Name: "memory_clear_context",
|
||||||
|
Description: "Clear the conversation context window. Useful to start fresh. Omit room_id to clear the current room.",
|
||||||
|
Parameters: []Param{
|
||||||
|
{Name: "room_id", Type: "string", Description: "Optional room ID to clear; defaults to current room", Required: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||||
|
roomID := getString(args, "room_id")
|
||||||
|
if roomID == "" {
|
||||||
|
roomID = roomCtx.Get()
|
||||||
|
}
|
||||||
|
if roomID == "" {
|
||||||
|
return Result{Err: fmt.Errorf("memory_clear_context: no room_id provided and no current room")}
|
||||||
|
}
|
||||||
|
clearer.ClearWindow(roomID)
|
||||||
|
return Result{Output: fmt.Sprintf("conversation context cleared for room %s", roomID)}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemorySummary creates a tool that saves an important summary to long-term memory.
|
||||||
|
func NewMemorySummary(agentID string, store MemoryStore) Tool {
|
||||||
|
return Tool{
|
||||||
|
Def: Def{
|
||||||
|
Name: "memory_summary",
|
||||||
|
Description: "Save an important summary or takeaway from the current conversation to long-term memory.",
|
||||||
|
Parameters: []Param{
|
||||||
|
{Name: "text", Type: "string", Description: "The summary text to save", Required: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||||
|
text := getString(args, "text")
|
||||||
|
if text == "" {
|
||||||
|
return Result{Err: fmt.Errorf("memory_summary: text is required")}
|
||||||
|
}
|
||||||
|
key := time.Now().UTC().Format("2006-01-02T15:04:05")
|
||||||
|
err := store.SaveFact(ctx, memory.Fact{
|
||||||
|
AgentID: agentID,
|
||||||
|
Subject: "_summary",
|
||||||
|
Key: key,
|
||||||
|
Value: text,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return Result{Err: fmt.Errorf("memory_summary: %w", err)}
|
||||||
|
}
|
||||||
|
return Result{Output: "summary saved"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user