diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 3dd7c6331..6140de52a 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -81,6 +81,7 @@ pub mod tool_search; pub mod traits; pub mod verifiable_intent; pub mod web_fetch; +mod web_search_provider_routing; pub mod web_search_tool; pub mod workspace_tool; diff --git a/src/tools/web_search_provider_routing.rs b/src/tools/web_search_provider_routing.rs new file mode 100644 index 000000000..4fd372f4b --- /dev/null +++ b/src/tools/web_search_provider_routing.rs @@ -0,0 +1,73 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WebSearchProviderRoute { + DuckDuckGo, + Brave, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WebSearchProviderResolution { + pub route: WebSearchProviderRoute, + pub canonical_provider: &'static str, + pub used_fallback: bool, +} + +pub const DEFAULT_WEB_SEARCH_PROVIDER: &str = "duckduckgo"; +const BRAVE_PROVIDER: &str = "brave"; + +pub fn resolve_web_search_provider(raw_provider: &str) -> WebSearchProviderResolution { + let normalized = raw_provider.trim().to_ascii_lowercase(); + match normalized.as_str() { + "" | "default" | "duckduckgo" | "ddg" | "duck-duck-go" | "duck_duck_go" => { + WebSearchProviderResolution { + route: WebSearchProviderRoute::DuckDuckGo, + canonical_provider: DEFAULT_WEB_SEARCH_PROVIDER, + used_fallback: false, + } + } + "brave" | "brave-search" | "brave_search" => WebSearchProviderResolution { + route: WebSearchProviderRoute::Brave, + canonical_provider: BRAVE_PROVIDER, + used_fallback: false, + }, + _ => WebSearchProviderResolution { + route: WebSearchProviderRoute::DuckDuckGo, + canonical_provider: DEFAULT_WEB_SEARCH_PROVIDER, + used_fallback: true, + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_aliases_to_duckduckgo() { + let ddg_aliases = ["duckduckgo", "ddg", "duck-duck-go", "duck_duck_go"]; + for alias in ddg_aliases { + let resolved = resolve_web_search_provider(alias); + assert_eq!(resolved.route, WebSearchProviderRoute::DuckDuckGo); + assert_eq!(resolved.canonical_provider, DEFAULT_WEB_SEARCH_PROVIDER); + assert!(!resolved.used_fallback); + } + } + + #[test] + fn resolve_aliases_to_brave() { + let brave_aliases = ["brave", "brave-search", "brave_search"]; + for alias in brave_aliases { + let resolved = resolve_web_search_provider(alias); + assert_eq!(resolved.route, WebSearchProviderRoute::Brave); + assert_eq!(resolved.canonical_provider, BRAVE_PROVIDER); + assert!(!resolved.used_fallback); + } + } + + #[test] + fn resolve_unknown_provider_falls_back_to_default() { + let resolved = resolve_web_search_provider("bing"); + assert_eq!(resolved.route, WebSearchProviderRoute::DuckDuckGo); + assert_eq!(resolved.canonical_provider, DEFAULT_WEB_SEARCH_PROVIDER); + assert!(resolved.used_fallback); + } +} diff --git a/src/tools/web_search_tool.rs b/src/tools/web_search_tool.rs index a1ef6470e..2b0183a40 100644 --- a/src/tools/web_search_tool.rs +++ b/src/tools/web_search_tool.rs @@ -1,4 +1,5 @@ use super::traits::{Tool, ToolResult}; +use super::web_search_provider_routing::{resolve_web_search_provider, WebSearchProviderRoute}; use async_trait::async_trait; use regex::Regex; use serde_json::json; @@ -13,6 +14,7 @@ use std::time::Duration; /// `[web_search] brave_api_key` field, and uses the result. This ensures that /// keys set or rotated after boot, and encrypted keys, are correctly picked up. pub struct WebSearchTool { + /// Provider selector as configured by user. Routed via provider aliases at runtime. provider: String, /// Boot-time key snapshot (may be `None` if not yet configured at startup). boot_brave_api_key: Option, @@ -300,13 +302,18 @@ impl Tool for WebSearchTool { tracing::info!("Searching web for: {}", query); - let result = match self.provider.as_str() { - "duckduckgo" | "ddg" => self.search_duckduckgo(query).await?, - "brave" => self.search_brave(query).await?, - _ => anyhow::bail!( - "Unknown search provider: '{}'. Set tools.web_search.provider to 'duckduckgo' or 'brave' in config.toml", - self.provider - ), + let resolution = resolve_web_search_provider(&self.provider); + if resolution.used_fallback { + tracing::warn!( + "Unknown web search provider '{}'; falling back to '{}'", + self.provider, + resolution.canonical_provider + ); + } + + let result = match resolution.route { + WebSearchProviderRoute::DuckDuckGo => self.search_duckduckgo(query).await?, + WebSearchProviderRoute::Brave => self.search_brave(query).await?, }; Ok(ToolResult {