feat: import agents_and_robots platform as unibots (Matrix-out, unibus transport)
Reemplaza el scaffold del echobot por la plataforma completa de bots traida desde ~/DataProyects/Github/agents_and_robots tras la operacion Matrix-out: los bots ya no hablan por Matrix sino por el bus unibus (modelo todo-rooms + E2E via shell/transportunibus sobre github.com/enmanuel/unibus/pkg/client). - go.mod: replace de unibus -> ../unibus y de fn-registry -> ../../../.. (paths relativos reajustados a la nueva ubicacion dentro de fn_registry). - app.md: bump a 0.2.0, descripcion + arquitectura + comandos + gotchas reales. - modulo Go conservado como github.com/enmanuel/agents (sin reescribir imports). agents_and_robots queda archivado como museo de la era Matrix.
This commit is contained in:
@@ -0,0 +1,252 @@
|
||||
// Command launcher starts one or more agents from their config files.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// go run ./cmd/launcher # auto-discovers agents/*/config.yaml
|
||||
// go run ./cmd/launcher -c agents/assistant/config.yaml
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/enmanuel/agents/agents"
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/decision"
|
||||
pksecurity "github.com/enmanuel/agents/pkg/security"
|
||||
"github.com/enmanuel/agents/shell/bus"
|
||||
agentlog "github.com/enmanuel/agents/shell/logger"
|
||||
shellsecurity "github.com/enmanuel/agents/shell/security"
|
||||
|
||||
// Blank imports: each agent self-registers its rules via init().
|
||||
_ "github.com/enmanuel/agents/agents/asistente-2"
|
||||
_ "github.com/enmanuel/agents/agents/assistant-bot"
|
||||
_ "github.com/enmanuel/agents/agents/meteorologo"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
configPaths []string
|
||||
logLevel string
|
||||
logDir string
|
||||
)
|
||||
|
||||
root := &cobra.Command{
|
||||
Use: "launcher",
|
||||
Short: "Start Matrix agents from config files",
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(configPaths) == 0 {
|
||||
matches, _ := filepath.Glob("agents/*/config.yaml")
|
||||
configPaths = matches
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
lvl := parseLogLevel(logLevel)
|
||||
|
||||
// ── Launcher-level logger ──
|
||||
logger, launcherCleanup, err := agentlog.NewAgentLogger(agentlog.LoggerConfig{
|
||||
BaseDir: logDir,
|
||||
AgentID: "launcher",
|
||||
Level: lvl,
|
||||
})
|
||||
if err != nil {
|
||||
// Fallback to stdout if file logger fails.
|
||||
logger = newLogger(logLevel)
|
||||
logger.Warn("could not create file logger, falling back to stdout", "err", err)
|
||||
launcherCleanup = func() {}
|
||||
}
|
||||
defer launcherCleanup()
|
||||
|
||||
if len(configPaths) == 0 {
|
||||
logger.Warn("no agent configs found — nothing to start")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// ── Load centralized security policy ──
|
||||
secPolicy, secErr := shellsecurity.Load("security/")
|
||||
if secErr != nil {
|
||||
logger.Warn("security policy load failed, using empty policy (open access)", "err", secErr)
|
||||
secPolicy = pksecurity.SecurityPolicy{}
|
||||
} else {
|
||||
logger.Info("security policy loaded",
|
||||
"user_groups", len(secPolicy.UserGroups),
|
||||
"agent_groups", len(secPolicy.AgentGroups),
|
||||
"policies", len(secPolicy.Policies),
|
||||
)
|
||||
}
|
||||
|
||||
// ── Shared bus for inter-agent communication ──
|
||||
agentBus := bus.New(logger)
|
||||
|
||||
// NOTE: the multi-bot orchestrator is parked (Matrix-out). Its room
|
||||
// discovery was Matrix-intrinsic and has been removed; it is no longer
|
||||
// wired into the launcher. Re-introducing it over unibus is a later step.
|
||||
|
||||
// ── Shared dependencies for agent registry ──
|
||||
deps := &launchDeps{
|
||||
agentBus: agentBus,
|
||||
logDir: logDir,
|
||||
logLevel: lvl,
|
||||
parentCtx: ctx,
|
||||
secPolicy: secPolicy,
|
||||
}
|
||||
registry := newAgentRegistry(deps)
|
||||
|
||||
// ── SIGHUP: hot-reload individual agent or all agents ──
|
||||
sighup := make(chan os.Signal, 1)
|
||||
signal.Notify(sighup, syscall.SIGHUP)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case _, ok := <-sighup:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id := readReloadTarget("run/reload.txt")
|
||||
// Remove the target file after reading so it doesn't
|
||||
// affect the next SIGHUP.
|
||||
_ = os.Remove("run/reload.txt")
|
||||
if id == "" {
|
||||
logger.Info("sighup: reloading all agents")
|
||||
registry.reloadAll(rulesFor)
|
||||
} else {
|
||||
logger.Info("sighup: reloading agent", "id", id)
|
||||
registry.reload(id, rulesFor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// ── Start normal agents ──
|
||||
for _, path := range configPaths {
|
||||
path := path
|
||||
cfg, err := config.Load(path)
|
||||
if err != nil {
|
||||
logger.Error("failed to load config", "path", path, "err", err)
|
||||
continue
|
||||
}
|
||||
if !cfg.Agent.Enabled {
|
||||
logger.Info("agent disabled, skipping", "id", cfg.Agent.ID)
|
||||
continue
|
||||
}
|
||||
if cfg.Agent.Template {
|
||||
logger.Info("agent is template, skipping", "id", cfg.Agent.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Per-agent logger → writes to logs/<agent-id>/YYYY-MM-DD.jsonl
|
||||
agentLogger, agentCleanup, aErr := agentlog.NewAgentLogger(agentlog.LoggerConfig{
|
||||
BaseDir: logDir,
|
||||
AgentID: cfg.Agent.ID,
|
||||
Level: lvl,
|
||||
})
|
||||
if aErr != nil {
|
||||
logger.Warn("agent file logger failed, using launcher logger", "agent", cfg.Agent.ID, "err", aErr)
|
||||
agentLogger = logger.With("agent", cfg.Agent.ID)
|
||||
agentCleanup = func() {}
|
||||
}
|
||||
|
||||
// Branch: robot (command-only, lightweight) vs agent (full runtime).
|
||||
var runner agents.Runner
|
||||
|
||||
if cfg.Agent.Type == "robot" {
|
||||
robot, rErr := agents.NewRobot(cfg, agentLogger)
|
||||
if rErr != nil {
|
||||
logger.Error("failed to create robot", "id", cfg.Agent.ID, "err", rErr)
|
||||
agentCleanup()
|
||||
continue
|
||||
}
|
||||
runner = robot
|
||||
agentLogger.Info("created robot", "id", cfg.Agent.ID)
|
||||
} else {
|
||||
rules := rulesFor(cfg.Agent.ID, logger)
|
||||
|
||||
// Resolve centralized ACL for this agent
|
||||
agentACL := pksecurity.ResolveACL(cfg.Agent.ID, deps.secPolicy)
|
||||
agentLogger.Debug("resolved acl for agent",
|
||||
"agent", cfg.Agent.ID,
|
||||
"acl_empty", agentACL.Empty(),
|
||||
)
|
||||
|
||||
a, cErr := agents.New(cfg, rules, agentACL, agentLogger)
|
||||
if cErr != nil {
|
||||
logger.Error("failed to create agent", "id", cfg.Agent.ID, "err", cErr)
|
||||
agentCleanup()
|
||||
continue
|
||||
}
|
||||
|
||||
// Connect agent to the inter-agent bus.
|
||||
a.SetBus(agentBus)
|
||||
|
||||
runner = a
|
||||
}
|
||||
|
||||
registry.register(&runningAgent{
|
||||
runner: runner,
|
||||
cfg: cfg,
|
||||
cfgPath: path,
|
||||
logger: agentLogger,
|
||||
logCleanup: agentCleanup,
|
||||
})
|
||||
}
|
||||
|
||||
registry.waitAll()
|
||||
registry.cleanupLogs()
|
||||
logger.Info("all agents stopped")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
root.Flags().StringSliceVarP(&configPaths, "config", "c", nil,
|
||||
"Agent config file(s). If omitted, discovers all agents/*/config.yaml")
|
||||
root.Flags().StringVar(&logLevel, "log-level", "info",
|
||||
"Log level: debug | info | warn | error")
|
||||
root.Flags().StringVar(&logDir, "log-dir", "logs",
|
||||
`Log directory (logs/<agent>/YYYY-MM-DD.jsonl). Use "stdout" for console only`)
|
||||
|
||||
if err := root.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// rulesFor retrieves the rule factory for the given agent ID from the
|
||||
// global registry (populated by init() in each agent package).
|
||||
// Returns nil if no rules are registered (command-only bot).
|
||||
func rulesFor(agentID string, logger *slog.Logger) []decision.Rule {
|
||||
factory := agents.GetRules(agentID)
|
||||
if factory == nil {
|
||||
logger.Warn("no rules registered for agent, using empty ruleset (command-only)", "id", agentID)
|
||||
return nil
|
||||
}
|
||||
return factory()
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) slog.Level {
|
||||
switch level {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "warn":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
default:
|
||||
return slog.LevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// newLogger creates a stdout-only JSON logger (fallback when file logger fails).
|
||||
func newLogger(level string) *slog.Logger {
|
||||
return slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: parseLogLevel(level)}))
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/agents"
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
"github.com/enmanuel/agents/pkg/decision"
|
||||
pksecurity "github.com/enmanuel/agents/pkg/security"
|
||||
"github.com/enmanuel/agents/shell/bus"
|
||||
agentlog "github.com/enmanuel/agents/shell/logger"
|
||||
)
|
||||
|
||||
// runningAgent holds a live runner (Agent or Robot) and the metadata needed to recreate it.
|
||||
type runningAgent struct {
|
||||
runner agents.Runner
|
||||
cfg *config.AgentConfig
|
||||
cfgPath string
|
||||
logger *slog.Logger
|
||||
logCleanup func()
|
||||
}
|
||||
|
||||
// launchDeps holds shared resources needed to start/reload agents.
|
||||
type launchDeps struct {
|
||||
agentBus *bus.Bus
|
||||
logDir string
|
||||
logLevel slog.Level
|
||||
parentCtx context.Context
|
||||
secPolicy pksecurity.SecurityPolicy // centralized security policy loaded from security/
|
||||
}
|
||||
|
||||
// agentRegistry tracks all running agents by ID, enabling individual hot-reload.
|
||||
type agentRegistry struct {
|
||||
mu sync.Mutex
|
||||
agents map[string]*runningAgent
|
||||
deps *launchDeps
|
||||
}
|
||||
|
||||
func newAgentRegistry(deps *launchDeps) *agentRegistry {
|
||||
return &agentRegistry{
|
||||
agents: make(map[string]*runningAgent),
|
||||
deps: deps,
|
||||
}
|
||||
}
|
||||
|
||||
// register adds a running agent/robot to the registry and starts its goroutine.
|
||||
func (r *agentRegistry) register(ra *runningAgent) {
|
||||
r.mu.Lock()
|
||||
r.agents[ra.cfg.Agent.ID] = ra
|
||||
r.mu.Unlock()
|
||||
|
||||
runtimeType := ra.cfg.Agent.Type
|
||||
if runtimeType == "" {
|
||||
runtimeType = "agent"
|
||||
}
|
||||
|
||||
go func() {
|
||||
ra.logger.Info("runner started", "type", runtimeType)
|
||||
if err := ra.runner.Run(r.deps.parentCtx); err != nil {
|
||||
ra.logger.Error("runner stopped with error", "err", err, "type", runtimeType)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// stopAndWait stops a running agent/robot and waits for it to finish.
|
||||
// Caller must NOT hold r.mu.
|
||||
func (r *agentRegistry) stopAndWait(id string) {
|
||||
r.mu.Lock()
|
||||
ra, ok := r.agents[id]
|
||||
r.mu.Unlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ra.runner.Stop()
|
||||
select {
|
||||
case <-ra.runner.Done():
|
||||
case <-time.After(10 * time.Second):
|
||||
ra.logger.Warn("runner did not stop within 10s, forcing", "id", id)
|
||||
}
|
||||
|
||||
// Unsubscribe from bus so no stale channel remains.
|
||||
r.deps.agentBus.Unsubscribe(bus.AgentID(id))
|
||||
}
|
||||
|
||||
// reload stops an agent, re-reads its config, recreates it, and restarts it.
|
||||
func (r *agentRegistry) reload(id string, rulesFor func(string, *slog.Logger) []decision.Rule) {
|
||||
r.mu.Lock()
|
||||
ra, ok := r.agents[id]
|
||||
r.mu.Unlock()
|
||||
if !ok {
|
||||
slog.Warn("reload: agent not found", "id", id)
|
||||
return
|
||||
}
|
||||
|
||||
cfgPath := ra.cfgPath
|
||||
oldCleanup := ra.logCleanup
|
||||
|
||||
ra.logger.Info("agent_reload_start", "id", id)
|
||||
|
||||
// 1. Stop current instance and wait.
|
||||
r.stopAndWait(id)
|
||||
|
||||
// 2. Cleanup old log writer.
|
||||
if oldCleanup != nil {
|
||||
oldCleanup()
|
||||
}
|
||||
|
||||
// 3. Re-read config.
|
||||
cfg, err := config.Load(cfgPath)
|
||||
if err != nil {
|
||||
slog.Error("reload: failed to load config", "path", cfgPath, "err", err)
|
||||
return
|
||||
}
|
||||
if !cfg.Agent.Enabled {
|
||||
slog.Info("reload: agent is disabled, not restarting", "id", id)
|
||||
r.mu.Lock()
|
||||
delete(r.agents, id)
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// 4. New per-agent logger.
|
||||
newLogger, newCleanup, aErr := agentlog.NewAgentLogger(agentlog.LoggerConfig{
|
||||
BaseDir: r.deps.logDir,
|
||||
AgentID: cfg.Agent.ID,
|
||||
Level: r.deps.logLevel,
|
||||
})
|
||||
if aErr != nil {
|
||||
newLogger = slog.Default().With("agent", cfg.Agent.ID)
|
||||
newCleanup = func() {}
|
||||
}
|
||||
|
||||
// 5. Create new runner (validates config before discarding the old one).
|
||||
var newRunner agents.Runner
|
||||
|
||||
if cfg.Agent.Type == "robot" {
|
||||
robot, rErr := agents.NewRobot(cfg, newLogger)
|
||||
if rErr != nil {
|
||||
newLogger.Error("reload: failed to create robot", "id", id, "err", rErr)
|
||||
newCleanup()
|
||||
return
|
||||
}
|
||||
newRunner = robot
|
||||
} else {
|
||||
rules := rulesFor(cfg.Agent.ID, newLogger)
|
||||
agentACL := pksecurity.ResolveACL(cfg.Agent.ID, r.deps.secPolicy)
|
||||
newLogger.Debug("resolved acl for agent (reload)", "agent", cfg.Agent.ID, "acl_empty", agentACL.Empty())
|
||||
newAgent, aErr := agents.New(cfg, rules, agentACL, newLogger)
|
||||
if aErr != nil {
|
||||
newLogger.Error("reload: failed to create agent", "id", id, "err", aErr)
|
||||
newCleanup()
|
||||
return
|
||||
}
|
||||
|
||||
// Wire bus (orchestration is parked; only agents connect to the bus).
|
||||
newAgent.SetBus(r.deps.agentBus)
|
||||
newRunner = newAgent
|
||||
}
|
||||
|
||||
newRA := &runningAgent{
|
||||
runner: newRunner,
|
||||
cfg: cfg,
|
||||
cfgPath: cfgPath,
|
||||
logger: newLogger,
|
||||
logCleanup: newCleanup,
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.agents[id] = newRA
|
||||
r.mu.Unlock()
|
||||
|
||||
// 7. Start new goroutine.
|
||||
runtimeType := cfg.Agent.Type
|
||||
if runtimeType == "" {
|
||||
runtimeType = "agent"
|
||||
}
|
||||
go func() {
|
||||
newLogger.Info("runner started", "type", runtimeType)
|
||||
if err := newRunner.Run(r.deps.parentCtx); err != nil {
|
||||
newLogger.Error("runner stopped with error", "err", err, "type", runtimeType)
|
||||
}
|
||||
}()
|
||||
|
||||
newLogger.Info("runner_reloaded", "id", id, "type", runtimeType)
|
||||
}
|
||||
|
||||
// reloadAll reloads every registered agent sequentially.
|
||||
func (r *agentRegistry) reloadAll(rulesFor func(string, *slog.Logger) []decision.Rule) {
|
||||
r.mu.Lock()
|
||||
ids := make([]string, 0, len(r.agents))
|
||||
for id := range r.agents {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, id := range ids {
|
||||
r.reload(id, rulesFor)
|
||||
}
|
||||
}
|
||||
|
||||
// waitAll blocks until all registered runners have stopped.
|
||||
func (r *agentRegistry) waitAll() {
|
||||
r.mu.Lock()
|
||||
dones := make([]<-chan struct{}, 0, len(r.agents))
|
||||
for _, ra := range r.agents {
|
||||
dones = append(dones, ra.runner.Done())
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, done := range dones {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupLogs calls every agent's log cleanup function (called on launcher shutdown).
|
||||
func (r *agentRegistry) cleanupLogs() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, ra := range r.agents {
|
||||
if ra.logCleanup != nil {
|
||||
ra.logCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readReloadTarget reads the given file and returns the trimmed content.
|
||||
// Returns "" if the file doesn't exist, is empty, or equals "*" (meaning reload all).
|
||||
func readReloadTarget(path string) string {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
id := strings.TrimSpace(string(data))
|
||||
if id == "*" {
|
||||
return ""
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadReloadTarget_missing(t *testing.T) {
|
||||
got := readReloadTarget(filepath.Join(t.TempDir(), "reload.txt"))
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string for missing file, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadReloadTarget_empty(t *testing.T) {
|
||||
f := filepath.Join(t.TempDir(), "reload.txt")
|
||||
if err := os.WriteFile(f, []byte(""), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := readReloadTarget(f)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string for empty file, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadReloadTarget_star(t *testing.T) {
|
||||
f := filepath.Join(t.TempDir(), "reload.txt")
|
||||
if err := os.WriteFile(f, []byte("*\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := readReloadTarget(f)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string for '*', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadReloadTarget_agentID(t *testing.T) {
|
||||
f := filepath.Join(t.TempDir(), "reload.txt")
|
||||
if err := os.WriteFile(f, []byte("assistant-bot\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := readReloadTarget(f)
|
||||
if got != "assistant-bot" {
|
||||
t.Fatalf("expected 'assistant-bot', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadReloadTarget_whitespace(t *testing.T) {
|
||||
f := filepath.Join(t.TempDir(), "reload.txt")
|
||||
if err := os.WriteFile(f, []byte(" asistente-2 \n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := readReloadTarget(f)
|
||||
if got != "asistente-2" {
|
||||
t.Fatalf("expected 'asistente-2', got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
|
||||
moderncsqlite "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// mautrix dbutil opens sqlite as "sqlite3"; register the pure-Go driver
|
||||
// under that name. We add a connection hook that sets WAL mode and a
|
||||
// busy timeout on every connection to prevent SQLITE_BUSY crashes during
|
||||
// concurrent writes (crypto store sync + memory store).
|
||||
d := &moderncsqlite.Driver{}
|
||||
d.RegisterConnectionHook(sqlitePragmaHook)
|
||||
sql.Register("sqlite3", d)
|
||||
}
|
||||
|
||||
// sqlitePragmaHook sets WAL journal mode and a 5-second busy timeout on
|
||||
// every new SQLite connection. This prevents SQLITE_BUSY errors when
|
||||
// multiple goroutines write concurrently (e.g. mautrix crypto sync +
|
||||
// memory/knowledge stores).
|
||||
func sqlitePragmaHook(conn moderncsqlite.ExecQuerierContext, _ string) error {
|
||||
ctx := context.Background()
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA busy_timeout=5000",
|
||||
}
|
||||
for _, p := range pragmas {
|
||||
if _, err := conn.ExecContext(ctx, p, []driver.NamedValue{}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSQLitePragmaHook verifies that every connection opened via the registered
|
||||
// "sqlite3" driver has WAL journal mode and a busy_timeout set.
|
||||
func TestSQLitePragmaHook(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "test.db")
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Force a real connection to be created (Open is lazy).
|
||||
if err := db.Ping(); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
|
||||
var journalMode string
|
||||
if err := db.QueryRow("PRAGMA journal_mode").Scan(&journalMode); err != nil {
|
||||
t.Fatalf("query journal_mode: %v", err)
|
||||
}
|
||||
if journalMode != "wal" {
|
||||
t.Errorf("journal_mode = %q, want %q", journalMode, "wal")
|
||||
}
|
||||
|
||||
var busyTimeout int
|
||||
if err := db.QueryRow("PRAGMA busy_timeout").Scan(&busyTimeout); err != nil {
|
||||
t.Fatalf("query busy_timeout: %v", err)
|
||||
}
|
||||
if busyTimeout != 5000 {
|
||||
t.Errorf("busy_timeout = %d, want %d", busyTimeout, 5000)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSQLiteConcurrentWrites verifies that concurrent writers do not get
|
||||
// SQLITE_BUSY errors thanks to WAL mode and busy_timeout.
|
||||
func TestSQLiteConcurrentWrites(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "concurrent.db")
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a table to write to.
|
||||
if _, err := db.Exec(`CREATE TABLE kv (k TEXT PRIMARY KEY, v TEXT)`); err != nil {
|
||||
t.Fatalf("create table: %v", err)
|
||||
}
|
||||
|
||||
// Simulate the scenario: multiple goroutines writing concurrently,
|
||||
// like mautrix crypto sync + memory store + knowledge store.
|
||||
const writers = 5
|
||||
const writesPerWriter = 50
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, writers*writesPerWriter)
|
||||
|
||||
for w := 0; w < writers; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < writesPerWriter; i++ {
|
||||
_, err := db.ExecContext(ctx,
|
||||
`INSERT OR REPLACE INTO kv (k, v) VALUES (?, ?)`,
|
||||
// Use writer+iteration as key so they conflict
|
||||
"key", "value",
|
||||
)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
for err := range errs {
|
||||
t.Errorf("concurrent write error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSQLiteConcurrentWritesSeparateConnections tests with separate sql.DB
|
||||
// instances (like crypto.db being opened by both mautrix and our code).
|
||||
func TestSQLiteConcurrentWritesSeparateConnections(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "shared.db")
|
||||
|
||||
// Open two separate connections to the same file (simulates mautrix +
|
||||
// our memory store sharing a DB, or separate processes).
|
||||
db1, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open db1: %v", err)
|
||||
}
|
||||
defer db1.Close()
|
||||
|
||||
db2, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open db2: %v", err)
|
||||
}
|
||||
defer db2.Close()
|
||||
|
||||
// Create table via db1
|
||||
if _, err := db1.Exec(`CREATE TABLE t (id INTEGER PRIMARY KEY, data TEXT)`); err != nil {
|
||||
t.Fatalf("create table: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
const iterations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, iterations*2)
|
||||
|
||||
// Writer 1 (simulates mautrix SaveNextBatch)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
_, err := db1.ExecContext(ctx,
|
||||
`INSERT INTO t (data) VALUES (?)`, "from_crypto_sync",
|
||||
)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Writer 2 (simulates our memory store SaveMessage)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
_, err := db2.ExecContext(ctx,
|
||||
`INSERT INTO t (data) VALUES (?)`, "from_memory_store",
|
||||
)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
for err := range errs {
|
||||
t.Errorf("concurrent write error (separate conns): %v", err)
|
||||
}
|
||||
|
||||
// Verify all writes succeeded
|
||||
var count int
|
||||
if err := db1.QueryRow("SELECT COUNT(*) FROM t").Scan(&count); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
expected := iterations * 2
|
||||
if count != expected {
|
||||
t.Errorf("row count = %d, want %d", count, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSQLiteWALFileCreated verifies that WAL mode actually creates the -wal file,
|
||||
// confirming the pragma took effect at the filesystem level.
|
||||
func TestSQLiteWALFileCreated(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "walcheck.db")
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a table and write data to trigger WAL file creation.
|
||||
if _, err := db.Exec(`CREATE TABLE x (id INTEGER PRIMARY KEY)`); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
if _, err := db.Exec(`INSERT INTO x (id) VALUES (1)`); err != nil {
|
||||
t.Fatalf("insert: %v", err)
|
||||
}
|
||||
|
||||
walPath := dbPath + "-wal"
|
||||
if _, err := os.Stat(walPath); os.IsNotExist(err) {
|
||||
t.Errorf("WAL file not created at %s — PRAGMA journal_mode=WAL may not be taking effect", walPath)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user