fix(channels): prompt non-CLI always_ask approvals (#2337)
* fix(channels): prompt non-cli always_ask approvals * chore(ci): retrigger intake after PR template update
This commit is contained in:
parent
f3c82cb13a
commit
0683467bc1
@ -778,36 +778,40 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
|
||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
excluded_tools: &[String],
|
||||
progress_mode: ProgressMode,
|
||||
safety_heartbeat: Option<SafetyHeartbeatConfig>,
|
||||
) -> Result<String> {
|
||||
let reply_target = non_cli_approval_context
|
||||
.as_ref()
|
||||
.map(|ctx| ctx.reply_target.clone());
|
||||
|
||||
SAFETY_HEARTBEAT_CONFIG
|
||||
TOOL_LOOP_PROGRESS_MODE
|
||||
.scope(
|
||||
safety_heartbeat,
|
||||
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
|
||||
non_cli_approval_context,
|
||||
TOOL_LOOP_REPLY_TARGET.scope(
|
||||
reply_target,
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
history,
|
||||
tools_registry,
|
||||
observer,
|
||||
provider_name,
|
||||
model,
|
||||
temperature,
|
||||
silent,
|
||||
approval,
|
||||
channel_name,
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
cancellation_token,
|
||||
on_delta,
|
||||
hooks,
|
||||
excluded_tools,
|
||||
progress_mode,
|
||||
SAFETY_HEARTBEAT_CONFIG.scope(
|
||||
safety_heartbeat,
|
||||
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
|
||||
non_cli_approval_context,
|
||||
TOOL_LOOP_REPLY_TARGET.scope(
|
||||
reply_target,
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
history,
|
||||
tools_registry,
|
||||
observer,
|
||||
provider_name,
|
||||
model,
|
||||
temperature,
|
||||
silent,
|
||||
approval,
|
||||
channel_name,
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
cancellation_token,
|
||||
on_delta,
|
||||
hooks,
|
||||
excluded_tools,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -3617,6 +3621,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
ProgressMode::Verbose,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
||||
@ -78,7 +78,8 @@ pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{
|
||||
build_shell_policy_instructions, build_tool_instructions_from_specs,
|
||||
run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig,
|
||||
run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext,
|
||||
NonCliApprovalPrompt, SafetyHeartbeatConfig,
|
||||
};
|
||||
use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager};
|
||||
use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError};
|
||||
@ -3664,11 +3665,53 @@ or tune thresholds in config.",
|
||||
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
|
||||
let (approval_prompt_tx, mut approval_prompt_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<NonCliApprovalPrompt>();
|
||||
let non_cli_approval_context = if msg.channel != "cli" && target_channel.is_some() {
|
||||
Some(NonCliApprovalContext {
|
||||
sender: msg.sender.clone(),
|
||||
reply_target: msg.reply_target.clone(),
|
||||
prompt_tx: approval_prompt_tx,
|
||||
})
|
||||
} else {
|
||||
drop(approval_prompt_tx);
|
||||
None
|
||||
};
|
||||
let approval_prompt_dispatcher = if let (Some(channel_ref), true) =
|
||||
(target_channel.as_ref(), non_cli_approval_context.is_some())
|
||||
{
|
||||
let channel = Arc::clone(channel_ref);
|
||||
let reply_target = msg.reply_target.clone();
|
||||
let thread_ts = msg.thread_ts.clone();
|
||||
Some(tokio::spawn(async move {
|
||||
while let Some(prompt) = approval_prompt_rx.recv().await {
|
||||
if let Err(err) = channel
|
||||
.send_approval_prompt(
|
||||
&reply_target,
|
||||
&prompt.request_id,
|
||||
&prompt.tool_name,
|
||||
&prompt.arguments,
|
||||
thread_ts.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"Failed to send non-CLI approval prompt for request {}: {err}",
|
||||
prompt.request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let llm_result = tokio::select! {
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_budget_secs),
|
||||
run_tool_call_loop_with_reply_target(
|
||||
run_tool_call_loop_with_non_cli_approval_context(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
@ -3679,7 +3722,7 @@ or tune thresholds in config.",
|
||||
true,
|
||||
Some(ctx.approval_manager.as_ref()),
|
||||
msg.channel.as_str(),
|
||||
Some(msg.reply_target.as_str()),
|
||||
non_cli_approval_context,
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
@ -3687,6 +3730,7 @@ or tune thresholds in config.",
|
||||
ctx.hooks.as_deref(),
|
||||
&excluded_tools_snapshot,
|
||||
progress_mode,
|
||||
ctx.safety_heartbeat.clone(),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@ -3694,6 +3738,9 @@ or tune thresholds in config.",
|
||||
if let Some(handle) = draft_updater {
|
||||
let _ = handle.await;
|
||||
}
|
||||
if let Some(handle) = approval_prompt_dispatcher {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
if let Some(token) = typing_cancellation.as_ref() {
|
||||
token.cancel();
|
||||
@ -7653,6 +7700,131 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_prompts_and_waits_for_non_cli_always_ask_approval() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
always_ask: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
approval_manager: Arc::new(ApprovalManager::from_config(&autonomy_cfg)),
|
||||
safety_heartbeat: None,
|
||||
startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(),
|
||||
});
|
||||
|
||||
let runtime_ctx_for_first_turn = runtime_ctx.clone();
|
||||
let first_turn = tokio::spawn(async move {
|
||||
process_channel_message(
|
||||
runtime_ctx_for_first_turn,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-non-cli-approval-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "What is the BTC price now?".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let request_id = tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
let pending = runtime_ctx.approval_manager.list_non_cli_pending_requests(
|
||||
Some("alice"),
|
||||
Some("telegram"),
|
||||
Some("chat-1"),
|
||||
);
|
||||
if let Some(req) = pending.first() {
|
||||
break req.request_id.clone();
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("pending approval request should be created for always_ask tool");
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx.clone(),
|
||||
traits::ChannelMessage {
|
||||
id: "msg-non-cli-approval-2".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: format!("/approve-allow {request_id}"),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(5), first_turn)
|
||||
.await
|
||||
.expect("first channel turn should finish after approval")
|
||||
.expect("first channel turn task should not panic");
|
||||
|
||||
let sent = channel_impl.sent_messages.lock().await;
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("Approval required for tool `mock_price`")),
|
||||
"channel should emit non-cli approval prompt"
|
||||
);
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("Approved supervised execution for `mock_price`")),
|
||||
"channel should acknowledge explicit approval command"
|
||||
);
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("BTC is currently around")),
|
||||
"tool call should execute after approval and produce final response"
|
||||
);
|
||||
assert!(
|
||||
sent.iter().all(|entry| !entry.contains("Denied by user.")),
|
||||
"always_ask tool should not be silently denied once non-cli approval prompt path is wired"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_denies_approval_management_for_unlisted_sender() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
|
||||
Loading…
Reference in New Issue
Block a user