fix(gateway): add ws subprotocol negotiation and tool-enabled /agent endpoint

This commit is contained in:
argenis de la rosa 2026-03-04 06:19:18 -05:00
parent e2d65aef2a
commit 7d293a0069
2 changed files with 228 additions and 2 deletions

View File

@ -610,6 +610,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
println!(" 🌐 Web Dashboard: http://{display_addr}/");
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
println!(" POST /agent — tool-enabled agent chat {{\"message\": \"your prompt\"}}");
if whatsapp_channel.is_some() {
println!(" GET /whatsapp — Meta webhook verification");
println!(" POST /whatsapp — WhatsApp message webhook");
@ -718,6 +719,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.route("/metrics", get(handle_metrics))
.route("/pair", post(handle_pair))
.route("/webhook", post(handle_webhook))
.route("/agent", post(handle_agent))
.route("/whatsapp", get(handle_whatsapp_verify))
.route("/whatsapp", post(handle_whatsapp_message))
.route("/linq", post(handle_linq_webhook))
@ -974,6 +976,12 @@ pub struct WebhookBody {
pub message: String,
}
/// Agent request body
#[derive(serde::Deserialize)]
pub struct AgentBody {
pub message: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct NodeControlRequest {
pub method: String,
@ -1157,6 +1165,149 @@ async fn handle_node_control(
}
}
/// POST /agent — authenticated single-turn agent endpoint with tool execution.
///
/// This compatibility route mirrors CLI-style agent behavior for callers that
/// expect a JSON POST API rather than WebSocket chat.
async fn handle_agent(
State(state): State<AppState>,
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
body: Result<Json<AgentBody>, axum::extract::rejection::JsonRejection>,
) -> impl IntoResponse {
let rate_key =
client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers);
if !state.rate_limiter.allow_webhook(&rate_key) {
tracing::warn!("/agent rate limit exceeded");
let err = serde_json::json!({
"error": "Too many agent requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
if state.pairing.require_pairing() {
let auth = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let token = auth.strip_prefix("Bearer ").unwrap_or("");
if !state.pairing.is_authenticated(token) {
let err = serde_json::json!({
"error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>"
});
return (StatusCode::UNAUTHORIZED, Json(err));
}
}
let Json(agent_body) = match body {
Ok(b) => b,
Err(e) => {
tracing::warn!("/agent JSON parse error: {e}");
let err = serde_json::json!({
"error": "Invalid JSON body. Expected: {\"message\": \"...\"}"
});
return (StatusCode::BAD_REQUEST, Json(err));
}
};
let message = agent_body.message.trim();
if message.is_empty() {
let err = serde_json::json!({
"error": "message must not be empty"
});
return (StatusCode::BAD_REQUEST, Json(err));
}
if state.auto_save {
let key = webhook_memory_key();
let _ = state
.mem
.store(&key, message, MemoryCategory::Conversation, None)
.await;
}
let provider_label = state
.config
.lock()
.default_provider
.clone()
.unwrap_or_else(|| "unknown".to_string());
let model_label = state.model.clone();
let started_at = Instant::now();
state
.observer
.record_event(&crate::observability::ObserverEvent::AgentStart {
provider: provider_label.clone(),
model: model_label.clone(),
});
state
.observer
.record_event(&crate::observability::ObserverEvent::LlmRequest {
provider: provider_label.clone(),
model: model_label.clone(),
messages_count: 1,
});
let response = match run_gateway_chat_with_tools(&state, message).await {
Ok(response) => {
let safe = sanitize_gateway_response(&response, state.tools_registry_exec.as_ref());
state
.observer
.record_event(&crate::observability::ObserverEvent::LlmResponse {
provider: provider_label.clone(),
model: model_label.clone(),
duration: started_at.elapsed(),
success: true,
error_message: None,
input_tokens: None,
output_tokens: None,
});
state
.observer
.record_event(&crate::observability::ObserverEvent::TurnComplete);
safe
}
Err(e) => {
let sanitized = crate::providers::sanitize_api_error(&e.to_string());
state
.observer
.record_event(&crate::observability::ObserverEvent::LlmResponse {
provider: provider_label.clone(),
model: model_label.clone(),
duration: started_at.elapsed(),
success: false,
error_message: Some(sanitized.clone()),
input_tokens: None,
output_tokens: None,
});
let err = serde_json::json!({
"error": format!("Provider error: {sanitized}")
});
return (StatusCode::BAD_GATEWAY, Json(err));
}
};
state
.observer
.record_event(&crate::observability::ObserverEvent::AgentEnd {
provider: provider_label,
model: model_label,
duration: started_at.elapsed(),
tokens_used: None,
cost_usd: None,
});
(
StatusCode::OK,
Json(serde_json::json!({
"response": response
})),
)
}
/// POST /webhook — main webhook endpoint
async fn handle_webhook(
State(state): State<AppState>,
@ -1975,6 +2126,18 @@ mod tests {
assert!(parsed.is_err());
}
#[test]
fn agent_body_requires_message_field() {
let valid = r#"{"message": "hello"}"#;
let parsed: Result<AgentBody, _> = serde_json::from_str(valid);
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap().message, "hello");
let missing = r#"{"other": "field"}"#;
let parsed: Result<AgentBody, _> = serde_json::from_str(missing);
assert!(parsed.is_err());
}
#[test]
fn whatsapp_query_fields_are_optional() {
let q = WhatsAppVerifyQuery {
@ -2676,6 +2839,56 @@ Reminder set successfully."#;
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn agent_endpoint_requires_bearer_token_when_pairing_enabled() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl;
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let paired_token = "zc_test_token".to_string();
let state = AppState {
config: Arc::new(Mutex::new(Config::default())),
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(true, std::slice::from_ref(&paired_token))),
trust_forwarded_headers: false,
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
whatsapp: None,
whatsapp_app_secret: None,
linq: None,
linq_signing_secret: None,
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
qq: None,
qq_webhook_enabled: false,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
tools_registry_exec: Arc::new(Vec::new()),
multimodal: crate::config::MultimodalConfig::default(),
max_tool_iterations: 10,
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
};
let unauthorized = handle_agent(
State(state),
test_connect_info(),
HeaderMap::new(),
Ok(Json(AgentBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn webhook_rejects_public_traffic_without_auth_layers() {
let provider_impl = Arc::new(MockProvider::default());

View File

@ -24,6 +24,7 @@ use axum::{
const EMPTY_WS_RESPONSE_FALLBACK: &str =
"Tool execution completed, but the model returned no final text response. Please ask me to summarize the result.";
const WS_CHAT_SUBPROTOCOL: &str = "zeroclaw.v1";
fn sanitize_ws_response(response: &str, tools: &[Box<dyn crate::tools::Tool>]) -> String {
let sanitized = crate::channels::sanitize_channel_response(response, tools);
@ -123,13 +124,14 @@ pub async fn handle_ws_chat(
if !state.pairing.is_authenticated(&token) {
return (
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized — provide Authorization: Bearer <token> or Sec-WebSocket-Protocol: bearer.<token>",
"Unauthorized — provide Authorization: Bearer <token> or Sec-WebSocket-Protocol: zeroclaw.v1, bearer.<token>",
)
.into_response();
}
}
ws.on_upgrade(move |socket| handle_socket(socket, state))
ws.protocols([WS_CHAT_SUBPROTOCOL])
.on_upgrade(move |socket| handle_socket(socket, state))
.into_response()
}
@ -331,6 +333,17 @@ mod tests {
);
}
#[test]
fn extract_ws_bearer_token_ignores_protocol_without_bearer_value() {
let mut headers = HeaderMap::new();
headers.insert(
header::SEC_WEBSOCKET_PROTOCOL,
HeaderValue::from_static("zeroclaw.v1"),
);
assert!(extract_ws_bearer_token(&headers).is_none());
}
#[test]
fn extract_ws_bearer_token_rejects_empty_tokens() {
let mut headers = HeaderMap::new();