package ssh import ( "context" "fmt" "strings" "github.com/enmanuel/agents/internal/config" corespecs "github.com/enmanuel/agents/pkg/tools" shellssh "github.com/enmanuel/agents/shell/ssh" "github.com/enmanuel/agents/tools" ) // NewSSHCommand creates an ssh_command tool that executes remote commands via SSH. // Validates targets against cfg.AllowedTargets and commands against cfg.ForbiddenCommands. func NewSSHCommand(cfg config.SSHToolCfg, exec *shellssh.Executor) tools.Tool { return tools.Tool{ Def: tools.Def{ Name: "ssh_command", Description: "Execute a command on a remote server via SSH.", Parameters: []tools.Param{ {Name: "target", Type: "string", Description: "The SSH target name (e.g. production, staging)", Required: true}, {Name: "command", Type: "string", Description: "The shell command to execute", Required: true}, }, }, Exec: func(ctx context.Context, args map[string]any) tools.Result { target := tools.GetString(args, "target") command := tools.GetString(args, "command") if target == "" || command == "" { return tools.Result{Err: fmt.Errorf("ssh_command: target and command are required")} } if err := validateTarget(target, cfg.AllowedTargets); err != nil { return tools.Result{Err: err} } if err := validateCommand(command, cfg.ForbiddenCommands); err != nil { return tools.Result{Err: err} } timeout := "30s" if cfg.Timeout > 0 { timeout = cfg.Timeout.String() } res := exec.Execute(ctx, corespecs.SSHCommandSpec{ Target: target, Command: command, Timeout: timeout, }) if res.Err != nil { return tools.Result{Err: fmt.Errorf("ssh_command: %w", res.Err)} } output := res.Stdout if res.Stderr != "" { output += "\nstderr: " + res.Stderr } return tools.Result{Output: output} }, } } func validateTarget(target string, allowed []string) error { if len(allowed) == 0 { return nil } for _, a := range allowed { if target == a { return nil } } return fmt.Errorf("ssh target %q not in allowed list", target) } func validateCommand(command string, forbidden []string) error { lower := strings.ToLower(command) for _, f := range forbidden { if strings.Contains(lower, strings.ToLower(f)) { return fmt.Errorf("ssh command contains forbidden pattern %q", f) } } return nil }