915 lines
30 KiB
Rust
915 lines
30 KiB
Rust
//! WebSocket agent chat handler.
|
|
//!
|
|
//! Protocol:
|
|
//! ```text
|
|
//! Client -> Server: {"type":"message","content":"Hello"}
|
|
//! Server -> Client: {"type":"chunk","content":"Hi! "}
|
|
//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
|
|
//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
|
|
//! Server -> Client: {"type":"done","full_response":"..."}
|
|
//! ```
|
|
|
|
use super::AppState;
|
|
use crate::agent::loop_::{build_shell_policy_instructions, build_tool_instructions_from_specs};
|
|
use crate::memory::MemoryCategory;
|
|
use crate::providers::ChatMessage;
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket},
|
|
ConnectInfo, RawQuery, State, WebSocketUpgrade,
|
|
},
|
|
http::{header, HeaderMap},
|
|
response::IntoResponse,
|
|
};
|
|
use std::net::SocketAddr;
|
|
use uuid::Uuid;
|
|
|
|
const EMPTY_WS_RESPONSE_FALLBACK: &str =
|
|
"Tool execution completed, but the model returned no final text response. Please ask me to summarize the result.";
|
|
const WS_HISTORY_MEMORY_KEY_PREFIX: &str = "gateway_ws_history";
|
|
const MAX_WS_PERSISTED_TURNS: usize = 128;
|
|
const MAX_WS_SESSION_ID_LEN: usize = 128;
|
|
|
|
#[derive(Debug, Default, PartialEq, Eq)]
|
|
struct WsQueryParams {
|
|
token: Option<String>,
|
|
session_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
|
struct WsHistoryTurn {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default, PartialEq, Eq)]
|
|
struct WsPersistedHistory {
|
|
version: u8,
|
|
messages: Vec<WsHistoryTurn>,
|
|
}
|
|
|
|
fn normalize_ws_session_id(candidate: Option<&str>) -> Option<String> {
|
|
let raw = candidate?.trim();
|
|
if raw.is_empty() || raw.len() > MAX_WS_SESSION_ID_LEN {
|
|
return None;
|
|
}
|
|
|
|
if raw
|
|
.chars()
|
|
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_')
|
|
{
|
|
return Some(raw.to_string());
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
fn parse_ws_query_params(raw_query: Option<&str>) -> WsQueryParams {
|
|
let Some(query) = raw_query else {
|
|
return WsQueryParams::default();
|
|
};
|
|
|
|
let mut params = WsQueryParams::default();
|
|
for kv in query.split('&') {
|
|
let mut parts = kv.splitn(2, '=');
|
|
let key = parts.next().unwrap_or("").trim();
|
|
let value = parts.next().unwrap_or("").trim();
|
|
if value.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
match key {
|
|
"token" if params.token.is_none() => {
|
|
params.token = Some(value.to_string());
|
|
}
|
|
"session_id" if params.session_id.is_none() => {
|
|
params.session_id = normalize_ws_session_id(Some(value));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
params
|
|
}
|
|
|
|
fn ws_history_memory_key(session_id: &str) -> String {
|
|
format!("{WS_HISTORY_MEMORY_KEY_PREFIX}:{session_id}")
|
|
}
|
|
|
|
fn ws_history_turns_from_chat(history: &[ChatMessage]) -> Vec<WsHistoryTurn> {
|
|
let mut turns = history
|
|
.iter()
|
|
.filter_map(|msg| match msg.role.as_str() {
|
|
"user" | "assistant" => {
|
|
let content = msg.content.trim();
|
|
if content.is_empty() {
|
|
None
|
|
} else {
|
|
Some(WsHistoryTurn {
|
|
role: msg.role.clone(),
|
|
content: content.to_string(),
|
|
})
|
|
}
|
|
}
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
if turns.len() > MAX_WS_PERSISTED_TURNS {
|
|
let keep_from = turns.len().saturating_sub(MAX_WS_PERSISTED_TURNS);
|
|
turns.drain(0..keep_from);
|
|
}
|
|
turns
|
|
}
|
|
|
|
fn restore_chat_history(system_prompt: &str, turns: &[WsHistoryTurn]) -> Vec<ChatMessage> {
|
|
let mut history = vec![ChatMessage::system(system_prompt)];
|
|
for turn in turns {
|
|
match turn.role.as_str() {
|
|
"user" => history.push(ChatMessage::user(&turn.content)),
|
|
"assistant" => history.push(ChatMessage::assistant(&turn.content)),
|
|
_ => {}
|
|
}
|
|
}
|
|
history
|
|
}
|
|
|
|
async fn load_ws_history(
|
|
state: &AppState,
|
|
session_id: &str,
|
|
system_prompt: &str,
|
|
) -> Vec<ChatMessage> {
|
|
let key = ws_history_memory_key(session_id);
|
|
let Some(entry) = state.mem.get(&key).await.ok().flatten() else {
|
|
return vec![ChatMessage::system(system_prompt)];
|
|
};
|
|
|
|
let parsed = serde_json::from_str::<WsPersistedHistory>(&entry.content)
|
|
.map(|history| history.messages)
|
|
.or_else(|_| serde_json::from_str::<Vec<WsHistoryTurn>>(&entry.content));
|
|
|
|
match parsed {
|
|
Ok(turns) => restore_chat_history(system_prompt, &turns),
|
|
Err(err) => {
|
|
tracing::warn!(
|
|
"Failed to parse persisted websocket history for session {}: {}",
|
|
session_id,
|
|
err
|
|
);
|
|
vec![ChatMessage::system(system_prompt)]
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn persist_ws_history(state: &AppState, session_id: &str, history: &[ChatMessage]) {
|
|
let payload = WsPersistedHistory {
|
|
version: 1,
|
|
messages: ws_history_turns_from_chat(history),
|
|
};
|
|
let serialized = match serde_json::to_string(&payload) {
|
|
Ok(value) => value,
|
|
Err(err) => {
|
|
tracing::warn!(
|
|
"Failed to serialize websocket history for session {}: {}",
|
|
session_id,
|
|
err
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
|
|
let key = ws_history_memory_key(session_id);
|
|
if let Err(err) = state
|
|
.mem
|
|
.store(
|
|
&key,
|
|
&serialized,
|
|
MemoryCategory::Conversation,
|
|
Some(session_id),
|
|
)
|
|
.await
|
|
{
|
|
tracing::warn!(
|
|
"Failed to persist websocket history for session {}: {}",
|
|
session_id,
|
|
err
|
|
);
|
|
}
|
|
}
|
|
|
|
fn sanitize_ws_response(
|
|
response: &str,
|
|
tools: &[Box<dyn crate::tools::Tool>],
|
|
leak_guard: &crate::config::OutboundLeakGuardConfig,
|
|
) -> String {
|
|
match crate::channels::sanitize_channel_response(response, tools, leak_guard) {
|
|
crate::channels::ChannelSanitizationResult::Sanitized(sanitized) => {
|
|
if sanitized.is_empty() && !response.trim().is_empty() {
|
|
"I encountered malformed tool-call output and could not produce a safe reply. Please try again."
|
|
.to_string()
|
|
} else {
|
|
sanitized
|
|
}
|
|
}
|
|
crate::channels::ChannelSanitizationResult::Blocked { .. } => {
|
|
"I blocked a draft response because it appeared to contain credential material. Please ask for a redacted summary."
|
|
.to_string()
|
|
}
|
|
}
|
|
}
|
|
|
|
fn normalize_prompt_tool_results(content: &str) -> Option<String> {
|
|
let mut cleaned_lines: Vec<&str> = Vec::new();
|
|
|
|
for line in content.lines() {
|
|
let trimmed = line.trim();
|
|
if trimmed.is_empty() {
|
|
continue;
|
|
}
|
|
if trimmed.starts_with("<tool_result") || trimmed == "</tool_result>" {
|
|
continue;
|
|
}
|
|
cleaned_lines.push(line.trim_end());
|
|
}
|
|
|
|
if cleaned_lines.is_empty() {
|
|
None
|
|
} else {
|
|
Some(cleaned_lines.join("\n"))
|
|
}
|
|
}
|
|
|
|
fn extract_latest_tool_output(history: &[ChatMessage]) -> Option<String> {
|
|
for msg in history.iter().rev() {
|
|
match msg.role.as_str() {
|
|
"tool" => {
|
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&msg.content) {
|
|
if let Some(content) = value
|
|
.get("content")
|
|
.and_then(|v| v.as_str())
|
|
.map(str::trim)
|
|
.filter(|v| !v.is_empty())
|
|
{
|
|
return Some(content.to_string());
|
|
}
|
|
}
|
|
|
|
let trimmed = msg.content.trim();
|
|
if !trimmed.is_empty() {
|
|
return Some(trimmed.to_string());
|
|
}
|
|
}
|
|
"user" => {
|
|
if let Some(payload) = msg.content.strip_prefix("[Tool results]") {
|
|
let payload = payload.trim_start_matches('\n');
|
|
if let Some(cleaned) = normalize_prompt_tool_results(payload) {
|
|
return Some(cleaned);
|
|
}
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
fn finalize_ws_response(
|
|
response: &str,
|
|
history: &[ChatMessage],
|
|
tools: &[Box<dyn crate::tools::Tool>],
|
|
leak_guard: &crate::config::OutboundLeakGuardConfig,
|
|
) -> String {
|
|
let sanitized = sanitize_ws_response(response, tools, leak_guard);
|
|
if !sanitized.trim().is_empty() {
|
|
return sanitized;
|
|
}
|
|
|
|
if let Some(tool_output) = extract_latest_tool_output(history) {
|
|
let excerpt = crate::util::truncate_with_ellipsis(tool_output.trim(), 1200);
|
|
return format!(
|
|
"Tool execution completed, but the model returned no final text response.\n\nLatest tool output:\n{excerpt}"
|
|
);
|
|
}
|
|
|
|
EMPTY_WS_RESPONSE_FALLBACK.to_string()
|
|
}
|
|
|
|
fn build_ws_system_prompt(
|
|
config: &crate::config::Config,
|
|
model: &str,
|
|
tools_registry: &[Box<dyn crate::tools::Tool>],
|
|
native_tools: bool,
|
|
) -> String {
|
|
let mut tool_specs: Vec<crate::tools::ToolSpec> =
|
|
tools_registry.iter().map(|tool| tool.spec()).collect();
|
|
tool_specs.sort_by(|a, b| a.name.cmp(&b.name));
|
|
|
|
let tool_descs: Vec<(&str, &str)> = tool_specs
|
|
.iter()
|
|
.map(|spec| (spec.name.as_str(), spec.description.as_str()))
|
|
.collect();
|
|
|
|
let bootstrap_max_chars = if config.agent.compact_context {
|
|
Some(6000)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut prompt = crate::channels::build_system_prompt_with_mode(
|
|
&config.workspace_dir,
|
|
model,
|
|
&tool_descs,
|
|
&[],
|
|
Some(&config.identity),
|
|
bootstrap_max_chars,
|
|
native_tools,
|
|
config.skills.prompt_injection_mode,
|
|
);
|
|
if !native_tools {
|
|
prompt.push_str(&build_tool_instructions_from_specs(&tool_specs));
|
|
}
|
|
prompt.push_str(&build_shell_policy_instructions(&config.autonomy));
|
|
|
|
prompt
|
|
}
|
|
|
|
fn refresh_ws_history_system_prompt_datetime(history: &mut [ChatMessage]) {
|
|
if let Some(system_message) = history.first_mut() {
|
|
if system_message.role == "system" {
|
|
crate::agent::prompt::refresh_prompt_datetime(&mut system_message.content);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
enum WsAuthRejection {
|
|
MissingPairingToken,
|
|
NonLocalWithoutAuthLayer,
|
|
}
|
|
|
|
fn evaluate_ws_auth(
|
|
pairing_required: bool,
|
|
is_loopback_request: bool,
|
|
has_valid_pairing_token: bool,
|
|
) -> Option<WsAuthRejection> {
|
|
if pairing_required {
|
|
return (!has_valid_pairing_token).then_some(WsAuthRejection::MissingPairingToken);
|
|
}
|
|
|
|
if !is_loopback_request && !has_valid_pairing_token {
|
|
return Some(WsAuthRejection::NonLocalWithoutAuthLayer);
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// GET /ws/chat — WebSocket upgrade for agent chat
|
|
pub async fn handle_ws_chat(
|
|
State(state): State<AppState>,
|
|
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
|
headers: HeaderMap,
|
|
RawQuery(query): RawQuery,
|
|
ws: WebSocketUpgrade,
|
|
) -> impl IntoResponse {
|
|
let query_params = parse_ws_query_params(query.as_deref());
|
|
let token =
|
|
extract_ws_bearer_token(&headers, query_params.token.as_deref()).unwrap_or_default();
|
|
let has_valid_pairing_token = !token.is_empty() && state.pairing.is_authenticated(&token);
|
|
let is_loopback_request =
|
|
super::is_loopback_request(Some(peer_addr), &headers, state.trust_forwarded_headers);
|
|
|
|
match evaluate_ws_auth(
|
|
state.pairing.require_pairing(),
|
|
is_loopback_request,
|
|
has_valid_pairing_token,
|
|
) {
|
|
Some(WsAuthRejection::MissingPairingToken) => {
|
|
return (
|
|
axum::http::StatusCode::UNAUTHORIZED,
|
|
"Unauthorized — provide Authorization: Bearer <token>, Sec-WebSocket-Protocol: bearer.<token>, or ?token=<token>",
|
|
)
|
|
.into_response();
|
|
}
|
|
Some(WsAuthRejection::NonLocalWithoutAuthLayer) => {
|
|
return (
|
|
axum::http::StatusCode::UNAUTHORIZED,
|
|
"Unauthorized — enable gateway pairing or provide a valid paired bearer token for non-local /ws/chat access",
|
|
)
|
|
.into_response();
|
|
}
|
|
None => {}
|
|
}
|
|
|
|
let session_id = query_params
|
|
.session_id
|
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
|
|
|
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
|
|
.into_response()
|
|
}
|
|
|
|
async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: String) {
|
|
let ws_session_id = format!("ws_{}", Uuid::new_v4());
|
|
|
|
// Build system prompt once for the session
|
|
let system_prompt = {
|
|
let config_guard = state.config.lock();
|
|
build_ws_system_prompt(
|
|
&config_guard,
|
|
&state.model,
|
|
state.tools_registry_exec.as_ref(),
|
|
state.provider.supports_native_tools(),
|
|
)
|
|
};
|
|
|
|
// Restore persisted history (if any) and replay to the client before processing new input.
|
|
let mut history = load_ws_history(&state, &session_id, &system_prompt).await;
|
|
let persisted_turns = ws_history_turns_from_chat(&history);
|
|
let history_payload = serde_json::json!({
|
|
"type": "history",
|
|
"session_id": session_id.as_str(),
|
|
"messages": persisted_turns,
|
|
});
|
|
let _ = socket
|
|
.send(Message::Text(history_payload.to_string().into()))
|
|
.await;
|
|
|
|
while let Some(msg) = socket.recv().await {
|
|
let msg = match msg {
|
|
Ok(Message::Text(text)) => text,
|
|
Ok(Message::Close(_)) | Err(_) => break,
|
|
_ => continue,
|
|
};
|
|
|
|
// Parse incoming message
|
|
let parsed: serde_json::Value = match serde_json::from_str(&msg) {
|
|
Ok(v) => v,
|
|
Err(_) => {
|
|
let err = serde_json::json!({"type": "error", "message": "Invalid JSON"});
|
|
let _ = socket.send(Message::Text(err.to_string().into())).await;
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let msg_type = parsed["type"].as_str().unwrap_or("");
|
|
if msg_type != "message" {
|
|
continue;
|
|
}
|
|
|
|
let content = parsed["content"].as_str().unwrap_or("").to_string();
|
|
if content.is_empty() {
|
|
continue;
|
|
}
|
|
let perplexity_cfg = { state.config.lock().security.perplexity_filter.clone() };
|
|
if let Some(assessment) =
|
|
crate::security::detect_adversarial_suffix(&content, &perplexity_cfg)
|
|
{
|
|
let err = serde_json::json!({
|
|
"type": "error",
|
|
"message": format!(
|
|
"Input blocked by security.perplexity_filter: perplexity={:.2} (threshold {:.2}), symbol_ratio={:.2} (threshold {:.2}), suspicious_tokens={}.",
|
|
assessment.perplexity,
|
|
perplexity_cfg.perplexity_threshold,
|
|
assessment.symbol_ratio,
|
|
perplexity_cfg.symbol_ratio_threshold,
|
|
assessment.suspicious_token_count
|
|
),
|
|
});
|
|
let _ = socket.send(Message::Text(err.to_string().into())).await;
|
|
continue;
|
|
}
|
|
|
|
refresh_ws_history_system_prompt_datetime(&mut history);
|
|
|
|
// Add user message to history
|
|
history.push(ChatMessage::user(&content));
|
|
persist_ws_history(&state, &session_id, &history).await;
|
|
|
|
// Get provider info
|
|
let provider_label = state
|
|
.config
|
|
.lock()
|
|
.default_provider
|
|
.clone()
|
|
.unwrap_or_else(|| "unknown".to_string());
|
|
|
|
// Broadcast agent_start event
|
|
let _ = state.event_tx.send(serde_json::json!({
|
|
"type": "agent_start",
|
|
"provider": provider_label,
|
|
"model": state.model,
|
|
}));
|
|
|
|
// Full agentic loop with tools (includes WASM skills, shell, memory, etc.)
|
|
match super::run_gateway_chat_with_tools(&state, &content, Some(&ws_session_id)).await {
|
|
Ok(response) => {
|
|
let leak_guard_cfg = { state.config.lock().security.outbound_leak_guard.clone() };
|
|
let safe_response = finalize_ws_response(
|
|
&response,
|
|
&history,
|
|
state.tools_registry_exec.as_ref(),
|
|
&leak_guard_cfg,
|
|
);
|
|
// Add assistant response to history
|
|
history.push(ChatMessage::assistant(&safe_response));
|
|
persist_ws_history(&state, &session_id, &history).await;
|
|
|
|
// Send the full response as a done message
|
|
let done = serde_json::json!({
|
|
"type": "done",
|
|
"full_response": safe_response,
|
|
});
|
|
let _ = socket.send(Message::Text(done.to_string().into())).await;
|
|
|
|
// Broadcast agent_end event
|
|
let _ = state.event_tx.send(serde_json::json!({
|
|
"type": "agent_end",
|
|
"provider": provider_label,
|
|
"model": state.model,
|
|
}));
|
|
}
|
|
Err(e) => {
|
|
let sanitized = crate::providers::sanitize_api_error(&e.to_string());
|
|
let err = serde_json::json!({
|
|
"type": "error",
|
|
"message": sanitized,
|
|
});
|
|
let _ = socket.send(Message::Text(err.to_string().into())).await;
|
|
|
|
// Broadcast error event
|
|
let _ = state.event_tx.send(serde_json::json!({
|
|
"type": "error",
|
|
"component": "ws_chat",
|
|
"message": sanitized,
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn extract_ws_bearer_token(headers: &HeaderMap, query_token: Option<&str>) -> Option<String> {
|
|
if let Some(auth_header) = headers
|
|
.get(header::AUTHORIZATION)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(str::trim)
|
|
{
|
|
if let Some(token) = auth_header.strip_prefix("Bearer ") {
|
|
if !token.trim().is_empty() {
|
|
return Some(token.trim().to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(offered) = headers
|
|
.get(header::SEC_WEBSOCKET_PROTOCOL)
|
|
.and_then(|value| value.to_str().ok())
|
|
{
|
|
for protocol in offered.split(',').map(str::trim).filter(|s| !s.is_empty()) {
|
|
if let Some(token) = protocol.strip_prefix("bearer.") {
|
|
if !token.trim().is_empty() {
|
|
return Some(token.trim().to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
query_token
|
|
.map(str::trim)
|
|
.filter(|token| !token.is_empty())
|
|
.map(ToOwned::to_owned)
|
|
}
|
|
|
|
fn extract_query_token(raw_query: Option<&str>) -> Option<String> {
|
|
parse_ws_query_params(raw_query).token
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::tools::{Tool, ToolResult};
|
|
use async_trait::async_trait;
|
|
use axum::http::HeaderValue;
|
|
|
|
#[test]
|
|
fn extract_ws_bearer_token_prefers_authorization_header() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
header::AUTHORIZATION,
|
|
HeaderValue::from_static("Bearer from-auth-header"),
|
|
);
|
|
headers.insert(
|
|
header::SEC_WEBSOCKET_PROTOCOL,
|
|
HeaderValue::from_static("zeroclaw.v1, bearer.from-protocol"),
|
|
);
|
|
|
|
assert_eq!(
|
|
extract_ws_bearer_token(&headers, None).as_deref(),
|
|
Some("from-auth-header")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_ws_bearer_token_reads_websocket_protocol_token() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
header::SEC_WEBSOCKET_PROTOCOL,
|
|
HeaderValue::from_static("zeroclaw.v1, bearer.protocol-token"),
|
|
);
|
|
|
|
assert_eq!(
|
|
extract_ws_bearer_token(&headers, None).as_deref(),
|
|
Some("protocol-token")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_ws_bearer_token_rejects_empty_tokens() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
header::AUTHORIZATION,
|
|
HeaderValue::from_static("Bearer "),
|
|
);
|
|
headers.insert(
|
|
header::SEC_WEBSOCKET_PROTOCOL,
|
|
HeaderValue::from_static("zeroclaw.v1, bearer."),
|
|
);
|
|
|
|
assert!(extract_ws_bearer_token(&headers, None).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_ws_bearer_token_reads_query_token_fallback() {
|
|
let headers = HeaderMap::new();
|
|
assert_eq!(
|
|
extract_ws_bearer_token(&headers, Some("query-token")).as_deref(),
|
|
Some("query-token")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_ws_bearer_token_prefers_protocol_over_query_token() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
header::SEC_WEBSOCKET_PROTOCOL,
|
|
HeaderValue::from_static("zeroclaw.v1, bearer.protocol-token"),
|
|
);
|
|
|
|
assert_eq!(
|
|
extract_ws_bearer_token(&headers, Some("query-token")).as_deref(),
|
|
Some("protocol-token")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_query_token_reads_token_param() {
|
|
assert_eq!(
|
|
extract_query_token(Some("foo=1&token=query-token&bar=2")).as_deref(),
|
|
Some("query-token")
|
|
);
|
|
assert!(extract_query_token(Some("foo=1")).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_ws_query_params_reads_token_and_session_id() {
|
|
let parsed = parse_ws_query_params(Some("foo=1&session_id=sess_123&token=query-token"));
|
|
assert_eq!(parsed.token.as_deref(), Some("query-token"));
|
|
assert_eq!(parsed.session_id.as_deref(), Some("sess_123"));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_ws_query_params_rejects_invalid_session_id() {
|
|
let parsed = parse_ws_query_params(Some("session_id=../../etc/passwd"));
|
|
assert!(parsed.session_id.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn ws_history_turns_from_chat_skips_system_and_non_dialog_turns() {
|
|
let history = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user(" hello "),
|
|
ChatMessage {
|
|
role: "tool".to_string(),
|
|
content: "ignored".to_string(),
|
|
},
|
|
ChatMessage::assistant(" world "),
|
|
];
|
|
|
|
let turns = ws_history_turns_from_chat(&history);
|
|
assert_eq!(
|
|
turns,
|
|
vec![
|
|
WsHistoryTurn {
|
|
role: "user".to_string(),
|
|
content: "hello".to_string()
|
|
},
|
|
WsHistoryTurn {
|
|
role: "assistant".to_string(),
|
|
content: "world".to_string()
|
|
}
|
|
]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn refresh_ws_history_system_prompt_datetime_updates_only_system_entry() {
|
|
let mut history = vec![
|
|
ChatMessage::system("## Current Date & Time\n\n2000-01-01 00:00:00 (UTC)\n"),
|
|
ChatMessage::user("hello"),
|
|
];
|
|
refresh_ws_history_system_prompt_datetime(&mut history);
|
|
assert!(!history[0].content.contains("2000-01-01 00:00:00 (UTC)"));
|
|
assert_eq!(history[1].content, "hello");
|
|
}
|
|
|
|
#[test]
|
|
fn restore_chat_history_applies_system_prompt_once() {
|
|
let turns = vec![
|
|
WsHistoryTurn {
|
|
role: "user".to_string(),
|
|
content: "u1".to_string(),
|
|
},
|
|
WsHistoryTurn {
|
|
role: "assistant".to_string(),
|
|
content: "a1".to_string(),
|
|
},
|
|
];
|
|
|
|
let restored = restore_chat_history("sys", &turns);
|
|
assert_eq!(restored.len(), 3);
|
|
assert_eq!(restored[0].role, "system");
|
|
assert_eq!(restored[0].content, "sys");
|
|
assert_eq!(restored[1].role, "user");
|
|
assert_eq!(restored[1].content, "u1");
|
|
assert_eq!(restored[2].role, "assistant");
|
|
assert_eq!(restored[2].content, "a1");
|
|
}
|
|
|
|
#[test]
|
|
fn evaluate_ws_auth_requires_pairing_token_when_pairing_is_enabled() {
|
|
assert_eq!(
|
|
evaluate_ws_auth(true, true, false),
|
|
Some(WsAuthRejection::MissingPairingToken)
|
|
);
|
|
assert_eq!(evaluate_ws_auth(true, false, true), None);
|
|
}
|
|
|
|
#[test]
|
|
fn evaluate_ws_auth_rejects_public_without_auth_layer_when_pairing_disabled() {
|
|
assert_eq!(
|
|
evaluate_ws_auth(false, false, false),
|
|
Some(WsAuthRejection::NonLocalWithoutAuthLayer)
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn evaluate_ws_auth_allows_loopback_or_valid_token_when_pairing_disabled() {
|
|
assert_eq!(evaluate_ws_auth(false, true, false), None);
|
|
assert_eq!(evaluate_ws_auth(false, false, true), None);
|
|
}
|
|
|
|
struct MockScheduleTool;
|
|
|
|
#[async_trait]
|
|
impl Tool for MockScheduleTool {
|
|
fn name(&self) -> &str {
|
|
"schedule"
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
"Mock schedule tool"
|
|
}
|
|
|
|
fn parameters_schema(&self) -> serde_json::Value {
|
|
serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"action": { "type": "string" }
|
|
}
|
|
})
|
|
}
|
|
|
|
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
Ok(ToolResult {
|
|
success: true,
|
|
output: "ok".to_string(),
|
|
error: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn sanitize_ws_response_removes_tool_call_tags() {
|
|
let input = r#"Before
|
|
<tool_call>
|
|
{"name":"schedule","arguments":{"action":"create"}}
|
|
</tool_call>
|
|
After"#;
|
|
|
|
let leak_guard = crate::config::OutboundLeakGuardConfig::default();
|
|
let result = sanitize_ws_response(input, &[], &leak_guard);
|
|
let normalized = result
|
|
.lines()
|
|
.filter(|line| !line.trim().is_empty())
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
assert_eq!(normalized, "Before\nAfter");
|
|
assert!(!result.contains("<tool_call>"));
|
|
assert!(!result.contains("\"name\":\"schedule\""));
|
|
}
|
|
|
|
#[test]
|
|
fn sanitize_ws_response_removes_isolated_tool_json_artifacts() {
|
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
|
let input = r#"{"name":"schedule","parameters":{"action":"create"}}
|
|
{"result":{"status":"scheduled"}}
|
|
Reminder set successfully."#;
|
|
|
|
let leak_guard = crate::config::OutboundLeakGuardConfig::default();
|
|
let result = sanitize_ws_response(input, &tools, &leak_guard);
|
|
assert_eq!(result, "Reminder set successfully.");
|
|
assert!(!result.contains("\"name\":\"schedule\""));
|
|
assert!(!result.contains("\"result\""));
|
|
}
|
|
|
|
#[test]
|
|
fn sanitize_ws_response_blocks_detected_credentials_when_configured() {
|
|
let tools: Vec<Box<dyn Tool>> = Vec::new();
|
|
let leak_guard = crate::config::OutboundLeakGuardConfig {
|
|
enabled: true,
|
|
action: crate::config::OutboundLeakGuardAction::Block,
|
|
sensitivity: 0.7,
|
|
};
|
|
|
|
let result =
|
|
sanitize_ws_response("Temporary key: AKIAABCDEFGHIJKLMNOP", &tools, &leak_guard);
|
|
assert!(result.contains("blocked a draft response"));
|
|
}
|
|
|
|
#[test]
|
|
fn build_ws_system_prompt_includes_tool_protocol_for_prompt_mode() {
|
|
let config = crate::config::Config::default();
|
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
|
|
|
let prompt = build_ws_system_prompt(&config, "test-model", &tools, false);
|
|
|
|
assert!(prompt.contains("## Tool Use Protocol"));
|
|
assert!(prompt.contains("**schedule**"));
|
|
assert!(prompt.contains("## Shell Policy"));
|
|
}
|
|
|
|
#[test]
|
|
fn build_ws_system_prompt_omits_xml_protocol_for_native_mode() {
|
|
let config = crate::config::Config::default();
|
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
|
|
|
let prompt = build_ws_system_prompt(&config, "test-model", &tools, true);
|
|
|
|
assert!(!prompt.contains("## Tool Use Protocol"));
|
|
assert!(prompt.contains("**schedule**"));
|
|
assert!(prompt.contains("## Shell Policy"));
|
|
}
|
|
|
|
#[test]
|
|
fn finalize_ws_response_uses_prompt_mode_tool_output_when_final_text_empty() {
|
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
|
let history = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user(
|
|
"[Tool results]\n<tool_result name=\"schedule\">\nDisk usage: 72%\n</tool_result>",
|
|
),
|
|
];
|
|
|
|
let leak_guard = crate::config::OutboundLeakGuardConfig::default();
|
|
let result = finalize_ws_response("", &history, &tools, &leak_guard);
|
|
assert!(result.contains("Latest tool output:"));
|
|
assert!(result.contains("Disk usage: 72%"));
|
|
assert!(!result.contains("<tool_result"));
|
|
}
|
|
|
|
#[test]
|
|
fn finalize_ws_response_uses_native_tool_message_output_when_final_text_empty() {
|
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
|
let history = vec![ChatMessage {
|
|
role: "tool".to_string(),
|
|
content: r#"{"tool_call_id":"call_1","content":"Filesystem /dev/disk3s1: 210G free"}"#
|
|
.to_string(),
|
|
}];
|
|
|
|
let leak_guard = crate::config::OutboundLeakGuardConfig::default();
|
|
let result = finalize_ws_response("", &history, &tools, &leak_guard);
|
|
assert!(result.contains("Latest tool output:"));
|
|
assert!(result.contains("/dev/disk3s1"));
|
|
}
|
|
|
|
#[test]
|
|
fn finalize_ws_response_uses_static_fallback_when_nothing_available() {
|
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockScheduleTool)];
|
|
let history = vec![ChatMessage::system("sys")];
|
|
|
|
let leak_guard = crate::config::OutboundLeakGuardConfig::default();
|
|
let result = finalize_ws_response("", &history, &tools, &leak_guard);
|
|
assert_eq!(result, EMPTY_WS_RESPONSE_FALLBACK);
|
|
}
|
|
}
|