use super::traits::{Tool, ToolResult}; use crate::security::SecurityPolicy; use async_trait::async_trait; use regex::Regex; use reqwest::StatusCode; use serde_json::json; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; /// Web search tool for searching the internet. /// Supports providers: DuckDuckGo (free), Brave, Firecrawl, Tavily. pub struct WebSearchTool { security: Arc, provider: String, api_keys: Vec, api_url: Option, max_results: usize, timeout_secs: u64, user_agent: String, key_index: Arc, } impl WebSearchTool { fn duckduckgo_status_hint(status: StatusCode) -> &'static str { match status { StatusCode::FORBIDDEN | StatusCode::TOO_MANY_REQUESTS => { " DuckDuckGo may be blocking this network. Try [web_search].provider = \"brave\" with [web_search].brave_api_key, or set provider = \"firecrawl\"." } StatusCode::SERVICE_UNAVAILABLE | StatusCode::BAD_GATEWAY | StatusCode::GATEWAY_TIMEOUT => { " DuckDuckGo may be temporarily unavailable. Retry later or switch providers." } _ => "", } } pub fn new( security: Arc, provider: String, api_key: Option, api_url: Option, max_results: usize, timeout_secs: u64, user_agent: String, ) -> Self { let api_keys = api_key .as_ref() .map(|raw| { raw.split(',') .map(str::trim) .filter(|s| !s.is_empty()) .map(ToOwned::to_owned) .collect() }) .unwrap_or_default(); Self { security, provider: provider.trim().to_lowercase(), api_keys, api_url, max_results: max_results.clamp(1, 10), timeout_secs: timeout_secs.max(1), user_agent, key_index: Arc::new(AtomicUsize::new(0)), } } fn get_next_api_key(&self) -> Option { if self.api_keys.is_empty() { return None; } let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len(); Some(self.api_keys[idx].clone()) } async fn search_duckduckgo(&self, query: &str) -> anyhow::Result { let encoded_query = urlencoding::encode(query); let search_url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query); let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .user_agent(self.user_agent.as_str()) .build()?; let response = client.get(&search_url).send().await.map_err(|e| { anyhow::anyhow!( "DuckDuckGo search request failed: {e}. Check outbound network/proxy settings, or switch [web_search].provider to \"brave\"/\"firecrawl\"." ) })?; if !response.status().is_success() { let status = response.status(); anyhow::bail!( "DuckDuckGo search failed with status: {}.{}", status, Self::duckduckgo_status_hint(status) ); } let html = response.text().await?; self.parse_duckduckgo_results(&html, query) } fn parse_duckduckgo_results(&self, html: &str, query: &str) -> anyhow::Result { // Extract result links: Title let link_regex = Regex::new( r#"]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)"#, )?; // Extract snippets: ... let snippet_regex = Regex::new(r#"]*>([\s\S]*?)"#)?; let link_matches: Vec<_> = link_regex .captures_iter(html) .take(self.max_results + 2) .collect(); let snippet_matches: Vec<_> = snippet_regex .captures_iter(html) .take(self.max_results + 2) .collect(); if link_matches.is_empty() { return Ok(format!("No results found for: {}", query)); } let mut lines = vec![format!("Search results for: {} (via DuckDuckGo)", query)]; let count = link_matches.len().min(self.max_results); for i in 0..count { let caps = &link_matches[i]; let url_str = decode_ddg_redirect_url(&caps[1]); let title = strip_tags(&caps[2]); lines.push(format!("{}. {}", i + 1, title.trim())); lines.push(format!(" {}", url_str.trim())); // Add snippet if available if i < snippet_matches.len() { let snippet = strip_tags(&snippet_matches[i][1]); let snippet = snippet.trim(); if !snippet.is_empty() { lines.push(format!(" {}", snippet)); } } } Ok(lines.join("\n")) } async fn search_brave(&self, query: &str) -> anyhow::Result { let auth_token = self .get_next_api_key() .ok_or_else(|| anyhow::anyhow!("Brave API key not configured"))?; let encoded_query = urlencoding::encode(query); let search_url = format!( "https://api.search.brave.com/res/v1/web/search?q={}&count={}", encoded_query, self.max_results ); let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .user_agent(self.user_agent.as_str()) .build()?; let response = client .get(&search_url) .header("Accept", "application/json") .header("X-Subscription-Token", auth_token) .send() .await?; if !response.status().is_success() { anyhow::bail!("Brave search failed with status: {}", response.status()); } let json: serde_json::Value = response.json().await?; self.parse_brave_results(&json, query) } fn parse_brave_results(&self, json: &serde_json::Value, query: &str) -> anyhow::Result { let results = json .get("web") .and_then(|w| w.get("results")) .and_then(|r| r.as_array()) .ok_or_else(|| anyhow::anyhow!("Invalid Brave API response"))?; if results.is_empty() { return Ok(format!("No results found for: {}", query)); } let mut lines = vec![format!("Search results for: {} (via Brave)", query)]; for (i, result) in results.iter().take(self.max_results).enumerate() { let title = result .get("title") .and_then(|t| t.as_str()) .unwrap_or("No title"); let url = result.get("url").and_then(|u| u.as_str()).unwrap_or(""); let description = result .get("description") .and_then(|d| d.as_str()) .unwrap_or(""); lines.push(format!("{}. {}", i + 1, title)); lines.push(format!(" {}", url)); if !description.is_empty() { lines.push(format!(" {}", description)); } } Ok(lines.join("\n")) } #[cfg(feature = "firecrawl")] async fn search_firecrawl(&self, query: &str) -> anyhow::Result { let auth_token = self.get_next_api_key().ok_or_else(|| { anyhow::anyhow!( "web_search provider 'firecrawl' requires [web_search].api_key in config.toml" ) })?; let api_url = self .api_url .as_deref() .map(str::trim) .filter(|s| !s.is_empty()) .unwrap_or("https://api.firecrawl.dev"); let endpoint = format!("{}/v1/search", api_url.trim_end_matches('/')); let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .user_agent(self.user_agent.as_str()) .build()?; let response = client .post(endpoint) .header( reqwest::header::AUTHORIZATION, format!("Bearer {}", auth_token), ) .json(&json!({ "query": query, "limit": self.max_results, "timeout": (self.timeout_secs * 1000) as u64, })) .send() .await .map_err(|e| anyhow::anyhow!("Firecrawl search failed: {e}"))?; let status = response.status(); let body = response.text().await?; if !status.is_success() { anyhow::bail!( "Firecrawl search failed with status {}: {}", status.as_u16(), body ); } let parsed: serde_json::Value = serde_json::from_str(&body) .map_err(|e| anyhow::anyhow!("Invalid Firecrawl response JSON: {e}"))?; if !parsed .get("success") .and_then(serde_json::Value::as_bool) .unwrap_or(false) { let error = parsed .get("error") .and_then(serde_json::Value::as_str) .unwrap_or("unknown error"); anyhow::bail!("Firecrawl search failed: {error}"); } let results = parsed .get("data") .and_then(serde_json::Value::as_array) .ok_or_else(|| anyhow::anyhow!("Firecrawl response missing data array"))?; if results.is_empty() { return Ok(format!("No results found for: {}", query)); } let mut lines = vec![format!("Search results for: {} (via Firecrawl)", query)]; for (i, result) in results.iter().take(self.max_results).enumerate() { let title = result .get("title") .and_then(serde_json::Value::as_str) .unwrap_or("No title"); let url = result .get("url") .and_then(serde_json::Value::as_str) .unwrap_or(""); let description = result .get("description") .and_then(serde_json::Value::as_str) .unwrap_or(""); lines.push(format!("{}. {}", i + 1, title)); lines.push(format!(" {}", url)); if !description.trim().is_empty() { lines.push(format!(" {}", description.trim())); } } Ok(lines.join("\n")) } #[cfg(not(feature = "firecrawl"))] #[allow(clippy::unused_async)] async fn search_firecrawl(&self, _query: &str) -> anyhow::Result { anyhow::bail!("web_search provider 'firecrawl' requires Cargo feature 'firecrawl'") } async fn search_tavily(&self, query: &str) -> anyhow::Result { let api_key = self.get_next_api_key().ok_or_else(|| { anyhow::anyhow!( "web_search provider 'tavily' requires [web_search].api_key in config.toml" ) })?; let api_url = self .api_url .as_deref() .map(str::trim) .filter(|s| !s.is_empty()) .unwrap_or("https://api.tavily.com"); let endpoint = format!("{}/search", api_url.trim_end_matches('/')); let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .user_agent(self.user_agent.as_str()) .build()?; let response = client .post(&endpoint) .json(&json!({ "api_key": api_key, "query": query, "max_results": self.max_results, "search_depth": "basic", "include_answer": false, "include_raw_content": false, "include_images": false })) .send() .await .map_err(|e| anyhow::anyhow!("Tavily search failed: {e}"))?; let status = response.status(); let body = response.text().await?; if !status.is_success() { anyhow::bail!( "Tavily search failed with status {}: {}", status.as_u16(), body ); } let parsed: serde_json::Value = serde_json::from_str(&body) .map_err(|e| anyhow::anyhow!("Invalid Tavily response JSON: {e}"))?; if let Some(error) = parsed.get("error").and_then(serde_json::Value::as_str) { anyhow::bail!("Tavily API error: {error}"); } let results = parsed .get("results") .and_then(serde_json::Value::as_array) .ok_or_else(|| anyhow::anyhow!("Tavily response missing results array"))?; if results.is_empty() { return Ok(format!("No results found for: {}", query)); } let mut lines = vec![format!("Search results for: {} (via Tavily)", query)]; for (i, result) in results.iter().take(self.max_results).enumerate() { let title = result .get("title") .and_then(serde_json::Value::as_str) .unwrap_or("No title"); let url = result .get("url") .and_then(serde_json::Value::as_str) .unwrap_or(""); let content = result .get("content") .and_then(serde_json::Value::as_str) .unwrap_or("") .trim(); lines.push(format!("{}. {}", i + 1, title)); lines.push(format!(" {}", url)); if !content.is_empty() { lines.push(format!(" {}", content)); } } Ok(lines.join("\n")) } } fn decode_ddg_redirect_url(raw_url: &str) -> String { if let Some(index) = raw_url.find("uddg=") { let encoded = &raw_url[index + 5..]; let encoded = encoded.split('&').next().unwrap_or(encoded); if let Ok(decoded) = urlencoding::decode(encoded) { return decoded.into_owned(); } } raw_url.to_string() } fn strip_tags(content: &str) -> String { let re = Regex::new(r"<[^>]+>").unwrap(); re.replace_all(content, "").to_string() } #[async_trait] impl Tool for WebSearchTool { fn name(&self) -> &str { "web_search_tool" } fn description(&self) -> &str { "Search the web for information. Returns relevant search results with titles, URLs, and descriptions. Use this to find current information, news, or research topics." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "query": { "type": "string", "description": "The search query. Be specific for better results." } }, "required": ["query"] }) } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { if !self.security.can_act() { return Ok(ToolResult { success: false, output: String::new(), error: Some("Action blocked: autonomy is read-only".into()), }); } if !self.security.record_action() { return Ok(ToolResult { success: false, output: String::new(), error: Some("Action blocked: rate limit exceeded".into()), }); } let query = args .get("query") .and_then(|q| q.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing required parameter: query"))?; if query.trim().is_empty() { anyhow::bail!("Search query cannot be empty"); } tracing::info!("Searching web for: {}", query); let result = match self.provider.as_str() { "duckduckgo" | "ddg" => self.search_duckduckgo(query).await?, "brave" => self.search_brave(query).await?, "firecrawl" => self.search_firecrawl(query).await?, "tavily" => self.search_tavily(query).await?, _ => anyhow::bail!( "Unknown search provider: '{}'. Set [web_search].provider to 'duckduckgo', 'brave', 'firecrawl', or 'tavily' in config.toml", self.provider ), }; Ok(ToolResult { success: true, output: result, error: None, }) } } #[cfg(test)] mod tests { use super::*; use crate::security::{AutonomyLevel, SecurityPolicy}; fn test_security() -> Arc { Arc::new(SecurityPolicy { autonomy: AutonomyLevel::Supervised, ..SecurityPolicy::default() }) } #[test] fn test_tool_name() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); assert_eq!(tool.name(), "web_search_tool"); } #[test] fn test_tool_description() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); assert!(tool.description().contains("Search the web")); } #[test] fn test_parameters_schema() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let schema = tool.parameters_schema(); assert_eq!(schema["type"], "object"); assert!(schema["properties"]["query"].is_object()); } #[test] fn test_strip_tags() { let html = "Hello World"; assert_eq!(strip_tags(html), "Hello World"); } #[test] fn test_parse_duckduckgo_results_empty() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool .parse_duckduckgo_results("No results here", "test") .unwrap(); assert!(result.contains("No results found")); } #[test] fn test_parse_duckduckgo_results_with_data() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let html = r#" Example Title This is a description "#; let result = tool.parse_duckduckgo_results(html, "test").unwrap(); assert!(result.contains("Example Title")); assert!(result.contains("https://example.com")); } #[test] fn test_parse_duckduckgo_results_decodes_redirect_url() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let html = r#" Example Title This is a description "#; let result = tool.parse_duckduckgo_results(html, "test").unwrap(); assert!(result.contains("https://example.com/path?a=1")); assert!(!result.contains("rut=test")); } #[test] fn duckduckgo_status_hint_for_403_mentions_provider_switch() { let hint = WebSearchTool::duckduckgo_status_hint(StatusCode::FORBIDDEN); assert!(hint.contains("provider")); assert!(hint.contains("brave")); } #[test] fn duckduckgo_status_hint_for_500_is_empty() { assert!( WebSearchTool::duckduckgo_status_hint(StatusCode::INTERNAL_SERVER_ERROR).is_empty() ); } #[test] fn test_constructor_clamps_web_search_limits() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 0, 0, "test".to_string(), ); let html = r#" Example Title This is a description "#; let result = tool.parse_duckduckgo_results(html, "test").unwrap(); assert!(result.contains("Example Title")); } #[tokio::test] async fn test_execute_missing_query() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool.execute(json!({})).await; assert!(result.is_err()); } #[tokio::test] async fn test_execute_empty_query() { let tool = WebSearchTool::new( test_security(), "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool.execute(json!({"query": ""})).await; assert!(result.is_err()); } #[tokio::test] async fn test_execute_brave_without_api_key() { let tool = WebSearchTool::new( test_security(), "brave".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool.execute(json!({"query": "test"})).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("API key")); } #[tokio::test] async fn test_execute_firecrawl_without_api_key() { let tool = WebSearchTool::new( test_security(), "firecrawl".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool.execute(json!({"query": "test"})).await; assert!(result.is_err()); let error = result.unwrap_err().to_string(); if cfg!(feature = "firecrawl") { assert!(error.contains("api_key")); } else { assert!(error.contains("requires Cargo feature 'firecrawl'")); } } #[tokio::test] async fn test_execute_tavily_without_api_key() { let tool = WebSearchTool::new( test_security(), "tavily".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool.execute(json!({"query": "test"})).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("api_key")); } #[test] fn test_parses_multiple_api_keys() { let tool = WebSearchTool::new( test_security(), "tavily".to_string(), Some("key1,key2,key3".to_string()), None, 5, 15, "test".to_string(), ); assert_eq!(tool.api_keys, vec!["key1", "key2", "key3"]); } #[test] fn test_round_robin_api_key_selection_cycles() { let tool = WebSearchTool::new( test_security(), "tavily".to_string(), Some("k1,k2".to_string()), None, 5, 15, "test".to_string(), ); assert_eq!(tool.get_next_api_key().as_deref(), Some("k1")); assert_eq!(tool.get_next_api_key().as_deref(), Some("k2")); assert_eq!(tool.get_next_api_key().as_deref(), Some("k1")); } #[tokio::test] async fn test_execute_blocked_in_read_only_mode() { let security = Arc::new(SecurityPolicy { autonomy: AutonomyLevel::ReadOnly, ..SecurityPolicy::default() }); let tool = WebSearchTool::new( security, "duckduckgo".to_string(), None, None, 5, 15, "test".to_string(), ); let result = tool.execute(json!({"query": "rust"})).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("read-only")); } }