diff --git a/src/memory/traits.rs b/src/memory/traits.rs index de72923d3..63f766ad6 100644 --- a/src/memory/traits.rs +++ b/src/memory/traits.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// A single memory entry #[derive(Clone, Serialize, Deserialize)] @@ -27,8 +27,7 @@ impl std::fmt::Debug for MemoryEntry { } /// Memory categories for organization -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum MemoryCategory { /// Long-term facts, preferences, decisions Core, @@ -51,6 +50,30 @@ impl std::fmt::Display for MemoryCategory { } } +impl Serialize for MemoryCategory { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for MemoryCategory { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let raw = String::deserialize(deserializer)?; + Ok(match raw.as_str() { + "core" => Self::Core, + "daily" => Self::Daily, + "conversation" => Self::Conversation, + other => Self::Custom(other.to_string()), + }) + } +} + /// Core memory trait — implement for any persistence backend #[async_trait] pub trait Memory: Send + Sync { @@ -110,14 +133,33 @@ mod tests { } #[test] - fn memory_category_serde_uses_snake_case() { + fn memory_category_serde_roundtrip_uses_plain_strings() { let core = serde_json::to_string(&MemoryCategory::Core).unwrap(); let daily = serde_json::to_string(&MemoryCategory::Daily).unwrap(); let conversation = serde_json::to_string(&MemoryCategory::Conversation).unwrap(); + let custom = serde_json::to_string(&MemoryCategory::Custom("travel".into())).unwrap(); assert_eq!(core, "\"core\""); assert_eq!(daily, "\"daily\""); assert_eq!(conversation, "\"conversation\""); + assert_eq!(custom, "\"travel\""); + + assert_eq!( + serde_json::from_str::("\"core\"").unwrap(), + MemoryCategory::Core + ); + assert_eq!( + serde_json::from_str::("\"daily\"").unwrap(), + MemoryCategory::Daily + ); + assert_eq!( + serde_json::from_str::("\"conversation\"").unwrap(), + MemoryCategory::Conversation + ); + assert_eq!( + serde_json::from_str::("\"travel\"").unwrap(), + MemoryCategory::Custom("travel".into()) + ); } #[test] @@ -142,4 +184,20 @@ mod tests { assert_eq!(parsed.session_id.as_deref(), Some("session-abc")); assert_eq!(parsed.score, Some(0.98)); } + + #[test] + fn memory_entry_serializes_custom_category_as_plain_string() { + let entry = MemoryEntry { + id: "id-2".into(), + key: "trip".into(), + content: "booked a flight".into(), + category: MemoryCategory::Custom("travel".into()), + timestamp: "2026-03-04T00:00:00Z".into(), + session_id: None, + score: None, + }; + + let json = serde_json::to_value(&entry).unwrap(); + assert_eq!(json.get("category").unwrap(), "travel"); + } }