fix(provider): add chat() override to ReliableProvider for native tool calling

ReliableProvider was missing a chat() override, causing it to fall through
to the default Provider::chat() trait implementation. The default
implementation delegates to chat_with_history() which returns a plain
String and wraps it in ChatResponse with tool_calls: Vec::new() — so
native tool calling was completely broken through the retry/failover
wrapper even though the underlying provider properly supports it.

Changes:
- Add chat() with full retry/backoff/failover logic matching existing
  chat_with_system(), chat_with_history(), and chat_with_tools() overrides
- Include context_window_exceeded early-exit matching other method patterns
- Add 7 focused tests: delegation with tool calls, retry recovery,
  supports_native_tools propagation, aggregated error reporting,
  model failover, non-retryable error skip, and system prompt zero-XML
  verification
This commit is contained in:
Vernon Stinebaker 2026-02-20 14:45:24 +08:00 committed by Chummy
parent 04640a963e
commit 4fd41d5f2c
2 changed files with 514 additions and 1 deletions

View File

@ -3432,4 +3432,60 @@ Let me check the result."#;
assert_eq!(history[0].role, "system");
assert_eq!(history[1].content, "new msg");
}
/// When `build_system_prompt_with_mode` is called with `native_tools = true`,
/// the output must contain ZERO XML protocol artifacts. In the native path
/// `build_tool_instructions` is never called, so the system prompt alone
/// must be clean of XML tool-call protocol.
#[test]
fn native_tools_system_prompt_contains_zero_xml() {
use crate::channels::build_system_prompt_with_mode;
let tool_summaries: Vec<(&str, &str)> = vec![
("shell", "Execute shell commands"),
("file_read", "Read files"),
];
let system_prompt = build_system_prompt_with_mode(
std::path::Path::new("/tmp"),
"test-model",
&tool_summaries,
&[], // no skills
None, // no identity config
None, // no bootstrap_max_chars
true, // native_tools
);
// Must contain zero XML protocol artifacts
assert!(
!system_prompt.contains("<tool_call>"),
"Native prompt must not contain <tool_call>"
);
assert!(
!system_prompt.contains("</tool_call>"),
"Native prompt must not contain </tool_call>"
);
assert!(
!system_prompt.contains("<tool_result>"),
"Native prompt must not contain <tool_result>"
);
assert!(
!system_prompt.contains("</tool_result>"),
"Native prompt must not contain </tool_result>"
);
assert!(
!system_prompt.contains("## Tool Use Protocol"),
"Native prompt must not contain XML protocol header"
);
// Positive: native prompt should still list tools and contain task instructions
assert!(
system_prompt.contains("shell"),
"Native prompt must list tool names"
);
assert!(
system_prompt.contains("## Your Task"),
"Native prompt should contain task instructions"
);
}
}

View File

