diff --git a/agents/commands.go b/agents/commands.go index ff5dbc5..0064f1b 100644 --- a/agents/commands.go +++ b/agents/commands.go @@ -107,7 +107,7 @@ func (a *Agent) cmdTool(ctx context.Context, msgCtx decision.MessageContext) str "args", argsJSON, ) - result := a.toolReg.Execute(ctx, toolName, argsJSON) + result := a.toolReg.ExecuteForRoom(ctx, toolName, argsJSON, msgCtx.RoomID) if result.Err != nil { return fmt.Sprintf("Error ejecutando %s: %s", toolName, result.Err) } diff --git a/agents/runtime.go b/agents/runtime.go index fa21eb7..e8f7a75 100644 --- a/agents/runtime.go +++ b/agents/runtime.go @@ -230,6 +230,29 @@ func New(cfg *config.AgentConfig, rules []decision.Rule, logger *slog.Logger) (* // Tool registry — register tools enabled in config toolReg := buildToolRegistry(cfg, sshExec, matrixClient, memStore, kStore, roomCtx, logger) + // Rate limiting for tools + if cfg.Security.ToolRateLimit.Enabled { + maxCalls := cfg.Security.ToolRateLimit.MaxCallsPerMin + if maxCalls <= 0 { + maxCalls = 10 + } + rl := tools.NewRateLimiter(maxCalls, time.Minute) + toolReg.SetRateLimiter(rl) + + cleanupInterval := cfg.Security.ToolRateLimit.CleanupIntervalS + if cleanupInterval <= 0 { + cleanupInterval = 60 + } + go func() { + ticker := time.NewTicker(time.Duration(cleanupInterval) * time.Second) + defer ticker.Stop() + for range ticker.C { + rl.Cleanup() + } + }() + logger.Info("tool rate limiting enabled", "max_calls_per_min", maxCalls) + } + a := &Agent{ cfg: cfg, acl: agentACL, @@ -753,7 +776,7 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (str a.logger.Warn("failed to send tool call notice", "tool", tc.Name, "err", err) } - result := a.toolReg.Execute(ctx, tc.Name, tc.Arguments) + result := a.toolReg.ExecuteForRoom(ctx, tc.Name, tc.Arguments, msgCtx.RoomID) output := result.Output if result.Err != nil { diff --git a/internal/config/schema.go b/internal/config/schema.go index a7f2fae..37876c8 100644 --- a/internal/config/schema.go +++ b/internal/config/schema.go @@ -280,10 +280,18 @@ type SSHTargetCfg struct { // ── Security ────────────────────────────────────────────────────────────── type SecurityCfg struct { - Roles map[string]RoleCfg `yaml:"roles"` - Audit AuditCfg `yaml:"audit"` - Secrets SecretsCfg `yaml:"secrets"` - Sanitize SanitizeCfg `yaml:"sanitize"` + Roles map[string]RoleCfg `yaml:"roles"` + Audit AuditCfg `yaml:"audit"` + Secrets SecretsCfg `yaml:"secrets"` + Sanitize SanitizeCfg `yaml:"sanitize"` + ToolRateLimit ToolRateLimitCfg `yaml:"tool_rate_limit"` +} + +// ToolRateLimitCfg controls per-room rate limiting of tool executions. +type ToolRateLimitCfg struct { + Enabled bool `yaml:"enabled"` // enable tool rate limiting (default false) + MaxCallsPerMin int `yaml:"max_calls_per_min"` // max tool calls per room per minute (default 10) + CleanupIntervalS int `yaml:"cleanup_interval_s"` // seconds between stale entry cleanup (default 60) } // SanitizeCfg controls prompt injection detection on incoming messages. diff --git a/tools/ratelimit.go b/tools/ratelimit.go new file mode 100644 index 0000000..b856627 --- /dev/null +++ b/tools/ratelimit.go @@ -0,0 +1,70 @@ +package tools + +import ( + "sync" + "time" +) + +// RateLimiter tracks tool call counts per key (typically roomID) using a +// sliding window. It is safe for concurrent use. +type RateLimiter struct { + maxCalls int + window time.Duration + mu sync.Mutex + buckets map[string][]time.Time +} + +// NewRateLimiter creates a rate limiter that allows maxCalls per window per key. +func NewRateLimiter(maxCalls int, window time.Duration) *RateLimiter { + return &RateLimiter{ + maxCalls: maxCalls, + window: window, + buckets: make(map[string][]time.Time), + } +} + +// Allow checks whether a call for the given key is within the rate limit. +// If allowed, it records the call and returns true. Otherwise returns false. +func (rl *RateLimiter) Allow(key string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-rl.window) + + // Trim expired entries + calls := rl.buckets[key] + start := 0 + for start < len(calls) && calls[start].Before(cutoff) { + start++ + } + calls = calls[start:] + + if len(calls) >= rl.maxCalls { + rl.buckets[key] = calls + return false + } + + rl.buckets[key] = append(calls, now) + return true +} + +// Cleanup removes stale entries for keys that have no recent calls. +// Should be called periodically to prevent memory growth. +func (rl *RateLimiter) Cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + cutoff := time.Now().Add(-rl.window) + for key, calls := range rl.buckets { + start := 0 + for start < len(calls) && calls[start].Before(cutoff) { + start++ + } + if start >= len(calls) { + delete(rl.buckets, key) + } else { + rl.buckets[key] = calls[start:] + } + } +} diff --git a/tools/registry.go b/tools/registry.go index 8613350..2c8761d 100644 --- a/tools/registry.go +++ b/tools/registry.go @@ -14,8 +14,9 @@ import ( // Registry holds available tools keyed by name. type Registry struct { - tools map[string]Tool - logger *slog.Logger + tools map[string]Tool + logger *slog.Logger + rateLimiter *RateLimiter // nil when rate limiting is disabled } // NewRegistry creates an empty registry. @@ -53,6 +54,24 @@ func (r *Registry) Len() int { return len(r.tools) } +// SetRateLimiter attaches a rate limiter to the registry. +// When set, ExecuteForRoom checks the limit before running the tool. +func (r *Registry) SetRateLimiter(rl *RateLimiter) { + r.rateLimiter = rl +} + +// ExecuteForRoom is like Execute but checks the per-room rate limit first. +// If the rate limit is exceeded, it returns an error result without executing. +func (r *Registry) ExecuteForRoom(ctx context.Context, name, argsJSON, roomID string) Result { + if r.rateLimiter != nil && roomID != "" { + if !r.rateLimiter.Allow(roomID) { + r.logger.Warn("tool_rate_limited", "tool", name, "room", roomID) + return Result{Err: fmt.Errorf("rate limit exceeded for room %s: too many tool calls per minute", roomID)} + } + } + return r.Execute(ctx, name, argsJSON) +} + // Execute looks up a tool by name and runs it. Returns an error result if not found. func (r *Registry) Execute(ctx context.Context, name string, argsJSON string) Result { t, ok := r.tools[name]