package http import ( "context" "fmt" "io" "net" "net/http" "net/url" "strings" "time" "github.com/enmanuel/agents/internal/config" "github.com/enmanuel/agents/tools" ) // NewHTTPGet creates an http_get tool that performs GET requests. // Validates URLs against cfg.AllowedDomains (deny-by-default if non-empty) // and blocks requests to internal/private IP ranges (SSRF protection). func NewHTTPGet(cfg config.HTTPToolCfg) tools.Tool { timeout := cfg.Timeout if timeout == 0 { timeout = 30 * time.Second } client := &http.Client{Timeout: timeout} return tools.Tool{ Def: tools.Def{ Name: "http_get", Description: "Perform an HTTP GET request to a URL and return the response body.", Parameters: []tools.Param{ {Name: "url", Type: "string", Description: "The URL to request", Required: true}, }, }, Exec: func(ctx context.Context, args map[string]any) tools.Result { rawURL := tools.GetString(args, "url") if rawURL == "" { return tools.Result{Err: fmt.Errorf("http_get: url is required")} } if err := validateURL(rawURL, cfg.AllowedDomains); err != nil { return tools.Result{Err: fmt.Errorf("http_get: %w", err)} } req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) if err != nil { return tools.Result{Err: fmt.Errorf("http_get: %w", err)} } resp, err := client.Do(req) if err != nil { return tools.Result{Err: fmt.Errorf("http_get: %w", err)} } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) // 64 KB limit if err != nil { return tools.Result{Err: fmt.Errorf("http_get read body: %w", err)} } return tools.Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)} }, } } // NewHTTPPost creates an http_post tool that performs POST requests with a JSON body. // Validates URLs against cfg.AllowedDomains and blocks internal IPs. func NewHTTPPost(cfg config.HTTPToolCfg) tools.Tool { timeout := cfg.Timeout if timeout == 0 { timeout = 30 * time.Second } client := &http.Client{Timeout: timeout} return tools.Tool{ Def: tools.Def{ Name: "http_post", Description: "Perform an HTTP POST request with a JSON body and return the response.", Parameters: []tools.Param{ {Name: "url", Type: "string", Description: "The URL to request", Required: true}, {Name: "body", Type: "string", Description: "The JSON body to send", Required: true}, }, }, Exec: func(ctx context.Context, args map[string]any) tools.Result { rawURL := tools.GetString(args, "url") if rawURL == "" { return tools.Result{Err: fmt.Errorf("http_post: url is required")} } bodyStr := tools.GetString(args, "body") if bodyStr == "" { return tools.Result{Err: fmt.Errorf("http_post: body is required")} } if err := validateURL(rawURL, cfg.AllowedDomains); err != nil { return tools.Result{Err: fmt.Errorf("http_post: %w", err)} } req, err := http.NewRequestWithContext(ctx, http.MethodPost, rawURL, strings.NewReader(bodyStr)) if err != nil { return tools.Result{Err: fmt.Errorf("http_post: %w", err)} } req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) if err != nil { return tools.Result{Err: fmt.Errorf("http_post: %w", err)} } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) if err != nil { return tools.Result{Err: fmt.Errorf("http_post read body: %w", err)} } return tools.Result{Output: fmt.Sprintf("HTTP %d\n%s", resp.StatusCode, body)} }, } } // validateURL checks domain allowlist and blocks internal IPs (SSRF protection). func validateURL(rawURL string, allowedDomains []string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid url: %w", err) } host := u.Hostname() if host == "" { return fmt.Errorf("url has no host") } // SSRF protection: block internal/private IPs and localhost. if err := rejectInternalHost(host); err != nil { return err } // Domain allowlist (if configured). if err := validateDomain(host, allowedDomains); err != nil { return err } return nil } // validateDomain checks that the host is in the allowed list. // If allowedDomains is empty, all domains are allowed. func validateDomain(host string, allowedDomains []string) error { if len(allowedDomains) == 0 { return nil } lower := strings.ToLower(host) for _, d := range allowedDomains { if lower == strings.ToLower(d) { return nil } } return fmt.Errorf("domain %q not in allowed list", host) } // rejectInternalHost blocks requests to localhost, private IPs, and link-local addresses. func rejectInternalHost(host string) error { lower := strings.ToLower(host) if lower == "localhost" { return fmt.Errorf("requests to localhost are blocked") } ip := net.ParseIP(host) if ip == nil { // Not an IP literal — could be a domain. Resolve it. ips, err := net.LookupIP(host) if err != nil { return nil // let the HTTP client handle DNS errors } for _, resolved := range ips { if isPrivateIP(resolved) { return fmt.Errorf("domain %q resolves to private IP %s", host, resolved) } } return nil } if isPrivateIP(ip) { return fmt.Errorf("requests to private IP %s are blocked", ip) } return nil } // isPrivateIP returns true for loopback, private, link-local, and metadata IPs. func isPrivateIP(ip net.IP) bool { return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || isMetadataIP(ip) } // isMetadataIP checks for cloud metadata service IPs (169.254.169.254). func isMetadataIP(ip net.IP) bool { return ip.Equal(net.ParseIP("169.254.169.254")) }