//! WebSocket agent chat handler. //! //! Protocol: //! ```text //! Client -> Server: {"type":"message","content":"Hello"} //! Server -> Client: {"type":"chunk","content":"Hi! "} //! Server -> Client: {"type":"tool_call","name":"shell","args":{...}} //! Server -> Client: {"type":"tool_result","name":"shell","output":"..."} //! Server -> Client: {"type":"done","full_response":"..."} //! ``` use super::AppState; use crate::agent::loop_::{build_shell_policy_instructions, build_tool_instructions_from_specs}; use crate::memory::MemoryCategory; use crate::providers::ChatMessage; use axum::{ extract::{ ws::{Message, WebSocket}, ConnectInfo, RawQuery, State, WebSocketUpgrade, }, http::{header, HeaderMap}, response::IntoResponse, }; use std::net::SocketAddr; use uuid::Uuid; const EMPTY_WS_RESPONSE_FALLBACK: &str = "Tool execution completed, but the model returned no final text response. Please ask me to summarize the result."; const WS_HISTORY_MEMORY_KEY_PREFIX: &str = "gateway_ws_history"; const MAX_WS_PERSISTED_TURNS: usize = 128; const MAX_WS_SESSION_ID_LEN: usize = 128; #[derive(Debug, Default, PartialEq, Eq)] struct WsQueryParams { token: Option, session_id: Option, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)] struct WsHistoryTurn { role: String, content: String, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default, PartialEq, Eq)] struct WsPersistedHistory { version: u8, messages: Vec, } fn normalize_ws_session_id(candidate: Option<&str>) -> Option { let raw = candidate?.trim(); if raw.is_empty() || raw.len() > MAX_WS_SESSION_ID_LEN { return None; } if raw .chars() .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_') { return Some(raw.to_string()); } None } fn parse_ws_query_params(raw_query: Option<&str>) -> WsQueryParams { let Some(query) = raw_query else { return WsQueryParams::default(); }; let mut params = WsQueryParams::default(); for kv in query.split('&') { let mut parts = kv.splitn(2, '='); let key = parts.next().unwrap_or("").trim(); let value = parts.next().unwrap_or("").trim(); if value.is_empty() { continue; } match key { "token" if params.token.is_none() => { params.token = Some(value.to_string()); } "session_id" if params.session_id.is_none() => { params.session_id = normalize_ws_session_id(Some(value)); } _ => {} } } params } fn ws_history_memory_key(session_id: &str) -> String { format!("{WS_HISTORY_MEMORY_KEY_PREFIX}:{session_id}") } fn ws_history_turns_from_chat(history: &[ChatMessage]) -> Vec { let mut turns = history .iter() .filter_map(|msg| match msg.role.as_str() { "user" | "assistant" => { let content = msg.content.trim(); if content.is_empty() { None } else { Some(WsHistoryTurn { role: msg.role.clone(), content: content.to_string(), }) } } _ => None, }) .collect::>(); if turns.len() > MAX_WS_PERSISTED_TURNS { let keep_from = turns.len().saturating_sub(MAX_WS_PERSISTED_TURNS); turns.drain(0..keep_from); } turns } fn restore_chat_history(system_prompt: &str, turns: &[WsHistoryTurn]) -> Vec { let mut history = vec![ChatMessage::system(system_prompt)]; for turn in turns { match turn.role.as_str() { "user" => history.push(ChatMessage::user(&turn.content)), "assistant" => history.push(ChatMessage::assistant(&turn.content)), _ => {} } } history } async fn load_ws_history( state: &AppState, session_id: &str, system_prompt: &str, ) -> Vec { let key = ws_history_memory_key(session_id); let Some(entry) = state.mem.get(&key).await.ok().flatten() else { return vec![ChatMessage::system(system_prompt)]; }; let parsed = serde_json::from_str::(&entry.content) .map(|history| history.messages) .or_else(|_| serde_json::from_str::>(&entry.content)); match parsed { Ok(turns) => restore_chat_history(system_prompt, &turns), Err(err) => { tracing::warn!( "Failed to parse persisted websocket history for session {}: {}", session_id, err ); vec![ChatMessage::system(system_prompt)] } } } async fn persist_ws_history(state: &AppState, session_id: &str, history: &[ChatMessage]) { let payload = WsPersistedHistory { version: 1, messages: ws_history_turns_from_chat(history), }; let serialized = match serde_json::to_string(&payload) { Ok(value) => value, Err(err) => { tracing::warn!( "Failed to serialize websocket history for session {}: {}", session_id, err ); return; } }; let key = ws_history_memory_key(session_id); if let Err(err) = state .mem .store( &key, &serialized, MemoryCategory::Conversation, Some(session_id), ) .await { tracing::warn!( "Failed to persist websocket history for session {}: {}", session_id, err ); } } fn sanitize_ws_response( response: &str, tools: &[Box], leak_guard: &crate::config::OutboundLeakGuardConfig, ) -> String { match crate::channels::sanitize_channel_response(response, tools, leak_guard) { crate::channels::ChannelSanitizationResult::Sanitized(sanitized) => { if sanitized.is_empty() && !response.trim().is_empty() { "I encountered malformed tool-call output and could not produce a safe reply. Please try again." .to_string() } else { sanitized } } crate::channels::ChannelSanitizationResult::Blocked { .. } => { "I blocked a draft response because it appeared to contain credential material. Please ask for a redacted summary." .to_string() } } } fn normalize_prompt_tool_results(content: &str) -> Option { let mut cleaned_lines: Vec<&str> = Vec::new(); for line in content.lines() { let trimmed = line.trim(); if trimmed.is_empty() { continue; } if trimmed.starts_with("" { continue; } cleaned_lines.push(line.trim_end()); } if cleaned_lines.is_empty() { None } else { Some(cleaned_lines.join("\n")) } } fn extract_latest_tool_output(history: &[ChatMessage]) -> Option { for msg in history.iter().rev() { match msg.role.as_str() { "tool" => { if let Ok(value) = serde_json::from_str::(&msg.content) { if let Some(content) = value .get("content") .and_then(|v| v.as_str()) .map(str::trim) .filter(|v| !v.is_empty()) { return Some(content.to_string()); } } let trimmed = msg.content.trim(); if !trimmed.is_empty() { return Some(trimmed.to_string()); } } "user" => { if let Some(payload) = msg.content.strip_prefix("[Tool results]") { let payload = payload.trim_start_matches('\n'); if let Some(cleaned) = normalize_prompt_tool_results(payload) { return Some(cleaned); } } } _ => {} } } None } fn finalize_ws_response( response: &str, history: &[ChatMessage], tools: &[Box], leak_guard: &crate::config::OutboundLeakGuardConfig, ) -> String { let sanitized = sanitize_ws_response(response, tools, leak_guard); if !sanitized.trim().is_empty() { return sanitized; } if let Some(tool_output) = extract_latest_tool_output(history) { let excerpt = crate::util::truncate_with_ellipsis(tool_output.trim(), 1200); return format!( "Tool execution completed, but the model returned no final text response.\n\nLatest tool output:\n{excerpt}" ); } EMPTY_WS_RESPONSE_FALLBACK.to_string() } fn build_ws_system_prompt( config: &crate::config::Config, model: &str, tools_registry: &[Box], native_tools: bool, ) -> String { let mut tool_specs: Vec = tools_registry.iter().map(|tool| tool.spec()).collect(); tool_specs.sort_by(|a, b| a.name.cmp(&b.name)); let tool_descs: Vec<(&str, &str)> = tool_specs .iter() .map(|spec| (spec.name.as_str(), spec.description.as_str())) .collect(); let bootstrap_max_chars = if config.agent.compact_context { Some(6000) } else { None }; let mut prompt = crate::channels::build_system_prompt_with_mode( &config.workspace_dir, model, &tool_descs, &[], Some(&config.identity), bootstrap_max_chars, native_tools, config.skills.prompt_injection_mode, ); if !native_tools { prompt.push_str(&build_tool_instructions_from_specs(&tool_specs)); } prompt.push_str(&build_shell_policy_instructions(&config.autonomy)); prompt } fn refresh_ws_history_system_prompt_datetime(history: &mut [ChatMessage]) { if let Some(system_message) = history.first_mut() { if system_message.role == "system" { crate::agent::prompt::refresh_prompt_datetime(&mut system_message.content); } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum WsAuthRejection { MissingPairingToken, NonLocalWithoutAuthLayer, } fn evaluate_ws_auth( pairing_required: bool, is_loopback_request: bool, has_valid_pairing_token: bool, ) -> Option { if pairing_required { return (!has_valid_pairing_token).then_some(WsAuthRejection::MissingPairingToken); } if !is_loopback_request && !has_valid_pairing_token { return Some(WsAuthRejection::NonLocalWithoutAuthLayer); } None } /// GET /ws/chat — WebSocket upgrade for agent chat pub async fn handle_ws_chat( State(state): State, ConnectInfo(peer_addr): ConnectInfo, headers: HeaderMap, RawQuery(query): RawQuery, ws: WebSocketUpgrade, ) -> impl IntoResponse { let query_params = parse_ws_query_params(query.as_deref()); let token = extract_ws_bearer_token(&headers, query_params.token.as_deref()).unwrap_or_default(); let has_valid_pairing_token = !token.is_empty() && state.pairing.is_authenticated(&token); let is_loopback_request = super::is_loopback_request(Some(peer_addr), &headers, state.trust_forwarded_headers); match evaluate_ws_auth( state.pairing.require_pairing(), is_loopback_request, has_valid_pairing_token, ) { Some(WsAuthRejection::MissingPairingToken) => { return ( axum::http::StatusCode::UNAUTHORIZED, "Unauthorized — provide Authorization: Bearer , Sec-WebSocket-Protocol: bearer., or ?token=", ) .into_response(); } Some(WsAuthRejection::NonLocalWithoutAuthLayer) => { return ( axum::http::StatusCode::UNAUTHORIZED, "Unauthorized — enable gateway pairing or provide a valid paired bearer token for non-local /ws/chat access", ) .into_response(); } None => {} } let session_id = query_params .session_id .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); ws.on_upgrade(move |socket| handle_socket(socket, state, session_id)) .into_response() } async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: String) { let ws_session_id = format!("ws_{}", Uuid::new_v4()); // Build system prompt once for the session let system_prompt = { let config_guard = state.config.lock(); build_ws_system_prompt( &config_guard, &state.model, state.tools_registry_exec.as_ref(), state.provider.supports_native_tools(), ) }; // Restore persisted history (if any) and replay to the client before processing new input. let mut history = load_ws_history(&state, &session_id, &system_prompt).await; let persisted_turns = ws_history_turns_from_chat(&history); let history_payload = serde_json::json!({ "type": "history", "session_id": session_id.as_str(), "messages": persisted_turns, }); let _ = socket .send(Message::Text(history_payload.to_string().into())) .await; while let Some(msg) = socket.recv().await { let msg = match msg { Ok(Message::Text(text)) => text, Ok(Message::Close(_)) | Err(_) => break, _ => continue, }; // Parse incoming message let parsed: serde_json::Value = match serde_json::from_str(&msg) { Ok(v) => v, Err(_) => { let err = serde_json::json!({"type": "error", "message": "Invalid JSON"}); let _ = socket.send(Message::Text(err.to_string().into())).await; continue; } }; let msg_type = parsed["type"].as_str().unwrap_or(""); if msg_type != "message" { continue; } let content = parsed["content"].as_str().unwrap_or("").to_string(); if content.is_empty() { continue; } let perplexity_cfg = { state.config.lock().security.perplexity_filter.clone() }; if let Some(assessment) = crate::security::detect_adversarial_suffix(&content, &perplexity_cfg) { let err = serde_json::json!({ "type": "error", "message": format!( "Input blocked by security.perplexity_filter: perplexity={:.2} (threshold {:.2}), symbol_ratio={:.2} (threshold {:.2}), suspicious_tokens={}.", assessment.perplexity, perplexity_cfg.perplexity_threshold, assessment.symbol_ratio, perplexity_cfg.symbol_ratio_threshold, assessment.suspicious_token_count ), }); let _ = socket.send(Message::Text(err.to_string().into())).await; continue; } refresh_ws_history_system_prompt_datetime(&mut history); // Add user message to history history.push(ChatMessage::user(&content)); persist_ws_history(&state, &session_id, &history).await; // Get provider info let provider_label = state .config .lock() .default_provider .clone() .unwrap_or_else(|| "unknown".to_string()); // Broadcast agent_start event let _ = state.event_tx.send(serde_json::json!({ "type": "agent_start", "provider": provider_label, "model": state.model, })); // Full agentic loop with tools (includes WASM skills, shell, memory, etc.) match super::run_gateway_chat_with_tools(&state, &content, Some(&ws_session_id)).await { Ok(response) => { let leak_guard_cfg = { state.config.lock().security.outbound_leak_guard.clone() }; let safe_response = finalize_ws_response( &response, &history, state.tools_registry_exec.as_ref(), &leak_guard_cfg, ); // Add assistant response to history history.push(ChatMessage::assistant(&safe_response)); persist_ws_history(&state, &session_id, &history).await; // Send the full response as a done message let done = serde_json::json!({ "type": "done", "full_response": safe_response, }); let _ = socket.send(Message::Text(done.to_string().into())).await; // Broadcast agent_end event let _ = state.event_tx.send(serde_json::json!({ "type": "agent_end", "provider": provider_label, "model": state.model, })); } Err(e) => { let sanitized = crate::providers::sanitize_api_error(&e.to_string()); let err = serde_json::json!({ "type": "error", "message": sanitized, }); let _ = socket.send(Message::Text(err.to_string().into())).await; // Broadcast error event let _ = state.event_tx.send(serde_json::json!({ "type": "error", "component": "ws_chat", "message": sanitized, })); } } } } fn extract_ws_bearer_token(headers: &HeaderMap, query_token: Option<&str>) -> Option { if let Some(auth_header) = headers .get(header::AUTHORIZATION) .and_then(|value| value.to_str().ok()) .map(str::trim) { if let Some(token) = auth_header.strip_prefix("Bearer ") { if !token.trim().is_empty() { return Some(token.trim().to_string()); } } } if let Some(offered) = headers .get(header::SEC_WEBSOCKET_PROTOCOL) .and_then(|value| value.to_str().ok()) { for protocol in offered.split(',').map(str::trim).filter(|s| !s.is_empty()) { if let Some(token) = protocol.strip_prefix("bearer.") { if !token.trim().is_empty() { return Some(token.trim().to_string()); } } } } query_token .map(str::trim) .filter(|token| !token.is_empty()) .map(ToOwned::to_owned) } fn extract_query_token(raw_query: Option<&str>) -> Option { parse_ws_query_params(raw_query).token } #[cfg(test)] mod tests { use super::*; use crate::tools::{Tool, ToolResult}; use async_trait::async_trait; use axum::http::HeaderValue; #[test] fn extract_ws_bearer_token_prefers_authorization_header() { let mut headers = HeaderMap::new(); headers.insert( header::AUTHORIZATION, HeaderValue::from_static("Bearer from-auth-header"), ); headers.insert( header::SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("zeroclaw.v1, bearer.from-protocol"), ); assert_eq!( extract_ws_bearer_token(&headers, None).as_deref(), Some("from-auth-header") ); } #[test] fn extract_ws_bearer_token_reads_websocket_protocol_token() { let mut headers = HeaderMap::new(); headers.insert( header::SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("zeroclaw.v1, bearer.protocol-token"), ); assert_eq!( extract_ws_bearer_token(&headers, None).as_deref(), Some("protocol-token") ); } #[test] fn extract_ws_bearer_token_rejects_empty_tokens() { let mut headers = HeaderMap::new(); headers.insert( header::AUTHORIZATION, HeaderValue::from_static("Bearer "), ); headers.insert( header::SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("zeroclaw.v1, bearer."), ); assert!(extract_ws_bearer_token(&headers, None).is_none()); } #[test] fn extract_ws_bearer_token_reads_query_token_fallback() { let headers = HeaderMap::new(); assert_eq!( extract_ws_bearer_token(&headers, Some("query-token")).as_deref(), Some("query-token") ); } #[test] fn extract_ws_bearer_token_prefers_protocol_over_query_token() { let mut headers = HeaderMap::new(); headers.insert( header::SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("zeroclaw.v1, bearer.protocol-token"), ); assert_eq!( extract_ws_bearer_token(&headers, Some("query-token")).as_deref(), Some("protocol-token") ); } #[test] fn extract_query_token_reads_token_param() { assert_eq!( extract_query_token(Some("foo=1&token=query-token&bar=2")).as_deref(), Some("query-token") ); assert!(extract_query_token(Some("foo=1")).is_none()); } #[test] fn parse_ws_query_params_reads_token_and_session_id() { let parsed = parse_ws_query_params(Some("foo=1&session_id=sess_123&token=query-token")); assert_eq!(parsed.token.as_deref(), Some("query-token")); assert_eq!(parsed.session_id.as_deref(), Some("sess_123")); } #[test] fn parse_ws_query_params_rejects_invalid_session_id() { let parsed = parse_ws_query_params(Some("session_id=../../etc/passwd")); assert!(parsed.session_id.is_none()); } #[test] fn ws_history_turns_from_chat_skips_system_and_non_dialog_turns() { let history = vec![ ChatMessage::system("sys"), ChatMessage::user(" hello "), ChatMessage { role: "tool".to_string(), content: "ignored".to_string(), }, ChatMessage::assistant(" world "), ]; let turns = ws_history_turns_from_chat(&history); assert_eq!( turns, vec![ WsHistoryTurn { role: "user".to_string(), content: "hello".to_string() }, WsHistoryTurn { role: "assistant".to_string(), content: "world".to_string() } ] ); } #[test] fn refresh_ws_history_system_prompt_datetime_updates_only_system_entry() { let mut history = vec![ ChatMessage::system("## Current Date & Time\n\n2000-01-01 00:00:00 (UTC)\n"), ChatMessage::user("hello"), ]; refresh_ws_history_system_prompt_datetime(&mut history); assert!(!history[0].content.contains("2000-01-01 00:00:00 (UTC)")); assert_eq!(history[1].content, "hello"); } #[test] fn restore_chat_history_applies_system_prompt_once() { let turns = vec![ WsHistoryTurn { role: "user".to_string(), content: "u1".to_string(), }, WsHistoryTurn { role: "assistant".to_string(), content: "a1".to_string(), }, ]; let restored = restore_chat_history("sys", &turns); assert_eq!(restored.len(), 3); assert_eq!(restored[0].role, "system"); assert_eq!(restored[0].content, "sys"); assert_eq!(restored[1].role, "user"); assert_eq!(restored[1].content, "u1"); assert_eq!(restored[2].role, "assistant"); assert_eq!(restored[2].content, "a1"); } #[test] fn evaluate_ws_auth_requires_pairing_token_when_pairing_is_enabled() { assert_eq!( evaluate_ws_auth(true, true, false), Some(WsAuthRejection::MissingPairingToken) ); assert_eq!(evaluate_ws_auth(true, false, true), None); } #[test] fn evaluate_ws_auth_rejects_public_without_auth_layer_when_pairing_disabled() { assert_eq!( evaluate_ws_auth(false, false, false), Some(WsAuthRejection::NonLocalWithoutAuthLayer) ); } #[test] fn evaluate_ws_auth_allows_loopback_or_valid_token_when_pairing_disabled() { assert_eq!(evaluate_ws_auth(false, true, false), None); assert_eq!(evaluate_ws_auth(false, false, true), None); } struct MockScheduleTool; #[async_trait] impl Tool for MockScheduleTool { fn name(&self) -> &str { "schedule" } fn description(&self) -> &str { "Mock schedule tool" } fn parameters_schema(&self) -> serde_json::Value { serde_json::json!({ "type": "object", "properties": { "action": { "type": "string" } } }) } async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { Ok(ToolResult { success: true, output: "ok".to_string(), error: None, }) } } #[test] fn sanitize_ws_response_removes_tool_call_tags() { let input = r#"Before {"name":"schedule","arguments":{"action":"create"}} After"#; let leak_guard = crate::config::OutboundLeakGuardConfig::default(); let result = sanitize_ws_response(input, &[], &leak_guard); let normalized = result .lines() .filter(|line| !line.trim().is_empty()) .collect::>() .join("\n"); assert_eq!(normalized, "Before\nAfter"); assert!(!result.contains("")); assert!(!result.contains("\"name\":\"schedule\"")); } #[test] fn sanitize_ws_response_removes_isolated_tool_json_artifacts() { let tools: Vec> = vec![Box::new(MockScheduleTool)]; let input = r#"{"name":"schedule","parameters":{"action":"create"}} {"result":{"status":"scheduled"}} Reminder set successfully."#; let leak_guard = crate::config::OutboundLeakGuardConfig::default(); let result = sanitize_ws_response(input, &tools, &leak_guard); assert_eq!(result, "Reminder set successfully."); assert!(!result.contains("\"name\":\"schedule\"")); assert!(!result.contains("\"result\"")); } #[test] fn sanitize_ws_response_blocks_detected_credentials_when_configured() { let tools: Vec> = Vec::new(); let leak_guard = crate::config::OutboundLeakGuardConfig { enabled: true, action: crate::config::OutboundLeakGuardAction::Block, sensitivity: 0.7, }; let result = sanitize_ws_response("Temporary key: AKIAABCDEFGHIJKLMNOP", &tools, &leak_guard); assert!(result.contains("blocked a draft response")); } #[test] fn build_ws_system_prompt_includes_tool_protocol_for_prompt_mode() { let config = crate::config::Config::default(); let tools: Vec> = vec![Box::new(MockScheduleTool)]; let prompt = build_ws_system_prompt(&config, "test-model", &tools, false); assert!(prompt.contains("## Tool Use Protocol")); assert!(prompt.contains("**schedule**")); assert!(prompt.contains("## Shell Policy")); } #[test] fn build_ws_system_prompt_omits_xml_protocol_for_native_mode() { let config = crate::config::Config::default(); let tools: Vec> = vec![Box::new(MockScheduleTool)]; let prompt = build_ws_system_prompt(&config, "test-model", &tools, true); assert!(!prompt.contains("## Tool Use Protocol")); assert!(prompt.contains("**schedule**")); assert!(prompt.contains("## Shell Policy")); } #[test] fn finalize_ws_response_uses_prompt_mode_tool_output_when_final_text_empty() { let tools: Vec> = vec![Box::new(MockScheduleTool)]; let history = vec![ ChatMessage::system("sys"), ChatMessage::user( "[Tool results]\n\nDisk usage: 72%\n", ), ]; let leak_guard = crate::config::OutboundLeakGuardConfig::default(); let result = finalize_ws_response("", &history, &tools, &leak_guard); assert!(result.contains("Latest tool output:")); assert!(result.contains("Disk usage: 72%")); assert!(!result.contains("> = vec![Box::new(MockScheduleTool)]; let history = vec![ChatMessage { role: "tool".to_string(), content: r#"{"tool_call_id":"call_1","content":"Filesystem /dev/disk3s1: 210G free"}"# .to_string(), }]; let leak_guard = crate::config::OutboundLeakGuardConfig::default(); let result = finalize_ws_response("", &history, &tools, &leak_guard); assert!(result.contains("Latest tool output:")); assert!(result.contains("/dev/disk3s1")); } #[test] fn finalize_ws_response_uses_static_fallback_when_nothing_available() { let tools: Vec> = vec![Box::new(MockScheduleTool)]; let history = vec![ChatMessage::system("sys")]; let leak_guard = crate::config::OutboundLeakGuardConfig::default(); let result = finalize_ws_response("", &history, &tools, &leak_guard); assert_eq!(result, EMPTY_WS_RESPONSE_FALLBACK); } }