1036 lines
35 KiB
Rust
1036 lines
35 KiB
Rust
use super::traits::{Tool, ToolResult};
|
|
use super::url_validation::{
|
|
normalize_allowed_domains, validate_url, DomainPolicy, UrlSchemePolicy,
|
|
};
|
|
use crate::config::{HttpRequestCredentialProfile, UrlAccessConfig};
|
|
use crate::security::SecurityPolicy;
|
|
use async_trait::async_trait;
|
|
use serde_json::json;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
/// HTTP request tool for API interactions.
|
|
/// Supports GET, POST, PUT, DELETE methods with configurable security.
|
|
pub struct HttpRequestTool {
|
|
security: Arc<SecurityPolicy>,
|
|
allowed_domains: Vec<String>,
|
|
url_access: UrlAccessConfig,
|
|
max_response_size: usize,
|
|
timeout_secs: u64,
|
|
user_agent: String,
|
|
credential_profiles: HashMap<String, HttpRequestCredentialProfile>,
|
|
}
|
|
|
|
impl HttpRequestTool {
|
|
pub fn new(
|
|
security: Arc<SecurityPolicy>,
|
|
allowed_domains: Vec<String>,
|
|
url_access: UrlAccessConfig,
|
|
max_response_size: usize,
|
|
timeout_secs: u64,
|
|
user_agent: String,
|
|
credential_profiles: HashMap<String, HttpRequestCredentialProfile>,
|
|
) -> Self {
|
|
Self {
|
|
security,
|
|
allowed_domains: normalize_allowed_domains(allowed_domains),
|
|
url_access,
|
|
max_response_size,
|
|
timeout_secs,
|
|
user_agent,
|
|
credential_profiles: credential_profiles
|
|
.into_iter()
|
|
.map(|(name, profile)| (name.trim().to_ascii_lowercase(), profile))
|
|
.collect(),
|
|
}
|
|
}
|
|
|
|
fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
|
|
validate_url(
|
|
raw_url,
|
|
&DomainPolicy {
|
|
allowed_domains: &self.allowed_domains,
|
|
blocked_domains: &[],
|
|
allowed_field_name: "http_request.allowed_domains",
|
|
blocked_field_name: None,
|
|
empty_allowed_message: "HTTP request tool is enabled but no allowed_domains are configured. Add [http_request].allowed_domains in config.toml",
|
|
scheme_policy: UrlSchemePolicy::HttpOrHttps,
|
|
ipv6_error_context: "http_request",
|
|
url_access: Some(&self.url_access),
|
|
},
|
|
)
|
|
}
|
|
|
|
fn validate_method(&self, method: &str) -> anyhow::Result<reqwest::Method> {
|
|
match method.to_uppercase().as_str() {
|
|
"GET" => Ok(reqwest::Method::GET),
|
|
"POST" => Ok(reqwest::Method::POST),
|
|
"PUT" => Ok(reqwest::Method::PUT),
|
|
"DELETE" => Ok(reqwest::Method::DELETE),
|
|
"PATCH" => Ok(reqwest::Method::PATCH),
|
|
"HEAD" => Ok(reqwest::Method::HEAD),
|
|
"OPTIONS" => Ok(reqwest::Method::OPTIONS),
|
|
_ => anyhow::bail!("Unsupported HTTP method: {method}. Supported: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS"),
|
|
}
|
|
}
|
|
|
|
fn parse_headers(&self, headers: &serde_json::Value) -> Vec<(String, String)> {
|
|
let mut result = Vec::new();
|
|
if let Some(obj) = headers.as_object() {
|
|
for (key, value) in obj {
|
|
if let Some(str_val) = value.as_str() {
|
|
result.push((key.clone(), str_val.to_string()));
|
|
}
|
|
}
|
|
}
|
|
result
|
|
}
|
|
|
|
fn redact_headers_for_display(headers: &[(String, String)]) -> Vec<(String, String)> {
|
|
headers
|
|
.iter()
|
|
.map(|(key, value)| {
|
|
let lower = key.to_lowercase();
|
|
let is_sensitive = lower.contains("authorization")
|
|
|| lower.contains("api-key")
|
|
|| lower.contains("apikey")
|
|
|| lower.contains("token")
|
|
|| lower.contains("secret");
|
|
if is_sensitive {
|
|
(key.clone(), "***REDACTED***".into())
|
|
} else {
|
|
(key.clone(), value.clone())
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn resolve_credential_profile(
|
|
&self,
|
|
profile_name: &str,
|
|
) -> anyhow::Result<(Vec<(String, String)>, Vec<String>)> {
|
|
let requested_name = profile_name.trim();
|
|
if requested_name.is_empty() {
|
|
anyhow::bail!("credential_profile must not be empty");
|
|
}
|
|
|
|
let profile = self
|
|
.credential_profiles
|
|
.get(&requested_name.to_ascii_lowercase())
|
|
.ok_or_else(|| {
|
|
let mut names: Vec<&str> = self
|
|
.credential_profiles
|
|
.keys()
|
|
.map(std::string::String::as_str)
|
|
.collect();
|
|
names.sort_unstable();
|
|
if names.is_empty() {
|
|
anyhow::anyhow!(
|
|
"Unknown credential_profile '{requested_name}'. No credential profiles are configured under [http_request.credential_profiles]."
|
|
)
|
|
} else {
|
|
anyhow::anyhow!(
|
|
"Unknown credential_profile '{requested_name}'. Available profiles: {}",
|
|
names.join(", ")
|
|
)
|
|
}
|
|
})?;
|
|
|
|
let header_name = profile.header_name.trim();
|
|
if header_name.is_empty() {
|
|
anyhow::bail!(
|
|
"credential_profile '{requested_name}' has an empty header_name in config"
|
|
);
|
|
}
|
|
|
|
let env_var = profile.env_var.trim();
|
|
if env_var.is_empty() {
|
|
anyhow::bail!("credential_profile '{requested_name}' has an empty env_var in config");
|
|
}
|
|
|
|
let secret = std::env::var(env_var).map_err(|_| {
|
|
anyhow::anyhow!(
|
|
"credential_profile '{requested_name}' requires environment variable {env_var}"
|
|
)
|
|
})?;
|
|
let secret = secret.trim();
|
|
if secret.is_empty() {
|
|
anyhow::bail!(
|
|
"credential_profile '{requested_name}' uses environment variable {env_var}, but it is empty"
|
|
);
|
|
}
|
|
|
|
let header_value = format!("{}{}", profile.value_prefix, secret);
|
|
let mut sensitive_values = vec![secret.to_string(), header_value.clone()];
|
|
sensitive_values.sort_unstable();
|
|
sensitive_values.dedup();
|
|
|
|
Ok((
|
|
vec![(header_name.to_string(), header_value)],
|
|
sensitive_values,
|
|
))
|
|
}
|
|
|
|
fn has_header_name_conflict(
|
|
explicit_headers: &[(String, String)],
|
|
injected_headers: &[(String, String)],
|
|
) -> bool {
|
|
explicit_headers.iter().any(|(explicit_key, _)| {
|
|
injected_headers
|
|
.iter()
|
|
.any(|(injected_key, _)| injected_key.eq_ignore_ascii_case(explicit_key))
|
|
})
|
|
}
|
|
|
|
fn redact_sensitive_values(text: &str, sensitive_values: &[String]) -> String {
|
|
let mut redacted = text.to_string();
|
|
for value in sensitive_values {
|
|
let needle = value.trim();
|
|
if needle.is_empty() || needle.len() < 6 {
|
|
continue;
|
|
}
|
|
redacted = redacted.replace(needle, "***REDACTED***");
|
|
}
|
|
redacted
|
|
}
|
|
|
|
async fn execute_request(
|
|
&self,
|
|
url: &str,
|
|
method: reqwest::Method,
|
|
headers: Vec<(String, String)>,
|
|
body: Option<&str>,
|
|
) -> anyhow::Result<reqwest::Response> {
|
|
let timeout_secs = if self.timeout_secs == 0 {
|
|
tracing::warn!("http_request: timeout_secs is 0, using safe default of 30s");
|
|
30
|
|
} else {
|
|
self.timeout_secs
|
|
};
|
|
let builder = reqwest::Client::builder()
|
|
.timeout(Duration::from_secs(timeout_secs))
|
|
.connect_timeout(Duration::from_secs(10))
|
|
.redirect(reqwest::redirect::Policy::none())
|
|
.user_agent(self.user_agent.as_str());
|
|
let builder = crate::config::apply_runtime_proxy_to_builder(builder, "tool.http_request");
|
|
let client = builder.build()?;
|
|
|
|
let mut request = client.request(method, url);
|
|
|
|
for (key, value) in headers {
|
|
request = request.header(&key, &value);
|
|
}
|
|
|
|
if let Some(body_str) = body {
|
|
request = request.body(body_str.to_string());
|
|
}
|
|
|
|
Ok(request.send().await?)
|
|
}
|
|
|
|
fn truncate_response(&self, text: &str) -> String {
|
|
if text.len() > self.max_response_size {
|
|
let mut truncated = text
|
|
.chars()
|
|
.take(self.max_response_size)
|
|
.collect::<String>();
|
|
truncated.push_str("\n\n... [Response truncated due to size limit] ...");
|
|
truncated
|
|
} else {
|
|
text.to_string()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for HttpRequestTool {
|
|
fn name(&self) -> &str {
|
|
"http_request"
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
"Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS methods. \
|
|
Security constraints: allowlist-only domains, no local/private hosts, configurable timeout/response size limits, and optional env-backed credential profiles."
|
|
}
|
|
|
|
fn parameters_schema(&self) -> serde_json::Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"url": {
|
|
"type": "string",
|
|
"description": "HTTP or HTTPS URL to request"
|
|
},
|
|
"method": {
|
|
"type": "string",
|
|
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)",
|
|
"default": "GET"
|
|
},
|
|
"headers": {
|
|
"type": "object",
|
|
"description": "Optional HTTP headers as key-value pairs (e.g., {\"Authorization\": \"Bearer token\", \"Content-Type\": \"application/json\"})",
|
|
"default": {}
|
|
},
|
|
"credential_profile": {
|
|
"type": "string",
|
|
"description": "Optional profile name from [http_request.credential_profiles]. Lets the harness inject credentials from environment variables without passing raw tokens in tool arguments."
|
|
},
|
|
"body": {
|
|
"type": "string",
|
|
"description": "Optional request body (for POST, PUT, PATCH requests)"
|
|
}
|
|
},
|
|
"required": ["url"]
|
|
})
|
|
}
|
|
|
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
let url = args
|
|
.get("url")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| anyhow::anyhow!("Missing 'url' parameter"))?;
|
|
|
|
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
|
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
|
let credential_profile = match args.get("credential_profile") {
|
|
Some(value) => match value.as_str() {
|
|
Some(name) => Some(name),
|
|
None => {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some("Invalid 'credential_profile': expected string".into()),
|
|
});
|
|
}
|
|
},
|
|
None => None,
|
|
};
|
|
let body = args.get("body").and_then(|v| v.as_str());
|
|
|
|
if !self.security.can_act() {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some("Action blocked: autonomy is read-only".into()),
|
|
});
|
|
}
|
|
|
|
if !self.security.record_action() {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some("Action blocked: rate limit exceeded".into()),
|
|
});
|
|
}
|
|
|
|
let url = match self.validate_url(url) {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(e.to_string()),
|
|
})
|
|
}
|
|
};
|
|
|
|
let method = match self.validate_method(method_str) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(e.to_string()),
|
|
})
|
|
}
|
|
};
|
|
|
|
let mut request_headers = self.parse_headers(&headers_val);
|
|
let mut sensitive_values = Vec::new();
|
|
if let Some(profile_name) = credential_profile {
|
|
match self.resolve_credential_profile(profile_name) {
|
|
Ok((injected_headers, profile_sensitive_values)) => {
|
|
if Self::has_header_name_conflict(&request_headers, &injected_headers) {
|
|
let names = injected_headers
|
|
.iter()
|
|
.map(|(name, _)| name.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(format!(
|
|
"credential_profile '{profile_name}' conflicts with explicit headers ({names}); remove duplicate header keys from args.headers"
|
|
)),
|
|
});
|
|
}
|
|
request_headers.extend(injected_headers);
|
|
sensitive_values.extend(profile_sensitive_values);
|
|
}
|
|
Err(e) => {
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(e.to_string()),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
match self
|
|
.execute_request(&url, method, request_headers, body)
|
|
.await
|
|
{
|
|
Ok(response) => {
|
|
let status = response.status();
|
|
let status_code = status.as_u16();
|
|
|
|
// Get response headers (redact sensitive ones)
|
|
let response_headers = response.headers().iter();
|
|
let headers_text = response_headers
|
|
.map(|(k, v)| {
|
|
let lower = k.as_str().to_ascii_lowercase();
|
|
let is_sensitive = lower.contains("set-cookie")
|
|
|| lower.contains("authorization")
|
|
|| lower.contains("api-key")
|
|
|| lower.contains("token")
|
|
|| lower.contains("secret");
|
|
if is_sensitive {
|
|
format!("{}: ***REDACTED***", k.as_str())
|
|
} else {
|
|
let val = v.to_str().unwrap_or("<non-UTF8>");
|
|
format!("{}: {}", k.as_str(), val)
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
let headers_text = Self::redact_sensitive_values(&headers_text, &sensitive_values);
|
|
|
|
// Get response body with size limit
|
|
let response_text = match response.text().await {
|
|
Ok(text) => self.truncate_response(&text),
|
|
Err(e) => format!("[Failed to read response body: {e}]"),
|
|
};
|
|
let response_text =
|
|
Self::redact_sensitive_values(&response_text, &sensitive_values);
|
|
|
|
let output = format!(
|
|
"Status: {} {}\nResponse Headers: {}\n\nResponse Body:\n{}",
|
|
status_code,
|
|
status.canonical_reason().unwrap_or("Unknown"),
|
|
headers_text,
|
|
response_text
|
|
);
|
|
|
|
Ok(ToolResult {
|
|
success: status.is_success(),
|
|
output,
|
|
error: if status.is_client_error() || status.is_server_error() {
|
|
Some(format!("HTTP {}", status_code))
|
|
} else {
|
|
None
|
|
},
|
|
})
|
|
}
|
|
Err(e) => Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(format!("HTTP request failed: {e}")),
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::security::{AutonomyLevel, SecurityPolicy};
|
|
use crate::tools::url_validation::{is_private_or_local_host, normalize_domain};
|
|
|
|
fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool {
|
|
let security = Arc::new(SecurityPolicy {
|
|
autonomy: AutonomyLevel::Supervised,
|
|
..SecurityPolicy::default()
|
|
});
|
|
HttpRequestTool::new(
|
|
security,
|
|
allowed_domains.into_iter().map(String::from).collect(),
|
|
UrlAccessConfig::default(),
|
|
1_000_000,
|
|
30,
|
|
"test".to_string(),
|
|
HashMap::new(),
|
|
)
|
|
}
|
|
|
|
#[test]
|
|
fn normalize_domain_strips_scheme_path_and_case() {
|
|
let got = normalize_domain(" HTTPS://Docs.Example.com/path ").unwrap();
|
|
assert_eq!(got, "docs.example.com");
|
|
}
|
|
|
|
#[test]
|
|
fn normalize_allowed_domains_deduplicates() {
|
|
let got = normalize_allowed_domains(vec![
|
|
"example.com".into(),
|
|
"EXAMPLE.COM".into(),
|
|
"https://example.com/".into(),
|
|
]);
|
|
assert_eq!(got, vec!["example.com".to_string()]);
|
|
}
|
|
|
|
#[test]
|
|
fn validate_accepts_exact_domain() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let got = tool.validate_url("https://example.com/docs").unwrap();
|
|
assert_eq!(got, "https://example.com/docs");
|
|
}
|
|
|
|
#[test]
|
|
fn validate_accepts_http() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
assert!(tool.validate_url("http://example.com").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_accepts_subdomain() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
assert!(tool.validate_url("https://api.example.com/v1").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_accepts_wildcard_allowlist_for_public_host() {
|
|
let tool = test_tool(vec!["*"]);
|
|
assert!(tool.validate_url("https://news.ycombinator.com").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_wildcard_allowlist_still_rejects_private_host() {
|
|
let tool = test_tool(vec!["*"]);
|
|
let err = tool
|
|
.validate_url("https://localhost:8080")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("local/private"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_accepts_wildcard_subdomain_pattern() {
|
|
let tool = test_tool(vec!["*.example.com"]);
|
|
assert!(tool.validate_url("https://example.com").is_ok());
|
|
assert!(tool.validate_url("https://sub.example.com").is_ok());
|
|
assert!(tool.validate_url("https://other.com").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_allowlist_miss() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool
|
|
.validate_url("https://google.com")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("allowed_domains"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_localhost() {
|
|
let tool = test_tool(vec!["localhost"]);
|
|
let err = tool
|
|
.validate_url("https://localhost:8080")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("local/private"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_private_ipv4() {
|
|
let tool = test_tool(vec!["192.168.1.5"]);
|
|
let err = tool
|
|
.validate_url("https://192.168.1.5")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("local/private"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_whitespace() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool
|
|
.validate_url("https://example.com/hello world")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("whitespace"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_userinfo() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool
|
|
.validate_url("https://user@example.com")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("userinfo"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_requires_allowlist() {
|
|
let security = Arc::new(SecurityPolicy::default());
|
|
let tool = HttpRequestTool::new(
|
|
security,
|
|
vec![],
|
|
UrlAccessConfig::default(),
|
|
1_000_000,
|
|
30,
|
|
"test".to_string(),
|
|
HashMap::new(),
|
|
);
|
|
let err = tool
|
|
.validate_url("https://example.com")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("allowed_domains"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_accepts_valid_methods() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
assert!(tool.validate_method("GET").is_ok());
|
|
assert!(tool.validate_method("POST").is_ok());
|
|
assert!(tool.validate_method("PUT").is_ok());
|
|
assert!(tool.validate_method("DELETE").is_ok());
|
|
assert!(tool.validate_method("PATCH").is_ok());
|
|
assert!(tool.validate_method("HEAD").is_ok());
|
|
assert!(tool.validate_method("OPTIONS").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_invalid_method() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool.validate_method("INVALID").unwrap_err().to_string();
|
|
assert!(err.contains("Unsupported HTTP method"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_multicast_ipv4() {
|
|
assert!(is_private_or_local_host("224.0.0.1"));
|
|
assert!(is_private_or_local_host("239.255.255.255"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_broadcast() {
|
|
assert!(is_private_or_local_host("255.255.255.255"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_reserved_ipv4() {
|
|
assert!(is_private_or_local_host("240.0.0.1"));
|
|
assert!(is_private_or_local_host("250.1.2.3"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_documentation_ranges() {
|
|
assert!(is_private_or_local_host("192.0.2.1")); // TEST-NET-1
|
|
assert!(is_private_or_local_host("198.51.100.1")); // TEST-NET-2
|
|
assert!(is_private_or_local_host("203.0.113.1")); // TEST-NET-3
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_benchmarking_range() {
|
|
assert!(is_private_or_local_host("198.18.0.1"));
|
|
assert!(is_private_or_local_host("198.19.255.255"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_ipv6_localhost() {
|
|
assert!(is_private_or_local_host("::1"));
|
|
assert!(is_private_or_local_host("[::1]"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_ipv6_multicast() {
|
|
assert!(is_private_or_local_host("ff02::1"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_ipv6_link_local() {
|
|
assert!(is_private_or_local_host("fe80::1"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_ipv6_unique_local() {
|
|
assert!(is_private_or_local_host("fd00::1"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_ipv4_mapped_ipv6() {
|
|
assert!(is_private_or_local_host("::ffff:127.0.0.1"));
|
|
assert!(is_private_or_local_host("::ffff:192.168.1.1"));
|
|
assert!(is_private_or_local_host("::ffff:10.0.0.1"));
|
|
}
|
|
|
|
#[test]
|
|
fn allows_public_ipv4() {
|
|
assert!(!is_private_or_local_host("8.8.8.8"));
|
|
assert!(!is_private_or_local_host("1.1.1.1"));
|
|
assert!(!is_private_or_local_host("93.184.216.34"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_ipv6_documentation_range() {
|
|
assert!(is_private_or_local_host("2001:db8::1"));
|
|
}
|
|
|
|
#[test]
|
|
fn allows_public_ipv6() {
|
|
assert!(!is_private_or_local_host("2607:f8b0:4004:800::200e"));
|
|
}
|
|
|
|
#[test]
|
|
fn blocks_shared_address_space() {
|
|
assert!(is_private_or_local_host("100.64.0.1"));
|
|
assert!(is_private_or_local_host("100.127.255.255"));
|
|
assert!(!is_private_or_local_host("100.63.0.1")); // Just below range
|
|
assert!(!is_private_or_local_host("100.128.0.1")); // Just above range
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn execute_blocks_readonly_mode() {
|
|
let security = Arc::new(SecurityPolicy {
|
|
autonomy: AutonomyLevel::ReadOnly,
|
|
..SecurityPolicy::default()
|
|
});
|
|
let tool = HttpRequestTool::new(
|
|
security,
|
|
vec!["example.com".into()],
|
|
UrlAccessConfig::default(),
|
|
1_000_000,
|
|
30,
|
|
"test".to_string(),
|
|
HashMap::new(),
|
|
);
|
|
let result = tool
|
|
.execute(json!({"url": "https://example.com"}))
|
|
.await
|
|
.unwrap();
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("read-only"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn execute_blocks_when_rate_limited() {
|
|
let security = Arc::new(SecurityPolicy {
|
|
max_actions_per_hour: 0,
|
|
..SecurityPolicy::default()
|
|
});
|
|
let tool = HttpRequestTool::new(
|
|
security,
|
|
vec!["example.com".into()],
|
|
UrlAccessConfig::default(),
|
|
1_000_000,
|
|
30,
|
|
"test".to_string(),
|
|
HashMap::new(),
|
|
);
|
|
let result = tool
|
|
.execute(json!({"url": "https://example.com"}))
|
|
.await
|
|
.unwrap();
|
|
assert!(!result.success);
|
|
assert!(result.error.unwrap().contains("rate limit"));
|
|
}
|
|
|
|
#[test]
|
|
fn truncate_response_within_limit() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let text = "hello world";
|
|
assert_eq!(tool.truncate_response(text), "hello world");
|
|
}
|
|
|
|
#[test]
|
|
fn truncate_response_over_limit() {
|
|
let tool = HttpRequestTool::new(
|
|
Arc::new(SecurityPolicy::default()),
|
|
vec!["example.com".into()],
|
|
UrlAccessConfig::default(),
|
|
10,
|
|
30,
|
|
"test".to_string(),
|
|
HashMap::new(),
|
|
);
|
|
let text = "hello world this is long";
|
|
let truncated = tool.truncate_response(text);
|
|
assert!(truncated.len() <= 10 + 60); // limit + message
|
|
assert!(truncated.contains("[Response truncated"));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_headers_preserves_original_values() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let headers = json!({
|
|
"Authorization": "Bearer secret",
|
|
"Content-Type": "application/json",
|
|
"X-API-Key": "my-key"
|
|
});
|
|
let parsed = tool.parse_headers(&headers);
|
|
assert_eq!(parsed.len(), 3);
|
|
assert!(parsed
|
|
.iter()
|
|
.any(|(k, v)| k == "Authorization" && v == "Bearer secret"));
|
|
assert!(parsed
|
|
.iter()
|
|
.any(|(k, v)| k == "X-API-Key" && v == "my-key"));
|
|
assert!(parsed
|
|
.iter()
|
|
.any(|(k, v)| k == "Content-Type" && v == "application/json"));
|
|
}
|
|
|
|
#[test]
|
|
fn redact_headers_for_display_redacts_sensitive() {
|
|
let headers = vec![
|
|
("Authorization".into(), "Bearer secret".into()),
|
|
("Content-Type".into(), "application/json".into()),
|
|
("X-API-Key".into(), "my-key".into()),
|
|
("X-Secret-Token".into(), "tok-123".into()),
|
|
];
|
|
let redacted = HttpRequestTool::redact_headers_for_display(&headers);
|
|
assert_eq!(redacted.len(), 4);
|
|
assert!(redacted
|
|
.iter()
|
|
.any(|(k, v)| k == "Authorization" && v == "***REDACTED***"));
|
|
assert!(redacted
|
|
.iter()
|
|
.any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***"));
|
|
assert!(redacted
|
|
.iter()
|
|
.any(|(k, v)| k == "X-Secret-Token" && v == "***REDACTED***"));
|
|
assert!(redacted
|
|
.iter()
|
|
.any(|(k, v)| k == "Content-Type" && v == "application/json"));
|
|
}
|
|
|
|
#[test]
|
|
fn redact_headers_does_not_alter_original() {
|
|
let headers = vec![("Authorization".into(), "Bearer real-token".into())];
|
|
let _ = HttpRequestTool::redact_headers_for_display(&headers);
|
|
assert_eq!(headers[0].1, "Bearer real-token");
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_credential_profile_injects_env_backed_header() {
|
|
let test_secret = "test-credential-value-12345";
|
|
std::env::set_var("ZEROCLAW_TEST_HTTP_CREDENTIAL", test_secret);
|
|
|
|
let mut profiles = HashMap::new();
|
|
profiles.insert(
|
|
"github".to_string(),
|
|
HttpRequestCredentialProfile {
|
|
header_name: "Authorization".to_string(),
|
|
env_var: "ZEROCLAW_TEST_HTTP_CREDENTIAL".to_string(),
|
|
value_prefix: "Bearer ".to_string(),
|
|
},
|
|
);
|
|
|
|
let tool = HttpRequestTool::new(
|
|
Arc::new(SecurityPolicy::default()),
|
|
vec!["api.github.com".into()],
|
|
UrlAccessConfig::default(),
|
|
1_000_000,
|
|
30,
|
|
"test".to_string(),
|
|
profiles,
|
|
);
|
|
|
|
let (headers, sensitive_values) = tool
|
|
.resolve_credential_profile("github")
|
|
.expect("profile should resolve");
|
|
|
|
assert_eq!(headers.len(), 1);
|
|
assert_eq!(headers[0].0, "Authorization");
|
|
assert_eq!(headers[0].1, format!("Bearer {test_secret}"));
|
|
assert!(sensitive_values.contains(&test_secret.to_string()));
|
|
assert!(sensitive_values.contains(&format!("Bearer {test_secret}")));
|
|
|
|
std::env::remove_var("ZEROCLAW_TEST_HTTP_CREDENTIAL");
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_credential_profile_missing_env_var_fails() {
|
|
let mut profiles = HashMap::new();
|
|
profiles.insert(
|
|
"missing".to_string(),
|
|
HttpRequestCredentialProfile {
|
|
header_name: "Authorization".to_string(),
|
|
env_var: "ZEROCLAW_TEST_MISSING_HTTP_REQUEST_TOKEN".to_string(),
|
|
value_prefix: "Bearer ".to_string(),
|
|
},
|
|
);
|
|
|
|
let tool = HttpRequestTool::new(
|
|
Arc::new(SecurityPolicy::default()),
|
|
vec!["example.com".into()],
|
|
UrlAccessConfig::default(),
|
|
1_000_000,
|
|
30,
|
|
"test".to_string(),
|
|
profiles,
|
|
);
|
|
|
|
let err = tool
|
|
.resolve_credential_profile("missing")
|
|
.expect_err("missing env var should fail")
|
|
.to_string();
|
|
assert!(err.contains("ZEROCLAW_TEST_MISSING_HTTP_REQUEST_TOKEN"));
|
|
}
|
|
|
|
#[test]
|
|
fn has_header_name_conflict_is_case_insensitive() {
|
|
let explicit = vec![("authorization".to_string(), "Bearer one".to_string())];
|
|
let injected = vec![("Authorization".to_string(), "Bearer two".to_string())];
|
|
assert!(HttpRequestTool::has_header_name_conflict(
|
|
&explicit, &injected
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn redact_sensitive_values_scrubs_injected_secrets() {
|
|
let text = "Authorization: Bearer super-secret-token\nbody=super-secret-token";
|
|
let redacted = HttpRequestTool::redact_sensitive_values(
|
|
text,
|
|
&[
|
|
"super-secret-token".to_string(),
|
|
"Bearer super-secret-token".to_string(),
|
|
],
|
|
);
|
|
assert!(!redacted.contains("super-secret-token"));
|
|
assert!(redacted.contains("***REDACTED***"));
|
|
}
|
|
|
|
// ── SSRF: alternate IP notation bypass defense-in-depth ─────────
|
|
//
|
|
// Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
|
|
// decimal integer, zero-padded). These tests document that property
|
|
// so regressions are caught if the parsing strategy ever changes.
|
|
|
|
#[test]
|
|
fn ssrf_octal_loopback_not_parsed_as_ip() {
|
|
// 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
|
|
// Rust's IpAddr rejects it — it falls through as a hostname.
|
|
assert!(!is_private_or_local_host("0177.0.0.1"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_hex_loopback_not_parsed_as_ip() {
|
|
// 0x7f000001 is hex for 127.0.0.1 in some languages.
|
|
assert!(!is_private_or_local_host("0x7f000001"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_decimal_loopback_not_parsed_as_ip() {
|
|
// 2130706433 is decimal for 127.0.0.1 in some languages.
|
|
assert!(!is_private_or_local_host("2130706433"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
|
|
// 127.000.000.001 uses zero-padded octets.
|
|
assert!(!is_private_or_local_host("127.000.000.001"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_alternate_notations_rejected_by_validate_url() {
|
|
// Even if is_private_or_local_host doesn't flag these, they
|
|
// fail the allowlist because they're treated as hostnames.
|
|
let tool = test_tool(vec!["example.com"]);
|
|
for notation in [
|
|
"http://0177.0.0.1",
|
|
"http://0x7f000001",
|
|
"http://2130706433",
|
|
"http://127.000.000.001",
|
|
] {
|
|
let err = tool.validate_url(notation).unwrap_err().to_string();
|
|
assert!(
|
|
err.contains("allowed_domains"),
|
|
"Expected allowlist rejection for {notation}, got: {err}"
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn redirect_policy_is_none() {
|
|
// Structural test: the tool should be buildable with redirect-safe config.
|
|
// The actual Policy::none() enforcement is in execute_request's client builder.
|
|
let tool = test_tool(vec!["example.com"]);
|
|
assert_eq!(tool.name(), "http_request");
|
|
}
|
|
|
|
// ── §1.4 DNS rebinding / SSRF defense-in-depth tests ─────
|
|
|
|
#[test]
|
|
fn ssrf_blocks_loopback_127_range() {
|
|
assert!(is_private_or_local_host("127.0.0.1"));
|
|
assert!(is_private_or_local_host("127.0.0.2"));
|
|
assert!(is_private_or_local_host("127.255.255.255"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_blocks_rfc1918_10_range() {
|
|
assert!(is_private_or_local_host("10.0.0.1"));
|
|
assert!(is_private_or_local_host("10.255.255.255"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_blocks_rfc1918_172_range() {
|
|
assert!(is_private_or_local_host("172.16.0.1"));
|
|
assert!(is_private_or_local_host("172.31.255.255"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_blocks_unspecified_address() {
|
|
assert!(is_private_or_local_host("0.0.0.0"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_blocks_dot_localhost_subdomain() {
|
|
assert!(is_private_or_local_host("evil.localhost"));
|
|
assert!(is_private_or_local_host("a.b.localhost"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_blocks_dot_local_tld() {
|
|
assert!(is_private_or_local_host("service.local"));
|
|
}
|
|
|
|
#[test]
|
|
fn ssrf_ipv6_unspecified() {
|
|
assert!(is_private_or_local_host("::"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_ftp_scheme() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool
|
|
.validate_url("ftp://example.com")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("http://") || err.contains("https://"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_empty_url() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool.validate_url("").unwrap_err().to_string();
|
|
assert!(err.contains("empty"));
|
|
}
|
|
|
|
#[test]
|
|
fn validate_rejects_ipv6_host() {
|
|
let tool = test_tool(vec!["example.com"]);
|
|
let err = tool
|
|
.validate_url("http://[::1]:8080/path")
|
|
.unwrap_err()
|
|
.to_string();
|
|
assert!(err.contains("IPv6"));
|
|
}
|
|
}
|