From 561c4765e10b96a34663c73bd43c6b4c2d14b12e Mon Sep 17 00:00:00 2001 From: Argenis Date: Sun, 1 Mar 2026 13:22:55 -0500 Subject: [PATCH] feat(providers): add responses-mode chat-completions fallback (#2417) --- src/providers/compatible.rs | 463 +++++++++++++++++++++++++++++++++++- 1 file changed, 450 insertions(+), 13 deletions(-) diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 3a4bed581..10fb97be4 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -29,6 +29,7 @@ use tokio_tungstenite::{ /// A provider that speaks the OpenAI-compatible chat completions API. /// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot, /// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc. +#[derive(Clone)] #[allow(clippy::struct_excessive_bools)] pub struct OpenAiCompatibleProvider { pub(crate) name: String, @@ -1147,6 +1148,90 @@ impl OpenAiCompatibleProvider { self.api_mode == CompatibleApiMode::OpenAiResponses } + fn chat_completions_fallback_provider(&self) -> Self { + let mut provider = self.clone(); + provider.api_mode = CompatibleApiMode::OpenAiChatCompletions; + provider.supports_responses_fallback = false; + provider + } + + fn error_status_code(error: &anyhow::Error) -> Option { + if let Some(reqwest_error) = error.downcast_ref::() { + if let Some(status) = reqwest_error.status() { + return Some(status); + } + } + + let message = error.to_string(); + for token in message.split(|c: char| !c.is_ascii_digit()) { + let Ok(code) = token.parse::() else { + continue; + }; + if let Ok(status) = reqwest::StatusCode::from_u16(code) { + if status.is_client_error() || status.is_server_error() { + return Some(status); + } + } + } + + None + } + + fn is_authentication_error(error: &anyhow::Error) -> bool { + if let Some(status) = Self::error_status_code(error) { + if status == reqwest::StatusCode::UNAUTHORIZED + || status == reqwest::StatusCode::FORBIDDEN + { + return true; + } + } + + let lower = error.to_string().to_ascii_lowercase(); + let auth_hints = [ + "invalid api key", + "incorrect api key", + "missing api key", + "api key not set", + "authentication failed", + "auth failed", + "unauthorized", + "forbidden", + "permission denied", + "access denied", + "invalid token", + ]; + + auth_hints.iter().any(|hint| lower.contains(hint)) + } + + fn should_fallback_to_chat_completions(error: &anyhow::Error) -> bool { + if Self::is_authentication_error(error) { + return false; + } + + if let Some(status) = Self::error_status_code(error) { + return status == reqwest::StatusCode::NOT_FOUND + || status == reqwest::StatusCode::REQUEST_TIMEOUT + || status == reqwest::StatusCode::TOO_MANY_REQUESTS + || status.is_server_error(); + } + + if let Some(reqwest_error) = error.downcast_ref::() { + if reqwest_error.is_connect() + || reqwest_error.is_timeout() + || reqwest_error.is_request() + || reqwest_error.is_body() + || reqwest_error.is_decode() + { + return true; + } + } + + let lower = error.to_string().to_ascii_lowercase(); + lower.contains("responses api returned an unexpected payload") + || lower.contains("no response from") + } + fn effective_max_tokens(&self) -> Option { self.max_tokens_override.filter(|value| *value > 0) } @@ -1331,8 +1416,10 @@ impl OpenAiCompatibleProvider { .await?; if !response.status().is_success() { + let status = response.status(); let error = response.text().await?; - anyhow::bail!("{} Responses API error: {error}", self.name); + let sanitized = super::sanitize_api_error(&error); + anyhow::bail!("{} Responses API error ({status}): {sanitized}", self.name); } let body = response.text().await?; @@ -1383,10 +1470,37 @@ impl OpenAiCompatibleProvider { credential: &str, messages: &[ChatMessage], model: &str, + temperature: f64, ) -> anyhow::Result { - let responses = self + let responses = match self .send_responses_request(credential, messages, model, None) - .await?; + .await + { + Ok(response) => response, + Err(responses_err) => { + if self.should_use_responses_mode() + && Self::should_fallback_to_chat_completions(&responses_err) + { + tracing::warn!( + provider = %self.name, + error = %responses_err, + "Responses API request failed in responses mode; retrying via chat completions" + ); + let fallback_provider = self.chat_completions_fallback_provider(); + let sanitized = super::sanitize_api_error(&responses_err.to_string()); + return fallback_provider + .chat_with_history(messages, model, temperature) + .await + .map_err(|chat_err| { + anyhow::anyhow!( + "{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})", + self.name + ) + }); + } + return Err(responses_err); + } + }; extract_responses_text(&responses) .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) } @@ -1397,10 +1511,51 @@ impl OpenAiCompatibleProvider { messages: &[ChatMessage], model: &str, tools: Option>, + temperature: f64, ) -> anyhow::Result { - let responses = self - .send_responses_request(credential, messages, model, tools) - .await?; + let responses = match self + .send_responses_request(credential, messages, model, tools.clone()) + .await + { + Ok(response) => response, + Err(responses_err) => { + if self.should_use_responses_mode() + && Self::should_fallback_to_chat_completions(&responses_err) + { + tracing::warn!( + provider = %self.name, + error = %responses_err, + "Responses API request failed in responses mode; retrying via chat completions" + ); + let fallback_provider = self.chat_completions_fallback_provider(); + let fallback_tool_specs = tools + .as_deref() + .map(Self::openai_tools_to_tool_specs) + .unwrap_or_default(); + let fallback_tools = + (!fallback_tool_specs.is_empty()).then_some(fallback_tool_specs.as_slice()); + let sanitized = super::sanitize_api_error(&responses_err.to_string()); + + return fallback_provider + .chat( + ProviderChatRequest { + messages, + tools: fallback_tools, + }, + model, + temperature, + ) + .await + .map_err(|chat_err| { + anyhow::anyhow!( + "{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})", + self.name + ) + }); + } + return Err(responses_err); + } + }; let parsed = parse_responses_chat_response(responses); if parsed.text.is_none() && parsed.tool_calls.is_empty() { anyhow::bail!("No response from {} Responses API", self.name); @@ -1714,7 +1869,7 @@ impl Provider for OpenAiCompatibleProvider { if self.should_use_responses_mode() { return self - .chat_via_responses(credential, &fallback_messages, model) + .chat_via_responses(credential, &fallback_messages, model, temperature) .await; } @@ -1728,7 +1883,7 @@ impl Provider for OpenAiCompatibleProvider { if self.supports_responses_fallback { let sanitized = super::sanitize_api_error(&chat_error.to_string()); return self - .chat_via_responses(credential, &fallback_messages, model) + .chat_via_responses(credential, &fallback_messages, model, temperature) .await .map_err(|responses_err| { anyhow::anyhow!( @@ -1749,7 +1904,7 @@ impl Provider for OpenAiCompatibleProvider { if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { return self - .chat_via_responses(credential, &fallback_messages, model) + .chat_via_responses(credential, &fallback_messages, model, temperature) .await .map_err(|responses_err| { anyhow::anyhow!( @@ -1830,7 +1985,7 @@ impl Provider for OpenAiCompatibleProvider { if self.should_use_responses_mode() { return self - .chat_via_responses(credential, &effective_messages, model) + .chat_via_responses(credential, &effective_messages, model, temperature) .await; } @@ -1845,7 +2000,7 @@ impl Provider for OpenAiCompatibleProvider { if self.supports_responses_fallback { let sanitized = super::sanitize_api_error(&chat_error.to_string()); return self - .chat_via_responses(credential, &effective_messages, model) + .chat_via_responses(credential, &effective_messages, model, temperature) .await .map_err(|responses_err| { anyhow::anyhow!( @@ -1865,7 +2020,7 @@ impl Provider for OpenAiCompatibleProvider { // Mirror chat_with_system: 404 may mean this provider uses the Responses API if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { return self - .chat_via_responses(credential, &effective_messages, model) + .chat_via_responses(credential, &effective_messages, model, temperature) .await .map_err(|responses_err| { anyhow::anyhow!( @@ -1960,6 +2115,7 @@ impl Provider for OpenAiCompatibleProvider { &effective_messages, model, (!tools.is_empty()).then(|| tools.to_vec()), + temperature, ) .await; } @@ -2011,6 +2167,7 @@ impl Provider for OpenAiCompatibleProvider { &effective_messages, model, (!tools.is_empty()).then(|| tools.to_vec()), + temperature, ) .await; } @@ -2098,6 +2255,7 @@ impl Provider for OpenAiCompatibleProvider { &effective_messages, model, response_tools.clone(), + temperature, ) .await; } @@ -2121,6 +2279,7 @@ impl Provider for OpenAiCompatibleProvider { &effective_messages, model, response_tools.clone(), + temperature, ) .await .map_err(|responses_err| { @@ -2158,6 +2317,7 @@ impl Provider for OpenAiCompatibleProvider { &effective_messages, model, response_tools.clone(), + temperature, ) .await .map_err(|responses_err| { @@ -2674,7 +2834,12 @@ mod tests { async fn chat_via_responses_requires_non_system_message() { let provider = make_provider("custom", "https://api.example.com", Some("test-key")); let err = provider - .chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test") + .chat_via_responses( + "test-key", + &[ChatMessage::system("policy")], + "gpt-test", + 0.7, + ) .await .expect_err("system-only fallback payload should fail"); @@ -2683,6 +2848,278 @@ mod tests { .contains("requires at least one non-system message")); } + #[tokio::test] + async fn responses_mode_falls_back_to_chat_completions_on_responses_404() { + #[derive(Clone, Default)] + struct FallbackState { + hits: Arc>>, + } + + async fn responses_endpoint( + State(state): State, + Json(_payload): Json, + ) -> (StatusCode, Json) { + state.hits.lock().await.push("responses".to_string()); + ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ + "error": { "message": "responses endpoint unavailable" } + })), + ) + } + + async fn chat_endpoint( + State(state): State, + Json(payload): Json, + ) -> (StatusCode, Json) { + state.hits.lock().await.push("chat".to_string()); + assert_eq!( + payload.get("model").and_then(Value::as_str), + Some("test-model") + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "choices": [{ + "message": { + "content": "chat fallback ok" + } + }] + })), + ) + } + + let state = FallbackState::default(); + let app = Router::new() + .route("/v1/responses", post(responses_endpoint)) + .route("/chat/completions", post(chat_endpoint)) + .with_state(state.clone()); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("server local addr"); + let server = tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve test app"); + }); + + let provider = OpenAiCompatibleProvider::new_custom_with_mode( + "custom", + &format!("http://{}", addr), + Some("test-key"), + AuthStyle::Bearer, + false, + CompatibleApiMode::OpenAiResponses, + None, + ); + let text = provider + .chat_with_system(Some("system"), "hello", "test-model", 0.2) + .await + .expect("responses 404 should retry chat completions in responses mode"); + assert_eq!(text, "chat fallback ok"); + + let hits = state.hits.lock().await.clone(); + assert_eq!( + hits, + vec!["responses".to_string(), "chat".to_string()], + "must attempt responses first, then chat-completions fallback" + ); + + server.abort(); + let _ = server.await; + } + + #[tokio::test] + async fn responses_mode_does_not_fallback_to_chat_completions_on_auth_error() { + #[derive(Clone, Default)] + struct AuthFailureState { + hits: Arc>>, + } + + async fn responses_endpoint( + State(state): State, + Json(_payload): Json, + ) -> (StatusCode, Json) { + state.hits.lock().await.push("responses".to_string()); + ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({ + "error": { "message": "invalid api key" } + })), + ) + } + + async fn chat_endpoint( + State(state): State, + Json(_payload): Json, + ) -> (StatusCode, Json) { + state.hits.lock().await.push("chat".to_string()); + ( + StatusCode::OK, + Json(serde_json::json!({ + "choices": [{ + "message": { + "content": "should not be reached" + } + }] + })), + ) + } + + let state = AuthFailureState::default(); + let app = Router::new() + .route("/v1/responses", post(responses_endpoint)) + .route("/chat/completions", post(chat_endpoint)) + .with_state(state.clone()); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("server local addr"); + let server = tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve test app"); + }); + + let provider = OpenAiCompatibleProvider::new_custom_with_mode( + "custom", + &format!("http://{}", addr), + Some("test-key"), + AuthStyle::Bearer, + false, + CompatibleApiMode::OpenAiResponses, + None, + ); + let err = provider + .chat_with_system(None, "hello", "test-model", 0.2) + .await + .expect_err("auth errors should not trigger chat-completions fallback"); + assert!(err.to_string().contains("401")); + + let hits = state.hits.lock().await.clone(); + assert_eq!( + hits, + vec!["responses".to_string()], + "auth failures must not trigger fallback chat attempt" + ); + + server.abort(); + let _ = server.await; + } + + #[tokio::test] + async fn responses_mode_native_chat_falls_back_and_preserves_tool_call_id() { + #[derive(Clone, Default)] + struct NativeFallbackState { + hits: Arc>>, + chat_payloads: Arc>>, + } + + async fn responses_endpoint( + State(state): State, + Json(_payload): Json, + ) -> (StatusCode, Json) { + state.hits.lock().await.push("responses".to_string()); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": "responses backend unavailable" } + })), + ) + } + + async fn chat_endpoint( + State(state): State, + Json(payload): Json, + ) -> (StatusCode, Json) { + state.hits.lock().await.push("chat".to_string()); + state.chat_payloads.lock().await.push(payload); + ( + StatusCode::OK, + Json(serde_json::json!({ + "choices": [{ + "message": { + "content": null, + "tool_calls": [{ + "id": "call_abc", + "type": "function", + "function": { + "name": "shell", + "arguments": "{\"command\":\"pwd\"}" + } + }] + } + }] + })), + ) + } + + let state = NativeFallbackState::default(); + let app = Router::new() + .route("/v1/responses", post(responses_endpoint)) + .route("/chat/completions", post(chat_endpoint)) + .with_state(state.clone()); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("server local addr"); + let server = tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve test app"); + }); + + let provider = OpenAiCompatibleProvider::new_custom_with_mode( + "custom", + &format!("http://{}", addr), + Some("test-key"), + AuthStyle::Bearer, + false, + CompatibleApiMode::OpenAiResponses, + None, + ); + let messages = vec![ChatMessage::user("run a command")]; + let tools = vec![crate::tools::ToolSpec { + name: "shell".to_string(), + description: "Run a command".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": {"command": {"type": "string"}}, + "required": ["command"] + }), + }]; + let result = provider + .chat( + ProviderChatRequest { + messages: &messages, + tools: Some(&tools), + }, + "test-model", + 0.2, + ) + .await + .expect("responses server errors should retry via native chat-completions"); + + assert_eq!(result.tool_calls.len(), 1); + assert_eq!(result.tool_calls[0].id, "call_abc"); + assert_eq!(result.tool_calls[0].name, "shell"); + + let hits = state.hits.lock().await.clone(); + assert_eq!( + hits, + vec!["responses".to_string(), "chat".to_string()], + "responses mode should retry via chat for retryable errors" + ); + + let chat_payloads = state.chat_payloads.lock().await; + assert_eq!(chat_payloads.len(), 1); + assert!( + chat_payloads[0].get("tools").is_some(), + "fallback native chat request should preserve tool schema" + ); + + server.abort(); + let _ = server.await; + } + #[test] fn tool_call_function_name_falls_back_to_top_level_name() { let call: ToolCall = serde_json::from_value(serde_json::json!({