@ -1,4 +1,6 @@
use super::traits::{ChatMessage, ChatResponse, StreamChunk, StreamOptions, StreamResult};
use super::traits::{
ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamOptions, StreamResult,
};
use super::Provider;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
@ -548,6 +550,115 @@ impl Provider for ReliableProvider {
.any(|(_, provider)| provider.supports_vision())
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let models = self.model_chain(model);
let mut failures = Vec::new();
for current_model in &models {
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for attempt in 0..=self.max_retries {
let req = ChatRequest {
messages: request.messages,
tools: request.tools,
};
match provider.chat(req, current_model, temperature).await {
Ok(resp) => {
if attempt > 0 || *current_model != model {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
push_failure(
&mut failures,
provider_name,
current_model,
attempt + 1,
self.max_retries + 1,
failure_reason,
&error_detail,
);
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
error = %error_detail,
"Rate limited, rotated API key (key ending ...{})",
&new_key[new_key.len().saturating_sub(4)..]
);
}
}
if non_retryable {
tracing::warn!(
provider = provider_name,
model = *current_model,
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
if attempt < self.max_retries {
let wait = self.compute_backoff(backoff_ms, &e);
tracing::warn!(
provider = provider_name,
model = *current_model,
attempt = attempt + 1,
backoff_ms = wait,
reason = failure_reason,
error = %error_detail,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(wait)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
}
}
}
}
tracing::warn!(
provider = provider_name,
model = *current_model,
"Exhausted retries, trying next provider/model"
);
}
}
anyhow::bail!(
"All providers/models failed. Attempts:\n{}",
failures.join("\n")
)
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
@ -1509,4 +1620,350 @@ mod tests {
.await
}
}
/// Mock provider that implements `chat()` with native tool support.
struct NativeToolMock {
calls: Arc<AtomicUsize>,
fail_until_attempt: usize,
response_text: &'static str,
tool_calls: Vec<super::super::traits::ToolCall>,
error: &'static str,
}
#[async_trait]
impl Provider for NativeToolMock {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(self.response_text.to_string())
}
fn supports_native_tools(&self) -> bool {
true
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(ChatResponse {
text: Some(self.response_text.to_string()),
tool_calls: self.tool_calls.clone(),
})
}
}
#[tokio::test]
async fn chat_delegates_to_inner_provider() {
let calls = Arc::new(AtomicUsize::new(0));
let tool_call = super::super::traits::ToolCall {
id: "call_1".to_string(),
name: "shell".to_string(),
arguments: r#"{"command":"date"}"#.to_string(),
};
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(NativeToolMock {
calls: Arc::clone(&calls),
fail_until_attempt: 0,
response_text: "ok",
tool_calls: vec![tool_call.clone()],
error: "boom",
}) as Box<dyn Provider>,
)],
2,
1,
);
let messages = vec![ChatMessage::user("what time is it?")];
let request = ChatRequest {
messages: &messages,
tools: None,
};
let result = provider.chat(request, "test-model", 0.0).await.unwrap();
assert_eq!(result.text.as_deref(), Some("ok"));
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "shell");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn chat_retries_and_recovers() {
let calls = Arc::new(AtomicUsize::new(0));
let tool_call = super::super::traits::ToolCall {
id: "call_1".to_string(),
name: "shell".to_string(),
arguments: r#"{"command":"date"}"#.to_string(),
};
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(NativeToolMock {
calls: Arc::clone(&calls),
fail_until_attempt: 2,
response_text: "recovered",
tool_calls: vec![tool_call],
error: "temporary failure",
}) as Box<dyn Provider>,
)],
3,
1,
);
let messages = vec![ChatMessage::user("test")];
let request = ChatRequest {
messages: &messages,
tools: None,
};
let result = provider.chat(request, "test-model", 0.0).await.unwrap();
assert_eq!(result.text.as_deref(), Some("recovered"));
assert!(
calls.load(Ordering::SeqCst) > 1,
"should have retried at least once"
);
}
#[tokio::test]
async fn chat_preserves_native_tools_support() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(NativeToolMock {
calls: Arc::clone(&calls),
fail_until_attempt: 0,
response_text: "ok",
tool_calls: vec![],
error: "boom",
}) as Box<dyn Provider>,
)],
2,
1,
);
assert!(
provider.supports_native_tools(),
"ReliableProvider must propagate supports_native_tools from inner provider"
);
}
// ── Gap 2-4: Parity tests for chat() ────────────────────────
/// Gap 2: `chat()` returns an aggregated error when all providers fail,
/// matching behavior of `returns_aggregated_error_when_all_providers_fail`.
#[tokio::test]
async fn chat_returns_aggregated_error_when_all_providers_fail() {
let provider = ReliableProvider::new(
vec![
(
"p1".into(),
Box::new(NativeToolMock {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: usize::MAX,
response_text: "never",
tool_calls: vec![],
error: "p1 chat error",
}) as Box<dyn Provider>,
),
(
"p2".into(),
Box::new(NativeToolMock {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: usize::MAX,
response_text: "never",
tool_calls: vec![],
error: "p2 chat error",
}) as Box<dyn Provider>,
),
],
0,
1,
);
let messages = vec![ChatMessage::user("hello")];
let request = ChatRequest {
messages: &messages,
tools: None,
};
let err = provider
.chat(request, "test", 0.0)
.await
.expect_err("all providers should fail");
let msg = err.to_string();
assert!(msg.contains("All providers/models failed"));
assert!(msg.contains("provider=p1 model=test"));
assert!(msg.contains("provider=p2 model=test"));
assert!(msg.contains("error=p1 chat error"));
assert!(msg.contains("error=p2 chat error"));
assert!(msg.contains("retryable"));
}
/// Mock that records model names and can fail specific models,
/// implementing `chat()` for native tool calling parity tests.
struct NativeModelAwareMock {
calls: Arc<AtomicUsize>,
models_seen: parking_lot::Mutex<Vec<String>>,
fail_models: Vec<&'static str>,
response_text: &'static str,
}
#[async_trait]
impl Provider for NativeModelAwareMock {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(self.response_text.to_string())
}
fn supports_native_tools(&self) -> bool {
true
}
async fn chat(
&self,
_request: ChatRequest<'_>,
model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.models_seen.lock().push(model.to_string());
if self.fail_models.contains(&model) {
anyhow::bail!("500 model {} unavailable", model);
}
Ok(ChatResponse {
text: Some(self.response_text.to_string()),
tool_calls: vec![],
})
}
}
#[async_trait]
impl Provider for Arc<NativeModelAwareMock> {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.as_ref()
.chat_with_system(system_prompt, message, model, temperature)
.await
}
fn supports_native_tools(&self) -> bool {
true
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
self.as_ref().chat(request, model, temperature).await
}
}
/// Gap 3: `chat()` tries fallback models on failure,
/// matching behavior of `model_failover_tries_fallback_model`.
#[tokio::test]
async fn chat_tries_model_failover_on_failure() {
let calls = Arc::new(AtomicUsize::new(0));
let mock = Arc::new(NativeModelAwareMock {
calls: Arc::clone(&calls),
models_seen: parking_lot::Mutex::new(Vec::new()),
fail_models: vec!["claude-opus"],
response_text: "ok from sonnet",
});
let mut fallbacks = HashMap::new();
fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
let provider = ReliableProvider::new(
vec![(
"anthropic".into(),
Box::new(mock.clone()) as Box<dyn Provider>,
)],
0, // no retries — force immediate model failover
1,
)
.with_model_fallbacks(fallbacks);
let messages = vec![ChatMessage::user("hello")];
let request = ChatRequest {
messages: &messages,
tools: None,
};
let result = provider.chat(request, "claude-opus", 0.0).await.unwrap();
assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
let seen = mock.models_seen.lock();
assert_eq!(seen.len(), 2);
assert_eq!(seen[0], "claude-opus");
assert_eq!(seen[1], "claude-sonnet");
}
/// Gap 4: `chat()` skips retries on non-retryable errors (401, 403, etc.),
/// matching behavior of `skips_retries_on_non_retryable_error`.
#[tokio::test]
async fn chat_skips_non_retryable_errors() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(NativeToolMock {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response_text: "never",
tool_calls: vec![],
error: "401 Unauthorized",
}) as Box<dyn Provider>,
),
(
"fallback".into(),
Box::new(NativeToolMock {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response_text: "from fallback",
tool_calls: vec![],
error: "fallback err",
}) as Box<dyn Provider>,
),
],
3,
1,
);
let messages = vec![ChatMessage::user("hello")];
let request = ChatRequest {
messages: &messages,
tools: None,
};
let result = provider.chat(request, "test", 0.0).await.unwrap();
assert_eq!(result.text.as_deref(), Some("from fallback"));
// Primary should have been called only once (no retries)
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
}