From 2ecfa0d269ddfca496c48f3631d186153faf97f4 Mon Sep 17 00:00:00 2001 From: Chummy Date: Wed, 25 Feb 2026 18:11:11 +0800 Subject: [PATCH] hardening: enforce channel tool boundaries and websocket auth --- src/agent/loop_.rs | 155 +++++++++++++++++++++++++++++++++++++++- src/channels/discord.rs | 122 ++++++++++++++++++++++++++----- src/channels/mod.rs | 91 ++++++++++++++++++++--- src/gateway/ws.rs | 108 ++++++++++++++++++++++++---- web/src/lib/ws.ts | 8 ++- 5 files changed, 443 insertions(+), 41 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index c9783b3ee..6c86cc75f 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -699,6 +699,36 @@ pub(crate) async fn run_tool_call_loop( } } + if excluded_tools.iter().any(|ex| ex == &tool_name) { + let blocked = format!("Tool '{tool_name}' is not available in this channel."); + runtime_trace::record_event( + "tool_call_result", + Some(channel_name), + Some(provider_name), + Some(model), + Some(&turn_id), + Some(false), + Some(&blocked), + serde_json::json!({ + "iteration": iteration + 1, + "tool": tool_name.clone(), + "arguments": scrub_credentials(&tool_args.to_string()), + "blocked_by_channel_policy": true, + }), + ); + ordered_results[idx] = Some(( + tool_name.clone(), + call.tool_call_id.clone(), + ToolExecutionOutcome { + output: blocked.clone(), + success: false, + error_reason: Some(blocked), + duration: Duration::ZERO, + }, + )); + continue; + } + // ── Approval hook ──────────────────────────────── if let Some(mgr) = approval { if mgr.needs_approval(&tool_name) { @@ -707,11 +737,12 @@ pub(crate) async fn run_tool_call_loop( arguments: tool_args.clone(), }; - // Only prompt interactively on CLI; auto-approve on other channels. + // Only CLI supports interactive prompts today. For non-CLI channels, + // fail closed instead of silently auto-approving privileged tools. let decision = if channel_name == "cli" { mgr.prompt_cli(&request) } else { - ApprovalResponse::Yes + ApprovalResponse::No }; mgr.record_decision(&tool_name, &tool_args, decision, channel_name); @@ -2278,6 +2309,126 @@ mod tests { ); } + #[tokio::test] + async fn run_tool_call_loop_denies_supervised_tools_on_non_cli_channels() { + let provider = ScriptedProvider::from_text_responses(vec![ + r#" +{"name":"shell","arguments":{"command":"echo hi"}} +"#, + "done", + ]); + + let active = Arc::new(AtomicUsize::new(0)); + let max_active = Arc::new(AtomicUsize::new(0)); + let tools_registry: Vec> = vec![Box::new(DelayTool::new( + "shell", + 50, + Arc::clone(&active), + Arc::clone(&max_active), + ))]; + + let approval_mgr = ApprovalManager::from_config(&crate::config::AutonomyConfig::default()); + + let mut history = vec![ + ChatMessage::system("test-system"), + ChatMessage::user("run shell"), + ]; + let observer = NoopObserver; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + Some(&approval_mgr), + "telegram", + &crate::config::MultimodalConfig::default(), + 4, + None, + None, + None, + &[], + ) + .await + .expect("tool loop should complete with denied tool execution"); + + assert_eq!(result, "done"); + assert_eq!( + max_active.load(Ordering::SeqCst), + 0, + "shell tool must not execute when approval is unavailable on non-CLI channels" + ); + } + + #[tokio::test] + async fn run_tool_call_loop_blocks_tools_excluded_for_channel() { + let provider = ScriptedProvider::from_text_responses(vec![ + r#" +{"name":"shell","arguments":{"command":"echo hi"}} +"#, + "done", + ]); + + let active = Arc::new(AtomicUsize::new(0)); + let max_active = Arc::new(AtomicUsize::new(0)); + let tools_registry: Vec> = vec![Box::new(DelayTool::new( + "shell", + 50, + Arc::clone(&active), + Arc::clone(&max_active), + ))]; + + let mut history = vec![ + ChatMessage::system("test-system"), + ChatMessage::user("run shell"), + ]; + let observer = NoopObserver; + let excluded_tools = vec!["shell".to_string()]; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "telegram", + &crate::config::MultimodalConfig::default(), + 4, + None, + None, + None, + &excluded_tools, + ) + .await + .expect("tool loop should complete with blocked tool execution"); + + assert_eq!(result, "done"); + assert_eq!( + max_active.load(Ordering::SeqCst), + 0, + "excluded tool must not execute even if the model requests it" + ); + + let tool_results_message = history + .iter() + .find(|msg| msg.role == "user" && msg.content.starts_with("[Tool results]")) + .expect("tool results message should be present"); + assert!( + tool_results_message + .content + .contains("not available in this channel"), + "blocked reason should be visible to the model" + ); + } + #[tokio::test] async fn run_tool_call_loop_deduplicates_repeated_tool_calls() { let provider = ScriptedProvider::from_text_responses(vec![ diff --git a/src/channels/discord.rs b/src/channels/discord.rs index d10007e5b..dc6dc727e 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -1,4 +1,5 @@ use super::traits::{Channel, ChannelMessage, SendMessage}; +use anyhow::Context; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; use parking_lot::Mutex; @@ -16,6 +17,7 @@ pub struct DiscordChannel { allowed_users: Vec, listen_to_bots: bool, mention_only: bool, + workspace_dir: Option, typing_handles: Mutex>>, } @@ -33,10 +35,17 @@ impl DiscordChannel { allowed_users, listen_to_bots, mention_only, + workspace_dir: None, typing_handles: Mutex::new(HashMap::new()), } } + /// Configure workspace directory used for validating local attachment paths. + pub fn with_workspace_dir(mut self, dir: PathBuf) -> Self { + self.workspace_dir = Some(dir); + self + } + fn http_client(&self) -> reqwest::Client { crate::config::build_runtime_proxy_client("channel.discord") } @@ -53,6 +62,42 @@ impl DiscordChannel { let part = token.split('.').next()?; base64_decode(part) } + + fn resolve_local_attachment_path(&self, target: &str) -> anyhow::Result { + let workspace = self.workspace_dir.as_ref().ok_or_else(|| { + anyhow::anyhow!("workspace_dir is not configured; local file attachments are disabled") + })?; + let workspace_root = workspace + .canonicalize() + .unwrap_or_else(|_| workspace.to_path_buf()); + + let target_path = if let Some(rel) = target.strip_prefix("/workspace/") { + workspace.join(rel) + } else if target == "/workspace" { + workspace.to_path_buf() + } else { + let path = Path::new(target); + if path.is_absolute() { + path.to_path_buf() + } else { + workspace.join(path) + } + }; + + let resolved = target_path + .canonicalize() + .with_context(|| format!("attachment path not found: {target}"))?; + + if !resolved.starts_with(&workspace_root) { + anyhow::bail!("attachment path escapes workspace: {target}"); + } + + if !resolved.is_file() { + anyhow::bail!("attachment path is not a file: {}", resolved.display()); + } + + Ok(resolved) + } } /// Process Discord message attachments and return a string to append to the @@ -188,10 +233,10 @@ fn parse_attachment_markers(message: &str) -> (String, Vec) { fn classify_outgoing_attachments( attachments: &[DiscordAttachment], -) -> (Vec, Vec, Vec) { +) -> (Vec, Vec, Vec) { let mut local_files = Vec::new(); let mut remote_urls = Vec::new(); - let mut unresolved_markers = Vec::new(); + let unresolved_markers = Vec::new(); for attachment in attachments { let target = attachment.target.trim(); @@ -200,13 +245,7 @@ fn classify_outgoing_attachments( continue; } - let path = Path::new(target); - if path.exists() && path.is_file() { - local_files.push(path.to_path_buf()); - continue; - } - - unresolved_markers.push(format!("[{}:{}]", attachment.kind.marker_name(), target)); + local_files.push(attachment.clone()); } (local_files, remote_urls, unresolved_markers) @@ -490,8 +529,28 @@ impl Channel for DiscordChannel { async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { let raw_content = super::strip_tool_call_tags(&message.content); let (cleaned_content, parsed_attachments) = parse_attachment_markers(&raw_content); - let (mut local_files, remote_urls, unresolved_markers) = + let (local_attachment_targets, remote_urls, mut unresolved_markers) = classify_outgoing_attachments(&parsed_attachments); + let mut local_files = Vec::new(); + + for attachment in &local_attachment_targets { + let target = attachment.target.trim(); + match self.resolve_local_attachment_path(target) { + Ok(path) => local_files.push(path), + Err(error) => { + tracing::warn!( + target, + error = %error, + "discord: local attachment rejected by workspace policy" + ); + unresolved_markers.push(format!( + "[{}:{}]", + attachment.kind.marker_name(), + target + )); + } + } + } if !unresolved_markers.is_empty() { tracing::warn!( @@ -1483,13 +1542,11 @@ mod tests { ]; let (locals, remotes, unresolved) = classify_outgoing_attachments(&attachments); - assert_eq!(locals.len(), 1); - assert_eq!(locals[0], file_path); + assert_eq!(locals.len(), 2); + assert_eq!(locals[0].target, file_path.to_string_lossy()); + assert_eq!(locals[1].target, "/tmp/does-not-exist.mp4"); assert_eq!(remotes, vec!["https://example.com/remote.png".to_string()]); - assert_eq!( - unresolved, - vec!["[VIDEO:/tmp/does-not-exist.mp4]".to_string()] - ); + assert!(unresolved.is_empty()); } #[test] @@ -1504,4 +1561,37 @@ mod tests { "Done\nhttps://example.com/a.png\n[IMAGE:/tmp/missing.png]" ); } + + #[test] + fn with_workspace_dir_sets_field() { + let channel = DiscordChannel::new("fake".into(), None, vec![], false, false) + .with_workspace_dir(PathBuf::from("/tmp/discord-workspace")); + assert_eq!( + channel.workspace_dir.as_deref(), + Some(Path::new("/tmp/discord-workspace")) + ); + } + + #[test] + fn resolve_local_attachment_path_blocks_workspace_escape() { + let temp = tempfile::tempdir().expect("tempdir"); + let workspace = temp.path().join("workspace"); + std::fs::create_dir_all(&workspace).expect("workspace should exist"); + + let outside = temp.path().join("outside.txt"); + std::fs::write(&outside, b"secret").expect("fixture should be written"); + + let channel = DiscordChannel::new("fake".into(), None, vec![], false, false) + .with_workspace_dir(workspace.clone()); + + let allowed_path = workspace.join("ok.txt"); + std::fs::write(&allowed_path, b"ok").expect("workspace fixture should be written"); + let allowed = channel + .resolve_local_attachment_path("ok.txt") + .expect("workspace file should be allowed"); + assert!(allowed.starts_with(workspace.canonicalize().unwrap_or(workspace))); + + let escaped = channel.resolve_local_attachment_path(outside.to_string_lossy().as_ref()); + assert!(escaped.is_err(), "path outside workspace must be rejected"); + } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 3428b9dde..1afa43e9e 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -70,6 +70,7 @@ pub use whatsapp_web::WhatsAppWebChannel; use crate::agent::loop_::{ build_shell_policy_instructions, build_tool_instructions, run_tool_call_loop, scrub_credentials, }; +use crate::approval::ApprovalManager; use crate::config::Config; use crate::identity; use crate::memory::{self, Memory}; @@ -228,6 +229,7 @@ struct ChannelRuntimeContext { non_cli_excluded_tools: Arc>, query_classification: crate::config::QueryClassificationConfig, model_routes: Vec, + approval_manager: Arc, } #[derive(Clone)] @@ -1782,7 +1784,7 @@ async fn process_channel_message( route.model.as_str(), runtime_defaults.temperature, true, - None, + Some(ctx.approval_manager.as_ref()), msg.channel.as_str(), &ctx.multimodal, ctx.max_tool_iterations, @@ -2746,13 +2748,16 @@ fn collect_configured_channels( if let Some(ref dc) = config.channels_config.discord { channels.push(ConfiguredChannel { display_name: "Discord", - channel: Arc::new(DiscordChannel::new( - dc.bot_token.clone(), - dc.guild_id.clone(), - dc.allowed_users.clone(), - dc.listen_to_bots, - dc.mention_only, - )), + channel: Arc::new( + DiscordChannel::new( + dc.bot_token.clone(), + dc.guild_id.clone(), + dc.allowed_users.clone(), + dc.listen_to_bots, + dc.mention_only, + ) + .with_workspace_dir(config.workspace_dir.clone()), + ), }); } @@ -3380,6 +3385,7 @@ pub async fn start_channels(config: Config) -> Result<()> { non_cli_excluded_tools: Arc::new(config.autonomy.non_cli_excluded_tools.clone()), query_classification: config.query_classification.clone(), model_routes: config.model_routes.clone(), + approval_manager: Arc::new(ApprovalManager::from_config(&config.autonomy)), }); run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; @@ -3595,6 +3601,9 @@ mod tests { non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }; assert!(compact_sender_history(&ctx, &sender)); @@ -3646,6 +3655,9 @@ mod tests { non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }; append_sender_turn(&ctx, &sender, ChatMessage::user("hello")); @@ -3700,6 +3712,9 @@ mod tests { non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }; assert!(rollback_orphan_user_turn(&ctx, &sender, "pending")); @@ -4173,6 +4188,9 @@ BTC is currently around $65,000 based on latest tool output."# message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, interrupt_on_new_message: false, non_cli_excluded_tools: Arc::new(Vec::new()), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), multimodal: crate::config::MultimodalConfig::default(), hooks: None, query_classification: crate::config::QueryClassificationConfig::default(), @@ -4234,6 +4252,9 @@ BTC is currently around $65,000 based on latest tool output."# message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, interrupt_on_new_message: false, non_cli_excluded_tools: Arc::new(Vec::new()), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), multimodal: crate::config::MultimodalConfig::default(), hooks: None, query_classification: crate::config::QueryClassificationConfig::default(), @@ -4313,6 +4334,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4374,6 +4398,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4444,6 +4471,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4535,6 +4565,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4608,6 +4641,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4696,6 +4732,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4769,6 +4808,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -4831,6 +4873,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -5004,6 +5049,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); let (tx, rx) = tokio::sync::mpsc::channel::(4); @@ -5086,6 +5134,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); let (tx, rx) = tokio::sync::mpsc::channel::(8); @@ -5180,6 +5231,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); let (tx, rx) = tokio::sync::mpsc::channel::(8); @@ -5256,6 +5310,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -5317,6 +5374,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -5835,6 +5895,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -5922,6 +5985,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -6009,6 +6075,9 @@ BTC is currently around $65,000 based on latest tool output."# non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( @@ -6560,6 +6629,9 @@ This is an example JSON object for profile settings."#; non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); // Simulate a photo attachment message with [IMAGE:] marker. @@ -6628,6 +6700,9 @@ This is an example JSON object for profile settings."#; non_cli_excluded_tools: Arc::new(Vec::new()), query_classification: crate::config::QueryClassificationConfig::default(), model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config( + &crate::config::AutonomyConfig::default(), + )), }); process_channel_message( diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 79fee5105..20591f8ee 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -11,34 +11,30 @@ use super::AppState; use crate::agent::loop_::run_tool_call_loop; +use crate::approval::ApprovalManager; use crate::providers::ChatMessage; use axum::{ extract::{ ws::{Message, WebSocket}, - Query, State, WebSocketUpgrade, + State, WebSocketUpgrade, }, + http::{header, HeaderMap}, response::IntoResponse, }; -use serde::Deserialize; - -#[derive(Deserialize)] -pub struct WsQuery { - pub token: Option, -} /// GET /ws/chat — WebSocket upgrade for agent chat pub async fn handle_ws_chat( State(state): State, - Query(params): Query, + headers: HeaderMap, ws: WebSocketUpgrade, ) -> impl IntoResponse { - // Auth via query param (browser WebSocket limitation) + // Auth via Authorization header or websocket protocol token. if state.pairing.require_pairing() { - let token = params.token.as_deref().unwrap_or(""); - if !state.pairing.is_authenticated(token) { + let token = extract_ws_bearer_token(&headers).unwrap_or_default(); + if !state.pairing.is_authenticated(&token) { return ( axum::http::StatusCode::UNAUTHORIZED, - "Unauthorized — provide ?token=", + "Unauthorized — provide Authorization: Bearer or Sec-WebSocket-Protocol: bearer.", ) .into_response(); } @@ -68,6 +64,11 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) { // Add system message to history history.push(ChatMessage::system(&system_prompt)); + let approval_manager = { + let config_guard = state.config.lock(); + ApprovalManager::from_config(&config_guard.autonomy) + }; + while let Some(msg) = socket.recv().await { let msg = match msg { Ok(Message::Text(text)) => text, @@ -123,7 +124,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) { &state.model, state.temperature, true, // silent - no console output - None, // approval manager + Some(&approval_manager), "webchat", &state.multimodal, state.max_tool_iterations, @@ -171,3 +172,84 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) { } } } + +fn extract_ws_bearer_token(headers: &HeaderMap) -> Option { + if let Some(auth_header) = headers + .get(header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .map(str::trim) + { + if let Some(token) = auth_header.strip_prefix("Bearer ") { + if !token.trim().is_empty() { + return Some(token.trim().to_string()); + } + } + } + + let offered = headers + .get(header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|value| value.to_str().ok())?; + + for protocol in offered.split(',').map(str::trim).filter(|s| !s.is_empty()) { + if let Some(token) = protocol.strip_prefix("bearer.") { + if !token.trim().is_empty() { + return Some(token.trim().to_string()); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::HeaderValue; + + #[test] + fn extract_ws_bearer_token_prefers_authorization_header() { + let mut headers = HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + HeaderValue::from_static("Bearer from-auth-header"), + ); + headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static("zeroclaw.v1, bearer.from-protocol"), + ); + + assert_eq!( + extract_ws_bearer_token(&headers).as_deref(), + Some("from-auth-header") + ); + } + + #[test] + fn extract_ws_bearer_token_reads_websocket_protocol_token() { + let mut headers = HeaderMap::new(); + headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static("zeroclaw.v1, bearer.protocol-token"), + ); + + assert_eq!( + extract_ws_bearer_token(&headers).as_deref(), + Some("protocol-token") + ); + } + + #[test] + fn extract_ws_bearer_token_rejects_empty_tokens() { + let mut headers = HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + HeaderValue::from_static("Bearer "), + ); + headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static("zeroclaw.v1, bearer."), + ); + + assert!(extract_ws_bearer_token(&headers).is_none()); + } +} diff --git a/web/src/lib/ws.ts b/web/src/lib/ws.ts index 920391f75..fc0bfa329 100644 --- a/web/src/lib/ws.ts +++ b/web/src/lib/ws.ts @@ -52,9 +52,13 @@ export class WebSocketClient { this.clearReconnectTimer(); const token = getToken(); - const url = `${this.baseUrl}/ws/chat${token ? `?token=${encodeURIComponent(token)}` : ''}`; + const url = `${this.baseUrl}/ws/chat`; + const protocols = ['zeroclaw.v1']; + if (token) { + protocols.push(`bearer.${token}`); + } - this.ws = new WebSocket(url); + this.ws = new WebSocket(url, protocols); this.ws.onopen = () => { this.currentDelay = this.reconnectDelay;