feat: implement tool registry and add various tools for HTTP, file operations, SSH, and Matrix messaging
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
)
|
||||
|
||||
// NewReadFile creates a read_file tool that reads local files.
|
||||
// Validates paths against cfg.AllowedPaths when non-empty.
|
||||
func NewReadFile(cfg config.FileOpsCfg) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "read_file",
|
||||
Description: "Read the contents of a local file.",
|
||||
Parameters: []Param{
|
||||
{Name: "path", Type: "string", Description: "Absolute path to the file to read", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
path := getString(args, "path")
|
||||
if path == "" {
|
||||
return Result{Err: fmt.Errorf("read_file: path is required")}
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("read_file: %w", err)}
|
||||
}
|
||||
|
||||
if err := validatePath(absPath, cfg.AllowedPaths); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("read_file: %w", err)}
|
||||
}
|
||||
|
||||
// Limit output to 64 KB
|
||||
content := string(data)
|
||||
if len(content) > 64*1024 {
|
||||
content = content[:64*1024] + "\n... (truncated)"
|
||||
}
|
||||
|
||||
return Result{Output: content}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validatePath(absPath string, allowedPaths []string) error {
|
||||
if len(allowedPaths) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, allowed := range allowedPaths {
|
||||
a, err := filepath.Abs(allowed)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(absPath, a) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("path %q not under any allowed path", absPath)
|
||||
}
|
||||
+132
@@ -0,0 +1,132 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
)
|
||||
|
||||
// NewHTTPGet creates an http_get tool that performs GET requests.
|
||||
// Validates URLs against cfg.AllowedDomains when non-empty.
|
||||
func NewHTTPGet(cfg config.HTTPToolCfg) Tool {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "http_get",
|
||||
Description: "Perform an HTTP GET request to a URL and return the response body.",
|
||||
Parameters: []Param{
|
||||
{Name: "url", Type: "string", Description: "The URL to request", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
rawURL := getString(args, "url")
|
||||
if rawURL == "" {
|
||||
return Result{Err: fmt.Errorf("http_get: url is required")}
|
||||
}
|
||||
if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_get: %w", err)}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_get: %w", err)}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) // 64 KB limit
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_get read body: %w", err)}
|
||||
}
|
||||
|
||||
return Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPPost creates an http_post tool that performs POST requests with a JSON body.
|
||||
// Validates URLs against cfg.AllowedDomains when non-empty.
|
||||
func NewHTTPPost(cfg config.HTTPToolCfg) Tool {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "http_post",
|
||||
Description: "Perform an HTTP POST request with a JSON body and return the response.",
|
||||
Parameters: []Param{
|
||||
{Name: "url", Type: "string", Description: "The URL to request", Required: true},
|
||||
{Name: "body", Type: "string", Description: "The JSON body to send", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
rawURL := getString(args, "url")
|
||||
if rawURL == "" {
|
||||
return Result{Err: fmt.Errorf("http_post: url is required")}
|
||||
}
|
||||
bodyStr := getString(args, "body")
|
||||
if bodyStr == "" {
|
||||
return Result{Err: fmt.Errorf("http_post: body is required")}
|
||||
}
|
||||
if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, rawURL, strings.NewReader(bodyStr))
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_post: %w", err)}
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_post: %w", err)}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
if err != nil {
|
||||
return Result{Err: fmt.Errorf("http_post read body: %w", err)}
|
||||
}
|
||||
|
||||
return Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// validateDomain checks that the URL's host is in the allowed list.
|
||||
// If allowedDomains is empty, all domains are allowed.
|
||||
func validateDomain(rawURL string, allowedDomains []string) error {
|
||||
if len(allowedDomains) == 0 {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid url: %w", err)
|
||||
}
|
||||
host := u.Hostname()
|
||||
for _, d := range allowedDomains {
|
||||
if host == d {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("domain %q not in allowed list", host)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MatrixSender is the interface for sending Matrix messages.
|
||||
// Satisfied by shell/matrix.Client.
|
||||
type MatrixSender interface {
|
||||
SendText(ctx context.Context, roomID, text string) error
|
||||
}
|
||||
|
||||
// NewMatrixSend creates a matrix_send tool that sends a message to a Matrix room.
|
||||
func NewMatrixSend(sender MatrixSender) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "matrix_send",
|
||||
Description: "Send a text message to a Matrix room.",
|
||||
Parameters: []Param{
|
||||
{Name: "room_id", Type: "string", Description: "The Matrix room ID to send to", Required: true},
|
||||
{Name: "message", Type: "string", Description: "The text message to send", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
roomID := getString(args, "room_id")
|
||||
message := getString(args, "message")
|
||||
if roomID == "" || message == "" {
|
||||
return Result{Err: fmt.Errorf("matrix_send: room_id and message are required")}
|
||||
}
|
||||
|
||||
if err := sender.SendText(ctx, roomID, message); err != nil {
|
||||
return Result{Err: fmt.Errorf("matrix_send: %w", err)}
|
||||
}
|
||||
|
||||
return Result{Output: fmt.Sprintf("message sent to %s", roomID)}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
coretypes "github.com/enmanuel/agents/pkg/llm"
|
||||
)
|
||||
|
||||
// Registry holds available tools keyed by name.
|
||||
type Registry struct {
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{tools: make(map[string]Tool)}
|
||||
}
|
||||
|
||||
// Register adds a tool to the registry.
|
||||
func (r *Registry) Register(t Tool) {
|
||||
r.tools[t.Def.Name] = t
|
||||
}
|
||||
|
||||
// Get looks up a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
t, ok := r.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
// Names returns all registered tool names in sorted order.
|
||||
func (r *Registry) Names() []string {
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for k := range r.tools {
|
||||
names = append(names, k)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// Len returns the number of registered tools.
|
||||
func (r *Registry) Len() int {
|
||||
return len(r.tools)
|
||||
}
|
||||
|
||||
// Execute looks up a tool by name and runs it. Returns an error result if not found.
|
||||
func (r *Registry) Execute(ctx context.Context, name string, argsJSON string) Result {
|
||||
t, ok := r.tools[name]
|
||||
if !ok {
|
||||
return Result{Err: fmt.Errorf("tool %q not found", name)}
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if argsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return Result{Err: fmt.Errorf("parse args for %q: %w", name, err)}
|
||||
}
|
||||
}
|
||||
|
||||
return t.Exec(ctx, args)
|
||||
}
|
||||
|
||||
// ToLLMSpecs converts all registered tools to the LLM-compatible ToolSpec format.
|
||||
// This is a pure transformation — no side effects.
|
||||
func (r *Registry) ToLLMSpecs() []coretypes.ToolSpec {
|
||||
specs := make([]coretypes.ToolSpec, 0, len(r.tools))
|
||||
for _, name := range r.Names() {
|
||||
t := r.tools[name]
|
||||
specs = append(specs, defToLLMSpec(t.Def))
|
||||
}
|
||||
return specs
|
||||
}
|
||||
|
||||
// defToLLMSpec converts a pure Def to an LLM ToolSpec with JSON Schema.
|
||||
func defToLLMSpec(d Def) coretypes.ToolSpec {
|
||||
properties := make(map[string]any, len(d.Parameters))
|
||||
required := make([]string, 0)
|
||||
|
||||
for _, p := range d.Parameters {
|
||||
properties[p.Name] = map[string]any{
|
||||
"type": p.Type,
|
||||
"description": p.Description,
|
||||
}
|
||||
if p.Required {
|
||||
required = append(required, p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
|
||||
return coretypes.ToolSpec{
|
||||
Name: d.Name,
|
||||
Description: d.Description,
|
||||
InputSchema: schema,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/enmanuel/agents/internal/config"
|
||||
corespecs "github.com/enmanuel/agents/pkg/tools"
|
||||
"github.com/enmanuel/agents/shell/ssh"
|
||||
)
|
||||
|
||||
// NewSSHCommand creates an ssh_command tool that executes remote commands via SSH.
|
||||
// Validates targets against cfg.AllowedTargets and commands against cfg.ForbiddenCommands.
|
||||
func NewSSHCommand(cfg config.SSHToolCfg, exec *ssh.Executor) Tool {
|
||||
return Tool{
|
||||
Def: Def{
|
||||
Name: "ssh_command",
|
||||
Description: "Execute a command on a remote server via SSH.",
|
||||
Parameters: []Param{
|
||||
{Name: "target", Type: "string", Description: "The SSH target name (e.g. production, staging)", Required: true},
|
||||
{Name: "command", Type: "string", Description: "The shell command to execute", Required: true},
|
||||
},
|
||||
},
|
||||
Exec: func(ctx context.Context, args map[string]any) Result {
|
||||
target := getString(args, "target")
|
||||
command := getString(args, "command")
|
||||
if target == "" || command == "" {
|
||||
return Result{Err: fmt.Errorf("ssh_command: target and command are required")}
|
||||
}
|
||||
|
||||
if err := validateTarget(target, cfg.AllowedTargets); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
if err := validateCommand(command, cfg.ForbiddenCommands); err != nil {
|
||||
return Result{Err: err}
|
||||
}
|
||||
|
||||
timeout := "30s"
|
||||
if cfg.Timeout > 0 {
|
||||
timeout = cfg.Timeout.String()
|
||||
}
|
||||
|
||||
res := exec.Execute(ctx, corespecs.SSHCommandSpec{
|
||||
Target: target,
|
||||
Command: command,
|
||||
Timeout: timeout,
|
||||
})
|
||||
|
||||
if res.Err != nil {
|
||||
return Result{Err: fmt.Errorf("ssh_command: %w", res.Err)}
|
||||
}
|
||||
|
||||
output := res.Stdout
|
||||
if res.Stderr != "" {
|
||||
output += "\nstderr: " + res.Stderr
|
||||
}
|
||||
return Result{Output: output}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validateTarget(target string, allowed []string) error {
|
||||
if len(allowed) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, a := range allowed {
|
||||
if target == a {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("ssh target %q not in allowed list", target)
|
||||
}
|
||||
|
||||
func validateCommand(command string, forbidden []string) error {
|
||||
lower := strings.ToLower(command)
|
||||
for _, f := range forbidden {
|
||||
if strings.Contains(lower, strings.ToLower(f)) {
|
||||
return fmt.Errorf("ssh command contains forbidden pattern %q", f)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Package tools defines tool specifications (pure) and their execution functions (impure).
|
||||
// Each tool is a pair: Def (pure data) + ToolFunc (impure execution).
|
||||
// To add a new tool, create a file in this package and register it in the agent builder.
|
||||
package tools
|
||||
|
||||
import "context"
|
||||
|
||||
// Def is the pure specification of a tool — only data, no side effects.
|
||||
type Def struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []Param
|
||||
}
|
||||
|
||||
// Param describes a single parameter accepted by a tool.
|
||||
type Param struct {
|
||||
Name string
|
||||
Type string // "string", "number", "boolean", "integer", "object", "array"
|
||||
Description string
|
||||
Required bool
|
||||
}
|
||||
|
||||
// Result holds the outcome of executing a tool.
|
||||
type Result struct {
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// ToolFunc is the impure function that actually executes the tool.
|
||||
type ToolFunc func(ctx context.Context, args map[string]any) Result
|
||||
|
||||
// Tool bundles a pure definition with its impure implementation.
|
||||
type Tool struct {
|
||||
Def Def
|
||||
Exec ToolFunc
|
||||
}
|
||||
|
||||
// getString extracts a string argument by name, returning "" if missing or wrong type.
|
||||
func getString(args map[string]any, key string) string {
|
||||
v, ok := args[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
Reference in New Issue
Block a user