diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 8ee49a598..859b68d93 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -17,7 +17,7 @@ use std::collections::HashSet; use std::fmt::Write; use std::io::Write as _; use std::path::{Path, PathBuf}; -use std::sync::{Arc, LazyLock}; +use std::sync::{Arc, LazyLock, Mutex}; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -33,6 +33,29 @@ const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10; /// Matches the channel-side constant in `channels/mod.rs`. const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20; +/// Callback type for checking if model has been switched during tool execution. +/// Returns Some((provider, model)) if a switch was requested, None otherwise. +pub type ModelSwitchCallback = Arc>>; + +/// Global model switch request state - used for runtime model switching via model_switch tool. +/// This is set by the model_switch tool and checked by the agent loop. +#[allow(clippy::type_complexity)] +static MODEL_SWITCH_REQUEST: LazyLock>>> = + LazyLock::new(|| Arc::new(Mutex::new(None))); + +/// Get the global model switch request state +pub fn get_model_switch_state() -> ModelSwitchCallback { + Arc::clone(&MODEL_SWITCH_REQUEST) +} + +/// Clear any pending model switch request +pub fn clear_model_switch_request() { + if let Ok(guard) = MODEL_SWITCH_REQUEST.lock() { + let mut guard = guard; + *guard = None; + } +} + fn glob_match(pattern: &str, name: &str) -> bool { match pattern.find('*') { None => pattern == name, @@ -2118,6 +2141,31 @@ pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool { err.chain().any(|source| source.is::()) } +#[derive(Debug)] +pub(crate) struct ModelSwitchRequested { + pub provider: String, + pub model: String, +} + +impl std::fmt::Display for ModelSwitchRequested { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "model switch requested to {} {}", + self.provider, self.model + ) + } +} + +impl std::error::Error for ModelSwitchRequested {} + +pub(crate) fn is_model_switch_requested(err: &anyhow::Error) -> Option<(String, String)> { + err.chain() + .filter_map(|source| source.downcast_ref::()) + .map(|e| (e.provider.clone(), e.model.clone())) + .next() +} + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. /// When `silent` is true, suppresses stdout (for channel use). @@ -2137,6 +2185,7 @@ pub(crate) async fn agent_turn( excluded_tools: &[String], dedup_exempt_tools: &[String], activated_tools: Option<&std::sync::Arc>>, + model_switch_callback: Option, ) -> Result { run_tool_call_loop( provider, @@ -2157,6 +2206,7 @@ pub(crate) async fn agent_turn( excluded_tools, dedup_exempt_tools, activated_tools, + model_switch_callback, ) .await } @@ -2362,6 +2412,7 @@ pub(crate) async fn run_tool_call_loop( excluded_tools: &[String], dedup_exempt_tools: &[String], activated_tools: Option<&std::sync::Arc>>, + model_switch_callback: Option, ) -> Result { let max_iterations = if max_tool_iterations == 0 { DEFAULT_MAX_TOOL_ITERATIONS @@ -2380,6 +2431,28 @@ pub(crate) async fn run_tool_call_loop( return Err(ToolLoopCancelled.into()); } + // Check if model switch was requested via model_switch tool + if let Some(ref callback) = model_switch_callback { + if let Ok(guard) = callback.lock() { + if let Some((new_provider, new_model)) = guard.as_ref() { + if new_provider != provider_name || new_model != model { + tracing::info!( + "Model switch detected: {} {} -> {} {}", + provider_name, + model, + new_provider, + new_model + ); + return Err(ModelSwitchRequested { + provider: new_provider.clone(), + model: new_model.clone(), + } + .into()); + } + } + } + } + // Rebuild tool_specs each iteration so newly activated deferred tools appear. let mut tool_specs: Vec = tools_registry .iter() @@ -3199,28 +3272,32 @@ pub async fn run( } // ── Resolve provider ───────────────────────────────────────── - let provider_name = provider_override + let mut provider_name = provider_override .as_deref() .or(config.default_provider.as_deref()) - .unwrap_or("openrouter"); + .unwrap_or("openrouter") + .to_string(); - let model_name = model_override + let mut model_name = model_override .as_deref() .or(config.default_model.as_deref()) - .unwrap_or("anthropic/claude-sonnet-4"); + .unwrap_or("anthropic/claude-sonnet-4") + .to_string(); let provider_runtime_options = providers::provider_runtime_options_from_config(&config); - let provider: Box = providers::create_routed_provider_with_options( - provider_name, + let mut provider: Box = providers::create_routed_provider_with_options( + &provider_name, config.api_key.as_deref(), config.api_url.as_deref(), &config.reliability, &config.model_routes, - model_name, + &model_name, &provider_runtime_options, )?; + let model_switch_callback = get_model_switch_state(); + observer.record_event(&ObserverEvent::AgentStart { provider: provider_name.to_string(), model: model_name.to_string(), @@ -3364,7 +3441,7 @@ pub async fn run( let native_tools = provider.supports_native_tools(); let mut system_prompt = crate::channels::build_system_prompt_with_mode( &config.workspace_dir, - model_name, + &model_name, &tool_descs, &skills, Some(&config.identity), @@ -3447,27 +3524,72 @@ pub async fn run( let excluded_tools = compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, &msg); - let response = run_tool_call_loop( - provider.as_ref(), - &mut history, - &tools_registry, - observer.as_ref(), - provider_name, - model_name, - temperature, - false, - approval_manager.as_ref(), - channel_name, - &config.multimodal, - config.agent.max_tool_iterations, - None, - None, - None, - &excluded_tools, - &config.agent.tool_call_dedup_exempt, - activated_handle.as_ref(), - ) - .await?; + #[allow(unused_assignments)] + let mut response = String::new(); + loop { + match run_tool_call_loop( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + &provider_name, + &model_name, + temperature, + false, + approval_manager.as_ref(), + channel_name, + &config.multimodal, + config.agent.max_tool_iterations, + None, + None, + None, + &excluded_tools, + &config.agent.tool_call_dedup_exempt, + activated_handle.as_ref(), + Some(model_switch_callback.clone()), + ) + .await + { + Ok(resp) => { + response = resp; + break; + } + Err(e) => { + if let Some((new_provider, new_model)) = is_model_switch_requested(&e) { + tracing::info!( + "Model switch requested, switching from {} {} to {} {}", + provider_name, + model_name, + new_provider, + new_model + ); + + provider = providers::create_routed_provider_with_options( + &new_provider, + config.api_key.as_deref(), + config.api_url.as_deref(), + &config.reliability, + &config.model_routes, + &new_model, + &provider_runtime_options, + )?; + + provider_name = new_provider; + model_name = new_model; + + clear_model_switch_request(); + + observer.record_event(&ObserverEvent::AgentStart { + provider: provider_name.to_string(), + model: model_name.to_string(), + }); + + continue; + } + return Err(e); + } + } + } final_output = response.clone(); println!("{response}"); observer.record_event(&ObserverEvent::TurnComplete); @@ -3609,32 +3731,66 @@ pub async fn run( &user_input, ); - let response = match run_tool_call_loop( - provider.as_ref(), - &mut history, - &tools_registry, - observer.as_ref(), - provider_name, - model_name, - temperature, - false, - approval_manager.as_ref(), - channel_name, - &config.multimodal, - config.agent.max_tool_iterations, - None, - None, - None, - &excluded_tools, - &config.agent.tool_call_dedup_exempt, - activated_handle.as_ref(), - ) - .await - { - Ok(resp) => resp, - Err(e) => { - eprintln!("\nError: {e}\n"); - continue; + let response = loop { + match run_tool_call_loop( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + &provider_name, + &model_name, + temperature, + false, + approval_manager.as_ref(), + channel_name, + &config.multimodal, + config.agent.max_tool_iterations, + None, + None, + None, + &excluded_tools, + &config.agent.tool_call_dedup_exempt, + activated_handle.as_ref(), + Some(model_switch_callback.clone()), + ) + .await + { + Ok(resp) => break resp, + Err(e) => { + if let Some((new_provider, new_model)) = is_model_switch_requested(&e) { + tracing::info!( + "Model switch requested, switching from {} {} to {} {}", + provider_name, + model_name, + new_provider, + new_model + ); + + provider = providers::create_routed_provider_with_options( + &new_provider, + config.api_key.as_deref(), + config.api_url.as_deref(), + &config.reliability, + &config.model_routes, + &new_model, + &provider_runtime_options, + )?; + + provider_name = new_provider; + model_name = new_model; + + clear_model_switch_request(); + + observer.record_event(&ObserverEvent::AgentStart { + provider: provider_name.to_string(), + model: model_name.to_string(), + }); + + continue; + } + eprintln!("\nError: {e}\n"); + break String::new(); + } } }; final_output = response.clone(); @@ -3652,7 +3808,7 @@ pub async fn run( if let Ok(compacted) = auto_compact_history( &mut history, provider.as_ref(), - model_name, + &model_name, config.agent.max_history_messages, config.agent.max_context_tokens, ) @@ -3946,6 +4102,7 @@ pub async fn process_message( &excluded_tools, &config.agent.tool_call_dedup_exempt, activated_handle_pm.as_ref(), + None, ) .await } @@ -4403,6 +4560,7 @@ mod tests { &[], &[], None, + None, ) .await .expect_err("provider without vision support should fail"); @@ -4451,6 +4609,7 @@ mod tests { &[], &[], None, + None, ) .await .expect_err("oversized payload must fail"); @@ -4493,6 +4652,7 @@ mod tests { &[], &[], None, + None, ) .await .expect("valid multimodal payload should pass"); @@ -4621,6 +4781,7 @@ mod tests { &[], &[], None, + None, ) .await .expect("parallel execution should complete"); @@ -4692,6 +4853,7 @@ mod tests { &[], &[], None, + None, ) .await .expect("loop should finish after deduplicating repeated calls"); @@ -4759,6 +4921,7 @@ mod tests { &[], &[], None, + None, ) .await .expect("non-interactive shell should succeed for low-risk command"); @@ -4817,6 +4980,7 @@ mod tests { &[], &exempt, None, + None, ) .await .expect("loop should finish with exempt tool executing twice"); @@ -4895,6 +5059,7 @@ mod tests { &[], &exempt, None, + None, ) .await .expect("loop should complete"); @@ -4950,6 +5115,7 @@ mod tests { &[], &[], None, + None, ) .await .expect("native fallback id flow should complete"); @@ -5018,6 +5184,7 @@ mod tests { &[], &[], Some(&activated), + None, ) .await .expect("wrapper path should execute activated tools"); @@ -6909,6 +7076,7 @@ Let me check the result."#; &[], &[], None, + None, ) .await .expect("tool loop should complete"); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 8bf75daad..0ec690cde 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -2251,6 +2251,7 @@ async fn process_channel_message( }, ctx.tool_call_dedup_exempt.as_ref(), ctx.activated_tools.as_ref(), + None, ), ) => LlmExecutionResult::Completed(result), }; diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index aa47785f3..0831ad05e 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -422,6 +422,7 @@ impl DelegateTool { &[], &[], None, + None, ), ) .await; diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 7c5ae54d0..159f926cc 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -59,6 +59,7 @@ pub mod memory_recall; pub mod memory_store; pub mod microsoft365; pub mod model_routing_config; +pub mod model_switch; pub mod node_tool; pub mod notion_tool; pub mod pdf_read; @@ -119,6 +120,7 @@ pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; pub use microsoft365::Microsoft365Tool; pub use model_routing_config::ModelRoutingConfigTool; +pub use model_switch::ModelSwitchTool; #[allow(unused_imports)] pub use node_tool::NodeTool; pub use notion_tool::NotionTool; @@ -302,6 +304,7 @@ pub fn all_tools_with_runtime( config.clone(), security.clone(), )), + Arc::new(ModelSwitchTool::new(security.clone())), Arc::new(ProxyConfigTool::new(config.clone(), security.clone())), Arc::new(GitOperationsTool::new( security.clone(), diff --git a/src/tools/model_switch.rs b/src/tools/model_switch.rs new file mode 100644 index 000000000..a5882a210 --- /dev/null +++ b/src/tools/model_switch.rs @@ -0,0 +1,264 @@ +use super::traits::{Tool, ToolResult}; +use crate::agent::loop_::get_model_switch_state; +use crate::providers; +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +pub struct ModelSwitchTool { + security: Arc, +} + +impl ModelSwitchTool { + pub fn new(security: Arc) -> Self { + Self { security } + } +} + +#[async_trait] +impl Tool for ModelSwitchTool { + fn name(&self) -> &str { + "model_switch" + } + + fn description(&self) -> &str { + "Switch the AI model at runtime. Use 'get' to see current model, 'list_providers' to see available providers, 'list_models' to see models for a provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["get", "set", "list_providers", "list_models"], + "description": "Action to perform: get current model, set a new model, list available providers, or list models for a provider" + }, + "provider": { + "type": "string", + "description": "Provider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions." + }, + "model": { + "type": "string", + "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action." + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get"); + + if let Err(error) = self + .security + .enforce_tool_operation(ToolOperation::Act, "model_switch") + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + + match action { + "get" => self.handle_get(), + "set" => self.handle_set(&args), + "list_providers" => self.handle_list_providers(), + "list_models" => self.handle_list_models(&args), + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action: {}. Valid actions: get, set, list_providers, list_models", + action + )), + }), + } + } +} + +impl ModelSwitchTool { + fn handle_get(&self) -> anyhow::Result { + let switch_state = get_model_switch_state(); + let pending = switch_state.lock().unwrap().clone(); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "pending_switch": pending, + "note": "To switch models, use action 'set' with provider and model parameters" + }))?, + error: None, + }) + } + + fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result { + let provider = args.get("provider").and_then(|v| v.as_str()); + + let provider = match provider { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'provider' parameter for 'set' action".to_string()), + }); + } + }; + + let model = args.get("model").and_then(|v| v.as_str()); + + let model = match model { + Some(m) => m, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'model' parameter for 'set' action".to_string()), + }); + } + }; + + // Validate the provider exists + let known_providers = providers::list_providers(); + let provider_valid = known_providers.iter().any(|p| { + p.name.eq_ignore_ascii_case(provider) + || p.aliases.iter().any(|a| a.eq_ignore_ascii_case(provider)) + }); + + if !provider_valid { + return Ok(ToolResult { + success: false, + output: serde_json::to_string_pretty(&json!({ + "available_providers": known_providers.iter().map(|p| p.name).collect::>() + }))?, + error: Some(format!( + "Unknown provider: {}. Use 'list_providers' to see available options.", + provider + )), + }); + } + + // Set the global model switch request + let switch_state = get_model_switch_state(); + *switch_state.lock().unwrap() = Some((provider.to_string(), model.to_string())); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "message": "Model switch requested", + "provider": provider, + "model": model, + "note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch." + }))?, + error: None, + }) + } + + fn handle_list_providers(&self) -> anyhow::Result { + let providers_list = providers::list_providers(); + + let providers: Vec = providers_list + .iter() + .map(|p| { + json!({ + "name": p.name, + "display_name": p.display_name, + "aliases": p.aliases, + "local": p.local + }) + }) + .collect(); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "providers": providers, + "count": providers.len(), + "example": "Use action 'set' with provider and model to switch" + }))?, + error: None, + }) + } + + fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result { + let provider = args.get("provider").and_then(|v| v.as_str()); + + let provider = match provider { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "Missing 'provider' parameter for 'list_models' action".to_string(), + ), + }); + } + }; + + // Return common models for known providers + let models = match provider.to_lowercase().as_str() { + "openai" => vec![ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-3.5-turbo", + ], + "anthropic" => vec![ + "claude-sonnet-4-6", + "claude-sonnet-4-5", + "claude-3-5-sonnet", + "claude-3-opus", + "claude-3-haiku", + ], + "openrouter" => vec![ + "anthropic/claude-sonnet-4-6", + "openai/gpt-4o", + "google/gemini-pro", + "meta-llama/llama-3-70b-instruct", + ], + "groq" => vec![ + "llama-3.3-70b-versatile", + "mixtral-8x7b-32768", + "llama-3.1-70b-speculative", + ], + "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"], + "deepseek" => vec!["deepseek-chat", "deepseek-coder"], + "mistral" => vec![ + "mistral-large-latest", + "mistral-small-latest", + "mistral-nemo", + ], + "google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"], + "xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"], + _ => vec![], + }; + + if models.is_empty() { + return Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "provider": provider, + "models": [], + "note": "No common models listed for this provider. Check provider documentation for available models." + }))?, + error: None, + }); + } + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "provider": provider, + "models": models, + "example": "Use action 'set' with this provider and a model ID to switch" + }))?, + error: None, + }) + } +}