diff --git a/Cargo.toml b/Cargo.toml index 1fde74817..b4c91ed03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tokio-stream = { version = "0.1.18", default-features = false, features = ["fs", reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking", "multipart", "stream", "socks"] } # Matrix client + E2EE decryption -matrix-sdk = { version = "0.16", optional = true, default-features = false, features = ["e2e-encryption", "rustls-tls", "markdown"] } +matrix-sdk = { version = "0.16", optional = true, default-features = false, features = ["e2e-encryption", "rustls-tls", "markdown", "sqlite"] } # Serialization serde = { version = "1.0", default-features = false, features = ["derive"] } diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs index e0ad8ad67..e6d4c3836 100644 --- a/src/channels/matrix.rs +++ b/src/channels/matrix.rs @@ -13,6 +13,7 @@ use matrix_sdk::{ }; use reqwest::Client; use serde::Deserialize; +use std::path::PathBuf; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, OnceCell, RwLock}; @@ -26,6 +27,7 @@ pub struct MatrixChannel { allowed_users: Vec, session_owner_hint: Option, session_device_id_hint: Option, + zeroclaw_dir: Option, resolved_room_id_cache: Arc>>, sdk_client: Arc>, http_client: Client, @@ -120,6 +122,26 @@ impl MatrixChannel { allowed_users: Vec, owner_hint: Option, device_id_hint: Option, + ) -> Self { + Self::new_with_session_hint_and_zeroclaw_dir( + homeserver, + access_token, + room_id, + allowed_users, + owner_hint, + device_id_hint, + None, + ) + } + + pub fn new_with_session_hint_and_zeroclaw_dir( + homeserver: String, + access_token: String, + room_id: String, + allowed_users: Vec, + owner_hint: Option, + device_id_hint: Option, + zeroclaw_dir: Option, ) -> Self { let homeserver = homeserver.trim_end_matches('/').to_string(); let access_token = access_token.trim().to_string(); @@ -137,6 +159,7 @@ impl MatrixChannel { allowed_users, session_owner_hint: Self::normalize_optional_field(owner_hint), session_device_id_hint: Self::normalize_optional_field(device_id_hint), + zeroclaw_dir, resolved_room_id_cache: Arc::new(RwLock::new(None)), sdk_client: Arc::new(OnceCell::new()), http_client: Client::new(), @@ -168,6 +191,12 @@ impl MatrixChannel { format!("Bearer {}", self.access_token) } + fn matrix_store_dir(&self) -> Option { + self.zeroclaw_dir + .as_ref() + .map(|dir| dir.join("state").join("matrix")) + } + fn is_user_allowed(&self, sender: &str) -> bool { Self::is_sender_allowed(&self.allowed_users, sender) } @@ -314,10 +343,19 @@ impl MatrixChannel { } }; - let client = MatrixSdkClient::builder() - .homeserver_url(&self.homeserver) - .build() - .await?; + let mut client_builder = MatrixSdkClient::builder().homeserver_url(&self.homeserver); + + if let Some(store_dir) = self.matrix_store_dir() { + tokio::fs::create_dir_all(&store_dir).await.map_err(|error| { + anyhow::anyhow!( + "Matrix failed to initialize persistent store directory at '{}': {error}", + store_dir.display() + ) + })?; + client_builder = client_builder.sqlite_store(&store_dir, None); + } + + let client = client_builder.build().await?; let user_id: OwnedUserId = resolved_user_id.parse()?; let session = MatrixSession { @@ -744,6 +782,38 @@ mod tests { assert!(ch.session_device_id_hint.is_none()); } + #[test] + fn matrix_store_dir_is_derived_from_zeroclaw_dir() { + let ch = MatrixChannel::new_with_session_hint_and_zeroclaw_dir( + "https://matrix.org".to_string(), + "tok".to_string(), + "!r:m".to_string(), + vec![], + None, + None, + Some(PathBuf::from("/tmp/zeroclaw")), + ); + + assert_eq!( + ch.matrix_store_dir(), + Some(PathBuf::from("/tmp/zeroclaw/state/matrix")) + ); + } + + #[test] + fn matrix_store_dir_absent_without_zeroclaw_dir() { + let ch = MatrixChannel::new_with_session_hint( + "https://matrix.org".to_string(), + "tok".to_string(), + "!r:m".to_string(), + vec![], + None, + None, + ); + + assert!(ch.matrix_store_dir().is_none()); + } + #[test] fn encode_path_segment_encodes_room_refs() { assert_eq!( diff --git a/src/channels/mod.rs b/src/channels/mod.rs index ebf61fe02..2f89522ce 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -2690,13 +2690,14 @@ fn collect_configured_channels( if let Some(ref mx) = config.channels_config.matrix { channels.push(ConfiguredChannel { display_name: "Matrix", - channel: Arc::new(MatrixChannel::new_with_session_hint( + channel: Arc::new(MatrixChannel::new_with_session_hint_and_zeroclaw_dir( mx.homeserver.clone(), mx.access_token.clone(), mx.room_id.clone(), mx.allowed_users.clone(), mx.user_id.clone(), mx.device_id.clone(), + config.config_path.parent().map(|path| path.to_path_buf()), )), }); } @@ -3965,6 +3966,7 @@ BTC is currently around $65,000 based on latest tool output."# workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, interrupt_on_new_message: false, + non_cli_excluded_tools: Arc::new(Vec::new()), multimodal: crate::config::MultimodalConfig::default(), hooks: None, non_cli_excluded_tools: Arc::new(Vec::new()),