Files
agents_and_robots/shell/ssh/executor.go
T
2026-03-03 23:19:23 +00:00

147 lines
3.2 KiB
Go

// Package ssh provides impure SSH command execution.
package ssh
import (
"bytes"
"context"
"fmt"
"net"
"os"
"time"
gossh "golang.org/x/crypto/ssh"
"github.com/enmanuel/agents/internal/config"
"github.com/enmanuel/agents/pkg/tools"
)
// Result holds the output of an SSH command execution.
type Result struct {
Stdout string
Stderr string
ExitCode int
Err error
}
// Executor runs SSH commands against configured targets.
type Executor struct {
cfg config.SSHCfg
}
// NewExecutor creates an Executor from the SSH config section.
func NewExecutor(cfg config.SSHCfg) *Executor {
return &Executor{cfg: cfg}
}
// Execute runs the SSH command described by spec. Impure.
func (e *Executor) Execute(ctx context.Context, spec tools.SSHCommandSpec) Result {
target, ok := e.cfg.Targets[spec.Target]
if !ok {
return Result{Err: fmt.Errorf("unknown SSH target: %s", spec.Target)}
}
if len(target.Hosts) == 0 {
return Result{Err: fmt.Errorf("no hosts for target: %s", spec.Target)}
}
// Use first host (round-robin or load balancing can be added later)
host := target.Hosts[0]
user := target.User
if user == "" {
user = e.cfg.Defaults.User
}
port := target.Port
if port == 0 {
port = e.cfg.Defaults.Port
}
if port == 0 {
port = 22
}
keyEnv := target.KeyFileEnv
if keyEnv == "" {
keyEnv = e.cfg.Defaults.KeyFileEnv
}
signer, err := loadSigner(keyEnv)
if err != nil {
return Result{Err: fmt.Errorf("load SSH key: %w", err)}
}
sshCfg := &gossh.ClientConfig{
User: user,
Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)},
HostKeyCallback: gossh.InsecureIgnoreHostKey(), // TODO: use known_hosts
Timeout: e.cfg.Defaults.Timeout,
}
if sshCfg.Timeout == 0 {
sshCfg.Timeout = 10 * time.Second
}
addr := fmt.Sprintf("%s:%d", host, port)
conn, err := gossh.Dial("tcp", addr, sshCfg)
if err != nil {
return Result{Err: fmt.Errorf("ssh dial %s: %w", addr, err)}
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
return Result{Err: fmt.Errorf("ssh session: %w", err)}
}
defer session.Close()
var stdout, stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
// Respect context cancellation via a goroutine
done := make(chan error, 1)
go func() { done <- session.Run(spec.Command) }()
select {
case <-ctx.Done():
session.Signal(gossh.SIGTERM)
return Result{Err: ctx.Err()}
case err := <-done:
code := 0
if err != nil {
var exitErr *gossh.ExitError
if ok := asExitError(err, &exitErr); ok {
code = exitErr.ExitStatus()
} else {
return Result{Err: err}
}
}
return Result{
Stdout: stdout.String(),
Stderr: stderr.String(),
ExitCode: code,
}
}
}
func loadSigner(keyFileEnv string) (gossh.Signer, error) {
keyPath := os.Getenv(keyFileEnv)
if keyPath == "" {
return nil, fmt.Errorf("env var %s not set", keyFileEnv)
}
raw, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}
return gossh.ParsePrivateKey(raw)
}
// asExitError is a helper for type-asserting ssh.ExitError.
func asExitError(err error, target **gossh.ExitError) bool {
e, ok := err.(*gossh.ExitError)
if ok {
*target = e
}
return ok
}
// Ensure net is used (for future jump host support)
var _ = net.Dial