fix(memory): serialize custom categories as plain strings

This commit is contained in:
argenis de la rosa 2026-03-04 00:20:33 -05:00
parent 32a2cf370d
commit 389d497a51

View File

@ -1,5 +1,5 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// A single memory entry
#[derive(Clone, Serialize, Deserialize)]
@ -27,8 +27,7 @@ impl std::fmt::Debug for MemoryEntry {
}
/// Memory categories for organization
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MemoryCategory {
/// Long-term facts, preferences, decisions
Core,
@ -51,6 +50,30 @@ impl std::fmt::Display for MemoryCategory {
}
}
impl Serialize for MemoryCategory {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for MemoryCategory {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let raw = String::deserialize(deserializer)?;
Ok(match raw.as_str() {
"core" => Self::Core,
"daily" => Self::Daily,
"conversation" => Self::Conversation,
other => Self::Custom(other.to_string()),
})
}
}
/// Core memory trait — implement for any persistence backend
#[async_trait]
pub trait Memory: Send + Sync {
@ -110,14 +133,33 @@ mod tests {
}
#[test]
fn memory_category_serde_uses_snake_case() {
fn memory_category_serde_roundtrip_uses_plain_strings() {
let core = serde_json::to_string(&MemoryCategory::Core).unwrap();
let daily = serde_json::to_string(&MemoryCategory::Daily).unwrap();
let conversation = serde_json::to_string(&MemoryCategory::Conversation).unwrap();
let custom = serde_json::to_string(&MemoryCategory::Custom("travel".into())).unwrap();
assert_eq!(core, "\"core\"");
assert_eq!(daily, "\"daily\"");
assert_eq!(conversation, "\"conversation\"");
assert_eq!(custom, "\"travel\"");
assert_eq!(
serde_json::from_str::<MemoryCategory>("\"core\"").unwrap(),
MemoryCategory::Core
);
assert_eq!(
serde_json::from_str::<MemoryCategory>("\"daily\"").unwrap(),
MemoryCategory::Daily
);
assert_eq!(
serde_json::from_str::<MemoryCategory>("\"conversation\"").unwrap(),
MemoryCategory::Conversation
);
assert_eq!(
serde_json::from_str::<MemoryCategory>("\"travel\"").unwrap(),
MemoryCategory::Custom("travel".into())
);
}
#[test]
@ -142,4 +184,20 @@ mod tests {
assert_eq!(parsed.session_id.as_deref(), Some("session-abc"));
assert_eq!(parsed.score, Some(0.98));
}
#[test]
fn memory_entry_serializes_custom_category_as_plain_string() {
let entry = MemoryEntry {
id: "id-2".into(),
key: "trip".into(),
content: "booked a flight".into(),
category: MemoryCategory::Custom("travel".into()),
timestamp: "2026-03-04T00:00:00Z".into(),
session_id: None,
score: None,
};
let json = serde_json::to_value(&entry).unwrap();
assert_eq!(json.get("category").unwrap(), "travel");
}
}