225 lines
7.2 KiB
Rust
225 lines
7.2 KiB
Rust
use super::traits::{Tool, ToolResult};
|
|
use crate::memory::{Memory, MemoryCategory};
|
|
use crate::security::policy::ToolOperation;
|
|
use crate::security::SecurityPolicy;
|
|
use async_trait::async_trait;
|
|
use serde_json::json;
|
|
use std::sync::Arc;
|
|
|
|
/// Let the agent store memories — its own brain writes
|
|
pub struct MemoryStoreTool {
|
|
memory: Arc<dyn Memory>,
|
|
security: Arc<SecurityPolicy>,
|
|
}
|
|
|
|
impl MemoryStoreTool {
|
|
pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
|
|
Self { memory, security }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for MemoryStoreTool {
|
|
fn name(&self) -> &str {
|
|
"memory_store"
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
"Store a fact, preference, or note in long-term memory. Use category 'core' for permanent facts, 'daily' for session notes, 'conversation' for chat context, or a custom category name."
|
|
}
|
|
|
|
fn parameters_schema(&self) -> serde_json::Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"key": {
|
|
"type": "string",
|
|
"description": "Unique key for this memory (e.g. 'user_lang', 'project_stack')"
|
|
},
|
|
"content": {
|
|
"type": "string",
|
|
"description": "The information to remember"
|
|
},
|
|
"category": {
|
|
"type": "string",
|
|
"description": "Memory category: 'core' (permanent), 'daily' (session), 'conversation' (chat), or a custom category name. Defaults to 'core'."
|
|
}
|
|
},
|
|
"required": ["key", "content"]
|
|
})
|
|
}
|
|
|
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
let key = args
|
|
.get("key")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
|
|
|
|
let content = args
|
|
.get("content")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
|
|
|
|
let category = match args.get("category").and_then(|v| v.as_str()) {
|
|
Some("core") | None => MemoryCategory::Core,
|
|
Some("daily") => MemoryCategory::Daily,
|
|
Some("conversation") => MemoryCategory::Conversation,
|
|
Some(other) => MemoryCategory::Custom(other.to_string()),
|
|
};
|
|
|
|
if let Err(error) = self
|
|
.security
|
|
.enforce_tool_operation(ToolOperation::Act, "memory_store")
|
|
{
|
|
return Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(error),
|
|
});
|
|
}
|
|
|
|
match self.memory.store(key, content, category, None).await {
|
|
Ok(()) => Ok(ToolResult {
|
|
success: true,
|
|
output: format!("Stored memory: {key}"),
|
|
error: None,
|
|
}),
|
|
Err(e) => Ok(ToolResult {
|
|
success: false,
|
|
output: String::new(),
|
|
error: Some(format!("Failed to store memory: {e}")),
|
|
}),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::memory::SqliteMemory;
|
|
use crate::security::{AutonomyLevel, SecurityPolicy};
|
|
use tempfile::TempDir;
|
|
|
|
fn test_security() -> Arc<SecurityPolicy> {
|
|
Arc::new(SecurityPolicy::default())
|
|
}
|
|
|
|
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
|
let tmp = TempDir::new().unwrap();
|
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
|
(tmp, Arc::new(mem))
|
|
}
|
|
|
|
#[test]
|
|
fn name_and_schema() {
|
|
let (_tmp, mem) = test_mem();
|
|
let tool = MemoryStoreTool::new(mem, test_security());
|
|
assert_eq!(tool.name(), "memory_store");
|
|
let schema = tool.parameters_schema();
|
|
assert!(schema["properties"]["key"].is_object());
|
|
assert!(schema["properties"]["content"].is_object());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_core() {
|
|
let (_tmp, mem) = test_mem();
|
|
let tool = MemoryStoreTool::new(mem.clone(), test_security());
|
|
let result = tool
|
|
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
|
.await
|
|
.unwrap();
|
|
assert!(result.success);
|
|
assert!(result.output.contains("lang"));
|
|
|
|
let entry = mem.get("lang").await.unwrap();
|
|
assert!(entry.is_some());
|
|
assert_eq!(entry.unwrap().content, "Prefers Rust");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_with_category() {
|
|
let (_tmp, mem) = test_mem();
|
|
let tool = MemoryStoreTool::new(mem.clone(), test_security());
|
|
let result = tool
|
|
.execute(json!({"key": "note", "content": "Fixed bug", "category": "daily"}))
|
|
.await
|
|
.unwrap();
|
|
assert!(result.success);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_with_custom_category() {
|
|
let (_tmp, mem) = test_mem();
|
|
let tool = MemoryStoreTool::new(mem.clone(), test_security());
|
|
let result = tool
|
|
.execute(
|
|
json!({"key": "proj_note", "content": "Uses async runtime", "category": "project"}),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert!(result.success);
|
|
|
|
let entry = mem.get("proj_note").await.unwrap().unwrap();
|
|
assert_eq!(entry.content, "Uses async runtime");
|
|
assert_eq!(entry.category, MemoryCategory::Custom("project".into()));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_missing_key() {
|
|
let (_tmp, mem) = test_mem();
|
|
let tool = MemoryStoreTool::new(mem, test_security());
|
|
let result = tool.execute(json!({"content": "no key"})).await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_missing_content() {
|
|
let (_tmp, mem) = test_mem();
|
|
let tool = MemoryStoreTool::new(mem, test_security());
|
|
let result = tool.execute(json!({"key": "no_content"})).await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_blocked_in_readonly_mode() {
|
|
let (_tmp, mem) = test_mem();
|
|
let readonly = Arc::new(SecurityPolicy {
|
|
autonomy: AutonomyLevel::ReadOnly,
|
|
..SecurityPolicy::default()
|
|
});
|
|
let tool = MemoryStoreTool::new(mem.clone(), readonly);
|
|
let result = tool
|
|
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
|
.await
|
|
.unwrap();
|
|
assert!(!result.success);
|
|
assert!(result
|
|
.error
|
|
.as_deref()
|
|
.unwrap_or("")
|
|
.contains("read-only mode"));
|
|
assert!(mem.get("lang").await.unwrap().is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn store_blocked_when_rate_limited() {
|
|
let (_tmp, mem) = test_mem();
|
|
let limited = Arc::new(SecurityPolicy {
|
|
max_actions_per_hour: 0,
|
|
..SecurityPolicy::default()
|
|
});
|
|
let tool = MemoryStoreTool::new(mem.clone(), limited);
|
|
let result = tool
|
|
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
|
.await
|
|
.unwrap();
|
|
assert!(!result.success);
|
|
assert!(result
|
|
.error
|
|
.as_deref()
|
|
.unwrap_or("")
|
|
.contains("Rate limit exceeded"));
|
|
assert!(mem.get("lang").await.unwrap().is_none());
|
|
}
|
|
}
|