fix(provider): add chat() override to ReliableProvider for native tool calling
ReliableProvider was missing a chat() override, causing it to fall through to the default Provider::chat() trait implementation. The default implementation delegates to chat_with_history() which returns a plain String and wraps it in ChatResponse with tool_calls: Vec::new() — so native tool calling was completely broken through the retry/failover wrapper even though the underlying provider properly supports it. Changes: - Add chat() with full retry/backoff/failover logic matching existing chat_with_system(), chat_with_history(), and chat_with_tools() overrides - Include context_window_exceeded early-exit matching other method patterns - Add 7 focused tests: delegation with tool calls, retry recovery, supports_native_tools propagation, aggregated error reporting, model failover, non-retryable error skip, and system prompt zero-XML verification
This commit is contained in:
parent
04640a963e
commit
4fd41d5f2c
@ -3432,4 +3432,60 @@ Let me check the result."#;
|
||||
assert_eq!(history[0].role, "system");
|
||||
assert_eq!(history[1].content, "new msg");
|
||||
}
|
||||
|
||||
/// When `build_system_prompt_with_mode` is called with `native_tools = true`,
|
||||
/// the output must contain ZERO XML protocol artifacts. In the native path
|
||||
/// `build_tool_instructions` is never called, so the system prompt alone
|
||||
/// must be clean of XML tool-call protocol.
|
||||
#[test]
|
||||
fn native_tools_system_prompt_contains_zero_xml() {
|
||||
use crate::channels::build_system_prompt_with_mode;
|
||||
|
||||
let tool_summaries: Vec<(&str, &str)> = vec![
|
||||
("shell", "Execute shell commands"),
|
||||
("file_read", "Read files"),
|
||||
];
|
||||
|
||||
let system_prompt = build_system_prompt_with_mode(
|
||||
std::path::Path::new("/tmp"),
|
||||
"test-model",
|
||||
&tool_summaries,
|
||||
&[], // no skills
|
||||
None, // no identity config
|
||||
None, // no bootstrap_max_chars
|
||||
true, // native_tools
|
||||
);
|
||||
|
||||
// Must contain zero XML protocol artifacts
|
||||
assert!(
|
||||
!system_prompt.contains("<tool_call>"),
|
||||
"Native prompt must not contain <tool_call>"
|
||||
);
|
||||
assert!(
|
||||
!system_prompt.contains("</tool_call>"),
|
||||
"Native prompt must not contain </tool_call>"
|
||||
);
|
||||
assert!(
|
||||
!system_prompt.contains("<tool_result>"),
|
||||
"Native prompt must not contain <tool_result>"
|
||||
);
|
||||
assert!(
|
||||
!system_prompt.contains("</tool_result>"),
|
||||
"Native prompt must not contain </tool_result>"
|
||||
);
|
||||
assert!(
|
||||
!system_prompt.contains("## Tool Use Protocol"),
|
||||
"Native prompt must not contain XML protocol header"
|
||||
);
|
||||
|
||||
// Positive: native prompt should still list tools and contain task instructions
|
||||
assert!(
|
||||
system_prompt.contains("shell"),
|
||||
"Native prompt must list tool names"
|
||||
);
|
||||
assert!(
|
||||
system_prompt.contains("## Your Task"),
|
||||
"Native prompt should contain task instructions"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
use super::traits::{ChatMessage, ChatResponse, StreamChunk, StreamOptions, StreamResult};
|
||||
use super::traits::{
|
||||
ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamOptions, StreamResult,
|
||||
};
|
||||
use super::Provider;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{stream, StreamExt};
|
||||
@ -548,6 +550,115 @@ impl Provider for ReliableProvider {
|
||||
.any(|(_, provider)| provider.supports_vision())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let models = self.model_chain(model);
|
||||
let mut failures = Vec::new();
|
||||
|
||||
for current_model in &models {
|
||||
for (provider_name, provider) in &self.providers {
|
||||
let mut backoff_ms = self.base_backoff_ms;
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
let req = ChatRequest {
|
||||
messages: request.messages,
|
||||
tools: request.tools,
|
||||
};
|
||||
match provider.chat(req, current_model, temperature).await {
|
||||
Ok(resp) => {
|
||||
if attempt > 0 || *current_model != model {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt,
|
||||
original_model = model,
|
||||
"Provider recovered (failover/retry)"
|
||||
);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
let failure_reason = failure_reason(rate_limited, non_retryable);
|
||||
let error_detail = compact_error_detail(&e);
|
||||
|
||||
push_failure(
|
||||
&mut failures,
|
||||
provider_name,
|
||||
current_model,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
failure_reason,
|
||||
&error_detail,
|
||||
);
|
||||
|
||||
if rate_limited && !non_retryable_rate_limit {
|
||||
if let Some(new_key) = self.rotate_key() {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
error = %error_detail,
|
||||
"Rate limited, rotated API key (key ending ...{})",
|
||||
&new_key[new_key.len().saturating_sub(4)..]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if non_retryable {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
error = %error_detail,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
|
||||
if is_context_window_exceeded(&e) {
|
||||
anyhow::bail!(
|
||||
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if attempt < self.max_retries {
|
||||
let wait = self.compute_backoff(backoff_ms, &e);
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt = attempt + 1,
|
||||
backoff_ms = wait,
|
||||
reason = failure_reason,
|
||||
error = %error_detail,
|
||||
"Provider call failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(wait)).await;
|
||||
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
"Exhausted retries, trying next provider/model"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"All providers/models failed. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_with_tools(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
@ -1509,4 +1620,350 @@ mod tests {
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock provider that implements `chat()` with native tool support.
|
||||
struct NativeToolMock {
|
||||
calls: Arc<AtomicUsize>,
|
||||
fail_until_attempt: usize,
|
||||
response_text: &'static str,
|
||||
tool_calls: Vec<super::super::traits::ToolCall>,
|
||||
error: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for NativeToolMock {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(self.response_text.to_string())
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if attempt <= self.fail_until_attempt {
|
||||
anyhow::bail!(self.error);
|
||||
}
|
||||
Ok(ChatResponse {
|
||||
text: Some(self.response_text.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_delegates_to_inner_provider() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let tool_call = super::super::traits::ToolCall {
|
||||
id: "call_1".to_string(),
|
||||
name: "shell".to_string(),
|
||||
arguments: r#"{"command":"date"}"#.to_string(),
|
||||
};
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"primary".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: 0,
|
||||
response_text: "ok",
|
||||
tool_calls: vec![tool_call.clone()],
|
||||
error: "boom",
|
||||
}) as Box<dyn Provider>,
|
||||
)],
|
||||
2,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("what time is it?")];
|
||||
let request = ChatRequest {
|
||||
messages: &messages,
|
||||
tools: None,
|
||||
};
|
||||
let result = provider.chat(request, "test-model", 0.0).await.unwrap();
|
||||
|
||||
assert_eq!(result.text.as_deref(), Some("ok"));
|
||||
assert_eq!(result.tool_calls.len(), 1);
|
||||
assert_eq!(result.tool_calls[0].name, "shell");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_retries_and_recovers() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let tool_call = super::super::traits::ToolCall {
|
||||
id: "call_1".to_string(),
|
||||
name: "shell".to_string(),
|
||||
arguments: r#"{"command":"date"}"#.to_string(),
|
||||
};
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"primary".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: 2,
|
||||
response_text: "recovered",
|
||||
tool_calls: vec![tool_call],
|
||||
error: "temporary failure",
|
||||
}) as Box<dyn Provider>,
|
||||
)],
|
||||
3,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("test")];
|
||||
let request = ChatRequest {
|
||||
messages: &messages,
|
||||
tools: None,
|
||||
};
|
||||
let result = provider.chat(request, "test-model", 0.0).await.unwrap();
|
||||
|
||||
assert_eq!(result.text.as_deref(), Some("recovered"));
|
||||
assert!(
|
||||
calls.load(Ordering::SeqCst) > 1,
|
||||
"should have retried at least once"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_preserves_native_tools_support() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"primary".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: 0,
|
||||
response_text: "ok",
|
||||
tool_calls: vec![],
|
||||
error: "boom",
|
||||
}) as Box<dyn Provider>,
|
||||
)],
|
||||
2,
|
||||
1,
|
||||
);
|
||||
|
||||
assert!(
|
||||
provider.supports_native_tools(),
|
||||
"ReliableProvider must propagate supports_native_tools from inner provider"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Gap 2-4: Parity tests for chat() ────────────────────────
|
||||
|
||||
/// Gap 2: `chat()` returns an aggregated error when all providers fail,
|
||||
/// matching behavior of `returns_aggregated_error_when_all_providers_fail`.
|
||||
#[tokio::test]
|
||||
async fn chat_returns_aggregated_error_when_all_providers_fail() {
|
||||
let provider = ReliableProvider::new(
|
||||
vec![
|
||||
(
|
||||
"p1".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
fail_until_attempt: usize::MAX,
|
||||
response_text: "never",
|
||||
tool_calls: vec![],
|
||||
error: "p1 chat error",
|
||||
}) as Box<dyn Provider>,
|
||||
),
|
||||
(
|
||||
"p2".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
fail_until_attempt: usize::MAX,
|
||||
response_text: "never",
|
||||
tool_calls: vec![],
|
||||
error: "p2 chat error",
|
||||
}) as Box<dyn Provider>,
|
||||
),
|
||||
],
|
||||
0,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("hello")];
|
||||
let request = ChatRequest {
|
||||
messages: &messages,
|
||||
tools: None,
|
||||
};
|
||||
let err = provider
|
||||
.chat(request, "test", 0.0)
|
||||
.await
|
||||
.expect_err("all providers should fail");
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("All providers/models failed"));
|
||||
assert!(msg.contains("provider=p1 model=test"));
|
||||
assert!(msg.contains("provider=p2 model=test"));
|
||||
assert!(msg.contains("error=p1 chat error"));
|
||||
assert!(msg.contains("error=p2 chat error"));
|
||||
assert!(msg.contains("retryable"));
|
||||
}
|
||||
|
||||
/// Mock that records model names and can fail specific models,
|
||||
/// implementing `chat()` for native tool calling parity tests.
|
||||
struct NativeModelAwareMock {
|
||||
calls: Arc<AtomicUsize>,
|
||||
models_seen: parking_lot::Mutex<Vec<String>>,
|
||||
fail_models: Vec<&'static str>,
|
||||
response_text: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for NativeModelAwareMock {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(self.response_text.to_string())
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.models_seen.lock().push(model.to_string());
|
||||
if self.fail_models.contains(&model) {
|
||||
anyhow::bail!("500 model {} unavailable", model);
|
||||
}
|
||||
Ok(ChatResponse {
|
||||
text: Some(self.response_text.to_string()),
|
||||
tool_calls: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for Arc<NativeModelAwareMock> {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.as_ref()
|
||||
.chat_with_system(system_prompt, message, model, temperature)
|
||||
.await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
self.as_ref().chat(request, model, temperature).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Gap 3: `chat()` tries fallback models on failure,
|
||||
/// matching behavior of `model_failover_tries_fallback_model`.
|
||||
#[tokio::test]
|
||||
async fn chat_tries_model_failover_on_failure() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let mock = Arc::new(NativeModelAwareMock {
|
||||
calls: Arc::clone(&calls),
|
||||
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||
fail_models: vec!["claude-opus"],
|
||||
response_text: "ok from sonnet",
|
||||
});
|
||||
|
||||
let mut fallbacks = HashMap::new();
|
||||
fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"anthropic".into(),
|
||||
Box::new(mock.clone()) as Box<dyn Provider>,
|
||||
)],
|
||||
0, // no retries — force immediate model failover
|
||||
1,
|
||||
)
|
||||
.with_model_fallbacks(fallbacks);
|
||||
|
||||
let messages = vec![ChatMessage::user("hello")];
|
||||
let request = ChatRequest {
|
||||
messages: &messages,
|
||||
tools: None,
|
||||
};
|
||||
let result = provider.chat(request, "claude-opus", 0.0).await.unwrap();
|
||||
assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
|
||||
|
||||
let seen = mock.models_seen.lock();
|
||||
assert_eq!(seen.len(), 2);
|
||||
assert_eq!(seen[0], "claude-opus");
|
||||
assert_eq!(seen[1], "claude-sonnet");
|
||||
}
|
||||
|
||||
/// Gap 4: `chat()` skips retries on non-retryable errors (401, 403, etc.),
|
||||
/// matching behavior of `skips_retries_on_non_retryable_error`.
|
||||
#[tokio::test]
|
||||
async fn chat_skips_non_retryable_errors() {
|
||||
let primary_calls = Arc::new(AtomicUsize::new(0));
|
||||
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![
|
||||
(
|
||||
"primary".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::clone(&primary_calls),
|
||||
fail_until_attempt: usize::MAX,
|
||||
response_text: "never",
|
||||
tool_calls: vec![],
|
||||
error: "401 Unauthorized",
|
||||
}) as Box<dyn Provider>,
|
||||
),
|
||||
(
|
||||
"fallback".into(),
|
||||
Box::new(NativeToolMock {
|
||||
calls: Arc::clone(&fallback_calls),
|
||||
fail_until_attempt: 0,
|
||||
response_text: "from fallback",
|
||||
tool_calls: vec![],
|
||||
error: "fallback err",
|
||||
}) as Box<dyn Provider>,
|
||||
),
|
||||
],
|
||||
3,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("hello")];
|
||||
let request = ChatRequest {
|
||||
messages: &messages,
|
||||
tools: None,
|
||||
};
|
||||
let result = provider.chat(request, "test", 0.0).await.unwrap();
|
||||
assert_eq!(result.text.as_deref(), Some("from fallback"));
|
||||
// Primary should have been called only once (no retries)
|
||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user