feat: implement tool registry and add various tools for HTTP, file operations, SSH, and Matrix messaging

This commit is contained in:
2026-03-04 21:10:29 +00:00
parent ddec55871b
commit 0f8d2f9ca0
11 changed files with 828 additions and 45 deletions
+69
View File
@@ -0,0 +1,69 @@
package tools
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/enmanuel/agents/internal/config"
)
// NewReadFile creates a read_file tool that reads local files.
// Validates paths against cfg.AllowedPaths when non-empty.
func NewReadFile(cfg config.FileOpsCfg) Tool {
return Tool{
Def: Def{
Name: "read_file",
Description: "Read the contents of a local file.",
Parameters: []Param{
{Name: "path", Type: "string", Description: "Absolute path to the file to read", Required: true},
},
},
Exec: func(ctx context.Context, args map[string]any) Result {
path := getString(args, "path")
if path == "" {
return Result{Err: fmt.Errorf("read_file: path is required")}
}
absPath, err := filepath.Abs(path)
if err != nil {
return Result{Err: fmt.Errorf("read_file: %w", err)}
}
if err := validatePath(absPath, cfg.AllowedPaths); err != nil {
return Result{Err: err}
}
data, err := os.ReadFile(absPath)
if err != nil {
return Result{Err: fmt.Errorf("read_file: %w", err)}
}
// Limit output to 64 KB
content := string(data)
if len(content) > 64*1024 {
content = content[:64*1024] + "\n... (truncated)"
}
return Result{Output: content}
},
}
}
func validatePath(absPath string, allowedPaths []string) error {
if len(allowedPaths) == 0 {
return nil
}
for _, allowed := range allowedPaths {
a, err := filepath.Abs(allowed)
if err != nil {
continue
}
if strings.HasPrefix(absPath, a) {
return nil
}
}
return fmt.Errorf("path %q not under any allowed path", absPath)
}
+132
View File
@@ -0,0 +1,132 @@
package tools
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/enmanuel/agents/internal/config"
)
// NewHTTPGet creates an http_get tool that performs GET requests.
// Validates URLs against cfg.AllowedDomains when non-empty.
func NewHTTPGet(cfg config.HTTPToolCfg) Tool {
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
client := &http.Client{Timeout: timeout}
return Tool{
Def: Def{
Name: "http_get",
Description: "Perform an HTTP GET request to a URL and return the response body.",
Parameters: []Param{
{Name: "url", Type: "string", Description: "The URL to request", Required: true},
},
},
Exec: func(ctx context.Context, args map[string]any) Result {
rawURL := getString(args, "url")
if rawURL == "" {
return Result{Err: fmt.Errorf("http_get: url is required")}
}
if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil {
return Result{Err: err}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return Result{Err: fmt.Errorf("http_get: %w", err)}
}
resp, err := client.Do(req)
if err != nil {
return Result{Err: fmt.Errorf("http_get: %w", err)}
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) // 64 KB limit
if err != nil {
return Result{Err: fmt.Errorf("http_get read body: %w", err)}
}
return Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)}
},
}
}
// NewHTTPPost creates an http_post tool that performs POST requests with a JSON body.
// Validates URLs against cfg.AllowedDomains when non-empty.
func NewHTTPPost(cfg config.HTTPToolCfg) Tool {
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
client := &http.Client{Timeout: timeout}
return Tool{
Def: Def{
Name: "http_post",
Description: "Perform an HTTP POST request with a JSON body and return the response.",
Parameters: []Param{
{Name: "url", Type: "string", Description: "The URL to request", Required: true},
{Name: "body", Type: "string", Description: "The JSON body to send", Required: true},
},
},
Exec: func(ctx context.Context, args map[string]any) Result {
rawURL := getString(args, "url")
if rawURL == "" {
return Result{Err: fmt.Errorf("http_post: url is required")}
}
bodyStr := getString(args, "body")
if bodyStr == "" {
return Result{Err: fmt.Errorf("http_post: body is required")}
}
if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil {
return Result{Err: err}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, rawURL, strings.NewReader(bodyStr))
if err != nil {
return Result{Err: fmt.Errorf("http_post: %w", err)}
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return Result{Err: fmt.Errorf("http_post: %w", err)}
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if err != nil {
return Result{Err: fmt.Errorf("http_post read body: %w", err)}
}
return Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)}
},
}
}
// validateDomain checks that the URL's host is in the allowed list.
// If allowedDomains is empty, all domains are allowed.
func validateDomain(rawURL string, allowedDomains []string) error {
if len(allowedDomains) == 0 {
return nil
}
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid url: %w", err)
}
host := u.Hostname()
for _, d := range allowedDomains {
if host == d {
return nil
}
}
return fmt.Errorf("domain %q not in allowed list", host)
}
+39
View File
@@ -0,0 +1,39 @@
package tools
import (
"context"
"fmt"
)
// MatrixSender is the interface for sending Matrix messages.
// Satisfied by shell/matrix.Client.
type MatrixSender interface {
SendText(ctx context.Context, roomID, text string) error
}
// NewMatrixSend creates a matrix_send tool that sends a message to a Matrix room.
func NewMatrixSend(sender MatrixSender) Tool {
return Tool{
Def: Def{
Name: "matrix_send",
Description: "Send a text message to a Matrix room.",
Parameters: []Param{
{Name: "room_id", Type: "string", Description: "The Matrix room ID to send to", Required: true},
{Name: "message", Type: "string", Description: "The text message to send", Required: true},
},
},
Exec: func(ctx context.Context, args map[string]any) Result {
roomID := getString(args, "room_id")
message := getString(args, "message")
if roomID == "" || message == "" {
return Result{Err: fmt.Errorf("matrix_send: room_id and message are required")}
}
if err := sender.SendText(ctx, roomID, message); err != nil {
return Result{Err: fmt.Errorf("matrix_send: %w", err)}
}
return Result{Output: fmt.Sprintf("message sent to %s", roomID)}
},
}
}
+104
View File
@@ -0,0 +1,104 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"sort"
coretypes "github.com/enmanuel/agents/pkg/llm"
)
// Registry holds available tools keyed by name.
type Registry struct {
tools map[string]Tool
}
// NewRegistry creates an empty registry.
func NewRegistry() *Registry {
return &Registry{tools: make(map[string]Tool)}
}
// Register adds a tool to the registry.
func (r *Registry) Register(t Tool) {
r.tools[t.Def.Name] = t
}
// Get looks up a tool by name.
func (r *Registry) Get(name string) (Tool, bool) {
t, ok := r.tools[name]
return t, ok
}
// Names returns all registered tool names in sorted order.
func (r *Registry) Names() []string {
names := make([]string, 0, len(r.tools))
for k := range r.tools {
names = append(names, k)
}
sort.Strings(names)
return names
}
// Len returns the number of registered tools.
func (r *Registry) Len() int {
return len(r.tools)
}
// Execute looks up a tool by name and runs it. Returns an error result if not found.
func (r *Registry) Execute(ctx context.Context, name string, argsJSON string) Result {
t, ok := r.tools[name]
if !ok {
return Result{Err: fmt.Errorf("tool %q not found", name)}
}
var args map[string]any
if argsJSON != "" {
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return Result{Err: fmt.Errorf("parse args for %q: %w", name, err)}
}
}
return t.Exec(ctx, args)
}
// ToLLMSpecs converts all registered tools to the LLM-compatible ToolSpec format.
// This is a pure transformation — no side effects.
func (r *Registry) ToLLMSpecs() []coretypes.ToolSpec {
specs := make([]coretypes.ToolSpec, 0, len(r.tools))
for _, name := range r.Names() {
t := r.tools[name]
specs = append(specs, defToLLMSpec(t.Def))
}
return specs
}
// defToLLMSpec converts a pure Def to an LLM ToolSpec with JSON Schema.
func defToLLMSpec(d Def) coretypes.ToolSpec {
properties := make(map[string]any, len(d.Parameters))
required := make([]string, 0)
for _, p := range d.Parameters {
properties[p.Name] = map[string]any{
"type": p.Type,
"description": p.Description,
}
if p.Required {
required = append(required, p.Name)
}
}
schema := map[string]any{
"type": "object",
"properties": properties,
}
if len(required) > 0 {
schema["required"] = required
}
return coretypes.ToolSpec{
Name: d.Name,
Description: d.Description,
InputSchema: schema,
}
}
+83
View File
@@ -0,0 +1,83 @@
package tools
import (
"context"
"fmt"
"strings"
"github.com/enmanuel/agents/internal/config"
corespecs "github.com/enmanuel/agents/pkg/tools"
"github.com/enmanuel/agents/shell/ssh"
)
// NewSSHCommand creates an ssh_command tool that executes remote commands via SSH.
// Validates targets against cfg.AllowedTargets and commands against cfg.ForbiddenCommands.
func NewSSHCommand(cfg config.SSHToolCfg, exec *ssh.Executor) Tool {
return Tool{
Def: Def{
Name: "ssh_command",
Description: "Execute a command on a remote server via SSH.",
Parameters: []Param{
{Name: "target", Type: "string", Description: "The SSH target name (e.g. production, staging)", Required: true},
{Name: "command", Type: "string", Description: "The shell command to execute", Required: true},
},
},
Exec: func(ctx context.Context, args map[string]any) Result {
target := getString(args, "target")
command := getString(args, "command")
if target == "" || command == "" {
return Result{Err: fmt.Errorf("ssh_command: target and command are required")}
}
if err := validateTarget(target, cfg.AllowedTargets); err != nil {
return Result{Err: err}
}
if err := validateCommand(command, cfg.ForbiddenCommands); err != nil {
return Result{Err: err}
}
timeout := "30s"
if cfg.Timeout > 0 {
timeout = cfg.Timeout.String()
}
res := exec.Execute(ctx, corespecs.SSHCommandSpec{
Target: target,
Command: command,
Timeout: timeout,
})
if res.Err != nil {
return Result{Err: fmt.Errorf("ssh_command: %w", res.Err)}
}
output := res.Stdout
if res.Stderr != "" {
output += "\nstderr: " + res.Stderr
}
return Result{Output: output}
},
}
}
func validateTarget(target string, allowed []string) error {
if len(allowed) == 0 {
return nil
}
for _, a := range allowed {
if target == a {
return nil
}
}
return fmt.Errorf("ssh target %q not in allowed list", target)
}
func validateCommand(command string, forbidden []string) error {
lower := strings.ToLower(command)
for _, f := range forbidden {
if strings.Contains(lower, strings.ToLower(f)) {
return fmt.Errorf("ssh command contains forbidden pattern %q", f)
}
}
return nil
}
+49
View File
@@ -0,0 +1,49 @@
// Package tools defines tool specifications (pure) and their execution functions (impure).
// Each tool is a pair: Def (pure data) + ToolFunc (impure execution).
// To add a new tool, create a file in this package and register it in the agent builder.
package tools
import "context"
// Def is the pure specification of a tool — only data, no side effects.
type Def struct {
Name string
Description string
Parameters []Param
}
// Param describes a single parameter accepted by a tool.
type Param struct {
Name string
Type string // "string", "number", "boolean", "integer", "object", "array"
Description string
Required bool
}
// Result holds the outcome of executing a tool.
type Result struct {
Output string
Err error
}
// ToolFunc is the impure function that actually executes the tool.
type ToolFunc func(ctx context.Context, args map[string]any) Result
// Tool bundles a pure definition with its impure implementation.
type Tool struct {
Def Def
Exec ToolFunc
}
// getString extracts a string argument by name, returning "" if missing or wrong type.
func getString(args map[string]any, key string) string {
v, ok := args[key]
if !ok {
return ""
}
s, ok := v.(string)
if !ok {
return ""
}
return s
}