From fc9f511c528b1961866e59952814535d1cc448bf Mon Sep 17 00:00:00 2001 From: Argenis Date: Sat, 21 Mar 2026 05:43:21 -0400 Subject: [PATCH] fix(channels): wire model_switch callback into channel inference path (#4130) The channel path in `src/channels/mod.rs` was passing `None` as the `model_switch_callback` to `run_tool_call_loop()`, which meant model switching via the `model_switch` tool was silently ignored in channel mode. Wire the callback in following the same pattern as the CLI path: - Pass `Some(model_switch_callback.clone())` instead of `None` - Wrap the tool call loop in a retry loop - Handle `ModelSwitchRequested` errors by re-creating the provider with the new model and retrying Fixes #4107 --- src/channels/mod.rs | 118 +++++++++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 96960faab..9f428bf89 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -89,7 +89,10 @@ pub use whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] pub use whatsapp_web::WhatsAppWebChannel; -use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials}; +use crate::agent::loop_::{ + build_tool_instructions, clear_model_switch_request, get_model_switch_state, + is_model_switch_requested, run_tool_call_loop, scrub_credentials, +}; use crate::approval::ApprovalManager; use crate::config::Config; use crate::identity; @@ -2081,7 +2084,7 @@ async fn process_channel_message( } let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref()); - let active_provider = match get_or_create_provider( + let mut active_provider = match get_or_create_provider( ctx.as_ref(), &route.provider, route.api_key.as_deref(), @@ -2391,45 +2394,92 @@ async fn process_channel_message( Cancelled, } + let model_switch_callback = get_model_switch_state(); let timeout_budget_secs = channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations); let llm_call_start = Instant::now(); #[allow(clippy::cast_possible_truncation)] let elapsed_before_llm_ms = started_at.elapsed().as_millis() as u64; tracing::info!(elapsed_before_llm_ms, "⏱ Starting LLM call"); - let llm_result = tokio::select! { - () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, - result = tokio::time::timeout( - Duration::from_secs(timeout_budget_secs), - run_tool_call_loop( - active_provider.as_ref(), - &mut history, - ctx.tools_registry.as_ref(), - notify_observer.as_ref() as &dyn Observer, - route.provider.as_str(), - route.model.as_str(), - runtime_defaults.temperature, - true, - Some(&*ctx.approval_manager), - msg.channel.as_str(), - Some(msg.reply_target.as_str()), - &ctx.multimodal, - ctx.max_tool_iterations, - Some(cancellation_token.clone()), - delta_tx, - ctx.hooks.as_deref(), - if msg.channel == "cli" - || ctx.autonomy_level == AutonomyLevel::Full + let llm_result = loop { + let loop_result = tokio::select! { + () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, + result = tokio::time::timeout( + Duration::from_secs(timeout_budget_secs), + run_tool_call_loop( + active_provider.as_ref(), + &mut history, + ctx.tools_registry.as_ref(), + notify_observer.as_ref() as &dyn Observer, + route.provider.as_str(), + route.model.as_str(), + runtime_defaults.temperature, + true, + Some(&*ctx.approval_manager), + msg.channel.as_str(), + Some(msg.reply_target.as_str()), + &ctx.multimodal, + ctx.max_tool_iterations, + Some(cancellation_token.clone()), + delta_tx.clone(), + ctx.hooks.as_deref(), + if msg.channel == "cli" + || ctx.autonomy_level == AutonomyLevel::Full + { + &[] + } else { + ctx.non_cli_excluded_tools.as_ref() + }, + ctx.tool_call_dedup_exempt.as_ref(), + ctx.activated_tools.as_ref(), + Some(model_switch_callback.clone()), + ), + ) => LlmExecutionResult::Completed(result), + }; + + // Handle model switch: re-create the provider and retry + if let LlmExecutionResult::Completed(Ok(Err(ref e))) = loop_result { + if let Some((new_provider, new_model)) = is_model_switch_requested(e) { + tracing::info!( + "Model switch requested, switching from {} {} to {} {}", + route.provider, + route.model, + new_provider, + new_model + ); + + match create_resilient_provider_nonblocking( + &new_provider, + ctx.api_key.clone(), + ctx.api_url.clone(), + ctx.reliability.as_ref().clone(), + ctx.provider_runtime_options.clone(), + ) + .await { - &[] - } else { - ctx.non_cli_excluded_tools.as_ref() - }, - ctx.tool_call_dedup_exempt.as_ref(), - ctx.activated_tools.as_ref(), - None, - ), - ) => LlmExecutionResult::Completed(result), + Ok(new_prov) => { + active_provider = Arc::from(new_prov); + route.provider = new_provider; + route.model = new_model; + clear_model_switch_request(); + + ctx.observer.record_event(&ObserverEvent::AgentStart { + provider: route.provider.clone(), + model: route.model.clone(), + }); + + continue; + } + Err(err) => { + tracing::error!("Failed to create provider after model switch: {err}"); + clear_model_switch_request(); + // Fall through with the original error + } + } + } + } + + break loop_result; }; if let Some(handle) = draft_updater {