feat: implement tool registry and add various tools for HTTP, file operations, SSH, and Matrix messaging
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
Poder actualizar el avatar (foto de perfil) y el display name de cada bot en Matrix
|
||||
desde la CLI (`agentctl`) o desde un dev-script.
|
||||
|
||||
## Estado: pendiente
|
||||
## Estado: COMPLETADO
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
# Cómo crear una nueva herramienta (tool)
|
||||
|
||||
Las herramientas viven en `tools/` y siguen el patrón **spec puro + función impura**.
|
||||
|
||||
## Pasos
|
||||
|
||||
### 1. Crear el archivo `tools/<nombre>.go`
|
||||
|
||||
```go
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NewMiTool creates a mi_tool tool that does X.
|
||||
// Accepts dependencies needed for execution (configs, clients, etc).
|
||||
func NewMiTool(/* deps */) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "mi_tool",
|
||||
Description: "Description clara de qué hace la herramienta para el LLM.",
|
||||
Parameters: []Param{
|
||||
{Name: "param1", Type: "string", Description: "What this param is", Required: true},
|
||||
{Name: "param2", Type: "number", Description: "Optional param", Required: false},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
p1 := getString(args, "param1")
|
||||
if p1 == "" {
|
||||
return Result{Err: fmt.Errorf("mi_tool: param1 is required")}
|
||||
}
|
||||
|
||||
// Execute the actual work here (impure)
|
||||
output := doSomething(p1)
|
||||
|
||||
return Result{Output: output}
|
||||
},
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Registrar en `agents/runtime.go` → `buildToolRegistry()`
|
||||
|
||||
```go
|
||||
if /* condición basada en config */ {
|
||||
reg.Register(tools.NewMiTool(/* deps */))
|
||||
logger.Debug("registered mi_tool")
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Habilitar en el config del agente (`agents/<id>/config.yaml`)
|
||||
|
||||
Asegurarse de que `llm.tool_use.enabled: true` y la sección relevante de `tools:` esté habilitada.
|
||||
|
||||
## Reglas
|
||||
|
||||
- **Def es PURO**: solo datos (nombre, descripción, parámetros). Sin side effects.
|
||||
- **Exec es IMPURO**: hace I/O real. Recibe `context.Context` y `map[string]any`.
|
||||
- **Validar inputs**: siempre validar parámetros requeridos al inicio del Exec.
|
||||
- **Validar permisos**: usar los campos del config (AllowedDomains, AllowedPaths, etc.) para restringir acceso.
|
||||
- **Limitar output**: truncar a 64 KB máximo para no saturar el contexto del LLM.
|
||||
- **Usar `getString()`**: helper del package para extraer strings de args de forma segura.
|
||||
- **Param types válidos**: "string", "number", "integer", "boolean", "object", "array" (JSON Schema types).
|
||||
- **Descripción clara**: el LLM decide cuándo usar la tool basándose en el Description del Def.
|
||||
+128
-16
@@ -18,8 +18,11 @@ import (
|
||||
shelllm "github.com/enmanuel/agents/shell/llm"
|
||||
"github.com/enmanuel/agents/shell/matrix"
|
||||
"github.com/enmanuel/agents/shell/ssh"
|
||||
"github.com/enmanuel/agents/tools"
|
||||
)
|
||||
|
||||
const defaultMaxToolIterations = 5
|
||||
|
||||
// Agent is the assembled runtime: pure core + impure shell.
|
||||
type Agent struct {
|
||||
cfg *config.AgentConfig
|
||||
@@ -29,6 +32,7 @@ type Agent struct {
|
||||
matrix *matrix.Client
|
||||
runner *effects.Runner
|
||||
listener *matrix.Listener
|
||||
toolReg *tools.Registry
|
||||
logger *slog.Logger
|
||||
cryptoStore io.Closer // non-nil when E2EE is enabled; closed on shutdown
|
||||
}
|
||||
@@ -75,12 +79,16 @@ func New(cfg *config.AgentConfig, rules []decision.Rule, logger *slog.Logger) (*
|
||||
// Effects runner
|
||||
runner := effects.NewRunner(matrixClient, sshExec, logger)
|
||||
|
||||
// Tool registry — register tools enabled in config
|
||||
toolReg := buildToolRegistry(cfg, sshExec, matrixClient, logger)
|
||||
|
||||
a := &Agent{
|
||||
cfg: cfg,
|
||||
rules: rules,
|
||||
llm: llmFunc,
|
||||
matrix: matrixClient,
|
||||
runner: runner,
|
||||
toolReg: toolReg,
|
||||
logger: logger,
|
||||
cryptoStore: cryptoStore,
|
||||
}
|
||||
@@ -96,7 +104,11 @@ func (a *Agent) Run(ctx context.Context) error {
|
||||
if a.cryptoStore != nil {
|
||||
defer a.cryptoStore.Close()
|
||||
}
|
||||
a.logger.Info("agent starting", "id", a.cfg.Agent.ID, "name", a.cfg.Agent.Name)
|
||||
a.logger.Info("agent starting",
|
||||
"id", a.cfg.Agent.ID,
|
||||
"name", a.cfg.Agent.Name,
|
||||
"tools", a.toolReg.Names(),
|
||||
)
|
||||
return a.listener.Run(ctx)
|
||||
}
|
||||
|
||||
@@ -134,7 +146,7 @@ func (a *Agent) handleEvent(ctx context.Context, msgCtx decision.MessageContext,
|
||||
return
|
||||
}
|
||||
|
||||
// Expand LLM actions inline (simplified — real impl would maintain conversation state)
|
||||
// Expand LLM actions inline — with tool-use loop when enabled
|
||||
expanded := make([]decision.Action, 0, len(actions))
|
||||
for _, act := range actions {
|
||||
if act.Kind == decision.ActionKindLLM {
|
||||
@@ -164,20 +176,120 @@ func (a *Agent) runLLM(ctx context.Context, msgCtx decision.MessageContext) (str
|
||||
"model", a.cfg.LLM.Primary.Model,
|
||||
"provider", a.cfg.LLM.Primary.Provider,
|
||||
)
|
||||
req := coretypes.CompletionRequest{
|
||||
Model: a.cfg.LLM.Primary.Model,
|
||||
MaxTokens: a.cfg.LLM.Primary.MaxTokens,
|
||||
Temperature: a.cfg.LLM.Primary.Temperature,
|
||||
SystemPrompt: a.cfg.Agent.Description,
|
||||
Messages: []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: msgCtx.Content},
|
||||
},
|
||||
|
||||
// Load system prompt from file if configured, else use description
|
||||
systemPrompt := a.cfg.Agent.Description
|
||||
|
||||
messages := []coretypes.Message{
|
||||
{Role: coretypes.RoleUser, Content: msgCtx.Content},
|
||||
}
|
||||
resp, err := a.llm(ctx, req)
|
||||
if err != nil {
|
||||
a.logger.Error("LLM call failed", "model", req.Model, "err", err)
|
||||
return "", err
|
||||
|
||||
// Build tool specs for the LLM if tool_use is enabled
|
||||
var llmTools []coretypes.ToolSpec
|
||||
if a.cfg.LLM.ToolUse.Enabled && a.toolReg.Len() > 0 {
|
||||
llmTools = a.toolReg.ToLLMSpecs()
|
||||
a.logger.Debug("tools available for LLM", "count", len(llmTools))
|
||||
}
|
||||
a.logger.Debug("LLM responded", "content_len", len(resp.Content))
|
||||
return resp.Content, nil
|
||||
|
||||
maxIter := a.cfg.LLM.ToolUse.MaxIterations
|
||||
if maxIter <= 0 {
|
||||
maxIter = defaultMaxToolIterations
|
||||
}
|
||||
|
||||
// Tool-use loop: call LLM → execute tools → feed results back → repeat
|
||||
for i := 0; i < maxIter; i++ {
|
||||
req := coretypes.CompletionRequest{
|
||||
Model: a.cfg.LLM.Primary.Model,
|
||||
MaxTokens: a.cfg.LLM.Primary.MaxTokens,
|
||||
Temperature: a.cfg.LLM.Primary.Temperature,
|
||||
SystemPrompt: systemPrompt,
|
||||
Messages: messages,
|
||||
Tools: llmTools,
|
||||
}
|
||||
|
||||
resp, err := a.llm(ctx, req)
|
||||
if err != nil {
|
||||
a.logger.Error("LLM call failed", "model", req.Model, "err", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
a.logger.Debug("LLM responded",
|
||||
"content_len", len(resp.Content),
|
||||
"tool_calls", len(resp.ToolCalls),
|
||||
"finish_reason", resp.FinishReason,
|
||||
)
|
||||
|
||||
// No tool calls — return the text response
|
||||
if len(resp.ToolCalls) == 0 {
|
||||
return resp.Content, nil
|
||||
}
|
||||
|
||||
// Append assistant message with tool calls to conversation
|
||||
messages = append(messages, coretypes.Message{
|
||||
Role: coretypes.RoleAssistant,
|
||||
Content: resp.Content,
|
||||
ToolCalls: resp.ToolCalls,
|
||||
})
|
||||
|
||||
// Execute each tool and append results
|
||||
for _, tc := range resp.ToolCalls {
|
||||
a.logger.Info("executing tool",
|
||||
"tool", tc.Name,
|
||||
"call_id", tc.ID,
|
||||
)
|
||||
|
||||
result := a.toolReg.Execute(ctx, tc.Name, tc.Arguments)
|
||||
|
||||
output := result.Output
|
||||
if result.Err != nil {
|
||||
output = fmt.Sprintf("error: %s", result.Err)
|
||||
a.logger.Warn("tool execution error",
|
||||
"tool", tc.Name,
|
||||
"err", result.Err,
|
||||
)
|
||||
} else {
|
||||
a.logger.Debug("tool executed",
|
||||
"tool", tc.Name,
|
||||
"output_len", len(output),
|
||||
)
|
||||
}
|
||||
|
||||
messages = append(messages, coretypes.Message{
|
||||
Role: coretypes.RoleTool,
|
||||
Content: output,
|
||||
ToolCallID: tc.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Max iterations reached — return whatever we have
|
||||
a.logger.Warn("tool-use loop reached max iterations", "max", maxIter)
|
||||
return "I've reached the maximum number of tool iterations. Here's what I found so far.", nil
|
||||
}
|
||||
|
||||
// buildToolRegistry creates a Registry with tools enabled in the agent's config.
|
||||
func buildToolRegistry(cfg *config.AgentConfig, sshExec *ssh.Executor, matrixClient *matrix.Client, logger *slog.Logger) *tools.Registry {
|
||||
reg := tools.NewRegistry()
|
||||
|
||||
if cfg.Tools.HTTP.Enabled {
|
||||
reg.Register(tools.NewHTTPGet(cfg.Tools.HTTP))
|
||||
reg.Register(tools.NewHTTPPost(cfg.Tools.HTTP))
|
||||
logger.Debug("registered http tools")
|
||||
}
|
||||
|
||||
if cfg.Tools.SSH.Enabled {
|
||||
reg.Register(tools.NewSSHCommand(cfg.Tools.SSH, sshExec))
|
||||
logger.Debug("registered ssh tool")
|
||||
}
|
||||
|
||||
if cfg.Tools.FileOps.Enabled {
|
||||
reg.Register(tools.NewReadFile(cfg.Tools.FileOps))
|
||||
logger.Debug("registered file tool")
|
||||
}
|
||||
|
||||
// matrix_send is always available
|
||||
reg.Register(tools.NewMatrixSend(matrixClient))
|
||||
logger.Debug("registered matrix tool")
|
||||
|
||||
return reg
|
||||
}
|
||||
|
||||
+76
-13
@@ -71,8 +71,8 @@ type anthropicRequest struct {
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
@@ -81,12 +81,26 @@ type anthropicTool struct {
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock represents a block in a content array.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// text block
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// tool_use block (in assistant responses)
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input map[string]any `json:"input,omitempty"`
|
||||
|
||||
// tool_result block (in user messages)
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
@@ -97,12 +111,9 @@ func toAnthropicRequest(req coretypes.CompletionRequest) anthropicRequest {
|
||||
msgs := make([]anthropicMessage, 0, len(req.Messages))
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == coretypes.RoleSystem {
|
||||
continue // handled as top-level system param
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, anthropicMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
msgs = append(msgs, toAnthropicMessage(m))
|
||||
}
|
||||
|
||||
tools := make([]anthropicTool, len(req.Tools))
|
||||
@@ -123,19 +134,71 @@ func toAnthropicRequest(req coretypes.CompletionRequest) anthropicRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// toAnthropicMessage converts a core Message to the Anthropic format.
|
||||
// Handles plain text, assistant messages with tool calls, and tool result messages.
|
||||
func toAnthropicMessage(m coretypes.Message) anthropicMessage {
|
||||
// Assistant message with tool calls → content array with text + tool_use blocks
|
||||
if m.Role == coretypes.RoleAssistant && len(m.ToolCalls) > 0 {
|
||||
blocks := make([]anthropicContentBlock, 0, len(m.ToolCalls)+1)
|
||||
if m.Content != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: m.Content})
|
||||
}
|
||||
for _, tc := range m.ToolCalls {
|
||||
var input map[string]any
|
||||
_ = json.Unmarshal([]byte(tc.Arguments), &input)
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Input: input,
|
||||
})
|
||||
}
|
||||
raw, _ := json.Marshal(blocks)
|
||||
return anthropicMessage{Role: "assistant", Content: raw}
|
||||
}
|
||||
|
||||
// Tool result message → user message with tool_result content array
|
||||
if m.Role == coretypes.RoleTool {
|
||||
blocks := []anthropicContentBlock{{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.ToolCallID,
|
||||
Content: m.Content,
|
||||
}}
|
||||
raw, _ := json.Marshal(blocks)
|
||||
return anthropicMessage{Role: "user", Content: raw}
|
||||
}
|
||||
|
||||
// Plain text message
|
||||
raw, _ := json.Marshal(m.Content)
|
||||
return anthropicMessage{Role: string(m.Role), Content: raw}
|
||||
}
|
||||
|
||||
func fromAnthropicResponse(raw []byte) (coretypes.CompletionResponse, error) {
|
||||
var ar anthropicResponse
|
||||
if err := json.Unmarshal(raw, &ar); err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
var content string
|
||||
var toolCalls []coretypes.ToolCall
|
||||
|
||||
for _, c := range ar.Content {
|
||||
if c.Type == "text" {
|
||||
switch c.Type {
|
||||
case "text":
|
||||
content += c.Text
|
||||
case "tool_use":
|
||||
argsJSON, _ := json.Marshal(c.Input)
|
||||
toolCalls = append(toolCalls, coretypes.ToolCall{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Arguments: string(argsJSON),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: ar.StopReason,
|
||||
Usage: coretypes.TokenUsage{
|
||||
InputTokens: ar.Usage.InputTokens,
|
||||
|
||||
+81
-15
@@ -2,6 +2,7 @@ package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -33,19 +34,7 @@ func NewOpenAIComplete(apiKeyEnv, baseURL string) coretypes.CompleteFunc {
|
||||
})
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
role := openai.ChatMessageRoleUser
|
||||
switch m.Role {
|
||||
case coretypes.RoleAssistant:
|
||||
role = openai.ChatMessageRoleAssistant
|
||||
case coretypes.RoleSystem:
|
||||
role = openai.ChatMessageRoleSystem
|
||||
case coretypes.RoleTool:
|
||||
role = openai.ChatMessageRoleTool
|
||||
}
|
||||
msgs = append(msgs, openai.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
})
|
||||
msgs = append(msgs, toOpenAIMessage(m))
|
||||
}
|
||||
|
||||
openReq := openai.ChatCompletionRequest{
|
||||
@@ -55,6 +44,11 @@ func NewOpenAIComplete(apiKeyEnv, baseURL string) coretypes.CompleteFunc {
|
||||
Temperature: float32(req.Temperature),
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if len(req.Tools) > 0 {
|
||||
openReq.Tools = toOpenAITools(req.Tools)
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(ctx, openReq)
|
||||
if err != nil {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("openai completion: %w", err)
|
||||
@@ -63,9 +57,20 @@ func NewOpenAIComplete(apiKeyEnv, baseURL string) coretypes.CompleteFunc {
|
||||
return coretypes.CompletionResponse{}, fmt.Errorf("openai: empty choices")
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
var toolCalls []coretypes.ToolCall
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
toolCalls = append(toolCalls, coretypes.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
return coretypes.CompletionResponse{
|
||||
Content: resp.Choices[0].Message.Content,
|
||||
FinishReason: string(resp.Choices[0].FinishReason),
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: string(choice.FinishReason),
|
||||
Usage: coretypes.TokenUsage{
|
||||
InputTokens: resp.Usage.PromptTokens,
|
||||
OutputTokens: resp.Usage.CompletionTokens,
|
||||
@@ -74,3 +79,64 @@ func NewOpenAIComplete(apiKeyEnv, baseURL string) coretypes.CompleteFunc {
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// toOpenAIMessage converts a core Message to an OpenAI ChatCompletionMessage.
|
||||
func toOpenAIMessage(m coretypes.Message) openai.ChatCompletionMessage {
|
||||
role := openai.ChatMessageRoleUser
|
||||
switch m.Role {
|
||||
case coretypes.RoleAssistant:
|
||||
role = openai.ChatMessageRoleAssistant
|
||||
case coretypes.RoleSystem:
|
||||
role = openai.ChatMessageRoleSystem
|
||||
case coretypes.RoleTool:
|
||||
role = openai.ChatMessageRoleTool
|
||||
}
|
||||
|
||||
msg := openai.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
ToolCallID: m.ToolCallID,
|
||||
}
|
||||
|
||||
// Assistant messages with tool calls
|
||||
if m.Role == coretypes.RoleAssistant && len(m.ToolCalls) > 0 {
|
||||
msg.ToolCalls = make([]openai.ToolCall, len(m.ToolCalls))
|
||||
for i, tc := range m.ToolCalls {
|
||||
msg.ToolCalls[i] = openai.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: openai.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// toOpenAITools converts core ToolSpecs to OpenAI Tool format.
|
||||
func toOpenAITools(specs []coretypes.ToolSpec) []openai.Tool {
|
||||
tools := make([]openai.Tool, len(specs))
|
||||
for i, s := range specs {
|
||||
tools[i] = openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Parameters: json.RawMessage(marshalSchema(s.InputSchema)),
|
||||
},
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// marshalSchema marshals a JSON schema map to bytes. Falls back to empty object.
|
||||
func marshalSchema(schema map[string]any) []byte {
|
||||
b, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
return []byte("{}")
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
)
|
||||
|
||||
// NewReadFile creates a read_file tool that reads local files.
|
||||
// Validates paths against cfg.AllowedPaths when non-empty.
|
||||
func NewReadFile(cfg config.FileOpsCfg) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "read_file",
|
||||
Description: "Read the contents of a local file.",
|
||||
Parameters: []Param{
|
||||
{Name: "path", Type: "string", Description: "Absolute path to the file to read", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
path := getString(args, "path")
|
||||
if path == "" {
|
||||
return Result{Err: fmt.Errorf("read_file: path is required")}
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("read_file: %w", err)}
|
||||
}
|
||||
|
||||
if err := validatePath(absPath, cfg.AllowedPaths); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("read_file: %w", err)}
|
||||
}
|
||||
|
||||
// Limit output to 64 KB
|
||||
content := string(data)
|
||||
if len(content) > 64*1024 {
|
||||
content = content[:64*1024] + "\n... (truncated)"
|
||||
}
|
||||
|
||||
return Result{Output: content}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validatePath(absPath string, allowedPaths []string) error {
|
||||
if len(allowedPaths) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, allowed := range allowedPaths {
|
||||
a, err := filepath.Abs(allowed)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(absPath, a) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("path %q not under any allowed path", absPath)
|
||||
}
|
||||
+132
@@ -0,0 +1,132 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
)
|
||||
|
||||
// NewHTTPGet creates an http_get tool that performs GET requests.
|
||||
// Validates URLs against cfg.AllowedDomains when non-empty.
|
||||
func NewHTTPGet(cfg config.HTTPToolCfg) Tool {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "http_get",
|
||||
Description: "Perform an HTTP GET request to a URL and return the response body.",
|
||||
Parameters: []Param{
|
||||
{Name: "url", Type: "string", Description: "The URL to request", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
rawURL := getString(args, "url")
|
||||
if rawURL == "" {
|
||||
return Result{Err: fmt.Errorf("http_get: url is required")}
|
||||
}
|
||||
if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_get: %w", err)}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_get: %w", err)}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) // 64 KB limit
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_get read body: %w", err)}
|
||||
}
|
||||
|
||||
return Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPPost creates an http_post tool that performs POST requests with a JSON body.
|
||||
// Validates URLs against cfg.AllowedDomains when non-empty.
|
||||
func NewHTTPPost(cfg config.HTTPToolCfg) Tool {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "http_post",
|
||||
Description: "Perform an HTTP POST request with a JSON body and return the response.",
|
||||
Parameters: []Param{
|
||||
{Name: "url", Type: "string", Description: "The URL to request", Required: true},
|
||||
{Name: "body", Type: "string", Description: "The JSON body to send", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
rawURL := getString(args, "url")
|
||||
if rawURL == "" {
|
||||
return Result{Err: fmt.Errorf("http_post: url is required")}
|
||||
}
|
||||
bodyStr := getString(args, "body")
|
||||
if bodyStr == "" {
|
||||
return Result{Err: fmt.Errorf("http_post: body is required")}
|
||||
}
|
||||
if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, rawURL, strings.NewReader(bodyStr))
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_post: %w", err)}
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_post: %w", err)}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_post read body: %w", err)}
|
||||
}
|
||||
|
||||
return Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// validateDomain checks that the URL's host is in the allowed list.
|
||||
// If allowedDomains is empty, all domains are allowed.
|
||||
func validateDomain(rawURL string, allowedDomains []string) error {
|
||||
if len(allowedDomains) == 0 {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid url: %w", err)
|
||||
}
|
||||
host := u.Hostname()
|
||||
for _, d := range allowedDomains {
|
||||
if host == d {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("domain %q not in allowed list", host)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MatrixSender is the interface for sending Matrix messages.
|
||||
// Satisfied by shell/matrix.Client.
|
||||
type MatrixSender interface {
|
||||
SendText(ctx context.Context, roomID, text string) error
|
||||
}
|
||||
|
||||
// NewMatrixSend creates a matrix_send tool that sends a message to a Matrix room.
|
||||
func NewMatrixSend(sender MatrixSender) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "matrix_send",
|
||||
Description: "Send a text message to a Matrix room.",
|
||||
Parameters: []Param{
|
||||
{Name: "room_id", Type: "string", Description: "The Matrix room ID to send to", Required: true},
|
||||
{Name: "message", Type: "string", Description: "The text message to send", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
roomID := getString(args, "room_id")
|
||||
message := getString(args, "message")
|
||||
if roomID == "" || message == "" {
|
||||
return Result{Err: fmt.Errorf("matrix_send: room_id and message are required")}
|
||||
}
|
||||
|
||||
if err := sender.SendText(ctx, roomID, message); err != nil {
|
||||
return Result{Err: fmt.Errorf("matrix_send: %w", err)}
|
||||
}
|
||||
|
||||
return Result{Output: fmt.Sprintf("message sent to %s", roomID)}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// Registry holds available tools keyed by name.
|
||||
type Registry struct {
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{tools: make(map[string]Tool)}
|
||||
}
|
||||
|
||||
// Register adds a tool to the registry.
|
||||
func (r *Registry) Register(t Tool) {
|
||||
r.tools[t.Def.Name] = t
|
||||
}
|
||||
|
||||
// Get looks up a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
t, ok := r.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
// Names returns all registered tool names in sorted order.
|
||||
func (r *Registry) Names() []string {
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for k := range r.tools {
|
||||
names = append(names, k)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// Len returns the number of registered tools.
|
||||
func (r *Registry) Len() int {
|
||||
return len(r.tools)
|
||||
}
|
||||
|
||||
// Execute looks up a tool by name and runs it. Returns an error result if not found.
|
||||
func (r *Registry) Execute(ctx context.Context, name string, argsJSON string) Result {
|
||||
t, ok := r.tools[name]
|
||||
if !ok {
|
||||
return Result{Err: fmt.Errorf("tool %q not found", name)}
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if argsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return Result{Err: fmt.Errorf("parse args for %q: %w", name, err)}
|
||||
}
|
||||
}
|
||||
|
||||
return t.Exec(ctx, args)
|
||||
}
|
||||
|
||||
// ToLLMSpecs converts all registered tools to the LLM-compatible ToolSpec format.
|
||||
// This is a pure transformation — no side effects.
|
||||
func (r *Registry) ToLLMSpecs() []coretypes.ToolSpec {
|
||||
specs := make([]coretypes.ToolSpec, 0, len(r.tools))
|
||||
for _, name := range r.Names() {
|
||||
t := r.tools[name]
|
||||
specs = append(specs, defToLLMSpec(t.Def))
|
||||
}
|
||||
return specs
|
||||
}
|
||||
|
||||
// defToLLMSpec converts a pure Def to an LLM ToolSpec with JSON Schema.
|
||||
func defToLLMSpec(d Def) coretypes.ToolSpec {
|
||||
properties := make(map[string]any, len(d.Parameters))
|
||||
required := make([]string, 0)
|
||||
|
||||
for _, p := range d.Parameters {
|
||||
properties[p.Name] = map[string]any{
|
||||
"type": p.Type,
|
||||
"description": p.Description,
|
||||
}
|
||||
if p.Required {
|
||||
required = append(required, p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
|
||||
return coretypes.ToolSpec{
|
||||
Name: d.Name,
|
||||
Description: d.Description,
|
||||
InputSchema: schema,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
corespecs "github.com/enmanuel/agents/pkg/tools"
|
||||
"github.com/enmanuel/agents/shell/ssh"
|
||||
)
|
||||
|
||||
// NewSSHCommand creates an ssh_command tool that executes remote commands via SSH.
|
||||
// Validates targets against cfg.AllowedTargets and commands against cfg.ForbiddenCommands.
|
||||
func NewSSHCommand(cfg config.SSHToolCfg, exec *ssh.Executor) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "ssh_command",
|
||||
Description: "Execute a command on a remote server via SSH.",
|
||||
Parameters: []Param{
|
||||
{Name: "target", Type: "string", Description: "The SSH target name (e.g. production, staging)", Required: true},
|
||||
{Name: "command", Type: "string", Description: "The shell command to execute", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
target := getString(args, "target")
|
||||
command := getString(args, "command")
|
||||
if target == "" || command == "" {
|
||||
return Result{Err: fmt.Errorf("ssh_command: target and command are required")}
|
||||
}
|
||||
|
||||
if err := validateTarget(target, cfg.AllowedTargets); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
if err := validateCommand(command, cfg.ForbiddenCommands); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
timeout := "30s"
|
||||
if cfg.Timeout > 0 {
|
||||
timeout = cfg.Timeout.String()
|
||||
}
|
||||
|
||||
res := exec.Execute(ctx, corespecs.SSHCommandSpec{
|
||||
Target: target,
|
||||
Command: command,
|
||||
Timeout: timeout,
|
||||
})
|
||||
|
||||
if res.Err != nil {
|
||||
return Result{Err: fmt.Errorf("ssh_command: %w", res.Err)}
|
||||
}
|
||||
|
||||
output := res.Stdout
|
||||
if res.Stderr != "" {
|
||||
output += "\nstderr: " + res.Stderr
|
||||
}
|
||||
return Result{Output: output}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validateTarget(target string, allowed []string) error {
|
||||
if len(allowed) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, a := range allowed {
|
||||
if target == a {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("ssh target %q not in allowed list", target)
|
||||
}
|
||||
|
||||
func validateCommand(command string, forbidden []string) error {
|
||||
lower := strings.ToLower(command)
|
||||
for _, f := range forbidden {
|
||||
if strings.Contains(lower, strings.ToLower(f)) {
|
||||
return fmt.Errorf("ssh command contains forbidden pattern %q", f)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Package tools defines tool specifications (pure) and their execution functions (impure).
|
||||
// Each tool is a pair: Def (pure data) + ToolFunc (impure execution).
|
||||
// To add a new tool, create a file in this package and register it in the agent builder.
|
||||
package tools
|
||||
|
||||
import "context"
|
||||
|
||||
// Def is the pure specification of a tool — only data, no side effects.
|
||||
type Def struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []Param
|
||||
}
|
||||
|
||||
// Param describes a single parameter accepted by a tool.
|
||||
type Param struct {
|
||||
Name string
|
||||
Type string // "string", "number", "boolean", "integer", "object", "array"
|
||||
Description string
|
||||
Required bool
|
||||
}
|
||||
|
||||
// Result holds the outcome of executing a tool.
|
||||
type Result struct {
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// ToolFunc is the impure function that actually executes the tool.
|
||||
type ToolFunc func(ctx context.Context, args map[string]any) Result
|
||||
|
||||
// Tool bundles a pure definition with its impure implementation.
|
||||
type Tool struct {
|
||||
Def Def
|
||||
Exec ToolFunc
|
||||
}
|
||||
|
||||
// getString extracts a string argument by name, returning "" if missing or wrong type.
|
||||
func getString(args map[string]any, key string) string {
|
||||
v, ok := args[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
Reference in New Issue
Block a user