From ff8017a1f66e3d15ffbf90d30cc145a7a425f712 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 2 Mar 2026 17:53:36 -0500 Subject: [PATCH 1/2] feat(memory): add multi-query keyword expansion --- src/agent/loop_/context.rs | 6 +- src/agent/memory_loader.rs | 4 +- src/memory/mod.rs | 1 + src/memory/retrieval.rs | 283 +++++++++++++++++++++++++++++++++++++ 4 files changed, 290 insertions(+), 4 deletions(-) create mode 100644 src/memory/retrieval.rs diff --git a/src/agent/loop_/context.rs b/src/agent/loop_/context.rs index 668ea0d18..cb36ec48c 100644 --- a/src/agent/loop_/context.rs +++ b/src/agent/loop_/context.rs @@ -1,4 +1,4 @@ -use crate::memory::{self, decay, Memory, MemoryCategory}; +use crate::memory::{self, decay, retrieval, Memory, MemoryCategory}; use std::fmt::Write; /// Default half-life (days) for time decay in context building. @@ -16,6 +16,7 @@ const CONTEXT_ENTRY_LIMIT: usize = 5; const RECALL_OVER_FETCH_FACTOR: usize = 2; /// Build context preamble by searching memory for relevant entries. +/// Uses enhanced retrieval (multi-query + Core boosting) for better coverage. /// Entries with a hybrid score below `min_relevance_score` are dropped to /// prevent unrelated memories from bleeding into the conversation. /// @@ -34,7 +35,8 @@ pub(super) async fn build_context( // Over-fetch so Core-boosted entries can compete fairly after re-ranking. let fetch_limit = CONTEXT_ENTRY_LIMIT * RECALL_OVER_FETCH_FACTOR; - if let Ok(mut entries) = mem.recall(user_msg, fetch_limit, session_id).await { + let mut entries = retrieval::enhanced_recall(mem, user_msg, fetch_limit, session_id).await; + if !entries.is_empty() { // Apply time decay: older non-Core memories score lower. decay::apply_time_decay(&mut entries, CONTEXT_DECAY_HALF_LIFE_DAYS); diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs index a2aa85be2..333c930c2 100644 --- a/src/agent/memory_loader.rs +++ b/src/agent/memory_loader.rs @@ -1,4 +1,4 @@ -use crate::memory::{self, decay, Memory, MemoryCategory}; +use crate::memory::{self, decay, retrieval, Memory, MemoryCategory}; use async_trait::async_trait; use std::fmt::Write; @@ -51,7 +51,7 @@ impl MemoryLoader for DefaultMemoryLoader { ) -> anyhow::Result { // Over-fetch so Core-boosted entries can compete fairly after re-ranking. let fetch_limit = self.limit * RECALL_OVER_FETCH_FACTOR; - let mut entries = memory.recall(user_message, fetch_limit, None).await?; + let mut entries = retrieval::enhanced_recall(memory, user_message, fetch_limit, None).await; if entries.is_empty() { return Ok(String::new()); } diff --git a/src/memory/mod.rs b/src/memory/mod.rs index d6227f5a1..ac54ae983 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -13,6 +13,7 @@ pub mod none; pub mod postgres; pub mod qdrant; pub mod response_cache; +pub mod retrieval; pub mod snapshot; pub mod sqlite; pub mod traits; diff --git a/src/memory/retrieval.rs b/src/memory/retrieval.rs new file mode 100644 index 000000000..a773fe5ba --- /dev/null +++ b/src/memory/retrieval.rs @@ -0,0 +1,283 @@ +use super::traits::{Memory, MemoryEntry}; + +/// Minimum message length (chars) to trigger keyword expansion query. +const MIN_EXPANSION_LENGTH: usize = 30; + +/// Minimum word length to keep in keyword extraction. +const MIN_KEYWORD_LENGTH: usize = 4; + +/// Enhanced memory retrieval with multi-query expansion. +/// +/// 1. Runs the primary recall with the full query. +/// 2. For long messages, extracts significant keywords and runs a second recall, +/// merging results (deduplicated by key, keeping the higher score). +/// 3. Returns the top `limit` entries sorted by score descending. +pub async fn enhanced_recall( + mem: &dyn Memory, + query: &str, + limit: usize, + session_id: Option<&str>, +) -> Vec { + // Primary recall with full query + let mut results = mem.recall(query, limit, session_id).await.unwrap_or_default(); + + // Multi-query expansion for long messages + if query.len() >= MIN_EXPANSION_LENGTH { + let keywords = extract_keywords(query); + if !keywords.is_empty() && keywords != query.trim() { + if let Ok(extra) = mem.recall(&keywords, limit, session_id).await { + merge_entries(&mut results, extra); + } + } + } + + // Sort by score descending, take top `limit` + results.sort_by(|a, b| { + b.score + .unwrap_or(0.0) + .partial_cmp(&a.score.unwrap_or(0.0)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + results.truncate(limit); + + results +} + +/// Extract significant keywords (length >= 4) from a message. +fn extract_keywords(msg: &str) -> String { + msg.split_whitespace() + .filter_map(|w| { + let clean = w.trim_matches(|c: char| !c.is_alphanumeric()); + if clean.len() >= MIN_KEYWORD_LENGTH { + Some(clean) + } else { + None + } + }) + .collect::>() + .join(" ") +} + +/// Merge extra entries into results, deduplicating by key (keep highest score). +fn merge_entries(results: &mut Vec, extra: Vec) { + for entry in extra { + if let Some(existing) = results.iter_mut().find(|r| r.key == entry.key) { + if entry.score.unwrap_or(0.0) > existing.score.unwrap_or(0.0) { + existing.score = entry.score; + } + } else { + results.push(entry); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::traits::MemoryCategory; + use async_trait::async_trait; + + #[test] + fn extract_keywords_filters_short_words() { + assert_eq!( + extract_keywords("I want to use PostgreSQL for the database"), + "want PostgreSQL database" + ); + } + + #[test] + fn extract_keywords_strips_punctuation() { + // trim_matches strips non-alphanumeric from both ends: + // "config?" -> "config", "settings." -> "settings", "what's" stays (apostrophe is internal) + assert_eq!( + extract_keywords("what's the config? check settings."), + "what's config check settings" + ); + } + + #[test] + fn extract_keywords_empty_for_short_words() { + assert_eq!(extract_keywords("I am ok"), ""); + } + + #[test] + fn merge_entries_deduplicates_by_key_keeping_higher_score() { + let mut results = vec![MemoryEntry { + id: "1".into(), + key: "db".into(), + content: "PostgreSQL".into(), + category: MemoryCategory::Core, + timestamp: "now".into(), + session_id: None, + score: Some(0.6), + }]; + let extra = vec![ + MemoryEntry { + id: "1b".into(), + key: "db".into(), + content: "PostgreSQL".into(), + category: MemoryCategory::Core, + timestamp: "now".into(), + session_id: None, + score: Some(0.9), // higher + }, + MemoryEntry { + id: "2".into(), + key: "lang".into(), + content: "Rust".into(), + category: MemoryCategory::Core, + timestamp: "now".into(), + session_id: None, + score: Some(0.7), + }, + ]; + merge_entries(&mut results, extra); + assert_eq!(results.len(), 2); + assert_eq!(results[0].score, Some(0.9)); // upgraded + assert_eq!(results[1].key, "lang"); // new entry added + } + + struct MockMemory { + primary: Vec, + keyword: Vec, + call_count: std::sync::atomic::AtomicUsize, + } + + #[async_trait] + impl Memory for MockMemory { + async fn store( + &self, _k: &str, _c: &str, _cat: MemoryCategory, _s: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + async fn recall( + &self, _query: &str, _limit: usize, _s: Option<&str>, + ) -> anyhow::Result> { + // First call returns primary results, second call returns keyword results + let n = self.call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + Ok(self.primary.clone()) + } else { + Ok(self.keyword.clone()) + } + } + async fn get(&self, _k: &str) -> anyhow::Result> { + Ok(None) + } + async fn list( + &self, _c: Option<&MemoryCategory>, _s: Option<&str>, + ) -> anyhow::Result> { + Ok(vec![]) + } + async fn forget(&self, _k: &str) -> anyhow::Result { + Ok(true) + } + async fn count(&self) -> anyhow::Result { + Ok(0) + } + async fn health_check(&self) -> bool { + true + } + fn name(&self) -> &str { + "mock" + } + } + + #[tokio::test] + async fn enhanced_recall_merges_primary_and_keyword_results() { + let mem = MockMemory { + primary: vec![MemoryEntry { + id: "1".into(), + key: "db".into(), + content: "PostgreSQL".into(), + category: MemoryCategory::Core, + timestamp: "now".into(), + session_id: None, + score: Some(0.7), + }], + keyword: vec![MemoryEntry { + id: "2".into(), + key: "lang".into(), + content: "Rust".into(), + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score: Some(0.6), + }], + call_count: std::sync::atomic::AtomicUsize::new(0), + }; + + // Long query triggers expansion + let query = "what database and programming language should we use for this project"; + let results = enhanced_recall(&mem, query, 5, None).await; + + assert_eq!(results.len(), 2); + // "db" has higher score (0.7), ranked first + assert_eq!(results[0].key, "db"); + assert_eq!(results[0].score, Some(0.7)); + // "lang" from keyword expansion + assert_eq!(results[1].key, "lang"); + assert_eq!(results[1].score, Some(0.6)); + } + + #[tokio::test] + async fn enhanced_recall_skips_expansion_for_short_query() { + let mem = MockMemory { + primary: vec![MemoryEntry { + id: "1".into(), + key: "db".into(), + content: "PostgreSQL".into(), + category: MemoryCategory::Core, + timestamp: "now".into(), + session_id: None, + score: Some(0.7), + }], + keyword: vec![MemoryEntry { + id: "2".into(), + key: "lang".into(), + content: "Rust".into(), + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score: Some(0.6), + }], + call_count: std::sync::atomic::AtomicUsize::new(0), + }; + + // Short query — no expansion, so "keyword" recall is what gets returned + // (because our mock returns keyword results for short queries) + let results = enhanced_recall(&mem, "database?", 5, None).await; + + // Only keyword results returned (mock behavior), no merge + assert_eq!(results.len(), 1); + } + + #[tokio::test] + async fn enhanced_recall_respects_limit() { + let mem = MockMemory { + primary: (0..10) + .map(|i| MemoryEntry { + id: format!("{i}"), + key: format!("key_{i}"), + content: format!("val_{i}"), + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score: Some(0.5 + i as f64 * 0.01), + }) + .collect(), + keyword: vec![], + call_count: std::sync::atomic::AtomicUsize::new(0), + }; + + let results = enhanced_recall( + &mem, + "a very long query that definitely triggers keyword expansion for testing purposes", + 3, + None, + ) + .await; + + assert_eq!(results.len(), 3); + } +} From d80b535f5b637d9e4d7d982703f57fc8dc5f26dd Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 2 Mar 2026 17:53:43 -0500 Subject: [PATCH 2/2] fix(memory): propagate primary recall errors in enhanced retrieval --- src/agent/loop_/context.rs | 9 +++- src/agent/memory_loader.rs | 64 +++++++++++++++++++++++- src/memory/retrieval.rs | 99 +++++++++++++++++++++++++++++++++----- 3 files changed, 157 insertions(+), 15 deletions(-) diff --git a/src/agent/loop_/context.rs b/src/agent/loop_/context.rs index cb36ec48c..8721b80f5 100644 --- a/src/agent/loop_/context.rs +++ b/src/agent/loop_/context.rs @@ -35,8 +35,13 @@ pub(super) async fn build_context( // Over-fetch so Core-boosted entries can compete fairly after re-ranking. let fetch_limit = CONTEXT_ENTRY_LIMIT * RECALL_OVER_FETCH_FACTOR; - let mut entries = retrieval::enhanced_recall(mem, user_msg, fetch_limit, session_id).await; - if !entries.is_empty() { + if let Ok(mut entries) = + retrieval::enhanced_recall(mem, user_msg, fetch_limit, session_id).await + { + if entries.is_empty() { + return context; + } + // Apply time decay: older non-Core memories score lower. decay::apply_time_decay(&mut entries, CONTEXT_DECAY_HALF_LIFE_DAYS); diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs index 333c930c2..7e022f630 100644 --- a/src/agent/memory_loader.rs +++ b/src/agent/memory_loader.rs @@ -51,7 +51,8 @@ impl MemoryLoader for DefaultMemoryLoader { ) -> anyhow::Result { // Over-fetch so Core-boosted entries can compete fairly after re-ranking. let fetch_limit = self.limit * RECALL_OVER_FETCH_FACTOR; - let mut entries = retrieval::enhanced_recall(memory, user_message, fetch_limit, None).await; + let mut entries = + retrieval::enhanced_recall(memory, user_message, fetch_limit, None).await?; if entries.is_empty() { return Ok(String::new()); } @@ -105,6 +106,7 @@ mod tests { struct MockMemoryWithEntries { entries: Arc>, } + struct FailingRecallMemory; #[async_trait] impl Memory for MockMemory { @@ -217,6 +219,56 @@ mod tests { } } + #[async_trait] + impl Memory for FailingRecallMemory { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Err(anyhow::anyhow!("memory backend unavailable")) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(vec![]) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(true) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "failing-recall-memory" + } + } + #[tokio::test] async fn default_loader_formats_context() { let loader = DefaultMemoryLoader::default(); @@ -345,4 +397,14 @@ mod tests { "Conversation should be truncated when limit=1: {context}" ); } + + #[tokio::test] + async fn default_loader_propagates_primary_recall_errors() { + let loader = DefaultMemoryLoader::default(); + let err = loader + .load_context(&FailingRecallMemory, "hello") + .await + .expect_err("expected memory loader to propagate primary recall failure"); + assert!(err.to_string().contains("memory backend unavailable")); + } } diff --git a/src/memory/retrieval.rs b/src/memory/retrieval.rs index a773fe5ba..6f288a37d 100644 --- a/src/memory/retrieval.rs +++ b/src/memory/retrieval.rs @@ -17,9 +17,9 @@ pub async fn enhanced_recall( query: &str, limit: usize, session_id: Option<&str>, -) -> Vec { +) -> anyhow::Result> { // Primary recall with full query - let mut results = mem.recall(query, limit, session_id).await.unwrap_or_default(); + let mut results = mem.recall(query, limit, session_id).await?; // Multi-query expansion for long messages if query.len() >= MIN_EXPANSION_LENGTH { @@ -40,7 +40,7 @@ pub async fn enhanced_recall( }); results.truncate(limit); - results + Ok(results) } /// Extract significant keywords (length >= 4) from a message. @@ -140,32 +140,53 @@ mod tests { struct MockMemory { primary: Vec, keyword: Vec, + fail_primary: bool, + fail_keyword: bool, call_count: std::sync::atomic::AtomicUsize, } #[async_trait] impl Memory for MockMemory { async fn store( - &self, _k: &str, _c: &str, _cat: MemoryCategory, _s: Option<&str>, + &self, + _k: &str, + _c: &str, + _cat: MemoryCategory, + _s: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } async fn recall( - &self, _query: &str, _limit: usize, _s: Option<&str>, + &self, + _query: &str, + _limit: usize, + _s: Option<&str>, ) -> anyhow::Result> { // First call returns primary results, second call returns keyword results - let n = self.call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let n = self + .call_count + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); if n == 0 { - Ok(self.primary.clone()) + if self.fail_primary { + Err(anyhow::anyhow!("primary recall failed")) + } else { + Ok(self.primary.clone()) + } } else { - Ok(self.keyword.clone()) + if self.fail_keyword { + Err(anyhow::anyhow!("keyword recall failed")) + } else { + Ok(self.keyword.clone()) + } } } async fn get(&self, _k: &str) -> anyhow::Result> { Ok(None) } async fn list( - &self, _c: Option<&MemoryCategory>, _s: Option<&str>, + &self, + _c: Option<&MemoryCategory>, + _s: Option<&str>, ) -> anyhow::Result> { Ok(vec![]) } @@ -204,12 +225,14 @@ mod tests { session_id: None, score: Some(0.6), }], + fail_primary: false, + fail_keyword: false, call_count: std::sync::atomic::AtomicUsize::new(0), }; // Long query triggers expansion let query = "what database and programming language should we use for this project"; - let results = enhanced_recall(&mem, query, 5, None).await; + let results = enhanced_recall(&mem, query, 5, None).await.unwrap(); assert_eq!(results.len(), 2); // "db" has higher score (0.7), ranked first @@ -241,12 +264,14 @@ mod tests { session_id: None, score: Some(0.6), }], + fail_primary: false, + fail_keyword: false, call_count: std::sync::atomic::AtomicUsize::new(0), }; // Short query — no expansion, so "keyword" recall is what gets returned // (because our mock returns keyword results for short queries) - let results = enhanced_recall(&mem, "database?", 5, None).await; + let results = enhanced_recall(&mem, "database?", 5, None).await.unwrap(); // Only keyword results returned (mock behavior), no merge assert_eq!(results.len(), 1); @@ -267,6 +292,8 @@ mod tests { }) .collect(), keyword: vec![], + fail_primary: false, + fail_keyword: false, call_count: std::sync::atomic::AtomicUsize::new(0), }; @@ -276,8 +303,56 @@ mod tests { 3, None, ) - .await; + .await + .unwrap(); assert_eq!(results.len(), 3); } + + #[tokio::test] + async fn enhanced_recall_propagates_primary_recall_errors() { + let mem = MockMemory { + primary: vec![], + keyword: vec![], + fail_primary: true, + fail_keyword: false, + call_count: std::sync::atomic::AtomicUsize::new(0), + }; + + let err = enhanced_recall(&mem, "long enough query to trigger expansion", 5, None) + .await + .expect_err("expected primary recall error to propagate"); + assert!(err.to_string().contains("primary recall failed")); + } + + #[tokio::test] + async fn enhanced_recall_tolerates_keyword_recall_errors() { + let mem = MockMemory { + primary: vec![MemoryEntry { + id: "1".into(), + key: "db".into(), + content: "PostgreSQL".into(), + category: MemoryCategory::Core, + timestamp: "now".into(), + session_id: None, + score: Some(0.7), + }], + keyword: vec![], + fail_primary: false, + fail_keyword: true, + call_count: std::sync::atomic::AtomicUsize::new(0), + }; + + let results = enhanced_recall( + &mem, + "what database and programming language should we use for this project", + 5, + None, + ) + .await + .unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "db"); + } }