Merge pull request #3712 from zeroclaw-labs/feat/competitive-edge-heartbeat-sessions-caching
feat: competitive edge — heartbeat metrics, SQLite sessions, two-tier prompt cache
This commit is contained in:
commit
566e3cf35b
@ -38,6 +38,7 @@ pub struct Agent {
|
||||
available_hints: Vec<String>,
|
||||
route_model_by_hint: HashMap<String, String>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
response_cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder {
|
||||
@ -60,6 +61,7 @@ pub struct AgentBuilder {
|
||||
available_hints: Option<Vec<String>>,
|
||||
route_model_by_hint: Option<HashMap<String, String>>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
response_cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
@ -84,6 +86,7 @@ impl AgentBuilder {
|
||||
available_hints: None,
|
||||
route_model_by_hint: None,
|
||||
allowed_tools: None,
|
||||
response_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,6 +191,14 @@ impl AgentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_cache(
|
||||
mut self,
|
||||
cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
) -> Self {
|
||||
self.response_cache = cache;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Agent> {
|
||||
let mut tools = self
|
||||
.tools
|
||||
@ -236,6 +247,7 @@ impl AgentBuilder {
|
||||
available_hints: self.available_hints.unwrap_or_default(),
|
||||
route_model_by_hint: self.route_model_by_hint.unwrap_or_default(),
|
||||
allowed_tools: allowed,
|
||||
response_cache: self.response_cache,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -330,11 +342,25 @@ impl Agent {
|
||||
.collect();
|
||||
let available_hints: Vec<String> = route_model_by_hint.keys().cloned().collect();
|
||||
|
||||
let response_cache = if config.memory.response_cache_enabled {
|
||||
crate::memory::response_cache::ResponseCache::with_hot_cache(
|
||||
&config.workspace_dir,
|
||||
config.memory.response_cache_ttl_minutes,
|
||||
config.memory.response_cache_max_entries,
|
||||
config.memory.response_cache_hot_entries,
|
||||
)
|
||||
.ok()
|
||||
.map(Arc::new)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(tools)
|
||||
.memory(memory)
|
||||
.observer(observer)
|
||||
.response_cache(response_cache)
|
||||
.tool_dispatcher(tool_dispatcher)
|
||||
.memory_loader(Box::new(DefaultMemoryLoader::new(
|
||||
5,
|
||||
@ -513,6 +539,47 @@ impl Agent {
|
||||
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
|
||||
// Response cache: check before LLM call (only for deterministic, text-only prompts)
|
||||
let cache_key = if self.temperature == 0.0 {
|
||||
self.response_cache.as_ref().map(|_| {
|
||||
let last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.as_str())
|
||||
.unwrap_or("");
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
crate::memory::response_cache::ResponseCache::cache_key(
|
||||
&effective_model,
|
||||
system,
|
||||
last_user,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
if let Ok(Some(cached)) = cache.get(key) {
|
||||
self.observer.record_event(&ObserverEvent::CacheHit {
|
||||
cache_type: "response".into(),
|
||||
tokens_saved: 0,
|
||||
});
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
cached.clone(),
|
||||
)));
|
||||
self.trim_history();
|
||||
return Ok(cached);
|
||||
}
|
||||
self.observer.record_event(&ObserverEvent::CacheMiss {
|
||||
cache_type: "response".into(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.provider
|
||||
.chat(
|
||||
@ -541,6 +608,17 @@ impl Agent {
|
||||
text
|
||||
};
|
||||
|
||||
// Store in response cache (text-only, no tool calls)
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
let token_count = response
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|u| u.output_tokens)
|
||||
.unwrap_or(0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let _ = cache.put(key, &effective_model, &final_text, token_count as u32);
|
||||
}
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
final_text.clone(),
|
||||
|
||||
@ -3977,6 +3977,7 @@ mod tests {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -32,6 +32,8 @@ pub mod nextcloud_talk;
|
||||
pub mod nostr;
|
||||
pub mod notion;
|
||||
pub mod qq;
|
||||
pub mod session_backend;
|
||||
pub mod session_sqlite;
|
||||
pub mod session_store;
|
||||
pub mod signal;
|
||||
pub mod slack;
|
||||
|
||||
103
src/channels/session_backend.rs
Normal file
103
src/channels/session_backend.rs
Normal file
@ -0,0 +1,103 @@
|
||||
//! Trait abstraction for session persistence backends.
|
||||
//!
|
||||
//! Backends store per-sender conversation histories. The trait is intentionally
|
||||
//! minimal — load, append, remove_last, list — so that JSONL and SQLite (and
|
||||
//! future backends) share a common interface.
|
||||
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Metadata about a persisted session.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionMetadata {
|
||||
/// Session key (e.g. `telegram_user123`).
|
||||
pub key: String,
|
||||
/// When the session was first created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the last message was appended.
|
||||
pub last_activity: DateTime<Utc>,
|
||||
/// Total number of messages in the session.
|
||||
pub message_count: usize,
|
||||
}
|
||||
|
||||
/// Query parameters for listing sessions.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SessionQuery {
|
||||
/// Keyword to search in session messages (FTS5 if available).
|
||||
pub keyword: Option<String>,
|
||||
/// Maximum number of sessions to return.
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
/// Trait for session persistence backends.
|
||||
///
|
||||
/// Implementations must be `Send + Sync` for sharing across async tasks.
|
||||
pub trait SessionBackend: Send + Sync {
|
||||
/// Load all messages for a session. Returns empty vec if session doesn't exist.
|
||||
fn load(&self, session_key: &str) -> Vec<ChatMessage>;
|
||||
|
||||
/// Append a single message to a session.
|
||||
fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()>;
|
||||
|
||||
/// Remove the last message from a session. Returns `true` if a message was removed.
|
||||
fn remove_last(&self, session_key: &str) -> std::io::Result<bool>;
|
||||
|
||||
/// List all session keys.
|
||||
fn list_sessions(&self) -> Vec<String>;
|
||||
|
||||
/// List sessions with metadata.
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
// Default: construct metadata from messages (backends can override for efficiency)
|
||||
self.list_sessions()
|
||||
.into_iter()
|
||||
.map(|key| {
|
||||
let messages = self.load(&key);
|
||||
SessionMetadata {
|
||||
key,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: messages.len(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compact a session file (remove duplicates/corruption). No-op by default.
|
||||
fn compact(&self, _session_key: &str) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove sessions that haven't been active within the given TTL hours.
|
||||
fn cleanup_stale(&self, _ttl_hours: u32) -> std::io::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Search sessions by keyword. Default returns empty (backends with FTS override).
|
||||
fn search(&self, _query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn session_metadata_is_constructible() {
|
||||
let meta = SessionMetadata {
|
||||
key: "test".into(),
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: 5,
|
||||
};
|
||||
assert_eq!(meta.key, "test");
|
||||
assert_eq!(meta.message_count, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_query_defaults() {
|
||||
let q = SessionQuery::default();
|
||||
assert!(q.keyword.is_none());
|
||||
assert!(q.limit.is_none());
|
||||
}
|
||||
}
|
||||
503
src/channels/session_sqlite.rs
Normal file
503
src/channels/session_sqlite.rs
Normal file
@ -0,0 +1,503 @@
|
||||
//! SQLite-backed session persistence with FTS5 search.
|
||||
//!
|
||||
//! Stores sessions in `{workspace}/sessions/sessions.db` using WAL mode.
|
||||
//! Provides full-text search via FTS5 and automatic TTL-based cleanup.
|
||||
//! Designed as the default backend, replacing JSONL for new installations.
|
||||
|
||||
use crate::channels::session_backend::{SessionBackend, SessionMetadata, SessionQuery};
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// SQLite-backed session store with FTS5 and WAL mode.
|
||||
pub struct SqliteSessionBackend {
|
||||
conn: Mutex<Connection>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl SqliteSessionBackend {
|
||||
/// Open or create the sessions database.
|
||||
pub fn new(workspace_dir: &Path) -> Result<Self> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir).context("Failed to create sessions directory")?;
|
||||
let db_path = sessions_dir.join("sessions.db");
|
||||
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open session DB: {}", db_path.display()))?;
|
||||
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA temp_store = MEMORY;
|
||||
PRAGMA mmap_size = 4194304;",
|
||||
)?;
|
||||
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_key TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_key ON sessions(session_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_key_id ON sessions(session_key, id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS session_metadata (
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
last_activity TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
|
||||
session_key, content, content=sessions, content_rowid=id
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS sessions_ai AFTER INSERT ON sessions BEGIN
|
||||
INSERT INTO sessions_fts(rowid, session_key, content)
|
||||
VALUES (new.id, new.session_key, new.content);
|
||||
END;
|
||||
CREATE TRIGGER IF NOT EXISTS sessions_ad AFTER DELETE ON sessions BEGIN
|
||||
INSERT INTO sessions_fts(sessions_fts, rowid, session_key, content)
|
||||
VALUES ('delete', old.id, old.session_key, old.content);
|
||||
END;",
|
||||
)
|
||||
.context("Failed to initialize session schema")?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Migrate JSONL session files into SQLite. Renames migrated files to `.jsonl.migrated`.
|
||||
pub fn migrate_from_jsonl(&self, workspace_dir: &Path) -> Result<usize> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
let entries = match std::fs::read_dir(&sessions_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Ok(0),
|
||||
};
|
||||
|
||||
let mut migrated = 0;
|
||||
for entry in entries {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let name = match entry.file_name().into_string() {
|
||||
Ok(n) => n,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let Some(key) = name.strip_suffix(".jsonl") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let path = entry.path();
|
||||
let file = match std::fs::File::open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut count = 0;
|
||||
for line in std::io::BufRead::lines(reader) {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<ChatMessage>(trimmed) {
|
||||
if self.append(key, &msg).is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
let migrated_path = path.with_extension("jsonl.migrated");
|
||||
let _ = std::fs::rename(&path, &migrated_path);
|
||||
migrated += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(migrated)
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionBackend for SqliteSessionBackend {
|
||||
fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn
|
||||
.prepare("SELECT role, content FROM sessions WHERE session_key = ?1 ORDER BY id ASC")
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let rows = match stmt.query_map(params![session_key], |row| {
|
||||
Ok(ChatMessage {
|
||||
role: row.get(0)?,
|
||||
content: row.get(1)?,
|
||||
})
|
||||
}) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
}
|
||||
|
||||
fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
let conn = self.conn.lock();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (session_key, role, content, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
params![session_key, message.role, message.content, now],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
// Upsert metadata
|
||||
conn.execute(
|
||||
"INSERT INTO session_metadata (session_key, created_at, last_activity, message_count)
|
||||
VALUES (?1, ?2, ?3, 1)
|
||||
ON CONFLICT(session_key) DO UPDATE SET
|
||||
last_activity = excluded.last_activity,
|
||||
message_count = message_count + 1",
|
||||
params![session_key, now, now],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_last(&self, session_key: &str) -> std::io::Result<bool> {
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let last_id: Option<i64> = conn
|
||||
.query_row(
|
||||
"SELECT id FROM sessions WHERE session_key = ?1 ORDER BY id DESC LIMIT 1",
|
||||
params![session_key],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.ok();
|
||||
|
||||
let Some(id) = last_id else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?1", params![id])
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
// Update metadata count
|
||||
conn.execute(
|
||||
"UPDATE session_metadata SET message_count = MAX(0, message_count - 1)
|
||||
WHERE session_key = ?1",
|
||||
params![session_key],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Vec<String> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn
|
||||
.prepare("SELECT session_key FROM session_metadata ORDER BY last_activity DESC")
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let rows = match stmt.query_map([], |row| row.get(0)) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
}
|
||||
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_key, created_at, last_activity, message_count
|
||||
FROM session_metadata ORDER BY last_activity DESC",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let rows = match stmt.query_map([], |row| {
|
||||
let key: String = row.get(0)?;
|
||||
let created_str: String = row.get(1)?;
|
||||
let activity_str: String = row.get(2)?;
|
||||
let count: i64 = row.get(3)?;
|
||||
|
||||
let created = DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
let activity = DateTime::parse_from_rfc3339(&activity_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(SessionMetadata {
|
||||
key,
|
||||
created_at: created,
|
||||
last_activity: activity,
|
||||
message_count: count as usize,
|
||||
})
|
||||
}) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
}
|
||||
|
||||
fn cleanup_stale(&self, ttl_hours: u32) -> std::io::Result<usize> {
|
||||
let conn = self.conn.lock();
|
||||
let cutoff = (Utc::now() - Duration::hours(i64::from(ttl_hours))).to_rfc3339();
|
||||
|
||||
// Find stale sessions
|
||||
let stale_keys: Vec<String> = {
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT session_key FROM session_metadata WHERE last_activity < ?1")
|
||||
.map_err(std::io::Error::other)?;
|
||||
let rows = stmt
|
||||
.query_map(params![cutoff], |row| row.get(0))
|
||||
.map_err(std::io::Error::other)?;
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
};
|
||||
|
||||
let count = stale_keys.len();
|
||||
for key in &stale_keys {
|
||||
let _ = conn.execute("DELETE FROM sessions WHERE session_key = ?1", params![key]);
|
||||
let _ = conn.execute(
|
||||
"DELETE FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
);
|
||||
}
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
let Some(keyword) = &query.keyword else {
|
||||
return self.list_sessions_with_metadata();
|
||||
};
|
||||
|
||||
let conn = self.conn.lock();
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit = query.limit.unwrap_or(50) as i64;
|
||||
|
||||
// FTS5 search
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT DISTINCT f.session_key
|
||||
FROM sessions_fts f
|
||||
WHERE sessions_fts MATCH ?1
|
||||
LIMIT ?2",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Quote each word for FTS5
|
||||
let fts_query: String = keyword
|
||||
.split_whitespace()
|
||||
.map(|w| format!("\"{w}\""))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" OR ");
|
||||
|
||||
let keys: Vec<String> = match stmt.query_map(params![fts_query, limit], |row| row.get(0)) {
|
||||
Ok(r) => r.filter_map(|r| r.ok()).collect(),
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Look up metadata for matched sessions
|
||||
keys.iter()
|
||||
.filter_map(|key| {
|
||||
conn.query_row(
|
||||
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
|row| {
|
||||
let created_str: String = row.get(0)?;
|
||||
let activity_str: String = row.get(1)?;
|
||||
let count: i64 = row.get(2)?;
|
||||
Ok(SessionMetadata {
|
||||
key: key.clone(),
|
||||
created_at: DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
last_activity: DateTime::parse_from_rfc3339(&activity_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
message_count: count as usize,
|
||||
})
|
||||
},
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn round_trip_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend
|
||||
.append("user1", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
backend
|
||||
.append("user1", &ChatMessage::assistant("hi"))
|
||||
.unwrap();
|
||||
|
||||
let msgs = backend.load("user1");
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].role, "user");
|
||||
assert_eq!(msgs[1].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("u", &ChatMessage::user("a")).unwrap();
|
||||
backend.append("u", &ChatMessage::user("b")).unwrap();
|
||||
|
||||
assert!(backend.remove_last("u").unwrap());
|
||||
let msgs = backend.load("u");
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_empty_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
assert!(!backend.remove_last("nonexistent").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sessions_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("a", &ChatMessage::user("hi")).unwrap();
|
||||
backend.append("b", &ChatMessage::user("hey")).unwrap();
|
||||
|
||||
let sessions = backend.list_sessions();
|
||||
assert_eq!(sessions.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metadata_tracks_counts() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("a")).unwrap();
|
||||
backend.append("s1", &ChatMessage::user("b")).unwrap();
|
||||
backend.append("s1", &ChatMessage::user("c")).unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert_eq!(meta[0].message_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fts5_search_finds_content() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend
|
||||
.append(
|
||||
"code_chat",
|
||||
&ChatMessage::user("How do I parse JSON in Rust?"),
|
||||
)
|
||||
.unwrap();
|
||||
backend
|
||||
.append("weather", &ChatMessage::user("What's the weather today?"))
|
||||
.unwrap();
|
||||
|
||||
let results = backend.search(&SessionQuery {
|
||||
keyword: Some("Rust".into()),
|
||||
limit: Some(10),
|
||||
});
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "code_chat");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cleanup_stale_removes_old_sessions() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
// Insert a session with old timestamp
|
||||
{
|
||||
let conn = backend.conn.lock();
|
||||
let old_time = (Utc::now() - Duration::hours(100)).to_rfc3339();
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (session_key, role, content, created_at) VALUES (?1, ?2, ?3, ?4)",
|
||||
params!["old_session", "user", "ancient", old_time],
|
||||
).unwrap();
|
||||
conn.execute(
|
||||
"INSERT INTO session_metadata (session_key, created_at, last_activity, message_count) VALUES (?1, ?2, ?3, 1)",
|
||||
params!["old_session", old_time, old_time],
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
backend
|
||||
.append("new_session", &ChatMessage::user("fresh"))
|
||||
.unwrap();
|
||||
|
||||
let cleaned = backend.cleanup_stale(48).unwrap(); // 48h TTL
|
||||
assert_eq!(cleaned, 1);
|
||||
|
||||
let sessions = backend.list_sessions();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0], "new_session");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migrate_from_jsonl_imports_and_renames() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sessions_dir = tmp.path().join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir).unwrap();
|
||||
|
||||
// Create a JSONL file
|
||||
let jsonl_path = sessions_dir.join("test_user.jsonl");
|
||||
std::fs::write(
|
||||
&jsonl_path,
|
||||
"{\"role\":\"user\",\"content\":\"hello\"}\n{\"role\":\"assistant\",\"content\":\"hi\"}\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
let migrated = backend.migrate_from_jsonl(tmp.path()).unwrap();
|
||||
assert_eq!(migrated, 1);
|
||||
|
||||
// JSONL should be renamed
|
||||
assert!(!jsonl_path.exists());
|
||||
assert!(sessions_dir.join("test_user.jsonl.migrated").exists());
|
||||
|
||||
// Messages should be in SQLite
|
||||
let msgs = backend.load("test_user");
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].content, "hello");
|
||||
}
|
||||
}
|
||||
@ -5,6 +5,7 @@
|
||||
//! one-per-line as JSON, never modifying old lines. On daemon restart, sessions
|
||||
//! are loaded from disk to restore conversation context.
|
||||
|
||||
use crate::channels::session_backend::SessionBackend;
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
@ -78,6 +79,37 @@ impl SessionStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove the last message from a session's JSONL file.
|
||||
///
|
||||
/// Rewrite approach: load all messages, drop the last, rewrite. This is
|
||||
/// O(n) but rollbacks are rare.
|
||||
pub fn remove_last(&self, session_key: &str) -> std::io::Result<bool> {
|
||||
let mut messages = self.load(session_key);
|
||||
if messages.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
messages.pop();
|
||||
self.rewrite(session_key, &messages)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Compact a session file by rewriting only valid messages (removes corrupt lines).
|
||||
pub fn compact(&self, session_key: &str) -> std::io::Result<()> {
|
||||
let messages = self.load(session_key);
|
||||
self.rewrite(session_key, &messages)
|
||||
}
|
||||
|
||||
fn rewrite(&self, session_key: &str, messages: &[ChatMessage]) -> std::io::Result<()> {
|
||||
let path = self.session_path(session_key);
|
||||
let mut file = std::fs::File::create(&path)?;
|
||||
for msg in messages {
|
||||
let json = serde_json::to_string(msg)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
writeln!(file, "{json}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all session keys that have files on disk.
|
||||
pub fn list_sessions(&self) -> Vec<String> {
|
||||
let entries = match std::fs::read_dir(&self.sessions_dir) {
|
||||
@ -95,6 +127,28 @@ impl SessionStore {
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionBackend for SessionStore {
|
||||
fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
self.load(session_key)
|
||||
}
|
||||
|
||||
fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
self.append(session_key, message)
|
||||
}
|
||||
|
||||
fn remove_last(&self, session_key: &str) -> std::io::Result<bool> {
|
||||
self.remove_last(session_key)
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Vec<String> {
|
||||
self.list_sessions()
|
||||
}
|
||||
|
||||
fn compact(&self, session_key: &str) -> std::io::Result<()> {
|
||||
self.compact(session_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -178,6 +232,63 @@ mod tests {
|
||||
assert_eq!(lines.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_drops_final_message() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("rm_test", &ChatMessage::user("first"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("rm_test", &ChatMessage::user("second"))
|
||||
.unwrap();
|
||||
|
||||
assert!(store.remove_last("rm_test").unwrap());
|
||||
let messages = store.load("rm_test");
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].content, "first");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_empty_returns_false() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
assert!(!store.remove_last("nonexistent").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compact_removes_corrupt_lines() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "compact_test";
|
||||
|
||||
let path = store.session_path(key);
|
||||
std::fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
let mut file = std::fs::File::create(&path).unwrap();
|
||||
writeln!(file, r#"{{"role":"user","content":"ok"}}"#).unwrap();
|
||||
writeln!(file, "corrupt line").unwrap();
|
||||
writeln!(file, r#"{{"role":"assistant","content":"hi"}}"#).unwrap();
|
||||
|
||||
store.compact(key).unwrap();
|
||||
|
||||
let raw = std::fs::read_to_string(&path).unwrap();
|
||||
assert_eq!(raw.trim().lines().count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_backend_trait_works_via_dyn() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let backend: &dyn SessionBackend = &store;
|
||||
|
||||
backend
|
||||
.append("trait_test", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
let msgs = backend.load("trait_test");
|
||||
assert_eq!(msgs.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_corrupt_lines_gracefully() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@ -2650,6 +2650,9 @@ pub struct MemoryConfig {
|
||||
/// Max number of cached responses before LRU eviction (default: 5000)
|
||||
#[serde(default = "default_response_cache_max")]
|
||||
pub response_cache_max_entries: usize,
|
||||
/// Max in-memory hot cache entries for the two-tier response cache (default: 256)
|
||||
#[serde(default = "default_response_cache_hot_entries")]
|
||||
pub response_cache_hot_entries: usize,
|
||||
|
||||
// ── Memory Snapshot (soul backup to Markdown) ─────────────
|
||||
/// Enable periodic export of core memories to MEMORY_SNAPSHOT.md
|
||||
@ -2718,6 +2721,10 @@ fn default_response_cache_max() -> usize {
|
||||
5_000
|
||||
}
|
||||
|
||||
fn default_response_cache_hot_entries() -> usize {
|
||||
256
|
||||
}
|
||||
|
||||
impl Default for MemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -2738,6 +2745,7 @@ impl Default for MemoryConfig {
|
||||
response_cache_enabled: false,
|
||||
response_cache_ttl_minutes: default_response_cache_ttl(),
|
||||
response_cache_max_entries: default_response_cache_max(),
|
||||
response_cache_hot_entries: default_response_cache_hot_entries(),
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
@ -3344,12 +3352,48 @@ pub struct HeartbeatConfig {
|
||||
/// explicitly set).
|
||||
#[serde(default, alias = "recipient")]
|
||||
pub to: Option<String>,
|
||||
/// Enable adaptive intervals that back off on failures and speed up for
|
||||
/// high-priority tasks. Default: `false`.
|
||||
#[serde(default)]
|
||||
pub adaptive: bool,
|
||||
/// Minimum interval in minutes when adaptive mode is enabled. Default: `5`.
|
||||
#[serde(default = "default_heartbeat_min_interval")]
|
||||
pub min_interval_minutes: u32,
|
||||
/// Maximum interval in minutes when adaptive mode backs off. Default: `120`.
|
||||
#[serde(default = "default_heartbeat_max_interval")]
|
||||
pub max_interval_minutes: u32,
|
||||
/// Dead-man's switch timeout in minutes. If the heartbeat has not ticked
|
||||
/// within this window, an alert is sent. `0` disables. Default: `0`.
|
||||
#[serde(default)]
|
||||
pub deadman_timeout_minutes: u32,
|
||||
/// Channel for dead-man's switch alerts (e.g. `telegram`). Falls back to
|
||||
/// the heartbeat delivery channel.
|
||||
#[serde(default)]
|
||||
pub deadman_channel: Option<String>,
|
||||
/// Recipient for dead-man's switch alerts. Falls back to `to`.
|
||||
#[serde(default)]
|
||||
pub deadman_to: Option<String>,
|
||||
/// Maximum number of heartbeat run history records to retain. Default: `100`.
|
||||
#[serde(default = "default_heartbeat_max_run_history")]
|
||||
pub max_run_history: u32,
|
||||
}
|
||||
|
||||
fn default_two_phase() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_heartbeat_min_interval() -> u32 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_heartbeat_max_interval() -> u32 {
|
||||
120
|
||||
}
|
||||
|
||||
fn default_heartbeat_max_run_history() -> u32 {
|
||||
100
|
||||
}
|
||||
|
||||
impl Default for HeartbeatConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -3359,6 +3403,13 @@ impl Default for HeartbeatConfig {
|
||||
message: None,
|
||||
target: None,
|
||||
to: None,
|
||||
adaptive: false,
|
||||
min_interval_minutes: default_heartbeat_min_interval(),
|
||||
max_interval_minutes: default_heartbeat_max_interval(),
|
||||
deadman_timeout_minutes: 0,
|
||||
deadman_channel: None,
|
||||
deadman_to: None,
|
||||
max_run_history: default_heartbeat_max_run_history(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3587,6 +3638,13 @@ pub struct ChannelsConfig {
|
||||
/// daemon restarts. Files are stored in `{workspace}/sessions/`. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub session_persistence: bool,
|
||||
/// Session persistence backend: `"jsonl"` (legacy) or `"sqlite"` (new default).
|
||||
/// SQLite provides FTS5 search, metadata tracking, and TTL cleanup.
|
||||
#[serde(default = "default_session_backend")]
|
||||
pub session_backend: String,
|
||||
/// Auto-archive stale sessions older than this many hours. `0` disables. Default: `0`.
|
||||
#[serde(default)]
|
||||
pub session_ttl_hours: u32,
|
||||
}
|
||||
|
||||
impl ChannelsConfig {
|
||||
@ -3692,6 +3750,10 @@ fn default_channel_message_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_session_backend() -> String {
|
||||
"sqlite".into()
|
||||
}
|
||||
|
||||
impl Default for ChannelsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -3722,6 +3784,8 @@ impl Default for ChannelsConfig {
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7358,6 +7422,7 @@ default_temperature = 0.7
|
||||
message: Some("Check London time".into()),
|
||||
target: Some("telegram".into()),
|
||||
to: Some("123456".into()),
|
||||
..HeartbeatConfig::default()
|
||||
},
|
||||
cron: CronConfig::default(),
|
||||
channels_config: ChannelsConfig {
|
||||
@ -7395,6 +7460,8 @@ default_temperature = 0.7
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
},
|
||||
memory: MemoryConfig::default(),
|
||||
storage: StorageConfig::default(),
|
||||
@ -8127,6 +8194,8 @@ allowed_users = ["@ops:matrix.org"]
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
@ -8355,6 +8424,8 @@ channel_id = "C123"
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
|
||||
@ -203,7 +203,10 @@ where
|
||||
}
|
||||
|
||||
async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
use crate::heartbeat::engine::HeartbeatEngine;
|
||||
use crate::heartbeat::engine::{
|
||||
compute_adaptive_interval, HeartbeatEngine, HeartbeatTask, TaskPriority, TaskStatus,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
let observer: std::sync::Arc<dyn crate::observability::Observer> =
|
||||
std::sync::Arc::from(crate::observability::create_observer(&config.observability));
|
||||
@ -212,19 +215,72 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
config.workspace_dir.clone(),
|
||||
observer,
|
||||
);
|
||||
let metrics = engine.metrics();
|
||||
let delivery = resolve_heartbeat_delivery(&config)?;
|
||||
let two_phase = config.heartbeat.two_phase;
|
||||
let adaptive = config.heartbeat.adaptive;
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
let interval_mins = config.heartbeat.interval_minutes.max(5);
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
|
||||
// ── Deadman watcher ──────────────────────────────────────────
|
||||
let deadman_timeout = config.heartbeat.deadman_timeout_minutes;
|
||||
if deadman_timeout > 0 {
|
||||
let dm_metrics = Arc::clone(&metrics);
|
||||
let dm_config = config.clone();
|
||||
let dm_delivery = delivery.clone();
|
||||
tokio::spawn(async move {
|
||||
let check_interval = Duration::from_secs(60);
|
||||
let timeout = chrono::Duration::minutes(i64::from(deadman_timeout));
|
||||
loop {
|
||||
tokio::time::sleep(check_interval).await;
|
||||
let last_tick = dm_metrics.lock().last_tick_at;
|
||||
if let Some(last) = last_tick {
|
||||
if chrono::Utc::now() - last > timeout {
|
||||
let alert = format!(
|
||||
"⚠️ Heartbeat dead-man's switch: no tick in {deadman_timeout} minutes"
|
||||
);
|
||||
let (channel, target) =
|
||||
if let Some(ch) = &dm_config.heartbeat.deadman_channel {
|
||||
let to = dm_config
|
||||
.heartbeat
|
||||
.deadman_to
|
||||
.as_deref()
|
||||
.or(dm_config.heartbeat.to.as_deref())
|
||||
.unwrap_or_default();
|
||||
(ch.clone(), to.to_string())
|
||||
} else if let Some((ch, to)) = &dm_delivery {
|
||||
(ch.clone(), to.clone())
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let _ = crate::cron::scheduler::deliver_announcement(
|
||||
&dm_config, &channel, &target, &alert,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let base_interval = config.heartbeat.interval_minutes.max(5);
|
||||
let mut sleep_mins = base_interval;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
tokio::time::sleep(Duration::from_secs(u64::from(sleep_mins) * 60)).await;
|
||||
|
||||
// Update uptime
|
||||
{
|
||||
let mut m = metrics.lock();
|
||||
m.uptime_secs = start_time.elapsed().as_secs();
|
||||
}
|
||||
|
||||
let tick_start = std::time::Instant::now();
|
||||
|
||||
// Collect runnable tasks (active only, sorted by priority)
|
||||
let mut tasks = engine.collect_runnable_tasks().await?;
|
||||
let has_high_priority = tasks.iter().any(|t| t.priority == TaskPriority::High);
|
||||
|
||||
if tasks.is_empty() {
|
||||
// Try fallback message
|
||||
if let Some(fallback) = config
|
||||
.heartbeat
|
||||
.message
|
||||
@ -232,12 +288,15 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
.map(str::trim)
|
||||
.filter(|m| !m.is_empty())
|
||||
{
|
||||
tasks.push(crate::heartbeat::engine::HeartbeatTask {
|
||||
tasks.push(HeartbeatTask {
|
||||
text: fallback.to_string(),
|
||||
priority: crate::heartbeat::engine::TaskPriority::Medium,
|
||||
status: crate::heartbeat::engine::TaskStatus::Active,
|
||||
priority: TaskPriority::Medium,
|
||||
status: TaskStatus::Active,
|
||||
});
|
||||
} else {
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
metrics.lock().record_success(elapsed);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -250,7 +309,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
Some(decision_prompt),
|
||||
None,
|
||||
None,
|
||||
0.0, // Low temperature for deterministic decision
|
||||
0.0,
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
@ -263,6 +322,9 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
if indices.is_empty() {
|
||||
tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)");
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
metrics.lock().record_success(elapsed);
|
||||
continue;
|
||||
}
|
||||
tracing::info!(
|
||||
@ -285,7 +347,9 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
let mut tick_had_error = false;
|
||||
for task in &tasks_to_run {
|
||||
let task_start = std::time::Instant::now();
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let temp = config.default_temperature;
|
||||
match Box::pin(crate::agent::run(
|
||||
@ -303,6 +367,20 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
{
|
||||
Ok(output) => {
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let duration_ms = task_start.elapsed().as_millis() as i64;
|
||||
let now = chrono::Utc::now();
|
||||
let _ = crate::heartbeat::store::record_run(
|
||||
&config.workspace_dir,
|
||||
&task.text,
|
||||
&task.priority.to_string(),
|
||||
now - chrono::Duration::milliseconds(duration_ms),
|
||||
now,
|
||||
"ok",
|
||||
Some(output.as_str()),
|
||||
duration_ms,
|
||||
config.heartbeat.max_run_history,
|
||||
);
|
||||
let announcement = if output.trim().is_empty() {
|
||||
format!("💓 heartbeat task completed: {}", task.text)
|
||||
} else {
|
||||
@ -326,11 +404,52 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tick_had_error = true;
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let duration_ms = task_start.elapsed().as_millis() as i64;
|
||||
let now = chrono::Utc::now();
|
||||
let _ = crate::heartbeat::store::record_run(
|
||||
&config.workspace_dir,
|
||||
&task.text,
|
||||
&task.priority.to_string(),
|
||||
now - chrono::Duration::milliseconds(duration_ms),
|
||||
now,
|
||||
"error",
|
||||
Some(&e.to_string()),
|
||||
duration_ms,
|
||||
config.heartbeat.max_run_history,
|
||||
);
|
||||
crate::health::mark_component_error("heartbeat", e.to_string());
|
||||
tracing::warn!("Heartbeat task failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let tick_elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
{
|
||||
let mut m = metrics.lock();
|
||||
if tick_had_error {
|
||||
m.record_failure(tick_elapsed);
|
||||
} else {
|
||||
m.record_success(tick_elapsed);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute next sleep interval
|
||||
if adaptive {
|
||||
let failures = metrics.lock().consecutive_failures;
|
||||
sleep_mins = compute_adaptive_interval(
|
||||
base_interval,
|
||||
config.heartbeat.min_interval_minutes,
|
||||
config.heartbeat.max_interval_minutes,
|
||||
failures,
|
||||
has_high_priority,
|
||||
);
|
||||
} else {
|
||||
sleep_mins = base_interval;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
use crate::config::HeartbeatConfig;
|
||||
use crate::observability::{Observer, ObserverEvent};
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use parking_lot::Mutex as ParkingMutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
@ -68,6 +70,99 @@ impl fmt::Display for HeartbeatTask {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Health Metrics ───────────────────────────────────────────────
|
||||
|
||||
/// Live health metrics for the heartbeat subsystem.
|
||||
///
|
||||
/// Shared via `Arc<ParkingMutex<>>` between the heartbeat worker,
|
||||
/// deadman watcher, and API consumers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HeartbeatMetrics {
|
||||
/// Monotonic uptime since the heartbeat loop started.
|
||||
pub uptime_secs: u64,
|
||||
/// Consecutive successful ticks (resets on failure).
|
||||
pub consecutive_successes: u64,
|
||||
/// Consecutive failed ticks (resets on success).
|
||||
pub consecutive_failures: u64,
|
||||
/// Timestamp of the most recent tick (UTC RFC 3339).
|
||||
pub last_tick_at: Option<DateTime<Utc>>,
|
||||
/// Exponential moving average of tick durations in milliseconds.
|
||||
pub avg_tick_duration_ms: f64,
|
||||
/// Total number of ticks executed since startup.
|
||||
pub total_ticks: u64,
|
||||
}
|
||||
|
||||
impl Default for HeartbeatMetrics {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
uptime_secs: 0,
|
||||
consecutive_successes: 0,
|
||||
consecutive_failures: 0,
|
||||
last_tick_at: None,
|
||||
avg_tick_duration_ms: 0.0,
|
||||
total_ticks: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HeartbeatMetrics {
|
||||
/// Record a successful tick with the given duration.
|
||||
pub fn record_success(&mut self, duration_ms: f64) {
|
||||
self.consecutive_successes += 1;
|
||||
self.consecutive_failures = 0;
|
||||
self.last_tick_at = Some(Utc::now());
|
||||
self.total_ticks += 1;
|
||||
self.update_avg_duration(duration_ms);
|
||||
}
|
||||
|
||||
/// Record a failed tick with the given duration.
|
||||
pub fn record_failure(&mut self, duration_ms: f64) {
|
||||
self.consecutive_failures += 1;
|
||||
self.consecutive_successes = 0;
|
||||
self.last_tick_at = Some(Utc::now());
|
||||
self.total_ticks += 1;
|
||||
self.update_avg_duration(duration_ms);
|
||||
}
|
||||
|
||||
fn update_avg_duration(&mut self, duration_ms: f64) {
|
||||
const ALPHA: f64 = 0.3; // EMA smoothing factor
|
||||
if self.total_ticks == 1 {
|
||||
self.avg_tick_duration_ms = duration_ms;
|
||||
} else {
|
||||
self.avg_tick_duration_ms =
|
||||
ALPHA * duration_ms + (1.0 - ALPHA) * self.avg_tick_duration_ms;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the adaptive interval for the next heartbeat tick.
|
||||
///
|
||||
/// Strategy:
|
||||
/// - On failures: exponential back-off `base * 2^failures` capped at `max_interval`.
|
||||
/// - When high-priority tasks are present: use `min_interval` for faster reaction.
|
||||
/// - Otherwise: use `base_interval`.
|
||||
pub fn compute_adaptive_interval(
|
||||
base_minutes: u32,
|
||||
min_minutes: u32,
|
||||
max_minutes: u32,
|
||||
consecutive_failures: u64,
|
||||
has_high_priority_tasks: bool,
|
||||
) -> u32 {
|
||||
if consecutive_failures > 0 {
|
||||
let backoff = base_minutes.saturating_mul(
|
||||
1u32.checked_shl(consecutive_failures.min(10) as u32)
|
||||
.unwrap_or(u32::MAX),
|
||||
);
|
||||
return backoff.min(max_minutes).max(min_minutes);
|
||||
}
|
||||
|
||||
if has_high_priority_tasks {
|
||||
return min_minutes.max(5); // never go below 5 minutes
|
||||
}
|
||||
|
||||
base_minutes.clamp(min_minutes, max_minutes)
|
||||
}
|
||||
|
||||
// ── Engine ───────────────────────────────────────────────────────
|
||||
|
||||
/// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically
|
||||
@ -75,6 +170,7 @@ pub struct HeartbeatEngine {
|
||||
config: HeartbeatConfig,
|
||||
workspace_dir: std::path::PathBuf,
|
||||
observer: Arc<dyn Observer>,
|
||||
metrics: Arc<ParkingMutex<HeartbeatMetrics>>,
|
||||
}
|
||||
|
||||
impl HeartbeatEngine {
|
||||
@ -87,9 +183,15 @@ impl HeartbeatEngine {
|
||||
config,
|
||||
workspace_dir,
|
||||
observer,
|
||||
metrics: Arc::new(ParkingMutex::new(HeartbeatMetrics::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a shared handle to the live heartbeat metrics.
|
||||
pub fn metrics(&self) -> Arc<ParkingMutex<HeartbeatMetrics>> {
|
||||
Arc::clone(&self.metrics)
|
||||
}
|
||||
|
||||
/// Start the heartbeat loop (runs until cancelled)
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
if !self.config.enabled {
|
||||
@ -673,4 +775,79 @@ mod tests {
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
// ── HeartbeatMetrics tests ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn metrics_record_success_updates_fields() {
|
||||
let mut m = HeartbeatMetrics::default();
|
||||
m.record_success(100.0);
|
||||
assert_eq!(m.consecutive_successes, 1);
|
||||
assert_eq!(m.consecutive_failures, 0);
|
||||
assert_eq!(m.total_ticks, 1);
|
||||
assert!(m.last_tick_at.is_some());
|
||||
assert!((m.avg_tick_duration_ms - 100.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_record_failure_resets_successes() {
|
||||
let mut m = HeartbeatMetrics::default();
|
||||
m.record_success(50.0);
|
||||
m.record_success(50.0);
|
||||
m.record_failure(200.0);
|
||||
assert_eq!(m.consecutive_successes, 0);
|
||||
assert_eq!(m.consecutive_failures, 1);
|
||||
assert_eq!(m.total_ticks, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_ema_smoothing() {
|
||||
let mut m = HeartbeatMetrics::default();
|
||||
m.record_success(100.0);
|
||||
assert!((m.avg_tick_duration_ms - 100.0).abs() < f64::EPSILON);
|
||||
m.record_success(200.0);
|
||||
// EMA: 0.3 * 200 + 0.7 * 100 = 130
|
||||
assert!((m.avg_tick_duration_ms - 130.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Adaptive interval tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_uses_base_when_no_failures() {
|
||||
let result = compute_adaptive_interval(30, 5, 120, 0, false);
|
||||
assert_eq!(result, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_uses_min_for_high_priority() {
|
||||
let result = compute_adaptive_interval(30, 5, 120, 0, true);
|
||||
assert_eq!(result, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_backs_off_on_failures() {
|
||||
// 1 failure: 30 * 2 = 60
|
||||
assert_eq!(compute_adaptive_interval(30, 5, 120, 1, false), 60);
|
||||
// 2 failures: 30 * 4 = 120 (capped at max)
|
||||
assert_eq!(compute_adaptive_interval(30, 5, 120, 2, false), 120);
|
||||
// 3 failures: 30 * 8 = 240 → capped at 120
|
||||
assert_eq!(compute_adaptive_interval(30, 5, 120, 3, false), 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_backoff_respects_min() {
|
||||
// Even with failures, must be >= min
|
||||
assert!(compute_adaptive_interval(5, 10, 120, 0, false) >= 10);
|
||||
}
|
||||
|
||||
// ── Engine metrics accessor ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn engine_exposes_shared_metrics() {
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine =
|
||||
HeartbeatEngine::new(HeartbeatConfig::default(), std::env::temp_dir(), observer);
|
||||
let metrics = engine.metrics();
|
||||
assert_eq!(metrics.lock().total_ticks, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
pub mod engine;
|
||||
pub mod store;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
305
src/heartbeat/store.rs
Normal file
305
src/heartbeat/store.rs
Normal file
@ -0,0 +1,305 @@
|
||||
//! SQLite persistence for heartbeat task execution history.
|
||||
//!
|
||||
//! Mirrors the `cron/store.rs` pattern: fresh connection per call, schema
|
||||
//! auto-created, output truncated, history pruned to a configurable limit.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const MAX_OUTPUT_BYTES: usize = 16 * 1024;
|
||||
const TRUNCATED_MARKER: &str = "\n...[truncated]";
|
||||
|
||||
/// A single heartbeat task execution record.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeartbeatRun {
|
||||
pub id: i64,
|
||||
pub task_text: String,
|
||||
pub task_priority: String,
|
||||
pub started_at: DateTime<Utc>,
|
||||
pub finished_at: DateTime<Utc>,
|
||||
pub status: String, // "ok" or "error"
|
||||
pub output: Option<String>,
|
||||
pub duration_ms: i64,
|
||||
}
|
||||
|
||||
/// Record a heartbeat task execution and prune old entries.
|
||||
pub fn record_run(
|
||||
workspace_dir: &Path,
|
||||
task_text: &str,
|
||||
task_priority: &str,
|
||||
started_at: DateTime<Utc>,
|
||||
finished_at: DateTime<Utc>,
|
||||
status: &str,
|
||||
output: Option<&str>,
|
||||
duration_ms: i64,
|
||||
max_history: u32,
|
||||
) -> Result<()> {
|
||||
let bounded_output = output.map(truncate_output);
|
||||
with_connection(workspace_dir, |conn| {
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
tx.execute(
|
||||
"INSERT INTO heartbeat_runs
|
||||
(task_text, task_priority, started_at, finished_at, status, output, duration_ms)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
||||
params![
|
||||
task_text,
|
||||
task_priority,
|
||||
started_at.to_rfc3339(),
|
||||
finished_at.to_rfc3339(),
|
||||
status,
|
||||
bounded_output.as_deref(),
|
||||
duration_ms,
|
||||
],
|
||||
)
|
||||
.context("Failed to insert heartbeat run")?;
|
||||
|
||||
let keep = i64::from(max_history.max(1));
|
||||
tx.execute(
|
||||
"DELETE FROM heartbeat_runs
|
||||
WHERE id NOT IN (
|
||||
SELECT id FROM heartbeat_runs
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT ?1
|
||||
)",
|
||||
params![keep],
|
||||
)
|
||||
.context("Failed to prune heartbeat run history")?;
|
||||
|
||||
tx.commit()
|
||||
.context("Failed to commit heartbeat run transaction")?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// List the most recent heartbeat runs.
|
||||
pub fn list_runs(workspace_dir: &Path, limit: usize) -> Result<Vec<HeartbeatRun>> {
|
||||
with_connection(workspace_dir, |conn| {
|
||||
let lim = i64::try_from(limit.max(1)).context("Run history limit overflow")?;
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, task_text, task_priority, started_at, finished_at, status, output, duration_ms
|
||||
FROM heartbeat_runs
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT ?1",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![lim], |row| {
|
||||
Ok(HeartbeatRun {
|
||||
id: row.get(0)?,
|
||||
task_text: row.get(1)?,
|
||||
task_priority: row.get(2)?,
|
||||
started_at: parse_rfc3339(&row.get::<_, String>(3)?).map_err(sql_err)?,
|
||||
finished_at: parse_rfc3339(&row.get::<_, String>(4)?).map_err(sql_err)?,
|
||||
status: row.get(5)?,
|
||||
output: row.get(6)?,
|
||||
duration_ms: row.get(7)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut runs = Vec::new();
|
||||
for row in rows {
|
||||
runs.push(row?);
|
||||
}
|
||||
Ok(runs)
|
||||
})
|
||||
}
|
||||
|
||||
/// Get aggregate stats: (total_runs, total_ok, total_error).
|
||||
pub fn run_stats(workspace_dir: &Path) -> Result<(u64, u64, u64)> {
|
||||
with_connection(workspace_dir, |conn| {
|
||||
let total: i64 = conn.query_row("SELECT COUNT(*) FROM heartbeat_runs", [], |r| r.get(0))?;
|
||||
let ok: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM heartbeat_runs WHERE status = 'ok'",
|
||||
[],
|
||||
|r| r.get(0),
|
||||
)?;
|
||||
let err: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM heartbeat_runs WHERE status = 'error'",
|
||||
[],
|
||||
|r| r.get(0),
|
||||
)?;
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
Ok((total as u64, ok as u64, err as u64))
|
||||
})
|
||||
}
|
||||
|
||||
fn db_path(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("heartbeat").join("history.db")
|
||||
}
|
||||
|
||||
fn with_connection<T>(workspace_dir: &Path, f: impl FnOnce(&Connection) -> Result<T>) -> Result<T> {
|
||||
let path = db_path(workspace_dir);
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).with_context(|| {
|
||||
format!("Failed to create heartbeat directory: {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&path)
|
||||
.with_context(|| format!("Failed to open heartbeat history DB: {}", path.display()))?;
|
||||
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA temp_store = MEMORY;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS heartbeat_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_text TEXT NOT NULL,
|
||||
task_priority TEXT NOT NULL,
|
||||
started_at TEXT NOT NULL,
|
||||
finished_at TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
output TEXT,
|
||||
duration_ms INTEGER
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_hb_runs_started ON heartbeat_runs(started_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_hb_runs_task ON heartbeat_runs(task_text);",
|
||||
)
|
||||
.context("Failed to initialize heartbeat history schema")?;
|
||||
|
||||
f(&conn)
|
||||
}
|
||||
|
||||
fn truncate_output(output: &str) -> String {
|
||||
if output.len() <= MAX_OUTPUT_BYTES {
|
||||
return output.to_string();
|
||||
}
|
||||
|
||||
if MAX_OUTPUT_BYTES <= TRUNCATED_MARKER.len() {
|
||||
return TRUNCATED_MARKER.to_string();
|
||||
}
|
||||
|
||||
let mut cutoff = MAX_OUTPUT_BYTES - TRUNCATED_MARKER.len();
|
||||
while cutoff > 0 && !output.is_char_boundary(cutoff) {
|
||||
cutoff -= 1;
|
||||
}
|
||||
|
||||
let mut truncated = output[..cutoff].to_string();
|
||||
truncated.push_str(TRUNCATED_MARKER);
|
||||
truncated
|
||||
}
|
||||
|
||||
fn parse_rfc3339(raw: &str) -> Result<DateTime<Utc>> {
|
||||
let parsed = DateTime::parse_from_rfc3339(raw)
|
||||
.with_context(|| format!("Invalid RFC3339 timestamp in heartbeat DB: {raw}"))?;
|
||||
Ok(parsed.with_timezone(&Utc))
|
||||
}
|
||||
|
||||
fn sql_err(err: anyhow::Error) -> rusqlite::Error {
|
||||
rusqlite::Error::ToSqlConversionFailure(err.into())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn record_and_list_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let base = Utc::now();
|
||||
|
||||
for i in 0..3 {
|
||||
let start = base + ChronoDuration::seconds(i);
|
||||
let end = start + ChronoDuration::milliseconds(100);
|
||||
record_run(
|
||||
tmp.path(),
|
||||
&format!("Task {i}"),
|
||||
"medium",
|
||||
start,
|
||||
end,
|
||||
"ok",
|
||||
Some("done"),
|
||||
100,
|
||||
50,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let runs = list_runs(tmp.path(), 10).unwrap();
|
||||
assert_eq!(runs.len(), 3);
|
||||
// Most recent first
|
||||
assert!(runs[0].task_text.contains('2'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prunes_old_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let base = Utc::now();
|
||||
|
||||
for i in 0..5 {
|
||||
let start = base + ChronoDuration::seconds(i);
|
||||
let end = start + ChronoDuration::milliseconds(50);
|
||||
record_run(
|
||||
tmp.path(),
|
||||
"Task",
|
||||
"high",
|
||||
start,
|
||||
end,
|
||||
"ok",
|
||||
None,
|
||||
50,
|
||||
2, // keep only 2
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let runs = list_runs(tmp.path(), 10).unwrap();
|
||||
assert_eq!(runs.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_stats_counts_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let now = Utc::now();
|
||||
|
||||
record_run(tmp.path(), "A", "high", now, now, "ok", None, 10, 50).unwrap();
|
||||
record_run(
|
||||
tmp.path(),
|
||||
"B",
|
||||
"low",
|
||||
now,
|
||||
now,
|
||||
"error",
|
||||
Some("fail"),
|
||||
20,
|
||||
50,
|
||||
)
|
||||
.unwrap();
|
||||
record_run(tmp.path(), "C", "medium", now, now, "ok", None, 15, 50).unwrap();
|
||||
|
||||
let (total, ok, err) = run_stats(tmp.path()).unwrap();
|
||||
assert_eq!(total, 3);
|
||||
assert_eq!(ok, 2);
|
||||
assert_eq!(err, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_large_output() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let now = Utc::now();
|
||||
let big = "x".repeat(MAX_OUTPUT_BYTES + 512);
|
||||
|
||||
record_run(
|
||||
tmp.path(),
|
||||
"T",
|
||||
"medium",
|
||||
now,
|
||||
now,
|
||||
"ok",
|
||||
Some(&big),
|
||||
10,
|
||||
50,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let runs = list_runs(tmp.path(), 1).unwrap();
|
||||
let stored = runs[0].output.as_deref().unwrap_or_default();
|
||||
assert!(stored.ends_with(TRUNCATED_MARKER));
|
||||
assert!(stored.len() <= MAX_OUTPUT_BYTES);
|
||||
}
|
||||
}
|
||||
@ -10,23 +10,45 @@ use chrono::{Duration, Local};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Response cache backed by a dedicated SQLite database.
|
||||
/// An in-memory hot cache entry for the two-tier response cache.
|
||||
struct InMemoryEntry {
|
||||
response: String,
|
||||
token_count: u32,
|
||||
created_at: std::time::Instant,
|
||||
accessed_at: std::time::Instant,
|
||||
}
|
||||
|
||||
/// Two-tier response cache: in-memory LRU (hot) + SQLite (warm).
|
||||
///
|
||||
/// Lives alongside `brain.db` as `response_cache.db` so it can be
|
||||
/// independently wiped without touching memories.
|
||||
/// The hot cache avoids SQLite round-trips for frequently repeated prompts.
|
||||
/// On miss from hot cache, falls through to SQLite. On hit from SQLite,
|
||||
/// the entry is promoted to the hot cache.
|
||||
pub struct ResponseCache {
|
||||
conn: Mutex<Connection>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
ttl_minutes: i64,
|
||||
max_entries: usize,
|
||||
hot_cache: Mutex<HashMap<String, InMemoryEntry>>,
|
||||
hot_max_entries: usize,
|
||||
}
|
||||
|
||||
impl ResponseCache {
|
||||
/// Open (or create) the response cache database.
|
||||
pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result<Self> {
|
||||
Self::with_hot_cache(workspace_dir, ttl_minutes, max_entries, 256)
|
||||
}
|
||||
|
||||
/// Open (or create) the response cache database with a custom hot cache size.
|
||||
pub fn with_hot_cache(
|
||||
workspace_dir: &Path,
|
||||
ttl_minutes: u32,
|
||||
max_entries: usize,
|
||||
hot_max_entries: usize,
|
||||
) -> Result<Self> {
|
||||
let db_dir = workspace_dir.join("memory");
|
||||
std::fs::create_dir_all(&db_dir)?;
|
||||
let db_path = db_dir.join("response_cache.db");
|
||||
@ -58,6 +80,8 @@ impl ResponseCache {
|
||||
db_path,
|
||||
ttl_minutes: i64::from(ttl_minutes),
|
||||
max_entries,
|
||||
hot_cache: Mutex::new(HashMap::new()),
|
||||
hot_max_entries,
|
||||
})
|
||||
}
|
||||
|
||||
@ -76,35 +100,77 @@ impl ResponseCache {
|
||||
}
|
||||
|
||||
/// Look up a cached response. Returns `None` on miss or expired entry.
|
||||
///
|
||||
/// Two-tier lookup: checks the in-memory hot cache first, then falls
|
||||
/// through to SQLite. On a SQLite hit the entry is promoted to hot cache.
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT response FROM response_cache
|
||||
WHERE prompt_hash = ?1 AND created_at > ?2",
|
||||
)?;
|
||||
|
||||
let result: Option<String> = stmt.query_row(params![key, cutoff], |row| row.get(0)).ok();
|
||||
|
||||
if result.is_some() {
|
||||
// Bump hit count and accessed_at
|
||||
let now_str = now.to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
// Tier 1: hot cache (with TTL check)
|
||||
{
|
||||
let mut hot = self.hot_cache.lock();
|
||||
if let Some(entry) = hot.get_mut(key) {
|
||||
let ttl = std::time::Duration::from_secs(self.ttl_minutes as u64 * 60);
|
||||
if entry.created_at.elapsed() > ttl {
|
||||
hot.remove(key);
|
||||
} else {
|
||||
entry.accessed_at = std::time::Instant::now();
|
||||
let response = entry.response.clone();
|
||||
drop(hot);
|
||||
// Still bump SQLite hit count for accurate stats
|
||||
let conn = self.conn.lock();
|
||||
let now_str = Local::now().to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
return Ok(Some(response));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
// Tier 2: SQLite (warm)
|
||||
let result: Option<(String, u32)> = {
|
||||
let conn = self.conn.lock();
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT response, token_count FROM response_cache
|
||||
WHERE prompt_hash = ?1 AND created_at > ?2",
|
||||
)?;
|
||||
|
||||
let result: Option<(String, u32)> = stmt
|
||||
.query_row(params![key, cutoff], |row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.ok();
|
||||
|
||||
if result.is_some() {
|
||||
let now_str = now.to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
}
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
if let Some((ref response, token_count)) = result {
|
||||
self.promote_to_hot(key, response, token_count);
|
||||
}
|
||||
|
||||
Ok(result.map(|(r, _)| r))
|
||||
}
|
||||
|
||||
/// Store a response in the cache.
|
||||
/// Store a response in the cache (both hot and warm tiers).
|
||||
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
|
||||
// Write to hot cache
|
||||
self.promote_to_hot(key, response, token_count);
|
||||
|
||||
// Write to SQLite (warm)
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now().to_rfc3339();
|
||||
@ -138,6 +204,43 @@ impl ResponseCache {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Promote an entry to the in-memory hot cache, evicting the oldest if full.
|
||||
fn promote_to_hot(&self, key: &str, response: &str, token_count: u32) {
|
||||
let mut hot = self.hot_cache.lock();
|
||||
|
||||
// If already present, just update (keep original created_at for TTL)
|
||||
if let Some(entry) = hot.get_mut(key) {
|
||||
entry.response = response.to_string();
|
||||
entry.token_count = token_count;
|
||||
entry.accessed_at = std::time::Instant::now();
|
||||
return;
|
||||
}
|
||||
|
||||
// Evict oldest entry if at capacity
|
||||
if self.hot_max_entries > 0 && hot.len() >= self.hot_max_entries {
|
||||
if let Some(oldest_key) = hot
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.accessed_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
hot.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
if self.hot_max_entries > 0 {
|
||||
let now = std::time::Instant::now();
|
||||
hot.insert(
|
||||
key.to_string(),
|
||||
InMemoryEntry {
|
||||
response: response.to_string(),
|
||||
token_count,
|
||||
created_at: now,
|
||||
accessed_at: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
|
||||
pub fn stats(&self) -> Result<(usize, u64, u64)> {
|
||||
let conn = self.conn.lock();
|
||||
@ -163,8 +266,8 @@ impl ResponseCache {
|
||||
|
||||
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
|
||||
pub fn clear(&self) -> Result<usize> {
|
||||
self.hot_cache.lock().clear();
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let affected = conn.execute("DELETE FROM response_cache", [])?;
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
@ -47,6 +47,15 @@ impl Observer for LogObserver {
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
info!("heartbeat.tick");
|
||||
}
|
||||
ObserverEvent::CacheHit {
|
||||
cache_type,
|
||||
tokens_saved,
|
||||
} => {
|
||||
info!(cache_type = %cache_type, tokens_saved = tokens_saved, "cache.hit");
|
||||
}
|
||||
ObserverEvent::CacheMiss { cache_type } => {
|
||||
info!(cache_type = %cache_type, "cache.miss");
|
||||
}
|
||||
ObserverEvent::Error { component, message } => {
|
||||
info!(component = %component, error = %message, "error");
|
||||
}
|
||||
|
||||
@ -16,6 +16,9 @@ pub struct PrometheusObserver {
|
||||
channel_messages: IntCounterVec,
|
||||
heartbeat_ticks: prometheus::IntCounter,
|
||||
errors: IntCounterVec,
|
||||
cache_hits: IntCounterVec,
|
||||
cache_misses: IntCounterVec,
|
||||
cache_tokens_saved: IntCounterVec,
|
||||
|
||||
// Histograms
|
||||
agent_duration: HistogramVec,
|
||||
@ -81,6 +84,27 @@ impl PrometheusObserver {
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_hits = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_cache_hits_total", "Total response cache hits"),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_misses = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_cache_misses_total", "Total response cache misses"),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_tokens_saved = IntCounterVec::new(
|
||||
prometheus::Opts::new(
|
||||
"zeroclaw_cache_tokens_saved_total",
|
||||
"Total tokens saved by response cache",
|
||||
),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let agent_duration = HistogramVec::new(
|
||||
HistogramOpts::new(
|
||||
"zeroclaw_agent_duration_seconds",
|
||||
@ -139,6 +163,9 @@ impl PrometheusObserver {
|
||||
registry.register(Box::new(channel_messages.clone())).ok();
|
||||
registry.register(Box::new(heartbeat_ticks.clone())).ok();
|
||||
registry.register(Box::new(errors.clone())).ok();
|
||||
registry.register(Box::new(cache_hits.clone())).ok();
|
||||
registry.register(Box::new(cache_misses.clone())).ok();
|
||||
registry.register(Box::new(cache_tokens_saved.clone())).ok();
|
||||
registry.register(Box::new(agent_duration.clone())).ok();
|
||||
registry.register(Box::new(tool_duration.clone())).ok();
|
||||
registry.register(Box::new(request_latency.clone())).ok();
|
||||
@ -156,6 +183,9 @@ impl PrometheusObserver {
|
||||
channel_messages,
|
||||
heartbeat_ticks,
|
||||
errors,
|
||||
cache_hits,
|
||||
cache_misses,
|
||||
cache_tokens_saved,
|
||||
agent_duration,
|
||||
tool_duration,
|
||||
request_latency,
|
||||
@ -245,6 +275,18 @@ impl Observer for PrometheusObserver {
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
self.heartbeat_ticks.inc();
|
||||
}
|
||||
ObserverEvent::CacheHit {
|
||||
cache_type,
|
||||
tokens_saved,
|
||||
} => {
|
||||
self.cache_hits.with_label_values(&[cache_type]).inc();
|
||||
self.cache_tokens_saved
|
||||
.with_label_values(&[cache_type])
|
||||
.inc_by(*tokens_saved);
|
||||
}
|
||||
ObserverEvent::CacheMiss { cache_type } => {
|
||||
self.cache_misses.with_label_values(&[cache_type]).inc();
|
||||
}
|
||||
ObserverEvent::Error {
|
||||
component,
|
||||
message: _,
|
||||
|
||||
@ -61,6 +61,18 @@ pub enum ObserverEvent {
|
||||
},
|
||||
/// Periodic heartbeat tick from the runtime keep-alive loop.
|
||||
HeartbeatTick,
|
||||
/// Response cache hit — an LLM call was avoided.
|
||||
CacheHit {
|
||||
/// `"hot"` (in-memory) or `"warm"` (SQLite).
|
||||
cache_type: String,
|
||||
/// Estimated tokens saved by this cache hit.
|
||||
tokens_saved: u64,
|
||||
},
|
||||
/// Response cache miss — the prompt was not found in cache.
|
||||
CacheMiss {
|
||||
/// `"response"` cache layer that was checked.
|
||||
cache_type: String,
|
||||
},
|
||||
/// An error occurred in a named component.
|
||||
Error {
|
||||
/// Subsystem where the error originated (e.g., `"provider"`, `"gateway"`).
|
||||
|
||||
@ -402,6 +402,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||
response_cache_enabled: false,
|
||||
response_cache_ttl_minutes: 60,
|
||||
response_cache_max_entries: 5_000,
|
||||
response_cache_hot_entries: 256,
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
|
||||
@ -149,6 +149,10 @@ struct AnthropicUsage {
|
||||
input_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
output_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -475,6 +479,7 @@ impl AnthropicProvider {
|
||||
let usage = response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
cached_input_tokens: u.cache_read_input_tokens,
|
||||
});
|
||||
|
||||
for block in response.content {
|
||||
@ -614,6 +619,7 @@ impl Provider for AnthropicProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -312,6 +312,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -431,6 +432,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@ -491,6 +493,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -832,6 +832,7 @@ impl BedrockProvider {
|
||||
let usage = response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
|
||||
if let Some(output) = response.output {
|
||||
@ -967,6 +968,7 @@ impl Provider for BedrockProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1193,6 +1193,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
crate::providers::traits::ProviderCapabilities {
|
||||
native_tool_calling: self.native_tool_calling,
|
||||
vision: self.supports_vision,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1514,6 +1515,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
let usage = chat_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = chat_response
|
||||
.choices
|
||||
@ -1657,6 +1659,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -353,6 +353,7 @@ impl CopilotProvider {
|
||||
let usage = api_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = api_response
|
||||
.choices
|
||||
|
||||
@ -1128,6 +1128,7 @@ impl GeminiProvider {
|
||||
let usage = result.usage_metadata.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_token_count,
|
||||
output_tokens: u.candidates_token_count,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
|
||||
let text = result
|
||||
|
||||
@ -632,6 +632,7 @@ impl Provider for OllamaProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -764,6 +765,7 @@ impl Provider for OllamaProvider {
|
||||
Some(TokenUsage {
|
||||
input_tokens: response.prompt_eval_count,
|
||||
output_tokens: response.eval_count,
|
||||
cached_input_tokens: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
||||
@ -135,6 +135,14 @@ struct UsageInfo {
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
completion_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
prompt_tokens_details: Option<PromptTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PromptTokensDetails {
|
||||
#[serde(default)]
|
||||
cached_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -385,6 +393,7 @@ impl Provider for OpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@ -448,6 +457,7 @@ impl Provider for OpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -640,6 +640,7 @@ impl Provider for OpenAiCodexProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -306,6 +306,7 @@ impl Provider for OpenRouterProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -463,6 +464,7 @@ impl Provider for OpenRouterProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@ -554,6 +556,7 @@ impl Provider for OpenRouterProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -54,6 +54,9 @@ pub struct ToolCall {
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: Option<u64>,
|
||||
pub output_tokens: Option<u64>,
|
||||
/// Tokens served from the provider's prompt cache (Anthropic `cache_read_input_tokens`,
|
||||
/// OpenAI `prompt_tokens_details.cached_tokens`).
|
||||
pub cached_input_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
/// An LLM response that may contain text, tool calls, or both.
|
||||
@ -233,6 +236,9 @@ pub struct ProviderCapabilities {
|
||||
pub native_tool_calling: bool,
|
||||
/// Whether the provider supports vision / image inputs.
|
||||
pub vision: bool,
|
||||
/// Whether the provider supports prompt caching (Anthropic cache_control,
|
||||
/// OpenAI automatic prompt caching).
|
||||
pub prompt_caching: bool,
|
||||
}
|
||||
|
||||
/// Provider-specific tool payload formats.
|
||||
@ -498,6 +504,7 @@ mod tests {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -568,6 +575,7 @@ mod tests {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(100),
|
||||
output_tokens: Some(50),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
};
|
||||
@ -613,14 +621,17 @@ mod tests {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
|
||||
@ -166,6 +166,7 @@ impl Provider for TraceLlmProvider {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(input_tokens),
|
||||
output_tokens: Some(output_tokens),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}),
|
||||
@ -188,6 +189,7 @@ impl Provider for TraceLlmProvider {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(input_tokens),
|
||||
output_tokens: Some(output_tokens),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user