diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index bf5793fc1..802dd7453 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -778,36 +778,40 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context( on_delta: Option>, hooks: Option<&crate::hooks::HookRunner>, excluded_tools: &[String], + progress_mode: ProgressMode, safety_heartbeat: Option, ) -> Result { let reply_target = non_cli_approval_context .as_ref() .map(|ctx| ctx.reply_target.clone()); - SAFETY_HEARTBEAT_CONFIG + TOOL_LOOP_PROGRESS_MODE .scope( - safety_heartbeat, - TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope( - non_cli_approval_context, - TOOL_LOOP_REPLY_TARGET.scope( - reply_target, - run_tool_call_loop( - provider, - history, - tools_registry, - observer, - provider_name, - model, - temperature, - silent, - approval, - channel_name, - multimodal_config, - max_tool_iterations, - cancellation_token, - on_delta, - hooks, - excluded_tools, + progress_mode, + SAFETY_HEARTBEAT_CONFIG.scope( + safety_heartbeat, + TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope( + non_cli_approval_context, + TOOL_LOOP_REPLY_TARGET.scope( + reply_target, + run_tool_call_loop( + provider, + history, + tools_registry, + observer, + provider_name, + model, + temperature, + silent, + approval, + channel_name, + multimodal_config, + max_tool_iterations, + cancellation_token, + on_delta, + hooks, + excluded_tools, + ), ), ), ), @@ -3617,6 +3621,7 @@ mod tests { None, None, &[], + ProgressMode::Verbose, None, ) .await diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 51cf345de..1a5251895 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -78,7 +78,8 @@ pub use whatsapp_web::WhatsAppWebChannel; use crate::agent::loop_::{ build_shell_policy_instructions, build_tool_instructions_from_specs, - run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig, + run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext, + NonCliApprovalPrompt, SafetyHeartbeatConfig, }; use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager}; use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError}; @@ -3664,11 +3665,53 @@ or tune thresholds in config.", let timeout_budget_secs = channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations); + + let (approval_prompt_tx, mut approval_prompt_rx) = + tokio::sync::mpsc::unbounded_channel::(); + let non_cli_approval_context = if msg.channel != "cli" && target_channel.is_some() { + Some(NonCliApprovalContext { + sender: msg.sender.clone(), + reply_target: msg.reply_target.clone(), + prompt_tx: approval_prompt_tx, + }) + } else { + drop(approval_prompt_tx); + None + }; + let approval_prompt_dispatcher = if let (Some(channel_ref), true) = + (target_channel.as_ref(), non_cli_approval_context.is_some()) + { + let channel = Arc::clone(channel_ref); + let reply_target = msg.reply_target.clone(); + let thread_ts = msg.thread_ts.clone(); + Some(tokio::spawn(async move { + while let Some(prompt) = approval_prompt_rx.recv().await { + if let Err(err) = channel + .send_approval_prompt( + &reply_target, + &prompt.request_id, + &prompt.tool_name, + &prompt.arguments, + thread_ts.clone(), + ) + .await + { + tracing::warn!( + "Failed to send non-CLI approval prompt for request {}: {err}", + prompt.request_id + ); + } + } + })) + } else { + None + }; + let llm_result = tokio::select! { () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, result = tokio::time::timeout( Duration::from_secs(timeout_budget_secs), - run_tool_call_loop_with_reply_target( + run_tool_call_loop_with_non_cli_approval_context( active_provider.as_ref(), &mut history, ctx.tools_registry.as_ref(), @@ -3679,7 +3722,7 @@ or tune thresholds in config.", true, Some(ctx.approval_manager.as_ref()), msg.channel.as_str(), - Some(msg.reply_target.as_str()), + non_cli_approval_context, &ctx.multimodal, ctx.max_tool_iterations, Some(cancellation_token.clone()), @@ -3687,6 +3730,7 @@ or tune thresholds in config.", ctx.hooks.as_deref(), &excluded_tools_snapshot, progress_mode, + ctx.safety_heartbeat.clone(), ), ) => LlmExecutionResult::Completed(result), }; @@ -3694,6 +3738,9 @@ or tune thresholds in config.", if let Some(handle) = draft_updater { let _ = handle.await; } + if let Some(handle) = approval_prompt_dispatcher { + let _ = handle.await; + } if let Some(token) = typing_cancellation.as_ref() { token.cancel(); @@ -7653,6 +7700,131 @@ BTC is currently around $65,000 based on latest tool output."# assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0); } + #[tokio::test] + async fn process_channel_message_prompts_and_waits_for_non_cli_always_ask_approval() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let autonomy_cfg = crate::config::AutonomyConfig { + always_ask: vec!["mock_price".to_string()], + ..crate::config::AutonomyConfig::default() + }; + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(ToolCallingProvider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + conversation_locks: Default::default(), + session_config: crate::config::AgentSessionConfig::default(), + session_manager: None, + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + hooks: None, + non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())), + query_classification: crate::config::QueryClassificationConfig::default(), + model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config(&autonomy_cfg)), + safety_heartbeat: None, + startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(), + }); + + let runtime_ctx_for_first_turn = runtime_ctx.clone(); + let first_turn = tokio::spawn(async move { + process_channel_message( + runtime_ctx_for_first_turn, + traits::ChannelMessage { + id: "msg-non-cli-approval-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "What is the BTC price now?".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + }); + + let request_id = tokio::time::timeout(Duration::from_secs(2), async { + loop { + let pending = runtime_ctx.approval_manager.list_non_cli_pending_requests( + Some("alice"), + Some("telegram"), + Some("chat-1"), + ); + if let Some(req) = pending.first() { + break req.request_id.clone(); + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("pending approval request should be created for always_ask tool"); + + process_channel_message( + runtime_ctx.clone(), + traits::ChannelMessage { + id: "msg-non-cli-approval-2".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: format!("/approve-allow {request_id}"), + channel: "telegram".to_string(), + timestamp: 2, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + tokio::time::timeout(Duration::from_secs(5), first_turn) + .await + .expect("first channel turn should finish after approval") + .expect("first channel turn task should not panic"); + + let sent = channel_impl.sent_messages.lock().await; + assert!( + sent.iter() + .any(|entry| entry.contains("Approval required for tool `mock_price`")), + "channel should emit non-cli approval prompt" + ); + assert!( + sent.iter() + .any(|entry| entry.contains("Approved supervised execution for `mock_price`")), + "channel should acknowledge explicit approval command" + ); + assert!( + sent.iter() + .any(|entry| entry.contains("BTC is currently around")), + "tool call should execute after approval and produce final response" + ); + assert!( + sent.iter().all(|entry| !entry.contains("Denied by user.")), + "always_ask tool should not be silently denied once non-cli approval prompt path is wired" + ); + } + #[tokio::test] async fn process_channel_message_denies_approval_management_for_unlisted_sender() { let channel_impl = Arc::new(TelegramRecordingChannel::default());