fix(gateway): add ws subprotocol negotiation and tool-enabled /agent endpoint
This commit is contained in:
parent
e2d65aef2a
commit
7d293a0069
@ -610,6 +610,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}/");
|
||||
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||
println!(" POST /agent — tool-enabled agent chat {{\"message\": \"your prompt\"}}");
|
||||
if whatsapp_channel.is_some() {
|
||||
println!(" GET /whatsapp — Meta webhook verification");
|
||||
println!(" POST /whatsapp — WhatsApp message webhook");
|
||||
@ -718,6 +719,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/metrics", get(handle_metrics))
|
||||
.route("/pair", post(handle_pair))
|
||||
.route("/webhook", post(handle_webhook))
|
||||
.route("/agent", post(handle_agent))
|
||||
.route("/whatsapp", get(handle_whatsapp_verify))
|
||||
.route("/whatsapp", post(handle_whatsapp_message))
|
||||
.route("/linq", post(handle_linq_webhook))
|
||||
@ -974,6 +976,12 @@ pub struct WebhookBody {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Agent request body
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct AgentBody {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct NodeControlRequest {
|
||||
pub method: String,
|
||||
@ -1157,6 +1165,149 @@ async fn handle_node_control(
|
||||
}
|
||||
}
|
||||
|
||||
/// POST /agent — authenticated single-turn agent endpoint with tool execution.
|
||||
///
|
||||
/// This compatibility route mirrors CLI-style agent behavior for callers that
|
||||
/// expect a JSON POST API rather than WebSocket chat.
|
||||
async fn handle_agent(
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
headers: HeaderMap,
|
||||
body: Result<Json<AgentBody>, axum::extract::rejection::JsonRejection>,
|
||||
) -> impl IntoResponse {
|
||||
let rate_key =
|
||||
client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers);
|
||||
if !state.rate_limiter.allow_webhook(&rate_key) {
|
||||
tracing::warn!("/agent rate limit exceeded");
|
||||
let err = serde_json::json!({
|
||||
"error": "Too many agent requests. Please retry later.",
|
||||
"retry_after": RATE_LIMIT_WINDOW_SECS,
|
||||
});
|
||||
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
|
||||
}
|
||||
|
||||
if state.pairing.require_pairing() {
|
||||
let auth = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let token = auth.strip_prefix("Bearer ").unwrap_or("");
|
||||
if !state.pairing.is_authenticated(token) {
|
||||
let err = serde_json::json!({
|
||||
"error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>"
|
||||
});
|
||||
return (StatusCode::UNAUTHORIZED, Json(err));
|
||||
}
|
||||
}
|
||||
|
||||
let Json(agent_body) = match body {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("/agent JSON parse error: {e}");
|
||||
let err = serde_json::json!({
|
||||
"error": "Invalid JSON body. Expected: {\"message\": \"...\"}"
|
||||
});
|
||||
return (StatusCode::BAD_REQUEST, Json(err));
|
||||
}
|
||||
};
|
||||
|
||||
let message = agent_body.message.trim();
|
||||
if message.is_empty() {
|
||||
let err = serde_json::json!({
|
||||
"error": "message must not be empty"
|
||||
});
|
||||
return (StatusCode::BAD_REQUEST, Json(err));
|
||||
}
|
||||
|
||||
if state.auto_save {
|
||||
let key = webhook_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, message, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
let provider_label = state
|
||||
.config
|
||||
.lock()
|
||||
.default_provider
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let model_label = state.model.clone();
|
||||
let started_at = Instant::now();
|
||||
|
||||
state
|
||||
.observer
|
||||
.record_event(&crate::observability::ObserverEvent::AgentStart {
|
||||
provider: provider_label.clone(),
|
||||
model: model_label.clone(),
|
||||
});
|
||||
state
|
||||
.observer
|
||||
.record_event(&crate::observability::ObserverEvent::LlmRequest {
|
||||
provider: provider_label.clone(),
|
||||
model: model_label.clone(),
|
||||
messages_count: 1,
|
||||
});
|
||||
|
||||
let response = match run_gateway_chat_with_tools(&state, message).await {
|
||||
Ok(response) => {
|
||||
let safe = sanitize_gateway_response(&response, state.tools_registry_exec.as_ref());
|
||||
state
|
||||
.observer
|
||||
.record_event(&crate::observability::ObserverEvent::LlmResponse {
|
||||
provider: provider_label.clone(),
|
||||
model: model_label.clone(),
|
||||
duration: started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
});
|
||||
state
|
||||
.observer
|
||||
.record_event(&crate::observability::ObserverEvent::TurnComplete);
|
||||
safe
|
||||
}
|
||||
Err(e) => {
|
||||
let sanitized = crate::providers::sanitize_api_error(&e.to_string());
|
||||
state
|
||||
.observer
|
||||
.record_event(&crate::observability::ObserverEvent::LlmResponse {
|
||||
provider: provider_label.clone(),
|
||||
model: model_label.clone(),
|
||||
duration: started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(sanitized.clone()),
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
});
|
||||
|
||||
let err = serde_json::json!({
|
||||
"error": format!("Provider error: {sanitized}")
|
||||
});
|
||||
return (StatusCode::BAD_GATEWAY, Json(err));
|
||||
}
|
||||
};
|
||||
|
||||
state
|
||||
.observer
|
||||
.record_event(&crate::observability::ObserverEvent::AgentEnd {
|
||||
provider: provider_label,
|
||||
model: model_label,
|
||||
duration: started_at.elapsed(),
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"response": response
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
/// POST /webhook — main webhook endpoint
|
||||
async fn handle_webhook(
|
||||
State(state): State<AppState>,
|
||||
@ -1975,6 +2126,18 @@ mod tests {
|
||||
assert!(parsed.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_body_requires_message_field() {
|
||||
let valid = r#"{"message": "hello"}"#;
|
||||
let parsed: Result<AgentBody, _> = serde_json::from_str(valid);
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap().message, "hello");
|
||||
|
||||
let missing = r#"{"other": "field"}"#;
|
||||
let parsed: Result<AgentBody, _> = serde_json::from_str(missing);
|
||||
assert!(parsed.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn whatsapp_query_fields_are_optional() {
|
||||
let q = WhatsAppVerifyQuery {
|
||||
@ -2676,6 +2839,56 @@ Reminder set successfully."#;
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn agent_endpoint_requires_bearer_token_when_pairing_enabled() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl;
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
let paired_token = "zc_test_token".to_string();
|
||||
|
||||
let state = AppState {
|
||||
config: Arc::new(Mutex::new(Config::default())),
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: None,
|
||||
pairing: Arc::new(PairingGuard::new(true, std::slice::from_ref(&paired_token))),
|
||||
trust_forwarded_headers: false,
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
linq: None,
|
||||
linq_signing_secret: None,
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
qq: None,
|
||||
qq_webhook_enabled: false,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
tools_registry_exec: Arc::new(Vec::new()),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
max_tool_iterations: 10,
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
};
|
||||
|
||||
let unauthorized = handle_agent(
|
||||
State(state),
|
||||
test_connect_info(),
|
||||
HeaderMap::new(),
|
||||
Ok(Json(AgentBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_rejects_public_traffic_without_auth_layers() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
|
||||
@ -24,6 +24,7 @@ use axum::{
|
||||
|
||||
const EMPTY_WS_RESPONSE_FALLBACK: &str =
|
||||
"Tool execution completed, but the model returned no final text response. Please ask me to summarize the result.";
|
||||
const WS_CHAT_SUBPROTOCOL: &str = "zeroclaw.v1";
|
||||
|
||||
fn sanitize_ws_response(response: &str, tools: &[Box<dyn crate::tools::Tool>]) -> String {
|
||||
let sanitized = crate::channels::sanitize_channel_response(response, tools);
|
||||
@ -123,13 +124,14 @@ pub async fn handle_ws_chat(
|
||||
if !state.pairing.is_authenticated(&token) {
|
||||
return (
|
||||
axum::http::StatusCode::UNAUTHORIZED,
|
||||
"Unauthorized — provide Authorization: Bearer <token> or Sec-WebSocket-Protocol: bearer.<token>",
|
||||
"Unauthorized — provide Authorization: Bearer <token> or Sec-WebSocket-Protocol: zeroclaw.v1, bearer.<token>",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state))
|
||||
ws.protocols([WS_CHAT_SUBPROTOCOL])
|
||||
.on_upgrade(move |socket| handle_socket(socket, state))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
@ -331,6 +333,17 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_bearer_token_ignores_protocol_without_bearer_value() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
header::SEC_WEBSOCKET_PROTOCOL,
|
||||
HeaderValue::from_static("zeroclaw.v1"),
|
||||
);
|
||||
|
||||
assert!(extract_ws_bearer_token(&headers).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_bearer_token_rejects_empty_tokens() {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user