Merge pull request #2798 from zeroclaw-labs/issue-2786-streaming-tool-events-dev
feat(gateway): stream chunk and tool events over websocket
This commit is contained in:
commit
fc995b9446
@ -10,7 +10,7 @@
|
||||
//! ```
|
||||
|
||||
use super::AppState;
|
||||
use crate::agent::loop_::run_tool_call_loop;
|
||||
use crate::agent::loop_::{run_tool_call_loop, DRAFT_CLEAR_SENTINEL, DRAFT_PROGRESS_SENTINEL};
|
||||
use crate::approval::ApprovalManager;
|
||||
use crate::providers::ChatMessage;
|
||||
use axum::{
|
||||
@ -21,11 +21,26 @@ use axum::{
|
||||
http::{header, HeaderMap},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
const EMPTY_WS_RESPONSE_FALLBACK: &str =
|
||||
"Tool execution completed, but the model returned no final text response. Please ask me to summarize the result.";
|
||||
const WS_CHAT_SUBPROTOCOL: &str = "zeroclaw.v1";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum WsDeltaEvent {
|
||||
ContentChunk(String),
|
||||
ToolCall {
|
||||
name: String,
|
||||
hint: Option<String>,
|
||||
},
|
||||
ToolResult {
|
||||
name: String,
|
||||
success: bool,
|
||||
duration_secs: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
fn sanitize_ws_response(response: &str, tools: &[Box<dyn crate::tools::Tool>]) -> String {
|
||||
let sanitized = crate::channels::sanitize_channel_response(response, tools);
|
||||
if sanitized.is_empty() && !response.trim().is_empty() {
|
||||
@ -112,6 +127,109 @@ fn finalize_ws_response(
|
||||
EMPTY_WS_RESPONSE_FALLBACK.to_string()
|
||||
}
|
||||
|
||||
fn parse_tool_completion_payload(raw: &str) -> Option<(String, Option<u64>)> {
|
||||
let trimmed = raw.trim();
|
||||
let (name_part, duration_part) = trimmed.rsplit_once(" (")?;
|
||||
let duration_part = duration_part.strip_suffix(')')?;
|
||||
let secs = duration_part.strip_suffix('s')?.parse::<u64>().ok();
|
||||
Some((name_part.trim().to_string(), secs))
|
||||
}
|
||||
|
||||
fn parse_ws_delta_event(delta: &str) -> Option<WsDeltaEvent> {
|
||||
if delta == DRAFT_CLEAR_SENTINEL {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(progress) = delta.strip_prefix(DRAFT_PROGRESS_SENTINEL) {
|
||||
let progress = progress.trim();
|
||||
if let Some(rest) = progress.strip_prefix("⏳ ") {
|
||||
let rest = rest.trim();
|
||||
if rest.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let (name, hint) = match rest.split_once(": ") {
|
||||
Some((name, hint)) => {
|
||||
let hint = hint.trim();
|
||||
(
|
||||
name.trim().to_string(),
|
||||
if hint.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(hint.to_string())
|
||||
},
|
||||
)
|
||||
}
|
||||
None => (rest.to_string(), None),
|
||||
};
|
||||
return Some(WsDeltaEvent::ToolCall { name, hint });
|
||||
}
|
||||
|
||||
if let Some(rest) = progress.strip_prefix("✅ ") {
|
||||
if let Some((name, duration_secs)) = parse_tool_completion_payload(rest) {
|
||||
return Some(WsDeltaEvent::ToolResult {
|
||||
name,
|
||||
success: true,
|
||||
duration_secs,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rest) = progress.strip_prefix("❌ ") {
|
||||
if let Some((name, duration_secs)) = parse_tool_completion_payload(rest) {
|
||||
return Some(WsDeltaEvent::ToolResult {
|
||||
name,
|
||||
success: false,
|
||||
duration_secs,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
|
||||
if delta.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(WsDeltaEvent::ContentChunk(delta.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_ws_delta_event(socket: &mut WebSocket, event: WsDeltaEvent) {
|
||||
let payload = match event {
|
||||
WsDeltaEvent::ContentChunk(content) => json!({
|
||||
"type": "chunk",
|
||||
"content": content,
|
||||
}),
|
||||
WsDeltaEvent::ToolCall { name, hint } => json!({
|
||||
"type": "tool_call",
|
||||
"name": name,
|
||||
"args": {
|
||||
"hint": hint,
|
||||
},
|
||||
}),
|
||||
WsDeltaEvent::ToolResult {
|
||||
name,
|
||||
success,
|
||||
duration_secs,
|
||||
} => {
|
||||
let status = if success { "ok" } else { "error" };
|
||||
let output = match duration_secs {
|
||||
Some(secs) => format!("{status} ({secs}s)"),
|
||||
None => status.to_string(),
|
||||
};
|
||||
json!({
|
||||
"type": "tool_result",
|
||||
"name": name,
|
||||
"success": success,
|
||||
"duration_secs": duration_secs,
|
||||
"output": output,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let _ = socket.send(Message::Text(payload.to_string().into())).await;
|
||||
}
|
||||
|
||||
/// GET /ws/chat — WebSocket upgrade for agent chat
|
||||
pub async fn handle_ws_chat(
|
||||
State(state): State<AppState>,
|
||||
@ -205,26 +323,50 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
"model": state.model,
|
||||
}));
|
||||
|
||||
// Run the agent loop with tool execution
|
||||
let result = run_tool_call_loop(
|
||||
state.provider.as_ref(),
|
||||
&mut history,
|
||||
state.tools_registry_exec.as_ref(),
|
||||
state.observer.as_ref(),
|
||||
&provider_label,
|
||||
&state.model,
|
||||
state.temperature,
|
||||
true, // silent - no console output
|
||||
Some(&approval_manager),
|
||||
"webchat",
|
||||
&state.multimodal,
|
||||
state.max_tool_iterations,
|
||||
None, // cancellation token
|
||||
None, // delta streaming
|
||||
None, // hooks
|
||||
&[], // excluded tools
|
||||
)
|
||||
.await;
|
||||
// Run the agent loop with real-time delta streaming for web clients.
|
||||
let result = {
|
||||
let (delta_tx, mut delta_rx) = tokio::sync::mpsc::channel::<String>(128);
|
||||
let mut loop_future = std::pin::pin!(run_tool_call_loop(
|
||||
state.provider.as_ref(),
|
||||
&mut history,
|
||||
state.tools_registry_exec.as_ref(),
|
||||
state.observer.as_ref(),
|
||||
&provider_label,
|
||||
&state.model,
|
||||
state.temperature,
|
||||
true, // silent - no console output
|
||||
Some(&approval_manager),
|
||||
"webchat",
|
||||
&state.multimodal,
|
||||
state.max_tool_iterations,
|
||||
None, // cancellation token
|
||||
Some(delta_tx), // delta streaming
|
||||
None, // hooks
|
||||
&[], // excluded tools
|
||||
));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_delta = delta_rx.recv() => {
|
||||
if let Some(delta) = maybe_delta {
|
||||
if let Some(event) = parse_ws_delta_event(&delta) {
|
||||
emit_ws_delta_event(&mut socket, event).await;
|
||||
}
|
||||
} else {
|
||||
break loop_future.await;
|
||||
}
|
||||
}
|
||||
response = &mut loop_future => {
|
||||
while let Ok(delta) = delta_rx.try_recv() {
|
||||
if let Some(event) = parse_ws_delta_event(&delta) {
|
||||
emit_ws_delta_event(&mut socket, event).await;
|
||||
}
|
||||
}
|
||||
break response;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
@ -319,6 +461,40 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ws_delta_event_maps_tool_start() {
|
||||
let delta = format!("{DRAFT_PROGRESS_SENTINEL}⏳ shell: ls -la\n");
|
||||
assert_eq!(
|
||||
parse_ws_delta_event(&delta),
|
||||
Some(WsDeltaEvent::ToolCall {
|
||||
name: "shell".to_string(),
|
||||
hint: Some("ls -la".to_string()),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ws_delta_event_maps_tool_success() {
|
||||
let delta = format!("{DRAFT_PROGRESS_SENTINEL}✅ shell (2s)\n");
|
||||
assert_eq!(
|
||||
parse_ws_delta_event(&delta),
|
||||
Some(WsDeltaEvent::ToolResult {
|
||||
name: "shell".to_string(),
|
||||
success: true,
|
||||
duration_secs: Some(2),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ws_delta_event_treats_plain_text_as_chunk() {
|
||||
let delta = "partial response ".to_string();
|
||||
assert_eq!(
|
||||
parse_ws_delta_event(&delta),
|
||||
Some(WsDeltaEvent::ContentChunk(delta))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_bearer_token_reads_websocket_protocol_token() {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user