diff --git a/.gitignore b/.gitignore index 2e419747b..4c7b4e4fc 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,7 @@ venv/ *.key *.pem credentials.json -config.toml +/config.toml .worktrees/ # Nix diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 0f00304c8..b31719921 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -10,7 +10,7 @@ use crate::runtime; use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; -use anyhow::Result; +use anyhow::{Context as _, Result}; use regex::{Regex, RegexSet}; use rustyline::completion::{Completer, Pair}; use rustyline::error::ReadlineError; @@ -19,12 +19,11 @@ use rustyline::hint::Hinter; use rustyline::validate::Validator; use rustyline::{CompletionType, Config as RlConfig, Context, Editor, Helper}; use std::borrow::Cow; -use std::collections::{BTreeSet, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Write; use std::io::Write as _; -use std::sync::{Arc, LazyLock}; +use std::sync::{Arc, LazyLock, Mutex}; use std::time::{Duration, Instant}; -use tokio::sync::OnceCell; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -134,15 +133,28 @@ impl Highlighter for SlashCommandCompleter { impl Validator for SlashCommandCompleter {} impl Helper for SlashCommandCompleter {} -static CHANNEL_SESSION_MANAGER: OnceCell>> = OnceCell::const_new(); +static CHANNEL_SESSION_MANAGER: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); async fn channel_session_manager(config: &Config) -> Result>> { - let mgr = CHANNEL_SESSION_MANAGER - .get_or_try_init(|| async { - create_session_manager(&config.agent.session, &config.workspace_dir) - }) - .await?; - Ok(mgr.clone()) + let key = format!("{:?}:{:?}", config.workspace_dir, config.agent.session); + + { + let map = CHANNEL_SESSION_MANAGER.lock().unwrap(); + if let Some(mgr) = map.get(&key) { + return Ok(Some(mgr.clone())); + } + } + + let mgr_opt = create_session_manager(&config.agent.session, &config.workspace_dir)?; + + if let Some(mgr) = mgr_opt { + let mut map = CHANNEL_SESSION_MANAGER.lock().unwrap(); + map.insert(key, mgr.clone()); + Ok(Some(mgr)) + } else { + Ok(None) + } } static SENSITIVE_KEY_PATTERNS: LazyLock = LazyLock::new(|| { RegexSet::new([ @@ -2221,7 +2233,10 @@ pub async fn process_message( .into_iter() .filter(|m| m.role != "system") .collect(); - let _ = session.update_history(persisted).await; + session + .update_history(persisted) + .await + .context("Failed to update session history")?; Ok(output) } else { let mut history = vec![ diff --git a/src/agent/session.rs b/src/agent/session.rs index 0c9704469..3a5857e6a 100644 --- a/src/agent/session.rs +++ b/src/agent/session.rs @@ -1,6 +1,6 @@ use crate::providers::ChatMessage; use crate::{config::AgentSessionBackend, config::AgentSessionConfig, config::AgentSessionStrategy}; -use anyhow::Result; +use anyhow::{Context, Result}; use async_trait::async_trait; use parking_lot::Mutex; use rusqlite::{params, Connection}; @@ -237,6 +237,23 @@ impl SqliteSessionManager { } }); } + + #[cfg(test)] + pub async fn force_expire_session(&self, session_id: &str, age: Duration) -> Result<()> { + let conn = self.conn.clone(); + let session_id = session_id.to_string(); + let age_secs = age.as_secs() as i64; + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + let new_time = unix_seconds_now() - age_secs; + conn.execute( + "UPDATE agent_sessions SET updated_at = ?2 WHERE session_id = ?1", + params![session_id, new_time], + )?; + Ok(()) + }).await? + } } #[async_trait] @@ -247,63 +264,85 @@ impl SessionManager for SqliteSessionManager { async fn get_history(&self, session_id: &str) -> Result> { let now = unix_seconds_now(); - let conn = self.conn.lock(); - let mut stmt = conn.prepare( - "SELECT history_json FROM agent_sessions WHERE session_id = ?1", - )?; - let mut rows = stmt.query(params![session_id])?; - if let Some(row) = rows.next()? { - let json: String = row.get(0)?; + let conn = self.conn.clone(); + let session_id = session_id.to_string(); + let max_messages = self.max_messages; + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + let mut stmt = conn.prepare( + "SELECT history_json FROM agent_sessions WHERE session_id = ?1", + )?; + let mut rows = stmt.query(params![session_id])?; + if let Some(row) = rows.next()? { + let json: String = row.get(0)?; + conn.execute( + "UPDATE agent_sessions SET updated_at = ?2 WHERE session_id = ?1", + params![session_id, now], + )?; + let mut history: Vec = serde_json::from_str(&json) + .with_context(|| format!("Failed to parse session history for session_id={session_id}"))?; + trim_non_system(&mut history, max_messages); + return Ok(history); + } + conn.execute( - "UPDATE agent_sessions SET updated_at = ?2 WHERE session_id = ?1", + "INSERT INTO agent_sessions(session_id, history_json, updated_at) VALUES(?1, '[]', ?2)", params![session_id, now], )?; - let mut history: Vec = serde_json::from_str(&json).unwrap_or_default(); - trim_non_system(&mut history, self.max_messages); - return Ok(history); - } - - conn.execute( - "INSERT INTO agent_sessions(session_id, history_json, updated_at) VALUES(?1, '[]', ?2)", - params![session_id, now], - )?; - Ok(Vec::new()) + Ok(Vec::new()) + }).await? } async fn set_history(&self, session_id: &str, mut history: Vec) -> Result<()> { trim_non_system(&mut history, self.max_messages); let json = serde_json::to_string(&history)?; let now = unix_seconds_now(); - let conn = self.conn.lock(); - conn.execute( - "INSERT INTO agent_sessions(session_id, history_json, updated_at) - VALUES(?1, ?2, ?3) - ON CONFLICT(session_id) DO UPDATE SET history_json=excluded.history_json, updated_at=excluded.updated_at", - params![session_id, json, now], - )?; - Ok(()) + let conn = self.conn.clone(); + let session_id = session_id.to_string(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + conn.execute( + "INSERT INTO agent_sessions(session_id, history_json, updated_at) + VALUES(?1, ?2, ?3) + ON CONFLICT(session_id) DO UPDATE SET history_json=excluded.history_json, updated_at=excluded.updated_at", + params![session_id, json, now], + )?; + Ok(()) + }).await? } async fn delete(&self, session_id: &str) -> Result<()> { - let conn = self.conn.lock(); - conn.execute( - "DELETE FROM agent_sessions WHERE session_id = ?1", - params![session_id], - )?; - Ok(()) + let conn = self.conn.clone(); + let session_id = session_id.to_string(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + conn.execute( + "DELETE FROM agent_sessions WHERE session_id = ?1", + params![session_id], + )?; + Ok(()) + }).await? } async fn cleanup_expired(&self) -> Result { if self.ttl.is_zero() { return Ok(0); } - let cutoff = unix_seconds_now() - self.ttl.as_secs() as i64; - let conn = self.conn.lock(); - let removed = conn.execute( - "DELETE FROM agent_sessions WHERE updated_at < ?1", - params![cutoff], - )?; - Ok(removed) + let conn = self.conn.clone(); + let ttl_secs = self.ttl.as_secs() as i64; + + tokio::task::spawn_blocking(move || { + let cutoff = unix_seconds_now() - ttl_secs; + let conn = conn.lock(); + let removed = conn.execute( + "DELETE FROM agent_sessions WHERE updated_at < ?1", + params![cutoff], + )?; + Ok(removed) + }).await? } } @@ -427,12 +466,16 @@ mod tests { async fn sqlite_session_cleanup_expires() -> Result<()> { let dir = tempfile::tempdir()?; let db_path = dir.path().join("sessions.db"); + // TTL 1 second let mgr = SqliteSessionManager::new(db_path, Duration::from_secs(1), 50)?; let session = mgr.get_or_create("s1").await?; session .update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")]) .await?; - tokio::time::sleep(Duration::from_millis(2100)).await; + + // Force expire by setting age to 2 seconds + mgr.force_expire_session("s1", Duration::from_secs(2)).await?; + let removed = mgr.cleanup_expired().await?; assert!(removed >= 1); Ok(()) diff --git a/src/config/schema.rs b/src/config/schema.rs index 7ad97ddc5..2ad136c17 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -728,14 +728,27 @@ pub enum AgentSessionStrategy { Main, } +/// Session persistence configuration (`[agent.session]` section). #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentSessionConfig { + /// Session backend to use. Options: "memory", "sqlite", "none". + /// Default: "none" (no persistence). + /// Set to "none" to disable session persistence entirely. #[serde(default = "default_agent_session_backend")] pub backend: AgentSessionBackend, + + /// Strategy for resolving session IDs. Options: "per-sender", "per-channel", "main". + /// Default: "per-sender" (each user gets a unique session per channel). #[serde(default = "default_agent_session_strategy")] pub strategy: AgentSessionStrategy, + + /// Time-to-live for sessions in seconds. + /// Default: 3600 (1 hour). #[serde(default = "default_agent_session_ttl_seconds")] pub ttl_seconds: u64, + + /// Maximum number of messages to retain per session. + /// Default: 50. #[serde(default = "default_agent_session_max_messages")] pub max_messages: usize, } diff --git a/src/tools/mcp_transport.rs b/src/tools/mcp_transport.rs index e472689bf..375ee543a 100644 --- a/src/tools/mcp_transport.rs +++ b/src/tools/mcp_transport.rs @@ -272,7 +272,7 @@ impl McpTransportConn for SseTransport { let resp_text = resp.text().await.context("failed to read SSE response")?; let json_str = extract_json_from_sse_text(&resp_text); let mcp_resp: JsonRpcResponse = serde_json::from_str(json_str.as_ref()) - .with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?; + .with_context(|| format!("invalid JSON-RPC response (len={})", resp_text.len()))?; Ok(mcp_resp) }