feat(providers): add responses-mode chat-completions fallback (#2417)

This commit is contained in:
Argenis 2026-03-01 13:22:55 -05:00 committed by GitHub
parent d024877ba8
commit 561c4765e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,6 +29,7 @@ use tokio_tungstenite::{
/// A provider that speaks the OpenAI-compatible chat completions API.
/// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot,
/// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc.
#[derive(Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct OpenAiCompatibleProvider {
pub(crate) name: String,
@ -1147,6 +1148,90 @@ impl OpenAiCompatibleProvider {
self.api_mode == CompatibleApiMode::OpenAiResponses
}
fn chat_completions_fallback_provider(&self) -> Self {
let mut provider = self.clone();
provider.api_mode = CompatibleApiMode::OpenAiChatCompletions;
provider.supports_responses_fallback = false;
provider
}
fn error_status_code(error: &anyhow::Error) -> Option<reqwest::StatusCode> {
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
if let Some(status) = reqwest_error.status() {
return Some(status);
}
}
let message = error.to_string();
for token in message.split(|c: char| !c.is_ascii_digit()) {
let Ok(code) = token.parse::<u16>() else {
continue;
};
if let Ok(status) = reqwest::StatusCode::from_u16(code) {
if status.is_client_error() || status.is_server_error() {
return Some(status);
}
}
}
None
}
fn is_authentication_error(error: &anyhow::Error) -> bool {
if let Some(status) = Self::error_status_code(error) {
if status == reqwest::StatusCode::UNAUTHORIZED
|| status == reqwest::StatusCode::FORBIDDEN
{
return true;
}
}
let lower = error.to_string().to_ascii_lowercase();
let auth_hints = [
"invalid api key",
"incorrect api key",
"missing api key",
"api key not set",
"authentication failed",
"auth failed",
"unauthorized",
"forbidden",
"permission denied",
"access denied",
"invalid token",
];
auth_hints.iter().any(|hint| lower.contains(hint))
}
fn should_fallback_to_chat_completions(error: &anyhow::Error) -> bool {
if Self::is_authentication_error(error) {
return false;
}
if let Some(status) = Self::error_status_code(error) {
return status == reqwest::StatusCode::NOT_FOUND
|| status == reqwest::StatusCode::REQUEST_TIMEOUT
|| status == reqwest::StatusCode::TOO_MANY_REQUESTS
|| status.is_server_error();
}
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
if reqwest_error.is_connect()
|| reqwest_error.is_timeout()
|| reqwest_error.is_request()
|| reqwest_error.is_body()
|| reqwest_error.is_decode()
{
return true;
}
}
let lower = error.to_string().to_ascii_lowercase();
lower.contains("responses api returned an unexpected payload")
|| lower.contains("no response from")
}
fn effective_max_tokens(&self) -> Option<u32> {
self.max_tokens_override.filter(|value| *value > 0)
}
@ -1331,8 +1416,10 @@ impl OpenAiCompatibleProvider {
.await?;
if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
anyhow::bail!("{} Responses API error: {error}", self.name);
let sanitized = super::sanitize_api_error(&error);
anyhow::bail!("{} Responses API error ({status}): {sanitized}", self.name);
}
let body = response.text().await?;
@ -1383,10 +1470,37 @@ impl OpenAiCompatibleProvider {
credential: &str,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let responses = self
let responses = match self
.send_responses_request(credential, messages, model, None)
.await?;
.await
{
Ok(response) => response,
Err(responses_err) => {
if self.should_use_responses_mode()
&& Self::should_fallback_to_chat_completions(&responses_err)
{
tracing::warn!(
provider = %self.name,
error = %responses_err,
"Responses API request failed in responses mode; retrying via chat completions"
);
let fallback_provider = self.chat_completions_fallback_provider();
let sanitized = super::sanitize_api_error(&responses_err.to_string());
return fallback_provider
.chat_with_history(messages, model, temperature)
.await
.map_err(|chat_err| {
anyhow::anyhow!(
"{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})",
self.name
)
});
}
return Err(responses_err);
}
};
extract_responses_text(&responses)
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
}
@ -1397,10 +1511,51 @@ impl OpenAiCompatibleProvider {
messages: &[ChatMessage],
model: &str,
tools: Option<Vec<Value>>,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let responses = self
.send_responses_request(credential, messages, model, tools)
.await?;
let responses = match self
.send_responses_request(credential, messages, model, tools.clone())
.await
{
Ok(response) => response,
Err(responses_err) => {
if self.should_use_responses_mode()
&& Self::should_fallback_to_chat_completions(&responses_err)
{
tracing::warn!(
provider = %self.name,
error = %responses_err,
"Responses API request failed in responses mode; retrying via chat completions"
);
let fallback_provider = self.chat_completions_fallback_provider();
let fallback_tool_specs = tools
.as_deref()
.map(Self::openai_tools_to_tool_specs)
.unwrap_or_default();
let fallback_tools =
(!fallback_tool_specs.is_empty()).then_some(fallback_tool_specs.as_slice());
let sanitized = super::sanitize_api_error(&responses_err.to_string());
return fallback_provider
.chat(
ProviderChatRequest {
messages,
tools: fallback_tools,
},
model,
temperature,
)
.await
.map_err(|chat_err| {
anyhow::anyhow!(
"{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})",
self.name
)
});
}
return Err(responses_err);
}
};
let parsed = parse_responses_chat_response(responses);
if parsed.text.is_none() && parsed.tool_calls.is_empty() {
anyhow::bail!("No response from {} Responses API", self.name);
@ -1714,7 +1869,7 @@ impl Provider for OpenAiCompatibleProvider {
if self.should_use_responses_mode() {
return self
.chat_via_responses(credential, &fallback_messages, model)
.chat_via_responses(credential, &fallback_messages, model, temperature)
.await;
}
@ -1728,7 +1883,7 @@ impl Provider for OpenAiCompatibleProvider {
if self.supports_responses_fallback {
let sanitized = super::sanitize_api_error(&chat_error.to_string());
return self
.chat_via_responses(credential, &fallback_messages, model)
.chat_via_responses(credential, &fallback_messages, model, temperature)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@ -1749,7 +1904,7 @@ impl Provider for OpenAiCompatibleProvider {
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self
.chat_via_responses(credential, &fallback_messages, model)
.chat_via_responses(credential, &fallback_messages, model, temperature)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@ -1830,7 +1985,7 @@ impl Provider for OpenAiCompatibleProvider {
if self.should_use_responses_mode() {
return self
.chat_via_responses(credential, &effective_messages, model)
.chat_via_responses(credential, &effective_messages, model, temperature)
.await;
}
@ -1845,7 +2000,7 @@ impl Provider for OpenAiCompatibleProvider {
if self.supports_responses_fallback {
let sanitized = super::sanitize_api_error(&chat_error.to_string());
return self
.chat_via_responses(credential, &effective_messages, model)
.chat_via_responses(credential, &effective_messages, model, temperature)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@ -1865,7 +2020,7 @@ impl Provider for OpenAiCompatibleProvider {
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self
.chat_via_responses(credential, &effective_messages, model)
.chat_via_responses(credential, &effective_messages, model, temperature)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@ -1960,6 +2115,7 @@ impl Provider for OpenAiCompatibleProvider {
&effective_messages,
model,
(!tools.is_empty()).then(|| tools.to_vec()),
temperature,
)
.await;
}
@ -2011,6 +2167,7 @@ impl Provider for OpenAiCompatibleProvider {
&effective_messages,
model,
(!tools.is_empty()).then(|| tools.to_vec()),
temperature,
)
.await;
}
@ -2098,6 +2255,7 @@ impl Provider for OpenAiCompatibleProvider {
&effective_messages,
model,
response_tools.clone(),
temperature,
)
.await;
}
@ -2121,6 +2279,7 @@ impl Provider for OpenAiCompatibleProvider {
&effective_messages,
model,
response_tools.clone(),
temperature,
)
.await
.map_err(|responses_err| {
@ -2158,6 +2317,7 @@ impl Provider for OpenAiCompatibleProvider {
&effective_messages,
model,
response_tools.clone(),
temperature,
)
.await
.map_err(|responses_err| {
@ -2674,7 +2834,12 @@ mod tests {
async fn chat_via_responses_requires_non_system_message() {
let provider = make_provider("custom", "https://api.example.com", Some("test-key"));
let err = provider
.chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test")
.chat_via_responses(
"test-key",
&[ChatMessage::system("policy")],
"gpt-test",
0.7,
)
.await
.expect_err("system-only fallback payload should fail");
@ -2683,6 +2848,278 @@ mod tests {
.contains("requires at least one non-system message"));
}
#[tokio::test]
async fn responses_mode_falls_back_to_chat_completions_on_responses_404() {
#[derive(Clone, Default)]
struct FallbackState {
hits: Arc<Mutex<Vec<String>>>,
}
async fn responses_endpoint(
State(state): State<FallbackState>,
Json(_payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
state.hits.lock().await.push("responses".to_string());
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": { "message": "responses endpoint unavailable" }
})),
)
}
async fn chat_endpoint(
State(state): State<FallbackState>,
Json(payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
state.hits.lock().await.push("chat".to_string());
assert_eq!(
payload.get("model").and_then(Value::as_str),
Some("test-model")
);
(
StatusCode::OK,
Json(serde_json::json!({
"choices": [{
"message": {
"content": "chat fallback ok"
}
}]
})),
)
}
let state = FallbackState::default();
let app = Router::new()
.route("/v1/responses", post(responses_endpoint))
.route("/chat/completions", post(chat_endpoint))
.with_state(state.clone());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test server");
let addr = listener.local_addr().expect("server local addr");
let server = tokio::spawn(async move {
axum::serve(listener, app).await.expect("serve test app");
});
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
"custom",
&format!("http://{}", addr),
Some("test-key"),
AuthStyle::Bearer,
false,
CompatibleApiMode::OpenAiResponses,
None,
);
let text = provider
.chat_with_system(Some("system"), "hello", "test-model", 0.2)
.await
.expect("responses 404 should retry chat completions in responses mode");
assert_eq!(text, "chat fallback ok");
let hits = state.hits.lock().await.clone();
assert_eq!(
hits,
vec!["responses".to_string(), "chat".to_string()],
"must attempt responses first, then chat-completions fallback"
);
server.abort();
let _ = server.await;
}
#[tokio::test]
async fn responses_mode_does_not_fallback_to_chat_completions_on_auth_error() {
#[derive(Clone, Default)]
struct AuthFailureState {
hits: Arc<Mutex<Vec<String>>>,
}
async fn responses_endpoint(
State(state): State<AuthFailureState>,
Json(_payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
state.hits.lock().await.push("responses".to_string());
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": { "message": "invalid api key" }
})),
)
}
async fn chat_endpoint(
State(state): State<AuthFailureState>,
Json(_payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
state.hits.lock().await.push("chat".to_string());
(
StatusCode::OK,
Json(serde_json::json!({
"choices": [{
"message": {
"content": "should not be reached"
}
}]
})),
)
}
let state = AuthFailureState::default();
let app = Router::new()
.route("/v1/responses", post(responses_endpoint))
.route("/chat/completions", post(chat_endpoint))
.with_state(state.clone());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test server");
let addr = listener.local_addr().expect("server local addr");
let server = tokio::spawn(async move {
axum::serve(listener, app).await.expect("serve test app");
});
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
"custom",
&format!("http://{}", addr),
Some("test-key"),
AuthStyle::Bearer,
false,
CompatibleApiMode::OpenAiResponses,
None,
);
let err = provider
.chat_with_system(None, "hello", "test-model", 0.2)
.await
.expect_err("auth errors should not trigger chat-completions fallback");
assert!(err.to_string().contains("401"));
let hits = state.hits.lock().await.clone();
assert_eq!(
hits,
vec!["responses".to_string()],
"auth failures must not trigger fallback chat attempt"
);
server.abort();
let _ = server.await;
}
#[tokio::test]
async fn responses_mode_native_chat_falls_back_and_preserves_tool_call_id() {
#[derive(Clone, Default)]
struct NativeFallbackState {
hits: Arc<Mutex<Vec<String>>>,
chat_payloads: Arc<Mutex<Vec<Value>>>,
}
async fn responses_endpoint(
State(state): State<NativeFallbackState>,
Json(_payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
state.hits.lock().await.push("responses".to_string());
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "responses backend unavailable" }
})),
)
}
async fn chat_endpoint(
State(state): State<NativeFallbackState>,
Json(payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
state.hits.lock().await.push("chat".to_string());
state.chat_payloads.lock().await.push(payload);
(
StatusCode::OK,
Json(serde_json::json!({
"choices": [{
"message": {
"content": null,
"tool_calls": [{
"id": "call_abc",
"type": "function",
"function": {
"name": "shell",
"arguments": "{\"command\":\"pwd\"}"
}
}]
}
}]
})),
)
}
let state = NativeFallbackState::default();
let app = Router::new()
.route("/v1/responses", post(responses_endpoint))
.route("/chat/completions", post(chat_endpoint))
.with_state(state.clone());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test server");
let addr = listener.local_addr().expect("server local addr");
let server = tokio::spawn(async move {
axum::serve(listener, app).await.expect("serve test app");
});
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
"custom",
&format!("http://{}", addr),
Some("test-key"),
AuthStyle::Bearer,
false,
CompatibleApiMode::OpenAiResponses,
None,
);
let messages = vec![ChatMessage::user("run a command")];
let tools = vec![crate::tools::ToolSpec {
name: "shell".to_string(),
description: "Run a command".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {"command": {"type": "string"}},
"required": ["command"]
}),
}];
let result = provider
.chat(
ProviderChatRequest {
messages: &messages,
tools: Some(&tools),
},
"test-model",
0.2,
)
.await
.expect("responses server errors should retry via native chat-completions");
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].id, "call_abc");
assert_eq!(result.tool_calls[0].name, "shell");
let hits = state.hits.lock().await.clone();
assert_eq!(
hits,
vec!["responses".to_string(), "chat".to_string()],
"responses mode should retry via chat for retryable errors"
);
let chat_payloads = state.chat_payloads.lock().await;
assert_eq!(chat_payloads.len(), 1);
assert!(
chat_payloads[0].get("tools").is_some(),
"fallback native chat request should preserve tool schema"
);
server.abort();
let _ = server.await;
}
#[test]
fn tool_call_function_name_falls_back_to_top_level_name() {
let call: ToolCall = serde_json::from_value(serde_json::json!({