diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index a93cad476..03f30fc06 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -211,9 +211,9 @@ impl AnthropicProvider { text.len() > 3072 } - /// Cache conversations with more than 4 messages (excluding system) + /// Cache conversations with more than 1 non-system message (i.e. after first exchange) fn should_cache_conversation(messages: &[ChatMessage]) -> bool { - messages.iter().filter(|m| m.role != "system").count() > 4 + messages.iter().filter(|m| m.role != "system").count() > 1 } /// Apply cache control to the last message content block @@ -447,17 +447,13 @@ impl AnthropicProvider { } } - // Convert system text to SystemPrompt with cache control if large + // Always use Blocks format with cache_control for system prompts let system_prompt = system_text.map(|text| { - if Self::should_cache_system(&text) { - SystemPrompt::Blocks(vec![SystemBlock { - block_type: "text".to_string(), - text, - cache_control: Some(CacheControl::ephemeral()), - }]) - } else { - SystemPrompt::String(text) - } + SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text, + cache_control: Some(CacheControl::ephemeral()), + }]) }); (system_prompt, native_messages) @@ -1063,12 +1059,8 @@ mod tests { role: "user".to_string(), content: "Hello".to_string(), }, - ChatMessage { - role: "assistant".to_string(), - content: "Hi".to_string(), - }, ]; - // Only 2 non-system messages + // Only 1 non-system message — should not cache assert!(!AnthropicProvider::should_cache_conversation(&messages)); } @@ -1078,8 +1070,8 @@ mod tests { role: "system".to_string(), content: "System prompt".to_string(), }]; - // Add 5 non-system messages - for i in 0..5 { + // Add 3 non-system messages + for i in 0..3 { messages.push(ChatMessage { role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), content: format!("Message {i}"), @@ -1090,21 +1082,24 @@ mod tests { #[test] fn should_cache_conversation_boundary() { - let mut messages = vec![]; - // Add exactly 4 non-system messages - for i in 0..4 { - messages.push(ChatMessage { - role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), - content: format!("Message {i}"), - }); - } + let messages = vec![ChatMessage { + role: "user".to_string(), + content: "Hello".to_string(), + }]; + // Exactly 1 non-system message — should not cache assert!(!AnthropicProvider::should_cache_conversation(&messages)); - // Add one more to cross boundary - messages.push(ChatMessage { - role: "user".to_string(), - content: "One more".to_string(), - }); + // Add one more to cross boundary (>1) + let messages = vec![ + ChatMessage { + role: "user".to_string(), + content: "Hello".to_string(), + }, + ChatMessage { + role: "assistant".to_string(), + content: "Hi".to_string(), + }, + ]; assert!(AnthropicProvider::should_cache_conversation(&messages)); } @@ -1217,7 +1212,7 @@ mod tests { } #[test] - fn convert_messages_small_system_prompt() { + fn convert_messages_small_system_prompt_uses_blocks_with_cache() { let messages = vec![ChatMessage { role: "system".to_string(), content: "Short system prompt".to_string(), @@ -1226,10 +1221,17 @@ mod tests { let (system_prompt, _) = AnthropicProvider::convert_messages(&messages); match system_prompt.unwrap() { - SystemPrompt::String(s) => { - assert_eq!(s, "Short system prompt"); + SystemPrompt::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + assert_eq!(blocks[0].text, "Short system prompt"); + assert!( + blocks[0].cache_control.is_some(), + "Small system prompts should have cache_control" + ); + } + SystemPrompt::String(_) => { + panic!("Expected Blocks variant with cache_control for small prompt") } - SystemPrompt::Blocks(_) => panic!("Expected String variant for small prompt"), } } @@ -1254,12 +1256,16 @@ mod tests { } #[test] - fn backward_compatibility_native_chat_request() { - // Test that requests without cache_control serialize identically to old format + fn native_chat_request_with_blocks_system() { + // System prompts now always use Blocks format with cache_control let req = NativeChatRequest { model: "claude-3-opus".to_string(), max_tokens: 4096, - system: Some(SystemPrompt::String("System".to_string())), + system: Some(SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text: "System".to_string(), + cache_control: Some(CacheControl::ephemeral()), + }])), messages: vec![NativeMessage { role: "user".to_string(), content: vec![NativeContentOut::Text { @@ -1272,8 +1278,11 @@ mod tests { }; let json = serde_json::to_string(&req).unwrap(); - assert!(!json.contains("cache_control")); - assert!(json.contains(r#""system":"System""#)); + assert!(json.contains("System")); + assert!( + json.contains(r#""cache_control":{"type":"ephemeral"}"#), + "System prompt should include cache_control" + ); } #[tokio::test]