feat(providers): add responses-mode chat-completions fallback (#2417)
This commit is contained in:
parent
d024877ba8
commit
561c4765e1
@ -29,6 +29,7 @@ use tokio_tungstenite::{
|
||||
/// A provider that speaks the OpenAI-compatible chat completions API.
|
||||
/// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot,
|
||||
/// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc.
|
||||
#[derive(Clone)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct OpenAiCompatibleProvider {
|
||||
pub(crate) name: String,
|
||||
@ -1147,6 +1148,90 @@ impl OpenAiCompatibleProvider {
|
||||
self.api_mode == CompatibleApiMode::OpenAiResponses
|
||||
}
|
||||
|
||||
fn chat_completions_fallback_provider(&self) -> Self {
|
||||
let mut provider = self.clone();
|
||||
provider.api_mode = CompatibleApiMode::OpenAiChatCompletions;
|
||||
provider.supports_responses_fallback = false;
|
||||
provider
|
||||
}
|
||||
|
||||
fn error_status_code(error: &anyhow::Error) -> Option<reqwest::StatusCode> {
|
||||
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
|
||||
if let Some(status) = reqwest_error.status() {
|
||||
return Some(status);
|
||||
}
|
||||
}
|
||||
|
||||
let message = error.to_string();
|
||||
for token in message.split(|c: char| !c.is_ascii_digit()) {
|
||||
let Ok(code) = token.parse::<u16>() else {
|
||||
continue;
|
||||
};
|
||||
if let Ok(status) = reqwest::StatusCode::from_u16(code) {
|
||||
if status.is_client_error() || status.is_server_error() {
|
||||
return Some(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn is_authentication_error(error: &anyhow::Error) -> bool {
|
||||
if let Some(status) = Self::error_status_code(error) {
|
||||
if status == reqwest::StatusCode::UNAUTHORIZED
|
||||
|| status == reqwest::StatusCode::FORBIDDEN
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let lower = error.to_string().to_ascii_lowercase();
|
||||
let auth_hints = [
|
||||
"invalid api key",
|
||||
"incorrect api key",
|
||||
"missing api key",
|
||||
"api key not set",
|
||||
"authentication failed",
|
||||
"auth failed",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"permission denied",
|
||||
"access denied",
|
||||
"invalid token",
|
||||
];
|
||||
|
||||
auth_hints.iter().any(|hint| lower.contains(hint))
|
||||
}
|
||||
|
||||
fn should_fallback_to_chat_completions(error: &anyhow::Error) -> bool {
|
||||
if Self::is_authentication_error(error) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(status) = Self::error_status_code(error) {
|
||||
return status == reqwest::StatusCode::NOT_FOUND
|
||||
|| status == reqwest::StatusCode::REQUEST_TIMEOUT
|
||||
|| status == reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
|| status.is_server_error();
|
||||
}
|
||||
|
||||
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
|
||||
if reqwest_error.is_connect()
|
||||
|| reqwest_error.is_timeout()
|
||||
|| reqwest_error.is_request()
|
||||
|| reqwest_error.is_body()
|
||||
|| reqwest_error.is_decode()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let lower = error.to_string().to_ascii_lowercase();
|
||||
lower.contains("responses api returned an unexpected payload")
|
||||
|| lower.contains("no response from")
|
||||
}
|
||||
|
||||
fn effective_max_tokens(&self) -> Option<u32> {
|
||||
self.max_tokens_override.filter(|value| *value > 0)
|
||||
}
|
||||
@ -1331,8 +1416,10 @@ impl OpenAiCompatibleProvider {
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("{} Responses API error: {error}", self.name);
|
||||
let sanitized = super::sanitize_api_error(&error);
|
||||
anyhow::bail!("{} Responses API error ({status}): {sanitized}", self.name);
|
||||
}
|
||||
|
||||
let body = response.text().await?;
|
||||
@ -1383,10 +1470,37 @@ impl OpenAiCompatibleProvider {
|
||||
credential: &str,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let responses = self
|
||||
let responses = match self
|
||||
.send_responses_request(credential, messages, model, None)
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(responses_err) => {
|
||||
if self.should_use_responses_mode()
|
||||
&& Self::should_fallback_to_chat_completions(&responses_err)
|
||||
{
|
||||
tracing::warn!(
|
||||
provider = %self.name,
|
||||
error = %responses_err,
|
||||
"Responses API request failed in responses mode; retrying via chat completions"
|
||||
);
|
||||
let fallback_provider = self.chat_completions_fallback_provider();
|
||||
let sanitized = super::sanitize_api_error(&responses_err.to_string());
|
||||
return fallback_provider
|
||||
.chat_with_history(messages, model, temperature)
|
||||
.await
|
||||
.map_err(|chat_err| {
|
||||
anyhow::anyhow!(
|
||||
"{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})",
|
||||
self.name
|
||||
)
|
||||
});
|
||||
}
|
||||
return Err(responses_err);
|
||||
}
|
||||
};
|
||||
extract_responses_text(&responses)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
|
||||
}
|
||||
@ -1397,10 +1511,51 @@ impl OpenAiCompatibleProvider {
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
tools: Option<Vec<Value>>,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let responses = self
|
||||
.send_responses_request(credential, messages, model, tools)
|
||||
.await?;
|
||||
let responses = match self
|
||||
.send_responses_request(credential, messages, model, tools.clone())
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(responses_err) => {
|
||||
if self.should_use_responses_mode()
|
||||
&& Self::should_fallback_to_chat_completions(&responses_err)
|
||||
{
|
||||
tracing::warn!(
|
||||
provider = %self.name,
|
||||
error = %responses_err,
|
||||
"Responses API request failed in responses mode; retrying via chat completions"
|
||||
);
|
||||
let fallback_provider = self.chat_completions_fallback_provider();
|
||||
let fallback_tool_specs = tools
|
||||
.as_deref()
|
||||
.map(Self::openai_tools_to_tool_specs)
|
||||
.unwrap_or_default();
|
||||
let fallback_tools =
|
||||
(!fallback_tool_specs.is_empty()).then_some(fallback_tool_specs.as_slice());
|
||||
let sanitized = super::sanitize_api_error(&responses_err.to_string());
|
||||
|
||||
return fallback_provider
|
||||
.chat(
|
||||
ProviderChatRequest {
|
||||
messages,
|
||||
tools: fallback_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
.map_err(|chat_err| {
|
||||
anyhow::anyhow!(
|
||||
"{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})",
|
||||
self.name
|
||||
)
|
||||
});
|
||||
}
|
||||
return Err(responses_err);
|
||||
}
|
||||
};
|
||||
let parsed = parse_responses_chat_response(responses);
|
||||
if parsed.text.is_none() && parsed.tool_calls.is_empty() {
|
||||
anyhow::bail!("No response from {} Responses API", self.name);
|
||||
@ -1714,7 +1869,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
|
||||
if self.should_use_responses_mode() {
|
||||
return self
|
||||
.chat_via_responses(credential, &fallback_messages, model)
|
||||
.chat_via_responses(credential, &fallback_messages, model, temperature)
|
||||
.await;
|
||||
}
|
||||
|
||||
@ -1728,7 +1883,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
if self.supports_responses_fallback {
|
||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||
return self
|
||||
.chat_via_responses(credential, &fallback_messages, model)
|
||||
.chat_via_responses(credential, &fallback_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1749,7 +1904,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(credential, &fallback_messages, model)
|
||||
.chat_via_responses(credential, &fallback_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1830,7 +1985,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
|
||||
if self.should_use_responses_mode() {
|
||||
return self
|
||||
.chat_via_responses(credential, &effective_messages, model)
|
||||
.chat_via_responses(credential, &effective_messages, model, temperature)
|
||||
.await;
|
||||
}
|
||||
|
||||
@ -1845,7 +2000,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
if self.supports_responses_fallback {
|
||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||
return self
|
||||
.chat_via_responses(credential, &effective_messages, model)
|
||||
.chat_via_responses(credential, &effective_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1865,7 +2020,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(credential, &effective_messages, model)
|
||||
.chat_via_responses(credential, &effective_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1960,6 +2115,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
(!tools.is_empty()).then(|| tools.to_vec()),
|
||||
temperature,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -2011,6 +2167,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
(!tools.is_empty()).then(|| tools.to_vec()),
|
||||
temperature,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -2098,6 +2255,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
response_tools.clone(),
|
||||
temperature,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -2121,6 +2279,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
response_tools.clone(),
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
@ -2158,6 +2317,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
response_tools.clone(),
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
@ -2674,7 +2834,12 @@ mod tests {
|
||||
async fn chat_via_responses_requires_non_system_message() {
|
||||
let provider = make_provider("custom", "https://api.example.com", Some("test-key"));
|
||||
let err = provider
|
||||
.chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test")
|
||||
.chat_via_responses(
|
||||
"test-key",
|
||||
&[ChatMessage::system("policy")],
|
||||
"gpt-test",
|
||||
0.7,
|
||||
)
|
||||
.await
|
||||
.expect_err("system-only fallback payload should fail");
|
||||
|
||||
@ -2683,6 +2848,278 @@ mod tests {
|
||||
.contains("requires at least one non-system message"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_mode_falls_back_to_chat_completions_on_responses_404() {
|
||||
#[derive(Clone, Default)]
|
||||
struct FallbackState {
|
||||
hits: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
async fn responses_endpoint(
|
||||
State(state): State<FallbackState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("responses".to_string());
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "responses endpoint unavailable" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_endpoint(
|
||||
State(state): State<FallbackState>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("chat".to_string());
|
||||
assert_eq!(
|
||||
payload.get("model").and_then(Value::as_str),
|
||||
Some("test-model")
|
||||
);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "chat fallback ok"
|
||||
}
|
||||
}]
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
let state = FallbackState::default();
|
||||
let app = Router::new()
|
||||
.route("/v1/responses", post(responses_endpoint))
|
||||
.route("/chat/completions", post(chat_endpoint))
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind test server");
|
||||
let addr = listener.local_addr().expect("server local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.expect("serve test app");
|
||||
});
|
||||
|
||||
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
|
||||
"custom",
|
||||
&format!("http://{}", addr),
|
||||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
false,
|
||||
CompatibleApiMode::OpenAiResponses,
|
||||
None,
|
||||
);
|
||||
let text = provider
|
||||
.chat_with_system(Some("system"), "hello", "test-model", 0.2)
|
||||
.await
|
||||
.expect("responses 404 should retry chat completions in responses mode");
|
||||
assert_eq!(text, "chat fallback ok");
|
||||
|
||||
let hits = state.hits.lock().await.clone();
|
||||
assert_eq!(
|
||||
hits,
|
||||
vec!["responses".to_string(), "chat".to_string()],
|
||||
"must attempt responses first, then chat-completions fallback"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_mode_does_not_fallback_to_chat_completions_on_auth_error() {
|
||||
#[derive(Clone, Default)]
|
||||
struct AuthFailureState {
|
||||
hits: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
async fn responses_endpoint(
|
||||
State(state): State<AuthFailureState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("responses".to_string());
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "invalid api key" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_endpoint(
|
||||
State(state): State<AuthFailureState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("chat".to_string());
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "should not be reached"
|
||||
}
|
||||
}]
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
let state = AuthFailureState::default();
|
||||
let app = Router::new()
|
||||
.route("/v1/responses", post(responses_endpoint))
|
||||
.route("/chat/completions", post(chat_endpoint))
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind test server");
|
||||
let addr = listener.local_addr().expect("server local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.expect("serve test app");
|
||||
});
|
||||
|
||||
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
|
||||
"custom",
|
||||
&format!("http://{}", addr),
|
||||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
false,
|
||||
CompatibleApiMode::OpenAiResponses,
|
||||
None,
|
||||
);
|
||||
let err = provider
|
||||
.chat_with_system(None, "hello", "test-model", 0.2)
|
||||
.await
|
||||
.expect_err("auth errors should not trigger chat-completions fallback");
|
||||
assert!(err.to_string().contains("401"));
|
||||
|
||||
let hits = state.hits.lock().await.clone();
|
||||
assert_eq!(
|
||||
hits,
|
||||
vec!["responses".to_string()],
|
||||
"auth failures must not trigger fallback chat attempt"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_mode_native_chat_falls_back_and_preserves_tool_call_id() {
|
||||
#[derive(Clone, Default)]
|
||||
struct NativeFallbackState {
|
||||
hits: Arc<Mutex<Vec<String>>>,
|
||||
chat_payloads: Arc<Mutex<Vec<Value>>>,
|
||||
}
|
||||
|
||||
async fn responses_endpoint(
|
||||
State(state): State<NativeFallbackState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("responses".to_string());
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "responses backend unavailable" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_endpoint(
|
||||
State(state): State<NativeFallbackState>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("chat".to_string());
|
||||
state.chat_payloads.lock().await.push(payload);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": "{\"command\":\"pwd\"}"
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
let state = NativeFallbackState::default();
|
||||
let app = Router::new()
|
||||
.route("/v1/responses", post(responses_endpoint))
|
||||
.route("/chat/completions", post(chat_endpoint))
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind test server");
|
||||
let addr = listener.local_addr().expect("server local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.expect("serve test app");
|
||||
});
|
||||
|
||||
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
|
||||
"custom",
|
||||
&format!("http://{}", addr),
|
||||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
false,
|
||||
CompatibleApiMode::OpenAiResponses,
|
||||
None,
|
||||
);
|
||||
let messages = vec![ChatMessage::user("run a command")];
|
||||
let tools = vec![crate::tools::ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run a command".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"command": {"type": "string"}},
|
||||
"required": ["command"]
|
||||
}),
|
||||
}];
|
||||
let result = provider
|
||||
.chat(
|
||||
ProviderChatRequest {
|
||||
messages: &messages,
|
||||
tools: Some(&tools),
|
||||
},
|
||||
"test-model",
|
||||
0.2,
|
||||
)
|
||||
.await
|
||||
.expect("responses server errors should retry via native chat-completions");
|
||||
|
||||
assert_eq!(result.tool_calls.len(), 1);
|
||||
assert_eq!(result.tool_calls[0].id, "call_abc");
|
||||
assert_eq!(result.tool_calls[0].name, "shell");
|
||||
|
||||
let hits = state.hits.lock().await.clone();
|
||||
assert_eq!(
|
||||
hits,
|
||||
vec!["responses".to_string(), "chat".to_string()],
|
||||
"responses mode should retry via chat for retryable errors"
|
||||
);
|
||||
|
||||
let chat_payloads = state.chat_payloads.lock().await;
|
||||
assert_eq!(chat_payloads.len(), 1);
|
||||
assert!(
|
||||
chat_payloads[0].get("tools").is_some(),
|
||||
"fallback native chat request should preserve tool schema"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_function_name_falls_back_to_top_level_name() {
|
||||
let call: ToolCall = serde_json::from_value(serde_json::json!({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user