diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 03849d66c..670ade317 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -15,7 +15,7 @@ use axum::{ ws::{Message, WebSocket}, Query, State, WebSocketUpgrade, }, - http::HeaderMap, + http::{header, HeaderMap}, response::IntoResponse, }; use futures_util::{SinkExt, StreamExt}; @@ -24,12 +24,62 @@ use serde::Deserialize; /// The sub-protocol we support for the chat WebSocket. const WS_PROTOCOL: &str = "zeroclaw.v1"; +/// Prefix used in `Sec-WebSocket-Protocol` to carry a bearer token. +const BEARER_SUBPROTO_PREFIX: &str = "bearer."; + #[derive(Deserialize)] pub struct WsQuery { pub token: Option, pub session_id: Option, } +/// Extract a bearer token from WebSocket-compatible sources. +/// +/// Precedence (first non-empty wins): +/// 1. `Authorization: Bearer ` header +/// 2. `Sec-WebSocket-Protocol: bearer.` subprotocol +/// 3. `?token=` query parameter +/// +/// Browsers cannot set custom headers on `new WebSocket(url)`, so the query +/// parameter and subprotocol paths are required for browser-based clients. +fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> { + // 1. Authorization header + if let Some(t) = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|auth| auth.strip_prefix("Bearer ")) + { + if !t.is_empty() { + return Some(t); + } + } + + // 2. Sec-WebSocket-Protocol: bearer. + if let Some(t) = headers + .get("sec-websocket-protocol") + .and_then(|v| v.to_str().ok()) + .and_then(|protos| { + protos + .split(',') + .map(|p| p.trim()) + .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX)) + }) + { + if !t.is_empty() { + return Some(t); + } + } + + // 3. ?token= query parameter + if let Some(t) = query_token { + if !t.is_empty() { + return Some(t); + } + } + + None +} + /// GET /ws/chat — WebSocket upgrade for agent chat pub async fn handle_ws_chat( State(state): State, @@ -37,13 +87,13 @@ pub async fn handle_ws_chat( headers: HeaderMap, ws: WebSocketUpgrade, ) -> impl IntoResponse { - // Auth via query param (browser WebSocket limitation) + // Auth: check header, subprotocol, then query param (precedence order) if state.pairing.require_pairing() { - let token = params.token.as_deref().unwrap_or(""); + let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or(""); if !state.pairing.is_authenticated(token) { return ( axum::http::StatusCode::UNAUTHORIZED, - "Unauthorized — provide ?token=", + "Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param", ) .into_response(); } @@ -183,3 +233,85 @@ async fn handle_socket(socket: WebSocket, state: AppState, _session_id: Option