diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index a779f965f..68d3d45bb 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -610,6 +610,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { println!(" 🌐 Web Dashboard: http://{display_addr}/"); println!(" POST /pair — pair a new client (X-Pairing-Code header)"); println!(" POST /webhook — {{\"message\": \"your prompt\"}}"); + println!(" POST /agent — tool-enabled agent chat {{\"message\": \"your prompt\"}}"); if whatsapp_channel.is_some() { println!(" GET /whatsapp — Meta webhook verification"); println!(" POST /whatsapp — WhatsApp message webhook"); @@ -718,6 +719,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .route("/metrics", get(handle_metrics)) .route("/pair", post(handle_pair)) .route("/webhook", post(handle_webhook)) + .route("/agent", post(handle_agent)) .route("/whatsapp", get(handle_whatsapp_verify)) .route("/whatsapp", post(handle_whatsapp_message)) .route("/linq", post(handle_linq_webhook)) @@ -974,6 +976,12 @@ pub struct WebhookBody { pub message: String, } +/// Agent request body +#[derive(serde::Deserialize)] +pub struct AgentBody { + pub message: String, +} + #[derive(Debug, Clone, serde::Deserialize)] pub struct NodeControlRequest { pub method: String, @@ -1157,6 +1165,149 @@ async fn handle_node_control( } } +/// POST /agent — authenticated single-turn agent endpoint with tool execution. +/// +/// This compatibility route mirrors CLI-style agent behavior for callers that +/// expect a JSON POST API rather than WebSocket chat. +async fn handle_agent( + State(state): State, + ConnectInfo(peer_addr): ConnectInfo, + headers: HeaderMap, + body: Result, axum::extract::rejection::JsonRejection>, +) -> impl IntoResponse { + let rate_key = + client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); + if !state.rate_limiter.allow_webhook(&rate_key) { + tracing::warn!("/agent rate limit exceeded"); + let err = serde_json::json!({ + "error": "Too many agent requests. Please retry later.", + "retry_after": RATE_LIMIT_WINDOW_SECS, + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)); + } + + if state.pairing.require_pairing() { + let auth = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + let token = auth.strip_prefix("Bearer ").unwrap_or(""); + if !state.pairing.is_authenticated(token) { + let err = serde_json::json!({ + "error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer " + }); + return (StatusCode::UNAUTHORIZED, Json(err)); + } + } + + let Json(agent_body) = match body { + Ok(b) => b, + Err(e) => { + tracing::warn!("/agent JSON parse error: {e}"); + let err = serde_json::json!({ + "error": "Invalid JSON body. Expected: {\"message\": \"...\"}" + }); + return (StatusCode::BAD_REQUEST, Json(err)); + } + }; + + let message = agent_body.message.trim(); + if message.is_empty() { + let err = serde_json::json!({ + "error": "message must not be empty" + }); + return (StatusCode::BAD_REQUEST, Json(err)); + } + + if state.auto_save { + let key = webhook_memory_key(); + let _ = state + .mem + .store(&key, message, MemoryCategory::Conversation, None) + .await; + } + + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); + let model_label = state.model.clone(); + let started_at = Instant::now(); + + state + .observer + .record_event(&crate::observability::ObserverEvent::AgentStart { + provider: provider_label.clone(), + model: model_label.clone(), + }); + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmRequest { + provider: provider_label.clone(), + model: model_label.clone(), + messages_count: 1, + }); + + let response = match run_gateway_chat_with_tools(&state, message).await { + Ok(response) => { + let safe = sanitize_gateway_response(&response, state.tools_registry_exec.as_ref()); + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmResponse { + provider: provider_label.clone(), + model: model_label.clone(), + duration: started_at.elapsed(), + success: true, + error_message: None, + input_tokens: None, + output_tokens: None, + }); + state + .observer + .record_event(&crate::observability::ObserverEvent::TurnComplete); + safe + } + Err(e) => { + let sanitized = crate::providers::sanitize_api_error(&e.to_string()); + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmResponse { + provider: provider_label.clone(), + model: model_label.clone(), + duration: started_at.elapsed(), + success: false, + error_message: Some(sanitized.clone()), + input_tokens: None, + output_tokens: None, + }); + + let err = serde_json::json!({ + "error": format!("Provider error: {sanitized}") + }); + return (StatusCode::BAD_GATEWAY, Json(err)); + } + }; + + state + .observer + .record_event(&crate::observability::ObserverEvent::AgentEnd { + provider: provider_label, + model: model_label, + duration: started_at.elapsed(), + tokens_used: None, + cost_usd: None, + }); + + ( + StatusCode::OK, + Json(serde_json::json!({ + "response": response + })), + ) +} + /// POST /webhook — main webhook endpoint async fn handle_webhook( State(state): State, @@ -1975,6 +2126,18 @@ mod tests { assert!(parsed.is_err()); } + #[test] + fn agent_body_requires_message_field() { + let valid = r#"{"message": "hello"}"#; + let parsed: Result = serde_json::from_str(valid); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap().message, "hello"); + + let missing = r#"{"other": "field"}"#; + let parsed: Result = serde_json::from_str(missing); + assert!(parsed.is_err()); + } + #[test] fn whatsapp_query_fields_are_optional() { let q = WhatsAppVerifyQuery { @@ -2676,6 +2839,56 @@ Reminder set successfully."#; assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1); } + #[tokio::test] + async fn agent_endpoint_requires_bearer_token_when_pairing_enabled() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl; + let memory: Arc = Arc::new(MockMemory); + let paired_token = "zc_test_token".to_string(); + + let state = AppState { + config: Arc::new(Mutex::new(Config::default())), + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: None, + pairing: Arc::new(PairingGuard::new(true, std::slice::from_ref(&paired_token))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), + whatsapp: None, + whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, + nextcloud_talk: None, + nextcloud_talk_webhook_secret: None, + wati: None, + qq: None, + qq_webhook_enabled: false, + observer: Arc::new(crate::observability::NoopObserver), + tools_registry: Arc::new(Vec::new()), + tools_registry_exec: Arc::new(Vec::new()), + multimodal: crate::config::MultimodalConfig::default(), + max_tool_iterations: 10, + cost_tracker: None, + event_tx: tokio::sync::broadcast::channel(16).0, + }; + + let unauthorized = handle_agent( + State(state), + test_connect_info(), + HeaderMap::new(), + Ok(Json(AgentBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED); + } + #[tokio::test] async fn webhook_rejects_public_traffic_without_auth_layers() { let provider_impl = Arc::new(MockProvider::default()); diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 8f343ab82..5a789fbe7 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -24,6 +24,7 @@ use axum::{ const EMPTY_WS_RESPONSE_FALLBACK: &str = "Tool execution completed, but the model returned no final text response. Please ask me to summarize the result."; +const WS_CHAT_SUBPROTOCOL: &str = "zeroclaw.v1"; fn sanitize_ws_response(response: &str, tools: &[Box]) -> String { let sanitized = crate::channels::sanitize_channel_response(response, tools); @@ -123,13 +124,14 @@ pub async fn handle_ws_chat( if !state.pairing.is_authenticated(&token) { return ( axum::http::StatusCode::UNAUTHORIZED, - "Unauthorized — provide Authorization: Bearer or Sec-WebSocket-Protocol: bearer.", + "Unauthorized — provide Authorization: Bearer or Sec-WebSocket-Protocol: zeroclaw.v1, bearer.", ) .into_response(); } } - ws.on_upgrade(move |socket| handle_socket(socket, state)) + ws.protocols([WS_CHAT_SUBPROTOCOL]) + .on_upgrade(move |socket| handle_socket(socket, state)) .into_response() } @@ -331,6 +333,17 @@ mod tests { ); } + #[test] + fn extract_ws_bearer_token_ignores_protocol_without_bearer_value() { + let mut headers = HeaderMap::new(); + headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static("zeroclaw.v1"), + ); + + assert!(extract_ws_bearer_token(&headers).is_none()); + } + #[test] fn extract_ws_bearer_token_rejects_empty_tokens() { let mut headers = HeaderMap::new();