From 162efbb49cf67a3fdb068a12ef8a2e1c4d6939e3 Mon Sep 17 00:00:00 2001 From: Argenis Date: Wed, 18 Mar 2026 14:54:56 -0400 Subject: [PATCH] fix(providers): recover from context window errors by truncating history (#3908) When a provider returns a context-size-exceeded error, truncate the oldest non-system messages from conversation history and retry instead of immediately bailing out. This enables local models with small context windows (llamafile, llama.cpp) to work by automatically fitting the conversation within available context. Closes #3894 --- src/providers/reliable.rs | 291 ++++++++++++++++++++++++++++++++------ 1 file changed, 250 insertions(+), 41 deletions(-) diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index b90513117..66c095948 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -16,8 +16,10 @@ use std::time::Duration; /// Check if an error is non-retryable (client errors that won't resolve with retries). pub fn is_non_retryable(err: &anyhow::Error) -> bool { + // Context window errors are NOT non-retryable — they can be recovered + // by truncating conversation history, so let the retry loop handle them. if is_context_window_exceeded(err) { - return true; + return false; } // 4xx errors are generally non-retryable (bad request, auth failure, etc.), @@ -75,6 +77,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool { let lower = err.to_string().to_lowercase(); let hints = [ "exceeds the context window", + "exceeds the available context size", "context window of this model", "maximum context length", "context length exceeded", @@ -197,6 +200,35 @@ fn compact_error_detail(err: &anyhow::Error) -> String { .join(" ") } +/// Truncate conversation history by dropping the oldest non-system messages. +/// Returns the number of messages dropped. Keeps at least the system message +/// (if any) and the most recent user message. +fn truncate_for_context(messages: &mut Vec) -> usize { + // Find all non-system message indices + let non_system: Vec = messages + .iter() + .enumerate() + .filter(|(_, m)| m.role != "system") + .map(|(i, _)| i) + .collect(); + + // Keep at least the last non-system message (most recent user turn) + if non_system.len() <= 1 { + return 0; + } + + // Drop the oldest half of non-system messages + let drop_count = non_system.len() / 2; + let indices_to_remove: Vec = non_system[..drop_count].to_vec(); + + // Remove in reverse order to preserve indices + for &idx in indices_to_remove.iter().rev() { + messages.remove(idx); + } + + drop_count +} + fn push_failure( failures: &mut Vec, provider_name: &str, @@ -338,6 +370,25 @@ impl Provider for ReliableProvider { return Ok(resp); } Err(e) => { + // Context window exceeded: no history to truncate + // in chat_with_system, bail immediately. + if is_context_window_exceeded(&e) { + let error_detail = compact_error_detail(&e); + push_failure( + &mut failures, + provider_name, + current_model, + attempt + 1, + self.max_retries + 1, + "non_retryable", + &error_detail, + ); + anyhow::bail!( + "Request exceeds model context window. Attempts:\n{}", + failures.join("\n") + ); + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -376,14 +427,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -435,6 +478,8 @@ impl Provider for ReliableProvider { ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); + let mut effective_messages = messages.to_vec(); + let mut context_truncated = false; for current_model in &models { for (provider_name, provider) in &self.providers { @@ -442,22 +487,39 @@ impl Provider for ReliableProvider { for attempt in 0..=self.max_retries { match provider - .chat_with_history(messages, current_model, temperature) + .chat_with_history(&effective_messages, current_model, temperature) .await { Ok(resp) => { - if attempt > 0 || *current_model != model { + if attempt > 0 || *current_model != model || context_truncated { tracing::info!( provider = provider_name, model = *current_model, attempt, original_model = model, + context_truncated, "Provider recovered (failover/retry)" ); } return Ok(resp); } Err(e) => { + // Context window exceeded: truncate history and retry + if is_context_window_exceeded(&e) && !context_truncated { + let dropped = truncate_for_context(&mut effective_messages); + if dropped > 0 { + context_truncated = true; + tracing::warn!( + provider = provider_name, + model = *current_model, + dropped, + remaining = effective_messages.len(), + "Context window exceeded; truncated history and retrying" + ); + continue; // Retry with truncated messages (counts as an attempt) + } + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -494,14 +556,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -559,6 +613,8 @@ impl Provider for ReliableProvider { ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); + let mut effective_messages = messages.to_vec(); + let mut context_truncated = false; for current_model in &models { for (provider_name, provider) in &self.providers { @@ -566,22 +622,39 @@ impl Provider for ReliableProvider { for attempt in 0..=self.max_retries { match provider - .chat_with_tools(messages, tools, current_model, temperature) + .chat_with_tools(&effective_messages, tools, current_model, temperature) .await { Ok(resp) => { - if attempt > 0 || *current_model != model { + if attempt > 0 || *current_model != model || context_truncated { tracing::info!( provider = provider_name, model = *current_model, attempt, original_model = model, + context_truncated, "Provider recovered (failover/retry)" ); } return Ok(resp); } Err(e) => { + // Context window exceeded: truncate history and retry + if is_context_window_exceeded(&e) && !context_truncated { + let dropped = truncate_for_context(&mut effective_messages); + if dropped > 0 { + context_truncated = true; + tracing::warn!( + provider = provider_name, + model = *current_model, + dropped, + remaining = effective_messages.len(), + "Context window exceeded; truncated history and retrying" + ); + continue; // Retry with truncated messages (counts as an attempt) + } + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -618,14 +691,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -669,6 +734,8 @@ impl Provider for ReliableProvider { ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); + let mut effective_messages = request.messages.to_vec(); + let mut context_truncated = false; for current_model in &models { for (provider_name, provider) in &self.providers { @@ -676,23 +743,40 @@ impl Provider for ReliableProvider { for attempt in 0..=self.max_retries { let req = ChatRequest { - messages: request.messages, + messages: &effective_messages, tools: request.tools, }; match provider.chat(req, current_model, temperature).await { Ok(resp) => { - if attempt > 0 || *current_model != model { + if attempt > 0 || *current_model != model || context_truncated { tracing::info!( provider = provider_name, model = *current_model, attempt, original_model = model, + context_truncated, "Provider recovered (failover/retry)" ); } return Ok(resp); } Err(e) => { + // Context window exceeded: truncate history and retry + if is_context_window_exceeded(&e) && !context_truncated { + let dropped = truncate_for_context(&mut effective_messages); + if dropped > 0 { + context_truncated = true; + tracing::warn!( + provider = provider_name, + model = *current_model, + dropped, + remaining = effective_messages.len(), + "Context window exceeded; truncated history and retrying" + ); + continue; // Retry with truncated messages (counts as an attempt) + } + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -729,14 +813,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -1071,7 +1147,8 @@ mod tests { assert!(!is_non_retryable(&anyhow::anyhow!( "model overloaded, try again later" ))); - assert!(is_non_retryable(&anyhow::anyhow!( + // Context window errors are now recoverable (not non-retryable) + assert!(!is_non_retryable(&anyhow::anyhow!( "OpenAI Codex stream error: Your input exceeds the context window of this model." ))); } @@ -1107,7 +1184,7 @@ mod tests { let msg = err.to_string(); assert!(msg.contains("context window")); - assert!(msg.contains("skipped")); + // chat_with_system has no history to truncate, so it bails immediately assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -1980,4 +2057,136 @@ mod tests { assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } + + // ── Context window truncation tests ───────────────────────── + + #[test] + fn context_window_error_is_not_non_retryable() { + // Context window errors should be recoverable via truncation + assert!(!is_non_retryable(&anyhow::anyhow!( + "exceeds the context window" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!( + "maximum context length exceeded" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!( + "too many tokens in the request" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded"))); + } + + #[test] + fn is_context_window_exceeded_detects_llamacpp() { + assert!(is_context_window_exceeded(&anyhow::anyhow!( + "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it" + ))); + } + + #[test] + fn truncate_for_context_drops_oldest_non_system() { + let mut messages = vec![ + ChatMessage::system("sys"), + ChatMessage::user("msg1"), + ChatMessage::assistant("resp1"), + ChatMessage::user("msg2"), + ChatMessage::assistant("resp2"), + ChatMessage::user("msg3"), + ]; + + let dropped = truncate_for_context(&mut messages); + + // 5 non-system messages, drop oldest half = 2 + assert_eq!(dropped, 2); + // System message preserved + assert_eq!(messages[0].role, "system"); + // Remaining messages should be the newer ones + assert_eq!(messages.len(), 4); // system + 3 remaining non-system + // The last message should still be the most recent user message + assert_eq!(messages.last().unwrap().content, "msg3"); + } + + #[test] + fn truncate_for_context_preserves_system_and_last_message() { + // Only one non-system message: nothing to drop + let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")]; + let dropped = truncate_for_context(&mut messages); + assert_eq!(dropped, 0); + assert_eq!(messages.len(), 2); + + // No system message, only one user message + let mut messages = vec![ChatMessage::user("only")]; + let dropped = truncate_for_context(&mut messages); + assert_eq!(dropped, 0); + assert_eq!(messages.len(), 1); + } + + /// Mock that fails with context error on first N calls, then succeeds. + /// Tracks the number of messages received on each call. + struct ContextOverflowMock { + calls: Arc, + fail_until_attempt: usize, + message_counts: parking_lot::Mutex>, + } + + #[async_trait] + impl Provider for ContextOverflowMock { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".to_string()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + self.message_counts.lock().push(messages.len()); + if attempt <= self.fail_until_attempt { + anyhow::bail!( + "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it" + ); + } + Ok("recovered after truncation".to_string()) + } + } + + #[tokio::test] + async fn chat_with_history_truncates_on_context_overflow() { + let calls = Arc::new(AtomicUsize::new(0)); + let mock = ContextOverflowMock { + calls: Arc::clone(&calls), + fail_until_attempt: 1, // fail first call, succeed after truncation + message_counts: parking_lot::Mutex::new(Vec::new()), + }; + + let provider = ReliableProvider::new( + vec![("local".into(), Box::new(mock) as Box)], + 3, + 1, + ); + + let messages = vec![ + ChatMessage::system("system prompt"), + ChatMessage::user("old message 1"), + ChatMessage::assistant("old response 1"), + ChatMessage::user("old message 2"), + ChatMessage::assistant("old response 2"), + ChatMessage::user("current question"), + ]; + + let result = provider + .chat_with_history(&messages, "local-model", 0.0) + .await + .unwrap(); + assert_eq!(result, "recovered after truncation"); + // Should have been called twice: once with full messages, once with truncated + assert_eq!(calls.load(Ordering::SeqCst), 2); + } }