Sanitize WebSocket chat done responses to prevent tool artifact leaks
This commit is contained in:
parent
3b6786d0d7
commit
1f257d7bf8
@ -22,6 +22,16 @@ use axum::{
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
fn sanitize_ws_response(response: &str, tools: &[Box<dyn crate::tools::Tool>]) -> String {
|
||||
let sanitized = crate::channels::sanitize_channel_response(response, tools);
|
||||
if sanitized.is_empty() && !response.trim().is_empty() {
|
||||
"I encountered malformed tool-call output and could not produce a safe reply. Please try again."
|
||||
.to_string()
|
||||
} else {
|
||||
sanitized
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /ws/chat — WebSocket upgrade for agent chat
|
||||
pub async fn handle_ws_chat(
|
||||
State(state): State<AppState>,
|
||||
@ -137,13 +147,15 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let safe_response =
|
||||
sanitize_ws_response(&response, state.tools_registry_exec.as_ref());
|
||||
// Add assistant response to history
|
||||
history.push(ChatMessage::assistant(&response));
|
||||
history.push(ChatMessage::assistant(&safe_response));
|
||||
|
||||
// Send the full response as a done message
|
||||
let done = serde_json::json!({
|
||||
"type": "done",
|
||||
"full_response": response,
|
||||
"full_response": safe_response,
|
||||
});
|
||||
let _ = socket.send(Message::Text(done.to_string().into())).await;
|
||||
|
||||
@ -204,6 +216,8 @@ fn extract_ws_bearer_token(headers: &HeaderMap) -> Option<String> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use axum::http::HeaderValue;
|
||||
|
||||
#[test]
|
||||
@ -252,4 +266,66 @@ mod tests {
|
||||
|
||||
assert!(extract_ws_bearer_token(&headers).is_none());
|
||||
}
|
||||
|
||||
struct MockScheduleTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MockScheduleTool {
|
||||
fn name(&self) -> &str {
|
||||
"schedule"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Mock schedule tool"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": { "type": "string" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: "ok".to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_ws_response_removes_tool_call_tags() {
|
||||
let input = r#"Before
|
||||
<tool_call>
|
||||
{"name":"schedule","arguments":{"action":"create"}}
|
||||
</tool_call>
|
||||
After"#;
|
||||
|
||||
let result = sanitize_ws_response(input, &[]);
|
||||
let normalized = result
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
assert_eq!(normalized, "Before\nAfter");
|
||||
assert!(!result.contains("<tool_call>"));
|
||||
assert!(!result.contains("\"name\":\"schedule\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_ws_response_removes_isolated_tool_json_artifacts() {
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
||||
let input = r#"{"name":"schedule","parameters":{"action":"create"}}
|
||||
{"result":{"status":"scheduled"}}
|
||||
Reminder set successfully."#;
|
||||
|
||||
let result = sanitize_ws_response(input, &tools);
|
||||
assert_eq!(result, "Reminder set successfully.");
|
||||
assert!(!result.contains("\"name\":\"schedule\""));
|
||||
assert!(!result.contains("\"result\""));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user