fix(providers): recover from context window errors by truncating history (#3908)

When a provider returns a context-size-exceeded error, truncate the
oldest non-system messages from conversation history and retry instead
of immediately bailing out. This enables local models with small
context windows (llamafile, llama.cpp) to work by automatically
fitting the conversation within available context.

Closes #3894
This commit is contained in:
Argenis 2026-03-18 14:54:56 -04:00 committed by GitHub
parent 58b98c59a8
commit 162efbb49c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,8 +16,10 @@ use std::time::Duration;
/// Check if an error is non-retryable (client errors that won't resolve with retries).
pub fn is_non_retryable(err: &anyhow::Error) -> bool {
// Context window errors are NOT non-retryable — they can be recovered
// by truncating conversation history, so let the retry loop handle them.
if is_context_window_exceeded(err) {
return true;
return false;
}
// 4xx errors are generally non-retryable (bad request, auth failure, etc.),
@ -75,6 +77,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
let lower = err.to_string().to_lowercase();
let hints = [
"exceeds the context window",
"exceeds the available context size",
"context window of this model",
"maximum context length",
"context length exceeded",
@ -197,6 +200,35 @@ fn compact_error_detail(err: &anyhow::Error) -> String {
.join(" ")
}
/// Truncate conversation history by dropping the oldest non-system messages.
/// Returns the number of messages dropped. Keeps at least the system message
/// (if any) and the most recent user message.
fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
// Find all non-system message indices
let non_system: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| m.role != "system")
.map(|(i, _)| i)
.collect();
// Keep at least the last non-system message (most recent user turn)
if non_system.len() <= 1 {
return 0;
}
// Drop the oldest half of non-system messages
let drop_count = non_system.len() / 2;
let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
// Remove in reverse order to preserve indices
for &idx in indices_to_remove.iter().rev() {
messages.remove(idx);
}
drop_count
}
fn push_failure(
failures: &mut Vec<String>,
provider_name: &str,
@ -338,6 +370,25 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
// Context window exceeded: no history to truncate
// in chat_with_system, bail immediately.
if is_context_window_exceeded(&e) {
let error_detail = compact_error_detail(&e);
push_failure(
&mut failures,
provider_name,
current_model,
attempt + 1,
self.max_retries + 1,
"non_retryable",
&error_detail,
);
anyhow::bail!(
"Request exceeds model context window. Attempts:\n{}",
failures.join("\n")
);
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@ -376,14 +427,6 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@ -435,6 +478,8 @@ impl Provider for ReliableProvider {
) -> anyhow::Result<String> {
let models = self.model_chain(model);
let mut failures = Vec::new();
let mut effective_messages = messages.to_vec();
let mut context_truncated = false;
for current_model in &models {
for (provider_name, provider) in &self.providers {
@ -442,22 +487,39 @@ impl Provider for ReliableProvider {
for attempt in 0..=self.max_retries {
match provider
.chat_with_history(messages, current_model, temperature)
.chat_with_history(&effective_messages, current_model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 || *current_model != model {
if attempt > 0 || *current_model != model || context_truncated {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
context_truncated,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
// Context window exceeded: truncate history and retry
if is_context_window_exceeded(&e) && !context_truncated {
let dropped = truncate_for_context(&mut effective_messages);
if dropped > 0 {
context_truncated = true;
tracing::warn!(
provider = provider_name,
model = *current_model,
dropped,
remaining = effective_messages.len(),
"Context window exceeded; truncated history and retrying"
);
continue; // Retry with truncated messages (counts as an attempt)
}
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@ -494,14 +556,6 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@ -559,6 +613,8 @@ impl Provider for ReliableProvider {
) -> anyhow::Result<ChatResponse> {
let models = self.model_chain(model);
let mut failures = Vec::new();
let mut effective_messages = messages.to_vec();
let mut context_truncated = false;
for current_model in &models {
for (provider_name, provider) in &self.providers {
@ -566,22 +622,39 @@ impl Provider for ReliableProvider {
for attempt in 0..=self.max_retries {
match provider
.chat_with_tools(messages, tools, current_model, temperature)
.chat_with_tools(&effective_messages, tools, current_model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 || *current_model != model {
if attempt > 0 || *current_model != model || context_truncated {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
context_truncated,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
// Context window exceeded: truncate history and retry
if is_context_window_exceeded(&e) && !context_truncated {
let dropped = truncate_for_context(&mut effective_messages);
if dropped > 0 {
context_truncated = true;
tracing::warn!(
provider = provider_name,
model = *current_model,
dropped,
remaining = effective_messages.len(),
"Context window exceeded; truncated history and retrying"
);
continue; // Retry with truncated messages (counts as an attempt)
}
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@ -618,14 +691,6 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@ -669,6 +734,8 @@ impl Provider for ReliableProvider {
) -> anyhow::Result<ChatResponse> {
let models = self.model_chain(model);
let mut failures = Vec::new();
let mut effective_messages = request.messages.to_vec();
let mut context_truncated = false;
for current_model in &models {
for (provider_name, provider) in &self.providers {
@ -676,23 +743,40 @@ impl Provider for ReliableProvider {
for attempt in 0..=self.max_retries {
let req = ChatRequest {
messages: request.messages,
messages: &effective_messages,
tools: request.tools,
};
match provider.chat(req, current_model, temperature).await {
Ok(resp) => {
if attempt > 0 || *current_model != model {
if attempt > 0 || *current_model != model || context_truncated {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
context_truncated,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
// Context window exceeded: truncate history and retry
if is_context_window_exceeded(&e) && !context_truncated {
let dropped = truncate_for_context(&mut effective_messages);
if dropped > 0 {
context_truncated = true;
tracing::warn!(
provider = provider_name,
model = *current_model,
dropped,
remaining = effective_messages.len(),
"Context window exceeded; truncated history and retrying"
);
continue; // Retry with truncated messages (counts as an attempt)
}
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@ -729,14 +813,6 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@ -1071,7 +1147,8 @@ mod tests {
assert!(!is_non_retryable(&anyhow::anyhow!(
"model overloaded, try again later"
)));
assert!(is_non_retryable(&anyhow::anyhow!(
// Context window errors are now recoverable (not non-retryable)
assert!(!is_non_retryable(&anyhow::anyhow!(
"OpenAI Codex stream error: Your input exceeds the context window of this model."
)));
}
@ -1107,7 +1184,7 @@ mod tests {
let msg = err.to_string();
assert!(msg.contains("context window"));
assert!(msg.contains("skipped"));
// chat_with_system has no history to truncate, so it bails immediately
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@ -1980,4 +2057,136 @@ mod tests {
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
// ── Context window truncation tests ─────────────────────────
#[test]
fn context_window_error_is_not_non_retryable() {
// Context window errors should be recoverable via truncation
assert!(!is_non_retryable(&anyhow::anyhow!(
"exceeds the context window"
)));
assert!(!is_non_retryable(&anyhow::anyhow!(
"maximum context length exceeded"
)));
assert!(!is_non_retryable(&anyhow::anyhow!(
"too many tokens in the request"
)));
assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded")));
}
#[test]
fn is_context_window_exceeded_detects_llamacpp() {
assert!(is_context_window_exceeded(&anyhow::anyhow!(
"request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
)));
}
#[test]
fn truncate_for_context_drops_oldest_non_system() {
let mut messages = vec![
ChatMessage::system("sys"),
ChatMessage::user("msg1"),
ChatMessage::assistant("resp1"),
ChatMessage::user("msg2"),
ChatMessage::assistant("resp2"),
ChatMessage::user("msg3"),
];
let dropped = truncate_for_context(&mut messages);
// 5 non-system messages, drop oldest half = 2
assert_eq!(dropped, 2);
// System message preserved
assert_eq!(messages[0].role, "system");
// Remaining messages should be the newer ones
assert_eq!(messages.len(), 4); // system + 3 remaining non-system
// The last message should still be the most recent user message
assert_eq!(messages.last().unwrap().content, "msg3");
}
#[test]
fn truncate_for_context_preserves_system_and_last_message() {
// Only one non-system message: nothing to drop
let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
let dropped = truncate_for_context(&mut messages);
assert_eq!(dropped, 0);
assert_eq!(messages.len(), 2);
// No system message, only one user message
let mut messages = vec![ChatMessage::user("only")];
let dropped = truncate_for_context(&mut messages);
assert_eq!(dropped, 0);
assert_eq!(messages.len(), 1);
}
/// Mock that fails with context error on first N calls, then succeeds.
/// Tracks the number of messages received on each call.
struct ContextOverflowMock {
calls: Arc<AtomicUsize>,
fail_until_attempt: usize,
message_counts: parking_lot::Mutex<Vec<usize>>,
}
#[async_trait]
impl Provider for ContextOverflowMock {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".to_string())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
self.message_counts.lock().push(messages.len());
if attempt <= self.fail_until_attempt {
anyhow::bail!(
"request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
);
}
Ok("recovered after truncation".to_string())
}
}
#[tokio::test]
async fn chat_with_history_truncates_on_context_overflow() {
let calls = Arc::new(AtomicUsize::new(0));
let mock = ContextOverflowMock {
calls: Arc::clone(&calls),
fail_until_attempt: 1, // fail first call, succeed after truncation
message_counts: parking_lot::Mutex::new(Vec::new()),
};
let provider = ReliableProvider::new(
vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
3,
1,
);
let messages = vec![
ChatMessage::system("system prompt"),
ChatMessage::user("old message 1"),
ChatMessage::assistant("old response 1"),
ChatMessage::user("old message 2"),
ChatMessage::assistant("old response 2"),
ChatMessage::user("current question"),
];
let result = provider
.chat_with_history(&messages, "local-model", 0.0)
.await
.unwrap();
assert_eq!(result, "recovered after truncation");
// Should have been called twice: once with full messages, once with truncated
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
}