diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 20591f8ee..b0c1e1874 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -22,6 +22,16 @@ use axum::{ response::IntoResponse, }; +fn sanitize_ws_response(response: &str, tools: &[Box]) -> String { + let sanitized = crate::channels::sanitize_channel_response(response, tools); + if sanitized.is_empty() && !response.trim().is_empty() { + "I encountered malformed tool-call output and could not produce a safe reply. Please try again." + .to_string() + } else { + sanitized + } +} + /// GET /ws/chat — WebSocket upgrade for agent chat pub async fn handle_ws_chat( State(state): State, @@ -137,13 +147,15 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) { match result { Ok(response) => { + let safe_response = + sanitize_ws_response(&response, state.tools_registry_exec.as_ref()); // Add assistant response to history - history.push(ChatMessage::assistant(&response)); + history.push(ChatMessage::assistant(&safe_response)); // Send the full response as a done message let done = serde_json::json!({ "type": "done", - "full_response": response, + "full_response": safe_response, }); let _ = socket.send(Message::Text(done.to_string().into())).await; @@ -204,6 +216,8 @@ fn extract_ws_bearer_token(headers: &HeaderMap) -> Option { #[cfg(test)] mod tests { use super::*; + use crate::tools::{Tool, ToolResult}; + use async_trait::async_trait; use axum::http::HeaderValue; #[test] @@ -252,4 +266,66 @@ mod tests { assert!(extract_ws_bearer_token(&headers).is_none()); } + + struct MockScheduleTool; + + #[async_trait] + impl Tool for MockScheduleTool { + fn name(&self) -> &str { + "schedule" + } + + fn description(&self) -> &str { + "Mock schedule tool" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "action": { "type": "string" } + } + }) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(ToolResult { + success: true, + output: "ok".to_string(), + error: None, + }) + } + } + + #[test] + fn sanitize_ws_response_removes_tool_call_tags() { + let input = r#"Before + +{"name":"schedule","arguments":{"action":"create"}} + +After"#; + + let result = sanitize_ws_response(input, &[]); + let normalized = result + .lines() + .filter(|line| !line.trim().is_empty()) + .collect::>() + .join("\n"); + assert_eq!(normalized, "Before\nAfter"); + assert!(!result.contains("")); + assert!(!result.contains("\"name\":\"schedule\"")); + } + + #[test] + fn sanitize_ws_response_removes_isolated_tool_json_artifacts() { + let tools: Vec> = vec![Box::new(MockScheduleTool)]; + let input = r#"{"name":"schedule","parameters":{"action":"create"}} +{"result":{"status":"scheduled"}} +Reminder set successfully."#; + + let result = sanitize_ws_response(input, &tools); + assert_eq!(result, "Reminder set successfully."); + assert!(!result.contains("\"name\":\"schedule\"")); + assert!(!result.contains("\"result\"")); + } }