fix(channels): wire model_switch callback into channel inference path (#4130)
The channel path in `src/channels/mod.rs` was passing `None` as the `model_switch_callback` to `run_tool_call_loop()`, which meant model switching via the `model_switch` tool was silently ignored in channel mode. Wire the callback in following the same pattern as the CLI path: - Pass `Some(model_switch_callback.clone())` instead of `None` - Wrap the tool call loop in a retry loop - Handle `ModelSwitchRequested` errors by re-creating the provider with the new model and retrying Fixes #4107
This commit is contained in:
parent
7735008246
commit
fc9f511c52
@ -89,7 +89,10 @@ pub use whatsapp::WhatsAppChannel;
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
|
||||
use crate::agent::loop_::{
|
||||
build_tool_instructions, clear_model_switch_request, get_model_switch_state,
|
||||
is_model_switch_requested, run_tool_call_loop, scrub_credentials,
|
||||
};
|
||||
use crate::approval::ApprovalManager;
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
@ -2081,7 +2084,7 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
let active_provider = match get_or_create_provider(
|
||||
let mut active_provider = match get_or_create_provider(
|
||||
ctx.as_ref(),
|
||||
&route.provider,
|
||||
route.api_key.as_deref(),
|
||||
@ -2391,45 +2394,92 @@ async fn process_channel_message(
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
let model_switch_callback = get_model_switch_state();
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
let llm_call_start = Instant::now();
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let elapsed_before_llm_ms = started_at.elapsed().as_millis() as u64;
|
||||
tracing::info!(elapsed_before_llm_ms, "⏱ Starting LLM call");
|
||||
let llm_result = tokio::select! {
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_budget_secs),
|
||||
run_tool_call_loop(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
notify_observer.as_ref() as &dyn Observer,
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
Some(&*ctx.approval_manager),
|
||||
msg.channel.as_str(),
|
||||
Some(msg.reply_target.as_str()),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
delta_tx,
|
||||
ctx.hooks.as_deref(),
|
||||
if msg.channel == "cli"
|
||||
|| ctx.autonomy_level == AutonomyLevel::Full
|
||||
let llm_result = loop {
|
||||
let loop_result = tokio::select! {
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_budget_secs),
|
||||
run_tool_call_loop(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
notify_observer.as_ref() as &dyn Observer,
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
Some(&*ctx.approval_manager),
|
||||
msg.channel.as_str(),
|
||||
Some(msg.reply_target.as_str()),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
delta_tx.clone(),
|
||||
ctx.hooks.as_deref(),
|
||||
if msg.channel == "cli"
|
||||
|| ctx.autonomy_level == AutonomyLevel::Full
|
||||
{
|
||||
&[]
|
||||
} else {
|
||||
ctx.non_cli_excluded_tools.as_ref()
|
||||
},
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
|
||||
// Handle model switch: re-create the provider and retry
|
||||
if let LlmExecutionResult::Completed(Ok(Err(ref e))) = loop_result {
|
||||
if let Some((new_provider, new_model)) = is_model_switch_requested(e) {
|
||||
tracing::info!(
|
||||
"Model switch requested, switching from {} {} to {} {}",
|
||||
route.provider,
|
||||
route.model,
|
||||
new_provider,
|
||||
new_model
|
||||
);
|
||||
|
||||
match create_resilient_provider_nonblocking(
|
||||
&new_provider,
|
||||
ctx.api_key.clone(),
|
||||
ctx.api_url.clone(),
|
||||
ctx.reliability.as_ref().clone(),
|
||||
ctx.provider_runtime_options.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
&[]
|
||||
} else {
|
||||
ctx.non_cli_excluded_tools.as_ref()
|
||||
},
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
None,
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
Ok(new_prov) => {
|
||||
active_provider = Arc::from(new_prov);
|
||||
route.provider = new_provider;
|
||||
route.model = new_model;
|
||||
clear_model_switch_request();
|
||||
|
||||
ctx.observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: route.provider.clone(),
|
||||
model: route.model.clone(),
|
||||
});
|
||||
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to create provider after model switch: {err}");
|
||||
clear_model_switch_request();
|
||||
// Fall through with the original error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break loop_result;
|
||||
};
|
||||
|
||||
if let Some(handle) = draft_updater {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user