diff --git a/docs/config-reference.md b/docs/config-reference.md index 4fa30452c..ccc8c33c1 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -382,6 +382,7 @@ WASM profile templates: | Key | Default | Purpose | |---|---|---| | `reasoning_level` | unset (`None`) | Reasoning effort/level override for providers that support explicit levels (currently OpenAI Codex `/responses`) | +| `transport` | unset (`None`) | Provider transport override (`auto`, `websocket`, `sse`) | Notes: @@ -389,6 +390,13 @@ Notes: - When set, overrides `ZEROCLAW_CODEX_REASONING_EFFORT` for OpenAI Codex requests. - Unset falls back to `ZEROCLAW_CODEX_REASONING_EFFORT` if present, otherwise defaults to `xhigh`. - If both `provider.reasoning_level` and deprecated `runtime.reasoning_level` are set, provider-level value wins. +- `provider.transport` is normalized case-insensitively (`ws` aliases to `websocket`; `http` aliases to `sse`). +- For OpenAI Codex, default transport mode is `auto` (WebSocket-first with SSE fallback). +- Transport override precedence for OpenAI Codex: + 1. `[[model_routes]].transport` (route-specific) + 2. `provider.transport` + 3. `ZEROCLAW_CODEX_TRANSPORT` / `ZEROCLAW_PROVIDER_TRANSPORT` + 4. legacy `ZEROCLAW_RESPONSES_WEBSOCKET` (boolean) ## `[skills]` @@ -668,6 +676,7 @@ Use route hints so integrations can keep stable names while model IDs evolve. | `model` | _required_ | Model to use with that provider | | `max_tokens` | unset | Optional per-route output token cap forwarded to provider APIs | | `api_key` | unset | Optional API key override for this route's provider | +| `transport` | unset | Optional per-route transport override (`auto`, `websocket`, `sse`) | ### `[[embedding_routes]]` diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 038193503..c1da7d6fd 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1704,6 +1704,7 @@ pub async fn run( let provider_runtime_options = providers::ProviderRuntimeOptions { auth_profile_override: None, provider_api_url: config.api_url.clone(), + provider_transport: config.effective_provider_transport(), zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), secrets_encrypt: config.secrets.encrypt, reasoning_enabled: config.runtime.reasoning_enabled, @@ -2184,6 +2185,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { let provider_runtime_options = providers::ProviderRuntimeOptions { auth_profile_override: None, provider_api_url: config.api_url.clone(), + provider_transport: config.effective_provider_transport(), zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), secrets_encrypt: config.secrets.encrypt, reasoning_enabled: config.runtime.reasoning_enabled, diff --git a/src/agent/prompt.rs b/src/agent/prompt.rs index 6d63489a2..612a5c958 100644 --- a/src/agent/prompt.rs +++ b/src/agent/prompt.rs @@ -115,7 +115,9 @@ impl PromptSection for IdentitySection { inject_workspace_file(&mut prompt, ctx.workspace_dir, "MEMORY.md"); } - let extra_files = ctx.identity_config.map_or(&[][..], |cfg| cfg.extra_files.as_slice()); + let extra_files = ctx + .identity_config + .map_or(&[][..], |cfg| cfg.extra_files.as_slice()); for file in extra_files { if let Some(safe_relative) = normalize_openclaw_identity_extra_file(file) { inject_workspace_file(&mut prompt, ctx.workspace_dir, safe_relative); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 34254d17d..decf6f017 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -4805,6 +4805,7 @@ pub async fn start_channels(config: Config) -> Result<()> { let provider_runtime_options = providers::ProviderRuntimeOptions { auth_profile_override: None, provider_api_url: config.api_url.clone(), + provider_transport: None, zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), secrets_encrypt: config.secrets.encrypt, reasoning_enabled: config.runtime.reasoning_enabled, diff --git a/src/config/mod.rs b/src/config/mod.rs index ac6abaa2b..cb5acb468 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,15 +11,15 @@ pub use schema::{ DockerRuntimeConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, - MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, ObservabilityConfig, - OtpChallengeDelivery, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, - NonCliNaturalLanguageApprovalMode, PerplexityFilterConfig, PluginEntryConfig, PluginsConfig, - ProviderConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, - ReliabilityConfig, ResearchPhaseConfig, ResearchTrigger, ResourceLimitsConfig, RuntimeConfig, - SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SecurityRoleConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, - StorageProviderConfig, StorageProviderSection, StreamMode, SyscallAnomalyConfig, - TelegramConfig, TranscriptionConfig, TunnelConfig, UrlAccessConfig, + MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, + NonCliNaturalLanguageApprovalMode, ObservabilityConfig, OtpChallengeDelivery, OtpConfig, + OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PerplexityFilterConfig, PluginEntryConfig, + PluginsConfig, ProviderConfig, ProxyConfig, ProxyScope, QdrantConfig, + QueryClassificationConfig, ReliabilityConfig, ResearchPhaseConfig, ResearchTrigger, + ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, + SecretsConfig, SecurityConfig, SecurityRoleConfig, SkillsConfig, SkillsPromptInjectionMode, + SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, + SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig, TunnelConfig, UrlAccessConfig, WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy, WasmRuntimeConfig, WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index fb7315409..162f3f5b9 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -309,6 +309,10 @@ pub struct ProviderConfig { /// (e.g. OpenAI Codex `/responses` reasoning effort). #[serde(default)] pub reasoning_level: Option, + /// Optional transport override for providers that support multiple transports. + /// Supported values: "auto", "websocket", "sse". + #[serde(default)] + pub transport: Option, } // ── Delegate Agents ────────────────────────────────────────────── @@ -3195,6 +3199,10 @@ pub struct ModelRouteConfig { /// Optional API key override for this route's provider #[serde(default)] pub api_key: Option, + /// Optional provider transport override for this route. + /// Supported values: "auto", "websocket", "sse". + #[serde(default)] + pub transport: Option, } // ── Embedding routing ─────────────────────────────────────────── @@ -6041,6 +6049,28 @@ impl Config { } } + fn normalize_provider_transport(raw: Option<&str>, source: &str) -> Option { + let value = raw?.trim(); + if value.is_empty() { + return None; + } + + let normalized = value.to_ascii_lowercase().replace(['-', '_'], ""); + match normalized.as_str() { + "auto" => Some("auto".to_string()), + "websocket" | "ws" => Some("websocket".to_string()), + "sse" | "http" => Some("sse".to_string()), + _ => { + tracing::warn!( + transport = %value, + source, + "Ignoring invalid provider transport override" + ); + None + } + } + } + /// Resolve provider reasoning level with backward-compatible runtime alias. /// /// Priority: @@ -6084,6 +6114,16 @@ impl Config { } } + /// Resolve provider transport mode (`provider.transport`). + /// + /// Supported values: + /// - `auto` + /// - `websocket` + /// - `sse` + pub fn effective_provider_transport(&self) -> Option { + Self::normalize_provider_transport(self.provider.transport.as_deref(), "provider.transport") + } + fn lookup_model_provider_profile( &self, provider_name: &str, @@ -6447,6 +6487,32 @@ impl Config { if route.max_tokens == Some(0) { anyhow::bail!("model_routes[{i}].max_tokens must be greater than 0"); } + if route + .transport + .as_deref() + .is_some_and(|value| !value.trim().is_empty()) + && Self::normalize_provider_transport( + route.transport.as_deref(), + "model_routes[].transport", + ) + .is_none() + { + anyhow::bail!("model_routes[{i}].transport must be one of: auto, websocket, sse"); + } + } + + if self + .provider + .transport + .as_deref() + .is_some_and(|value| !value.trim().is_empty()) + && Self::normalize_provider_transport( + self.provider.transport.as_deref(), + "provider.transport", + ) + .is_none() + { + anyhow::bail!("provider.transport must be one of: auto, websocket, sse"); } if self.provider_api.is_some() @@ -6778,6 +6844,17 @@ impl Config { } } + // Provider transport override: ZEROCLAW_PROVIDER_TRANSPORT or PROVIDER_TRANSPORT + if let Ok(transport) = std::env::var("ZEROCLAW_PROVIDER_TRANSPORT") + .or_else(|_| std::env::var("PROVIDER_TRANSPORT")) + { + if let Some(normalized) = + Self::normalize_provider_transport(Some(&transport), "env:provider_transport") + { + self.provider.transport = Some(normalized); + } + } + // Vision support override: ZEROCLAW_MODEL_SUPPORT_VISION or MODEL_SUPPORT_VISION if let Ok(flag) = std::env::var("ZEROCLAW_MODEL_SUPPORT_VISION") .or_else(|_| std::env::var("MODEL_SUPPORT_VISION")) @@ -9376,6 +9453,7 @@ provider_api = "not-a-real-mode" model: "anthropic/claude-sonnet-4.6".to_string(), max_tokens: Some(0), api_key: None, + transport: None, }]; let err = config @@ -9386,6 +9464,48 @@ provider_api = "not-a-real-mode" .contains("model_routes[0].max_tokens must be greater than 0")); } + #[test] + async fn provider_transport_normalizes_aliases() { + let mut config = Config::default(); + config.provider.transport = Some("WS".to_string()); + assert_eq!( + config.effective_provider_transport().as_deref(), + Some("websocket") + ); + } + + #[test] + async fn provider_transport_invalid_is_rejected() { + let mut config = Config::default(); + config.provider.transport = Some("udp".to_string()); + let err = config + .validate() + .expect_err("provider.transport should reject invalid values"); + assert!(err + .to_string() + .contains("provider.transport must be one of: auto, websocket, sse")); + } + + #[test] + async fn model_route_transport_invalid_is_rejected() { + let mut config = Config::default(); + config.model_routes = vec![ModelRouteConfig { + hint: "reasoning".to_string(), + provider: "openrouter".to_string(), + model: "anthropic/claude-sonnet-4.6".to_string(), + max_tokens: None, + api_key: None, + transport: Some("udp".to_string()), + }]; + + let err = config + .validate() + .expect_err("model_routes[].transport should reject invalid values"); + assert!(err + .to_string() + .contains("model_routes[0].transport must be one of: auto, websocket, sse")); + } + #[test] async fn env_override_glm_api_key_for_regional_aliases() { let _env_guard = env_override_lock().await; diff --git a/src/doctor/mod.rs b/src/doctor/mod.rs index edb29a2a3..bd80c54af 100644 --- a/src/doctor/mod.rs +++ b/src/doctor/mod.rs @@ -1167,6 +1167,7 @@ mod tests { model: String::new(), max_tokens: None, api_key: None, + transport: None, }]; let mut items = Vec::new(); check_config_semantics(&config, &mut items); diff --git a/src/gateway/api.rs b/src/gateway/api.rs index a936264c0..20adaf035 100644 --- a/src/gateway/api.rs +++ b/src/gateway/api.rs @@ -706,7 +706,10 @@ fn restore_masked_sensitive_fields( restore_optional_secret(&mut incoming.proxy.http_proxy, ¤t.proxy.http_proxy); restore_optional_secret(&mut incoming.proxy.https_proxy, ¤t.proxy.https_proxy); restore_optional_secret(&mut incoming.proxy.all_proxy, ¤t.proxy.all_proxy); - restore_optional_secret(&mut incoming.transcription.api_key, ¤t.transcription.api_key); + restore_optional_secret( + &mut incoming.transcription.api_key, + ¤t.transcription.api_key, + ); restore_optional_secret( &mut incoming.browser.computer_use.api_key, ¤t.browser.computer_use.api_key, @@ -932,7 +935,10 @@ mod tests { assert_eq!(hydrated.config_path, current.config_path); assert_eq!(hydrated.workspace_dir, current.workspace_dir); assert_eq!(hydrated.api_key, current.api_key); - assert_eq!(hydrated.transcription.api_key, current.transcription.api_key); + assert_eq!( + hydrated.transcription.api_key, + current.transcription.api_key + ); assert_eq!(hydrated.default_model.as_deref(), Some("gpt-4.1-mini")); assert_eq!( hydrated.reliability.api_keys, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e11659729..4df567bfb 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -362,6 +362,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &providers::ProviderRuntimeOptions { auth_profile_override: None, provider_api_url: config.api_url.clone(), + provider_transport: None, zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), secrets_encrypt: config.secrets.encrypt, reasoning_enabled: config.runtime.reasoning_enabled, diff --git a/src/gateway/openclaw_compat.rs b/src/gateway/openclaw_compat.rs index 9be848da7..521222c75 100644 --- a/src/gateway/openclaw_compat.rs +++ b/src/gateway/openclaw_compat.rs @@ -95,9 +95,7 @@ pub async fn handle_api_chat( && state.webhook_secret_hash.is_none() && !peer_addr.ip().is_loopback() { - tracing::warn!( - "/api/chat: rejected unauthenticated non-loopback request" - ); + tracing::warn!("/api/chat: rejected unauthenticated non-loopback request"); let err = serde_json::json!({ "error": "Unauthorized — configure pairing or X-Webhook-Secret for non-local access" }); @@ -152,7 +150,11 @@ pub async fn handle_api_chat( message.to_string() } else { let recent: Vec<&String> = chat_body.context.iter().rev().take(10).rev().collect(); - let context_block = recent.iter().map(|s| s.as_str()).collect::>().join("\n"); + let context_block = recent + .iter() + .map(|s| s.as_str()) + .collect::>() + .join("\n"); format!( "Recent conversation context:\n{}\n\nCurrent message:\n{}", context_block, message @@ -395,7 +397,9 @@ pub async fn handle_v1_chat_completions_with_tools( .unwrap_or(""); let token = auth.strip_prefix("Bearer ").unwrap_or(""); if !state.pairing.is_authenticated(token) { - tracing::warn!("/v1/chat/completions (compat): rejected — not paired / invalid bearer token"); + tracing::warn!( + "/v1/chat/completions (compat): rejected — not paired / invalid bearer token" + ); let err = serde_json::json!({ "error": { "message": "Invalid API key. Pair first via POST /pair, then use Authorization: Bearer ", @@ -481,7 +485,11 @@ pub async fn handle_v1_chat_completions_with_tools( .rev() .filter(|m| m.role == "user" || m.role == "assistant") .map(|m| { - let role_label = if m.role == "user" { "User" } else { "Assistant" }; + let role_label = if m.role == "user" { + "User" + } else { + "Assistant" + }; format!("{}: {}", role_label, m.content) }) .collect(); @@ -495,7 +503,11 @@ pub async fn handle_v1_chat_completions_with_tools( .take(MAX_CONTEXT_MESSAGES) .rev() .collect(); - let context_block = recent.iter().map(|s| s.as_str()).collect::>().join("\n"); + let context_block = recent + .iter() + .map(|s| s.as_str()) + .collect::>() + .join("\n"); format!( "Recent conversation context:\n{}\n\nCurrent message:\n{}", context_block, message @@ -617,9 +629,7 @@ pub async fn handle_v1_chat_completions_with_tools( } }; - let model_name = request - .model - .unwrap_or_else(|| state.model.clone()); + let model_name = request.model.unwrap_or_else(|| state.model.clone()); #[allow(clippy::cast_possible_truncation)] let prompt_tokens = (enriched_message.len() / 4) as u32; @@ -844,14 +854,20 @@ mod tests { fn api_chat_body_rejects_missing_message() { let json = r#"{"session_id": "s1"}"#; let result: Result = serde_json::from_str(json); - assert!(result.is_err(), "missing `message` field should fail deserialization"); + assert!( + result.is_err(), + "missing `message` field should fail deserialization" + ); } #[test] fn oai_request_rejects_empty_messages() { let json = r#"{"messages": []}"#; let req: OaiChatRequest = serde_json::from_str(json).unwrap(); - assert!(req.messages.is_empty(), "empty messages should parse but be caught by handler"); + assert!( + req.messages.is_empty(), + "empty messages should parse but be caught by handler" + ); } #[test] @@ -892,7 +908,17 @@ mod tests { .skip(1) .rev() .filter(|m| m.role == "user" || m.role == "assistant") - .map(|m| format!("{}: {}", if m.role == "user" { "User" } else { "Assistant" }, m.content)) + .map(|m| { + format!( + "{}: {}", + if m.role == "user" { + "User" + } else { + "Assistant" + }, + m.content + ) + }) .collect(); assert_eq!(context_messages.len(), 2); diff --git a/src/lib.rs b/src/lib.rs index 056ab6ad9..5a8be0779 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,8 +57,6 @@ pub(crate) mod heartbeat; pub mod hooks; pub(crate) mod identity; // Intentionally unused re-export — public API surface for plugin authors. -#[allow(unused_imports)] -pub(crate) mod plugins; pub(crate) mod integrations; pub mod memory; pub(crate) mod migration; @@ -66,6 +64,8 @@ pub(crate) mod multimodal; pub mod observability; pub(crate) mod onboard; pub mod peripherals; +#[allow(unused_imports)] +pub(crate) mod plugins; pub mod providers; pub mod rag; pub mod runtime; diff --git a/src/plugins/discovery.rs b/src/plugins/discovery.rs index 330080e18..44fab394f 100644 --- a/src/plugins/discovery.rs +++ b/src/plugins/discovery.rs @@ -5,7 +5,9 @@ use std::path::{Path, PathBuf}; -use super::manifest::{load_manifest, ManifestLoadResult, PluginManifest, PLUGIN_MANIFEST_FILENAME}; +use super::manifest::{ + load_manifest, ManifestLoadResult, PluginManifest, PLUGIN_MANIFEST_FILENAME, +}; use super::registry::{DiagnosticLevel, PluginDiagnostic, PluginOrigin}; /// A discovered plugin before loading. @@ -79,10 +81,7 @@ fn scan_dir(dir: &Path, origin: PluginOrigin) -> (Vec, Vec/.zeroclaw/extensions/` /// 4. Extra paths from config `[plugins] load_paths` -pub fn discover_plugins( - workspace_dir: Option<&Path>, - extra_paths: &[PathBuf], -) -> DiscoveryResult { +pub fn discover_plugins(workspace_dir: Option<&Path>, extra_paths: &[PathBuf]) -> DiscoveryResult { let mut all_plugins = Vec::new(); let mut all_diagnostics = Vec::new(); @@ -183,10 +182,7 @@ version = "0.1.0" make_plugin_dir(&ext_dir, "custom-one"); let result = discover_plugins(None, &[ext_dir]); - assert!(result - .plugins - .iter() - .any(|p| p.manifest.id == "custom-one")); + assert!(result.plugins.iter().any(|p| p.manifest.id == "custom-one")); } #[test] diff --git a/src/plugins/loader.rs b/src/plugins/loader.rs index 722e11f1a..073cd7a1a 100644 --- a/src/plugins/loader.rs +++ b/src/plugins/loader.rs @@ -306,7 +306,10 @@ mod tests { }; let reg = load_plugins(&cfg, None, vec![]); assert_eq!(reg.active_count(), 0); - assert!(reg.diagnostics.iter().any(|d| d.message.contains("disabled"))); + assert!(reg + .diagnostics + .iter() + .any(|d| d.message.contains("disabled"))); } #[test] diff --git a/src/providers/mod.rs b/src/providers/mod.rs index d716137f7..bca02f915 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -683,6 +683,7 @@ fn zai_base_url(name: &str) -> Option<&'static str> { pub struct ProviderRuntimeOptions { pub auth_profile_override: Option, pub provider_api_url: Option, + pub provider_transport: Option, pub zeroclaw_dir: Option, pub secrets_encrypt: bool, pub reasoning_enabled: Option, @@ -697,6 +698,7 @@ impl Default for ProviderRuntimeOptions { Self { auth_profile_override: None, provider_api_url: None, + provider_transport: None, zeroclaw_dir: None, secrets_encrypt: true, reasoning_enabled: None, @@ -1512,7 +1514,15 @@ pub fn create_routed_provider_with_options( .then_some(api_url) .flatten(); - let route_options = options.clone(); + let mut route_options = options.clone(); + if let Some(transport) = route + .transport + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + { + route_options.provider_transport = Some(transport.to_string()); + } match create_resilient_provider_with_options( &route.provider, @@ -3049,6 +3059,7 @@ mod tests { model: "anthropic/claude-sonnet-4.6".to_string(), max_tokens: Some(4096), api_key: None, + transport: None, }]; let provider = create_routed_provider_with_options( diff --git a/src/providers/openai_codex.rs b/src/providers/openai_codex.rs index 36b0d472f..aedaa481a 100644 --- a/src/providers/openai_codex.rs +++ b/src/providers/openai_codex.rs @@ -4,21 +4,44 @@ use crate::multimodal; use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities}; use crate::providers::ProviderRuntimeOptions; use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::path::PathBuf; +use tokio_tungstenite::{ + connect_async, + tungstenite::{ + client::IntoClientRequest, + http::{ + header::{AUTHORIZATION, USER_AGENT}, + HeaderValue as WsHeaderValue, + }, + Message as WsMessage, + }, +}; const DEFAULT_CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses"; const CODEX_RESPONSES_URL_ENV: &str = "ZEROCLAW_CODEX_RESPONSES_URL"; const CODEX_BASE_URL_ENV: &str = "ZEROCLAW_CODEX_BASE_URL"; +const CODEX_TRANSPORT_ENV: &str = "ZEROCLAW_CODEX_TRANSPORT"; +const CODEX_PROVIDER_TRANSPORT_ENV: &str = "ZEROCLAW_PROVIDER_TRANSPORT"; +const CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY: &str = "ZEROCLAW_RESPONSES_WEBSOCKET"; const DEFAULT_CODEX_INSTRUCTIONS: &str = "You are ZeroClaw, a concise and helpful coding assistant."; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CodexTransport { + Auto, + WebSocket, + Sse, +} + pub struct OpenAiCodexProvider { auth: AuthService, auth_profile_override: Option, responses_url: String, + transport: CodexTransport, custom_endpoint: bool, gateway_api_key: Option, reasoning_level: Option, @@ -104,6 +127,7 @@ impl OpenAiCodexProvider { auth_profile_override: options.auth_profile_override.clone(), custom_endpoint: !is_default_responses_url(&responses_url), responses_url, + transport: resolve_transport_mode(options), gateway_api_key: gateway_api_key.map(ToString::to_string), reasoning_level: normalize_reasoning_level( options.reasoning_level.as_deref(), @@ -204,6 +228,73 @@ fn first_nonempty(text: Option<&str>) -> Option { }) } +fn parse_transport_override(raw: Option<&str>, source: &str) -> Option { + let value = raw?.trim(); + if value.is_empty() { + return None; + } + + let normalized = value.to_ascii_lowercase().replace(['-', '_'], ""); + match normalized.as_str() { + "auto" => Some(CodexTransport::Auto), + "websocket" | "ws" => Some(CodexTransport::WebSocket), + "sse" | "http" => Some(CodexTransport::Sse), + _ => { + tracing::warn!( + transport = %value, + source, + "Ignoring invalid OpenAI Codex transport override" + ); + None + } + } +} + +fn parse_legacy_websocket_flag(raw: &str) -> Option { + let normalized = raw.trim().to_ascii_lowercase(); + match normalized.as_str() { + "1" | "true" | "on" | "yes" => Some(CodexTransport::WebSocket), + "0" | "false" | "off" | "no" => Some(CodexTransport::Sse), + _ => None, + } +} + +fn resolve_transport_mode(options: &ProviderRuntimeOptions) -> CodexTransport { + if let Some(mode) = parse_transport_override( + options.provider_transport.as_deref(), + "provider.transport runtime override", + ) { + return mode; + } + + if let Some(mode) = std::env::var(CODEX_TRANSPORT_ENV) + .ok() + .and_then(|value| parse_transport_override(Some(&value), CODEX_TRANSPORT_ENV)) + { + return mode; + } + + if let Some(mode) = std::env::var(CODEX_PROVIDER_TRANSPORT_ENV) + .ok() + .and_then(|value| parse_transport_override(Some(&value), CODEX_PROVIDER_TRANSPORT_ENV)) + { + return mode; + } + + if let Some(mode) = std::env::var(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY) + .ok() + .and_then(|value| parse_legacy_websocket_flag(&value)) + { + tracing::warn!( + env = CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, + "Using deprecated websocket toggle env for OpenAI Codex transport" + ); + return mode; + } + + CodexTransport::Auto +} + fn resolve_instructions(system_prompt: Option<&str>) -> String { first_nonempty(system_prompt).unwrap_or_else(|| DEFAULT_CODEX_INSTRUCTIONS.to_string()) } @@ -526,6 +617,218 @@ async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result anyhow::Result { + let mut url = reqwest::Url::parse(&self.responses_url)?; + let next_scheme: &'static str = match url.scheme() { + "https" | "wss" => "wss", + "http" | "ws" => "ws", + other => { + anyhow::bail!( + "OpenAI Codex websocket transport does not support URL scheme: {}", + other + ); + } + }; + + url.set_scheme(next_scheme) + .map_err(|()| anyhow::anyhow!("failed to set websocket URL scheme"))?; + + if !url.query_pairs().any(|(k, _)| k == "model") { + url.query_pairs_mut().append_pair("model", model); + } + + Ok(url.into()) + } + + fn apply_auth_headers_ws( + &self, + request: &mut tokio_tungstenite::tungstenite::http::Request<()>, + bearer_token: &str, + account_id: Option<&str>, + access_token: Option<&str>, + use_gateway_api_key_auth: bool, + ) -> anyhow::Result<()> { + let headers = request.headers_mut(); + headers.insert( + AUTHORIZATION, + WsHeaderValue::from_str(&format!("Bearer {bearer_token}"))?, + ); + headers.insert( + "OpenAI-Beta", + WsHeaderValue::from_static("responses=experimental"), + ); + headers.insert("originator", WsHeaderValue::from_static("pi")); + headers.insert("accept", WsHeaderValue::from_static("text/event-stream")); + headers.insert(USER_AGENT, WsHeaderValue::from_static("zeroclaw")); + + if let Some(account_id) = account_id { + headers.insert("chatgpt-account-id", WsHeaderValue::from_str(account_id)?); + } + + if use_gateway_api_key_auth { + if let Some(access_token) = access_token { + headers.insert( + "x-openai-access-token", + WsHeaderValue::from_str(access_token)?, + ); + } + if let Some(account_id) = account_id { + headers.insert("x-openai-account-id", WsHeaderValue::from_str(account_id)?); + } + } + + Ok(()) + } + + async fn send_responses_websocket_request( + &self, + request: &ResponsesRequest, + model: &str, + bearer_token: &str, + account_id: Option<&str>, + access_token: Option<&str>, + use_gateway_api_key_auth: bool, + ) -> anyhow::Result { + let ws_url = self.responses_websocket_url(model)?; + let mut ws_request = ws_url + .into_client_request() + .map_err(|error| anyhow::anyhow!("invalid websocket request URL: {error}"))?; + self.apply_auth_headers_ws( + &mut ws_request, + bearer_token, + account_id, + access_token, + use_gateway_api_key_auth, + )?; + + let payload = serde_json::json!({ + "type": "response.create", + "model": &request.model, + "input": &request.input, + "instructions": &request.instructions, + "store": request.store, + "text": &request.text, + "reasoning": &request.reasoning, + "include": &request.include, + "tool_choice": &request.tool_choice, + "parallel_tool_calls": request.parallel_tool_calls, + }); + + let (mut ws_stream, _) = connect_async(ws_request).await?; + ws_stream + .send(WsMessage::Text(serde_json::to_string(&payload)?.into())) + .await?; + + let mut saw_delta = false; + let mut delta_accumulator = String::new(); + let mut fallback_text: Option = None; + + while let Some(frame) = ws_stream.next().await { + let frame = frame?; + let event: Value = match frame { + WsMessage::Text(text) => serde_json::from_str(text.as_ref())?, + WsMessage::Binary(binary) => { + let text = String::from_utf8(binary.to_vec()).map_err(|error| { + anyhow::anyhow!("invalid UTF-8 websocket frame from OpenAI Codex: {error}") + })?; + serde_json::from_str(&text)? + } + WsMessage::Ping(payload) => { + ws_stream.send(WsMessage::Pong(payload)).await?; + continue; + } + WsMessage::Close(_) => break, + _ => continue, + }; + + if let Some(message) = extract_stream_error_message(&event) { + anyhow::bail!("OpenAI Codex websocket stream error: {message}"); + } + + if let Some(text) = extract_stream_event_text(&event, saw_delta) { + let event_type = event.get("type").and_then(Value::as_str); + if event_type == Some("response.output_text.delta") { + saw_delta = true; + delta_accumulator.push_str(&text); + } else if fallback_text.is_none() { + fallback_text = Some(text); + } + } + + let event_type = event.get("type").and_then(Value::as_str); + if event_type == Some("response.completed") || event_type == Some("response.done") { + if let Some(response_value) = event.get("response").cloned() { + if let Ok(parsed) = serde_json::from_value::(response_value) + { + if let Some(text) = extract_responses_text(&parsed) { + let _ = ws_stream.close(None).await; + return Ok(text); + } + } + } + + if saw_delta { + let _ = ws_stream.close(None).await; + return nonempty_preserve(Some(&delta_accumulator)) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI Codex")); + } + if let Some(text) = fallback_text.clone() { + let _ = ws_stream.close(None).await; + return Ok(text); + } + } + } + + if saw_delta { + return nonempty_preserve(Some(&delta_accumulator)) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI Codex")); + } + if let Some(text) = fallback_text { + return Ok(text); + } + + anyhow::bail!("No response from OpenAI Codex websocket stream"); + } + + async fn send_responses_sse_request( + &self, + request: &ResponsesRequest, + bearer_token: &str, + account_id: Option<&str>, + access_token: Option<&str>, + use_gateway_api_key_auth: bool, + ) -> anyhow::Result { + let mut request_builder = self + .client + .post(&self.responses_url) + .header("Authorization", format!("Bearer {bearer_token}")) + .header("OpenAI-Beta", "responses=experimental") + .header("originator", "pi") + .header("accept", "text/event-stream") + .header("Content-Type", "application/json"); + + if let Some(account_id) = account_id { + request_builder = request_builder.header("chatgpt-account-id", account_id); + } + + if use_gateway_api_key_auth { + if let Some(access_token) = access_token { + request_builder = request_builder.header("x-openai-access-token", access_token); + } + if let Some(account_id) = account_id { + request_builder = request_builder.header("x-openai-account-id", account_id); + } + } + + let response = request_builder.json(request).send().await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenAI Codex", response).await); + } + + decode_responses_body(response).await + } + async fn send_responses_request( &self, input: Vec, @@ -613,35 +916,58 @@ impl OpenAiCodexProvider { access_token.as_deref().unwrap_or_default() }; - let mut request_builder = self - .client - .post(&self.responses_url) - .header("Authorization", format!("Bearer {bearer_token}")) - .header("OpenAI-Beta", "responses=experimental") - .header("originator", "pi") - .header("accept", "text/event-stream") - .header("Content-Type", "application/json"); - - if let Some(account_id) = account_id.as_deref() { - request_builder = request_builder.header("chatgpt-account-id", account_id); - } - - if use_gateway_api_key_auth { - if let Some(access_token) = access_token.as_deref() { - request_builder = request_builder.header("x-openai-access-token", access_token); + match self.transport { + CodexTransport::WebSocket => { + self.send_responses_websocket_request( + &request, + normalized_model, + bearer_token, + account_id.as_deref(), + access_token.as_deref(), + use_gateway_api_key_auth, + ) + .await } - if let Some(account_id) = account_id.as_deref() { - request_builder = request_builder.header("x-openai-account-id", account_id); + CodexTransport::Sse => { + self.send_responses_sse_request( + &request, + bearer_token, + account_id.as_deref(), + access_token.as_deref(), + use_gateway_api_key_auth, + ) + .await + } + CodexTransport::Auto => { + match self + .send_responses_websocket_request( + &request, + normalized_model, + bearer_token, + account_id.as_deref(), + access_token.as_deref(), + use_gateway_api_key_auth, + ) + .await + { + Ok(text) => Ok(text), + Err(error) => { + tracing::warn!( + error = %error, + "OpenAI Codex websocket request failed; falling back to SSE" + ); + self.send_responses_sse_request( + &request, + bearer_token, + account_id.as_deref(), + access_token.as_deref(), + use_gateway_api_key_auth, + ) + .await + } + } } } - - let response = request_builder.json(&request).send().await?; - - if !response.status().is_success() { - return Err(super::api_error("OpenAI Codex", response).await); - } - - decode_responses_body(response).await } } @@ -809,6 +1135,63 @@ mod tests { ); } + #[test] + fn resolve_transport_mode_defaults_to_auto() { + let _env_lock = env_lock(); + let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None); + let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, None); + let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None); + + assert_eq!( + resolve_transport_mode(&ProviderRuntimeOptions::default()), + CodexTransport::Auto + ); + } + + #[test] + fn resolve_transport_mode_accepts_runtime_override() { + let _env_lock = env_lock(); + let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, Some("sse")); + + let options = ProviderRuntimeOptions { + provider_transport: Some("websocket".to_string()), + ..ProviderRuntimeOptions::default() + }; + + assert_eq!(resolve_transport_mode(&options), CodexTransport::WebSocket); + } + + #[test] + fn resolve_transport_mode_legacy_bool_env_is_supported() { + let _env_lock = env_lock(); + let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None); + let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None); + let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, Some("false")); + + assert_eq!( + resolve_transport_mode(&ProviderRuntimeOptions::default()), + CodexTransport::Sse + ); + } + + #[test] + fn websocket_url_uses_ws_scheme_and_model_query() { + let _env_lock = env_lock(); + let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None); + let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None); + + let options = ProviderRuntimeOptions::default(); + let provider = OpenAiCodexProvider::new(&options, None).expect("provider should init"); + let ws_url = provider + .responses_websocket_url("gpt-5.3-codex") + .expect("websocket URL should be derived"); + + assert_eq!( + ws_url, + "wss://chatgpt.com/backend-api/codex/responses?model=gpt-5.3-codex" + ); + } + #[test] fn default_responses_url_detector_handles_equivalent_urls() { assert!(is_default_responses_url(DEFAULT_CODEX_RESPONSES_URL)); @@ -1077,6 +1460,7 @@ data: [DONE] fn capabilities_includes_vision() { let options = ProviderRuntimeOptions { provider_api_url: None, + provider_transport: None, zeroclaw_dir: None, secrets_encrypt: false, auth_profile_override: None, diff --git a/src/tools/docx_read.rs b/src/tools/docx_read.rs index 316a308e5..e63527631 100644 --- a/src/tools/docx_read.rs +++ b/src/tools/docx_read.rs @@ -202,24 +202,23 @@ impl Tool for DocxReadTool { } }; - let text = - match tokio::task::spawn_blocking(move || extract_docx_text(&bytes)).await { - Ok(Ok(t)) => t, - Ok(Err(e)) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("DOCX extraction failed: {e}")), - }); - } - Err(e) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("DOCX extraction task panicked: {e}")), - }); - } - }; + let text = match tokio::task::spawn_blocking(move || extract_docx_text(&bytes)).await { + Ok(Ok(t)) => t, + Ok(Err(e)) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("DOCX extraction failed: {e}")), + }); + } + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("DOCX extraction task panicked: {e}")), + }); + } + }; if text.trim().is_empty() { return Ok(ToolResult { diff --git a/src/tools/mod.rs b/src/tools/mod.rs index b28af8b64..7ddc341d8 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -428,6 +428,7 @@ pub fn all_tools_with_runtime( let provider_runtime_options = crate::providers::ProviderRuntimeOptions { auth_profile_override: None, provider_api_url: root_config.api_url.clone(), + provider_transport: None, zeroclaw_dir: root_config .config_path .parent() diff --git a/src/tools/model_routing_config.rs b/src/tools/model_routing_config.rs index 7e7b01096..6ff23ebb9 100644 --- a/src/tools/model_routing_config.rs +++ b/src/tools/model_routing_config.rs @@ -466,6 +466,7 @@ impl ModelRoutingConfigTool { model: model.clone(), max_tokens: None, api_key: None, + transport: None, }); next_route.hint = hint.clone(); @@ -1160,6 +1161,9 @@ mod tests { .unwrap(); assert_eq!(route["api_key_configured"], json!(true)); - assert_eq!(output["agents"]["voice_helper"]["api_key_configured"], json!(true)); + assert_eq!( + output["agents"]["voice_helper"]["api_key_configured"], + json!(true) + ); } } diff --git a/tests/openai_codex_vision_e2e.rs b/tests/openai_codex_vision_e2e.rs index 843108906..a4a7875fa 100644 --- a/tests/openai_codex_vision_e2e.rs +++ b/tests/openai_codex_vision_e2e.rs @@ -148,6 +148,7 @@ async fn openai_codex_second_vision_support() -> Result<()> { let opts = ProviderRuntimeOptions { auth_profile_override: Some("second".to_string()), provider_api_url: None, + provider_transport: None, zeroclaw_dir: None, secrets_encrypt: false, reasoning_enabled: None,