diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 5a789fbe7..20b618119 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -10,7 +10,7 @@ //! ``` use super::AppState; -use crate::agent::loop_::run_tool_call_loop; +use crate::agent::loop_::{run_tool_call_loop, DRAFT_CLEAR_SENTINEL, DRAFT_PROGRESS_SENTINEL}; use crate::approval::ApprovalManager; use crate::providers::ChatMessage; use axum::{ @@ -21,11 +21,26 @@ use axum::{ http::{header, HeaderMap}, response::IntoResponse, }; +use serde_json::json; const EMPTY_WS_RESPONSE_FALLBACK: &str = "Tool execution completed, but the model returned no final text response. Please ask me to summarize the result."; const WS_CHAT_SUBPROTOCOL: &str = "zeroclaw.v1"; +#[derive(Debug, Clone, PartialEq, Eq)] +enum WsDeltaEvent { + ContentChunk(String), + ToolCall { + name: String, + hint: Option, + }, + ToolResult { + name: String, + success: bool, + duration_secs: Option, + }, +} + fn sanitize_ws_response(response: &str, tools: &[Box]) -> String { let sanitized = crate::channels::sanitize_channel_response(response, tools); if sanitized.is_empty() && !response.trim().is_empty() { @@ -112,6 +127,109 @@ fn finalize_ws_response( EMPTY_WS_RESPONSE_FALLBACK.to_string() } +fn parse_tool_completion_payload(raw: &str) -> Option<(String, Option)> { + let trimmed = raw.trim(); + let (name_part, duration_part) = trimmed.rsplit_once(" (")?; + let duration_part = duration_part.strip_suffix(')')?; + let secs = duration_part.strip_suffix('s')?.parse::().ok(); + Some((name_part.trim().to_string(), secs)) +} + +fn parse_ws_delta_event(delta: &str) -> Option { + if delta == DRAFT_CLEAR_SENTINEL { + return None; + } + + if let Some(progress) = delta.strip_prefix(DRAFT_PROGRESS_SENTINEL) { + let progress = progress.trim(); + if let Some(rest) = progress.strip_prefix("⏳ ") { + let rest = rest.trim(); + if rest.is_empty() { + return None; + } + let (name, hint) = match rest.split_once(": ") { + Some((name, hint)) => { + let hint = hint.trim(); + ( + name.trim().to_string(), + if hint.is_empty() { + None + } else { + Some(hint.to_string()) + }, + ) + } + None => (rest.to_string(), None), + }; + return Some(WsDeltaEvent::ToolCall { name, hint }); + } + + if let Some(rest) = progress.strip_prefix("✅ ") { + if let Some((name, duration_secs)) = parse_tool_completion_payload(rest) { + return Some(WsDeltaEvent::ToolResult { + name, + success: true, + duration_secs, + }); + } + } + + if let Some(rest) = progress.strip_prefix("❌ ") { + if let Some((name, duration_secs)) = parse_tool_completion_payload(rest) { + return Some(WsDeltaEvent::ToolResult { + name, + success: false, + duration_secs, + }); + } + } + + return None; + } + + if delta.is_empty() { + None + } else { + Some(WsDeltaEvent::ContentChunk(delta.to_string())) + } +} + +async fn emit_ws_delta_event(socket: &mut WebSocket, event: WsDeltaEvent) { + let payload = match event { + WsDeltaEvent::ContentChunk(content) => json!({ + "type": "chunk", + "content": content, + }), + WsDeltaEvent::ToolCall { name, hint } => json!({ + "type": "tool_call", + "name": name, + "args": { + "hint": hint, + }, + }), + WsDeltaEvent::ToolResult { + name, + success, + duration_secs, + } => { + let status = if success { "ok" } else { "error" }; + let output = match duration_secs { + Some(secs) => format!("{status} ({secs}s)"), + None => status.to_string(), + }; + json!({ + "type": "tool_result", + "name": name, + "success": success, + "duration_secs": duration_secs, + "output": output, + }) + } + }; + + let _ = socket.send(Message::Text(payload.to_string().into())).await; +} + /// GET /ws/chat — WebSocket upgrade for agent chat pub async fn handle_ws_chat( State(state): State, @@ -205,26 +323,50 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) { "model": state.model, })); - // Run the agent loop with tool execution - let result = run_tool_call_loop( - state.provider.as_ref(), - &mut history, - state.tools_registry_exec.as_ref(), - state.observer.as_ref(), - &provider_label, - &state.model, - state.temperature, - true, // silent - no console output - Some(&approval_manager), - "webchat", - &state.multimodal, - state.max_tool_iterations, - None, // cancellation token - None, // delta streaming - None, // hooks - &[], // excluded tools - ) - .await; + // Run the agent loop with real-time delta streaming for web clients. + let result = { + let (delta_tx, mut delta_rx) = tokio::sync::mpsc::channel::(128); + let mut loop_future = std::pin::pin!(run_tool_call_loop( + state.provider.as_ref(), + &mut history, + state.tools_registry_exec.as_ref(), + state.observer.as_ref(), + &provider_label, + &state.model, + state.temperature, + true, // silent - no console output + Some(&approval_manager), + "webchat", + &state.multimodal, + state.max_tool_iterations, + None, // cancellation token + Some(delta_tx), // delta streaming + None, // hooks + &[], // excluded tools + )); + + loop { + tokio::select! { + maybe_delta = delta_rx.recv() => { + if let Some(delta) = maybe_delta { + if let Some(event) = parse_ws_delta_event(&delta) { + emit_ws_delta_event(&mut socket, event).await; + } + } else { + break loop_future.await; + } + } + response = &mut loop_future => { + while let Ok(delta) = delta_rx.try_recv() { + if let Some(event) = parse_ws_delta_event(&delta) { + emit_ws_delta_event(&mut socket, event).await; + } + } + break response; + } + } + } + }; match result { Ok(response) => { @@ -319,6 +461,40 @@ mod tests { ); } + #[test] + fn parse_ws_delta_event_maps_tool_start() { + let delta = format!("{DRAFT_PROGRESS_SENTINEL}⏳ shell: ls -la\n"); + assert_eq!( + parse_ws_delta_event(&delta), + Some(WsDeltaEvent::ToolCall { + name: "shell".to_string(), + hint: Some("ls -la".to_string()), + }) + ); + } + + #[test] + fn parse_ws_delta_event_maps_tool_success() { + let delta = format!("{DRAFT_PROGRESS_SENTINEL}✅ shell (2s)\n"); + assert_eq!( + parse_ws_delta_event(&delta), + Some(WsDeltaEvent::ToolResult { + name: "shell".to_string(), + success: true, + duration_secs: Some(2), + }) + ); + } + + #[test] + fn parse_ws_delta_event_treats_plain_text_as_chunk() { + let delta = "partial response ".to_string(); + assert_eq!( + parse_ws_delta_event(&delta), + Some(WsDeltaEvent::ContentChunk(delta)) + ); + } + #[test] fn extract_ws_bearer_token_reads_websocket_protocol_token() { let mut headers = HeaderMap::new();