204 lines
6.3 KiB
Rust
204 lines
6.3 KiB
Rust
use async_trait::async_trait;
|
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
|
|
/// A single memory entry
|
|
#[derive(Clone, Serialize, Deserialize)]
|
|
pub struct MemoryEntry {
|
|
pub id: String,
|
|
pub key: String,
|
|
pub content: String,
|
|
pub category: MemoryCategory,
|
|
pub timestamp: String,
|
|
pub session_id: Option<String>,
|
|
pub score: Option<f64>,
|
|
}
|
|
|
|
impl std::fmt::Debug for MemoryEntry {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("MemoryEntry")
|
|
.field("id", &self.id)
|
|
.field("key", &self.key)
|
|
.field("content", &self.content)
|
|
.field("category", &self.category)
|
|
.field("timestamp", &self.timestamp)
|
|
.field("score", &self.score)
|
|
.finish_non_exhaustive()
|
|
}
|
|
}
|
|
|
|
/// Memory categories for organization
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum MemoryCategory {
|
|
/// Long-term facts, preferences, decisions
|
|
Core,
|
|
/// Daily session logs
|
|
Daily,
|
|
/// Conversation context
|
|
Conversation,
|
|
/// User-defined custom category
|
|
Custom(String),
|
|
}
|
|
|
|
impl std::fmt::Display for MemoryCategory {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::Core => write!(f, "core"),
|
|
Self::Daily => write!(f, "daily"),
|
|
Self::Conversation => write!(f, "conversation"),
|
|
Self::Custom(name) => write!(f, "{name}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Serialize for MemoryCategory {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
serializer.serialize_str(&self.to_string())
|
|
}
|
|
}
|
|
|
|
impl<'de> Deserialize<'de> for MemoryCategory {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
let raw = String::deserialize(deserializer)?;
|
|
Ok(match raw.as_str() {
|
|
"core" => Self::Core,
|
|
"daily" => Self::Daily,
|
|
"conversation" => Self::Conversation,
|
|
other => Self::Custom(other.to_string()),
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Core memory trait — implement for any persistence backend
|
|
#[async_trait]
|
|
pub trait Memory: Send + Sync {
|
|
/// Backend name
|
|
fn name(&self) -> &str;
|
|
|
|
/// Store a memory entry, optionally scoped to a session
|
|
async fn store(
|
|
&self,
|
|
key: &str,
|
|
content: &str,
|
|
category: MemoryCategory,
|
|
session_id: Option<&str>,
|
|
) -> anyhow::Result<()>;
|
|
|
|
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
|
async fn recall(
|
|
&self,
|
|
query: &str,
|
|
limit: usize,
|
|
session_id: Option<&str>,
|
|
) -> anyhow::Result<Vec<MemoryEntry>>;
|
|
|
|
/// Get a specific memory by key
|
|
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
|
|
|
/// List all memory keys, optionally filtered by category and/or session
|
|
async fn list(
|
|
&self,
|
|
category: Option<&MemoryCategory>,
|
|
session_id: Option<&str>,
|
|
) -> anyhow::Result<Vec<MemoryEntry>>;
|
|
|
|
/// Remove a memory by key
|
|
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
|
|
|
/// Count total memories
|
|
async fn count(&self) -> anyhow::Result<usize>;
|
|
|
|
/// Health check
|
|
async fn health_check(&self) -> bool;
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn memory_category_display_outputs_expected_values() {
|
|
assert_eq!(MemoryCategory::Core.to_string(), "core");
|
|
assert_eq!(MemoryCategory::Daily.to_string(), "daily");
|
|
assert_eq!(MemoryCategory::Conversation.to_string(), "conversation");
|
|
assert_eq!(
|
|
MemoryCategory::Custom("project_notes".into()).to_string(),
|
|
"project_notes"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn memory_category_serde_roundtrip_uses_plain_strings() {
|
|
let core = serde_json::to_string(&MemoryCategory::Core).unwrap();
|
|
let daily = serde_json::to_string(&MemoryCategory::Daily).unwrap();
|
|
let conversation = serde_json::to_string(&MemoryCategory::Conversation).unwrap();
|
|
let custom = serde_json::to_string(&MemoryCategory::Custom("travel".into())).unwrap();
|
|
|
|
assert_eq!(core, "\"core\"");
|
|
assert_eq!(daily, "\"daily\"");
|
|
assert_eq!(conversation, "\"conversation\"");
|
|
assert_eq!(custom, "\"travel\"");
|
|
|
|
assert_eq!(
|
|
serde_json::from_str::<MemoryCategory>("\"core\"").unwrap(),
|
|
MemoryCategory::Core
|
|
);
|
|
assert_eq!(
|
|
serde_json::from_str::<MemoryCategory>("\"daily\"").unwrap(),
|
|
MemoryCategory::Daily
|
|
);
|
|
assert_eq!(
|
|
serde_json::from_str::<MemoryCategory>("\"conversation\"").unwrap(),
|
|
MemoryCategory::Conversation
|
|
);
|
|
assert_eq!(
|
|
serde_json::from_str::<MemoryCategory>("\"travel\"").unwrap(),
|
|
MemoryCategory::Custom("travel".into())
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn memory_entry_roundtrip_preserves_optional_fields() {
|
|
let entry = MemoryEntry {
|
|
id: "id-1".into(),
|
|
key: "favorite_language".into(),
|
|
content: "Rust".into(),
|
|
category: MemoryCategory::Core,
|
|
timestamp: "2026-02-16T00:00:00Z".into(),
|
|
session_id: Some("session-abc".into()),
|
|
score: Some(0.98),
|
|
};
|
|
|
|
let json = serde_json::to_string(&entry).unwrap();
|
|
let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(parsed.id, "id-1");
|
|
assert_eq!(parsed.key, "favorite_language");
|
|
assert_eq!(parsed.content, "Rust");
|
|
assert_eq!(parsed.category, MemoryCategory::Core);
|
|
assert_eq!(parsed.session_id.as_deref(), Some("session-abc"));
|
|
assert_eq!(parsed.score, Some(0.98));
|
|
}
|
|
|
|
#[test]
|
|
fn memory_entry_serializes_custom_category_as_plain_string() {
|
|
let entry = MemoryEntry {
|
|
id: "id-2".into(),
|
|
key: "trip".into(),
|
|
content: "booked a flight".into(),
|
|
category: MemoryCategory::Custom("travel".into()),
|
|
timestamp: "2026-03-04T00:00:00Z".into(),
|
|
session_id: None,
|
|
score: None,
|
|
};
|
|
|
|
let json = serde_json::to_value(&entry).unwrap();
|
|
assert_eq!(json.get("category").unwrap(), "travel");
|
|
}
|
|
}
|