diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index b90513117..66c095948 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -16,8 +16,10 @@ use std::time::Duration; /// Check if an error is non-retryable (client errors that won't resolve with retries). pub fn is_non_retryable(err: &anyhow::Error) -> bool { + // Context window errors are NOT non-retryable — they can be recovered + // by truncating conversation history, so let the retry loop handle them. if is_context_window_exceeded(err) { - return true; + return false; } // 4xx errors are generally non-retryable (bad request, auth failure, etc.), @@ -75,6 +77,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool { let lower = err.to_string().to_lowercase(); let hints = [ "exceeds the context window", + "exceeds the available context size", "context window of this model", "maximum context length", "context length exceeded", @@ -197,6 +200,35 @@ fn compact_error_detail(err: &anyhow::Error) -> String { .join(" ") } +/// Truncate conversation history by dropping the oldest non-system messages. +/// Returns the number of messages dropped. Keeps at least the system message +/// (if any) and the most recent user message. +fn truncate_for_context(messages: &mut Vec) -> usize { + // Find all non-system message indices + let non_system: Vec = messages + .iter() + .enumerate() + .filter(|(_, m)| m.role != "system") + .map(|(i, _)| i) + .collect(); + + // Keep at least the last non-system message (most recent user turn) + if non_system.len() <= 1 { + return 0; + } + + // Drop the oldest half of non-system messages + let drop_count = non_system.len() / 2; + let indices_to_remove: Vec = non_system[..drop_count].to_vec(); + + // Remove in reverse order to preserve indices + for &idx in indices_to_remove.iter().rev() { + messages.remove(idx); + } + + drop_count +} + fn push_failure( failures: &mut Vec, provider_name: &str, @@ -338,6 +370,25 @@ impl Provider for ReliableProvider { return Ok(resp); } Err(e) => { + // Context window exceeded: no history to truncate + // in chat_with_system, bail immediately. + if is_context_window_exceeded(&e) { + let error_detail = compact_error_detail(&e); + push_failure( + &mut failures, + provider_name, + current_model, + attempt + 1, + self.max_retries + 1, + "non_retryable", + &error_detail, + ); + anyhow::bail!( + "Request exceeds model context window. Attempts:\n{}", + failures.join("\n") + ); + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -376,14 +427,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -435,6 +478,8 @@ impl Provider for ReliableProvider { ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); + let mut effective_messages = messages.to_vec(); + let mut context_truncated = false; for current_model in &models { for (provider_name, provider) in &self.providers { @@ -442,22 +487,39 @@ impl Provider for ReliableProvider { for attempt in 0..=self.max_retries { match provider - .chat_with_history(messages, current_model, temperature) + .chat_with_history(&effective_messages, current_model, temperature) .await { Ok(resp) => { - if attempt > 0 || *current_model != model { + if attempt > 0 || *current_model != model || context_truncated { tracing::info!( provider = provider_name, model = *current_model, attempt, original_model = model, + context_truncated, "Provider recovered (failover/retry)" ); } return Ok(resp); } Err(e) => { + // Context window exceeded: truncate history and retry + if is_context_window_exceeded(&e) && !context_truncated { + let dropped = truncate_for_context(&mut effective_messages); + if dropped > 0 { + context_truncated = true; + tracing::warn!( + provider = provider_name, + model = *current_model, + dropped, + remaining = effective_messages.len(), + "Context window exceeded; truncated history and retrying" + ); + continue; // Retry with truncated messages (counts as an attempt) + } + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -494,14 +556,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -559,6 +613,8 @@ impl Provider for ReliableProvider { ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); + let mut effective_messages = messages.to_vec(); + let mut context_truncated = false; for current_model in &models { for (provider_name, provider) in &self.providers { @@ -566,22 +622,39 @@ impl Provider for ReliableProvider { for attempt in 0..=self.max_retries { match provider - .chat_with_tools(messages, tools, current_model, temperature) + .chat_with_tools(&effective_messages, tools, current_model, temperature) .await { Ok(resp) => { - if attempt > 0 || *current_model != model { + if attempt > 0 || *current_model != model || context_truncated { tracing::info!( provider = provider_name, model = *current_model, attempt, original_model = model, + context_truncated, "Provider recovered (failover/retry)" ); } return Ok(resp); } Err(e) => { + // Context window exceeded: truncate history and retry + if is_context_window_exceeded(&e) && !context_truncated { + let dropped = truncate_for_context(&mut effective_messages); + if dropped > 0 { + context_truncated = true; + tracing::warn!( + provider = provider_name, + model = *current_model, + dropped, + remaining = effective_messages.len(), + "Context window exceeded; truncated history and retrying" + ); + continue; // Retry with truncated messages (counts as an attempt) + } + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -618,14 +691,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -669,6 +734,8 @@ impl Provider for ReliableProvider { ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); + let mut effective_messages = request.messages.to_vec(); + let mut context_truncated = false; for current_model in &models { for (provider_name, provider) in &self.providers { @@ -676,23 +743,40 @@ impl Provider for ReliableProvider { for attempt in 0..=self.max_retries { let req = ChatRequest { - messages: request.messages, + messages: &effective_messages, tools: request.tools, }; match provider.chat(req, current_model, temperature).await { Ok(resp) => { - if attempt > 0 || *current_model != model { + if attempt > 0 || *current_model != model || context_truncated { tracing::info!( provider = provider_name, model = *current_model, attempt, original_model = model, + context_truncated, "Provider recovered (failover/retry)" ); } return Ok(resp); } Err(e) => { + // Context window exceeded: truncate history and retry + if is_context_window_exceeded(&e) && !context_truncated { + let dropped = truncate_for_context(&mut effective_messages); + if dropped > 0 { + context_truncated = true; + tracing::warn!( + provider = provider_name, + model = *current_model, + dropped, + remaining = effective_messages.len(), + "Context window exceeded; truncated history and retrying" + ); + continue; // Retry with truncated messages (counts as an attempt) + } + } + let non_retryable_rate_limit = is_non_retryable_rate_limit(&e); let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit; let rate_limited = is_rate_limited(&e); @@ -729,14 +813,6 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); - - if is_context_window_exceeded(&e) { - anyhow::bail!( - "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", - failures.join("\n") - ); - } - break; } @@ -1071,7 +1147,8 @@ mod tests { assert!(!is_non_retryable(&anyhow::anyhow!( "model overloaded, try again later" ))); - assert!(is_non_retryable(&anyhow::anyhow!( + // Context window errors are now recoverable (not non-retryable) + assert!(!is_non_retryable(&anyhow::anyhow!( "OpenAI Codex stream error: Your input exceeds the context window of this model." ))); } @@ -1107,7 +1184,7 @@ mod tests { let msg = err.to_string(); assert!(msg.contains("context window")); - assert!(msg.contains("skipped")); + // chat_with_system has no history to truncate, so it bails immediately assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -1980,4 +2057,136 @@ mod tests { assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } + + // ── Context window truncation tests ───────────────────────── + + #[test] + fn context_window_error_is_not_non_retryable() { + // Context window errors should be recoverable via truncation + assert!(!is_non_retryable(&anyhow::anyhow!( + "exceeds the context window" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!( + "maximum context length exceeded" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!( + "too many tokens in the request" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded"))); + } + + #[test] + fn is_context_window_exceeded_detects_llamacpp() { + assert!(is_context_window_exceeded(&anyhow::anyhow!( + "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it" + ))); + } + + #[test] + fn truncate_for_context_drops_oldest_non_system() { + let mut messages = vec![ + ChatMessage::system("sys"), + ChatMessage::user("msg1"), + ChatMessage::assistant("resp1"), + ChatMessage::user("msg2"), + ChatMessage::assistant("resp2"), + ChatMessage::user("msg3"), + ]; + + let dropped = truncate_for_context(&mut messages); + + // 5 non-system messages, drop oldest half = 2 + assert_eq!(dropped, 2); + // System message preserved + assert_eq!(messages[0].role, "system"); + // Remaining messages should be the newer ones + assert_eq!(messages.len(), 4); // system + 3 remaining non-system + // The last message should still be the most recent user message + assert_eq!(messages.last().unwrap().content, "msg3"); + } + + #[test] + fn truncate_for_context_preserves_system_and_last_message() { + // Only one non-system message: nothing to drop + let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")]; + let dropped = truncate_for_context(&mut messages); + assert_eq!(dropped, 0); + assert_eq!(messages.len(), 2); + + // No system message, only one user message + let mut messages = vec![ChatMessage::user("only")]; + let dropped = truncate_for_context(&mut messages); + assert_eq!(dropped, 0); + assert_eq!(messages.len(), 1); + } + + /// Mock that fails with context error on first N calls, then succeeds. + /// Tracks the number of messages received on each call. + struct ContextOverflowMock { + calls: Arc, + fail_until_attempt: usize, + message_counts: parking_lot::Mutex>, + } + + #[async_trait] + impl Provider for ContextOverflowMock { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".to_string()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + self.message_counts.lock().push(messages.len()); + if attempt <= self.fail_until_attempt { + anyhow::bail!( + "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it" + ); + } + Ok("recovered after truncation".to_string()) + } + } + + #[tokio::test] + async fn chat_with_history_truncates_on_context_overflow() { + let calls = Arc::new(AtomicUsize::new(0)); + let mock = ContextOverflowMock { + calls: Arc::clone(&calls), + fail_until_attempt: 1, // fail first call, succeed after truncation + message_counts: parking_lot::Mutex::new(Vec::new()), + }; + + let provider = ReliableProvider::new( + vec![("local".into(), Box::new(mock) as Box)], + 3, + 1, + ); + + let messages = vec![ + ChatMessage::system("system prompt"), + ChatMessage::user("old message 1"), + ChatMessage::assistant("old response 1"), + ChatMessage::user("old message 2"), + ChatMessage::assistant("old response 2"), + ChatMessage::user("current question"), + ]; + + let result = provider + .chat_with_history(&messages, "local-model", 0.0) + .await + .unwrap(); + assert_eq!(result, "recovered after truncation"); + // Should have been called twice: once with full messages, once with truncated + assert_eq!(calls.load(Ordering::SeqCst), 2); + } }