package ssh import ( "context" "fmt" "strings" "github.com/enmanuel/agents/internal/config" corespecs "github.com/enmanuel/agents/pkg/tools" shellssh "github.com/enmanuel/agents/shell/ssh" "github.com/enmanuel/agents/tools" ) // NewSSHCommand creates an ssh_command tool that executes remote commands via SSH. // Validates targets against AllowedTargets (deny-by-default if non-empty), // commands against AllowedCommands allowlist (if non-empty, only those prefixes permitted), // and against ForbiddenCommands blocklist as a second defense layer. func NewSSHCommand(cfg config.SSHToolCfg, exec *shellssh.Executor) tools.Tool { return tools.Tool{ Def: tools.Def{ Name: "ssh_command", Description: "Execute a command on a remote server via SSH.", Parameters: []tools.Param{ {Name: "target", Type: "string", Description: "The SSH target name (e.g. production, staging)", Required: true}, {Name: "command", Type: "string", Description: "The shell command to execute", Required: true}, }, }, Exec: func(ctx context.Context, args map[string]any) tools.Result { target := tools.GetString(args, "target") command := tools.GetString(args, "command") if target == "" || command == "" { return tools.Result{Err: fmt.Errorf("ssh_command: target and command are required")} } if err := validateTarget(target, cfg.AllowedTargets); err != nil { return tools.Result{Err: err} } if err := validateAllowedCommand(command, cfg.AllowedCommands); err != nil { return tools.Result{Err: err} } if err := validateForbiddenCommand(command, cfg.ForbiddenCommands); err != nil { return tools.Result{Err: err} } if err := validateCommandSyntax(command); err != nil { return tools.Result{Err: err} } timeout := "30s" if cfg.Timeout > 0 { timeout = cfg.Timeout.String() } res := exec.Execute(ctx, corespecs.SSHCommandSpec{ Target: target, Command: command, Timeout: timeout, }) if res.Err != nil { return tools.Result{Err: fmt.Errorf("ssh_command: %w", res.Err)} } output := res.Stdout if res.Stderr != "" { output += "\nstderr: " + res.Stderr } return tools.Result{Output: output} }, } } func validateTarget(target string, allowed []string) error { if len(allowed) == 0 { return nil } for _, a := range allowed { if target == a { return nil } } return fmt.Errorf("ssh target %q not in allowed list", target) } // validateAllowedCommand checks that the command starts with one of the allowed prefixes. // If the allowlist is empty, all commands pass this check (blocklist still applies). func validateAllowedCommand(command string, allowed []string) error { if len(allowed) == 0 { return nil } lower := strings.ToLower(command) for _, a := range allowed { if strings.HasPrefix(lower, strings.ToLower(a)) { return nil } } return fmt.Errorf("ssh command not in allowed commands list") } // validateForbiddenCommand checks that the command does not contain any forbidden patterns. func validateForbiddenCommand(command string, forbidden []string) error { lower := strings.ToLower(command) for _, f := range forbidden { if strings.Contains(lower, strings.ToLower(f)) { return fmt.Errorf("ssh command contains forbidden pattern %q", f) } } return nil } // validateCommandSyntax rejects commands with suspicious shell constructs // that could be used to bypass restrictions: pipes to external services, // subshells, and output redirection. func validateCommandSyntax(command string) error { suspicious := []string{ "|", // pipe (can exfiltrate output) "$(", // command substitution "`", // backtick substitution ">>", // append redirection ">", // output redirection "&&", // command chaining "||", // command chaining ";", // command separator } for _, s := range suspicious { if strings.Contains(command, s) { return fmt.Errorf("ssh command contains disallowed shell syntax %q", s) } } return nil }