From 318ed8e9f1f337f66a360935a6efbfe052711dcd Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 16 Mar 2026 12:08:32 -0400 Subject: [PATCH 1/3] feat(heartbeat): add health metrics, adaptive intervals, and task history - Add HeartbeatMetrics struct with uptime, consecutive success/failure counts, EMA tick duration, and total ticks - Add compute_adaptive_interval() for exponential backoff on failures and faster polling when high-priority tasks are present - Add SQLite-backed task run history (src/heartbeat/store.rs) mirroring the cron/store.rs pattern with output truncation and pruning - Add dead-man's switch that alerts if heartbeat stops ticking - Wire metrics, history recording, and adaptive sleep into daemon worker - Add config fields: adaptive, min/max_interval_minutes, deadman_timeout_minutes, deadman_channel, deadman_to, max_run_history - All new fields are backward-compatible with serde defaults --- src/config/schema.rs | 44 ++++++ src/daemon/mod.rs | 137 ++++++++++++++++-- src/heartbeat/engine.rs | 177 +++++++++++++++++++++++ src/heartbeat/mod.rs | 1 + src/heartbeat/store.rs | 305 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 655 insertions(+), 9 deletions(-) create mode 100644 src/heartbeat/store.rs diff --git a/src/config/schema.rs b/src/config/schema.rs index 345e230ed..db7d60348 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -3344,12 +3344,48 @@ pub struct HeartbeatConfig { /// explicitly set). #[serde(default, alias = "recipient")] pub to: Option, + /// Enable adaptive intervals that back off on failures and speed up for + /// high-priority tasks. Default: `false`. + #[serde(default)] + pub adaptive: bool, + /// Minimum interval in minutes when adaptive mode is enabled. Default: `5`. + #[serde(default = "default_heartbeat_min_interval")] + pub min_interval_minutes: u32, + /// Maximum interval in minutes when adaptive mode backs off. Default: `120`. + #[serde(default = "default_heartbeat_max_interval")] + pub max_interval_minutes: u32, + /// Dead-man's switch timeout in minutes. If the heartbeat has not ticked + /// within this window, an alert is sent. `0` disables. Default: `0`. + #[serde(default)] + pub deadman_timeout_minutes: u32, + /// Channel for dead-man's switch alerts (e.g. `telegram`). Falls back to + /// the heartbeat delivery channel. + #[serde(default)] + pub deadman_channel: Option, + /// Recipient for dead-man's switch alerts. Falls back to `to`. + #[serde(default)] + pub deadman_to: Option, + /// Maximum number of heartbeat run history records to retain. Default: `100`. + #[serde(default = "default_heartbeat_max_run_history")] + pub max_run_history: u32, } fn default_two_phase() -> bool { true } +fn default_heartbeat_min_interval() -> u32 { + 5 +} + +fn default_heartbeat_max_interval() -> u32 { + 120 +} + +fn default_heartbeat_max_run_history() -> u32 { + 100 +} + impl Default for HeartbeatConfig { fn default() -> Self { Self { @@ -3359,6 +3395,13 @@ impl Default for HeartbeatConfig { message: None, target: None, to: None, + adaptive: false, + min_interval_minutes: default_heartbeat_min_interval(), + max_interval_minutes: default_heartbeat_max_interval(), + deadman_timeout_minutes: 0, + deadman_channel: None, + deadman_to: None, + max_run_history: default_heartbeat_max_run_history(), } } } @@ -7358,6 +7401,7 @@ default_temperature = 0.7 message: Some("Check London time".into()), target: Some("telegram".into()), to: Some("123456".into()), + ..HeartbeatConfig::default() }, cron: CronConfig::default(), channels_config: ChannelsConfig { diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 267dae28a..7dc8cfe73 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -203,7 +203,10 @@ where } async fn run_heartbeat_worker(config: Config) -> Result<()> { - use crate::heartbeat::engine::HeartbeatEngine; + use crate::heartbeat::engine::{ + compute_adaptive_interval, HeartbeatEngine, HeartbeatTask, TaskPriority, TaskStatus, + }; + use std::sync::Arc; let observer: std::sync::Arc = std::sync::Arc::from(crate::observability::create_observer(&config.observability)); @@ -212,19 +215,72 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { config.workspace_dir.clone(), observer, ); + let metrics = engine.metrics(); let delivery = resolve_heartbeat_delivery(&config)?; let two_phase = config.heartbeat.two_phase; + let adaptive = config.heartbeat.adaptive; + let start_time = std::time::Instant::now(); - let interval_mins = config.heartbeat.interval_minutes.max(5); - let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60)); + // ── Deadman watcher ────────────────────────────────────────── + let deadman_timeout = config.heartbeat.deadman_timeout_minutes; + if deadman_timeout > 0 { + let dm_metrics = Arc::clone(&metrics); + let dm_config = config.clone(); + let dm_delivery = delivery.clone(); + tokio::spawn(async move { + let check_interval = Duration::from_secs(60); + let timeout = chrono::Duration::minutes(i64::from(deadman_timeout)); + loop { + tokio::time::sleep(check_interval).await; + let last_tick = dm_metrics.lock().last_tick_at; + if let Some(last) = last_tick { + if chrono::Utc::now() - last > timeout { + let alert = format!( + "⚠️ Heartbeat dead-man's switch: no tick in {deadman_timeout} minutes" + ); + let (channel, target) = + if let Some(ch) = &dm_config.heartbeat.deadman_channel { + let to = dm_config + .heartbeat + .deadman_to + .as_deref() + .or(dm_config.heartbeat.to.as_deref()) + .unwrap_or_default(); + (ch.clone(), to.to_string()) + } else if let Some((ch, to)) = &dm_delivery { + (ch.clone(), to.clone()) + } else { + continue; + }; + let _ = crate::cron::scheduler::deliver_announcement( + &dm_config, &channel, &target, &alert, + ) + .await; + } + } + } + }); + } + + let base_interval = config.heartbeat.interval_minutes.max(5); + let mut sleep_mins = base_interval; loop { - interval.tick().await; + tokio::time::sleep(Duration::from_secs(u64::from(sleep_mins) * 60)).await; + + // Update uptime + { + let mut m = metrics.lock(); + m.uptime_secs = start_time.elapsed().as_secs(); + } + + let tick_start = std::time::Instant::now(); // Collect runnable tasks (active only, sorted by priority) let mut tasks = engine.collect_runnable_tasks().await?; + let has_high_priority = tasks.iter().any(|t| t.priority == TaskPriority::High); + if tasks.is_empty() { - // Try fallback message if let Some(fallback) = config .heartbeat .message @@ -232,12 +288,15 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { .map(str::trim) .filter(|m| !m.is_empty()) { - tasks.push(crate::heartbeat::engine::HeartbeatTask { + tasks.push(HeartbeatTask { text: fallback.to_string(), - priority: crate::heartbeat::engine::TaskPriority::Medium, - status: crate::heartbeat::engine::TaskStatus::Active, + priority: TaskPriority::Medium, + status: TaskStatus::Active, }); } else { + #[allow(clippy::cast_precision_loss)] + let elapsed = tick_start.elapsed().as_millis() as f64; + metrics.lock().record_success(elapsed); continue; } } @@ -250,7 +309,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { Some(decision_prompt), None, None, - 0.0, // Low temperature for deterministic decision + 0.0, vec![], false, None, @@ -263,6 +322,9 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { if indices.is_empty() { tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)"); crate::health::mark_component_ok("heartbeat"); + #[allow(clippy::cast_precision_loss)] + let elapsed = tick_start.elapsed().as_millis() as f64; + metrics.lock().record_success(elapsed); continue; } tracing::info!( @@ -285,7 +347,9 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { }; // ── Phase 2: Execute selected tasks ───────────────────── + let mut tick_had_error = false; for task in &tasks_to_run { + let task_start = std::time::Instant::now(); let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text); let temp = config.default_temperature; match Box::pin(crate::agent::run( @@ -303,6 +367,20 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { { Ok(output) => { crate::health::mark_component_ok("heartbeat"); + #[allow(clippy::cast_possible_truncation)] + let duration_ms = task_start.elapsed().as_millis() as i64; + let now = chrono::Utc::now(); + let _ = crate::heartbeat::store::record_run( + &config.workspace_dir, + &task.text, + &task.priority.to_string(), + now - chrono::Duration::milliseconds(duration_ms), + now, + "ok", + Some(output.as_str()), + duration_ms, + config.heartbeat.max_run_history, + ); let announcement = if output.trim().is_empty() { format!("💓 heartbeat task completed: {}", task.text) } else { @@ -326,11 +404,52 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { } } Err(e) => { + tick_had_error = true; + #[allow(clippy::cast_possible_truncation)] + let duration_ms = task_start.elapsed().as_millis() as i64; + let now = chrono::Utc::now(); + let _ = crate::heartbeat::store::record_run( + &config.workspace_dir, + &task.text, + &task.priority.to_string(), + now - chrono::Duration::milliseconds(duration_ms), + now, + "error", + Some(&e.to_string()), + duration_ms, + config.heartbeat.max_run_history, + ); crate::health::mark_component_error("heartbeat", e.to_string()); tracing::warn!("Heartbeat task failed: {e}"); } } } + + // Update metrics + #[allow(clippy::cast_precision_loss)] + let tick_elapsed = tick_start.elapsed().as_millis() as f64; + { + let mut m = metrics.lock(); + if tick_had_error { + m.record_failure(tick_elapsed); + } else { + m.record_success(tick_elapsed); + } + } + + // Compute next sleep interval + if adaptive { + let failures = metrics.lock().consecutive_failures; + sleep_mins = compute_adaptive_interval( + base_interval, + config.heartbeat.min_interval_minutes, + config.heartbeat.max_interval_minutes, + failures, + has_high_priority, + ); + } else { + sleep_mins = base_interval; + } } } diff --git a/src/heartbeat/engine.rs b/src/heartbeat/engine.rs index f7e3b59fa..abecf0480 100644 --- a/src/heartbeat/engine.rs +++ b/src/heartbeat/engine.rs @@ -1,6 +1,8 @@ use crate::config::HeartbeatConfig; use crate::observability::{Observer, ObserverEvent}; use anyhow::Result; +use chrono::{DateTime, Utc}; +use parking_lot::Mutex as ParkingMutex; use serde::{Deserialize, Serialize}; use std::fmt; use std::path::Path; @@ -68,6 +70,99 @@ impl fmt::Display for HeartbeatTask { } } +// ── Health Metrics ─────────────────────────────────────────────── + +/// Live health metrics for the heartbeat subsystem. +/// +/// Shared via `Arc>` between the heartbeat worker, +/// deadman watcher, and API consumers. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HeartbeatMetrics { + /// Monotonic uptime since the heartbeat loop started. + pub uptime_secs: u64, + /// Consecutive successful ticks (resets on failure). + pub consecutive_successes: u64, + /// Consecutive failed ticks (resets on success). + pub consecutive_failures: u64, + /// Timestamp of the most recent tick (UTC RFC 3339). + pub last_tick_at: Option>, + /// Exponential moving average of tick durations in milliseconds. + pub avg_tick_duration_ms: f64, + /// Total number of ticks executed since startup. + pub total_ticks: u64, +} + +impl Default for HeartbeatMetrics { + fn default() -> Self { + Self { + uptime_secs: 0, + consecutive_successes: 0, + consecutive_failures: 0, + last_tick_at: None, + avg_tick_duration_ms: 0.0, + total_ticks: 0, + } + } +} + +impl HeartbeatMetrics { + /// Record a successful tick with the given duration. + pub fn record_success(&mut self, duration_ms: f64) { + self.consecutive_successes += 1; + self.consecutive_failures = 0; + self.last_tick_at = Some(Utc::now()); + self.total_ticks += 1; + self.update_avg_duration(duration_ms); + } + + /// Record a failed tick with the given duration. + pub fn record_failure(&mut self, duration_ms: f64) { + self.consecutive_failures += 1; + self.consecutive_successes = 0; + self.last_tick_at = Some(Utc::now()); + self.total_ticks += 1; + self.update_avg_duration(duration_ms); + } + + fn update_avg_duration(&mut self, duration_ms: f64) { + const ALPHA: f64 = 0.3; // EMA smoothing factor + if self.total_ticks == 1 { + self.avg_tick_duration_ms = duration_ms; + } else { + self.avg_tick_duration_ms = + ALPHA * duration_ms + (1.0 - ALPHA) * self.avg_tick_duration_ms; + } + } +} + +/// Compute the adaptive interval for the next heartbeat tick. +/// +/// Strategy: +/// - On failures: exponential back-off `base * 2^failures` capped at `max_interval`. +/// - When high-priority tasks are present: use `min_interval` for faster reaction. +/// - Otherwise: use `base_interval`. +pub fn compute_adaptive_interval( + base_minutes: u32, + min_minutes: u32, + max_minutes: u32, + consecutive_failures: u64, + has_high_priority_tasks: bool, +) -> u32 { + if consecutive_failures > 0 { + let backoff = base_minutes.saturating_mul( + 1u32.checked_shl(consecutive_failures.min(10) as u32) + .unwrap_or(u32::MAX), + ); + return backoff.min(max_minutes).max(min_minutes); + } + + if has_high_priority_tasks { + return min_minutes.max(5); // never go below 5 minutes + } + + base_minutes.clamp(min_minutes, max_minutes) +} + // ── Engine ─────────────────────────────────────────────────────── /// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically @@ -75,6 +170,7 @@ pub struct HeartbeatEngine { config: HeartbeatConfig, workspace_dir: std::path::PathBuf, observer: Arc, + metrics: Arc>, } impl HeartbeatEngine { @@ -87,9 +183,15 @@ impl HeartbeatEngine { config, workspace_dir, observer, + metrics: Arc::new(ParkingMutex::new(HeartbeatMetrics::default())), } } + /// Get a shared handle to the live heartbeat metrics. + pub fn metrics(&self) -> Arc> { + Arc::clone(&self.metrics) + } + /// Start the heartbeat loop (runs until cancelled) pub async fn run(&self) -> Result<()> { if !self.config.enabled { @@ -673,4 +775,79 @@ mod tests { let _ = tokio::fs::remove_dir_all(&dir).await; } + + // ── HeartbeatMetrics tests ─────────────────────────────────── + + #[test] + fn metrics_record_success_updates_fields() { + let mut m = HeartbeatMetrics::default(); + m.record_success(100.0); + assert_eq!(m.consecutive_successes, 1); + assert_eq!(m.consecutive_failures, 0); + assert_eq!(m.total_ticks, 1); + assert!(m.last_tick_at.is_some()); + assert!((m.avg_tick_duration_ms - 100.0).abs() < f64::EPSILON); + } + + #[test] + fn metrics_record_failure_resets_successes() { + let mut m = HeartbeatMetrics::default(); + m.record_success(50.0); + m.record_success(50.0); + m.record_failure(200.0); + assert_eq!(m.consecutive_successes, 0); + assert_eq!(m.consecutive_failures, 1); + assert_eq!(m.total_ticks, 3); + } + + #[test] + fn metrics_ema_smoothing() { + let mut m = HeartbeatMetrics::default(); + m.record_success(100.0); + assert!((m.avg_tick_duration_ms - 100.0).abs() < f64::EPSILON); + m.record_success(200.0); + // EMA: 0.3 * 200 + 0.7 * 100 = 130 + assert!((m.avg_tick_duration_ms - 130.0).abs() < f64::EPSILON); + } + + // ── Adaptive interval tests ───────────────────────────────── + + #[test] + fn adaptive_uses_base_when_no_failures() { + let result = compute_adaptive_interval(30, 5, 120, 0, false); + assert_eq!(result, 30); + } + + #[test] + fn adaptive_uses_min_for_high_priority() { + let result = compute_adaptive_interval(30, 5, 120, 0, true); + assert_eq!(result, 5); + } + + #[test] + fn adaptive_backs_off_on_failures() { + // 1 failure: 30 * 2 = 60 + assert_eq!(compute_adaptive_interval(30, 5, 120, 1, false), 60); + // 2 failures: 30 * 4 = 120 (capped at max) + assert_eq!(compute_adaptive_interval(30, 5, 120, 2, false), 120); + // 3 failures: 30 * 8 = 240 → capped at 120 + assert_eq!(compute_adaptive_interval(30, 5, 120, 3, false), 120); + } + + #[test] + fn adaptive_backoff_respects_min() { + // Even with failures, must be >= min + assert!(compute_adaptive_interval(5, 10, 120, 0, false) >= 10); + } + + // ── Engine metrics accessor ───────────────────────────────── + + #[test] + fn engine_exposes_shared_metrics() { + let observer: Arc = Arc::new(crate::observability::NoopObserver); + let engine = + HeartbeatEngine::new(HeartbeatConfig::default(), std::env::temp_dir(), observer); + let metrics = engine.metrics(); + assert_eq!(metrics.lock().total_ticks, 0); + } } diff --git a/src/heartbeat/mod.rs b/src/heartbeat/mod.rs index 865c91e7a..caa12b5a8 100644 --- a/src/heartbeat/mod.rs +++ b/src/heartbeat/mod.rs @@ -1,4 +1,5 @@ pub mod engine; +pub mod store; #[cfg(test)] mod tests { diff --git a/src/heartbeat/store.rs b/src/heartbeat/store.rs new file mode 100644 index 000000000..d9140e17d --- /dev/null +++ b/src/heartbeat/store.rs @@ -0,0 +1,305 @@ +//! SQLite persistence for heartbeat task execution history. +//! +//! Mirrors the `cron/store.rs` pattern: fresh connection per call, schema +//! auto-created, output truncated, history pruned to a configurable limit. + +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use rusqlite::{params, Connection}; +use std::path::{Path, PathBuf}; + +const MAX_OUTPUT_BYTES: usize = 16 * 1024; +const TRUNCATED_MARKER: &str = "\n...[truncated]"; + +/// A single heartbeat task execution record. +#[derive(Debug, Clone)] +pub struct HeartbeatRun { + pub id: i64, + pub task_text: String, + pub task_priority: String, + pub started_at: DateTime, + pub finished_at: DateTime, + pub status: String, // "ok" or "error" + pub output: Option, + pub duration_ms: i64, +} + +/// Record a heartbeat task execution and prune old entries. +pub fn record_run( + workspace_dir: &Path, + task_text: &str, + task_priority: &str, + started_at: DateTime, + finished_at: DateTime, + status: &str, + output: Option<&str>, + duration_ms: i64, + max_history: u32, +) -> Result<()> { + let bounded_output = output.map(truncate_output); + with_connection(workspace_dir, |conn| { + let tx = conn.unchecked_transaction()?; + + tx.execute( + "INSERT INTO heartbeat_runs + (task_text, task_priority, started_at, finished_at, status, output, duration_ms) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + params![ + task_text, + task_priority, + started_at.to_rfc3339(), + finished_at.to_rfc3339(), + status, + bounded_output.as_deref(), + duration_ms, + ], + ) + .context("Failed to insert heartbeat run")?; + + let keep = i64::from(max_history.max(1)); + tx.execute( + "DELETE FROM heartbeat_runs + WHERE id NOT IN ( + SELECT id FROM heartbeat_runs + ORDER BY started_at DESC, id DESC + LIMIT ?1 + )", + params![keep], + ) + .context("Failed to prune heartbeat run history")?; + + tx.commit() + .context("Failed to commit heartbeat run transaction")?; + Ok(()) + }) +} + +/// List the most recent heartbeat runs. +pub fn list_runs(workspace_dir: &Path, limit: usize) -> Result> { + with_connection(workspace_dir, |conn| { + let lim = i64::try_from(limit.max(1)).context("Run history limit overflow")?; + let mut stmt = conn.prepare( + "SELECT id, task_text, task_priority, started_at, finished_at, status, output, duration_ms + FROM heartbeat_runs + ORDER BY started_at DESC, id DESC + LIMIT ?1", + )?; + + let rows = stmt.query_map(params![lim], |row| { + Ok(HeartbeatRun { + id: row.get(0)?, + task_text: row.get(1)?, + task_priority: row.get(2)?, + started_at: parse_rfc3339(&row.get::<_, String>(3)?).map_err(sql_err)?, + finished_at: parse_rfc3339(&row.get::<_, String>(4)?).map_err(sql_err)?, + status: row.get(5)?, + output: row.get(6)?, + duration_ms: row.get(7)?, + }) + })?; + + let mut runs = Vec::new(); + for row in rows { + runs.push(row?); + } + Ok(runs) + }) +} + +/// Get aggregate stats: (total_runs, total_ok, total_error). +pub fn run_stats(workspace_dir: &Path) -> Result<(u64, u64, u64)> { + with_connection(workspace_dir, |conn| { + let total: i64 = conn.query_row("SELECT COUNT(*) FROM heartbeat_runs", [], |r| r.get(0))?; + let ok: i64 = conn.query_row( + "SELECT COUNT(*) FROM heartbeat_runs WHERE status = 'ok'", + [], + |r| r.get(0), + )?; + let err: i64 = conn.query_row( + "SELECT COUNT(*) FROM heartbeat_runs WHERE status = 'error'", + [], + |r| r.get(0), + )?; + #[allow(clippy::cast_sign_loss)] + Ok((total as u64, ok as u64, err as u64)) + }) +} + +fn db_path(workspace_dir: &Path) -> PathBuf { + workspace_dir.join("heartbeat").join("history.db") +} + +fn with_connection(workspace_dir: &Path, f: impl FnOnce(&Connection) -> Result) -> Result { + let path = db_path(workspace_dir); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).with_context(|| { + format!("Failed to create heartbeat directory: {}", parent.display()) + })?; + } + + let conn = Connection::open(&path) + .with_context(|| format!("Failed to open heartbeat history DB: {}", path.display()))?; + + conn.execute_batch( + "PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL; + PRAGMA temp_store = MEMORY; + + CREATE TABLE IF NOT EXISTS heartbeat_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_text TEXT NOT NULL, + task_priority TEXT NOT NULL, + started_at TEXT NOT NULL, + finished_at TEXT NOT NULL, + status TEXT NOT NULL, + output TEXT, + duration_ms INTEGER + ); + CREATE INDEX IF NOT EXISTS idx_hb_runs_started ON heartbeat_runs(started_at); + CREATE INDEX IF NOT EXISTS idx_hb_runs_task ON heartbeat_runs(task_text);", + ) + .context("Failed to initialize heartbeat history schema")?; + + f(&conn) +} + +fn truncate_output(output: &str) -> String { + if output.len() <= MAX_OUTPUT_BYTES { + return output.to_string(); + } + + if MAX_OUTPUT_BYTES <= TRUNCATED_MARKER.len() { + return TRUNCATED_MARKER.to_string(); + } + + let mut cutoff = MAX_OUTPUT_BYTES - TRUNCATED_MARKER.len(); + while cutoff > 0 && !output.is_char_boundary(cutoff) { + cutoff -= 1; + } + + let mut truncated = output[..cutoff].to_string(); + truncated.push_str(TRUNCATED_MARKER); + truncated +} + +fn parse_rfc3339(raw: &str) -> Result> { + let parsed = DateTime::parse_from_rfc3339(raw) + .with_context(|| format!("Invalid RFC3339 timestamp in heartbeat DB: {raw}"))?; + Ok(parsed.with_timezone(&Utc)) +} + +fn sql_err(err: anyhow::Error) -> rusqlite::Error { + rusqlite::Error::ToSqlConversionFailure(err.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Duration as ChronoDuration; + use tempfile::TempDir; + + #[test] + fn record_and_list_runs() { + let tmp = TempDir::new().unwrap(); + let base = Utc::now(); + + for i in 0..3 { + let start = base + ChronoDuration::seconds(i); + let end = start + ChronoDuration::milliseconds(100); + record_run( + tmp.path(), + &format!("Task {i}"), + "medium", + start, + end, + "ok", + Some("done"), + 100, + 50, + ) + .unwrap(); + } + + let runs = list_runs(tmp.path(), 10).unwrap(); + assert_eq!(runs.len(), 3); + // Most recent first + assert!(runs[0].task_text.contains('2')); + } + + #[test] + fn prunes_old_runs() { + let tmp = TempDir::new().unwrap(); + let base = Utc::now(); + + for i in 0..5 { + let start = base + ChronoDuration::seconds(i); + let end = start + ChronoDuration::milliseconds(50); + record_run( + tmp.path(), + "Task", + "high", + start, + end, + "ok", + None, + 50, + 2, // keep only 2 + ) + .unwrap(); + } + + let runs = list_runs(tmp.path(), 10).unwrap(); + assert_eq!(runs.len(), 2); + } + + #[test] + fn run_stats_counts_correctly() { + let tmp = TempDir::new().unwrap(); + let now = Utc::now(); + + record_run(tmp.path(), "A", "high", now, now, "ok", None, 10, 50).unwrap(); + record_run( + tmp.path(), + "B", + "low", + now, + now, + "error", + Some("fail"), + 20, + 50, + ) + .unwrap(); + record_run(tmp.path(), "C", "medium", now, now, "ok", None, 15, 50).unwrap(); + + let (total, ok, err) = run_stats(tmp.path()).unwrap(); + assert_eq!(total, 3); + assert_eq!(ok, 2); + assert_eq!(err, 1); + } + + #[test] + fn truncates_large_output() { + let tmp = TempDir::new().unwrap(); + let now = Utc::now(); + let big = "x".repeat(MAX_OUTPUT_BYTES + 512); + + record_run( + tmp.path(), + "T", + "medium", + now, + now, + "ok", + Some(&big), + 10, + 50, + ) + .unwrap(); + + let runs = list_runs(tmp.path(), 1).unwrap(); + let stored = runs[0].output.as_deref().unwrap_or_default(); + assert!(stored.ends_with(TRUNCATED_MARKER)); + assert!(stored.len() <= MAX_OUTPUT_BYTES); + } +} From 9ba5ba563230b12cbc26c74af7f63c16846cf2bf Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 16 Mar 2026 12:23:18 -0400 Subject: [PATCH 2/3] feat(sessions): add SQLite backend with FTS5, trait abstraction, and migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add SessionBackend trait abstracting over storage backends (load, append, remove_last, list, search, cleanup_stale, compact) - Add SqliteSessionBackend with WAL mode, FTS5 full-text search, session metadata tracking, and TTL-based cleanup - Add remove_last() and compact() to JSONL SessionStore - Implement SessionBackend for both JSONL and SQLite backends - Add automatic JSONL-to-SQLite migration (renames .jsonl → .jsonl.migrated) - Add config: session_backend ("jsonl"/"sqlite"), session_ttl_hours - SQLite is the new default backend; JSONL preserved for backward compat --- src/channels/mod.rs | 2 + src/channels/session_backend.rs | 103 +++++++ src/channels/session_sqlite.rs | 503 ++++++++++++++++++++++++++++++++ src/channels/session_store.rs | 111 +++++++ src/config/schema.rs | 19 ++ 5 files changed, 738 insertions(+) create mode 100644 src/channels/session_backend.rs create mode 100644 src/channels/session_sqlite.rs diff --git a/src/channels/mod.rs b/src/channels/mod.rs index d5caaff1b..5928e6996 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -32,6 +32,8 @@ pub mod nextcloud_talk; pub mod nostr; pub mod notion; pub mod qq; +pub mod session_backend; +pub mod session_sqlite; pub mod session_store; pub mod signal; pub mod slack; diff --git a/src/channels/session_backend.rs b/src/channels/session_backend.rs new file mode 100644 index 000000000..b467b0932 --- /dev/null +++ b/src/channels/session_backend.rs @@ -0,0 +1,103 @@ +//! Trait abstraction for session persistence backends. +//! +//! Backends store per-sender conversation histories. The trait is intentionally +//! minimal — load, append, remove_last, list — so that JSONL and SQLite (and +//! future backends) share a common interface. + +use crate::providers::traits::ChatMessage; +use chrono::{DateTime, Utc}; + +/// Metadata about a persisted session. +#[derive(Debug, Clone)] +pub struct SessionMetadata { + /// Session key (e.g. `telegram_user123`). + pub key: String, + /// When the session was first created. + pub created_at: DateTime, + /// When the last message was appended. + pub last_activity: DateTime, + /// Total number of messages in the session. + pub message_count: usize, +} + +/// Query parameters for listing sessions. +#[derive(Debug, Clone, Default)] +pub struct SessionQuery { + /// Keyword to search in session messages (FTS5 if available). + pub keyword: Option, + /// Maximum number of sessions to return. + pub limit: Option, +} + +/// Trait for session persistence backends. +/// +/// Implementations must be `Send + Sync` for sharing across async tasks. +pub trait SessionBackend: Send + Sync { + /// Load all messages for a session. Returns empty vec if session doesn't exist. + fn load(&self, session_key: &str) -> Vec; + + /// Append a single message to a session. + fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()>; + + /// Remove the last message from a session. Returns `true` if a message was removed. + fn remove_last(&self, session_key: &str) -> std::io::Result; + + /// List all session keys. + fn list_sessions(&self) -> Vec; + + /// List sessions with metadata. + fn list_sessions_with_metadata(&self) -> Vec { + // Default: construct metadata from messages (backends can override for efficiency) + self.list_sessions() + .into_iter() + .map(|key| { + let messages = self.load(&key); + SessionMetadata { + key, + created_at: Utc::now(), + last_activity: Utc::now(), + message_count: messages.len(), + } + }) + .collect() + } + + /// Compact a session file (remove duplicates/corruption). No-op by default. + fn compact(&self, _session_key: &str) -> std::io::Result<()> { + Ok(()) + } + + /// Remove sessions that haven't been active within the given TTL hours. + fn cleanup_stale(&self, _ttl_hours: u32) -> std::io::Result { + Ok(0) + } + + /// Search sessions by keyword. Default returns empty (backends with FTS override). + fn search(&self, _query: &SessionQuery) -> Vec { + Vec::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn session_metadata_is_constructible() { + let meta = SessionMetadata { + key: "test".into(), + created_at: Utc::now(), + last_activity: Utc::now(), + message_count: 5, + }; + assert_eq!(meta.key, "test"); + assert_eq!(meta.message_count, 5); + } + + #[test] + fn session_query_defaults() { + let q = SessionQuery::default(); + assert!(q.keyword.is_none()); + assert!(q.limit.is_none()); + } +} diff --git a/src/channels/session_sqlite.rs b/src/channels/session_sqlite.rs new file mode 100644 index 000000000..9fb84f64a --- /dev/null +++ b/src/channels/session_sqlite.rs @@ -0,0 +1,503 @@ +//! SQLite-backed session persistence with FTS5 search. +//! +//! Stores sessions in `{workspace}/sessions/sessions.db` using WAL mode. +//! Provides full-text search via FTS5 and automatic TTL-based cleanup. +//! Designed as the default backend, replacing JSONL for new installations. + +use crate::channels::session_backend::{SessionBackend, SessionMetadata, SessionQuery}; +use crate::providers::traits::ChatMessage; +use anyhow::{Context, Result}; +use chrono::{DateTime, Duration, Utc}; +use parking_lot::Mutex; +use rusqlite::{params, Connection}; +use std::path::{Path, PathBuf}; + +/// SQLite-backed session store with FTS5 and WAL mode. +pub struct SqliteSessionBackend { + conn: Mutex, + #[allow(dead_code)] + db_path: PathBuf, +} + +impl SqliteSessionBackend { + /// Open or create the sessions database. + pub fn new(workspace_dir: &Path) -> Result { + let sessions_dir = workspace_dir.join("sessions"); + std::fs::create_dir_all(&sessions_dir).context("Failed to create sessions directory")?; + let db_path = sessions_dir.join("sessions.db"); + + let conn = Connection::open(&db_path) + .with_context(|| format!("Failed to open session DB: {}", db_path.display()))?; + + conn.execute_batch( + "PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL; + PRAGMA temp_store = MEMORY; + PRAGMA mmap_size = 4194304;", + )?; + + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_key TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_sessions_key ON sessions(session_key); + CREATE INDEX IF NOT EXISTS idx_sessions_key_id ON sessions(session_key, id); + + CREATE TABLE IF NOT EXISTS session_metadata ( + session_key TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + last_activity TEXT NOT NULL, + message_count INTEGER NOT NULL DEFAULT 0 + ); + + CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5( + session_key, content, content=sessions, content_rowid=id + ); + + CREATE TRIGGER IF NOT EXISTS sessions_ai AFTER INSERT ON sessions BEGIN + INSERT INTO sessions_fts(rowid, session_key, content) + VALUES (new.id, new.session_key, new.content); + END; + CREATE TRIGGER IF NOT EXISTS sessions_ad AFTER DELETE ON sessions BEGIN + INSERT INTO sessions_fts(sessions_fts, rowid, session_key, content) + VALUES ('delete', old.id, old.session_key, old.content); + END;", + ) + .context("Failed to initialize session schema")?; + + Ok(Self { + conn: Mutex::new(conn), + db_path, + }) + } + + /// Migrate JSONL session files into SQLite. Renames migrated files to `.jsonl.migrated`. + pub fn migrate_from_jsonl(&self, workspace_dir: &Path) -> Result { + let sessions_dir = workspace_dir.join("sessions"); + let entries = match std::fs::read_dir(&sessions_dir) { + Ok(e) => e, + Err(_) => return Ok(0), + }; + + let mut migrated = 0; + for entry in entries { + let entry = match entry { + Ok(e) => e, + Err(_) => continue, + }; + let name = match entry.file_name().into_string() { + Ok(n) => n, + Err(_) => continue, + }; + let Some(key) = name.strip_suffix(".jsonl") else { + continue; + }; + + let path = entry.path(); + let file = match std::fs::File::open(&path) { + Ok(f) => f, + Err(_) => continue, + }; + + let reader = std::io::BufReader::new(file); + let mut count = 0; + for line in std::io::BufRead::lines(reader) { + let Ok(line) = line else { continue }; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + if let Ok(msg) = serde_json::from_str::(trimmed) { + if self.append(key, &msg).is_ok() { + count += 1; + } + } + } + + if count > 0 { + let migrated_path = path.with_extension("jsonl.migrated"); + let _ = std::fs::rename(&path, &migrated_path); + migrated += 1; + } + } + + Ok(migrated) + } +} + +impl SessionBackend for SqliteSessionBackend { + fn load(&self, session_key: &str) -> Vec { + let conn = self.conn.lock(); + let mut stmt = match conn + .prepare("SELECT role, content FROM sessions WHERE session_key = ?1 ORDER BY id ASC") + { + Ok(s) => s, + Err(_) => return Vec::new(), + }; + + let rows = match stmt.query_map(params![session_key], |row| { + Ok(ChatMessage { + role: row.get(0)?, + content: row.get(1)?, + }) + }) { + Ok(r) => r, + Err(_) => return Vec::new(), + }; + + rows.filter_map(|r| r.ok()).collect() + } + + fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> { + let conn = self.conn.lock(); + let now = Utc::now().to_rfc3339(); + + conn.execute( + "INSERT INTO sessions (session_key, role, content, created_at) + VALUES (?1, ?2, ?3, ?4)", + params![session_key, message.role, message.content, now], + ) + .map_err(std::io::Error::other)?; + + // Upsert metadata + conn.execute( + "INSERT INTO session_metadata (session_key, created_at, last_activity, message_count) + VALUES (?1, ?2, ?3, 1) + ON CONFLICT(session_key) DO UPDATE SET + last_activity = excluded.last_activity, + message_count = message_count + 1", + params![session_key, now, now], + ) + .map_err(std::io::Error::other)?; + + Ok(()) + } + + fn remove_last(&self, session_key: &str) -> std::io::Result { + let conn = self.conn.lock(); + + let last_id: Option = conn + .query_row( + "SELECT id FROM sessions WHERE session_key = ?1 ORDER BY id DESC LIMIT 1", + params![session_key], + |row| row.get(0), + ) + .ok(); + + let Some(id) = last_id else { + return Ok(false); + }; + + conn.execute("DELETE FROM sessions WHERE id = ?1", params![id]) + .map_err(std::io::Error::other)?; + + // Update metadata count + conn.execute( + "UPDATE session_metadata SET message_count = MAX(0, message_count - 1) + WHERE session_key = ?1", + params![session_key], + ) + .map_err(std::io::Error::other)?; + + Ok(true) + } + + fn list_sessions(&self) -> Vec { + let conn = self.conn.lock(); + let mut stmt = match conn + .prepare("SELECT session_key FROM session_metadata ORDER BY last_activity DESC") + { + Ok(s) => s, + Err(_) => return Vec::new(), + }; + + let rows = match stmt.query_map([], |row| row.get(0)) { + Ok(r) => r, + Err(_) => return Vec::new(), + }; + + rows.filter_map(|r| r.ok()).collect() + } + + fn list_sessions_with_metadata(&self) -> Vec { + let conn = self.conn.lock(); + let mut stmt = match conn.prepare( + "SELECT session_key, created_at, last_activity, message_count + FROM session_metadata ORDER BY last_activity DESC", + ) { + Ok(s) => s, + Err(_) => return Vec::new(), + }; + + let rows = match stmt.query_map([], |row| { + let key: String = row.get(0)?; + let created_str: String = row.get(1)?; + let activity_str: String = row.get(2)?; + let count: i64 = row.get(3)?; + + let created = DateTime::parse_from_rfc3339(&created_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + let activity = DateTime::parse_from_rfc3339(&activity_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + Ok(SessionMetadata { + key, + created_at: created, + last_activity: activity, + message_count: count as usize, + }) + }) { + Ok(r) => r, + Err(_) => return Vec::new(), + }; + + rows.filter_map(|r| r.ok()).collect() + } + + fn cleanup_stale(&self, ttl_hours: u32) -> std::io::Result { + let conn = self.conn.lock(); + let cutoff = (Utc::now() - Duration::hours(i64::from(ttl_hours))).to_rfc3339(); + + // Find stale sessions + let stale_keys: Vec = { + let mut stmt = conn + .prepare("SELECT session_key FROM session_metadata WHERE last_activity < ?1") + .map_err(std::io::Error::other)?; + let rows = stmt + .query_map(params![cutoff], |row| row.get(0)) + .map_err(std::io::Error::other)?; + rows.filter_map(|r| r.ok()).collect() + }; + + let count = stale_keys.len(); + for key in &stale_keys { + let _ = conn.execute("DELETE FROM sessions WHERE session_key = ?1", params![key]); + let _ = conn.execute( + "DELETE FROM session_metadata WHERE session_key = ?1", + params![key], + ); + } + + Ok(count) + } + + fn search(&self, query: &SessionQuery) -> Vec { + let Some(keyword) = &query.keyword else { + return self.list_sessions_with_metadata(); + }; + + let conn = self.conn.lock(); + #[allow(clippy::cast_possible_wrap)] + let limit = query.limit.unwrap_or(50) as i64; + + // FTS5 search + let mut stmt = match conn.prepare( + "SELECT DISTINCT f.session_key + FROM sessions_fts f + WHERE sessions_fts MATCH ?1 + LIMIT ?2", + ) { + Ok(s) => s, + Err(_) => return Vec::new(), + }; + + // Quote each word for FTS5 + let fts_query: String = keyword + .split_whitespace() + .map(|w| format!("\"{w}\"")) + .collect::>() + .join(" OR "); + + let keys: Vec = match stmt.query_map(params![fts_query, limit], |row| row.get(0)) { + Ok(r) => r.filter_map(|r| r.ok()).collect(), + Err(_) => return Vec::new(), + }; + + // Look up metadata for matched sessions + keys.iter() + .filter_map(|key| { + conn.query_row( + "SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1", + params![key], + |row| { + let created_str: String = row.get(0)?; + let activity_str: String = row.get(1)?; + let count: i64 = row.get(2)?; + Ok(SessionMetadata { + key: key.clone(), + created_at: DateTime::parse_from_rfc3339(&created_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()), + last_activity: DateTime::parse_from_rfc3339(&activity_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()), + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + message_count: count as usize, + }) + }, + ) + .ok() + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn round_trip_sqlite() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + + backend + .append("user1", &ChatMessage::user("hello")) + .unwrap(); + backend + .append("user1", &ChatMessage::assistant("hi")) + .unwrap(); + + let msgs = backend.load("user1"); + assert_eq!(msgs.len(), 2); + assert_eq!(msgs[0].role, "user"); + assert_eq!(msgs[1].role, "assistant"); + } + + #[test] + fn remove_last_sqlite() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + + backend.append("u", &ChatMessage::user("a")).unwrap(); + backend.append("u", &ChatMessage::user("b")).unwrap(); + + assert!(backend.remove_last("u").unwrap()); + let msgs = backend.load("u"); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "a"); + } + + #[test] + fn remove_last_empty_sqlite() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + assert!(!backend.remove_last("nonexistent").unwrap()); + } + + #[test] + fn list_sessions_sqlite() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + + backend.append("a", &ChatMessage::user("hi")).unwrap(); + backend.append("b", &ChatMessage::user("hey")).unwrap(); + + let sessions = backend.list_sessions(); + assert_eq!(sessions.len(), 2); + } + + #[test] + fn metadata_tracks_counts() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + + backend.append("s1", &ChatMessage::user("a")).unwrap(); + backend.append("s1", &ChatMessage::user("b")).unwrap(); + backend.append("s1", &ChatMessage::user("c")).unwrap(); + + let meta = backend.list_sessions_with_metadata(); + assert_eq!(meta.len(), 1); + assert_eq!(meta[0].message_count, 3); + } + + #[test] + fn fts5_search_finds_content() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + + backend + .append( + "code_chat", + &ChatMessage::user("How do I parse JSON in Rust?"), + ) + .unwrap(); + backend + .append("weather", &ChatMessage::user("What's the weather today?")) + .unwrap(); + + let results = backend.search(&SessionQuery { + keyword: Some("Rust".into()), + limit: Some(10), + }); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "code_chat"); + } + + #[test] + fn cleanup_stale_removes_old_sessions() { + let tmp = TempDir::new().unwrap(); + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + + // Insert a session with old timestamp + { + let conn = backend.conn.lock(); + let old_time = (Utc::now() - Duration::hours(100)).to_rfc3339(); + conn.execute( + "INSERT INTO sessions (session_key, role, content, created_at) VALUES (?1, ?2, ?3, ?4)", + params!["old_session", "user", "ancient", old_time], + ).unwrap(); + conn.execute( + "INSERT INTO session_metadata (session_key, created_at, last_activity, message_count) VALUES (?1, ?2, ?3, 1)", + params!["old_session", old_time, old_time], + ).unwrap(); + } + + backend + .append("new_session", &ChatMessage::user("fresh")) + .unwrap(); + + let cleaned = backend.cleanup_stale(48).unwrap(); // 48h TTL + assert_eq!(cleaned, 1); + + let sessions = backend.list_sessions(); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0], "new_session"); + } + + #[test] + fn migrate_from_jsonl_imports_and_renames() { + let tmp = TempDir::new().unwrap(); + let sessions_dir = tmp.path().join("sessions"); + std::fs::create_dir_all(&sessions_dir).unwrap(); + + // Create a JSONL file + let jsonl_path = sessions_dir.join("test_user.jsonl"); + std::fs::write( + &jsonl_path, + "{\"role\":\"user\",\"content\":\"hello\"}\n{\"role\":\"assistant\",\"content\":\"hi\"}\n", + ) + .unwrap(); + + let backend = SqliteSessionBackend::new(tmp.path()).unwrap(); + let migrated = backend.migrate_from_jsonl(tmp.path()).unwrap(); + assert_eq!(migrated, 1); + + // JSONL should be renamed + assert!(!jsonl_path.exists()); + assert!(sessions_dir.join("test_user.jsonl.migrated").exists()); + + // Messages should be in SQLite + let msgs = backend.load("test_user"); + assert_eq!(msgs.len(), 2); + assert_eq!(msgs[0].content, "hello"); + } +} diff --git a/src/channels/session_store.rs b/src/channels/session_store.rs index b5b75e821..a9149eb5c 100644 --- a/src/channels/session_store.rs +++ b/src/channels/session_store.rs @@ -5,6 +5,7 @@ //! one-per-line as JSON, never modifying old lines. On daemon restart, sessions //! are loaded from disk to restore conversation context. +use crate::channels::session_backend::SessionBackend; use crate::providers::traits::ChatMessage; use std::io::{BufRead, Write}; use std::path::{Path, PathBuf}; @@ -78,6 +79,37 @@ impl SessionStore { Ok(()) } + /// Remove the last message from a session's JSONL file. + /// + /// Rewrite approach: load all messages, drop the last, rewrite. This is + /// O(n) but rollbacks are rare. + pub fn remove_last(&self, session_key: &str) -> std::io::Result { + let mut messages = self.load(session_key); + if messages.is_empty() { + return Ok(false); + } + messages.pop(); + self.rewrite(session_key, &messages)?; + Ok(true) + } + + /// Compact a session file by rewriting only valid messages (removes corrupt lines). + pub fn compact(&self, session_key: &str) -> std::io::Result<()> { + let messages = self.load(session_key); + self.rewrite(session_key, &messages) + } + + fn rewrite(&self, session_key: &str, messages: &[ChatMessage]) -> std::io::Result<()> { + let path = self.session_path(session_key); + let mut file = std::fs::File::create(&path)?; + for msg in messages { + let json = serde_json::to_string(msg) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + writeln!(file, "{json}")?; + } + Ok(()) + } + /// List all session keys that have files on disk. pub fn list_sessions(&self) -> Vec { let entries = match std::fs::read_dir(&self.sessions_dir) { @@ -95,6 +127,28 @@ impl SessionStore { } } +impl SessionBackend for SessionStore { + fn load(&self, session_key: &str) -> Vec { + self.load(session_key) + } + + fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> { + self.append(session_key, message) + } + + fn remove_last(&self, session_key: &str) -> std::io::Result { + self.remove_last(session_key) + } + + fn list_sessions(&self) -> Vec { + self.list_sessions() + } + + fn compact(&self, session_key: &str) -> std::io::Result<()> { + self.compact(session_key) + } +} + #[cfg(test)] mod tests { use super::*; @@ -178,6 +232,63 @@ mod tests { assert_eq!(lines.len(), 2); } + #[test] + fn remove_last_drops_final_message() { + let tmp = TempDir::new().unwrap(); + let store = SessionStore::new(tmp.path()).unwrap(); + + store + .append("rm_test", &ChatMessage::user("first")) + .unwrap(); + store + .append("rm_test", &ChatMessage::user("second")) + .unwrap(); + + assert!(store.remove_last("rm_test").unwrap()); + let messages = store.load("rm_test"); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].content, "first"); + } + + #[test] + fn remove_last_empty_returns_false() { + let tmp = TempDir::new().unwrap(); + let store = SessionStore::new(tmp.path()).unwrap(); + assert!(!store.remove_last("nonexistent").unwrap()); + } + + #[test] + fn compact_removes_corrupt_lines() { + let tmp = TempDir::new().unwrap(); + let store = SessionStore::new(tmp.path()).unwrap(); + let key = "compact_test"; + + let path = store.session_path(key); + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + let mut file = std::fs::File::create(&path).unwrap(); + writeln!(file, r#"{{"role":"user","content":"ok"}}"#).unwrap(); + writeln!(file, "corrupt line").unwrap(); + writeln!(file, r#"{{"role":"assistant","content":"hi"}}"#).unwrap(); + + store.compact(key).unwrap(); + + let raw = std::fs::read_to_string(&path).unwrap(); + assert_eq!(raw.trim().lines().count(), 2); + } + + #[test] + fn session_backend_trait_works_via_dyn() { + let tmp = TempDir::new().unwrap(); + let store = SessionStore::new(tmp.path()).unwrap(); + let backend: &dyn SessionBackend = &store; + + backend + .append("trait_test", &ChatMessage::user("hello")) + .unwrap(); + let msgs = backend.load("trait_test"); + assert_eq!(msgs.len(), 1); + } + #[test] fn handles_corrupt_lines_gracefully() { let tmp = TempDir::new().unwrap(); diff --git a/src/config/schema.rs b/src/config/schema.rs index db7d60348..89261b9c4 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -3630,6 +3630,13 @@ pub struct ChannelsConfig { /// daemon restarts. Files are stored in `{workspace}/sessions/`. Default: `true`. #[serde(default = "default_true")] pub session_persistence: bool, + /// Session persistence backend: `"jsonl"` (legacy) or `"sqlite"` (new default). + /// SQLite provides FTS5 search, metadata tracking, and TTL cleanup. + #[serde(default = "default_session_backend")] + pub session_backend: String, + /// Auto-archive stale sessions older than this many hours. `0` disables. Default: `0`. + #[serde(default)] + pub session_ttl_hours: u32, } impl ChannelsConfig { @@ -3735,6 +3742,10 @@ fn default_channel_message_timeout_secs() -> u64 { 300 } +fn default_session_backend() -> String { + "sqlite".into() +} + impl Default for ChannelsConfig { fn default() -> Self { Self { @@ -3765,6 +3776,8 @@ impl Default for ChannelsConfig { ack_reactions: true, show_tool_calls: true, session_persistence: true, + session_backend: default_session_backend(), + session_ttl_hours: 0, } } } @@ -7439,6 +7452,8 @@ default_temperature = 0.7 ack_reactions: true, show_tool_calls: true, session_persistence: true, + session_backend: default_session_backend(), + session_ttl_hours: 0, }, memory: MemoryConfig::default(), storage: StorageConfig::default(), @@ -8171,6 +8186,8 @@ allowed_users = ["@ops:matrix.org"] ack_reactions: true, show_tool_calls: true, session_persistence: true, + session_backend: default_session_backend(), + session_ttl_hours: 0, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); @@ -8399,6 +8416,8 @@ channel_id = "C123" ack_reactions: true, show_tool_calls: true, session_persistence: true, + session_backend: default_session_backend(), + session_ttl_hours: 0, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); From 98688c61ff46173a9fad65253f1e9c220fb32466 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 16 Mar 2026 12:44:48 -0400 Subject: [PATCH 3/3] feat(cache): wire two-tier response cache, multi-provider token tracking, and cache analytics - Two-tier response cache: in-memory LRU (hot) + SQLite (warm) with TTL-aware eviction - Wire response cache into agent turn loop (temp==0.0, text-only responses only) - Parse Anthropic cache_creation_input_tokens/cache_read_input_tokens - Parse OpenAI prompt_tokens_details.cached_tokens - Add cached_input_tokens to TokenUsage, prompt_caching to ProviderCapabilities - Add CacheHit/CacheMiss observer events with Prometheus counters - Add response_cache_hot_entries config field (default: 256) --- src/agent/agent.rs | 78 ++++++++++++++++ src/agent/loop_.rs | 1 + src/config/schema.rs | 8 ++ src/daemon/mod.rs | 2 +- src/memory/response_cache.rs | 157 ++++++++++++++++++++++++++------ src/observability/log.rs | 9 ++ src/observability/prometheus.rs | 42 +++++++++ src/observability/traits.rs | 12 +++ src/onboard/wizard.rs | 1 + src/providers/anthropic.rs | 6 ++ src/providers/azure_openai.rs | 3 + src/providers/bedrock.rs | 2 + src/providers/compatible.rs | 3 + src/providers/copilot.rs | 1 + src/providers/gemini.rs | 1 + src/providers/ollama.rs | 2 + src/providers/openai.rs | 10 ++ src/providers/openai_codex.rs | 1 + src/providers/openrouter.rs | 3 + src/providers/traits.rs | 11 +++ tests/support/mock_provider.rs | 2 + 21 files changed, 327 insertions(+), 28 deletions(-) diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 309345a18..2f08c8abb 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -38,6 +38,7 @@ pub struct Agent { available_hints: Vec, route_model_by_hint: HashMap, allowed_tools: Option>, + response_cache: Option>, } pub struct AgentBuilder { @@ -60,6 +61,7 @@ pub struct AgentBuilder { available_hints: Option>, route_model_by_hint: Option>, allowed_tools: Option>, + response_cache: Option>, } impl AgentBuilder { @@ -84,6 +86,7 @@ impl AgentBuilder { available_hints: None, route_model_by_hint: None, allowed_tools: None, + response_cache: None, } } @@ -188,6 +191,14 @@ impl AgentBuilder { self } + pub fn response_cache( + mut self, + cache: Option>, + ) -> Self { + self.response_cache = cache; + self + } + pub fn build(self) -> Result { let mut tools = self .tools @@ -236,6 +247,7 @@ impl AgentBuilder { available_hints: self.available_hints.unwrap_or_default(), route_model_by_hint: self.route_model_by_hint.unwrap_or_default(), allowed_tools: allowed, + response_cache: self.response_cache, }) } } @@ -330,11 +342,25 @@ impl Agent { .collect(); let available_hints: Vec = route_model_by_hint.keys().cloned().collect(); + let response_cache = if config.memory.response_cache_enabled { + crate::memory::response_cache::ResponseCache::with_hot_cache( + &config.workspace_dir, + config.memory.response_cache_ttl_minutes, + config.memory.response_cache_max_entries, + config.memory.response_cache_hot_entries, + ) + .ok() + .map(Arc::new) + } else { + None + }; + Agent::builder() .provider(provider) .tools(tools) .memory(memory) .observer(observer) + .response_cache(response_cache) .tool_dispatcher(tool_dispatcher) .memory_loader(Box::new(DefaultMemoryLoader::new( 5, @@ -513,6 +539,47 @@ impl Agent { for _ in 0..self.config.max_tool_iterations { let messages = self.tool_dispatcher.to_provider_messages(&self.history); + + // Response cache: check before LLM call (only for deterministic, text-only prompts) + let cache_key = if self.temperature == 0.0 { + self.response_cache.as_ref().map(|_| { + let last_user = messages + .iter() + .rfind(|m| m.role == "user") + .map(|m| m.content.as_str()) + .unwrap_or(""); + let system = messages + .iter() + .find(|m| m.role == "system") + .map(|m| m.content.as_str()); + crate::memory::response_cache::ResponseCache::cache_key( + &effective_model, + system, + last_user, + ) + }) + } else { + None + }; + + if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) { + if let Ok(Some(cached)) = cache.get(key) { + self.observer.record_event(&ObserverEvent::CacheHit { + cache_type: "response".into(), + tokens_saved: 0, + }); + self.history + .push(ConversationMessage::Chat(ChatMessage::assistant( + cached.clone(), + ))); + self.trim_history(); + return Ok(cached); + } + self.observer.record_event(&ObserverEvent::CacheMiss { + cache_type: "response".into(), + }); + } + let response = match self .provider .chat( @@ -541,6 +608,17 @@ impl Agent { text }; + // Store in response cache (text-only, no tool calls) + if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) { + let token_count = response + .usage + .as_ref() + .and_then(|u| u.output_tokens) + .unwrap_or(0); + #[allow(clippy::cast_possible_truncation)] + let _ = cache.put(key, &effective_model, &final_text, token_count as u32); + } + self.history .push(ConversationMessage::Chat(ChatMessage::assistant( final_text.clone(), diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index dbaaeb29b..3f667b9a6 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -3977,6 +3977,7 @@ mod tests { ProviderCapabilities { native_tool_calling: false, vision: true, + prompt_caching: false, } } diff --git a/src/config/schema.rs b/src/config/schema.rs index 89261b9c4..fec1626df 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -2650,6 +2650,9 @@ pub struct MemoryConfig { /// Max number of cached responses before LRU eviction (default: 5000) #[serde(default = "default_response_cache_max")] pub response_cache_max_entries: usize, + /// Max in-memory hot cache entries for the two-tier response cache (default: 256) + #[serde(default = "default_response_cache_hot_entries")] + pub response_cache_hot_entries: usize, // ── Memory Snapshot (soul backup to Markdown) ───────────── /// Enable periodic export of core memories to MEMORY_SNAPSHOT.md @@ -2718,6 +2721,10 @@ fn default_response_cache_max() -> usize { 5_000 } +fn default_response_cache_hot_entries() -> usize { + 256 +} + impl Default for MemoryConfig { fn default() -> Self { Self { @@ -2738,6 +2745,7 @@ impl Default for MemoryConfig { response_cache_enabled: false, response_cache_ttl_minutes: default_response_cache_ttl(), response_cache_max_entries: default_response_cache_max(), + response_cache_hot_entries: default_response_cache_hot_entries(), snapshot_enabled: false, snapshot_on_hygiene: false, auto_hydrate: true, diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 7dc8cfe73..eee231220 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -323,7 +323,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)"); crate::health::mark_component_ok("heartbeat"); #[allow(clippy::cast_precision_loss)] - let elapsed = tick_start.elapsed().as_millis() as f64; + let elapsed = tick_start.elapsed().as_millis() as f64; metrics.lock().record_success(elapsed); continue; } diff --git a/src/memory/response_cache.rs b/src/memory/response_cache.rs index 5c6492463..e0af48a49 100644 --- a/src/memory/response_cache.rs +++ b/src/memory/response_cache.rs @@ -10,23 +10,45 @@ use chrono::{Duration, Local}; use parking_lot::Mutex; use rusqlite::{params, Connection}; use sha2::{Digest, Sha256}; +use std::collections::HashMap; use std::path::{Path, PathBuf}; -/// Response cache backed by a dedicated SQLite database. +/// An in-memory hot cache entry for the two-tier response cache. +struct InMemoryEntry { + response: String, + token_count: u32, + created_at: std::time::Instant, + accessed_at: std::time::Instant, +} + +/// Two-tier response cache: in-memory LRU (hot) + SQLite (warm). /// -/// Lives alongside `brain.db` as `response_cache.db` so it can be -/// independently wiped without touching memories. +/// The hot cache avoids SQLite round-trips for frequently repeated prompts. +/// On miss from hot cache, falls through to SQLite. On hit from SQLite, +/// the entry is promoted to the hot cache. pub struct ResponseCache { conn: Mutex, #[allow(dead_code)] db_path: PathBuf, ttl_minutes: i64, max_entries: usize, + hot_cache: Mutex>, + hot_max_entries: usize, } impl ResponseCache { /// Open (or create) the response cache database. pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result { + Self::with_hot_cache(workspace_dir, ttl_minutes, max_entries, 256) + } + + /// Open (or create) the response cache database with a custom hot cache size. + pub fn with_hot_cache( + workspace_dir: &Path, + ttl_minutes: u32, + max_entries: usize, + hot_max_entries: usize, + ) -> Result { let db_dir = workspace_dir.join("memory"); std::fs::create_dir_all(&db_dir)?; let db_path = db_dir.join("response_cache.db"); @@ -58,6 +80,8 @@ impl ResponseCache { db_path, ttl_minutes: i64::from(ttl_minutes), max_entries, + hot_cache: Mutex::new(HashMap::new()), + hot_max_entries, }) } @@ -76,35 +100,77 @@ impl ResponseCache { } /// Look up a cached response. Returns `None` on miss or expired entry. + /// + /// Two-tier lookup: checks the in-memory hot cache first, then falls + /// through to SQLite. On a SQLite hit the entry is promoted to hot cache. + #[allow(clippy::cast_sign_loss)] pub fn get(&self, key: &str) -> Result> { - let conn = self.conn.lock(); - - let now = Local::now(); - let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339(); - - let mut stmt = conn.prepare( - "SELECT response FROM response_cache - WHERE prompt_hash = ?1 AND created_at > ?2", - )?; - - let result: Option = stmt.query_row(params![key, cutoff], |row| row.get(0)).ok(); - - if result.is_some() { - // Bump hit count and accessed_at - let now_str = now.to_rfc3339(); - conn.execute( - "UPDATE response_cache - SET accessed_at = ?1, hit_count = hit_count + 1 - WHERE prompt_hash = ?2", - params![now_str, key], - )?; + // Tier 1: hot cache (with TTL check) + { + let mut hot = self.hot_cache.lock(); + if let Some(entry) = hot.get_mut(key) { + let ttl = std::time::Duration::from_secs(self.ttl_minutes as u64 * 60); + if entry.created_at.elapsed() > ttl { + hot.remove(key); + } else { + entry.accessed_at = std::time::Instant::now(); + let response = entry.response.clone(); + drop(hot); + // Still bump SQLite hit count for accurate stats + let conn = self.conn.lock(); + let now_str = Local::now().to_rfc3339(); + conn.execute( + "UPDATE response_cache + SET accessed_at = ?1, hit_count = hit_count + 1 + WHERE prompt_hash = ?2", + params![now_str, key], + )?; + return Ok(Some(response)); + } + } } - Ok(result) + // Tier 2: SQLite (warm) + let result: Option<(String, u32)> = { + let conn = self.conn.lock(); + let now = Local::now(); + let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339(); + + let mut stmt = conn.prepare( + "SELECT response, token_count FROM response_cache + WHERE prompt_hash = ?1 AND created_at > ?2", + )?; + + let result: Option<(String, u32)> = stmt + .query_row(params![key, cutoff], |row| Ok((row.get(0)?, row.get(1)?))) + .ok(); + + if result.is_some() { + let now_str = now.to_rfc3339(); + conn.execute( + "UPDATE response_cache + SET accessed_at = ?1, hit_count = hit_count + 1 + WHERE prompt_hash = ?2", + params![now_str, key], + )?; + } + + result + }; + + if let Some((ref response, token_count)) = result { + self.promote_to_hot(key, response, token_count); + } + + Ok(result.map(|(r, _)| r)) } - /// Store a response in the cache. + /// Store a response in the cache (both hot and warm tiers). pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> { + // Write to hot cache + self.promote_to_hot(key, response, token_count); + + // Write to SQLite (warm) let conn = self.conn.lock(); let now = Local::now().to_rfc3339(); @@ -138,6 +204,43 @@ impl ResponseCache { Ok(()) } + /// Promote an entry to the in-memory hot cache, evicting the oldest if full. + fn promote_to_hot(&self, key: &str, response: &str, token_count: u32) { + let mut hot = self.hot_cache.lock(); + + // If already present, just update (keep original created_at for TTL) + if let Some(entry) = hot.get_mut(key) { + entry.response = response.to_string(); + entry.token_count = token_count; + entry.accessed_at = std::time::Instant::now(); + return; + } + + // Evict oldest entry if at capacity + if self.hot_max_entries > 0 && hot.len() >= self.hot_max_entries { + if let Some(oldest_key) = hot + .iter() + .min_by_key(|(_, v)| v.accessed_at) + .map(|(k, _)| k.clone()) + { + hot.remove(&oldest_key); + } + } + + if self.hot_max_entries > 0 { + let now = std::time::Instant::now(); + hot.insert( + key.to_string(), + InMemoryEntry { + response: response.to_string(), + token_count, + created_at: now, + accessed_at: now, + }, + ); + } + } + /// Return cache statistics: (total_entries, total_hits, total_tokens_saved). pub fn stats(&self) -> Result<(usize, u64, u64)> { let conn = self.conn.lock(); @@ -163,8 +266,8 @@ impl ResponseCache { /// Wipe the entire cache (useful for `zeroclaw cache clear`). pub fn clear(&self) -> Result { + self.hot_cache.lock().clear(); let conn = self.conn.lock(); - let affected = conn.execute("DELETE FROM response_cache", [])?; Ok(affected) } diff --git a/src/observability/log.rs b/src/observability/log.rs index e4b4a4ddb..0b7194fe2 100644 --- a/src/observability/log.rs +++ b/src/observability/log.rs @@ -47,6 +47,15 @@ impl Observer for LogObserver { ObserverEvent::HeartbeatTick => { info!("heartbeat.tick"); } + ObserverEvent::CacheHit { + cache_type, + tokens_saved, + } => { + info!(cache_type = %cache_type, tokens_saved = tokens_saved, "cache.hit"); + } + ObserverEvent::CacheMiss { cache_type } => { + info!(cache_type = %cache_type, "cache.miss"); + } ObserverEvent::Error { component, message } => { info!(component = %component, error = %message, "error"); } diff --git a/src/observability/prometheus.rs b/src/observability/prometheus.rs index 4fbb1c67a..444651913 100644 --- a/src/observability/prometheus.rs +++ b/src/observability/prometheus.rs @@ -16,6 +16,9 @@ pub struct PrometheusObserver { channel_messages: IntCounterVec, heartbeat_ticks: prometheus::IntCounter, errors: IntCounterVec, + cache_hits: IntCounterVec, + cache_misses: IntCounterVec, + cache_tokens_saved: IntCounterVec, // Histograms agent_duration: HistogramVec, @@ -81,6 +84,27 @@ impl PrometheusObserver { ) .expect("valid metric"); + let cache_hits = IntCounterVec::new( + prometheus::Opts::new("zeroclaw_cache_hits_total", "Total response cache hits"), + &["cache_type"], + ) + .expect("valid metric"); + + let cache_misses = IntCounterVec::new( + prometheus::Opts::new("zeroclaw_cache_misses_total", "Total response cache misses"), + &["cache_type"], + ) + .expect("valid metric"); + + let cache_tokens_saved = IntCounterVec::new( + prometheus::Opts::new( + "zeroclaw_cache_tokens_saved_total", + "Total tokens saved by response cache", + ), + &["cache_type"], + ) + .expect("valid metric"); + let agent_duration = HistogramVec::new( HistogramOpts::new( "zeroclaw_agent_duration_seconds", @@ -139,6 +163,9 @@ impl PrometheusObserver { registry.register(Box::new(channel_messages.clone())).ok(); registry.register(Box::new(heartbeat_ticks.clone())).ok(); registry.register(Box::new(errors.clone())).ok(); + registry.register(Box::new(cache_hits.clone())).ok(); + registry.register(Box::new(cache_misses.clone())).ok(); + registry.register(Box::new(cache_tokens_saved.clone())).ok(); registry.register(Box::new(agent_duration.clone())).ok(); registry.register(Box::new(tool_duration.clone())).ok(); registry.register(Box::new(request_latency.clone())).ok(); @@ -156,6 +183,9 @@ impl PrometheusObserver { channel_messages, heartbeat_ticks, errors, + cache_hits, + cache_misses, + cache_tokens_saved, agent_duration, tool_duration, request_latency, @@ -245,6 +275,18 @@ impl Observer for PrometheusObserver { ObserverEvent::HeartbeatTick => { self.heartbeat_ticks.inc(); } + ObserverEvent::CacheHit { + cache_type, + tokens_saved, + } => { + self.cache_hits.with_label_values(&[cache_type]).inc(); + self.cache_tokens_saved + .with_label_values(&[cache_type]) + .inc_by(*tokens_saved); + } + ObserverEvent::CacheMiss { cache_type } => { + self.cache_misses.with_label_values(&[cache_type]).inc(); + } ObserverEvent::Error { component, message: _, diff --git a/src/observability/traits.rs b/src/observability/traits.rs index c1391aa2e..3b68feafc 100644 --- a/src/observability/traits.rs +++ b/src/observability/traits.rs @@ -61,6 +61,18 @@ pub enum ObserverEvent { }, /// Periodic heartbeat tick from the runtime keep-alive loop. HeartbeatTick, + /// Response cache hit — an LLM call was avoided. + CacheHit { + /// `"hot"` (in-memory) or `"warm"` (SQLite). + cache_type: String, + /// Estimated tokens saved by this cache hit. + tokens_saved: u64, + }, + /// Response cache miss — the prompt was not found in cache. + CacheMiss { + /// `"response"` cache layer that was checked. + cache_type: String, + }, /// An error occurred in a named component. Error { /// Subsystem where the error originated (e.g., `"provider"`, `"gateway"`). diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 00b643005..c0dde3161 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -402,6 +402,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig { response_cache_enabled: false, response_cache_ttl_minutes: 60, response_cache_max_entries: 5_000, + response_cache_hot_entries: 256, snapshot_enabled: false, snapshot_on_hygiene: false, auto_hydrate: true, diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index b91998f84..a93cad476 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -149,6 +149,10 @@ struct AnthropicUsage { input_tokens: Option, #[serde(default)] output_tokens: Option, + #[serde(default)] + cache_creation_input_tokens: Option, + #[serde(default)] + cache_read_input_tokens: Option, } #[derive(Debug, Deserialize)] @@ -475,6 +479,7 @@ impl AnthropicProvider { let usage = response.usage.map(|u| TokenUsage { input_tokens: u.input_tokens, output_tokens: u.output_tokens, + cached_input_tokens: u.cache_read_input_tokens, }); for block in response.content { @@ -614,6 +619,7 @@ impl Provider for AnthropicProvider { ProviderCapabilities { native_tool_calling: true, vision: true, + prompt_caching: true, } } diff --git a/src/providers/azure_openai.rs b/src/providers/azure_openai.rs index 1bdaeee07..7f053e7c4 100644 --- a/src/providers/azure_openai.rs +++ b/src/providers/azure_openai.rs @@ -312,6 +312,7 @@ impl Provider for AzureOpenAiProvider { ProviderCapabilities { native_tool_calling: true, vision: true, + prompt_caching: false, } } @@ -431,6 +432,7 @@ impl Provider for AzureOpenAiProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let message = native_response .choices @@ -491,6 +493,7 @@ impl Provider for AzureOpenAiProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let message = native_response .choices diff --git a/src/providers/bedrock.rs b/src/providers/bedrock.rs index ad353d3cd..d92c00f76 100644 --- a/src/providers/bedrock.rs +++ b/src/providers/bedrock.rs @@ -832,6 +832,7 @@ impl BedrockProvider { let usage = response.usage.map(|u| TokenUsage { input_tokens: u.input_tokens, output_tokens: u.output_tokens, + cached_input_tokens: None, }); if let Some(output) = response.output { @@ -967,6 +968,7 @@ impl Provider for BedrockProvider { ProviderCapabilities { native_tool_calling: true, vision: true, + prompt_caching: false, } } diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index e221cc289..2741d1066 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -1193,6 +1193,7 @@ impl Provider for OpenAiCompatibleProvider { crate::providers::traits::ProviderCapabilities { native_tool_calling: self.native_tool_calling, vision: self.supports_vision, + prompt_caching: false, } } @@ -1514,6 +1515,7 @@ impl Provider for OpenAiCompatibleProvider { let usage = chat_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let choice = chat_response .choices @@ -1657,6 +1659,7 @@ impl Provider for OpenAiCompatibleProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let message = native_response .choices diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs index 96ef39382..3f82cb817 100644 --- a/src/providers/copilot.rs +++ b/src/providers/copilot.rs @@ -353,6 +353,7 @@ impl CopilotProvider { let usage = api_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let choice = api_response .choices diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 31ab5becc..7085ae2cb 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -1128,6 +1128,7 @@ impl GeminiProvider { let usage = result.usage_metadata.map(|u| TokenUsage { input_tokens: u.prompt_token_count, output_tokens: u.candidates_token_count, + cached_input_tokens: None, }); let text = result diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index b1324cfa5..13637b2b3 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -632,6 +632,7 @@ impl Provider for OllamaProvider { ProviderCapabilities { native_tool_calling: true, vision: true, + prompt_caching: false, } } @@ -764,6 +765,7 @@ impl Provider for OllamaProvider { Some(TokenUsage { input_tokens: response.prompt_eval_count, output_tokens: response.eval_count, + cached_input_tokens: None, }) } else { None diff --git a/src/providers/openai.rs b/src/providers/openai.rs index ae9f5ca32..7db47bd3d 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -135,6 +135,14 @@ struct UsageInfo { prompt_tokens: Option, #[serde(default)] completion_tokens: Option, + #[serde(default)] + prompt_tokens_details: Option, +} + +#[derive(Debug, Deserialize)] +struct PromptTokensDetails { + #[serde(default)] + cached_tokens: Option, } #[derive(Debug, Deserialize)] @@ -385,6 +393,7 @@ impl Provider for OpenAiProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens), }); let message = native_response .choices @@ -448,6 +457,7 @@ impl Provider for OpenAiProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens), }); let message = native_response .choices diff --git a/src/providers/openai_codex.rs b/src/providers/openai_codex.rs index 457aa42df..96a7c1e41 100644 --- a/src/providers/openai_codex.rs +++ b/src/providers/openai_codex.rs @@ -640,6 +640,7 @@ impl Provider for OpenAiCodexProvider { ProviderCapabilities { native_tool_calling: false, vision: true, + prompt_caching: false, } } diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 3443b48db..c1bbdca0b 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -306,6 +306,7 @@ impl Provider for OpenRouterProvider { ProviderCapabilities { native_tool_calling: true, vision: true, + prompt_caching: false, } } @@ -463,6 +464,7 @@ impl Provider for OpenRouterProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let message = native_response .choices @@ -554,6 +556,7 @@ impl Provider for OpenRouterProvider { let usage = native_response.usage.map(|u| TokenUsage { input_tokens: u.prompt_tokens, output_tokens: u.completion_tokens, + cached_input_tokens: None, }); let message = native_response .choices diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 1f30b602e..765eff455 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -54,6 +54,9 @@ pub struct ToolCall { pub struct TokenUsage { pub input_tokens: Option, pub output_tokens: Option, + /// Tokens served from the provider's prompt cache (Anthropic `cache_read_input_tokens`, + /// OpenAI `prompt_tokens_details.cached_tokens`). + pub cached_input_tokens: Option, } /// An LLM response that may contain text, tool calls, or both. @@ -233,6 +236,9 @@ pub struct ProviderCapabilities { pub native_tool_calling: bool, /// Whether the provider supports vision / image inputs. pub vision: bool, + /// Whether the provider supports prompt caching (Anthropic cache_control, + /// OpenAI automatic prompt caching). + pub prompt_caching: bool, } /// Provider-specific tool payload formats. @@ -498,6 +504,7 @@ mod tests { ProviderCapabilities { native_tool_calling: true, vision: true, + prompt_caching: false, } } @@ -568,6 +575,7 @@ mod tests { usage: Some(TokenUsage { input_tokens: Some(100), output_tokens: Some(50), + cached_input_tokens: None, }), reasoning_content: None, }; @@ -613,14 +621,17 @@ mod tests { let caps1 = ProviderCapabilities { native_tool_calling: true, vision: false, + prompt_caching: false, }; let caps2 = ProviderCapabilities { native_tool_calling: true, vision: false, + prompt_caching: false, }; let caps3 = ProviderCapabilities { native_tool_calling: false, vision: false, + prompt_caching: false, }; assert_eq!(caps1, caps2); diff --git a/tests/support/mock_provider.rs b/tests/support/mock_provider.rs index 40e6ea6b1..e587a9fb9 100644 --- a/tests/support/mock_provider.rs +++ b/tests/support/mock_provider.rs @@ -166,6 +166,7 @@ impl Provider for TraceLlmProvider { usage: Some(TokenUsage { input_tokens: Some(input_tokens), output_tokens: Some(output_tokens), + cached_input_tokens: None, }), reasoning_content: None, }), @@ -188,6 +189,7 @@ impl Provider for TraceLlmProvider { usage: Some(TokenUsage { input_tokens: Some(input_tokens), output_tokens: Some(output_tokens), + cached_input_tokens: None, }), reasoning_content: None, })