fix(channels): wire model_switch callback into channel inference path (#4130)

The channel path in `src/channels/mod.rs` was passing `None` as the
`model_switch_callback` to `run_tool_call_loop()`, which meant model
switching via the `model_switch` tool was silently ignored in channel
mode.

Wire the callback in following the same pattern as the CLI path:
- Pass `Some(model_switch_callback.clone())` instead of `None`
- Wrap the tool call loop in a retry loop
- Handle `ModelSwitchRequested` errors by re-creating the provider
  with the new model and retrying

Fixes #4107
This commit is contained in:
Argenis 2026-03-21 05:43:21 -04:00 committed by Roman Tataurov
parent 7735008246
commit fc9f511c52
No known key found for this signature in database
GPG Key ID: 70A51EF3185C334B

View File

@ -89,7 +89,10 @@ pub use whatsapp::WhatsAppChannel;
#[cfg(feature = "whatsapp-web")]
pub use whatsapp_web::WhatsAppWebChannel;
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
use crate::agent::loop_::{
build_tool_instructions, clear_model_switch_request, get_model_switch_state,
is_model_switch_requested, run_tool_call_loop, scrub_credentials,
};
use crate::approval::ApprovalManager;
use crate::config::Config;
use crate::identity;
@ -2081,7 +2084,7 @@ async fn process_channel_message(
}
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
let active_provider = match get_or_create_provider(
let mut active_provider = match get_or_create_provider(
ctx.as_ref(),
&route.provider,
route.api_key.as_deref(),
@ -2391,45 +2394,92 @@ async fn process_channel_message(
Cancelled,
}
let model_switch_callback = get_model_switch_state();
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
let llm_call_start = Instant::now();
#[allow(clippy::cast_possible_truncation)]
let elapsed_before_llm_ms = started_at.elapsed().as_millis() as u64;
tracing::info!(elapsed_before_llm_ms, "⏱ Starting LLM call");
let llm_result = tokio::select! {
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
result = tokio::time::timeout(
Duration::from_secs(timeout_budget_secs),
run_tool_call_loop(
active_provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
notify_observer.as_ref() as &dyn Observer,
route.provider.as_str(),
route.model.as_str(),
runtime_defaults.temperature,
true,
Some(&*ctx.approval_manager),
msg.channel.as_str(),
Some(msg.reply_target.as_str()),
&ctx.multimodal,
ctx.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx,
ctx.hooks.as_deref(),
if msg.channel == "cli"
|| ctx.autonomy_level == AutonomyLevel::Full
let llm_result = loop {
let loop_result = tokio::select! {
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
result = tokio::time::timeout(
Duration::from_secs(timeout_budget_secs),
run_tool_call_loop(
active_provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
notify_observer.as_ref() as &dyn Observer,
route.provider.as_str(),
route.model.as_str(),
runtime_defaults.temperature,
true,
Some(&*ctx.approval_manager),
msg.channel.as_str(),
Some(msg.reply_target.as_str()),
&ctx.multimodal,
ctx.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx.clone(),
ctx.hooks.as_deref(),
if msg.channel == "cli"
|| ctx.autonomy_level == AutonomyLevel::Full
{
&[]
} else {
ctx.non_cli_excluded_tools.as_ref()
},
ctx.tool_call_dedup_exempt.as_ref(),
ctx.activated_tools.as_ref(),
Some(model_switch_callback.clone()),
),
) => LlmExecutionResult::Completed(result),
};
// Handle model switch: re-create the provider and retry
if let LlmExecutionResult::Completed(Ok(Err(ref e))) = loop_result {
if let Some((new_provider, new_model)) = is_model_switch_requested(e) {
tracing::info!(
"Model switch requested, switching from {} {} to {} {}",
route.provider,
route.model,
new_provider,
new_model
);
match create_resilient_provider_nonblocking(
&new_provider,
ctx.api_key.clone(),
ctx.api_url.clone(),
ctx.reliability.as_ref().clone(),
ctx.provider_runtime_options.clone(),
)
.await
{
&[]
} else {
ctx.non_cli_excluded_tools.as_ref()
},
ctx.tool_call_dedup_exempt.as_ref(),
ctx.activated_tools.as_ref(),
None,
),
) => LlmExecutionResult::Completed(result),
Ok(new_prov) => {
active_provider = Arc::from(new_prov);
route.provider = new_provider;
route.model = new_model;
clear_model_switch_request();
ctx.observer.record_event(&ObserverEvent::AgentStart {
provider: route.provider.clone(),
model: route.model.clone(),
});
continue;
}
Err(err) => {
tracing::error!("Failed to create provider after model switch: {err}");
clear_model_switch_request();
// Fall through with the original error
}
}
}
}
break loop_result;
};
if let Some(handle) = draft_updater {