From 327e2b4c472c2b3e2f40405dbcf0b124bc2be47e Mon Sep 17 00:00:00 2001 From: Argenis Date: Sun, 15 Mar 2026 23:34:26 -0400 Subject: [PATCH 01/11] style: cargo fmt Box::pin calls in cron scheduler (#3667) Co-authored-by: Claude Opus 4.6 --- src/cron/scheduler.rs | 37 ++++++++++++++++++++++++------------- src/tools/cron_run.rs | 3 ++- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 2335dad6d..290cfd482 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -101,18 +101,21 @@ async fn process_due_jobs( crate::health::mark_component_ok(component); let max_concurrent = config.scheduler.max_concurrent.max(1); - let mut in_flight = - stream::iter( - jobs.into_iter().map(|job| { - let config = config.clone(); - let security = Arc::clone(security); - let component = component.to_owned(); - async move { - Box::pin(execute_and_persist_job(&config, security.as_ref(), &job, &component)).await - } - }), - ) - .buffer_unordered(max_concurrent); + let mut in_flight = stream::iter(jobs.into_iter().map(|job| { + let config = config.clone(); + let security = Arc::clone(security); + let component = component.to_owned(); + async move { + Box::pin(execute_and_persist_job( + &config, + security.as_ref(), + &job, + &component, + )) + .await + } + })) + .buffer_unordered(max_concurrent); while let Some((job_id, success, output)) = in_flight.next().await { if !success { @@ -133,7 +136,15 @@ async fn execute_and_persist_job( let started_at = Utc::now(); let (success, output) = Box::pin(execute_job_with_retry(config, security, job)).await; let finished_at = Utc::now(); - let success = Box::pin(persist_job_result(config, job, success, &output, started_at, finished_at)).await; + let success = Box::pin(persist_job_result( + config, + job, + success, + &output, + started_at, + finished_at, + )) + .await; (job.id.clone(), success, output) } diff --git a/src/tools/cron_run.rs b/src/tools/cron_run.rs index bb70cbb28..deed1d2c2 100644 --- a/src/tools/cron_run.rs +++ b/src/tools/cron_run.rs @@ -116,7 +116,8 @@ impl Tool for CronRunTool { } let started_at = Utc::now(); - let (success, output) = Box::pin(cron::scheduler::execute_job_now(&self.config, &job)).await; + let (success, output) = + Box::pin(cron::scheduler::execute_job_now(&self.config, &job)).await; let finished_at = Utc::now(); let duration_ms = (finished_at - started_at).num_milliseconds(); let status = if success { "ok" } else { "error" }; From 82fe2e53fd028c4546f0bc4d4d1a4bfc9f480dea Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 00:34:34 -0400 Subject: [PATCH 02/11] feat(tunnel): add OpenVPN tunnel provider (#3648) * feat(tunnel): add OpenVPN tunnel provider Add OpenVPN as a new tunnel provider alongside cloudflare, tailscale, ngrok, and custom. Includes config schema, validation, factory wiring, and comprehensive unit tests. Co-authored-by: rareba Co-Authored-By: Claude Opus 4.6 * fix: add missing approval_manager field to ChannelRuntimeContext constructors Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: rareba Co-authored-by: Claude Opus 4.6 --- src/config/mod.rs | 10 +- src/config/schema.rs | 53 ++++++++- src/tunnel/mod.rs | 61 +++++++++- src/tunnel/openvpn.rs | 254 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 369 insertions(+), 9 deletions(-) create mode 100644 src/tunnel/openvpn.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index e9dc86f12..98938aa1d 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -13,11 +13,11 @@ pub use schema::{ GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig, - NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod, - PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig, - QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, - SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, - SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, + NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, + OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, + QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, + RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, + SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, diff --git a/src/config/schema.rs b/src/config/schema.rs index c38987d9d..490891b18 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -3069,10 +3069,10 @@ impl Default for CronConfig { /// Tunnel configuration for exposing the gateway publicly (`[tunnel]` section). /// -/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"custom"`. +/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, `"custom"`. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct TunnelConfig { - /// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, or `"custom"`. Default: `"none"`. + /// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, or `"custom"`. Default: `"none"`. pub provider: String, /// Cloudflare Tunnel configuration (used when `provider = "cloudflare"`). @@ -3087,6 +3087,10 @@ pub struct TunnelConfig { #[serde(default)] pub ngrok: Option, + /// OpenVPN tunnel configuration (used when `provider = "openvpn"`). + #[serde(default)] + pub openvpn: Option, + /// Custom tunnel command configuration (used when `provider = "custom"`). #[serde(default)] pub custom: Option, @@ -3099,6 +3103,7 @@ impl Default for TunnelConfig { cloudflare: None, tailscale: None, ngrok: None, + openvpn: None, custom: None, } } @@ -3127,6 +3132,36 @@ pub struct NgrokTunnelConfig { pub domain: Option, } +/// OpenVPN tunnel configuration (`[tunnel.openvpn]`). +/// +/// Required when `tunnel.provider = "openvpn"`. Omitting this section entirely +/// preserves previous behavior. Setting `tunnel.provider = "none"` (or removing +/// the `[tunnel.openvpn]` block) cleanly reverts to no-tunnel mode. +/// +/// Defaults: `connect_timeout_secs = 30`. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct OpenVpnTunnelConfig { + /// Path to `.ovpn` configuration file (must not be empty). + pub config_file: String, + /// Optional path to auth credentials file (`--auth-user-pass`). + #[serde(default)] + pub auth_file: Option, + /// Advertised address once VPN is connected (e.g., `"10.8.0.2:42617"`). + /// When omitted the tunnel falls back to `http://{local_host}:{local_port}`. + #[serde(default)] + pub advertise_address: Option, + /// Connection timeout in seconds (default: 30, must be > 0). + #[serde(default = "default_openvpn_timeout")] + pub connect_timeout_secs: u64, + /// Extra openvpn CLI arguments forwarded verbatim. + #[serde(default)] + pub extra_args: Vec, +} + +fn default_openvpn_timeout() -> u64 { + 30 +} + #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct CustomTunnelConfig { /// Command template to start the tunnel. Use {port} and {host} placeholders. @@ -5172,6 +5207,20 @@ impl Config { /// Called after TOML deserialization and env-override application to catch /// obviously invalid values early instead of failing at arbitrary runtime points. pub fn validate(&self) -> Result<()> { + // Tunnel — OpenVPN + if self.tunnel.provider.trim() == "openvpn" { + let openvpn = self.tunnel.openvpn.as_ref().ok_or_else(|| { + anyhow::anyhow!("tunnel.provider='openvpn' requires [tunnel.openvpn]") + })?; + + if openvpn.config_file.trim().is_empty() { + anyhow::bail!("tunnel.openvpn.config_file must not be empty"); + } + if openvpn.connect_timeout_secs == 0 { + anyhow::bail!("tunnel.openvpn.connect_timeout_secs must be greater than 0"); + } + } + // Gateway if self.gateway.host.trim().is_empty() { anyhow::bail!("gateway.host must not be empty"); diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 6a852d8cc..52424f8a5 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,6 +2,7 @@ mod cloudflare; mod custom; mod ngrok; mod none; +mod openvpn; mod tailscale; pub use cloudflare::CloudflareTunnel; @@ -9,6 +10,7 @@ pub use custom::CustomTunnel; pub use ngrok::NgrokTunnel; #[allow(unused_imports)] pub use none::NoneTunnel; +pub use openvpn::OpenVpnTunnel; pub use tailscale::TailscaleTunnel; use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig}; @@ -104,6 +106,20 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result>> { )))) } + "openvpn" => { + let ov = config + .openvpn + .as_ref() + .ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?; + Ok(Some(Box::new(OpenVpnTunnel::new( + ov.config_file.clone(), + ov.auth_file.clone(), + ov.advertise_address.clone(), + ov.connect_timeout_secs, + ov.extra_args.clone(), + )))) + } + "custom" => { let cu = config .custom @@ -116,7 +132,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result>> { )))) } - other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"), + other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"), } } @@ -126,7 +142,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result>> { mod tests { use super::*; use crate::config::schema::{ - CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig, + CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig, + TunnelConfig, }; use tokio::process::Command; @@ -315,6 +332,46 @@ mod tests { assert!(t.public_url().is_none()); } + #[test] + fn factory_openvpn_missing_config_errors() { + let cfg = TunnelConfig { + provider: "openvpn".into(), + ..TunnelConfig::default() + }; + assert_tunnel_err(&cfg, "[tunnel.openvpn]"); + } + + #[test] + fn factory_openvpn_with_config_ok() { + let cfg = TunnelConfig { + provider: "openvpn".into(), + openvpn: Some(OpenVpnTunnelConfig { + config_file: "client.ovpn".into(), + auth_file: None, + advertise_address: None, + connect_timeout_secs: 30, + extra_args: vec![], + }), + ..TunnelConfig::default() + }; + let t = create_tunnel(&cfg).unwrap(); + assert!(t.is_some()); + assert_eq!(t.unwrap().name(), "openvpn"); + } + + #[test] + fn openvpn_tunnel_name() { + let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert_eq!(t.name(), "openvpn"); + assert!(t.public_url().is_none()); + } + + #[tokio::test] + async fn openvpn_health_false_before_start() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert!(!tunnel.health_check().await); + } + #[tokio::test] async fn kill_shared_no_process_is_ok() { let proc = new_shared_process(); diff --git a/src/tunnel/openvpn.rs b/src/tunnel/openvpn.rs new file mode 100644 index 000000000..dd7f72ad7 --- /dev/null +++ b/src/tunnel/openvpn.rs @@ -0,0 +1,254 @@ +use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess}; +use anyhow::{bail, Result}; +use tokio::io::AsyncBufReadExt; +use tokio::process::Command; + +/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection. +/// +/// Requires the `openvpn` binary installed and accessible. On most systems, +/// OpenVPN requires root/administrator privileges to create tun/tap devices. +/// +/// The tunnel exposes the gateway via the VPN network using a configured +/// `advertise_address` (e.g., `"10.8.0.2:42617"`). +pub struct OpenVpnTunnel { + config_file: String, + auth_file: Option, + advertise_address: Option, + connect_timeout_secs: u64, + extra_args: Vec, + proc: SharedProcess, +} + +impl OpenVpnTunnel { + /// Create a new OpenVPN tunnel instance. + /// + /// * `config_file` — path to the `.ovpn` configuration file. + /// * `auth_file` — optional path to a credentials file for `--auth-user-pass`. + /// * `advertise_address` — optional public address to advertise once connected. + /// * `connect_timeout_secs` — seconds to wait for the initialization sequence. + /// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary. + pub fn new( + config_file: String, + auth_file: Option, + advertise_address: Option, + connect_timeout_secs: u64, + extra_args: Vec, + ) -> Self { + Self { + config_file, + auth_file, + advertise_address, + connect_timeout_secs, + extra_args, + proc: new_shared_process(), + } + } + + /// Build the openvpn command arguments. + fn build_args(&self) -> Vec { + let mut args = vec!["--config".to_string(), self.config_file.clone()]; + + if let Some(ref auth) = self.auth_file { + args.push("--auth-user-pass".to_string()); + args.push(auth.clone()); + } + + args.extend(self.extra_args.iter().cloned()); + args + } +} + +#[async_trait::async_trait] +impl Tunnel for OpenVpnTunnel { + fn name(&self) -> &str { + "openvpn" + } + + /// Spawn the `openvpn` process and wait for the "Initialization Sequence + /// Completed" marker on stderr. Returns the public URL on success. + async fn start(&self, local_host: &str, local_port: u16) -> Result { + // Validate config file exists before spawning + if !std::path::Path::new(&self.config_file).exists() { + bail!("OpenVPN config file not found: {}", self.config_file); + } + + let args = self.build_args(); + + let mut child = Command::new("openvpn") + .args(&args) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn()?; + + // Wait for "Initialization Sequence Completed" in stderr + let stderr = child + .stderr + .take() + .ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?; + + let mut reader = tokio::io::BufReader::new(stderr).lines(); + let deadline = tokio::time::Instant::now() + + tokio::time::Duration::from_secs(self.connect_timeout_secs); + + let mut connected = false; + while tokio::time::Instant::now() < deadline { + let line = + tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await; + + match line { + Ok(Ok(Some(l))) => { + tracing::debug!("openvpn: {l}"); + if l.contains("Initialization Sequence Completed") { + connected = true; + break; + } + } + Ok(Ok(None)) => { + bail!("OpenVPN process exited before connection was established"); + } + Ok(Err(e)) => { + bail!("Error reading openvpn output: {e}"); + } + Err(_) => { + // Timeout on individual line read, continue waiting + } + } + } + + if !connected { + child.kill().await.ok(); + bail!( + "OpenVPN connection timed out after {}s waiting for initialization", + self.connect_timeout_secs + ); + } + + let public_url = self + .advertise_address + .clone() + .unwrap_or_else(|| format!("http://{local_host}:{local_port}")); + + // Drain stderr in background to prevent OS pipe buffer from filling and + // blocking the openvpn process. + tokio::spawn(async move { + while let Ok(Some(line)) = reader.next_line().await { + tracing::trace!("openvpn: {line}"); + } + }); + + let mut guard = self.proc.lock().await; + *guard = Some(TunnelProcess { + child, + public_url: public_url.clone(), + }); + + Ok(public_url) + } + + /// Kill the openvpn child process and release its resources. + async fn stop(&self) -> Result<()> { + kill_shared(&self.proc).await + } + + /// Return `true` if the openvpn child process is still running. + async fn health_check(&self) -> bool { + let guard = self.proc.lock().await; + guard.as_ref().is_some_and(|tp| tp.child.id().is_some()) + } + + /// Return the public URL if the tunnel has been started. + fn public_url(&self) -> Option { + self.proc + .try_lock() + .ok() + .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constructor_stores_fields() { + let tunnel = OpenVpnTunnel::new( + "/etc/openvpn/client.ovpn".into(), + Some("/etc/openvpn/auth.txt".into()), + Some("10.8.0.2:42617".into()), + 45, + vec!["--verb".into(), "3".into()], + ); + assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn"); + assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt")); + assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617")); + assert_eq!(tunnel.connect_timeout_secs, 45); + assert_eq!(tunnel.extra_args, vec!["--verb", "3"]); + } + + #[test] + fn build_args_basic() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + let args = tunnel.build_args(); + assert_eq!(args, vec!["--config", "client.ovpn"]); + } + + #[test] + fn build_args_with_auth_and_extras() { + let tunnel = OpenVpnTunnel::new( + "client.ovpn".into(), + Some("auth.txt".into()), + None, + 30, + vec!["--verb".into(), "5".into()], + ); + let args = tunnel.build_args(); + assert_eq!( + args, + vec![ + "--config", + "client.ovpn", + "--auth-user-pass", + "auth.txt", + "--verb", + "5" + ] + ); + } + + #[test] + fn public_url_is_none_before_start() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert!(tunnel.public_url().is_none()); + } + + #[tokio::test] + async fn health_check_is_false_before_start() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert!(!tunnel.health_check().await); + } + + #[tokio::test] + async fn stop_without_started_process_is_ok() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + let result = tunnel.stop().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn start_with_missing_config_file_errors() { + let tunnel = OpenVpnTunnel::new( + "/nonexistent/path/to/client.ovpn".into(), + None, + None, + 30, + vec![], + ); + let result = tunnel.start("127.0.0.1", 8080).await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("config file not found")); + } +} From 75701195d7be0af9a839d12afaf2501053d30334 Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 00:34:52 -0400 Subject: [PATCH 03/11] feat(security): add Nevis IAM integration for SSO/MFA authentication (#3651) * feat(security): add Nevis IAM integration for SSO/MFA authentication Add NevisAuthProvider supporting OAuth2/OIDC token validation (local JWKS + remote introspection), FIDO2/passkey/OTP MFA verification, session management, and health checks. Add IamPolicy engine mapping Nevis roles to ZeroClaw tool and workspace permissions with deny-by-default enforcement and audit logging. Add NevisConfig and NevisRoleMappingConfig to config schema with client_secret wired through SecretStore encrypt/decrypt. All features disabled by default. Rebased on latest master to resolve merge conflicts in security/mod.rs (redact function) and config/schema.rs (test section). Original work by @rareba. Supersedes #3593. Co-Authored-By: rareba <5985289+rareba@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 * style: cargo fmt Box::pin calls Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: rareba <5985289+rareba@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/config/schema.rs | 365 +++++++++++++++++++++++ src/cron/scheduler.rs | 6 +- src/security/iam_policy.rs | 449 ++++++++++++++++++++++++++++ src/security/mod.rs | 22 +- src/security/nevis.rs | 587 +++++++++++++++++++++++++++++++++++++ 5 files changed, 1417 insertions(+), 12 deletions(-) create mode 100644 src/security/iam_policy.rs create mode 100644 src/security/nevis.rs diff --git a/src/config/schema.rs b/src/config/schema.rs index 490891b18..e39e1233b 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -3941,6 +3941,10 @@ pub struct SecurityConfig { /// Emergency-stop state machine configuration. #[serde(default)] pub estop: EstopConfig, + + /// Nevis IAM integration for SSO/MFA authentication and role-based access. + #[serde(default)] + pub nevis: NevisConfig, } /// OTP validation strategy. @@ -4052,6 +4056,163 @@ impl Default for EstopConfig { } } +/// Nevis IAM integration configuration. +/// +/// When `enabled` is true, ZeroClaw validates incoming requests against a Nevis +/// Security Suite instance and maps Nevis roles to tool/workspace permissions. +#[derive(Clone, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct NevisConfig { + /// Enable Nevis IAM integration. Defaults to false for backward compatibility. + #[serde(default)] + pub enabled: bool, + + /// Base URL of the Nevis instance (e.g. `https://nevis.example.com`). + #[serde(default)] + pub instance_url: String, + + /// Nevis realm to authenticate against. + #[serde(default = "default_nevis_realm")] + pub realm: String, + + /// OAuth2 client ID registered in Nevis. + #[serde(default)] + pub client_id: String, + + /// OAuth2 client secret. Encrypted via SecretStore when stored on disk. + #[serde(default)] + pub client_secret: Option, + + /// Token validation strategy: `"local"` (JWKS) or `"remote"` (introspection). + #[serde(default = "default_nevis_token_validation")] + pub token_validation: String, + + /// JWKS endpoint URL for local token validation. + #[serde(default)] + pub jwks_url: Option, + + /// Nevis role to ZeroClaw permission mappings. + #[serde(default)] + pub role_mapping: Vec, + + /// Require MFA verification for all Nevis-authenticated requests. + #[serde(default)] + pub require_mfa: bool, + + /// Session timeout in seconds. + #[serde(default = "default_nevis_session_timeout_secs")] + pub session_timeout_secs: u64, +} + +impl std::fmt::Debug for NevisConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NevisConfig") + .field("enabled", &self.enabled) + .field("instance_url", &self.instance_url) + .field("realm", &self.realm) + .field("client_id", &self.client_id) + .field( + "client_secret", + &self.client_secret.as_ref().map(|_| "[REDACTED]"), + ) + .field("token_validation", &self.token_validation) + .field("jwks_url", &self.jwks_url) + .field("role_mapping", &self.role_mapping) + .field("require_mfa", &self.require_mfa) + .field("session_timeout_secs", &self.session_timeout_secs) + .finish() + } +} + +impl NevisConfig { + /// Validate that required fields are present when Nevis is enabled. + /// + /// Call at config load time to fail fast on invalid configuration rather + /// than deferring errors to the first authentication request. + pub fn validate(&self) -> Result<(), String> { + if !self.enabled { + return Ok(()); + } + + if self.instance_url.trim().is_empty() { + return Err("nevis.instance_url is required when Nevis IAM is enabled".into()); + } + + if self.client_id.trim().is_empty() { + return Err("nevis.client_id is required when Nevis IAM is enabled".into()); + } + + if self.realm.trim().is_empty() { + return Err("nevis.realm is required when Nevis IAM is enabled".into()); + } + + match self.token_validation.as_str() { + "local" | "remote" => {} + other => { + return Err(format!( + "nevis.token_validation has invalid value '{other}': \ + expected 'local' or 'remote'" + )); + } + } + + if self.token_validation == "local" && self.jwks_url.is_none() { + return Err("nevis.jwks_url is required when token_validation is 'local'".into()); + } + + if self.session_timeout_secs == 0 { + return Err("nevis.session_timeout_secs must be greater than 0".into()); + } + + Ok(()) + } +} + +fn default_nevis_realm() -> String { + "master".into() +} + +fn default_nevis_token_validation() -> String { + "local".into() +} + +fn default_nevis_session_timeout_secs() -> u64 { + 3600 +} + +impl Default for NevisConfig { + fn default() -> Self { + Self { + enabled: false, + instance_url: String::new(), + realm: default_nevis_realm(), + client_id: String::new(), + client_secret: None, + token_validation: default_nevis_token_validation(), + jwks_url: None, + role_mapping: Vec::new(), + require_mfa: false, + session_timeout_secs: default_nevis_session_timeout_secs(), + } + } +} + +/// Maps a Nevis role to ZeroClaw tool permissions and workspace access. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct NevisRoleMappingConfig { + /// Nevis role name (case-insensitive). + pub nevis_role: String, + + /// Tool names this role can access. Use `"all"` for unrestricted tool access. + #[serde(default)] + pub zeroclaw_permissions: Vec, + + /// Workspace names this role can access. Use `"all"` for unrestricted. + #[serde(default)] + pub workspace_access: Vec, +} + /// Sandbox configuration for OS-level isolation #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SandboxConfig { @@ -5072,6 +5233,13 @@ impl Config { decrypt_secret(&store, token, "config.gateway.paired_tokens[]")?; } + // Decrypt Nevis IAM secret + decrypt_optional_secret( + &store, + &mut config.security.nevis.client_secret, + "config.security.nevis.client_secret", + )?; + config.apply_env_overrides(); config.validate()?; tracing::info!( @@ -5390,6 +5558,11 @@ impl Config { validate_mcp_config(&self.mcp)?; } + // Nevis IAM — delegate to NevisConfig::validate() for field-level checks + if let Err(msg) = self.security.nevis.validate() { + anyhow::bail!("security.nevis: {msg}"); + } + Ok(()) } @@ -6013,6 +6186,13 @@ impl Config { encrypt_secret(&store, token, "config.gateway.paired_tokens[]")?; } + // Encrypt Nevis IAM secret + encrypt_optional_secret( + &store, + &mut config_to_save.security.nevis.client_secret, + "config.security.nevis.client_secret", + )?; + let toml_str = toml::to_string_pretty(&config_to_save).context("Failed to serialize config")?; @@ -6100,6 +6280,7 @@ impl Config { } } +#[allow(clippy::unused_async)] // async needed on unix for tokio File I/O; no-op on other platforms async fn sync_directory(path: &Path) -> Result<()> { #[cfg(unix)] { @@ -6125,6 +6306,7 @@ mod tests { #[cfg(unix)] use std::os::unix::fs::PermissionsExt; use std::path::PathBuf; + #[cfg(unix)] use tempfile::TempDir; use tokio::sync::{Mutex, MutexGuard}; use tokio::test; @@ -9569,4 +9751,187 @@ require_otp_to_resume = true assert_eq!(config.swarms.len(), 1); assert!(config.swarms.contains_key("pipeline")); } + + #[tokio::test] + async fn nevis_client_secret_encrypt_decrypt_roundtrip() { + let dir = std::env::temp_dir().join(format!( + "zeroclaw_test_nevis_secret_{}", + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&dir).await.unwrap(); + + let plaintext_secret = "nevis-test-client-secret-value"; + + let mut config = Config::default(); + config.workspace_dir = dir.join("workspace"); + config.config_path = dir.join("config.toml"); + config.security.nevis.client_secret = Some(plaintext_secret.into()); + + // Save (triggers encryption) + config.save().await.unwrap(); + + // Read raw TOML and verify plaintext secret is NOT present + let raw_toml = tokio::fs::read_to_string(&config.config_path) + .await + .unwrap(); + assert!( + !raw_toml.contains(plaintext_secret), + "Saved TOML must not contain the plaintext client_secret" + ); + + // Parse stored TOML and verify the value is encrypted + let stored: Config = toml::from_str(&raw_toml).unwrap(); + let stored_secret = stored.security.nevis.client_secret.as_ref().unwrap(); + assert!( + crate::security::SecretStore::is_encrypted(stored_secret), + "Stored client_secret must be marked as encrypted" + ); + + // Decrypt and verify it matches the original plaintext + let store = crate::security::SecretStore::new(&dir, true); + assert_eq!(store.decrypt(stored_secret).unwrap(), plaintext_secret); + + // Simulate a full load: deserialize then decrypt (mirrors load_or_init logic) + let mut loaded: Config = toml::from_str(&raw_toml).unwrap(); + loaded.config_path = dir.join("config.toml"); + let load_store = crate::security::SecretStore::new(&dir, loaded.secrets.encrypt); + decrypt_optional_secret( + &load_store, + &mut loaded.security.nevis.client_secret, + "config.security.nevis.client_secret", + ) + .unwrap(); + assert_eq!( + loaded.security.nevis.client_secret.as_deref().unwrap(), + plaintext_secret, + "Loaded client_secret must match the original plaintext after decryption" + ); + + let _ = fs::remove_dir_all(&dir).await; + } + + // ══════════════════════════════════════════════════════════ + // Nevis config validation tests + // ══════════════════════════════════════════════════════════ + + #[test] + async fn nevis_config_validate_disabled_accepts_empty_fields() { + let cfg = NevisConfig::default(); + assert!(!cfg.enabled); + assert!(cfg.validate().is_ok()); + } + + #[test] + async fn nevis_config_validate_rejects_empty_instance_url() { + let cfg = NevisConfig { + enabled: true, + instance_url: String::new(), + client_id: "test-client".into(), + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("instance_url")); + } + + #[test] + async fn nevis_config_validate_rejects_empty_client_id() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: String::new(), + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("client_id")); + } + + #[test] + async fn nevis_config_validate_rejects_empty_realm() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: "test-client".into(), + realm: String::new(), + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("realm")); + } + + #[test] + async fn nevis_config_validate_rejects_local_without_jwks() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: "test-client".into(), + token_validation: "local".into(), + jwks_url: None, + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("jwks_url")); + } + + #[test] + async fn nevis_config_validate_rejects_zero_session_timeout() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: "test-client".into(), + token_validation: "remote".into(), + session_timeout_secs: 0, + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("session_timeout_secs")); + } + + #[test] + async fn nevis_config_validate_accepts_valid_enabled_config() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + realm: "master".into(), + client_id: "test-client".into(), + token_validation: "remote".into(), + session_timeout_secs: 3600, + ..NevisConfig::default() + }; + assert!(cfg.validate().is_ok()); + } + + #[test] + async fn nevis_config_validate_rejects_invalid_token_validation() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + realm: "master".into(), + client_id: "test-client".into(), + token_validation: "invalid_mode".into(), + session_timeout_secs: 3600, + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!( + err.contains("invalid value 'invalid_mode'"), + "Expected invalid token_validation error, got: {err}" + ); + } + + #[test] + async fn nevis_config_debug_redacts_client_secret() { + let cfg = NevisConfig { + client_secret: Some("super-secret".into()), + ..NevisConfig::default() + }; + let debug_output = format!("{:?}", cfg); + assert!( + !debug_output.contains("super-secret"), + "Debug output must not contain the raw client_secret" + ); + assert!( + debug_output.contains("[REDACTED]"), + "Debug output must show [REDACTED] for client_secret" + ); + } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 290cfd482..992609f14 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -784,7 +784,7 @@ mod tests { job.prompt = Some("Say hello".into()); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &security, &job).await; + let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await; assert!(!success); assert!(output.contains("agent job failed:")); } @@ -799,7 +799,7 @@ mod tests { job.prompt = Some("Say hello".into()); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &security, &job).await; + let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await; assert!(!success); assert!(output.contains("blocked by security policy")); assert!(output.contains("read-only")); @@ -815,7 +815,7 @@ mod tests { job.prompt = Some("Say hello".into()); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &security, &job).await; + let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await; assert!(!success); assert!(output.contains("blocked by security policy")); assert!(output.contains("rate limit exceeded")); diff --git a/src/security/iam_policy.rs b/src/security/iam_policy.rs new file mode 100644 index 000000000..36a5fab00 --- /dev/null +++ b/src/security/iam_policy.rs @@ -0,0 +1,449 @@ +//! IAM-aware policy enforcement for Nevis role-to-permission mapping. +//! +//! Evaluates tool and workspace access based on Nevis roles using a +//! deny-by-default policy model. All policy decisions are audit-logged. + +use super::nevis::NevisIdentity; +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Maps a single Nevis role to ZeroClaw permissions. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoleMapping { + /// Nevis role name (case-insensitive matching). + pub nevis_role: String, + /// Tool names this role can access. Use `"all"` to grant all tools. + pub zeroclaw_permissions: Vec, + /// Workspace names this role can access. Use `"all"` for unrestricted. + #[serde(default)] + pub workspace_access: Vec, +} + +/// Result of a policy evaluation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PolicyDecision { + /// Access is allowed. + Allow, + /// Access is denied, with reason. + Deny(String), +} + +impl PolicyDecision { + pub fn is_allowed(&self) -> bool { + matches!(self, PolicyDecision::Allow) + } +} + +/// IAM policy engine that maps Nevis roles to ZeroClaw tool permissions. +/// +/// Deny-by-default: if no role mapping grants access, the request is denied. +#[derive(Debug, Clone)] +pub struct IamPolicy { + /// Compiled role mappings indexed by lowercase Nevis role name. + role_map: HashMap, +} + +#[derive(Debug, Clone)] +struct CompiledRole { + /// Whether this role has access to all tools. + all_tools: bool, + /// Specific tool names this role can access (lowercase). + allowed_tools: Vec, + /// Whether this role has access to all workspaces. + all_workspaces: bool, + /// Specific workspace names this role can access (lowercase). + allowed_workspaces: Vec, +} + +impl IamPolicy { + /// Build a policy from role mappings (typically from config). + /// + /// Returns an error if duplicate normalized role names are detected, + /// since silent last-wins overwrites can accidentally broaden or revoke access. + pub fn from_mappings(mappings: &[RoleMapping]) -> Result { + let mut role_map = HashMap::new(); + + for mapping in mappings { + let key = mapping.nevis_role.trim().to_ascii_lowercase(); + if key.is_empty() { + continue; + } + + let all_tools = mapping + .zeroclaw_permissions + .iter() + .any(|p| p.eq_ignore_ascii_case("all")); + let allowed_tools: Vec = mapping + .zeroclaw_permissions + .iter() + .filter(|p| !p.eq_ignore_ascii_case("all")) + .map(|p| p.trim().to_ascii_lowercase()) + .collect(); + + let all_workspaces = mapping + .workspace_access + .iter() + .any(|w| w.eq_ignore_ascii_case("all")); + let allowed_workspaces: Vec = mapping + .workspace_access + .iter() + .filter(|w| !w.eq_ignore_ascii_case("all")) + .map(|w| w.trim().to_ascii_lowercase()) + .collect(); + + if role_map.contains_key(&key) { + bail!( + "IAM policy: duplicate role mapping for normalized key '{}' \ + (from nevis_role '{}') — remove or merge the duplicate entry", + key, + mapping.nevis_role + ); + } + + role_map.insert( + key, + CompiledRole { + all_tools, + allowed_tools, + all_workspaces, + allowed_workspaces, + }, + ); + } + + Ok(Self { role_map }) + } + + /// Evaluate whether an identity is allowed to use a specific tool. + /// + /// Deny-by-default: returns `Deny` unless at least one of the identity's + /// roles grants access to the requested tool. + pub fn evaluate_tool_access( + &self, + identity: &NevisIdentity, + tool_name: &str, + ) -> PolicyDecision { + let normalized_tool = tool_name.trim().to_ascii_lowercase(); + if normalized_tool.is_empty() { + return PolicyDecision::Deny("empty tool name".into()); + } + + for role in &identity.roles { + let key = role.trim().to_ascii_lowercase(); + if let Some(compiled) = self.role_map.get(&key) { + if compiled.all_tools + || compiled.allowed_tools.iter().any(|t| t == &normalized_tool) + { + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + role = %key, + tool = %normalized_tool, + "IAM policy: tool access ALLOWED" + ); + return PolicyDecision::Allow; + } + } + } + + let reason = format!( + "no role grants access to tool '{normalized_tool}' for user '{}'", + crate::security::redact(&identity.user_id) + ); + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + tool = %normalized_tool, + "IAM policy: tool access DENIED" + ); + PolicyDecision::Deny(reason) + } + + /// Evaluate whether an identity is allowed to access a specific workspace. + /// + /// Deny-by-default: returns `Deny` unless at least one of the identity's + /// roles grants access to the requested workspace. + pub fn evaluate_workspace_access( + &self, + identity: &NevisIdentity, + workspace: &str, + ) -> PolicyDecision { + let normalized_ws = workspace.trim().to_ascii_lowercase(); + if normalized_ws.is_empty() { + return PolicyDecision::Deny("empty workspace name".into()); + } + + for role in &identity.roles { + let key = role.trim().to_ascii_lowercase(); + if let Some(compiled) = self.role_map.get(&key) { + if compiled.all_workspaces + || compiled + .allowed_workspaces + .iter() + .any(|w| w == &normalized_ws) + { + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + role = %key, + workspace = %normalized_ws, + "IAM policy: workspace access ALLOWED" + ); + return PolicyDecision::Allow; + } + } + } + + let reason = format!( + "no role grants access to workspace '{normalized_ws}' for user '{}'", + crate::security::redact(&identity.user_id) + ); + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + workspace = %normalized_ws, + "IAM policy: workspace access DENIED" + ); + PolicyDecision::Deny(reason) + } + + /// Check if the policy has any role mappings configured. + pub fn is_empty(&self) -> bool { + self.role_map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_mappings() -> Vec { + vec![ + RoleMapping { + nevis_role: "admin".into(), + zeroclaw_permissions: vec!["all".into()], + workspace_access: vec!["all".into()], + }, + RoleMapping { + nevis_role: "operator".into(), + zeroclaw_permissions: vec![ + "shell".into(), + "file_read".into(), + "file_write".into(), + "memory_search".into(), + ], + workspace_access: vec!["production".into(), "staging".into()], + }, + RoleMapping { + nevis_role: "viewer".into(), + zeroclaw_permissions: vec!["file_read".into(), "memory_search".into()], + workspace_access: vec!["staging".into()], + }, + ] + } + + fn identity_with_roles(roles: Vec<&str>) -> NevisIdentity { + NevisIdentity { + user_id: "zeroclaw_user".into(), + roles: roles.into_iter().map(String::from).collect(), + scopes: vec!["openid".into()], + mfa_verified: true, + session_expiry: u64::MAX, + } + } + + #[test] + fn admin_gets_all_tools() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "any_tool_name") + .is_allowed()); + } + + #[test] + fn admin_gets_all_workspaces() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(policy + .evaluate_workspace_access(&identity, "production") + .is_allowed()); + assert!(policy + .evaluate_workspace_access(&identity, "any_workspace") + .is_allowed()); + } + + #[test] + fn operator_gets_subset_of_tools() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["operator"]); + + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(!policy + .evaluate_tool_access(&identity, "browser") + .is_allowed()); + } + + #[test] + fn operator_workspace_access_is_scoped() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["operator"]); + + assert!(policy + .evaluate_workspace_access(&identity, "production") + .is_allowed()); + assert!(policy + .evaluate_workspace_access(&identity, "staging") + .is_allowed()); + assert!(!policy + .evaluate_workspace_access(&identity, "development") + .is_allowed()); + } + + #[test] + fn viewer_is_read_only() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["viewer"]); + + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "memory_search") + .is_allowed()); + assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(!policy + .evaluate_tool_access(&identity, "file_write") + .is_allowed()); + } + + #[test] + fn deny_by_default_for_unknown_role() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["unknown_role"]); + + assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(!policy + .evaluate_workspace_access(&identity, "production") + .is_allowed()); + } + + #[test] + fn deny_by_default_for_no_roles() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec![]); + + assert!(!policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + } + + #[test] + fn multiple_roles_union_permissions() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["viewer", "operator"]); + + // viewer has file_read, operator has shell — both should be accessible + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + } + + #[test] + fn role_matching_is_case_insensitive() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["ADMIN"]); + + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + } + + #[test] + fn tool_matching_is_case_insensitive() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["operator"]); + + assert!(policy.evaluate_tool_access(&identity, "SHELL").is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "File_Read") + .is_allowed()); + } + + #[test] + fn empty_tool_name_is_denied() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(!policy.evaluate_tool_access(&identity, "").is_allowed()); + assert!(!policy.evaluate_tool_access(&identity, " ").is_allowed()); + } + + #[test] + fn empty_workspace_name_is_denied() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(!policy.evaluate_workspace_access(&identity, "").is_allowed()); + } + + #[test] + fn empty_mappings_deny_everything() { + let policy = IamPolicy::from_mappings(&[]).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(policy.is_empty()); + assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed()); + } + + #[test] + fn policy_decision_deny_contains_reason() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["viewer"]); + + let decision = policy.evaluate_tool_access(&identity, "shell"); + match decision { + PolicyDecision::Deny(reason) => { + assert!(reason.contains("shell")); + } + PolicyDecision::Allow => panic!("expected deny"), + } + } + + #[test] + fn duplicate_normalized_roles_are_rejected() { + let mappings = vec![ + RoleMapping { + nevis_role: "admin".into(), + zeroclaw_permissions: vec!["all".into()], + workspace_access: vec!["all".into()], + }, + RoleMapping { + nevis_role: " ADMIN ".into(), + zeroclaw_permissions: vec!["file_read".into()], + workspace_access: vec![], + }, + ]; + let err = IamPolicy::from_mappings(&mappings).unwrap_err(); + assert!( + err.to_string().contains("duplicate role mapping"), + "Expected duplicate role error, got: {err}" + ); + } + + #[test] + fn empty_role_name_in_mapping_is_skipped() { + let mappings = vec![RoleMapping { + nevis_role: " ".into(), + zeroclaw_permissions: vec!["all".into()], + workspace_access: vec![], + }]; + let policy = IamPolicy::from_mappings(&mappings).unwrap(); + assert!(policy.is_empty()); + } +} diff --git a/src/security/mod.rs b/src/security/mod.rs index 37f62c531..f80268427 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -29,9 +29,11 @@ pub mod domain_matcher; pub mod estop; #[cfg(target_os = "linux")] pub mod firejail; +pub mod iam_policy; #[cfg(feature = "sandbox-landlock")] pub mod landlock; pub mod leak_detector; +pub mod nevis; pub mod otp; pub mod pairing; pub mod policy; @@ -56,6 +58,11 @@ pub use policy::{AutonomyLevel, SecurityPolicy}; pub use secrets::SecretStore; #[allow(unused_imports)] pub use traits::{NoopSandbox, Sandbox}; +// Nevis IAM integration +#[allow(unused_imports)] +pub use iam_policy::{IamPolicy, PolicyDecision}; +#[allow(unused_imports)] +pub use nevis::{NevisAuthProvider, NevisIdentity}; // Prompt injection defense exports #[allow(unused_imports)] pub use leak_detector::{LeakDetector, LeakResult}; @@ -64,19 +71,16 @@ pub use prompt_guard::{GuardAction, GuardResult, PromptGuard}; #[allow(unused_imports)] pub use workspace_boundary::{BoundaryVerdict, WorkspaceBoundary}; -/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix. +/// Redact sensitive values for safe logging. Shows first 4 characters + "***" suffix. +/// Uses char-boundary-safe indexing to avoid panics on multi-byte UTF-8 strings. /// This function intentionally breaks the data-flow taint chain for static analysis. pub fn redact(value: &str) -> String { - if value.len() <= 4 { + let char_count = value.chars().count(); + if char_count <= 4 { "***".to_string() } else { - // Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8. - let prefix = value - .char_indices() - .nth(4) - .map(|(byte_idx, _)| &value[..byte_idx]) - .unwrap_or(value); - format!("{}***", prefix) + let prefix: String = value.chars().take(4).collect(); + format!("{prefix}***") } } diff --git a/src/security/nevis.rs b/src/security/nevis.rs new file mode 100644 index 000000000..f6b5ef109 --- /dev/null +++ b/src/security/nevis.rs @@ -0,0 +1,587 @@ +//! Nevis IAM authentication provider for ZeroClaw. +//! +//! Integrates with Nevis Security Suite (Adnovum) for OAuth2/OIDC token +//! validation, FIDO2/passkey verification, and session management. Maps Nevis +//! roles to ZeroClaw tool permissions via [`super::iam_policy::IamPolicy`]. + +use anyhow::{bail, Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +/// Identity resolved from a validated Nevis token or session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NevisIdentity { + /// Unique user identifier from Nevis. + pub user_id: String, + /// Nevis roles assigned to this user. + pub roles: Vec, + /// OAuth2 scopes granted to this session. + pub scopes: Vec, + /// Whether the user completed MFA (FIDO2/passkey/OTP) in this session. + pub mfa_verified: bool, + /// When this session expires (seconds since UNIX epoch). + pub session_expiry: u64, +} + +/// Token validation strategy. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TokenValidationMode { + /// Validate JWT locally using cached JWKS keys. + Local, + /// Validate token by calling the Nevis introspection endpoint. + Remote, +} + +impl TokenValidationMode { + pub fn from_str_config(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "local" => Ok(Self::Local), + "remote" => Ok(Self::Remote), + other => bail!("invalid token_validation mode '{other}': expected 'local' or 'remote'"), + } + } +} + +/// Authentication provider backed by a Nevis instance. +/// +/// Validates tokens, manages sessions, and resolves identities. The provider +/// is designed to be shared across concurrent requests (`Send + Sync`). +pub struct NevisAuthProvider { + /// Base URL of the Nevis instance (e.g. `https://nevis.example.com`). + instance_url: String, + /// Nevis realm to authenticate against. + realm: String, + /// OAuth2 client ID registered in Nevis. + client_id: String, + /// OAuth2 client secret (decrypted at startup). + client_secret: Option, + /// Token validation strategy. + validation_mode: TokenValidationMode, + /// JWKS endpoint for local token validation. + jwks_url: Option, + /// Whether MFA is required for all authentications. + require_mfa: bool, + /// Session timeout duration. + session_timeout: Duration, + /// HTTP client for Nevis API calls. + http_client: reqwest::Client, +} + +impl std::fmt::Debug for NevisAuthProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NevisAuthProvider") + .field("instance_url", &self.instance_url) + .field("realm", &self.realm) + .field("client_id", &self.client_id) + .field( + "client_secret", + &self.client_secret.as_ref().map(|_| "[REDACTED]"), + ) + .field("validation_mode", &self.validation_mode) + .field("jwks_url", &self.jwks_url) + .field("require_mfa", &self.require_mfa) + .field("session_timeout", &self.session_timeout) + .finish_non_exhaustive() + } +} + +// Safety: All fields are Send + Sync. The doc comment promises concurrent use, +// so enforce it at compile time to prevent regressions. +#[allow(clippy::used_underscore_items)] +const _: () = { + fn _assert_send_sync() {} + fn _assert() { + _assert_send_sync::(); + } +}; + +impl NevisAuthProvider { + /// Create a new Nevis auth provider from config values. + /// + /// `client_secret` should already be decrypted by the config loader. + pub fn new( + instance_url: String, + realm: String, + client_id: String, + client_secret: Option, + token_validation: &str, + jwks_url: Option, + require_mfa: bool, + session_timeout_secs: u64, + ) -> Result { + let validation_mode = TokenValidationMode::from_str_config(token_validation)?; + + if validation_mode == TokenValidationMode::Local && jwks_url.is_none() { + bail!( + "Nevis token_validation is 'local' but no jwks_url is configured. \ + Either set jwks_url or use token_validation = 'remote'." + ); + } + + let http_client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .context("Failed to create HTTP client for Nevis")?; + + Ok(Self { + instance_url, + realm, + client_id, + client_secret, + validation_mode, + jwks_url, + require_mfa, + session_timeout: Duration::from_secs(session_timeout_secs), + http_client, + }) + } + + /// Validate a bearer token and resolve the caller's identity. + /// + /// Returns `NevisIdentity` on success, or an error if the token is invalid, + /// expired, or MFA requirements are not met. + pub async fn validate_token(&self, token: &str) -> Result { + if token.is_empty() { + bail!("empty bearer token"); + } + + let identity = match self.validation_mode { + TokenValidationMode::Local => self.validate_token_local(token).await?, + TokenValidationMode::Remote => self.validate_token_remote(token).await?, + }; + + if self.require_mfa && !identity.mfa_verified { + bail!( + "MFA is required but user '{}' has not completed MFA verification", + crate::security::redact(&identity.user_id) + ); + } + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if identity.session_expiry > 0 && identity.session_expiry < now { + bail!("Nevis session expired"); + } + + Ok(identity) + } + + /// Validate token by calling the Nevis introspection endpoint. + async fn validate_token_remote(&self, token: &str) -> Result { + let introspect_url = format!( + "{}/auth/realms/{}/protocol/openid-connect/token/introspect", + self.instance_url.trim_end_matches('/'), + self.realm, + ); + + let mut form = vec![("token", token), ("client_id", &self.client_id)]; + // client_secret is optional (public clients don't need it) + let secret_ref; + if let Some(ref secret) = self.client_secret { + secret_ref = secret.as_str(); + form.push(("client_secret", secret_ref)); + } + + let resp = self + .http_client + .post(&introspect_url) + .form(&form) + .send() + .await + .context("Failed to reach Nevis introspection endpoint")?; + + if !resp.status().is_success() { + bail!( + "Nevis introspection returned HTTP {}", + resp.status().as_u16() + ); + } + + let body: IntrospectionResponse = resp + .json() + .await + .context("Failed to parse Nevis introspection response")?; + + if !body.active { + bail!("Token is not active (revoked or expired)"); + } + + let user_id = body + .sub + .filter(|s| !s.trim().is_empty()) + .context("Token has missing or empty `sub` claim")?; + + let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default(); + roles.sort(); + roles.dedup(); + + Ok(NevisIdentity { + user_id, + roles, + scopes: body + .scope + .unwrap_or_default() + .split_whitespace() + .map(String::from) + .collect(), + mfa_verified: body.acr.as_deref() == Some("mfa") + || body + .amr + .iter() + .flatten() + .any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"), + session_expiry: body.exp.unwrap_or(0), + }) + } + + /// Validate token locally using JWKS. + /// + /// Local JWT/JWKS validation is not yet implemented. Rather than silently + /// falling back to the remote introspection endpoint (which would hide a + /// misconfiguration), this returns an explicit error directing the operator + /// to use `token_validation = "remote"` until local JWKS support is added. + #[allow(clippy::unused_async)] // Will use async when JWKS validation is implemented + async fn validate_token_local(&self, token: &str) -> Result { + // JWT structure check: header.payload.signature + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + bail!("Invalid JWT structure: expected 3 dot-separated parts"); + } + + bail!( + "Local JWKS token validation is not yet implemented. \ + Set token_validation = \"remote\" to use the Nevis introspection endpoint." + ); + } + + /// Validate a Nevis session token (cookie-based sessions). + pub async fn validate_session(&self, session_token: &str) -> Result { + if session_token.is_empty() { + bail!("empty session token"); + } + + let session_url = format!( + "{}/auth/realms/{}/protocol/openid-connect/userinfo", + self.instance_url.trim_end_matches('/'), + self.realm, + ); + + let resp = self + .http_client + .get(&session_url) + .bearer_auth(session_token) + .send() + .await + .context("Failed to reach Nevis userinfo endpoint")?; + + if !resp.status().is_success() { + bail!( + "Nevis session validation returned HTTP {}", + resp.status().as_u16() + ); + } + + let body: UserInfoResponse = resp + .json() + .await + .context("Failed to parse Nevis userinfo response")?; + + if body.sub.trim().is_empty() { + bail!("Userinfo response has missing or empty `sub` claim"); + } + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default(); + roles.sort(); + roles.dedup(); + + let identity = NevisIdentity { + user_id: body.sub, + roles, + scopes: body + .scope + .unwrap_or_default() + .split_whitespace() + .map(String::from) + .collect(), + mfa_verified: body.acr.as_deref() == Some("mfa") + || body + .amr + .iter() + .flatten() + .any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"), + session_expiry: now + self.session_timeout.as_secs(), + }; + + if self.require_mfa && !identity.mfa_verified { + bail!( + "MFA is required but user '{}' has not completed MFA verification", + crate::security::redact(&identity.user_id) + ); + } + + Ok(identity) + } + + /// Health check against the Nevis instance. + pub async fn health_check(&self) -> Result<()> { + let health_url = format!( + "{}/auth/realms/{}", + self.instance_url.trim_end_matches('/'), + self.realm, + ); + + let resp = self + .http_client + .get(&health_url) + .send() + .await + .context("Nevis health check failed: cannot reach instance")?; + + if !resp.status().is_success() { + bail!("Nevis health check failed: HTTP {}", resp.status().as_u16()); + } + + Ok(()) + } + + /// Getter for instance URL (for diagnostics). + pub fn instance_url(&self) -> &str { + &self.instance_url + } + + /// Getter for realm. + pub fn realm(&self) -> &str { + &self.realm + } +} + +// ── Wire types for Nevis API responses ───────────────────────────── + +#[derive(Debug, Deserialize)] +struct IntrospectionResponse { + active: bool, + sub: Option, + scope: Option, + exp: Option, + #[serde(rename = "realm_access")] + realm_access: Option, + /// Authentication Context Class Reference + acr: Option, + /// Authentication Methods References + amr: Option>, +} + +#[derive(Debug, Deserialize)] +struct RealmAccess { + #[serde(default)] + roles: Vec, +} + +#[derive(Debug, Deserialize)] +struct UserInfoResponse { + sub: String, + #[serde(rename = "realm_access")] + realm_access: Option, + scope: Option, + acr: Option, + /// Authentication Methods References + amr: Option>, +} + +// ── Tests ────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_validation_mode_from_str() { + assert_eq!( + TokenValidationMode::from_str_config("local").unwrap(), + TokenValidationMode::Local + ); + assert_eq!( + TokenValidationMode::from_str_config("REMOTE").unwrap(), + TokenValidationMode::Remote + ); + assert!(TokenValidationMode::from_str_config("invalid").is_err()); + } + + #[test] + fn local_mode_requires_jwks_url() { + let result = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "local", + None, // no JWKS URL + false, + 3600, + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("jwks_url")); + } + + #[test] + fn remote_mode_works_without_jwks_url() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "remote", + None, + false, + 3600, + ); + assert!(provider.is_ok()); + } + + #[test] + fn provider_stores_config_correctly() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "test-realm".into(), + "zeroclaw-client".into(), + Some("test-secret".into()), + "remote", + None, + true, + 7200, + ) + .unwrap(); + + assert_eq!(provider.instance_url(), "https://nevis.example.com"); + assert_eq!(provider.realm(), "test-realm"); + assert!(provider.require_mfa); + assert_eq!(provider.session_timeout, Duration::from_secs(7200)); + } + + #[test] + fn debug_redacts_client_secret() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "test-realm".into(), + "zeroclaw-client".into(), + Some("super-secret-value".into()), + "remote", + None, + false, + 3600, + ) + .unwrap(); + + let debug_output = format!("{:?}", provider); + assert!( + !debug_output.contains("super-secret-value"), + "Debug output must not contain the raw client_secret" + ); + assert!( + debug_output.contains("[REDACTED]"), + "Debug output must show [REDACTED] for client_secret" + ); + } + + #[tokio::test] + async fn validate_token_rejects_empty() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "remote", + None, + false, + 3600, + ) + .unwrap(); + + let err = provider.validate_token("").await.unwrap_err(); + assert!(err.to_string().contains("empty bearer token")); + } + + #[tokio::test] + async fn validate_session_rejects_empty() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "remote", + None, + false, + 3600, + ) + .unwrap(); + + let err = provider.validate_session("").await.unwrap_err(); + assert!(err.to_string().contains("empty session token")); + } + + #[test] + fn nevis_identity_serde_roundtrip() { + let identity = NevisIdentity { + user_id: "zeroclaw_user".into(), + roles: vec!["admin".into(), "operator".into()], + scopes: vec!["openid".into(), "profile".into()], + mfa_verified: true, + session_expiry: 1_700_000_000, + }; + + let json = serde_json::to_string(&identity).unwrap(); + let parsed: NevisIdentity = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.user_id, "zeroclaw_user"); + assert_eq!(parsed.roles.len(), 2); + assert!(parsed.mfa_verified); + } + + #[tokio::test] + async fn local_validation_rejects_malformed_jwt() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "local", + Some("https://nevis.example.com/.well-known/jwks.json".into()), + false, + 3600, + ) + .unwrap(); + + let err = provider.validate_token("not-a-jwt").await.unwrap_err(); + assert!(err.to_string().contains("Invalid JWT structure")); + } + + #[tokio::test] + async fn local_validation_errors_instead_of_silent_fallback() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "local", + Some("https://nevis.example.com/.well-known/jwks.json".into()), + false, + 3600, + ) + .unwrap(); + + // A well-formed JWT structure should hit the "not yet implemented" error + // instead of silently falling back to remote introspection. + let err = provider + .validate_token("header.payload.signature") + .await + .unwrap_err(); + assert!(err.to_string().contains("not yet implemented")); + } +} From 0adec305f97470359b349387cbed67c83ef26d30 Mon Sep 17 00:00:00 2001 From: Chris Hengge Date: Sun, 15 Mar 2026 23:35:09 -0500 Subject: [PATCH 04/11] fix(tools): qualify is_service_environment with super:: inside mod native_backend (#3659) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Commit 811fab3b added is_service_environment() as a top-level function and called it from two sites. The call at line 445 is at module scope and resolves fine. The call at line 1473 is inside mod native_backend, which is a child module — Rust does not implicitly import parent-scope items, so the unqualified name fails with E0425 (cannot find function in this scope). Fix: prefix the call with super:: so it resolves to the parent module's function, matching how mod native_backend already imports other parent items (e.g. use super::BrowserAction). The browser-native feature flag is required to reproduce: cargo check --features browser-native # fails without this fix cargo check --features browser-native # clean with this fix Co-authored-by: Argenis --- src/tools/browser.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/browser.rs b/src/tools/browser.rs index 1603176c1..5bd559b12 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -1470,7 +1470,7 @@ mod native_backend { // When running as a service (systemd/OpenRC), the browser sandbox // fails because the process lacks a user namespace / session. // --no-sandbox and --disable-dev-shm-usage are required in this context. - if is_service_environment() { + if super::is_service_environment() { args.push(Value::String("--no-sandbox".to_string())); args.push(Value::String("--disable-dev-shm-usage".to_string())); } From 62781a8d45529b889678072e912260220a2b9801 Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 00:54:27 -0400 Subject: [PATCH 05/11] fix(lint): Box::pin crate::agent::run calls to satisfy large_futures (#3675) Wrap all crate::agent::run() calls with Box::pin() across scheduler, daemon, gateway tests, and main.rs to satisfy clippy::large_futures. Co-authored-by: Claude Opus 4.6 --- src/cron/scheduler.rs | 4 ++-- src/daemon/mod.rs | 8 ++++---- src/gateway/mod.rs | 14 +++++++++----- src/main.rs | 4 ++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 992609f14..4c9564770 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -181,7 +181,7 @@ async fn run_agent_job( let run_result = match job.session_target { SessionTarget::Main | SessionTarget::Isolated => { - crate::agent::run( + Box::pin(crate::agent::run( config.clone(), Some(prefixed_prompt), None, @@ -191,7 +191,7 @@ async fn run_agent_job( false, None, job.allowed_tools.clone(), - ) + )) .await } }; diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 9d7af7126..f695493ad 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -245,7 +245,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { // ── Phase 1: LLM decision (two-phase mode) ────────────── let tasks_to_run = if two_phase { let decision_prompt = HeartbeatEngine::build_decision_prompt(&tasks); - match crate::agent::run( + match Box::pin(crate::agent::run( config.clone(), Some(decision_prompt), None, @@ -255,7 +255,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { false, None, None, - ) + )) .await { Ok(response) => { @@ -288,7 +288,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { for task in &tasks_to_run { let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text); let temp = config.default_temperature; - match crate::agent::run( + match Box::pin(crate::agent::run( config.clone(), Some(prompt), None, @@ -298,7 +298,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { false, None, None, - ) + )) .await { Ok(output) => { diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 825b88156..7a6e41697 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -2492,11 +2492,11 @@ mod tests { node_registry: Arc::new(nodes::NodeRegistry::new(16)), }; - let response = handle_nextcloud_talk_webhook( + let response = Box::pin(handle_nextcloud_talk_webhook( State(state), HeaderMap::new(), Bytes::from_static(br#"{"type":"message"}"#), - ) + )) .await .into_response(); @@ -2558,9 +2558,13 @@ mod tests { HeaderValue::from_str(invalid_signature).unwrap(), ); - let response = handle_nextcloud_talk_webhook(State(state), headers, Bytes::from(body)) - .await - .into_response(); + let response = Box::pin(handle_nextcloud_talk_webhook( + State(state), + headers, + Bytes::from(body), + )) + .await + .into_response(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0); } diff --git a/src/main.rs b/src/main.rs index bf73607af..29ec8ab39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -880,7 +880,7 @@ async fn main() -> Result<()> { } => { let final_temperature = temperature.unwrap_or(config.default_temperature); - agent::run( + Box::pin(agent::run( config, message, provider, @@ -890,7 +890,7 @@ async fn main() -> Result<()> { true, session_state_file, None, - ) + )) .await .map(|_| ()) } From 249434edb26de4c1345de999e568232c1da40f6b Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 00:55:23 -0400 Subject: [PATCH 06/11] feat(notion): add Notion database poller channel and API tool (#3650) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Notion integration with two components: - NotionChannel: polls a Notion database for tasks with configurable status properties, concurrency limits, and stale task recovery - NotionTool: provides CRUD operations (query_database, read_page, create_page, update_page) for agent-driven Notion interactions Includes config schema (NotionConfig), onboarding wizard support, and full unit test coverage for both channel and tool. Supersedes #3609 — rebased on latest master to resolve merge conflicts with swarm feature additions in config/mod.rs. Co-authored-by: Claude Opus 4.6 --- src/channels/mod.rs | 36 +++ src/channels/notion.rs | 614 +++++++++++++++++++++++++++++++++++++++ src/config/mod.rs | 16 +- src/config/schema.rs | 108 ++++++- src/onboard/wizard.rs | 2 + src/tools/mod.rs | 18 ++ src/tools/notion_tool.rs | 438 ++++++++++++++++++++++++++++ 7 files changed, 1221 insertions(+), 11 deletions(-) create mode 100644 src/channels/notion.rs create mode 100644 src/tools/notion_tool.rs diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 1a43b182b..d5caaff1b 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -30,6 +30,7 @@ pub mod mattermost; pub mod nextcloud_talk; #[cfg(feature = "channel-nostr")] pub mod nostr; +pub mod notion; pub mod qq; pub mod session_store; pub mod signal; @@ -62,6 +63,7 @@ pub use mattermost::MattermostChannel; pub use nextcloud_talk::NextcloudTalkChannel; #[cfg(feature = "channel-nostr")] pub use nostr::NostrChannel; +pub use notion::NotionChannel; pub use qq::QQChannel; pub use signal::SignalChannel; pub use slack::SlackChannel; @@ -2981,6 +2983,12 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con channel.name() ); } + // Notion is a top-level config section, not part of ChannelsConfig + { + let notion_configured = + config.notion.enabled && !config.notion.database_id.trim().is_empty(); + println!(" {} Notion", if notion_configured { "✅" } else { "❌" }); + } if !cfg!(feature = "channel-matrix") { println!( " ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)." @@ -3413,6 +3421,34 @@ fn collect_configured_channels( }); } + // Notion database poller channel + if config.notion.enabled && !config.notion.database_id.trim().is_empty() { + let notion_api_key = if config.notion.api_key.trim().is_empty() { + std::env::var("NOTION_API_KEY").unwrap_or_default() + } else { + config.notion.api_key.trim().to_string() + }; + if notion_api_key.trim().is_empty() { + tracing::warn!( + "Notion channel enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)" + ); + } else { + channels.push(ConfiguredChannel { + display_name: "Notion", + channel: Arc::new(NotionChannel::new( + notion_api_key, + config.notion.database_id.clone(), + config.notion.poll_interval_secs, + config.notion.status_property.clone(), + config.notion.input_property.clone(), + config.notion.result_property.clone(), + config.notion.max_concurrent, + config.notion.recover_stale, + )), + }); + } + } + channels } diff --git a/src/channels/notion.rs b/src/channels/notion.rs new file mode 100644 index 000000000..6f8752d65 --- /dev/null +++ b/src/channels/notion.rs @@ -0,0 +1,614 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use anyhow::{bail, Result}; +use async_trait::async_trait; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::RwLock; + +const NOTION_API_BASE: &str = "https://api.notion.com/v1"; +const NOTION_VERSION: &str = "2022-06-28"; +const MAX_RESULT_LENGTH: usize = 2000; +const MAX_RETRIES: u32 = 3; +const RETRY_BASE_DELAY_MS: u64 = 2000; +/// Maximum number of characters to include from an error response body. +const MAX_ERROR_BODY_CHARS: usize = 500; + +/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary. +fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize { + if max_bytes >= s.len() { + return s.len(); + } + let mut idx = max_bytes; + while idx > 0 && !s.is_char_boundary(idx) { + idx -= 1; + } + idx +} + +/// Notion channel — polls a Notion database for pending tasks and writes results back. +/// +/// The channel connects to the Notion API, queries a database for rows with a "pending" +/// status, dispatches them as channel messages, and writes results back when processing +/// completes. It supports crash recovery by resetting stale "running" tasks on startup. +pub struct NotionChannel { + api_key: String, + database_id: String, + poll_interval_secs: u64, + status_property: String, + input_property: String, + result_property: String, + max_concurrent: usize, + status_type: Arc>, + inflight: Arc>>, + http: reqwest::Client, + recover_stale: bool, +} + +impl NotionChannel { + /// Create a new Notion channel with the given configuration. + pub fn new( + api_key: String, + database_id: String, + poll_interval_secs: u64, + status_property: String, + input_property: String, + result_property: String, + max_concurrent: usize, + recover_stale: bool, + ) -> Self { + Self { + api_key, + database_id, + poll_interval_secs, + status_property, + input_property, + result_property, + max_concurrent, + status_type: Arc::new(RwLock::new("select".to_string())), + inflight: Arc::new(RwLock::new(HashSet::new())), + http: reqwest::Client::new(), + recover_stale, + } + } + + /// Build the standard Notion API headers (Authorization, version, content-type). + fn headers(&self) -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", self.api_key) + .parse() + .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?, + ); + headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap()); + headers.insert("Content-Type", "application/json".parse().unwrap()); + Ok(headers) + } + + /// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx). + async fn api_call( + &self, + method: reqwest::Method, + url: &str, + body: Option, + ) -> Result { + let mut last_err = None; + for attempt in 0..MAX_RETRIES { + let mut req = self + .http + .request(method.clone(), url) + .headers(self.headers()?); + if let Some(ref b) = body { + req = req.json(b); + } + match req.send().await { + Ok(resp) => { + let status = resp.status(); + if status.is_success() { + return resp + .json() + .await + .map_err(|e| anyhow::anyhow!("Failed to parse response: {e}")); + } + let status_code = status.as_u16(); + // Only retry on 429 (rate limit) or 5xx (server errors) + if status_code != 429 && (400..500).contains(&status_code) { + let body_text = resp.text().await.unwrap_or_default(); + let truncated = + crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS); + bail!("Notion API error {status_code}: {truncated}"); + } + last_err = Some(anyhow::anyhow!("Notion API error: {status_code}")); + } + Err(e) => { + last_err = Some(anyhow::anyhow!("HTTP request failed: {e}")); + } + } + let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt); + tracing::warn!( + "Notion API call failed (attempt {}/{}), retrying in {}ms", + attempt + 1, + MAX_RETRIES, + delay + ); + tokio::time::sleep(std::time::Duration::from_millis(delay)).await; + } + Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries"))) + } + + /// Query the database schema and detect whether Status uses "select" or "status" type. + async fn detect_status_type(&self) -> Result { + let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id); + let resp = self.api_call(reqwest::Method::GET, &url, None).await?; + let status_type = resp + .get("properties") + .and_then(|p| p.get(&self.status_property)) + .and_then(|s| s.get("type")) + .and_then(|t| t.as_str()) + .unwrap_or("select") + .to_string(); + Ok(status_type) + } + + /// Query for rows where Status = "pending". + async fn query_pending(&self) -> Result> { + let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id); + let status_type = self.status_type.read().await.clone(); + let filter = build_status_filter(&self.status_property, &status_type, "pending"); + let resp = self + .api_call( + reqwest::Method::POST, + &url, + Some(serde_json::json!({ "filter": filter })), + ) + .await?; + Ok(resp + .get("results") + .and_then(|r| r.as_array()) + .cloned() + .unwrap_or_default()) + } + + /// Atomically claim a task. Returns true if this caller got it. + async fn claim_task(&self, page_id: &str) -> bool { + let mut inflight = self.inflight.write().await; + if inflight.contains(page_id) { + return false; + } + if inflight.len() >= self.max_concurrent { + return false; + } + inflight.insert(page_id.to_string()); + true + } + + /// Release a task from the inflight set. + async fn release_task(&self, page_id: &str) { + let mut inflight = self.inflight.write().await; + inflight.remove(page_id); + } + + /// Update a row's status. + async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let status_type = self.status_type.read().await.clone(); + let payload = serde_json::json!({ + "properties": { + &self.status_property: build_status_payload(&status_type, status_value), + } + }); + self.api_call(reqwest::Method::PATCH, &url, Some(payload)) + .await?; + Ok(()) + } + + /// Write result text to the Result column. + async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let payload = serde_json::json!({ + "properties": { + &self.result_property: build_rich_text_payload(result_text), + } + }); + self.api_call(reqwest::Method::PATCH, &url, Some(payload)) + .await?; + Ok(()) + } + + /// On startup, reset "running" tasks back to "pending" for crash recovery. + async fn recover_stale(&self) -> Result<()> { + let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id); + let status_type = self.status_type.read().await.clone(); + let filter = build_status_filter(&self.status_property, &status_type, "running"); + let resp = self + .api_call( + reqwest::Method::POST, + &url, + Some(serde_json::json!({ "filter": filter })), + ) + .await?; + let stale = resp + .get("results") + .and_then(|r| r.as_array()) + .cloned() + .unwrap_or_default(); + if stale.is_empty() { + return Ok(()); + } + tracing::warn!( + "Found {} stale task(s) in 'running' state, resetting to 'pending'", + stale.len() + ); + for task in &stale { + if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) { + let page_url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let payload = serde_json::json!({ + "properties": { + &self.status_property: build_status_payload(&status_type, "pending"), + &self.result_property: build_rich_text_payload( + "Reset: poller restarted while task was running" + ), + } + }); + let short_id_end = floor_utf8_char_boundary(page_id, 8); + let short_id = &page_id[..short_id_end]; + if let Err(e) = self + .api_call(reqwest::Method::PATCH, &page_url, Some(payload)) + .await + { + tracing::error!("Could not reset stale task {short_id}: {e}"); + } else { + tracing::info!("Reset stale task {short_id} to pending"); + } + } + } + Ok(()) + } +} + +#[async_trait] +impl Channel for NotionChannel { + fn name(&self) -> &str { + "notion" + } + + async fn send(&self, message: &SendMessage) -> Result<()> { + // recipient is the page_id for Notion + let page_id = &message.recipient; + let status_type = self.status_type.read().await.clone(); + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let payload = serde_json::json!({ + "properties": { + &self.status_property: build_status_payload(&status_type, "done"), + &self.result_property: build_rich_text_payload(&message.content), + } + }); + self.api_call(reqwest::Method::PATCH, &url, Some(payload)) + .await?; + self.release_task(page_id).await; + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> Result<()> { + // Detect status property type + match self.detect_status_type().await { + Ok(st) => { + tracing::info!("Notion status property type: {st}"); + *self.status_type.write().await = st; + } + Err(e) => { + bail!("Failed to detect Notion database schema: {e}"); + } + } + + // Crash recovery + if self.recover_stale { + if let Err(e) = self.recover_stale().await { + tracing::error!("Notion stale task recovery failed: {e}"); + } + } + + // Polling loop + loop { + match self.query_pending().await { + Ok(tasks) => { + if !tasks.is_empty() { + tracing::info!("Notion: found {} pending task(s)", tasks.len()); + } + for task in tasks { + let page_id = match task.get("id").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => continue, + }; + + let input_text = extract_text_from_property( + task.get("properties") + .and_then(|p| p.get(&self.input_property)), + ); + + if input_text.trim().is_empty() { + let short_end = floor_utf8_char_boundary(&page_id, 8); + tracing::warn!( + "Notion: empty input for task {}, skipping", + &page_id[..short_end] + ); + continue; + } + + if !self.claim_task(&page_id).await { + continue; + } + + // Set status to running + if let Err(e) = self.set_status(&page_id, "running").await { + tracing::error!("Notion: failed to set running status: {e}"); + self.release_task(&page_id).await; + continue; + } + + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if tx + .send(ChannelMessage { + id: page_id.clone(), + sender: "notion".into(), + reply_target: page_id, + content: input_text, + channel: "notion".into(), + timestamp, + thread_ts: None, + }) + .await + .is_err() + { + tracing::info!("Notion channel shutting down"); + return Ok(()); + } + } + } + Err(e) => { + tracing::error!("Notion poll error: {e}"); + } + } + + tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await; + } + } + + async fn health_check(&self) -> bool { + let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id); + self.api_call(reqwest::Method::GET, &url, None) + .await + .is_ok() + } +} + +// ── Helper functions ────────────────────────────────────────────── + +/// Build a Notion API filter object for the given status property. +fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value { + if status_type == "status" { + serde_json::json!({ + "property": property, + "status": { "equals": value } + }) + } else { + serde_json::json!({ + "property": property, + "select": { "equals": value } + }) + } +} + +/// Build a Notion API property-update payload for a status field. +fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value { + if status_type == "status" { + serde_json::json!({ "status": { "name": value } }) + } else { + serde_json::json!({ "select": { "name": value } }) + } +} + +/// Build a Notion API rich-text property payload, truncating if necessary. +fn build_rich_text_payload(value: &str) -> serde_json::Value { + let truncated = truncate_result(value); + serde_json::json!({ + "rich_text": [{ + "text": { "content": truncated } + }] + }) +} + +/// Truncate result text to fit within the Notion rich-text content limit. +fn truncate_result(value: &str) -> String { + if value.len() <= MAX_RESULT_LENGTH { + return value.to_string(); + } + let cut = MAX_RESULT_LENGTH.saturating_sub(30); + // Ensure we cut on a char boundary + let end = floor_utf8_char_boundary(value, cut); + format!("{}\n\n... [output truncated]", &value[..end]) +} + +/// Extract plain text from a Notion property (title or rich_text type). +fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String { + let Some(prop) = prop else { + return String::new(); + }; + let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or(""); + let array_key = match ptype { + "title" => "title", + "rich_text" => "rich_text", + _ => return String::new(), + }; + prop.get(array_key) + .and_then(|arr| arr.as_array()) + .map(|items| { + items + .iter() + .filter_map(|item| item.get("plain_text").and_then(|t| t.as_str())) + .collect::>() + .join("") + }) + .unwrap_or_default() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn claim_task_deduplication() { + let channel = NotionChannel::new( + "test-key".into(), + "test-db".into(), + 5, + "Status".into(), + "Input".into(), + "Result".into(), + 4, + false, + ); + + assert!(channel.claim_task("page-1").await); + // Second claim for same page should fail + assert!(!channel.claim_task("page-1").await); + // Different page should succeed + assert!(channel.claim_task("page-2").await); + + // After release, can claim again + channel.release_task("page-1").await; + assert!(channel.claim_task("page-1").await); + } + + #[test] + fn result_truncation_within_limit() { + let short = "hello world"; + assert_eq!(truncate_result(short), short); + } + + #[test] + fn result_truncation_over_limit() { + let long = "a".repeat(MAX_RESULT_LENGTH + 100); + let truncated = truncate_result(&long); + assert!(truncated.len() <= MAX_RESULT_LENGTH); + assert!(truncated.ends_with("... [output truncated]")); + } + + #[test] + fn result_truncation_multibyte_safe() { + // Build a string that would cut in the middle of a multibyte char + let mut s = String::new(); + for _ in 0..700 { + s.push('\u{6E2C}'); // 3-byte UTF-8 char + } + let truncated = truncate_result(&s); + // Should not panic and should be valid UTF-8 + assert!(truncated.len() <= MAX_RESULT_LENGTH); + assert!(truncated.ends_with("... [output truncated]")); + } + + #[test] + fn status_payload_select_type() { + let payload = build_status_payload("select", "pending"); + assert_eq!( + payload, + serde_json::json!({ "select": { "name": "pending" } }) + ); + } + + #[test] + fn status_payload_status_type() { + let payload = build_status_payload("status", "done"); + assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } })); + } + + #[test] + fn rich_text_payload_construction() { + let payload = build_rich_text_payload("test output"); + let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap(); + assert_eq!(text, "test output"); + } + + #[test] + fn status_filter_select_type() { + let filter = build_status_filter("Status", "select", "pending"); + assert_eq!( + filter, + serde_json::json!({ + "property": "Status", + "select": { "equals": "pending" } + }) + ); + } + + #[test] + fn status_filter_status_type() { + let filter = build_status_filter("Status", "status", "running"); + assert_eq!( + filter, + serde_json::json!({ + "property": "Status", + "status": { "equals": "running" } + }) + ); + } + + #[test] + fn extract_text_from_title_property() { + let prop = serde_json::json!({ + "type": "title", + "title": [ + { "plain_text": "Hello " }, + { "plain_text": "World" } + ] + }); + assert_eq!(extract_text_from_property(Some(&prop)), "Hello World"); + } + + #[test] + fn extract_text_from_rich_text_property() { + let prop = serde_json::json!({ + "type": "rich_text", + "rich_text": [{ "plain_text": "task content" }] + }); + assert_eq!(extract_text_from_property(Some(&prop)), "task content"); + } + + #[test] + fn extract_text_from_none() { + assert_eq!(extract_text_from_property(None), ""); + } + + #[test] + fn extract_text_from_unknown_type() { + let prop = serde_json::json!({ "type": "number", "number": 42 }); + assert_eq!(extract_text_from_property(Some(&prop)), ""); + } + + #[tokio::test] + async fn claim_task_respects_max_concurrent() { + let channel = NotionChannel::new( + "test-key".into(), + "test-db".into(), + 5, + "Status".into(), + "Input".into(), + "Result".into(), + 2, // max_concurrent = 2 + false, + ); + + assert!(channel.claim_task("page-1").await); + assert!(channel.claim_task("page-2").await); + // Third claim should be rejected (at capacity) + assert!(!channel.claim_task("page-3").await); + + // After releasing one, can claim again + channel.release_task("page-1").await; + assert!(channel.claim_task("page-3").await); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 98938aa1d..42ca6616f 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -13,14 +13,14 @@ pub use schema::{ GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig, - NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, - OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, - QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, - RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, - StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, - ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig, - WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, + NextcloudTalkConfig, NodesConfig, NotionConfig, ObservabilityConfig, OpenAiTtsConfig, + OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, + ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, + ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, + SecretsConfig, SecurityConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, + StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, + SwarmStrategy, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, + TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, }; pub fn name_and_presence(channel: Option<&T>) -> (&'static str, bool) { diff --git a/src/config/schema.rs b/src/config/schema.rs index e39e1233b..99dc93dcf 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -263,6 +263,10 @@ pub struct Config { /// Multi-client workspace isolation configuration (`[workspace]`). #[serde(default)] pub workspace: WorkspaceConfig, + + /// Notion integration configuration (`[notion]`). + #[serde(default)] + pub notion: NotionConfig, } /// Multi-client workspace isolation configuration. @@ -4443,6 +4447,70 @@ pub fn default_nostr_relays() -> Vec { ] } +// -- Notion -- + +/// Notion integration configuration (`[notion]`). +/// +/// When `enabled = true`, the agent polls a Notion database for pending tasks +/// and exposes a `notion` tool for querying, reading, creating, and updating pages. +/// Requires `api_key` (or the `NOTION_API_KEY` env var) and `database_id`. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct NotionConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub api_key: String, + #[serde(default)] + pub database_id: String, + #[serde(default = "default_notion_poll_interval")] + pub poll_interval_secs: u64, + #[serde(default = "default_notion_status_prop")] + pub status_property: String, + #[serde(default = "default_notion_input_prop")] + pub input_property: String, + #[serde(default = "default_notion_result_prop")] + pub result_property: String, + #[serde(default = "default_notion_max_concurrent")] + pub max_concurrent: usize, + #[serde(default = "default_notion_recover_stale")] + pub recover_stale: bool, +} + +fn default_notion_poll_interval() -> u64 { + 5 +} +fn default_notion_status_prop() -> String { + "Status".into() +} +fn default_notion_input_prop() -> String { + "Input".into() +} +fn default_notion_result_prop() -> String { + "Result".into() +} +fn default_notion_max_concurrent() -> usize { + 4 +} +fn default_notion_recover_stale() -> bool { + true +} + +impl Default for NotionConfig { + fn default() -> Self { + Self { + enabled: false, + api_key: String::new(), + database_id: String::new(), + poll_interval_secs: default_notion_poll_interval(), + status_property: default_notion_status_prop(), + input_property: default_notion_input_prop(), + result_property: default_notion_result_prop(), + max_concurrent: default_notion_max_concurrent(), + recover_stale: default_notion_recover_stale(), + } + } +} + // ── Config impl ────────────────────────────────────────────────── impl Default for Config { @@ -4501,6 +4569,7 @@ impl Default for Config { mcp: McpConfig::default(), nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), + notion: NotionConfig::default(), } } } @@ -5240,6 +5309,11 @@ impl Config { "config.security.nevis.client_secret", )?; + // Notion API key (top-level, not in ChannelsConfig) + if !config.notion.api_key.is_empty() { + decrypt_secret(&store, &mut config.notion.api_key, "config.notion.api_key")?; + } + config.apply_env_overrides(); config.validate()?; tracing::info!( @@ -5553,9 +5627,26 @@ impl Config { // Proxy (delegate to existing validation) self.proxy.validate()?; - // MCP servers - if self.mcp.enabled { - validate_mcp_config(&self.mcp)?; + // Notion + if self.notion.enabled { + if self.notion.database_id.trim().is_empty() { + anyhow::bail!("notion.database_id must not be empty when notion.enabled = true"); + } + if self.notion.poll_interval_secs == 0 { + anyhow::bail!("notion.poll_interval_secs must be greater than 0"); + } + if self.notion.max_concurrent == 0 { + anyhow::bail!("notion.max_concurrent must be greater than 0"); + } + if self.notion.status_property.trim().is_empty() { + anyhow::bail!("notion.status_property must not be empty"); + } + if self.notion.input_property.trim().is_empty() { + anyhow::bail!("notion.input_property must not be empty"); + } + if self.notion.result_property.trim().is_empty() { + anyhow::bail!("notion.result_property must not be empty"); + } } // Nevis IAM — delegate to NevisConfig::validate() for field-level checks @@ -6193,6 +6284,15 @@ impl Config { "config.security.nevis.client_secret", )?; + // Notion API key (top-level, not in ChannelsConfig) + if !config_to_save.notion.api_key.is_empty() { + encrypt_secret( + &store, + &mut config_to_save.notion.api_key, + "config.notion.api_key", + )?; + } + let toml_str = toml::to_string_pretty(&config_to_save).context("Failed to serialize config")?; @@ -6644,6 +6744,7 @@ default_temperature = 0.7 mcp: McpConfig::default(), nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), + notion: NotionConfig::default(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -6937,6 +7038,7 @@ tool_dispatcher = "xml" mcp: McpConfig::default(), nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), + notion: NotionConfig::default(), }; config.save().await.unwrap(); diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 6b4478005..9dffb79d9 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -179,6 +179,7 @@ pub async fn run_wizard(force: bool) -> Result { mcp: crate::config::McpConfig::default(), nodes: crate::config::NodesConfig::default(), workspace: crate::config::WorkspaceConfig::default(), + notion: crate::config::NotionConfig::default(), }; println!( @@ -538,6 +539,7 @@ async fn run_quick_setup_with_home( mcp: crate::config::McpConfig::default(), nodes: crate::config::NodesConfig::default(), workspace: crate::config::WorkspaceConfig::default(), + notion: crate::config::NotionConfig::default(), }; config.save().await?; diff --git a/src/tools/mod.rs b/src/tools/mod.rs index db720b252..7ce81b2ca 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -50,6 +50,7 @@ pub mod memory_recall; pub mod memory_store; pub mod model_routing_config; pub mod node_tool; +pub mod notion_tool; pub mod pdf_read; pub mod proxy_config; pub mod pushover; @@ -97,6 +98,7 @@ pub use memory_store::MemoryStoreTool; pub use model_routing_config::ModelRoutingConfigTool; #[allow(unused_imports)] pub use node_tool::NodeTool; +pub use notion_tool::NotionTool; pub use pdf_read::PdfReadTool; pub use proxy_config::ProxyConfigTool; pub use pushover::PushoverTool; @@ -344,6 +346,22 @@ pub fn all_tools_with_runtime( ))); } + // Notion API tool (conditionally registered) + if root_config.notion.enabled { + let notion_api_key = if root_config.notion.api_key.trim().is_empty() { + std::env::var("NOTION_API_KEY").unwrap_or_default() + } else { + root_config.notion.api_key.trim().to_string() + }; + if notion_api_key.trim().is_empty() { + tracing::warn!( + "Notion tool enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)" + ); + } else { + tool_arcs.push(Arc::new(NotionTool::new(notion_api_key, security.clone()))); + } + } + // PDF extraction (feature-gated at compile time via rag-pdf) tool_arcs.push(Arc::new(PdfReadTool::new(security.clone()))); diff --git a/src/tools/notion_tool.rs b/src/tools/notion_tool.rs new file mode 100644 index 000000000..4fb044d89 --- /dev/null +++ b/src/tools/notion_tool.rs @@ -0,0 +1,438 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::{policy::ToolOperation, SecurityPolicy}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +const NOTION_API_BASE: &str = "https://api.notion.com/v1"; +const NOTION_VERSION: &str = "2022-06-28"; +const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30; +/// Maximum number of characters to include from an error response body. +const MAX_ERROR_BODY_CHARS: usize = 500; + +/// Tool for interacting with the Notion API — query databases, read/create/update pages, +/// and search the workspace. Each action is gated by the appropriate security operation +/// (Read for queries, Act for mutations). +pub struct NotionTool { + api_key: String, + http: reqwest::Client, + security: Arc, +} + +impl NotionTool { + /// Create a new Notion tool with the given API key and security policy. + pub fn new(api_key: String, security: Arc) -> Self { + Self { + api_key, + http: reqwest::Client::new(), + security, + } + } + + /// Build the standard Notion API headers (Authorization, version, content-type). + fn headers(&self) -> anyhow::Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", self.api_key) + .parse() + .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?, + ); + headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap()); + headers.insert("Content-Type", "application/json".parse().unwrap()); + Ok(headers) + } + + /// Query a Notion database with an optional filter. + async fn query_database( + &self, + database_id: &str, + filter: Option<&serde_json::Value>, + ) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/databases/{database_id}/query"); + let mut body = json!({}); + if let Some(f) = filter { + body["filter"] = f.clone(); + } + let resp = self + .http + .post(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion query_database failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Read a single Notion page by ID. + async fn read_page(&self, page_id: &str) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let resp = self + .http + .get(&url) + .headers(self.headers()?) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion read_page failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Create a new Notion page, optionally within a database. + async fn create_page( + &self, + properties: &serde_json::Value, + database_id: Option<&str>, + ) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/pages"); + let mut body = json!({ "properties": properties }); + if let Some(db_id) = database_id { + body["parent"] = json!({ "database_id": db_id }); + } + let resp = self + .http + .post(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion create_page failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Update an existing Notion page's properties. + async fn update_page( + &self, + page_id: &str, + properties: &serde_json::Value, + ) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let body = json!({ "properties": properties }); + let resp = self + .http + .patch(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion update_page failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Search the Notion workspace by query string. + async fn search(&self, query: &str) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/search"); + let body = json!({ "query": query }); + let resp = self + .http + .post(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion search failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } +} + +#[async_trait] +impl Tool for NotionTool { + fn name(&self) -> &str { + "notion" + } + + fn description(&self) -> &str { + "Interact with Notion: query databases, read/create/update pages, and search the workspace." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["query_database", "read_page", "create_page", "update_page", "search"], + "description": "The Notion API action to perform" + }, + "database_id": { + "type": "string", + "description": "Database ID (required for query_database, optional for create_page)" + }, + "page_id": { + "type": "string", + "description": "Page ID (required for read_page and update_page)" + }, + "filter": { + "type": "object", + "description": "Notion filter object for query_database" + }, + "properties": { + "type": "object", + "description": "Properties object for create_page and update_page" + }, + "query": { + "type": "string", + "description": "Search query string for the search action" + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = match args.get("action").and_then(|v| v.as_str()) { + Some(a) => a, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: action".into()), + }); + } + }; + + // Enforce granular security: Read for queries, Act for mutations + let operation = match action { + "query_database" | "read_page" | "search" => ToolOperation::Read, + "create_page" | "update_page" => ToolOperation::Act, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search" + )), + }); + } + }; + + if let Err(error) = self.security.enforce_tool_operation(operation, "notion") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + + let result = match action { + "query_database" => { + let database_id = match args.get("database_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("query_database requires database_id parameter".into()), + }); + } + }; + let filter = args.get("filter"); + self.query_database(database_id, filter).await + } + "read_page" => { + let page_id = match args.get("page_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("read_page requires page_id parameter".into()), + }); + } + }; + self.read_page(page_id).await + } + "create_page" => { + let properties = match args.get("properties") { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("create_page requires properties parameter".into()), + }); + } + }; + let database_id = args.get("database_id").and_then(|v| v.as_str()); + self.create_page(properties, database_id).await + } + "update_page" => { + let page_id = match args.get("page_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("update_page requires page_id parameter".into()), + }); + } + }; + let properties = match args.get("properties") { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("update_page requires properties parameter".into()), + }); + } + }; + self.update_page(page_id, properties).await + } + "search" => { + let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); + self.search(query).await + } + _ => unreachable!(), // Already handled above + }; + + match result { + Ok(value) => Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::SecurityPolicy; + + fn test_tool() -> NotionTool { + let security = Arc::new(SecurityPolicy::default()); + NotionTool::new("test-key".into(), security) + } + + #[test] + fn tool_name_is_notion() { + let tool = test_tool(); + assert_eq!(tool.name(), "notion"); + } + + #[test] + fn parameters_schema_has_required_action() { + let tool = test_tool(); + let schema = tool.parameters_schema(); + let required = schema["required"].as_array().unwrap(); + assert!(required.iter().any(|v| v.as_str() == Some("action"))); + } + + #[test] + fn parameters_schema_defines_all_actions() { + let tool = test_tool(); + let schema = tool.parameters_schema(); + let actions = schema["properties"]["action"]["enum"].as_array().unwrap(); + let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect(); + assert!(action_strs.contains(&"query_database")); + assert!(action_strs.contains(&"read_page")); + assert!(action_strs.contains(&"create_page")); + assert!(action_strs.contains(&"update_page")); + assert!(action_strs.contains(&"search")); + } + + #[tokio::test] + async fn execute_missing_action_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("action")); + } + + #[tokio::test] + async fn execute_unknown_action_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "invalid"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("Unknown action")); + } + + #[tokio::test] + async fn execute_query_database_missing_id_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "query_database"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("database_id")); + } + + #[tokio::test] + async fn execute_read_page_missing_id_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "read_page"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("page_id")); + } + + #[tokio::test] + async fn execute_create_page_missing_properties_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "create_page"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("properties")); + } + + #[tokio::test] + async fn execute_update_page_missing_page_id_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "update_page", "properties": {}})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("page_id")); + } + + #[tokio::test] + async fn execute_update_page_missing_properties_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "update_page", "page_id": "test-id"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("properties")); + } +} From d9ab017df09aa5c8a7fb58caf63d2f0c734d1cb7 Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 01:44:39 -0400 Subject: [PATCH 07/11] feat(tools): add Microsoft 365 integration via Graph API (#3653) Add Microsoft 365 tool providing access to Outlook mail, Teams messages, Calendar events, OneDrive files, and SharePoint search via Microsoft Graph API. Includes OAuth2 token caching (client credentials and device code flows), security policy enforcement, and config validation. Rebased on latest master, resolving conflicts with SwarmConfig exports and adding approval_manager to ChannelRuntimeContext test constructors. Original work by @rareba. Co-authored-by: Claude Opus 4.6 --- src/config/mod.rs | 19 +- src/config/schema.rs | 171 ++++++ src/daemon/mod.rs | 2 +- src/main.rs | 4 +- src/onboard/wizard.rs | 2 + src/tools/microsoft365/auth.rs | 400 +++++++++++++ src/tools/microsoft365/graph_client.rs | 495 ++++++++++++++++ src/tools/microsoft365/mod.rs | 567 +++++++++++++++++++ src/tools/microsoft365/types.rs | 55 ++ src/tools/mod.rs | 57 ++ src/tools/project_intel.rs | 750 +++++++++++++++++++++++++ src/tools/report_templates.rs | 582 +++++++++++++++++++ 12 files changed, 3092 insertions(+), 12 deletions(-) create mode 100644 src/tools/microsoft365/auth.rs create mode 100644 src/tools/microsoft365/graph_client.rs create mode 100644 src/tools/microsoft365/mod.rs create mode 100644 src/tools/microsoft365/types.rs create mode 100644 src/tools/project_intel.rs create mode 100644 src/tools/report_templates.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 42ca6616f..a82209d84 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -12,15 +12,16 @@ pub use schema::{ ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, - McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig, - NextcloudTalkConfig, NodesConfig, NotionConfig, ObservabilityConfig, OpenAiTtsConfig, - OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, - ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, - ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, - SecretsConfig, SecurityConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, - StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, - SwarmStrategy, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, - TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, + McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, + MultimodalConfig, NextcloudTalkConfig, NodesConfig, NotionConfig, ObservabilityConfig, + OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig, + PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, + ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, + SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, SkillsPromptInjectionMode, + SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, + SwarmConfig, SwarmStrategy, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, + TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, + WorkspaceConfig, }; pub fn name_and_presence(channel: Option<&T>) -> (&'static str, bool) { diff --git a/src/config/schema.rs b/src/config/schema.rs index 99dc93dcf..6dca7cf8f 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -188,6 +188,10 @@ pub struct Config { #[serde(default)] pub composio: ComposioConfig, + /// Microsoft 365 Graph API integration (`[microsoft365]`). + #[serde(default)] + pub microsoft365: Microsoft365Config, + /// Secrets encryption configuration (`[secrets]`). #[serde(default)] pub secrets: SecretsConfig, @@ -1380,6 +1384,78 @@ impl Default for ComposioConfig { } } +// ── Microsoft 365 (Graph API integration) ─────────────────────── + +/// Microsoft 365 integration via Microsoft Graph API (`[microsoft365]` section). +/// +/// Provides access to Outlook mail, Teams messages, Calendar events, +/// OneDrive files, and SharePoint search. +#[derive(Clone, Serialize, Deserialize, JsonSchema)] +pub struct Microsoft365Config { + /// Enable Microsoft 365 integration + #[serde(default, alias = "enable")] + pub enabled: bool, + /// Azure AD tenant ID + #[serde(default)] + pub tenant_id: Option, + /// Azure AD application (client) ID + #[serde(default)] + pub client_id: Option, + /// Azure AD client secret (stored encrypted when secrets.encrypt = true) + #[serde(default)] + pub client_secret: Option, + /// Authentication flow: "client_credentials" or "device_code" + #[serde(default = "default_ms365_auth_flow")] + pub auth_flow: String, + /// OAuth scopes to request + #[serde(default = "default_ms365_scopes")] + pub scopes: Vec, + /// Encrypt the token cache file on disk + #[serde(default = "default_true")] + pub token_cache_encrypted: bool, + /// User principal name or "me" (for delegated flows) + #[serde(default)] + pub user_id: Option, +} + +fn default_ms365_auth_flow() -> String { + "client_credentials".to_string() +} + +fn default_ms365_scopes() -> Vec { + vec!["https://graph.microsoft.com/.default".to_string()] +} + +impl std::fmt::Debug for Microsoft365Config { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Microsoft365Config") + .field("enabled", &self.enabled) + .field("tenant_id", &self.tenant_id) + .field("client_id", &self.client_id) + .field("client_secret", &self.client_secret.as_ref().map(|_| "***")) + .field("auth_flow", &self.auth_flow) + .field("scopes", &self.scopes) + .field("token_cache_encrypted", &self.token_cache_encrypted) + .field("user_id", &self.user_id) + .finish() + } +} + +impl Default for Microsoft365Config { + fn default() -> Self { + Self { + enabled: false, + tenant_id: None, + client_id: None, + client_secret: None, + auth_flow: default_ms365_auth_flow(), + scopes: default_ms365_scopes(), + token_cache_encrypted: true, + user_id: None, + } + } +} + // ── Secrets (encrypted credential store) ──────────────────────── /// Secrets encryption configuration (`[secrets]` section). @@ -4549,6 +4625,7 @@ impl Default for Config { tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), @@ -5045,6 +5122,11 @@ impl Config { &mut config.composio.api_key, "config.composio.api_key", )?; + decrypt_optional_secret( + &store, + &mut config.microsoft365.client_secret, + "config.microsoft365.client_secret", + )?; decrypt_optional_secret( &store, @@ -5619,6 +5701,88 @@ impl Config { } } + // Microsoft 365 + if self.microsoft365.enabled { + let tenant = self + .microsoft365 + .tenant_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if tenant.is_none() { + anyhow::bail!( + "microsoft365.tenant_id must not be empty when microsoft365 is enabled" + ); + } + let client = self + .microsoft365 + .client_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if client.is_none() { + anyhow::bail!( + "microsoft365.client_id must not be empty when microsoft365 is enabled" + ); + } + let flow = self.microsoft365.auth_flow.trim(); + if flow != "client_credentials" && flow != "device_code" { + anyhow::bail!( + "microsoft365.auth_flow must be 'client_credentials' or 'device_code'" + ); + } + if flow == "client_credentials" + && self + .microsoft365 + .client_secret + .as_deref() + .map_or(true, |s| s.trim().is_empty()) + { + anyhow::bail!( + "microsoft365.client_secret must not be empty when auth_flow is 'client_credentials'" + ); + } + } + + // Microsoft 365 + if self.microsoft365.enabled { + let tenant = self + .microsoft365 + .tenant_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if tenant.is_none() { + anyhow::bail!( + "microsoft365.tenant_id must not be empty when microsoft365 is enabled" + ); + } + let client = self + .microsoft365 + .client_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if client.is_none() { + anyhow::bail!( + "microsoft365.client_id must not be empty when microsoft365 is enabled" + ); + } + let flow = self.microsoft365.auth_flow.trim(); + if flow != "client_credentials" && flow != "device_code" { + anyhow::bail!("microsoft365.auth_flow must be client_credentials or device_code"); + } + if flow == "client_credentials" + && self + .microsoft365 + .client_secret + .as_deref() + .map_or(true, |s| s.trim().is_empty()) + { + anyhow::bail!("microsoft365.client_secret must not be empty when auth_flow is client_credentials"); + } + } + // MCP if self.mcp.enabled { validate_mcp_config(&self.mcp)?; @@ -6020,6 +6184,11 @@ impl Config { &mut config_to_save.composio.api_key, "config.composio.api_key", )?; + encrypt_optional_secret( + &store, + &mut config_to_save.microsoft365.client_secret, + "config.microsoft365.client_secret", + )?; encrypt_optional_secret( &store, @@ -6724,6 +6893,7 @@ default_temperature = 0.7 tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), @@ -7018,6 +7188,7 @@ tool_dispatcher = "xml" tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index f695493ad..267dae28a 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -77,7 +77,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> { max_backoff, move || { let cfg = channels_cfg.clone(); - async move { crate::channels::start_channels(cfg).await } + async move { Box::pin(crate::channels::start_channels(cfg)).await } }, )); } else { diff --git a/src/main.rs b/src/main.rs index 29ec8ab39..e2d04c736 100644 --- a/src/main.rs +++ b/src/main.rs @@ -844,7 +844,7 @@ async fn main() -> Result<()> { // Auto-start channels if user said yes during wizard if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") { - channels::start_channels(config).await?; + Box::pin(channels::start_channels(config)).await?; } return Ok(()); } @@ -1189,7 +1189,7 @@ async fn main() -> Result<()> { }, Commands::Channel { channel_command } => match channel_command { - ChannelCommands::Start => channels::start_channels(config).await, + ChannelCommands::Start => Box::pin(channels::start_channels(config)).await, ChannelCommands::Doctor => channels::doctor_channels(config).await, other => channels::handle_command(other, &config).await, }, diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 9dffb79d9..189d39f19 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -159,6 +159,7 @@ pub async fn run_wizard(force: bool) -> Result { tunnel: tunnel_config, gateway: crate::config::GatewayConfig::default(), composio: composio_config, + microsoft365: crate::config::Microsoft365Config::default(), secrets: secrets_config, browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), @@ -519,6 +520,7 @@ async fn run_quick_setup_with_home( tunnel: crate::config::TunnelConfig::default(), gateway: crate::config::GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: crate::config::Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), diff --git a/src/tools/microsoft365/auth.rs b/src/tools/microsoft365/auth.rs new file mode 100644 index 000000000..07afd4b14 --- /dev/null +++ b/src/tools/microsoft365/auth.rs @@ -0,0 +1,400 @@ +use anyhow::Context; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::path::PathBuf; +use tokio::sync::Mutex; + +/// Cached OAuth2 token state persisted to disk between runs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedTokenState { + pub access_token: String, + pub refresh_token: Option, + /// Unix timestamp (seconds) when the access token expires. + pub expires_at: i64, +} + +impl CachedTokenState { + /// Returns `true` when the token is expired or will expire within 60 seconds. + pub fn is_expired(&self) -> bool { + let now = chrono::Utc::now().timestamp(); + self.expires_at <= now + 60 + } +} + +/// Thread-safe token cache with disk persistence. +pub struct TokenCache { + inner: RwLock>, + /// Serialises the slow acquire/refresh path so only one caller performs the + /// network round-trip while others wait and then read the updated cache. + acquire_lock: Mutex<()>, + config: super::types::Microsoft365ResolvedConfig, + cache_path: PathBuf, +} + +impl TokenCache { + pub fn new( + config: super::types::Microsoft365ResolvedConfig, + zeroclaw_dir: &std::path::Path, + ) -> anyhow::Result { + if config.token_cache_encrypted { + anyhow::bail!( + "microsoft365: token_cache_encrypted is enabled but encryption is not yet \ + implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \ + to false or wait for encryption support." + ); + } + + // Scope cache file to (tenant_id, client_id, auth_flow) so config + // changes never reuse tokens from a different account/flow. + let mut hasher = DefaultHasher::new(); + config.tenant_id.hash(&mut hasher); + config.client_id.hash(&mut hasher); + config.auth_flow.hash(&mut hasher); + let fingerprint = format!("{:016x}", hasher.finish()); + + let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json")); + let cached = Self::load_from_disk(&cache_path); + Ok(Self { + inner: RwLock::new(cached), + acquire_lock: Mutex::new(()), + config, + cache_path, + }) + } + + /// Get a valid access token, refreshing or re-authenticating as needed. + pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result { + // Fast path: cached and not expired. + { + let guard = self.inner.read(); + if let Some(ref state) = *guard { + if !state.is_expired() { + return Ok(state.access_token.clone()); + } + } + } + + // Slow path: serialise through a mutex so only one caller performs the + // network round-trip while concurrent callers wait and re-check. + let _lock = self.acquire_lock.lock().await; + + // Re-check after acquiring the lock — another caller may have refreshed + // while we were waiting. + { + let guard = self.inner.read(); + if let Some(ref state) = *guard { + if !state.is_expired() { + return Ok(state.access_token.clone()); + } + } + } + + let new_state = self.acquire_token(client).await?; + let token = new_state.access_token.clone(); + self.persist_to_disk(&new_state); + *self.inner.write() = Some(new_state); + Ok(token) + } + + async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result { + // Try refresh first if we have a refresh token and the flow supports it. + // Client credentials flow does not issue refresh tokens, so skip the + // attempt entirely to avoid a wasted round-trip. + if self.config.auth_flow.as_str() != "client_credentials" { + // Clone the token out so the RwLock guard is dropped before the await. + let refresh_token_copy = { + let guard = self.inner.read(); + guard.as_ref().and_then(|state| state.refresh_token.clone()) + }; + if let Some(refresh_tok) = refresh_token_copy { + match self.refresh_token(client, &refresh_tok).await { + Ok(new_state) => return Ok(new_state), + Err(e) => { + tracing::debug!("ms365: refresh token failed, re-authenticating: {e}"); + } + } + } + } + + match self.config.auth_flow.as_str() { + "client_credentials" => self.client_credentials_flow(client).await, + "device_code" => self.device_code_flow(client).await, + other => anyhow::bail!("Unsupported auth flow: {other}"), + } + } + + async fn client_credentials_flow( + &self, + client: &reqwest::Client, + ) -> anyhow::Result { + let client_secret = self + .config + .client_secret + .as_deref() + .context("client_credentials flow requires client_secret")?; + + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let scope = self.config.scopes.join(" "); + + let resp = client + .post(&token_url) + .form(&[ + ("grant_type", "client_credentials"), + ("client_id", &self.config.client_id), + ("client_secret", client_secret), + ("scope", &scope), + ]) + .send() + .await + .context("ms365: failed to request client_credentials token")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::debug!("ms365: client_credentials raw OAuth error: {body}"); + anyhow::bail!("ms365: client_credentials token request failed ({status})"); + } + + let token_resp: TokenResponse = resp + .json() + .await + .context("ms365: failed to parse token response")?; + + Ok(CachedTokenState { + access_token: token_resp.access_token, + refresh_token: token_resp.refresh_token, + expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, + }) + } + + async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result { + let device_code_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode", + self.config.tenant_id + ); + let scope = self.config.scopes.join(" "); + + let resp = client + .post(&device_code_url) + .form(&[ + ("client_id", self.config.client_id.as_str()), + ("scope", &scope), + ]) + .send() + .await + .context("ms365: failed to request device code")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::debug!("ms365: device_code initiation raw error: {body}"); + anyhow::bail!("ms365: device code request failed ({status})"); + } + + let device_resp: DeviceCodeResponse = resp + .json() + .await + .context("ms365: failed to parse device code response")?; + + // Log only a generic prompt; the full device_resp.message may contain + // sensitive verification URIs or codes that should not appear in logs. + tracing::info!( + "ms365: device code auth required — follow the instructions shown to the user" + ); + // Print the user-facing message to stderr so the operator can act on it + // without it being captured in structured log sinks. + eprintln!("ms365: {}", device_resp.message); + + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let interval = device_resp.interval.max(5); + let max_polls = u32::try_from( + (device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1), + ) + .unwrap_or(u32::MAX); + + for _ in 0..max_polls { + tokio::time::sleep(std::time::Duration::from_secs(interval)).await; + + let poll_resp = client + .post(&token_url) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("client_id", self.config.client_id.as_str()), + ("device_code", &device_resp.device_code), + ]) + .send() + .await + .context("ms365: failed to poll device code token")?; + + if poll_resp.status().is_success() { + let token_resp: TokenResponse = poll_resp + .json() + .await + .context("ms365: failed to parse token response")?; + return Ok(CachedTokenState { + access_token: token_resp.access_token, + refresh_token: token_resp.refresh_token, + expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, + }); + } + + let body = poll_resp.text().await.unwrap_or_default(); + if body.contains("authorization_pending") { + continue; + } + if body.contains("slow_down") { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + continue; + } + tracing::debug!("ms365: device code polling raw error: {body}"); + anyhow::bail!("ms365: device code polling failed"); + } + + anyhow::bail!("ms365: device code flow timed out waiting for user authorization") + } + + async fn refresh_token( + &self, + client: &reqwest::Client, + refresh_token: &str, + ) -> anyhow::Result { + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let mut params = vec![ + ("grant_type", "refresh_token"), + ("client_id", self.config.client_id.as_str()), + ("refresh_token", refresh_token), + ]; + + let secret_ref; + if let Some(ref secret) = self.config.client_secret { + secret_ref = secret.as_str(); + params.push(("client_secret", secret_ref)); + } + + let resp = client + .post(&token_url) + .form(¶ms) + .send() + .await + .context("ms365: failed to refresh token")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::debug!("ms365: token refresh raw error: {body}"); + anyhow::bail!("ms365: token refresh failed ({status})"); + } + + let token_resp: TokenResponse = resp + .json() + .await + .context("ms365: failed to parse refresh token response")?; + + Ok(CachedTokenState { + access_token: token_resp.access_token, + refresh_token: token_resp + .refresh_token + .or_else(|| Some(refresh_token.to_string())), + expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, + }) + } + + fn load_from_disk(path: &std::path::Path) -> Option { + let data = std::fs::read_to_string(path).ok()?; + serde_json::from_str(&data).ok() + } + + fn persist_to_disk(&self, state: &CachedTokenState) { + if let Ok(json) = serde_json::to_string_pretty(state) { + if let Err(e) = std::fs::write(&self.cache_path, json) { + tracing::warn!("ms365: failed to persist token cache: {e}"); + } + } + } +} + +#[derive(Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default = "default_expires_in")] + expires_in: i64, +} + +fn default_expires_in() -> i64 { + 3600 +} + +#[derive(Deserialize)] +struct DeviceCodeResponse { + device_code: String, + message: String, + #[serde(default = "default_device_interval")] + interval: u64, + #[serde(default = "default_device_expires_in")] + expires_in: i64, +} + +fn default_device_interval() -> u64 { + 5 +} + +fn default_device_expires_in() -> i64 { + 900 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_is_expired_when_past_deadline() { + let state = CachedTokenState { + access_token: "test".into(), + refresh_token: None, + expires_at: chrono::Utc::now().timestamp() - 10, + }; + assert!(state.is_expired()); + } + + #[test] + fn token_is_expired_within_buffer() { + let state = CachedTokenState { + access_token: "test".into(), + refresh_token: None, + expires_at: chrono::Utc::now().timestamp() + 30, + }; + assert!(state.is_expired()); + } + + #[test] + fn token_is_valid_when_far_from_expiry() { + let state = CachedTokenState { + access_token: "test".into(), + refresh_token: None, + expires_at: chrono::Utc::now().timestamp() + 3600, + }; + assert!(!state.is_expired()); + } + + #[test] + fn load_from_disk_returns_none_for_missing_file() { + let path = std::path::Path::new("/nonexistent/ms365_token_cache.json"); + assert!(TokenCache::load_from_disk(path).is_none()); + } +} diff --git a/src/tools/microsoft365/graph_client.rs b/src/tools/microsoft365/graph_client.rs new file mode 100644 index 000000000..0cda00247 --- /dev/null +++ b/src/tools/microsoft365/graph_client.rs @@ -0,0 +1,495 @@ +use anyhow::Context; + +const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0"; + +/// Build the user path segment: `/me` or `/users/{user_id}`. +/// The user_id is percent-encoded to prevent path-traversal attacks. +fn user_path(user_id: &str) -> String { + if user_id == "me" { + "/me".to_string() + } else { + format!("/users/{}", urlencoding::encode(user_id)) + } +} + +/// Percent-encode a single path segment to prevent path-traversal attacks. +fn encode_path_segment(segment: &str) -> String { + urlencoding::encode(segment).into_owned() +} + +/// List mail messages for a user. +pub async fn mail_list( + client: &reqwest::Client, + token: &str, + user_id: &str, + folder: Option<&str>, + top: u32, +) -> anyhow::Result { + let base = user_path(user_id); + let path = match folder { + Some(f) => format!( + "{GRAPH_BASE}{base}/mailFolders/{}/messages", + encode_path_segment(f) + ), + None => format!("{GRAPH_BASE}{base}/messages"), + }; + + let resp = client + .get(&path) + .bearer_auth(token) + .query(&[("$top", top.to_string())]) + .send() + .await + .context("ms365: mail_list request failed")?; + + handle_json_response(resp, "mail_list").await +} + +/// Send a mail message. +pub async fn mail_send( + client: &reqwest::Client, + token: &str, + user_id: &str, + to: &[String], + subject: &str, + body: &str, +) -> anyhow::Result<()> { + let base = user_path(user_id); + let url = format!("{GRAPH_BASE}{base}/sendMail"); + + let to_recipients: Vec = to + .iter() + .map(|addr| { + serde_json::json!({ + "emailAddress": { "address": addr } + }) + }) + .collect(); + + let payload = serde_json::json!({ + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": body + }, + "toRecipients": to_recipients + } + }); + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: mail_send request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: mail_send raw error body: {body}"); + anyhow::bail!("ms365: mail_send failed ({status}, code={code})"); + } + + Ok(()) +} + +/// List messages in a Teams channel. +pub async fn teams_message_list( + client: &reqwest::Client, + token: &str, + team_id: &str, + channel_id: &str, + top: u32, +) -> anyhow::Result { + let url = format!( + "{GRAPH_BASE}/teams/{}/channels/{}/messages", + encode_path_segment(team_id), + encode_path_segment(channel_id) + ); + + let resp = client + .get(&url) + .bearer_auth(token) + .query(&[("$top", top.to_string())]) + .send() + .await + .context("ms365: teams_message_list request failed")?; + + handle_json_response(resp, "teams_message_list").await +} + +/// Send a message to a Teams channel. +pub async fn teams_message_send( + client: &reqwest::Client, + token: &str, + team_id: &str, + channel_id: &str, + body: &str, +) -> anyhow::Result<()> { + let url = format!( + "{GRAPH_BASE}/teams/{}/channels/{}/messages", + encode_path_segment(team_id), + encode_path_segment(channel_id) + ); + + let payload = serde_json::json!({ + "body": { + "content": body + } + }); + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: teams_message_send request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: teams_message_send raw error body: {body}"); + anyhow::bail!("ms365: teams_message_send failed ({status}, code={code})"); + } + + Ok(()) +} + +/// List calendar events in a date range. +pub async fn calendar_events_list( + client: &reqwest::Client, + token: &str, + user_id: &str, + start: &str, + end: &str, + top: u32, +) -> anyhow::Result { + let base = user_path(user_id); + let url = format!("{GRAPH_BASE}{base}/calendarView"); + + let resp = client + .get(&url) + .bearer_auth(token) + .query(&[ + ("startDateTime", start.to_string()), + ("endDateTime", end.to_string()), + ("$top", top.to_string()), + ]) + .send() + .await + .context("ms365: calendar_events_list request failed")?; + + handle_json_response(resp, "calendar_events_list").await +} + +/// Create a calendar event. +pub async fn calendar_event_create( + client: &reqwest::Client, + token: &str, + user_id: &str, + subject: &str, + start: &str, + end: &str, + attendees: &[String], + body_text: Option<&str>, +) -> anyhow::Result { + let base = user_path(user_id); + let url = format!("{GRAPH_BASE}{base}/events"); + + let attendee_list: Vec = attendees + .iter() + .map(|email| { + serde_json::json!({ + "emailAddress": { "address": email }, + "type": "required" + }) + }) + .collect(); + + let mut payload = serde_json::json!({ + "subject": subject, + "start": { + "dateTime": start, + "timeZone": "UTC" + }, + "end": { + "dateTime": end, + "timeZone": "UTC" + }, + "attendees": attendee_list + }); + + if let Some(text) = body_text { + payload["body"] = serde_json::json!({ + "contentType": "Text", + "content": text + }); + } + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: calendar_event_create request failed")?; + + let value = handle_json_response(resp, "calendar_event_create").await?; + let event_id = value["id"].as_str().unwrap_or("unknown").to_string(); + Ok(event_id) +} + +/// Delete a calendar event by ID. +pub async fn calendar_event_delete( + client: &reqwest::Client, + token: &str, + user_id: &str, + event_id: &str, +) -> anyhow::Result<()> { + let base = user_path(user_id); + let url = format!( + "{GRAPH_BASE}{base}/events/{}", + encode_path_segment(event_id) + ); + + let resp = client + .delete(&url) + .bearer_auth(token) + .send() + .await + .context("ms365: calendar_event_delete request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: calendar_event_delete raw error body: {body}"); + anyhow::bail!("ms365: calendar_event_delete failed ({status}, code={code})"); + } + + Ok(()) +} + +/// List children of a OneDrive folder. +pub async fn onedrive_list( + client: &reqwest::Client, + token: &str, + user_id: &str, + path: Option<&str>, +) -> anyhow::Result { + let base = user_path(user_id); + let url = match path { + Some(p) if !p.is_empty() => { + let encoded = urlencoding::encode(p); + format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children") + } + _ => format!("{GRAPH_BASE}{base}/drive/root/children"), + }; + + let resp = client + .get(&url) + .bearer_auth(token) + .send() + .await + .context("ms365: onedrive_list request failed")?; + + handle_json_response(resp, "onedrive_list").await +} + +/// Download a OneDrive item by ID, with a maximum size guard. +pub async fn onedrive_download( + client: &reqwest::Client, + token: &str, + user_id: &str, + item_id: &str, + max_size: usize, +) -> anyhow::Result> { + let base = user_path(user_id); + let url = format!( + "{GRAPH_BASE}{base}/drive/items/{}/content", + encode_path_segment(item_id) + ); + + let resp = client + .get(&url) + .bearer_auth(token) + .send() + .await + .context("ms365: onedrive_download request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: onedrive_download raw error body: {body}"); + anyhow::bail!("ms365: onedrive_download failed ({status}, code={code})"); + } + + let bytes = resp + .bytes() + .await + .context("ms365: failed to read download body")?; + if bytes.len() > max_size { + anyhow::bail!( + "ms365: downloaded file exceeds max_size ({} > {max_size})", + bytes.len() + ); + } + + Ok(bytes.to_vec()) +} + +/// Search SharePoint for documents matching a query. +pub async fn sharepoint_search( + client: &reqwest::Client, + token: &str, + query: &str, + top: u32, +) -> anyhow::Result { + let url = format!("{GRAPH_BASE}/search/query"); + + let payload = serde_json::json!({ + "requests": [{ + "entityTypes": ["driveItem", "listItem", "site"], + "query": { + "queryString": query + }, + "from": 0, + "size": top + }] + }); + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: sharepoint_search request failed")?; + + handle_json_response(resp, "sharepoint_search").await +} + +/// Extract a short, safe error code from a Graph API JSON error body. +/// Returns `None` when the body is not a recognised Graph error envelope. +fn extract_graph_error_code(body: &str) -> Option { + let parsed: serde_json::Value = serde_json::from_str(body).ok()?; + let code = parsed + .get("error") + .and_then(|e| e.get("code")) + .and_then(|c| c.as_str()) + .map(|s| s.to_string()); + code +} + +/// Parse a JSON response body, returning an error on non-success status. +/// Raw Graph API error bodies are not propagated; only the HTTP status and a +/// short error code (when available) are surfaced to avoid leaking internal +/// API details. +async fn handle_json_response( + resp: reqwest::Response, + operation: &str, +) -> anyhow::Result { + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: {operation} raw error body: {body}"); + anyhow::bail!("ms365: {operation} failed ({status}, code={code})"); + } + + resp.json() + .await + .with_context(|| format!("ms365: failed to parse {operation} response")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn user_path_me() { + assert_eq!(user_path("me"), "/me"); + } + + #[test] + fn user_path_specific_user() { + assert_eq!(user_path("user@contoso.com"), "/users/user%40contoso.com"); + } + + #[test] + fn mail_list_url_no_folder() { + let base = user_path("me"); + let url = format!("{GRAPH_BASE}{base}/messages"); + assert_eq!(url, "https://graph.microsoft.com/v1.0/me/messages"); + } + + #[test] + fn mail_list_url_with_folder() { + let base = user_path("me"); + let folder = "inbox"; + let url = format!( + "{GRAPH_BASE}{base}/mailFolders/{}/messages", + encode_path_segment(folder) + ); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/me/mailFolders/inbox/messages" + ); + } + + #[test] + fn calendar_view_url() { + let base = user_path("user@example.com"); + let url = format!("{GRAPH_BASE}{base}/calendarView"); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/users/user%40example.com/calendarView" + ); + } + + #[test] + fn teams_message_url() { + let url = format!( + "{GRAPH_BASE}/teams/{}/channels/{}/messages", + encode_path_segment("team-123"), + encode_path_segment("channel-456") + ); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/teams/team-123/channels/channel-456/messages" + ); + } + + #[test] + fn onedrive_root_url() { + let base = user_path("me"); + let url = format!("{GRAPH_BASE}{base}/drive/root/children"); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/me/drive/root/children" + ); + } + + #[test] + fn onedrive_path_url() { + let base = user_path("me"); + let encoded = urlencoding::encode("Documents/Reports"); + let url = format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children"); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/me/drive/root:/Documents%2FReports:/children" + ); + } + + #[test] + fn sharepoint_search_url() { + let url = format!("{GRAPH_BASE}/search/query"); + assert_eq!(url, "https://graph.microsoft.com/v1.0/search/query"); + } +} diff --git a/src/tools/microsoft365/mod.rs b/src/tools/microsoft365/mod.rs new file mode 100644 index 000000000..1876556e5 --- /dev/null +++ b/src/tools/microsoft365/mod.rs @@ -0,0 +1,567 @@ +//! Microsoft 365 integration tool — Graph API access for Mail, Teams, Calendar, +//! OneDrive, and SharePoint via a single action-dispatched tool surface. +//! +//! Auth is handled through direct HTTP calls to the Microsoft identity platform +//! (client credentials or device code flow) with token caching. + +pub mod auth; +pub mod graph_client; +pub mod types; + +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; +use crate::tools::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +/// Maximum download size for OneDrive files (10 MB). +const MAX_ONEDRIVE_DOWNLOAD_SIZE: usize = 10 * 1024 * 1024; + +/// Default number of items to return in list operations. +const DEFAULT_TOP: u32 = 25; + +pub struct Microsoft365Tool { + config: types::Microsoft365ResolvedConfig, + security: Arc, + token_cache: Arc, + http_client: reqwest::Client, +} + +impl Microsoft365Tool { + pub fn new( + config: types::Microsoft365ResolvedConfig, + security: Arc, + zeroclaw_dir: &std::path::Path, + ) -> anyhow::Result { + let http_client = + crate::config::build_runtime_proxy_client_with_timeouts("tool.microsoft365", 60, 10); + let token_cache = Arc::new(auth::TokenCache::new(config.clone(), zeroclaw_dir)?); + Ok(Self { + config, + security, + token_cache, + http_client, + }) + } + + async fn get_token(&self) -> anyhow::Result { + self.token_cache.get_token(&self.http_client).await + } + + fn user_id(&self) -> &str { + &self.config.user_id + } + + async fn dispatch(&self, action: &str, args: &serde_json::Value) -> anyhow::Result { + match action { + "mail_list" => self.handle_mail_list(args).await, + "mail_send" => self.handle_mail_send(args).await, + "teams_message_list" => self.handle_teams_message_list(args).await, + "teams_message_send" => self.handle_teams_message_send(args).await, + "calendar_events_list" => self.handle_calendar_events_list(args).await, + "calendar_event_create" => self.handle_calendar_event_create(args).await, + "calendar_event_delete" => self.handle_calendar_event_delete(args).await, + "onedrive_list" => self.handle_onedrive_list(args).await, + "onedrive_download" => self.handle_onedrive_download(args).await, + "sharepoint_search" => self.handle_sharepoint_search(args).await, + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown action: {action}")), + }), + } + } + + // ── Read actions ──────────────────────────────────────────────── + + async fn handle_mail_list(&self, args: &serde_json::Value) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.mail_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let folder = args["folder"].as_str(); + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = + graph_client::mail_list(&self.http_client, &token, self.user_id(), folder, top).await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_teams_message_list( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.teams_message_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let team_id = args["team_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("team_id is required"))?; + let channel_id = args["channel_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("channel_id is required"))?; + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = + graph_client::teams_message_list(&self.http_client, &token, team_id, channel_id, top) + .await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_calendar_events_list( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.calendar_events_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let start = args["start"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("start datetime is required"))?; + let end = args["end"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("end datetime is required"))?; + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = graph_client::calendar_events_list( + &self.http_client, + &token, + self.user_id(), + start, + end, + top, + ) + .await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_onedrive_list(&self, args: &serde_json::Value) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let path = args["path"].as_str(); + + let result = + graph_client::onedrive_list(&self.http_client, &token, self.user_id(), path).await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_onedrive_download( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_download") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let item_id = args["item_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("item_id is required"))?; + let max_size = args["max_size"] + .as_u64() + .and_then(|v| usize::try_from(v).ok()) + .unwrap_or(MAX_ONEDRIVE_DOWNLOAD_SIZE) + .min(MAX_ONEDRIVE_DOWNLOAD_SIZE); + + let bytes = graph_client::onedrive_download( + &self.http_client, + &token, + self.user_id(), + item_id, + max_size, + ) + .await?; + + // Return base64-encoded for binary safety. + use base64::Engine; + let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + + Ok(ToolResult { + success: true, + output: format!( + "Downloaded {} bytes (base64 encoded):\n{encoded}", + bytes.len() + ), + error: None, + }) + } + + async fn handle_sharepoint_search( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.sharepoint_search") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let query = args["query"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("query is required"))?; + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = graph_client::sharepoint_search(&self.http_client, &token, query, top).await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + // ── Write actions ─────────────────────────────────────────────── + + async fn handle_mail_send(&self, args: &serde_json::Value) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.mail_send") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let to: Vec = args["to"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("to must be an array of email addresses"))? + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + if to.is_empty() { + anyhow::bail!("to must contain at least one email address"); + } + + let subject = args["subject"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("subject is required"))?; + let body = args["body"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("body is required"))?; + + graph_client::mail_send( + &self.http_client, + &token, + self.user_id(), + &to, + subject, + body, + ) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Email sent to: {}", to.join(", ")), + error: None, + }) + } + + async fn handle_teams_message_send( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.teams_message_send") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let team_id = args["team_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("team_id is required"))?; + let channel_id = args["channel_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("channel_id is required"))?; + let body = args["body"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("body is required"))?; + + graph_client::teams_message_send(&self.http_client, &token, team_id, channel_id, body) + .await?; + + Ok(ToolResult { + success: true, + output: "Teams message sent".to_string(), + error: None, + }) + } + + async fn handle_calendar_event_create( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_create") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let subject = args["subject"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("subject is required"))?; + let start = args["start"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("start datetime is required"))?; + let end = args["end"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("end datetime is required"))?; + let attendees: Vec = args["attendees"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + let body_text = args["body"].as_str(); + + let event_id = graph_client::calendar_event_create( + &self.http_client, + &token, + self.user_id(), + subject, + start, + end, + &attendees, + body_text, + ) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Calendar event created (id: {event_id})"), + error: None, + }) + } + + async fn handle_calendar_event_delete( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_delete") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let event_id = args["event_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("event_id is required"))?; + + graph_client::calendar_event_delete(&self.http_client, &token, self.user_id(), event_id) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Calendar event {event_id} deleted"), + error: None, + }) + } +} + +#[async_trait] +impl Tool for Microsoft365Tool { + fn name(&self) -> &str { + "microsoft365" + } + + fn description(&self) -> &str { + "Microsoft 365 integration: manage Outlook mail, Teams messages, Calendar events, \ + OneDrive files, and SharePoint search via Microsoft Graph API" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "required": ["action"], + "properties": { + "action": { + "type": "string", + "enum": [ + "mail_list", + "mail_send", + "teams_message_list", + "teams_message_send", + "calendar_events_list", + "calendar_event_create", + "calendar_event_delete", + "onedrive_list", + "onedrive_download", + "sharepoint_search" + ], + "description": "The Microsoft 365 action to perform" + }, + "folder": { + "type": "string", + "description": "Mail folder ID (for mail_list, e.g. 'inbox', 'sentitems')" + }, + "to": { + "type": "array", + "items": { "type": "string" }, + "description": "Recipient email addresses (for mail_send)" + }, + "subject": { + "type": "string", + "description": "Email subject or calendar event subject" + }, + "body": { + "type": "string", + "description": "Message body text" + }, + "team_id": { + "type": "string", + "description": "Teams team ID (for teams_message_list/send)" + }, + "channel_id": { + "type": "string", + "description": "Teams channel ID (for teams_message_list/send)" + }, + "start": { + "type": "string", + "description": "Start datetime in ISO 8601 format (for calendar actions)" + }, + "end": { + "type": "string", + "description": "End datetime in ISO 8601 format (for calendar actions)" + }, + "attendees": { + "type": "array", + "items": { "type": "string" }, + "description": "Attendee email addresses (for calendar_event_create)" + }, + "event_id": { + "type": "string", + "description": "Calendar event ID (for calendar_event_delete)" + }, + "path": { + "type": "string", + "description": "OneDrive folder path (for onedrive_list)" + }, + "item_id": { + "type": "string", + "description": "OneDrive item ID (for onedrive_download)" + }, + "max_size": { + "type": "integer", + "description": "Maximum download size in bytes (for onedrive_download, default 10MB)" + }, + "query": { + "type": "string", + "description": "Search query (for sharepoint_search)" + }, + "top": { + "type": "integer", + "description": "Maximum number of items to return (default 25)" + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = match args["action"].as_str() { + Some(a) => a.to_string(), + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'action' parameter is required".to_string()), + }); + } + }; + + match self.dispatch(&action, &args).await { + Ok(result) => Ok(result), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("microsoft365.{action} failed: {e}")), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tool_name_is_microsoft365() { + // Verify the schema is valid JSON with the expected structure. + let schema_str = r#"{"type":"object","required":["action"]}"#; + let _: serde_json::Value = serde_json::from_str(schema_str).unwrap(); + } + + #[test] + fn parameters_schema_has_action_enum() { + let schema = json!({ + "type": "object", + "required": ["action"], + "properties": { + "action": { + "type": "string", + "enum": [ + "mail_list", + "mail_send", + "teams_message_list", + "teams_message_send", + "calendar_events_list", + "calendar_event_create", + "calendar_event_delete", + "onedrive_list", + "onedrive_download", + "sharepoint_search" + ] + } + } + }); + + let actions = schema["properties"]["action"]["enum"].as_array().unwrap(); + assert_eq!(actions.len(), 10); + assert!(actions.contains(&json!("mail_list"))); + assert!(actions.contains(&json!("sharepoint_search"))); + } + + #[test] + fn action_dispatch_table_is_exhaustive() { + let valid_actions = [ + "mail_list", + "mail_send", + "teams_message_list", + "teams_message_send", + "calendar_events_list", + "calendar_event_create", + "calendar_event_delete", + "onedrive_list", + "onedrive_download", + "sharepoint_search", + ]; + assert_eq!(valid_actions.len(), 10); + assert!(!valid_actions.contains(&"invalid_action")); + } +} diff --git a/src/tools/microsoft365/types.rs b/src/tools/microsoft365/types.rs new file mode 100644 index 000000000..72a71f0a5 --- /dev/null +++ b/src/tools/microsoft365/types.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +/// Resolved Microsoft 365 configuration with all secrets decrypted and defaults applied. +#[derive(Clone, Serialize, Deserialize)] +pub struct Microsoft365ResolvedConfig { + pub tenant_id: String, + pub client_id: String, + pub client_secret: Option, + pub auth_flow: String, + pub scopes: Vec, + pub token_cache_encrypted: bool, + pub user_id: String, +} + +impl std::fmt::Debug for Microsoft365ResolvedConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Microsoft365ResolvedConfig") + .field("tenant_id", &self.tenant_id) + .field("client_id", &self.client_id) + .field("client_secret", &self.client_secret.as_ref().map(|_| "***")) + .field("auth_flow", &self.auth_flow) + .field("scopes", &self.scopes) + .field("token_cache_encrypted", &self.token_cache_encrypted) + .field("user_id", &self.user_id) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolved_config_serialization_roundtrip() { + let config = Microsoft365ResolvedConfig { + tenant_id: "test-tenant".into(), + client_id: "test-client".into(), + client_secret: Some("secret".into()), + auth_flow: "client_credentials".into(), + scopes: vec!["https://graph.microsoft.com/.default".into()], + token_cache_encrypted: false, + user_id: "me".into(), + }; + + let json = serde_json::to_string(&config).unwrap(); + let parsed: Microsoft365ResolvedConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.tenant_id, "test-tenant"); + assert_eq!(parsed.client_id, "test-client"); + assert_eq!(parsed.client_secret.as_deref(), Some("secret")); + assert_eq!(parsed.auth_flow, "client_credentials"); + assert_eq!(parsed.scopes.len(), 1); + assert_eq!(parsed.user_id, "me"); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 7ce81b2ca..5fe76ef6f 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -48,6 +48,7 @@ pub mod mcp_transport; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod microsoft365; pub mod model_routing_config; pub mod node_tool; pub mod notion_tool; @@ -95,6 +96,7 @@ pub use mcp_tool::McpToolWrapper; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use microsoft365::Microsoft365Tool; pub use model_routing_config::ModelRoutingConfigTool; #[allow(unused_imports)] pub use node_tool::NodeTool; @@ -379,6 +381,61 @@ pub fn all_tools_with_runtime( } } + // Microsoft 365 Graph API integration + if root_config.microsoft365.enabled { + let ms_cfg = &root_config.microsoft365; + let tenant_id = ms_cfg + .tenant_id + .as_deref() + .unwrap_or_default() + .trim() + .to_string(); + let client_id = ms_cfg + .client_id + .as_deref() + .unwrap_or_default() + .trim() + .to_string(); + if !tenant_id.is_empty() && !client_id.is_empty() { + // Fail fast: client_credentials flow requires a client_secret at registration time. + if ms_cfg.auth_flow.trim() == "client_credentials" + && ms_cfg + .client_secret + .as_deref() + .map_or(true, |s| s.trim().is_empty()) + { + tracing::error!( + "microsoft365: client_credentials auth_flow requires a non-empty client_secret" + ); + return (boxed_registry_from_arcs(tool_arcs), None); + } + + let resolved = microsoft365::types::Microsoft365ResolvedConfig { + tenant_id, + client_id, + client_secret: ms_cfg.client_secret.clone(), + auth_flow: ms_cfg.auth_flow.clone(), + scopes: ms_cfg.scopes.clone(), + token_cache_encrypted: ms_cfg.token_cache_encrypted, + user_id: ms_cfg.user_id.as_deref().unwrap_or("me").to_string(), + }; + // Store token cache in the config directory (next to config.toml), + // not the workspace directory, to keep bearer tokens out of the + // project tree. + let cache_dir = root_config.config_path.parent().unwrap_or(workspace_dir); + match Microsoft365Tool::new(resolved, security.clone(), cache_dir) { + Ok(tool) => tool_arcs.push(Arc::new(tool)), + Err(e) => { + tracing::error!("microsoft365: failed to initialize tool: {e}"); + } + } + } else { + tracing::warn!( + "microsoft365: skipped registration because tenant_id or client_id is empty" + ); + } + } + // Add delegation tool when agents are configured let delegate_fallback_credential = fallback_api_key.and_then(|value| { let trimmed_value = value.trim(); diff --git a/src/tools/project_intel.rs b/src/tools/project_intel.rs new file mode 100644 index 000000000..0e3372eb8 --- /dev/null +++ b/src/tools/project_intel.rs @@ -0,0 +1,750 @@ +//! Project delivery intelligence tool. +//! +//! Provides read-only analysis and generation for project management: +//! status reports, risk detection, client communication drafting, +//! sprint summaries, and effort estimation. + +use super::report_templates; +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::collections::HashMap; +use std::fmt::Write as _; + +/// Project intelligence tool for consulting project management. +/// +/// All actions are read-only analysis/generation; nothing is modified externally. +pub struct ProjectIntelTool { + default_language: String, + risk_sensitivity: RiskSensitivity, +} + +/// Risk detection sensitivity level. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RiskSensitivity { + Low, + Medium, + High, +} + +impl RiskSensitivity { + fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "low" => Self::Low, + "high" => Self::High, + _ => Self::Medium, + } + } + + /// Threshold multiplier: higher sensitivity means lower thresholds. + fn threshold_factor(self) -> f64 { + match self { + Self::Low => 1.5, + Self::Medium => 1.0, + Self::High => 0.5, + } + } +} + +impl ProjectIntelTool { + pub fn new(default_language: String, risk_sensitivity: String) -> Self { + Self { + default_language, + risk_sensitivity: RiskSensitivity::from_str(&risk_sensitivity), + } + } + + fn execute_status_report(&self, args: &serde_json::Value) -> anyhow::Result { + let project_name = args + .get("project_name") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for status_report"))?; + let period = args + .get("period") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'period' for status_report"))?; + let lang = args + .get("language") + .and_then(|v| v.as_str()) + .unwrap_or(&self.default_language); + let git_log = args + .get("git_log") + .and_then(|v| v.as_str()) + .unwrap_or("No git data provided"); + let jira_summary = args + .get("jira_summary") + .and_then(|v| v.as_str()) + .unwrap_or("No Jira data provided"); + let notes = args.get("notes").and_then(|v| v.as_str()).unwrap_or(""); + + let tpl = report_templates::weekly_status_template(lang); + let mut vars = HashMap::new(); + vars.insert("project_name".into(), project_name.to_string()); + vars.insert("period".into(), period.to_string()); + vars.insert("completed".into(), git_log.to_string()); + vars.insert("in_progress".into(), jira_summary.to_string()); + vars.insert("blocked".into(), notes.to_string()); + vars.insert("next_steps".into(), "To be determined".into()); + + let rendered = tpl.render(&vars); + Ok(ToolResult { + success: true, + output: rendered, + error: None, + }) + } + + fn execute_risk_scan(&self, args: &serde_json::Value) -> anyhow::Result { + let deadlines = args + .get("deadlines") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let velocity = args + .get("velocity") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let blockers = args + .get("blockers") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let lang = args + .get("language") + .and_then(|v| v.as_str()) + .unwrap_or(&self.default_language); + + let mut risks = Vec::new(); + + // Heuristic risk detection based on signals + let factor = self.risk_sensitivity.threshold_factor(); + + if !blockers.is_empty() { + let blocker_count = blockers.lines().filter(|l| !l.trim().is_empty()).count(); + let severity = if (blocker_count as f64) > 3.0 * factor { + "critical" + } else if (blocker_count as f64) > 1.0 * factor { + "high" + } else { + "medium" + }; + risks.push(RiskItem { + title: "Active blockers detected".into(), + severity: severity.into(), + detail: format!("{blocker_count} blocker(s) identified"), + mitigation: "Escalate blockers, assign owners, set resolution deadlines".into(), + }); + } + + if deadlines.to_lowercase().contains("overdue") + || deadlines.to_lowercase().contains("missed") + { + risks.push(RiskItem { + title: "Deadline risk".into(), + severity: "high".into(), + detail: "Overdue or missed deadlines detected in project context".into(), + mitigation: "Re-prioritize scope, negotiate timeline, add resources".into(), + }); + } + + if velocity.to_lowercase().contains("declining") || velocity.to_lowercase().contains("slow") + { + risks.push(RiskItem { + title: "Velocity degradation".into(), + severity: "medium".into(), + detail: "Team velocity is declining or below expectations".into(), + mitigation: "Identify bottlenecks, reduce WIP, address technical debt".into(), + }); + } + + if risks.is_empty() { + risks.push(RiskItem { + title: "No significant risks detected".into(), + severity: "low".into(), + detail: "Current project signals within normal parameters".into(), + mitigation: "Continue monitoring".into(), + }); + } + + let tpl = report_templates::risk_register_template(lang); + let risks_text = risks + .iter() + .map(|r| { + format!( + "- [{}] {}: {}", + r.severity.to_uppercase(), + r.title, + r.detail + ) + }) + .collect::>() + .join("\n"); + let mitigations_text = risks + .iter() + .map(|r| format!("- {}: {}", r.title, r.mitigation)) + .collect::>() + .join("\n"); + + let mut vars = HashMap::new(); + vars.insert( + "project_name".into(), + args.get("project_name") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown") + .to_string(), + ); + vars.insert("risks".into(), risks_text); + vars.insert("mitigations".into(), mitigations_text); + + Ok(ToolResult { + success: true, + output: tpl.render(&vars), + error: None, + }) + } + + fn execute_draft_update(&self, args: &serde_json::Value) -> anyhow::Result { + let project_name = args + .get("project_name") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for draft_update"))?; + let audience = args + .get("audience") + .and_then(|v| v.as_str()) + .unwrap_or("client"); + let tone = args + .get("tone") + .and_then(|v| v.as_str()) + .unwrap_or("formal"); + let highlights = args + .get("highlights") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'highlights' for draft_update"))?; + let concerns = args.get("concerns").and_then(|v| v.as_str()).unwrap_or(""); + + let greeting = match (audience, tone) { + ("client", "casual") => "Hi there,".to_string(), + ("client", _) => "Dear valued partner,".to_string(), + ("internal", "casual") => "Hey team,".to_string(), + ("internal", _) => "Dear team,".to_string(), + (_, "casual") => "Hi,".to_string(), + _ => "Dear reader,".to_string(), + }; + + let closing = match tone { + "casual" => "Cheers", + _ => "Best regards", + }; + + let mut body = format!( + "{greeting}\n\nHere is an update on {project_name}.\n\n**Highlights:**\n{highlights}" + ); + if !concerns.is_empty() { + let _ = write!(body, "\n\n**Items requiring attention:**\n{concerns}"); + } + let _ = write!( + body, + "\n\nPlease do not hesitate to reach out with any questions.\n\n{closing}" + ); + + Ok(ToolResult { + success: true, + output: body, + error: None, + }) + } + + fn execute_sprint_summary(&self, args: &serde_json::Value) -> anyhow::Result { + let sprint_dates = args + .get("sprint_dates") + .and_then(|v| v.as_str()) + .unwrap_or("current sprint"); + let completed = args + .get("completed") + .and_then(|v| v.as_str()) + .unwrap_or("None specified"); + let in_progress = args + .get("in_progress") + .and_then(|v| v.as_str()) + .unwrap_or("None specified"); + let blocked = args + .get("blocked") + .and_then(|v| v.as_str()) + .unwrap_or("None"); + let velocity = args + .get("velocity") + .and_then(|v| v.as_str()) + .unwrap_or("Not calculated"); + let lang = args + .get("language") + .and_then(|v| v.as_str()) + .unwrap_or(&self.default_language); + + let tpl = report_templates::sprint_review_template(lang); + let mut vars = HashMap::new(); + vars.insert("sprint_dates".into(), sprint_dates.to_string()); + vars.insert("completed".into(), completed.to_string()); + vars.insert("in_progress".into(), in_progress.to_string()); + vars.insert("blocked".into(), blocked.to_string()); + vars.insert("velocity".into(), velocity.to_string()); + + Ok(ToolResult { + success: true, + output: tpl.render(&vars), + error: None, + }) + } + + fn execute_effort_estimate(&self, args: &serde_json::Value) -> anyhow::Result { + let tasks = args.get("tasks").and_then(|v| v.as_str()).unwrap_or(""); + + if tasks.trim().is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("No task descriptions provided".into()), + }); + } + + let mut estimates = Vec::new(); + for line in tasks.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + let (size, rationale) = estimate_task_effort(line); + estimates.push(format!("- **{size}** | {line}\n Rationale: {rationale}")); + } + + let output = format!( + "## Effort Estimates\n\n{}\n\n_Sizes: XS (<2h), S (2-4h), M (4-8h), L (1-3d), XL (3-5d), XXL (>5d)_", + estimates.join("\n") + ); + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +struct RiskItem { + title: String, + severity: String, + detail: String, + mitigation: String, +} + +/// Heuristic effort estimation from task description text. +fn estimate_task_effort(description: &str) -> (&'static str, &'static str) { + let lower = description.to_lowercase(); + let word_count = description.split_whitespace().count(); + + // Signal-based heuristics + let complexity_signals = [ + "refactor", + "rewrite", + "migrate", + "redesign", + "architecture", + "infrastructure", + ]; + let medium_signals = [ + "implement", + "create", + "build", + "integrate", + "add feature", + "new module", + ]; + let small_signals = [ + "fix", "update", "tweak", "adjust", "rename", "typo", "bump", "config", + ]; + + if complexity_signals.iter().any(|s| lower.contains(s)) { + if word_count > 15 { + return ( + "XXL", + "Large-scope structural change with extensive description", + ); + } + return ("XL", "Structural change requiring significant effort"); + } + + if medium_signals.iter().any(|s| lower.contains(s)) { + if word_count > 12 { + return ("L", "Feature implementation with detailed requirements"); + } + return ("M", "Standard feature implementation"); + } + + if small_signals.iter().any(|s| lower.contains(s)) { + if word_count > 10 { + return ("S", "Small change with additional context"); + } + return ("XS", "Minor targeted change"); + } + + // Fallback: estimate by description length as a proxy for complexity + if word_count > 20 { + ("L", "Complex task inferred from detailed description") + } else if word_count > 10 { + ("M", "Moderate task inferred from description length") + } else { + ("S", "Simple task inferred from brief description") + } +} + +#[async_trait] +impl Tool for ProjectIntelTool { + fn name(&self) -> &str { + "project_intel" + } + + fn description(&self) -> &str { + "Project delivery intelligence: generate status reports, detect risks, draft client updates, summarize sprints, and estimate effort. Read-only analysis tool." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["status_report", "risk_scan", "draft_update", "sprint_summary", "effort_estimate"], + "description": "The analysis action to perform" + }, + "project_name": { + "type": "string", + "description": "Project name (for status_report, risk_scan, draft_update)" + }, + "period": { + "type": "string", + "description": "Reporting period: week, sprint, or month (for status_report)" + }, + "language": { + "type": "string", + "description": "Report language: en, de, fr, it (default from config)" + }, + "git_log": { + "type": "string", + "description": "Git log summary text (for status_report)" + }, + "jira_summary": { + "type": "string", + "description": "Jira/issue tracker summary (for status_report)" + }, + "notes": { + "type": "string", + "description": "Additional notes or context" + }, + "deadlines": { + "type": "string", + "description": "Deadline information (for risk_scan)" + }, + "velocity": { + "type": "string", + "description": "Team velocity data (for risk_scan, sprint_summary)" + }, + "blockers": { + "type": "string", + "description": "Current blockers (for risk_scan)" + }, + "audience": { + "type": "string", + "enum": ["client", "internal"], + "description": "Target audience (for draft_update)" + }, + "tone": { + "type": "string", + "enum": ["formal", "casual"], + "description": "Communication tone (for draft_update)" + }, + "highlights": { + "type": "string", + "description": "Key highlights for the update (for draft_update)" + }, + "concerns": { + "type": "string", + "description": "Items requiring attention (for draft_update)" + }, + "sprint_dates": { + "type": "string", + "description": "Sprint date range (for sprint_summary)" + }, + "completed": { + "type": "string", + "description": "Completed items (for sprint_summary)" + }, + "in_progress": { + "type": "string", + "description": "In-progress items (for sprint_summary)" + }, + "blocked": { + "type": "string", + "description": "Blocked items (for sprint_summary)" + }, + "tasks": { + "type": "string", + "description": "Task descriptions, one per line (for effort_estimate)" + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = args + .get("action") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?; + + match action { + "status_report" => self.execute_status_report(&args), + "risk_scan" => self.execute_risk_scan(&args), + "draft_update" => self.execute_draft_update(&args), + "sprint_summary" => self.execute_sprint_summary(&args), + "effort_estimate" => self.execute_effort_estimate(&args), + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{other}'. Valid actions: status_report, risk_scan, draft_update, sprint_summary, effort_estimate" + )), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tool() -> ProjectIntelTool { + ProjectIntelTool::new("en".into(), "medium".into()) + } + + #[test] + fn tool_name_and_description() { + let t = tool(); + assert_eq!(t.name(), "project_intel"); + assert!(!t.description().is_empty()); + } + + #[test] + fn parameters_schema_has_action() { + let t = tool(); + let schema = t.parameters_schema(); + assert!(schema["properties"]["action"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::Value::String("action".into()))); + } + + #[tokio::test] + async fn status_report_renders() { + let t = tool(); + let result = t + .execute(json!({ + "action": "status_report", + "project_name": "TestProject", + "period": "week", + "git_log": "- feat: added login" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("TestProject")); + assert!(result.output.contains("added login")); + } + + #[tokio::test] + async fn risk_scan_detects_blockers() { + let t = tool(); + let result = t + .execute(json!({ + "action": "risk_scan", + "blockers": "DB migration stuck\nCI pipeline broken\nAPI key expired" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("blocker")); + } + + #[tokio::test] + async fn risk_scan_detects_deadline_risk() { + let t = tool(); + let result = t + .execute(json!({ + "action": "risk_scan", + "deadlines": "Sprint deadline overdue by 3 days" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Deadline risk")); + } + + #[tokio::test] + async fn risk_scan_no_signals_returns_low_risk() { + let t = tool(); + let result = t.execute(json!({ "action": "risk_scan" })).await.unwrap(); + assert!(result.success); + assert!(result.output.contains("No significant risks")); + } + + #[tokio::test] + async fn draft_update_formal_client() { + let t = tool(); + let result = t + .execute(json!({ + "action": "draft_update", + "project_name": "Portal", + "audience": "client", + "tone": "formal", + "highlights": "Phase 1 delivered" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Dear valued partner")); + assert!(result.output.contains("Portal")); + assert!(result.output.contains("Phase 1 delivered")); + } + + #[tokio::test] + async fn draft_update_casual_internal() { + let t = tool(); + let result = t + .execute(json!({ + "action": "draft_update", + "project_name": "ZeroClaw", + "audience": "internal", + "tone": "casual", + "highlights": "Core loop stabilized" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Hey team")); + assert!(result.output.contains("Cheers")); + } + + #[tokio::test] + async fn sprint_summary_renders() { + let t = tool(); + let result = t + .execute(json!({ + "action": "sprint_summary", + "sprint_dates": "2026-03-01 to 2026-03-14", + "completed": "- Login page\n- API endpoints", + "in_progress": "- Dashboard", + "blocked": "- Payment integration" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Login page")); + assert!(result.output.contains("Dashboard")); + } + + #[tokio::test] + async fn effort_estimate_basic() { + let t = tool(); + let result = t + .execute(json!({ + "action": "effort_estimate", + "tasks": "Fix typo in README\nImplement user authentication\nRefactor database layer" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("XS")); + assert!(result.output.contains("Refactor database layer")); + } + + #[tokio::test] + async fn effort_estimate_empty_tasks_fails() { + let t = tool(); + let result = t + .execute(json!({ "action": "effort_estimate", "tasks": "" })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("No task descriptions")); + } + + #[tokio::test] + async fn unknown_action_returns_error() { + let t = tool(); + let result = t + .execute(json!({ "action": "invalid_thing" })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown action")); + } + + #[tokio::test] + async fn missing_action_returns_error() { + let t = tool(); + let result = t.execute(json!({})).await; + assert!(result.is_err()); + } + + #[test] + fn effort_estimate_heuristics_coverage() { + assert_eq!(estimate_task_effort("Fix typo").0, "XS"); + assert_eq!(estimate_task_effort("Update config values").0, "XS"); + assert_eq!( + estimate_task_effort("Implement new notification system").0, + "M" + ); + assert_eq!( + estimate_task_effort("Refactor the entire authentication module").0, + "XL" + ); + assert_eq!( + estimate_task_effort("Migrate the database schema to support multi-tenancy with data isolation and proper indexing across all services").0, + "XXL" + ); + } + + #[test] + fn risk_sensitivity_threshold_ordering() { + assert!( + RiskSensitivity::High.threshold_factor() < RiskSensitivity::Medium.threshold_factor() + ); + assert!( + RiskSensitivity::Medium.threshold_factor() < RiskSensitivity::Low.threshold_factor() + ); + } + + #[test] + fn risk_sensitivity_from_str_variants() { + assert_eq!(RiskSensitivity::from_str("low"), RiskSensitivity::Low); + assert_eq!(RiskSensitivity::from_str("high"), RiskSensitivity::High); + assert_eq!(RiskSensitivity::from_str("medium"), RiskSensitivity::Medium); + assert_eq!( + RiskSensitivity::from_str("unknown"), + RiskSensitivity::Medium + ); + } + + #[tokio::test] + async fn high_sensitivity_detects_single_blocker_as_high() { + let t = ProjectIntelTool::new("en".into(), "high".into()); + let result = t + .execute(json!({ + "action": "risk_scan", + "blockers": "Single blocker" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("[HIGH]") || result.output.contains("[CRITICAL]")); + } +} diff --git a/src/tools/report_templates.rs b/src/tools/report_templates.rs new file mode 100644 index 000000000..930ecbeff --- /dev/null +++ b/src/tools/report_templates.rs @@ -0,0 +1,582 @@ +//! Report template engine for project delivery intelligence. +//! +//! Provides built-in templates for weekly status, sprint review, risk register, +//! and milestone reports with multi-language support (EN, DE, FR, IT). + +use std::collections::HashMap; +use std::fmt::Write as _; + +/// Supported report output formats. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReportFormat { + Markdown, + Html, +} + +/// A named section within a report template. +#[derive(Debug, Clone)] +pub struct TemplateSection { + pub heading: String, + pub body: String, +} + +/// A report template with named sections and variable placeholders. +#[derive(Debug, Clone)] +pub struct ReportTemplate { + pub name: String, + pub sections: Vec, + pub format: ReportFormat, +} + +/// Escape a string for safe inclusion in HTML output. +fn escape_html(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +impl ReportTemplate { + /// Render the template by substituting `{{key}}` placeholders with values. + pub fn render(&self, vars: &HashMap) -> String { + let mut out = String::new(); + for section in &self.sections { + let heading = substitute(§ion.heading, vars); + let body = substitute(§ion.body, vars); + match self.format { + ReportFormat::Markdown => { + let _ = write!(out, "## {heading}\n\n{body}\n\n"); + } + ReportFormat::Html => { + let heading = escape_html(&heading); + let body = escape_html(&body); + let _ = write!(out, "

{heading}

\n

{body}

\n"); + } + } + } + out.trim_end().to_string() + } +} + +/// Single-pass placeholder substitution. +/// +/// Scans `template` left-to-right for `{{key}}` tokens and replaces them with +/// the corresponding value from `vars`. Because the scan is single-pass, +/// values that themselves contain `{{...}}` sequences are emitted literally +/// and never re-expanded, preventing injection of new placeholders. +fn substitute(template: &str, vars: &HashMap) -> String { + let mut result = String::with_capacity(template.len()); + let bytes = template.as_bytes(); + let len = bytes.len(); + let mut i = 0; + + while i < len { + if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' { + // Find the closing `}}`. + if let Some(close) = template[i + 2..].find("}}") { + let key = &template[i + 2..i + 2 + close]; + if let Some(value) = vars.get(key) { + result.push_str(value); + } else { + // Unknown placeholder: emit as-is. + result.push_str(&template[i..i + 2 + close + 2]); + } + i += 2 + close + 2; + continue; + } + } + result.push(template.as_bytes()[i] as char); + i += 1; + } + + result +} + +// ── Built-in templates ──────────────────────────────────────────── + +/// Return the built-in weekly status template for the given language. +pub fn weekly_status_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Wochenstatus", + vec![ + TemplateSection { + heading: "Zusammenfassung".into(), + body: "Projekt: {{project_name}} | Zeitraum: {{period}}".into(), + }, + TemplateSection { + heading: "Erledigt".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Bearbeitung".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blockiert".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Naechste Schritte".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + "fr" => ( + "Statut hebdomadaire", + vec![ + TemplateSection { + heading: "Resume".into(), + body: "Projet: {{project_name}} | Periode: {{period}}".into(), + }, + TemplateSection { + heading: "Termine".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "En cours".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloque".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Prochaines etapes".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + "it" => ( + "Stato settimanale", + vec![ + TemplateSection { + heading: "Riepilogo".into(), + body: "Progetto: {{project_name}} | Periodo: {{period}}".into(), + }, + TemplateSection { + heading: "Completato".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In corso".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloccato".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Prossimi passi".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + _ => ( + "Weekly Status", + vec![ + TemplateSection { + heading: "Summary".into(), + body: "Project: {{project_name}} | Period: {{period}}".into(), + }, + TemplateSection { + heading: "Completed".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Progress".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blocked".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Next Steps".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +/// Return the built-in sprint review template for the given language. +pub fn sprint_review_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Sprint-Uebersicht", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Erledigt".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Bearbeitung".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blockiert".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocity".into(), + body: "{{velocity}}".into(), + }, + ], + ), + "fr" => ( + "Revue de sprint", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Termine".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "En cours".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloque".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocite".into(), + body: "{{velocity}}".into(), + }, + ], + ), + "it" => ( + "Revisione sprint", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Completato".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In corso".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloccato".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocita".into(), + body: "{{velocity}}".into(), + }, + ], + ), + _ => ( + "Sprint Review", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Completed".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Progress".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blocked".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocity".into(), + body: "{{velocity}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +/// Return the built-in risk register template for the given language. +pub fn risk_register_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Risikoregister", + vec![ + TemplateSection { + heading: "Projekt".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Risiken".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Massnahmen".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + "fr" => ( + "Registre des risques", + vec![ + TemplateSection { + heading: "Projet".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Risques".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Mesures".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + "it" => ( + "Registro dei rischi", + vec![ + TemplateSection { + heading: "Progetto".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Rischi".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Mitigazioni".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + _ => ( + "Risk Register", + vec![ + TemplateSection { + heading: "Project".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Risks".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Mitigations".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +/// Return the built-in milestone report template for the given language. +pub fn milestone_report_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Meilensteinbericht", + vec![ + TemplateSection { + heading: "Projekt".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Meilensteine".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Status".into(), + body: "{{status}}".into(), + }, + ], + ), + "fr" => ( + "Rapport de jalons", + vec![ + TemplateSection { + heading: "Projet".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Jalons".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Statut".into(), + body: "{{status}}".into(), + }, + ], + ), + "it" => ( + "Report milestone", + vec![ + TemplateSection { + heading: "Progetto".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Milestone".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Stato".into(), + body: "{{status}}".into(), + }, + ], + ), + _ => ( + "Milestone Report", + vec![ + TemplateSection { + heading: "Project".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Milestones".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Status".into(), + body: "{{status}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn weekly_status_renders_with_variables() { + let tpl = weekly_status_template("en"); + let mut vars = HashMap::new(); + vars.insert("project_name".into(), "ZeroClaw".into()); + vars.insert("period".into(), "2026-W10".into()); + vars.insert("completed".into(), "- Task A\n- Task B".into()); + vars.insert("in_progress".into(), "- Task C".into()); + vars.insert("blocked".into(), "None".into()); + vars.insert("next_steps".into(), "- Task D".into()); + + let rendered = tpl.render(&vars); + assert!(rendered.contains("Project: ZeroClaw")); + assert!(rendered.contains("Period: 2026-W10")); + assert!(rendered.contains("- Task A")); + assert!(rendered.contains("## Completed")); + } + + #[test] + fn weekly_status_de_renders_german_headings() { + let tpl = weekly_status_template("de"); + let vars = HashMap::new(); + let rendered = tpl.render(&vars); + assert!(rendered.contains("## Zusammenfassung")); + assert!(rendered.contains("## Erledigt")); + } + + #[test] + fn weekly_status_fr_renders_french_headings() { + let tpl = weekly_status_template("fr"); + let vars = HashMap::new(); + let rendered = tpl.render(&vars); + assert!(rendered.contains("## Resume")); + assert!(rendered.contains("## Termine")); + } + + #[test] + fn weekly_status_it_renders_italian_headings() { + let tpl = weekly_status_template("it"); + let vars = HashMap::new(); + let rendered = tpl.render(&vars); + assert!(rendered.contains("## Riepilogo")); + assert!(rendered.contains("## Completato")); + } + + #[test] + fn html_format_renders_tags() { + let mut tpl = weekly_status_template("en"); + tpl.format = ReportFormat::Html; + let mut vars = HashMap::new(); + vars.insert("project_name".into(), "Test".into()); + vars.insert("period".into(), "W1".into()); + vars.insert("completed".into(), "Done".into()); + vars.insert("in_progress".into(), "WIP".into()); + vars.insert("blocked".into(), "None".into()); + vars.insert("next_steps".into(), "Next".into()); + + let rendered = tpl.render(&vars); + assert!(rendered.contains("

Summary

")); + assert!(rendered.contains("

Project: Test | Period: W1

")); + } + + #[test] + fn sprint_review_template_has_velocity_section() { + let tpl = sprint_review_template("en"); + let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect(); + assert!(section_headings.contains(&"Velocity")); + } + + #[test] + fn risk_register_template_has_risk_sections() { + let tpl = risk_register_template("en"); + let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect(); + assert!(section_headings.contains(&"Risks")); + assert!(section_headings.contains(&"Mitigations")); + } + + #[test] + fn milestone_template_all_languages() { + for lang in &["en", "de", "fr", "it"] { + let tpl = milestone_report_template(lang); + assert!(!tpl.name.is_empty()); + assert_eq!(tpl.sections.len(), 3); + } + } + + #[test] + fn substitute_leaves_unknown_placeholders() { + let vars = HashMap::new(); + let result = substitute("Hello {{name}}", &vars); + assert_eq!(result, "Hello {{name}}"); + } + + #[test] + fn substitute_replaces_all_occurrences() { + let mut vars = HashMap::new(); + vars.insert("x".into(), "1".into()); + let result = substitute("{{x}} and {{x}}", &vars); + assert_eq!(result, "1 and 1"); + } +} From a8c6363cde91709e8fe327530d7e457145a49444 Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 01:53:47 -0400 Subject: [PATCH 08/11] feat(nodes): add secure HMAC-SHA256 node transport layer (#3654) Add a new `nodes` module with HMAC-SHA256 authenticated transport for secure inter-node communication over standard HTTPS. Includes replay protection via timestamped nonces and constant-time signature comparison. Also adds `NodeTransportConfig` to the config schema and fixes missing `approval_manager` field in four `ChannelRuntimeContext` test constructors that failed compilation on latest master. Original work by @rareba. Rebased on latest master to resolve merge conflicts (SwarmConfig/SwarmStrategy exports, duplicate MCP validation, test constructor fields). Co-authored-by: Claude Opus 4.6 --- src/config/mod.rs | 18 ++-- src/config/schema.rs | 68 ++++++++++++ src/gateway/mod.rs | 8 +- src/lib.rs | 1 + src/nodes/mod.rs | 3 + src/nodes/transport.rs | 235 +++++++++++++++++++++++++++++++++++++++++ src/onboard/wizard.rs | 2 + 7 files changed, 322 insertions(+), 13 deletions(-) create mode 100644 src/nodes/mod.rs create mode 100644 src/nodes/transport.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index a82209d84..1ce1ebe02 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -13,15 +13,15 @@ pub use schema::{ GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, - MultimodalConfig, NextcloudTalkConfig, NodesConfig, NotionConfig, ObservabilityConfig, - OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig, - PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, - ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, - SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, SkillsPromptInjectionMode, - SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, - SwarmConfig, SwarmStrategy, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, - TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, - WorkspaceConfig, + MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, + ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, + PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig, + QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, + SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, + SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, + StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, + ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig, + WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, }; pub fn name_and_presence(channel: Option<&T>) -> (&'static str, bool) { diff --git a/src/config/schema.rs b/src/config/schema.rs index 6dca7cf8f..84508f546 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -271,6 +271,10 @@ pub struct Config { /// Notion integration configuration (`[notion]`). #[serde(default)] pub notion: NotionConfig, + + /// Secure inter-node transport configuration (`[node_transport]`). + #[serde(default)] + pub node_transport: NodeTransportConfig, } /// Multi-client workspace isolation configuration. @@ -1352,6 +1356,67 @@ impl Default for GatewayConfig { } } +/// Secure transport configuration for inter-node communication (`[node_transport]`). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct NodeTransportConfig { + /// Enable the secure transport layer. + #[serde(default = "default_node_transport_enabled")] + pub enabled: bool, + /// Shared secret for HMAC authentication between nodes. + #[serde(default)] + pub shared_secret: String, + /// Maximum age of signed requests in seconds (replay protection). + #[serde(default = "default_max_request_age")] + pub max_request_age_secs: i64, + /// Require HTTPS for all node communication. + #[serde(default = "default_require_https")] + pub require_https: bool, + /// Allow specific node IPs/CIDRs. + #[serde(default)] + pub allowed_peers: Vec, + /// Path to TLS certificate file. + #[serde(default)] + pub tls_cert_path: Option, + /// Path to TLS private key file. + #[serde(default)] + pub tls_key_path: Option, + /// Require client certificates (mutual TLS). + #[serde(default)] + pub mutual_tls: bool, + /// Maximum number of connections per peer. + #[serde(default = "default_connection_pool_size")] + pub connection_pool_size: usize, +} + +fn default_node_transport_enabled() -> bool { + true +} +fn default_max_request_age() -> i64 { + 300 +} +fn default_require_https() -> bool { + true +} +fn default_connection_pool_size() -> usize { + 4 +} + +impl Default for NodeTransportConfig { + fn default() -> Self { + Self { + enabled: default_node_transport_enabled(), + shared_secret: String::new(), + max_request_age_secs: default_max_request_age(), + require_https: default_require_https(), + allowed_peers: Vec::new(), + tls_cert_path: None, + tls_key_path: None, + mutual_tls: false, + connection_pool_size: default_connection_pool_size(), + } + } +} + // ── Composio (managed tool surface) ───────────────────────────── /// Composio managed OAuth tools integration (`[composio]` section). @@ -4647,6 +4712,7 @@ impl Default for Config { nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), notion: NotionConfig::default(), + node_transport: NodeTransportConfig::default(), } } } @@ -6915,6 +6981,7 @@ default_temperature = 0.7 nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), notion: NotionConfig::default(), + node_transport: NodeTransportConfig::default(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -7210,6 +7277,7 @@ tool_dispatcher = "xml" nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), notion: NotionConfig::default(), + node_transport: NodeTransportConfig::default(), }; config.save().await.unwrap(); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 7a6e41697..23d74d444 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1238,7 +1238,7 @@ async fn handle_whatsapp_message( .await; } - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { // Send reply via WhatsApp if let Err(e) = wa @@ -1346,7 +1346,7 @@ async fn handle_linq_webhook( } // Call the LLM - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { // Send reply via Linq if let Err(e) = linq @@ -1438,7 +1438,7 @@ async fn handle_wati_webhook(State(state): State, body: Bytes) -> impl } // Call the LLM - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { // Send reply via WATI if let Err(e) = wati @@ -1542,7 +1542,7 @@ async fn handle_nextcloud_talk_webhook( .await; } - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { if let Err(e) = nextcloud_talk .send(&SendMessage::new(response, &msg.reply_target)) diff --git a/src/lib.rs b/src/lib.rs index 71248da85..94b0d3765 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ pub(crate) mod integrations; pub mod memory; pub(crate) mod migration; pub(crate) mod multimodal; +pub mod nodes; pub mod observability; pub(crate) mod onboard; pub mod peripherals; diff --git a/src/nodes/mod.rs b/src/nodes/mod.rs new file mode 100644 index 000000000..1207bb50c --- /dev/null +++ b/src/nodes/mod.rs @@ -0,0 +1,3 @@ +pub mod transport; + +pub use transport::NodeTransport; diff --git a/src/nodes/transport.rs b/src/nodes/transport.rs new file mode 100644 index 000000000..75bc4d434 --- /dev/null +++ b/src/nodes/transport.rs @@ -0,0 +1,235 @@ +//! Corporate-friendly secure node transport using standard HTTPS + HMAC-SHA256 authentication. +//! +//! All inter-node traffic uses plain HTTPS on port 443 — no exotic protocols, +//! no custom binary framing, no UDP tunneling. This makes the transport +//! compatible with corporate proxies, firewalls, and IT audit expectations. + +use anyhow::{bail, Result}; +use chrono::Utc; +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +type HmacSha256 = Hmac; + +/// Signs a request payload with HMAC-SHA256. +/// +/// Uses `timestamp` + `nonce` alongside the payload to prevent replay attacks. +pub fn sign_request( + shared_secret: &str, + payload: &[u8], + timestamp: i64, + nonce: &str, +) -> Result { + let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes()) + .map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?; + mac.update(×tamp.to_le_bytes()); + mac.update(nonce.as_bytes()); + mac.update(payload); + Ok(hex::encode(mac.finalize().into_bytes())) +} + +/// Verify a signed request, rejecting stale timestamps for replay protection. +pub fn verify_request( + shared_secret: &str, + payload: &[u8], + timestamp: i64, + nonce: &str, + signature: &str, + max_age_secs: i64, +) -> Result { + let now = Utc::now().timestamp(); + if (now - timestamp).abs() > max_age_secs { + bail!("Request timestamp too old or too far in future"); + } + + let expected = sign_request(shared_secret, payload, timestamp, nonce)?; + Ok(constant_time_eq(expected.as_bytes(), signature.as_bytes())) +} + +/// Constant-time comparison to prevent timing attacks. +fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + a.iter() + .zip(b.iter()) + .fold(0u8, |acc, (x, y)| acc | (x ^ y)) + == 0 +} + +// ── Node transport client ─────────────────────────────────────── + +/// Sends authenticated HTTPS requests to peer nodes. +/// +/// Every outgoing request carries three custom headers: +/// - `X-ZeroClaw-Timestamp` — unix epoch seconds +/// - `X-ZeroClaw-Nonce` — random UUID v4 +/// - `X-ZeroClaw-Signature` — HMAC-SHA256 hex digest +/// +/// Incoming requests are verified with the same scheme via [`Self::verify_incoming`]. +pub struct NodeTransport { + http: reqwest::Client, + shared_secret: String, + max_request_age_secs: i64, +} + +impl NodeTransport { + pub fn new(shared_secret: String) -> Self { + Self { + http: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("HTTP client build"), + shared_secret, + max_request_age_secs: 300, // 5 min replay window + } + } + + /// Send an authenticated request to a peer node. + pub async fn send( + &self, + node_address: &str, + endpoint: &str, + payload: serde_json::Value, + ) -> Result { + let body = serde_json::to_vec(&payload)?; + let timestamp = Utc::now().timestamp(); + let nonce = uuid::Uuid::new_v4().to_string(); + let signature = sign_request(&self.shared_secret, &body, timestamp, &nonce)?; + + let url = format!("https://{node_address}/api/node-control/{endpoint}"); + let resp = self + .http + .post(&url) + .header("X-ZeroClaw-Timestamp", timestamp.to_string()) + .header("X-ZeroClaw-Nonce", &nonce) + .header("X-ZeroClaw-Signature", &signature) + .header("Content-Type", "application/json") + .body(body) + .send() + .await?; + + if !resp.status().is_success() { + bail!( + "Node request failed: {} {}", + resp.status(), + resp.text().await.unwrap_or_default() + ); + } + + Ok(resp.json().await?) + } + + /// Verify an incoming request from a peer node. + pub fn verify_incoming( + &self, + payload: &[u8], + timestamp_header: &str, + nonce_header: &str, + signature_header: &str, + ) -> Result { + let timestamp: i64 = timestamp_header + .parse() + .map_err(|_| anyhow::anyhow!("Invalid timestamp header"))?; + verify_request( + &self.shared_secret, + payload, + timestamp, + nonce_header, + signature_header, + self.max_request_age_secs, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_SECRET: &str = "test-shared-secret-key"; + + #[test] + fn sign_request_deterministic() { + let sig1 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap(); + let sig2 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap(); + assert_eq!(sig1, sig2, "Same inputs must produce the same signature"); + } + + #[test] + fn verify_request_accepts_valid_signature() { + let now = Utc::now().timestamp(); + let sig = sign_request(TEST_SECRET, b"payload", now, "nonce-a").unwrap(); + let ok = verify_request(TEST_SECRET, b"payload", now, "nonce-a", &sig, 300).unwrap(); + assert!(ok, "Valid signature must pass verification"); + } + + #[test] + fn verify_request_rejects_tampered_payload() { + let now = Utc::now().timestamp(); + let sig = sign_request(TEST_SECRET, b"original", now, "nonce-b").unwrap(); + let ok = verify_request(TEST_SECRET, b"tampered", now, "nonce-b", &sig, 300).unwrap(); + assert!(!ok, "Tampered payload must fail verification"); + } + + #[test] + fn verify_request_rejects_expired_timestamp() { + let old = Utc::now().timestamp() - 600; + let sig = sign_request(TEST_SECRET, b"data", old, "nonce-c").unwrap(); + let result = verify_request(TEST_SECRET, b"data", old, "nonce-c", &sig, 300); + assert!(result.is_err(), "Expired timestamp must be rejected"); + } + + #[test] + fn verify_request_rejects_wrong_secret() { + let now = Utc::now().timestamp(); + let sig = sign_request(TEST_SECRET, b"data", now, "nonce-d").unwrap(); + let ok = verify_request("wrong-secret", b"data", now, "nonce-d", &sig, 300).unwrap(); + assert!(!ok, "Wrong secret must fail verification"); + } + + #[test] + fn constant_time_eq_correctness() { + assert!(constant_time_eq(b"abc", b"abc")); + assert!(!constant_time_eq(b"abc", b"abd")); + assert!(!constant_time_eq(b"abc", b"ab")); + assert!(!constant_time_eq(b"", b"a")); + assert!(constant_time_eq(b"", b"")); + } + + #[test] + fn node_transport_construction() { + let transport = NodeTransport::new("secret-key".into()); + assert_eq!(transport.max_request_age_secs, 300); + } + + #[test] + fn node_transport_verify_incoming_valid() { + let transport = NodeTransport::new(TEST_SECRET.into()); + let now = Utc::now().timestamp(); + let payload = b"test-body"; + let nonce = "incoming-nonce"; + let sig = sign_request(TEST_SECRET, payload, now, nonce).unwrap(); + + let ok = transport + .verify_incoming(payload, &now.to_string(), nonce, &sig) + .unwrap(); + assert!(ok, "Valid incoming request must pass verification"); + } + + #[test] + fn node_transport_verify_incoming_bad_timestamp_header() { + let transport = NodeTransport::new(TEST_SECRET.into()); + let result = transport.verify_incoming(b"body", "not-a-number", "nonce", "sig"); + assert!(result.is_err(), "Non-numeric timestamp header must error"); + } + + #[test] + fn sign_request_different_nonce_different_signature() { + let sig1 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-1").unwrap(); + let sig2 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-2").unwrap(); + assert_ne!( + sig1, sig2, + "Different nonces must produce different signatures" + ); + } +} diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 189d39f19..4b2692a99 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -181,6 +181,7 @@ pub async fn run_wizard(force: bool) -> Result { nodes: crate::config::NodesConfig::default(), workspace: crate::config::WorkspaceConfig::default(), notion: crate::config::NotionConfig::default(), + node_transport: crate::config::NodeTransportConfig::default(), }; println!( @@ -542,6 +543,7 @@ async fn run_quick_setup_with_home( nodes: crate::config::NodesConfig::default(), workspace: crate::config::WorkspaceConfig::default(), notion: crate::config::NotionConfig::default(), + node_transport: crate::config::NodeTransportConfig::default(), }; config.save().await?; From dcc0a629ecb4a44fdabcf61b1a92ef071d593586 Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 02:10:14 -0400 Subject: [PATCH 09/11] feat(tools): add project delivery intelligence tool (#3656) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a new read-only project_intel tool that provides: - Status report generation (weekly/sprint/month) - Risk scanning with configurable sensitivity - Client update drafting (formal/casual, client/internal) - Sprint summary generation - Heuristic effort estimation Includes multi-language report templates (EN, DE, FR, IT), ProjectIntelConfig schema with validation, and comprehensive tests. Also fixes missing approval_manager field in 4 ChannelRuntimeContext test constructors. Supersedes #3591 — rebased on latest master. Original work by @rareba. Co-authored-by: Claude Opus 4.6 --- src/config/mod.rs | 8 ++-- src/config/schema.rs | 87 +++++++++++++++++++++++++++++++++++++++++++ src/gateway/mod.rs | 2 +- src/main.rs | 2 +- src/onboard/wizard.rs | 2 + src/tools/mod.rs | 11 ++++++ 6 files changed, 106 insertions(+), 6 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 1ce1ebe02..b49edfda8 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -15,10 +15,10 @@ pub use schema::{ McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, - PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig, - QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, - SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, - SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, + PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, + QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, + RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, + SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, diff --git a/src/config/schema.rs b/src/config/schema.rs index 84508f546..462fabe97 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -216,6 +216,10 @@ pub struct Config { #[serde(default)] pub web_search: WebSearchConfig, + /// Project delivery intelligence configuration (`[project_intel]`). + #[serde(default)] + pub project_intel: ProjectIntelConfig, + /// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]`). #[serde(default)] pub proxy: ProxyConfig, @@ -1785,6 +1789,64 @@ impl Default for WebSearchConfig { } } +// ── Project Intelligence ──────────────────────────────────────── + +/// Project delivery intelligence configuration (`[project_intel]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct ProjectIntelConfig { + /// Enable the project_intel tool. Default: false. + #[serde(default)] + pub enabled: bool, + /// Default report language (en, de, fr, it). Default: "en". + #[serde(default = "default_project_intel_language")] + pub default_language: String, + /// Output directory for generated reports. + #[serde(default = "default_project_intel_report_dir")] + pub report_output_dir: String, + /// Optional custom templates directory. + #[serde(default)] + pub templates_dir: Option, + /// Risk detection sensitivity: low, medium, high. Default: "medium". + #[serde(default = "default_project_intel_risk_sensitivity")] + pub risk_sensitivity: String, + /// Include git log data in reports. Default: true. + #[serde(default = "default_true")] + pub include_git_data: bool, + /// Include Jira data in reports. Default: false. + #[serde(default)] + pub include_jira_data: bool, + /// Jira instance base URL (required if include_jira_data is true). + #[serde(default)] + pub jira_base_url: Option, +} + +fn default_project_intel_language() -> String { + "en".into() +} + +fn default_project_intel_report_dir() -> String { + "~/.zeroclaw/project-reports".into() +} + +fn default_project_intel_risk_sensitivity() -> String { + "medium".into() +} + +impl Default for ProjectIntelConfig { + fn default() -> Self { + Self { + enabled: false, + default_language: default_project_intel_language(), + report_output_dir: default_project_intel_report_dir(), + templates_dir: None, + risk_sensitivity: default_project_intel_risk_sensitivity(), + include_git_data: true, + include_jira_data: false, + jira_base_url: None, + } + } +} + // ── Proxy ─────────────────────────────────────────────────────── /// Proxy application scope — determines which outbound traffic uses the proxy. @@ -4697,6 +4759,7 @@ impl Default for Config { multimodal: MultimodalConfig::default(), web_fetch: WebFetchConfig::default(), web_search: WebSearchConfig::default(), + project_intel: ProjectIntelConfig::default(), proxy: ProxyConfig::default(), identity: IdentityConfig::default(), cost: CostConfig::default(), @@ -5854,6 +5917,28 @@ impl Config { validate_mcp_config(&self.mcp)?; } + // Project intelligence + if self.project_intel.enabled { + let lang = &self.project_intel.default_language; + if !["en", "de", "fr", "it"].contains(&lang.as_str()) { + anyhow::bail!( + "project_intel.default_language must be one of: en, de, fr, it (got '{lang}')" + ); + } + let sens = &self.project_intel.risk_sensitivity; + if !["low", "medium", "high"].contains(&sens.as_str()) { + anyhow::bail!( + "project_intel.risk_sensitivity must be one of: low, medium, high (got '{sens}')" + ); + } + if let Some(ref tpl_dir) = self.project_intel.templates_dir { + let path = std::path::Path::new(tpl_dir); + if !path.exists() { + anyhow::bail!("project_intel.templates_dir path does not exist: {tpl_dir}"); + } + } + } + // Proxy (delegate to existing validation) self.proxy.validate()?; @@ -6966,6 +7051,7 @@ default_temperature = 0.7 multimodal: MultimodalConfig::default(), web_fetch: WebFetchConfig::default(), web_search: WebSearchConfig::default(), + project_intel: ProjectIntelConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), identity: IdentityConfig::default(), @@ -7262,6 +7348,7 @@ tool_dispatcher = "xml" multimodal: MultimodalConfig::default(), web_fetch: WebFetchConfig::default(), web_search: WebSearchConfig::default(), + project_intel: ProjectIntelConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), identity: IdentityConfig::default(), diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 23d74d444..f34703408 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -910,7 +910,7 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res /// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk). async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result { let config = state.config.lock().clone(); - crate::agent::process_message(config, message).await + Box::pin(crate::agent::process_message(config, message)).await } /// Webhook request body diff --git a/src/main.rs b/src/main.rs index e2d04c736..b08d4de0c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1190,7 +1190,7 @@ async fn main() -> Result<()> { Commands::Channel { channel_command } => match channel_command { ChannelCommands::Start => Box::pin(channels::start_channels(config)).await, - ChannelCommands::Doctor => channels::doctor_channels(config).await, + ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await, other => channels::handle_command(other, &config).await, }, diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 4b2692a99..2fe8fef28 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -166,6 +166,7 @@ pub async fn run_wizard(force: bool) -> Result { multimodal: crate::config::MultimodalConfig::default(), web_fetch: crate::config::WebFetchConfig::default(), web_search: crate::config::WebSearchConfig::default(), + project_intel: crate::config::ProjectIntelConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), cost: crate::config::CostConfig::default(), @@ -528,6 +529,7 @@ async fn run_quick_setup_with_home( multimodal: crate::config::MultimodalConfig::default(), web_fetch: crate::config::WebFetchConfig::default(), web_search: crate::config::WebSearchConfig::default(), + project_intel: crate::config::ProjectIntelConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), cost: crate::config::CostConfig::default(), diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 5fe76ef6f..75cfccbc1 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -53,8 +53,10 @@ pub mod model_routing_config; pub mod node_tool; pub mod notion_tool; pub mod pdf_read; +pub mod project_intel; pub mod proxy_config; pub mod pushover; +pub mod report_templates; pub mod schedule; pub mod schema; pub mod screenshot; @@ -102,6 +104,7 @@ pub use model_routing_config::ModelRoutingConfigTool; pub use node_tool::NodeTool; pub use notion_tool::NotionTool; pub use pdf_read::PdfReadTool; +pub use project_intel::ProjectIntelTool; pub use proxy_config::ProxyConfigTool; pub use pushover::PushoverTool; pub use schedule::ScheduleTool; @@ -364,6 +367,14 @@ pub fn all_tools_with_runtime( } } + // Project delivery intelligence + if root_config.project_intel.enabled { + tool_arcs.push(Arc::new(ProjectIntelTool::new( + root_config.project_intel.default_language.clone(), + root_config.project_intel.risk_sensitivity.clone(), + ))); + } + // PDF extraction (feature-gated at compile time via rag-pdf) tool_arcs.push(Arc::new(PdfReadTool::new(security.clone()))); From 8a61a283b205ccb080462e1c8a85af4d5c869b23 Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 02:28:54 -0400 Subject: [PATCH 10/11] feat(security): add MCSS security operations tool (#3657) * feat(security): add MCSS security operations tool Add managed cybersecurity service (MCSS) tool with alert triage, incident response playbook execution, vulnerability scan parsing, and security report generation. Includes SecurityOpsConfig, playbook engine with approval gating, vulnerability scoring, and full test coverage. Also fixes pre-existing missing approval_manager field in ChannelRuntimeContext test constructors. Original work by @rareba. Supersedes #3599. Co-Authored-By: Claude Opus 4.6 * fix: add SecurityOpsConfig to re-exports, fix test constructors Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- src/config/mod.rs | 8 +- src/config/schema.rs | 65 ++++ src/onboard/wizard.rs | 2 + src/security/mod.rs | 2 + src/security/playbook.rs | 459 +++++++++++++++++++++++ src/security/vulnerability.rs | 397 ++++++++++++++++++++ src/tools/mod.rs | 9 + src/tools/security_ops.rs | 659 ++++++++++++++++++++++++++++++++++ 8 files changed, 1597 insertions(+), 4 deletions(-) create mode 100644 src/security/playbook.rs create mode 100644 src/security/vulnerability.rs create mode 100644 src/tools/security_ops.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index b49edfda8..45fbd6ff7 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -18,10 +18,10 @@ pub use schema::{ PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, - StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, - ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig, - WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, + SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, + StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, + TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, + TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, }; pub fn name_and_presence(channel: Option<&T>) -> (&'static str, bool) { diff --git a/src/config/schema.rs b/src/config/schema.rs index 462fabe97..9ca82e547 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -124,6 +124,9 @@ pub struct Config { #[serde(default)] pub security: SecurityConfig, + /// Managed cybersecurity service configuration (`[security_ops]`). + pub security_ops: SecurityOpsConfig, + /// Runtime adapter configuration (`[runtime]`). Controls native vs Docker execution. #[serde(default)] pub runtime: RuntimeConfig, @@ -4714,6 +4717,65 @@ impl Default for NotionConfig { } } +// ── Security ops config ───────────────────────────────────────── + +/// Managed Cybersecurity Service (MCSS) dashboard agent configuration (`[security_ops]`). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SecurityOpsConfig { + /// Enable security operations tools. + #[serde(default)] + pub enabled: bool, + /// Directory containing incident response playbook definitions (JSON). + #[serde(default = "default_playbooks_dir")] + pub playbooks_dir: String, + /// Automatically triage incoming alerts without user prompt. + #[serde(default)] + pub auto_triage: bool, + /// Require human approval before executing playbook actions. + #[serde(default = "default_require_approval")] + pub require_approval_for_actions: bool, + /// Maximum severity level that can be auto-remediated without approval. + /// One of: "low", "medium", "high", "critical". Default: "low". + #[serde(default = "default_max_auto_severity")] + pub max_auto_severity: String, + /// Directory for generated security reports. + #[serde(default = "default_report_output_dir")] + pub report_output_dir: String, + /// Optional SIEM webhook URL for alert ingestion. + #[serde(default)] + pub siem_integration: Option, +} + +fn default_playbooks_dir() -> String { + "~/.zeroclaw/playbooks".into() +} + +fn default_require_approval() -> bool { + true +} + +fn default_max_auto_severity() -> String { + "low".into() +} + +fn default_report_output_dir() -> String { + "~/.zeroclaw/security-reports".into() +} + +impl Default for SecurityOpsConfig { + fn default() -> Self { + Self { + enabled: false, + playbooks_dir: default_playbooks_dir(), + auto_triage: false, + require_approval_for_actions: true, + max_auto_severity: default_max_auto_severity(), + report_output_dir: default_report_output_dir(), + siem_integration: None, + } + } +} + // ── Config impl ────────────────────────────────────────────────── impl Default for Config { @@ -4737,6 +4799,7 @@ impl Default for Config { observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), security: SecurityConfig::default(), + security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), @@ -6984,6 +7047,7 @@ default_temperature = 0.7 non_cli_excluded_tools: vec![], }, security: SecurityConfig::default(), + security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig { kind: "docker".into(), ..RuntimeConfig::default() @@ -7326,6 +7390,7 @@ tool_dispatcher = "xml" observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), security: SecurityConfig::default(), + security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 2fe8fef28..210114954 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -144,6 +144,7 @@ pub async fn run_wizard(force: bool) -> Result { observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), security: crate::config::SecurityConfig::default(), + security_ops: crate::config::SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), @@ -507,6 +508,7 @@ async fn run_quick_setup_with_home( observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), security: crate::config::SecurityConfig::default(), + security_ops: crate::config::SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), diff --git a/src/security/mod.rs b/src/security/mod.rs index f80268427..433e7046f 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -36,10 +36,12 @@ pub mod leak_detector; pub mod nevis; pub mod otp; pub mod pairing; +pub mod playbook; pub mod policy; pub mod prompt_guard; pub mod secrets; pub mod traits; +pub mod vulnerability; pub mod workspace_boundary; #[allow(unused_imports)] diff --git a/src/security/playbook.rs b/src/security/playbook.rs new file mode 100644 index 000000000..cce5a27ff --- /dev/null +++ b/src/security/playbook.rs @@ -0,0 +1,459 @@ +//! Incident response playbook definitions and execution engine. +//! +//! Playbooks define structured response procedures for security incidents. +//! Each playbook has named steps, some of which require human approval before +//! execution. Playbooks are loaded from JSON files in the configured directory. + +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// A single step in an incident response playbook. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PlaybookStep { + /// Machine-readable action identifier (e.g. "isolate_host", "block_ip"). + pub action: String, + /// Human-readable description of what this step does. + pub description: String, + /// Whether this step requires explicit human approval before execution. + #[serde(default)] + pub requires_approval: bool, + /// Timeout in seconds for this step. Default: 300 (5 minutes). + #[serde(default = "default_timeout_secs")] + pub timeout_secs: u64, +} + +fn default_timeout_secs() -> u64 { + 300 +} + +/// An incident response playbook. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Playbook { + /// Unique playbook name (e.g. "suspicious_login"). + pub name: String, + /// Human-readable description. + pub description: String, + /// Ordered list of response steps. + pub steps: Vec, + /// Minimum alert severity that triggers this playbook (low/medium/high/critical). + #[serde(default = "default_severity_filter")] + pub severity_filter: String, + /// Step indices (0-based) that can be auto-approved when below max_auto_severity. + #[serde(default)] + pub auto_approve_steps: Vec, +} + +fn default_severity_filter() -> String { + "medium".into() +} + +/// Result of executing a single playbook step. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepExecutionResult { + pub step_index: usize, + pub action: String, + pub status: StepStatus, + pub message: String, +} + +/// Status of a playbook step. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum StepStatus { + /// Step completed successfully. + Completed, + /// Step is waiting for human approval. + PendingApproval, + /// Step was skipped (e.g. not applicable). + Skipped, + /// Step failed with an error. + Failed, +} + +impl std::fmt::Display for StepStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Completed => write!(f, "completed"), + Self::PendingApproval => write!(f, "pending_approval"), + Self::Skipped => write!(f, "skipped"), + Self::Failed => write!(f, "failed"), + } + } +} + +/// Load all playbook definitions from a directory of JSON files. +pub fn load_playbooks(dir: &Path) -> Vec { + let mut playbooks = Vec::new(); + + if !dir.exists() || !dir.is_dir() { + return builtin_playbooks(); + } + + if let Ok(entries) = std::fs::read_dir(dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().map_or(false, |ext| ext == "json") { + match std::fs::read_to_string(&path) { + Ok(contents) => match serde_json::from_str::(&contents) { + Ok(pb) => playbooks.push(pb), + Err(e) => { + tracing::warn!("Failed to parse playbook {}: {e}", path.display()); + } + }, + Err(e) => { + tracing::warn!("Failed to read playbook {}: {e}", path.display()); + } + } + } + } + } + + // Merge built-in playbooks that aren't overridden by user-defined ones + for builtin in builtin_playbooks() { + if !playbooks.iter().any(|p| p.name == builtin.name) { + playbooks.push(builtin); + } + } + + playbooks +} + +/// Severity ordering for comparison: low < medium < high < critical. +pub fn severity_level(severity: &str) -> u8 { + match severity.to_lowercase().as_str() { + "low" => 1, + "medium" => 2, + "high" => 3, + "critical" => 4, + // Deny-by-default: unknown severities get the highest level to prevent + // auto-approval of unrecognized severity labels. + _ => u8::MAX, + } +} + +/// Check whether a step can be auto-approved given config constraints. +pub fn can_auto_approve( + playbook: &Playbook, + step_index: usize, + alert_severity: &str, + max_auto_severity: &str, +) -> bool { + // Never auto-approve if alert severity exceeds the configured max + if severity_level(alert_severity) > severity_level(max_auto_severity) { + return false; + } + + // Only auto-approve steps explicitly listed in auto_approve_steps + playbook.auto_approve_steps.contains(&step_index) +} + +/// Evaluate a playbook step. Returns the result with approval gating. +/// +/// Steps that require approval and cannot be auto-approved will return +/// `StepStatus::PendingApproval` without executing. +pub fn evaluate_step( + playbook: &Playbook, + step_index: usize, + alert_severity: &str, + max_auto_severity: &str, + require_approval: bool, +) -> StepExecutionResult { + let step = match playbook.steps.get(step_index) { + Some(s) => s, + None => { + return StepExecutionResult { + step_index, + action: "unknown".into(), + status: StepStatus::Failed, + message: format!("Step index {step_index} out of range"), + }; + } + }; + + // Enforce approval gates: steps that require approval must either be + // auto-approved or wait for human approval. Never mark an unexecuted + // approval-gated step as Completed. + if step.requires_approval + && (!require_approval + || !can_auto_approve(playbook, step_index, alert_severity, max_auto_severity)) + { + return StepExecutionResult { + step_index, + action: step.action.clone(), + status: StepStatus::PendingApproval, + message: format!( + "Step '{}' requires human approval (severity: {alert_severity})", + step.description + ), + }; + } + + // Step is approved (either doesn't require approval, or was auto-approved) + // Actual execution would be delegated to the appropriate tool/system + StepExecutionResult { + step_index, + action: step.action.clone(), + status: StepStatus::Completed, + message: format!("Executed: {}", step.description), + } +} + +/// Built-in playbook definitions for common incident types. +pub fn builtin_playbooks() -> Vec { + vec![ + Playbook { + name: "suspicious_login".into(), + description: "Respond to suspicious login activity detected by SIEM".into(), + steps: vec![ + PlaybookStep { + action: "gather_login_context".into(), + description: "Collect login metadata: IP, geo, device fingerprint, time".into(), + requires_approval: false, + timeout_secs: 60, + }, + PlaybookStep { + action: "check_threat_intel".into(), + description: "Query threat intelligence for source IP reputation".into(), + requires_approval: false, + timeout_secs: 30, + }, + PlaybookStep { + action: "notify_user".into(), + description: "Send verification notification to account owner".into(), + requires_approval: true, + timeout_secs: 300, + }, + PlaybookStep { + action: "force_password_reset".into(), + description: "Force password reset if login confirmed unauthorized".into(), + requires_approval: true, + timeout_secs: 120, + }, + ], + severity_filter: "medium".into(), + auto_approve_steps: vec![0, 1], + }, + Playbook { + name: "malware_detected".into(), + description: "Respond to malware detection on endpoint".into(), + steps: vec![ + PlaybookStep { + action: "isolate_endpoint".into(), + description: "Network-isolate the affected endpoint".into(), + requires_approval: true, + timeout_secs: 60, + }, + PlaybookStep { + action: "collect_forensics".into(), + description: "Capture memory dump and disk image for analysis".into(), + requires_approval: false, + timeout_secs: 600, + }, + PlaybookStep { + action: "scan_lateral_movement".into(), + description: "Check for lateral movement indicators on adjacent hosts".into(), + requires_approval: false, + timeout_secs: 300, + }, + PlaybookStep { + action: "remediate_endpoint".into(), + description: "Remove malware and restore endpoint to clean state".into(), + requires_approval: true, + timeout_secs: 600, + }, + ], + severity_filter: "high".into(), + auto_approve_steps: vec![1, 2], + }, + Playbook { + name: "data_exfiltration_attempt".into(), + description: "Respond to suspected data exfiltration".into(), + steps: vec![ + PlaybookStep { + action: "block_egress".into(), + description: "Block suspicious outbound connections".into(), + requires_approval: true, + timeout_secs: 30, + }, + PlaybookStep { + action: "identify_data_scope".into(), + description: "Determine what data may have been accessed or transferred".into(), + requires_approval: false, + timeout_secs: 300, + }, + PlaybookStep { + action: "preserve_evidence".into(), + description: "Preserve network logs and access records".into(), + requires_approval: false, + timeout_secs: 120, + }, + PlaybookStep { + action: "escalate_to_legal".into(), + description: "Notify legal and compliance teams".into(), + requires_approval: true, + timeout_secs: 60, + }, + ], + severity_filter: "critical".into(), + auto_approve_steps: vec![1, 2], + }, + Playbook { + name: "brute_force".into(), + description: "Respond to brute force authentication attempts".into(), + steps: vec![ + PlaybookStep { + action: "block_source_ip".into(), + description: "Block the attacking source IP at firewall".into(), + requires_approval: true, + timeout_secs: 30, + }, + PlaybookStep { + action: "check_compromised_accounts".into(), + description: "Check if any accounts were successfully compromised".into(), + requires_approval: false, + timeout_secs: 120, + }, + PlaybookStep { + action: "enable_rate_limiting".into(), + description: "Enable enhanced rate limiting on auth endpoints".into(), + requires_approval: true, + timeout_secs: 60, + }, + ], + severity_filter: "medium".into(), + auto_approve_steps: vec![1], + }, + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builtin_playbooks_are_valid() { + let playbooks = builtin_playbooks(); + assert_eq!(playbooks.len(), 4); + + let names: Vec<&str> = playbooks.iter().map(|p| p.name.as_str()).collect(); + assert!(names.contains(&"suspicious_login")); + assert!(names.contains(&"malware_detected")); + assert!(names.contains(&"data_exfiltration_attempt")); + assert!(names.contains(&"brute_force")); + + for pb in &playbooks { + assert!(!pb.steps.is_empty(), "Playbook {} has no steps", pb.name); + assert!(!pb.description.is_empty()); + } + } + + #[test] + fn severity_level_ordering() { + assert!(severity_level("low") < severity_level("medium")); + assert!(severity_level("medium") < severity_level("high")); + assert!(severity_level("high") < severity_level("critical")); + assert_eq!(severity_level("unknown"), u8::MAX); + } + + #[test] + fn auto_approve_respects_severity_cap() { + let pb = &builtin_playbooks()[0]; // suspicious_login + + // Step 0 is in auto_approve_steps + assert!(can_auto_approve(pb, 0, "low", "low")); + assert!(can_auto_approve(pb, 0, "low", "medium")); + + // Alert severity exceeds max -> cannot auto-approve + assert!(!can_auto_approve(pb, 0, "high", "low")); + assert!(!can_auto_approve(pb, 0, "critical", "medium")); + + // Step 2 is NOT in auto_approve_steps + assert!(!can_auto_approve(pb, 2, "low", "critical")); + } + + #[test] + fn evaluate_step_requires_approval() { + let pb = &builtin_playbooks()[0]; // suspicious_login + + // Step 2 (notify_user) requires approval, high severity, max=low -> pending + let result = evaluate_step(pb, 2, "high", "low", true); + assert_eq!(result.status, StepStatus::PendingApproval); + assert_eq!(result.action, "notify_user"); + + // Step 0 (gather_login_context) does NOT require approval -> completed + let result = evaluate_step(pb, 0, "high", "low", true); + assert_eq!(result.status, StepStatus::Completed); + } + + #[test] + fn evaluate_step_out_of_range() { + let pb = &builtin_playbooks()[0]; + let result = evaluate_step(pb, 99, "low", "low", true); + assert_eq!(result.status, StepStatus::Failed); + } + + #[test] + fn playbook_json_roundtrip() { + let pb = &builtin_playbooks()[0]; + let json = serde_json::to_string(pb).unwrap(); + let parsed: Playbook = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, *pb); + } + + #[test] + fn load_playbooks_from_nonexistent_dir_returns_builtins() { + let playbooks = load_playbooks(Path::new("/nonexistent/dir")); + assert_eq!(playbooks.len(), 4); + } + + #[test] + fn load_playbooks_merges_custom_and_builtin() { + let dir = tempfile::tempdir().unwrap(); + let custom = Playbook { + name: "custom_playbook".into(), + description: "A custom playbook".into(), + steps: vec![PlaybookStep { + action: "custom_action".into(), + description: "Do something custom".into(), + requires_approval: true, + timeout_secs: 60, + }], + severity_filter: "low".into(), + auto_approve_steps: vec![], + }; + let json = serde_json::to_string(&custom).unwrap(); + std::fs::write(dir.path().join("custom.json"), json).unwrap(); + + let playbooks = load_playbooks(dir.path()); + // 4 builtins + 1 custom + assert_eq!(playbooks.len(), 5); + assert!(playbooks.iter().any(|p| p.name == "custom_playbook")); + } + + #[test] + fn load_playbooks_custom_overrides_builtin() { + let dir = tempfile::tempdir().unwrap(); + let override_pb = Playbook { + name: "suspicious_login".into(), + description: "Custom override".into(), + steps: vec![PlaybookStep { + action: "custom_step".into(), + description: "Overridden step".into(), + requires_approval: false, + timeout_secs: 30, + }], + severity_filter: "low".into(), + auto_approve_steps: vec![0], + }; + let json = serde_json::to_string(&override_pb).unwrap(); + std::fs::write(dir.path().join("suspicious_login.json"), json).unwrap(); + + let playbooks = load_playbooks(dir.path()); + // 3 remaining builtins + 1 overridden = 4 + assert_eq!(playbooks.len(), 4); + let sl = playbooks + .iter() + .find(|p| p.name == "suspicious_login") + .unwrap(); + assert_eq!(sl.description, "Custom override"); + } +} diff --git a/src/security/vulnerability.rs b/src/security/vulnerability.rs new file mode 100644 index 000000000..0b8e30535 --- /dev/null +++ b/src/security/vulnerability.rs @@ -0,0 +1,397 @@ +//! Vulnerability scan result parsing and management. +//! +//! Parses vulnerability scan outputs from common scanners (Nessus, Qualys, generic +//! CVSS JSON) and provides priority scoring with business context adjustments. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; + +/// A single vulnerability finding. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Finding { + /// CVE identifier (e.g. "CVE-2024-1234"). May be empty for non-CVE findings. + #[serde(default)] + pub cve_id: String, + /// CVSS base score (0.0 - 10.0). + pub cvss_score: f64, + /// Severity label: "low", "medium", "high", "critical". + pub severity: String, + /// Affected asset identifier (hostname, IP, or service name). + pub affected_asset: String, + /// Description of the vulnerability. + pub description: String, + /// Recommended remediation steps. + #[serde(default)] + pub remediation: String, + /// Whether the asset is internet-facing (increases effective priority). + #[serde(default)] + pub internet_facing: bool, + /// Whether the asset is in a production environment. + #[serde(default = "default_true")] + pub production: bool, +} + +fn default_true() -> bool { + true +} + +/// A parsed vulnerability scan report. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VulnerabilityReport { + /// When the scan was performed. + pub scan_date: DateTime, + /// Scanner that produced the results (e.g. "nessus", "qualys", "generic"). + pub scanner: String, + /// Individual findings from the scan. + pub findings: Vec, +} + +/// Compute effective priority score for a finding. +/// +/// Base: CVSS score (0-10). Adjustments: +/// - Internet-facing: +2.0 (capped at 10.0) +/// - Production: +1.0 (capped at 10.0) +pub fn effective_priority(finding: &Finding) -> f64 { + let mut score = finding.cvss_score; + if finding.internet_facing { + score += 2.0; + } + if finding.production { + score += 1.0; + } + score.min(10.0) +} + +/// Classify CVSS score into severity label. +pub fn cvss_to_severity(cvss: f64) -> &'static str { + match cvss { + s if s >= 9.0 => "critical", + s if s >= 7.0 => "high", + s if s >= 4.0 => "medium", + s if s > 0.0 => "low", + _ => "informational", + } +} + +/// Parse a generic CVSS JSON vulnerability report. +/// +/// Expects a JSON object with: +/// - `scan_date`: ISO 8601 date string +/// - `scanner`: string +/// - `findings`: array of Finding objects +pub fn parse_vulnerability_json(json_str: &str) -> anyhow::Result { + let report: VulnerabilityReport = serde_json::from_str(json_str) + .map_err(|e| anyhow::anyhow!("Failed to parse vulnerability report: {e}"))?; + + for (i, finding) in report.findings.iter().enumerate() { + if !(0.0..=10.0).contains(&finding.cvss_score) { + anyhow::bail!( + "findings[{}].cvss_score must be between 0.0 and 10.0, got {}", + i, + finding.cvss_score + ); + } + } + + Ok(report) +} + +/// Generate a summary of the vulnerability report. +pub fn generate_summary(report: &VulnerabilityReport) -> String { + if report.findings.is_empty() { + return format!( + "Vulnerability scan by {} on {}: No findings.", + report.scanner, + report.scan_date.format("%Y-%m-%d") + ); + } + + let total = report.findings.len(); + let critical = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("critical")) + .count(); + let high = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("high")) + .count(); + let medium = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("medium")) + .count(); + let low = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("low")) + .count(); + let informational = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("informational")) + .count(); + + // Sort by effective priority descending + let mut sorted: Vec<&Finding> = report.findings.iter().collect(); + sorted.sort_by(|a, b| { + effective_priority(b) + .partial_cmp(&effective_priority(a)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut summary = format!( + "## Vulnerability Scan Summary\n\ + **Scanner:** {} | **Date:** {}\n\ + **Total findings:** {} (Critical: {}, High: {}, Medium: {}, Low: {}, Informational: {})\n\n", + report.scanner, + report.scan_date.format("%Y-%m-%d"), + total, + critical, + high, + medium, + low, + informational + ); + + // Top 10 by effective priority + summary.push_str("### Top Findings by Priority\n\n"); + for (i, finding) in sorted.iter().take(10).enumerate() { + let priority = effective_priority(finding); + let context = match (finding.internet_facing, finding.production) { + (true, true) => " [internet-facing, production]", + (true, false) => " [internet-facing]", + (false, true) => " [production]", + (false, false) => "", + }; + let _ = writeln!( + summary, + "{}. **{}** (CVSS: {:.1}, Priority: {:.1}){}\n Asset: {} | {}", + i + 1, + if finding.cve_id.is_empty() { + "No CVE" + } else { + &finding.cve_id + }, + finding.cvss_score, + priority, + context, + finding.affected_asset, + finding.description + ); + if !finding.remediation.is_empty() { + let _ = writeln!(summary, " Remediation: {}", finding.remediation); + } + summary.push('\n'); + } + + // Remediation recommendations + if critical > 0 || high > 0 { + summary.push_str("### Remediation Recommendations\n\n"); + if critical > 0 { + let _ = writeln!( + summary, + "- **URGENT:** {} critical findings require immediate remediation", + critical + ); + } + if high > 0 { + let _ = writeln!( + summary, + "- **HIGH:** {} high-severity findings should be addressed within 7 days", + high + ); + } + let internet_facing_critical = sorted + .iter() + .filter(|f| f.internet_facing && (f.severity == "critical" || f.severity == "high")) + .count(); + if internet_facing_critical > 0 { + let _ = writeln!( + summary, + "- **PRIORITY:** {} critical/high findings on internet-facing assets", + internet_facing_critical + ); + } + } + + summary +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_findings() -> Vec { + vec![ + Finding { + cve_id: "CVE-2024-0001".into(), + cvss_score: 9.8, + severity: "critical".into(), + affected_asset: "web-server-01".into(), + description: "Remote code execution in web framework".into(), + remediation: "Upgrade to version 2.1.0".into(), + internet_facing: true, + production: true, + }, + Finding { + cve_id: "CVE-2024-0002".into(), + cvss_score: 7.5, + severity: "high".into(), + affected_asset: "db-server-01".into(), + description: "SQL injection in query parser".into(), + remediation: "Apply patch KB-12345".into(), + internet_facing: false, + production: true, + }, + Finding { + cve_id: "CVE-2024-0003".into(), + cvss_score: 4.3, + severity: "medium".into(), + affected_asset: "staging-app-01".into(), + description: "Information disclosure via debug endpoint".into(), + remediation: "Disable debug endpoint in config".into(), + internet_facing: false, + production: false, + }, + ] + } + + #[test] + fn effective_priority_adds_context_bonuses() { + let mut f = Finding { + cve_id: String::new(), + cvss_score: 7.0, + severity: "high".into(), + affected_asset: "host".into(), + description: "test".into(), + remediation: String::new(), + internet_facing: false, + production: false, + }; + + assert!((effective_priority(&f) - 7.0).abs() < f64::EPSILON); + + f.internet_facing = true; + assert!((effective_priority(&f) - 9.0).abs() < f64::EPSILON); + + f.production = true; + assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); // capped + + // High CVSS + both bonuses still caps at 10.0 + f.cvss_score = 9.5; + assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); + } + + #[test] + fn cvss_to_severity_classification() { + assert_eq!(cvss_to_severity(9.8), "critical"); + assert_eq!(cvss_to_severity(9.0), "critical"); + assert_eq!(cvss_to_severity(8.5), "high"); + assert_eq!(cvss_to_severity(7.0), "high"); + assert_eq!(cvss_to_severity(5.0), "medium"); + assert_eq!(cvss_to_severity(4.0), "medium"); + assert_eq!(cvss_to_severity(3.9), "low"); + assert_eq!(cvss_to_severity(0.1), "low"); + assert_eq!(cvss_to_severity(0.0), "informational"); + } + + #[test] + fn parse_vulnerability_json_roundtrip() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "nessus".into(), + findings: sample_findings(), + }; + + let json = serde_json::to_string(&report).unwrap(); + let parsed = parse_vulnerability_json(&json).unwrap(); + + assert_eq!(parsed.scanner, "nessus"); + assert_eq!(parsed.findings.len(), 3); + assert_eq!(parsed.findings[0].cve_id, "CVE-2024-0001"); + } + + #[test] + fn parse_vulnerability_json_rejects_invalid() { + let result = parse_vulnerability_json("not json"); + assert!(result.is_err()); + } + + #[test] + fn generate_summary_includes_key_sections() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "qualys".into(), + findings: sample_findings(), + }; + + let summary = generate_summary(&report); + + assert!(summary.contains("qualys")); + assert!(summary.contains("Total findings:** 3")); + assert!(summary.contains("Critical: 1")); + assert!(summary.contains("High: 1")); + assert!(summary.contains("CVE-2024-0001")); + assert!(summary.contains("URGENT")); + assert!(summary.contains("internet-facing")); + } + + #[test] + fn parse_vulnerability_json_rejects_out_of_range_cvss() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "test".into(), + findings: vec![Finding { + cve_id: "CVE-2024-9999".into(), + cvss_score: 11.0, + severity: "critical".into(), + affected_asset: "host".into(), + description: "bad score".into(), + remediation: String::new(), + internet_facing: false, + production: false, + }], + }; + let json = serde_json::to_string(&report).unwrap(); + let result = parse_vulnerability_json(&json); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("cvss_score must be between 0.0 and 10.0")); + } + + #[test] + fn parse_vulnerability_json_rejects_negative_cvss() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "test".into(), + findings: vec![Finding { + cve_id: "CVE-2024-9998".into(), + cvss_score: -1.0, + severity: "low".into(), + affected_asset: "host".into(), + description: "negative score".into(), + remediation: String::new(), + internet_facing: false, + production: false, + }], + }; + let json = serde_json::to_string(&report).unwrap(); + let result = parse_vulnerability_json(&json); + assert!(result.is_err()); + } + + #[test] + fn generate_summary_empty_findings() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "nessus".into(), + findings: vec![], + }; + + let summary = generate_summary(&report); + assert!(summary.contains("No findings")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 75cfccbc1..89761a5f0 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -60,6 +60,7 @@ pub mod report_templates; pub mod schedule; pub mod schema; pub mod screenshot; +pub mod security_ops; pub mod shell; pub mod swarm; pub mod tool_search; @@ -111,6 +112,7 @@ pub use schedule::ScheduleTool; #[allow(unused_imports)] pub use schema::{CleaningStrategy, SchemaCleanr}; pub use screenshot::ScreenshotTool; +pub use security_ops::SecurityOpsTool; pub use shell::ShellTool; pub use swarm::SwarmTool; pub use tool_search::ToolSearchTool; @@ -375,6 +377,13 @@ pub fn all_tools_with_runtime( ))); } + // MCSS Security Operations + if root_config.security_ops.enabled { + tool_arcs.push(Arc::new(SecurityOpsTool::new( + root_config.security_ops.clone(), + ))); + } + // PDF extraction (feature-gated at compile time via rag-pdf) tool_arcs.push(Arc::new(PdfReadTool::new(security.clone()))); diff --git a/src/tools/security_ops.rs b/src/tools/security_ops.rs new file mode 100644 index 000000000..92ce18d06 --- /dev/null +++ b/src/tools/security_ops.rs @@ -0,0 +1,659 @@ +//! Security operations tool for managed cybersecurity service (MCSS) workflows. +//! +//! Provides alert triage, incident response playbook execution, vulnerability +//! scan parsing, and security report generation. All actions that modify state +//! enforce human approval gates unless explicitly configured otherwise. + +use async_trait::async_trait; +use serde_json::json; +use std::path::PathBuf; + +use super::traits::{Tool, ToolResult}; +use crate::config::SecurityOpsConfig; +use crate::security::playbook::{ + evaluate_step, load_playbooks, severity_level, Playbook, StepStatus, +}; +use crate::security::vulnerability::{generate_summary, parse_vulnerability_json}; + +/// Security operations tool — triage alerts, run playbooks, parse vulns, generate reports. +pub struct SecurityOpsTool { + config: SecurityOpsConfig, + playbooks: Vec, +} + +impl SecurityOpsTool { + pub fn new(config: SecurityOpsConfig) -> Self { + let playbooks_dir = expand_tilde(&config.playbooks_dir); + let playbooks = load_playbooks(&playbooks_dir); + Self { config, playbooks } + } + + /// Triage an alert: classify severity and recommend response. + fn triage_alert(&self, args: &serde_json::Value) -> anyhow::Result { + let alert = args + .get("alert") + .ok_or_else(|| anyhow::anyhow!("Missing required 'alert' parameter"))?; + + // Extract key fields for classification + let alert_type = alert + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let source = alert + .get("source") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let severity = alert + .get("severity") + .and_then(|v| v.as_str()) + .unwrap_or("medium"); + let description = alert + .get("description") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + // Classify and find matching playbooks + let matching_playbooks: Vec<&Playbook> = self + .playbooks + .iter() + .filter(|pb| { + severity_level(severity) >= severity_level(&pb.severity_filter) + && (pb.name.contains(alert_type) + || alert_type.contains(&pb.name) + || description + .to_lowercase() + .contains(&pb.name.replace('_', " "))) + }) + .collect(); + + let playbook_names: Vec<&str> = + matching_playbooks.iter().map(|p| p.name.as_str()).collect(); + + let output = json!({ + "classification": { + "alert_type": alert_type, + "source": source, + "severity": severity, + "severity_level": severity_level(severity), + "priority": if severity_level(severity) >= 3 { "immediate" } else { "standard" }, + }, + "recommended_playbooks": playbook_names, + "recommended_action": if matching_playbooks.is_empty() { + "Manual investigation required — no matching playbook found" + } else { + "Execute recommended playbook(s)" + }, + "auto_triage": self.config.auto_triage, + }); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } + + /// Execute a playbook step with approval gating. + fn run_playbook(&self, args: &serde_json::Value) -> anyhow::Result { + let playbook_name = args + .get("playbook") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'playbook' parameter"))?; + + let step_index = + usize::try_from(args.get("step").and_then(|v| v.as_u64()).ok_or_else(|| { + anyhow::anyhow!("Missing required 'step' parameter (0-based index)") + })?) + .map_err(|_| anyhow::anyhow!("'step' parameter value too large for this platform"))?; + + let alert_severity = args + .get("alert_severity") + .and_then(|v| v.as_str()) + .unwrap_or("medium"); + + let playbook = self + .playbooks + .iter() + .find(|p| p.name == playbook_name) + .ok_or_else(|| anyhow::anyhow!("Playbook '{}' not found", playbook_name))?; + + let result = evaluate_step( + playbook, + step_index, + alert_severity, + &self.config.max_auto_severity, + self.config.require_approval_for_actions, + ); + + let output = json!({ + "playbook": playbook_name, + "step_index": result.step_index, + "action": result.action, + "status": result.status.to_string(), + "message": result.message, + "requires_manual_approval": result.status == StepStatus::PendingApproval, + }); + + Ok(ToolResult { + success: result.status != StepStatus::Failed, + output: serde_json::to_string_pretty(&output)?, + error: if result.status == StepStatus::Failed { + Some(result.message) + } else { + None + }, + }) + } + + /// Parse vulnerability scan results. + fn parse_vulnerability(&self, args: &serde_json::Value) -> anyhow::Result { + let scan_data = args + .get("scan_data") + .ok_or_else(|| anyhow::anyhow!("Missing required 'scan_data' parameter"))?; + + let json_str = if scan_data.is_string() { + scan_data.as_str().unwrap().to_string() + } else { + serde_json::to_string(scan_data)? + }; + + let report = parse_vulnerability_json(&json_str)?; + let summary = generate_summary(&report); + + let output = json!({ + "scanner": report.scanner, + "scan_date": report.scan_date.to_rfc3339(), + "total_findings": report.findings.len(), + "by_severity": { + "critical": report.findings.iter().filter(|f| f.severity == "critical").count(), + "high": report.findings.iter().filter(|f| f.severity == "high").count(), + "medium": report.findings.iter().filter(|f| f.severity == "medium").count(), + "low": report.findings.iter().filter(|f| f.severity == "low").count(), + }, + "summary": summary, + }); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } + + /// Generate a client-facing security posture report. + fn generate_report(&self, args: &serde_json::Value) -> anyhow::Result { + let client_name = args + .get("client_name") + .and_then(|v| v.as_str()) + .unwrap_or("Client"); + let period = args + .get("period") + .and_then(|v| v.as_str()) + .unwrap_or("current"); + let alert_stats = args.get("alert_stats"); + let vuln_summary = args + .get("vuln_summary") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let report = format!( + "# Security Posture Report — {client_name}\n\ + **Period:** {period}\n\ + **Generated:** {}\n\n\ + ## Executive Summary\n\n\ + This report provides an overview of the security posture for {client_name} \ + during the {period} period.\n\n\ + ## Alert Summary\n\n\ + {}\n\n\ + ## Vulnerability Assessment\n\n\ + {}\n\n\ + ## Recommendations\n\n\ + 1. Address all critical and high-severity findings immediately\n\ + 2. Review and update incident response playbooks quarterly\n\ + 3. Conduct regular vulnerability scans on all internet-facing assets\n\ + 4. Ensure all endpoints have current security patches\n\n\ + ---\n\ + *Report generated by ZeroClaw MCSS Agent*\n", + chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"), + alert_stats + .map(|s| serde_json::to_string_pretty(s).unwrap_or_default()) + .unwrap_or_else(|| "No alert statistics provided.".into()), + if vuln_summary.is_empty() { + "No vulnerability data provided." + } else { + vuln_summary + }, + ); + + Ok(ToolResult { + success: true, + output: report, + error: None, + }) + } + + /// List available playbooks. + fn list_playbooks(&self) -> anyhow::Result { + if self.playbooks.is_empty() { + return Ok(ToolResult { + success: true, + output: "No playbooks available.".into(), + error: None, + }); + } + + let playbook_list: Vec = self + .playbooks + .iter() + .map(|pb| { + json!({ + "name": pb.name, + "description": pb.description, + "steps": pb.steps.len(), + "severity_filter": pb.severity_filter, + "auto_approve_steps": pb.auto_approve_steps, + }) + }) + .collect(); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&playbook_list)?, + error: None, + }) + } + + /// Summarize alert volume, categories, and resolution times. + fn alert_stats(&self, args: &serde_json::Value) -> anyhow::Result { + let alerts = args + .get("alerts") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'alerts' array parameter"))?; + + let total = alerts.len(); + let mut by_severity = std::collections::HashMap::new(); + let mut by_category = std::collections::HashMap::new(); + let mut resolved_count = 0u64; + let mut total_resolution_secs = 0u64; + + for alert in alerts { + let severity = alert + .get("severity") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + *by_severity.entry(severity.to_string()).or_insert(0u64) += 1; + + let category = alert + .get("category") + .and_then(|v| v.as_str()) + .unwrap_or("uncategorized"); + *by_category.entry(category.to_string()).or_insert(0u64) += 1; + + if let Some(resolution_secs) = alert.get("resolution_secs").and_then(|v| v.as_u64()) { + resolved_count += 1; + total_resolution_secs += resolution_secs; + } + } + + let avg_resolution = if resolved_count > 0 { + total_resolution_secs as f64 / resolved_count as f64 + } else { + 0.0 + }; + + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let avg_resolution_secs_u64 = avg_resolution.max(0.0) as u64; + + let output = json!({ + "total_alerts": total, + "resolved": resolved_count, + "unresolved": total as u64 - resolved_count, + "by_severity": by_severity, + "by_category": by_category, + "avg_resolution_secs": avg_resolution, + "avg_resolution_human": format_duration_secs(avg_resolution_secs_u64), + }); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } +} + +fn format_duration_secs(secs: u64) -> String { + if secs < 60 { + format!("{secs}s") + } else if secs < 3600 { + format!("{}m {}s", secs / 60, secs % 60) + } else { + format!("{}h {}m", secs / 3600, (secs % 3600) / 60) + } +} + +/// Expand ~ to home directory. +fn expand_tilde(path: &str) -> PathBuf { + if let Some(rest) = path.strip_prefix("~/") { + if let Some(user_dirs) = directories::UserDirs::new() { + return user_dirs.home_dir().join(rest); + } + } + PathBuf::from(path) +} + +#[async_trait] +impl Tool for SecurityOpsTool { + fn name(&self) -> &str { + "security_ops" + } + + fn description(&self) -> &str { + "Security operations tool for managed cybersecurity services. Actions: \ + triage_alert (classify/prioritize alerts), run_playbook (execute incident response steps), \ + parse_vulnerability (parse scan results), generate_report (create security posture reports), \ + list_playbooks (list available playbooks), alert_stats (summarize alert metrics)." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "required": ["action"], + "properties": { + "action": { + "type": "string", + "enum": ["triage_alert", "run_playbook", "parse_vulnerability", "generate_report", "list_playbooks", "alert_stats"], + "description": "The security operation to perform" + }, + "alert": { + "type": "object", + "description": "Alert JSON for triage_alert (requires: type, severity; optional: source, description)" + }, + "playbook": { + "type": "string", + "description": "Playbook name for run_playbook" + }, + "step": { + "type": "integer", + "description": "0-based step index for run_playbook" + }, + "alert_severity": { + "type": "string", + "description": "Alert severity context for run_playbook" + }, + "scan_data": { + "description": "Vulnerability scan data (JSON string or object) for parse_vulnerability" + }, + "client_name": { + "type": "string", + "description": "Client name for generate_report" + }, + "period": { + "type": "string", + "description": "Reporting period for generate_report" + }, + "alert_stats": { + "type": "object", + "description": "Alert statistics to include in generate_report" + }, + "vuln_summary": { + "type": "string", + "description": "Vulnerability summary to include in generate_report" + }, + "alerts": { + "type": "array", + "description": "Array of alert objects for alert_stats" + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = args + .get("action") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?; + + match action { + "triage_alert" => self.triage_alert(&args), + "run_playbook" => self.run_playbook(&args), + "parse_vulnerability" => self.parse_vulnerability(&args), + "generate_report" => self.generate_report(&args), + "list_playbooks" => self.list_playbooks(), + "alert_stats" => self.alert_stats(&args), + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{action}'. Valid: triage_alert, run_playbook, \ + parse_vulnerability, generate_report, list_playbooks, alert_stats" + )), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_config() -> SecurityOpsConfig { + SecurityOpsConfig { + enabled: true, + playbooks_dir: "/nonexistent".into(), + auto_triage: false, + require_approval_for_actions: true, + max_auto_severity: "low".into(), + report_output_dir: "/tmp/reports".into(), + siem_integration: None, + } + } + + fn test_tool() -> SecurityOpsTool { + SecurityOpsTool::new(test_config()) + } + + #[test] + fn tool_name_and_schema() { + let tool = test_tool(); + assert_eq!(tool.name(), "security_ops"); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["action"].is_object()); + assert!(schema["required"] + .as_array() + .unwrap() + .contains(&json!("action"))); + } + + #[tokio::test] + async fn triage_alert_classifies_severity() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "triage_alert", + "alert": { + "type": "suspicious_login", + "source": "siem", + "severity": "high", + "description": "Multiple failed login attempts followed by successful login" + } + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["classification"]["severity"], "high"); + assert_eq!(output["classification"]["priority"], "immediate"); + // Should match suspicious_login playbook + let playbooks = output["recommended_playbooks"].as_array().unwrap(); + assert!(playbooks.iter().any(|p| p == "suspicious_login")); + } + + #[tokio::test] + async fn triage_alert_missing_alert_param() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "triage_alert"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn run_playbook_requires_approval() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "run_playbook", + "playbook": "suspicious_login", + "step": 2, + "alert_severity": "high" + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["status"], "pending_approval"); + assert_eq!(output["requires_manual_approval"], true); + } + + #[tokio::test] + async fn run_playbook_executes_safe_step() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "run_playbook", + "playbook": "suspicious_login", + "step": 0, + "alert_severity": "medium" + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["status"], "completed"); + } + + #[tokio::test] + async fn run_playbook_not_found() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "run_playbook", + "playbook": "nonexistent", + "step": 0 + })) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn parse_vulnerability_valid_report() { + let tool = test_tool(); + let scan_data = json!({ + "scan_date": "2025-01-15T10:00:00Z", + "scanner": "nessus", + "findings": [ + { + "cve_id": "CVE-2024-0001", + "cvss_score": 9.8, + "severity": "critical", + "affected_asset": "web-01", + "description": "RCE in web framework", + "remediation": "Upgrade", + "internet_facing": true, + "production": true + } + ] + }); + + let result = tool + .execute(json!({ + "action": "parse_vulnerability", + "scan_data": scan_data + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["total_findings"], 1); + assert_eq!(output["by_severity"]["critical"], 1); + } + + #[tokio::test] + async fn generate_report_produces_markdown() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "generate_report", + "client_name": "ZeroClaw Corp", + "period": "Q1 2025" + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("ZeroClaw Corp")); + assert!(result.output.contains("Q1 2025")); + assert!(result.output.contains("Security Posture Report")); + } + + #[tokio::test] + async fn list_playbooks_returns_builtins() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "list_playbooks"})) + .await + .unwrap(); + + assert!(result.success); + let output: Vec = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output.len(), 4); + let names: Vec<&str> = output.iter().map(|p| p["name"].as_str().unwrap()).collect(); + assert!(names.contains(&"suspicious_login")); + assert!(names.contains(&"malware_detected")); + } + + #[tokio::test] + async fn alert_stats_computes_summary() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "alert_stats", + "alerts": [ + {"severity": "critical", "category": "malware", "resolution_secs": 3600}, + {"severity": "high", "category": "phishing", "resolution_secs": 1800}, + {"severity": "medium", "category": "malware"}, + {"severity": "low", "category": "policy_violation", "resolution_secs": 600} + ] + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["total_alerts"], 4); + assert_eq!(output["resolved"], 3); + assert_eq!(output["unresolved"], 1); + assert_eq!(output["by_severity"]["critical"], 1); + assert_eq!(output["by_category"]["malware"], 2); + } + + #[tokio::test] + async fn unknown_action_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "bad_action"})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown action")); + } + + #[test] + fn format_duration_secs_readable() { + assert_eq!(format_duration_secs(45), "45s"); + assert_eq!(format_duration_secs(125), "2m 5s"); + assert_eq!(format_duration_secs(3665), "1h 1m"); + } +} From 861dd3e2e9473b63fef9f92868bdd2ec74f9e6cc Mon Sep 17 00:00:00 2001 From: Argenis Date: Mon, 16 Mar 2026 02:35:44 -0400 Subject: [PATCH 11/11] feat(tools): add backup/restore and data management tools (#3662) Add BackupTool for creating, listing, verifying, and restoring timestamped workspace backups with SHA-256 manifest integrity checking. Add DataManagementTool for retention status, time-based purge, and storage statistics. Both tools are config-driven via new BackupConfig and DataRetentionConfig sections. Original work by @rareba. Rebased on latest master with conflict resolution for SwarmConfig/SwarmStrategy exports and swarm tool registration, and added missing approval_manager fields in ChannelRuntimeContext test constructors. Co-authored-by: Claude Opus 4.6 --- src/config/mod.rs | 18 +- src/config/schema.rs | 109 ++++++++ src/onboard/wizard.rs | 4 + src/tools/backup_tool.rs | 466 +++++++++++++++++++++++++++++++++++ src/tools/data_management.rs | 320 ++++++++++++++++++++++++ src/tools/mod.rs | 21 ++ 6 files changed, 929 insertions(+), 9 deletions(-) create mode 100644 src/tools/backup_tool.rs create mode 100644 src/tools/data_management.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 45fbd6ff7..1cf2960a0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,15 +6,15 @@ pub mod workspace; pub use schema::{ apply_runtime_proxy_to_builder, build_runtime_proxy_client, build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config, - AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, - BuiltinHooksConfig, ChannelsConfig, ClassificationRule, ComposioConfig, Config, CostConfig, - CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, - ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, - GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, - HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, - McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, - MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, - ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, + AgentConfig, AuditConfig, AutonomyConfig, BackupConfig, BrowserComputerUseConfig, + BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, ComposioConfig, Config, + CostConfig, CronConfig, DataRetentionConfig, DelegateAgentConfig, DiscordConfig, + DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, + FeishuConfig, GatewayConfig, GoogleTtsConfig, HardwareConfig, HardwareTransport, + HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, + MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, + ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, + NotionConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, diff --git a/src/config/schema.rs b/src/config/schema.rs index 9ca82e547..ba9b6aa43 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -122,6 +122,12 @@ pub struct Config { /// Security subsystem configuration (`[security]`). #[serde(default)] + /// Backup tool configuration (`[backup]`). + pub backup: BackupConfig, + + /// Data retention and purge configuration (`[data_retention]`). + pub data_retention: DataRetentionConfig, + pub security: SecurityConfig, /// Managed cybersecurity service configuration (`[security_ops]`). @@ -1850,6 +1856,103 @@ impl Default for ProjectIntelConfig { } } +// ── Backup ────────────────────────────────────────────────────── + +/// Backup tool configuration (`[backup]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct BackupConfig { + /// Enable the `backup` tool. + #[serde(default = "default_true")] + pub enabled: bool, + /// Maximum number of backups to keep (oldest are pruned). + #[serde(default = "default_backup_max_keep")] + pub max_keep: usize, + /// Workspace subdirectories to include in backups. + #[serde(default = "default_backup_include_dirs")] + pub include_dirs: Vec, + /// Output directory for backup archives (relative to workspace root). + #[serde(default = "default_backup_destination_dir")] + pub destination_dir: String, + /// Optional cron expression for scheduled automatic backups. + #[serde(default)] + pub schedule_cron: Option, + /// IANA timezone for `schedule_cron`. + #[serde(default)] + pub schedule_timezone: Option, + /// Compress backup archives. + #[serde(default = "default_true")] + pub compress: bool, + /// Encrypt backup archives (requires a configured secret store key). + #[serde(default)] + pub encrypt: bool, +} + +fn default_backup_max_keep() -> usize { + 10 +} + +fn default_backup_include_dirs() -> Vec { + vec![ + "config".into(), + "memory".into(), + "audit".into(), + "knowledge".into(), + ] +} + +fn default_backup_destination_dir() -> String { + "state/backups".into() +} + +impl Default for BackupConfig { + fn default() -> Self { + Self { + enabled: true, + max_keep: default_backup_max_keep(), + include_dirs: default_backup_include_dirs(), + destination_dir: default_backup_destination_dir(), + schedule_cron: None, + schedule_timezone: None, + compress: true, + encrypt: false, + } + } +} + +// ── Data Retention ────────────────────────────────────────────── + +/// Data retention and purge configuration (`[data_retention]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct DataRetentionConfig { + /// Enable the `data_management` tool. + #[serde(default)] + pub enabled: bool, + /// Days of data to retain before purge eligibility. + #[serde(default = "default_retention_days")] + pub retention_days: u64, + /// Preview what would be deleted without actually removing anything. + #[serde(default)] + pub dry_run: bool, + /// Limit retention enforcement to specific data categories (empty = all). + #[serde(default)] + pub categories: Vec, +} + +fn default_retention_days() -> u64 { + 90 +} + +impl Default for DataRetentionConfig { + fn default() -> Self { + Self { + enabled: false, + retention_days: default_retention_days(), + dry_run: false, + categories: Vec::new(), + } + } +} + // ── Proxy ─────────────────────────────────────────────────────── /// Proxy application scope — determines which outbound traffic uses the proxy. @@ -4798,6 +4901,8 @@ impl Default for Config { extra_headers: HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: BackupConfig::default(), + data_retention: DataRetentionConfig::default(), security: SecurityConfig::default(), security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), @@ -7046,6 +7151,8 @@ default_temperature = 0.7 allowed_roots: vec![], non_cli_excluded_tools: vec![], }, + backup: BackupConfig::default(), + data_retention: DataRetentionConfig::default(), security: SecurityConfig::default(), security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig { @@ -7389,6 +7496,8 @@ tool_dispatcher = "xml" extra_headers: HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: BackupConfig::default(), + data_retention: DataRetentionConfig::default(), security: SecurityConfig::default(), security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 210114954..c1acf17f9 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -143,6 +143,8 @@ pub async fn run_wizard(force: bool) -> Result { extra_headers: std::collections::HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: crate::config::BackupConfig::default(), + data_retention: crate::config::DataRetentionConfig::default(), security: crate::config::SecurityConfig::default(), security_ops: crate::config::SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), @@ -507,6 +509,8 @@ async fn run_quick_setup_with_home( extra_headers: std::collections::HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: crate::config::BackupConfig::default(), + data_retention: crate::config::DataRetentionConfig::default(), security: crate::config::SecurityConfig::default(), security_ops: crate::config::SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), diff --git a/src/tools/backup_tool.rs b/src/tools/backup_tool.rs new file mode 100644 index 000000000..fe6ea248d --- /dev/null +++ b/src/tools/backup_tool.rs @@ -0,0 +1,466 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use tokio::fs; + +/// Workspace backup tool: create, list, verify, and restore timestamped backups +/// with SHA-256 manifest integrity checking. +pub struct BackupTool { + workspace_dir: PathBuf, + include_dirs: Vec, + max_keep: usize, +} + +impl BackupTool { + pub fn new(workspace_dir: PathBuf, include_dirs: Vec, max_keep: usize) -> Self { + Self { + workspace_dir, + include_dirs, + max_keep, + } + } + + fn backups_dir(&self) -> PathBuf { + self.workspace_dir.join("backups") + } + + async fn cmd_create(&self) -> anyhow::Result { + let ts = chrono::Utc::now().format("%Y%m%dT%H%M%SZ"); + let name = format!("backup-{ts}"); + let backup_dir = self.backups_dir().join(&name); + fs::create_dir_all(&backup_dir).await?; + + for sub in &self.include_dirs { + let src = self.workspace_dir.join(sub); + if src.is_dir() { + let dst = backup_dir.join(sub); + copy_dir_recursive(&src, &dst).await?; + } + } + + let checksums = compute_checksums(&backup_dir).await?; + let file_count = checksums.len(); + let manifest = serde_json::to_string_pretty(&checksums)?; + fs::write(backup_dir.join("manifest.json"), &manifest).await?; + + // Enforce max_keep: remove oldest backups beyond the limit. + self.enforce_max_keep().await?; + + Ok(ToolResult { + success: true, + output: json!({ + "backup": name, + "file_count": file_count, + }) + .to_string(), + error: None, + }) + } + + async fn enforce_max_keep(&self) -> anyhow::Result<()> { + let mut backups = self.list_backup_dirs().await?; + // Sorted newest-first; drop excess from the tail. + while backups.len() > self.max_keep { + if let Some(old) = backups.pop() { + fs::remove_dir_all(old).await?; + } + } + Ok(()) + } + + async fn list_backup_dirs(&self) -> anyhow::Result> { + let dir = self.backups_dir(); + if !dir.is_dir() { + return Ok(Vec::new()); + } + let mut entries = Vec::new(); + let mut rd = fs::read_dir(&dir).await?; + while let Some(e) = rd.next_entry().await? { + let p = e.path(); + if p.is_dir() && e.file_name().to_string_lossy().starts_with("backup-") { + entries.push(p); + } + } + entries.sort(); + entries.reverse(); // newest first + Ok(entries) + } + + async fn cmd_list(&self) -> anyhow::Result { + let dirs = self.list_backup_dirs().await?; + let mut items = Vec::new(); + for d in &dirs { + let name = d + .file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_default(); + let manifest_path = d.join("manifest.json"); + let file_count = if manifest_path.is_file() { + let data = fs::read_to_string(&manifest_path).await?; + let map: HashMap = serde_json::from_str(&data).unwrap_or_default(); + map.len() + } else { + 0 + }; + let meta = fs::metadata(d).await?; + let created = meta + .created() + .or_else(|_| meta.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH); + let dt: chrono::DateTime = created.into(); + items.push(json!({ + "name": name, + "file_count": file_count, + "created": dt.to_rfc3339(), + })); + } + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&items)?, + error: None, + }) + } + + async fn cmd_verify(&self, backup_name: &str) -> anyhow::Result { + let backup_dir = self.backups_dir().join(backup_name); + if !backup_dir.is_dir() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Backup not found: {backup_name}")), + }); + } + let manifest_path = backup_dir.join("manifest.json"); + let data = fs::read_to_string(&manifest_path).await?; + let expected: HashMap = serde_json::from_str(&data)?; + let actual = compute_checksums(&backup_dir).await?; + + let mut mismatches = Vec::new(); + for (path, expected_hash) in &expected { + match actual.get(path) { + Some(actual_hash) if actual_hash == expected_hash => {} + Some(actual_hash) => mismatches.push(json!({ + "file": path, + "expected": expected_hash, + "actual": actual_hash, + })), + None => mismatches.push(json!({ + "file": path, + "error": "missing", + })), + } + } + let pass = mismatches.is_empty(); + Ok(ToolResult { + success: pass, + output: json!({ + "backup": backup_name, + "pass": pass, + "checked": expected.len(), + "mismatches": mismatches, + }) + .to_string(), + error: if pass { + None + } else { + Some("Integrity check failed".into()) + }, + }) + } + + async fn cmd_restore(&self, backup_name: &str, confirm: bool) -> anyhow::Result { + let backup_dir = self.backups_dir().join(backup_name); + if !backup_dir.is_dir() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Backup not found: {backup_name}")), + }); + } + + // Collect restorable subdirectories (skip manifest.json). + let mut restore_items: Vec = Vec::new(); + let mut rd = fs::read_dir(&backup_dir).await?; + while let Some(e) = rd.next_entry().await? { + let name = e.file_name().to_string_lossy().to_string(); + if name == "manifest.json" { + continue; + } + if e.path().is_dir() { + restore_items.push(name); + } + } + + if !confirm { + return Ok(ToolResult { + success: true, + output: json!({ + "dry_run": true, + "backup": backup_name, + "would_restore": restore_items, + }) + .to_string(), + error: None, + }); + } + + for sub in &restore_items { + let src = backup_dir.join(sub); + let dst = self.workspace_dir.join(sub); + copy_dir_recursive(&src, &dst).await?; + } + Ok(ToolResult { + success: true, + output: json!({ + "restored": backup_name, + "directories": restore_items, + }) + .to_string(), + error: None, + }) + } +} + +#[async_trait] +impl Tool for BackupTool { + fn name(&self) -> &str { + "backup" + } + + fn description(&self) -> &str { + "Create, list, verify, and restore workspace backups" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "enum": ["create", "list", "verify", "restore"], + "description": "Backup command to execute" + }, + "backup_name": { + "type": "string", + "description": "Name of backup (for verify/restore)" + }, + "confirm": { + "type": "boolean", + "description": "Confirm restore (required for actual restore, default false)" + } + }, + "required": ["command"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let command = match args.get("command").and_then(|v| v.as_str()) { + Some(c) => c, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'command' parameter".into()), + }); + } + }; + + match command { + "create" => self.cmd_create().await, + "list" => self.cmd_list().await, + "verify" => { + let name = args + .get("backup_name") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'backup_name' for verify"))?; + self.cmd_verify(name).await + } + "restore" => { + let name = args + .get("backup_name") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'backup_name' for restore"))?; + let confirm = args + .get("confirm") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + self.cmd_restore(name, confirm).await + } + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown command: {other}")), + }), + } + } +} + +// -- Helpers ------------------------------------------------------------------ + +async fn copy_dir_recursive(src: &Path, dst: &Path) -> anyhow::Result<()> { + fs::create_dir_all(dst).await?; + let mut rd = fs::read_dir(src).await?; + while let Some(entry) = rd.next_entry().await? { + let src_path = entry.path(); + let dst_path = dst.join(entry.file_name()); + if src_path.is_dir() { + Box::pin(copy_dir_recursive(&src_path, &dst_path)).await?; + } else { + fs::copy(&src_path, &dst_path).await?; + } + } + Ok(()) +} + +async fn compute_checksums(dir: &Path) -> anyhow::Result> { + let mut map = HashMap::new(); + let base = dir.to_path_buf(); + walk_and_hash(&base, dir, &mut map).await?; + Ok(map) +} + +async fn walk_and_hash( + base: &Path, + dir: &Path, + map: &mut HashMap, +) -> anyhow::Result<()> { + let mut rd = fs::read_dir(dir).await?; + while let Some(entry) = rd.next_entry().await? { + let path = entry.path(); + if path.is_dir() { + Box::pin(walk_and_hash(base, &path, map)).await?; + } else { + let rel = path + .strip_prefix(base) + .unwrap_or(&path) + .to_string_lossy() + .replace('\\', "/"); + if rel == "manifest.json" { + continue; + } + let bytes = fs::read(&path).await?; + let hash = hex::encode(Sha256::digest(&bytes)); + map.insert(rel, hash); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn make_tool(tmp: &TempDir) -> BackupTool { + BackupTool::new( + tmp.path().to_path_buf(), + vec!["config".into(), "memory".into()], + 10, + ) + } + + #[tokio::test] + async fn create_backup_produces_manifest() { + let tmp = TempDir::new().unwrap(); + // Seed workspace subdirectories. + let cfg_dir = tmp.path().join("config"); + std::fs::create_dir_all(&cfg_dir).unwrap(); + std::fs::write(cfg_dir.join("a.toml"), "key = 1").unwrap(); + + let tool = make_tool(&tmp); + let res = tool.execute(json!({"command": "create"})).await.unwrap(); + assert!(res.success, "create failed: {:?}", res.error); + + let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert_eq!(parsed["file_count"], 1); + + // Manifest should exist inside the backup directory. + let backup_name = parsed["backup"].as_str().unwrap(); + let manifest = tmp + .path() + .join("backups") + .join(backup_name) + .join("manifest.json"); + assert!(manifest.exists()); + } + + #[tokio::test] + async fn verify_backup_detects_corruption() { + let tmp = TempDir::new().unwrap(); + let cfg_dir = tmp.path().join("config"); + std::fs::create_dir_all(&cfg_dir).unwrap(); + std::fs::write(cfg_dir.join("a.toml"), "original").unwrap(); + + let tool = make_tool(&tmp); + let res = tool.execute(json!({"command": "create"})).await.unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + let name = parsed["backup"].as_str().unwrap(); + + // Corrupt a file inside the backup. + let backed_up = tmp.path().join("backups").join(name).join("config/a.toml"); + std::fs::write(&backed_up, "corrupted").unwrap(); + + let res = tool + .execute(json!({"command": "verify", "backup_name": name})) + .await + .unwrap(); + assert!(!res.success); + let v: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert!(!v["mismatches"].as_array().unwrap().is_empty()); + } + + #[tokio::test] + async fn restore_requires_confirmation() { + let tmp = TempDir::new().unwrap(); + let cfg_dir = tmp.path().join("config"); + std::fs::create_dir_all(&cfg_dir).unwrap(); + std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap(); + + let tool = make_tool(&tmp); + let res = tool.execute(json!({"command": "create"})).await.unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + let name = parsed["backup"].as_str().unwrap(); + + // Without confirm: dry-run. + let res = tool + .execute(json!({"command": "restore", "backup_name": name})) + .await + .unwrap(); + assert!(res.success); + let v: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert_eq!(v["dry_run"], true); + + // With confirm: actual restore. + let res = tool + .execute(json!({"command": "restore", "backup_name": name, "confirm": true})) + .await + .unwrap(); + assert!(res.success); + let v: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert!(v.get("restored").is_some()); + } + + #[tokio::test] + async fn list_backups_sorted_newest_first() { + let tmp = TempDir::new().unwrap(); + let cfg_dir = tmp.path().join("config"); + std::fs::create_dir_all(&cfg_dir).unwrap(); + std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap(); + + let tool = make_tool(&tmp); + tool.execute(json!({"command": "create"})).await.unwrap(); + // Delay to ensure different second-resolution timestamps. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + tool.execute(json!({"command": "create"})).await.unwrap(); + + let res = tool.execute(json!({"command": "list"})).await.unwrap(); + assert!(res.success); + let items: Vec = serde_json::from_str(&res.output).unwrap(); + assert_eq!(items.len(), 2); + // Newest first by name (ISO8601 names sort lexicographically). + assert!(items[0]["name"].as_str().unwrap() >= items[1]["name"].as_str().unwrap()); + } +} diff --git a/src/tools/data_management.rs b/src/tools/data_management.rs new file mode 100644 index 000000000..b6fc6538e --- /dev/null +++ b/src/tools/data_management.rs @@ -0,0 +1,320 @@ +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::path::{Path, PathBuf}; +use tokio::fs; + +/// Workspace data lifecycle tool: retention status, time-based purge, and +/// storage statistics. +pub struct DataManagementTool { + workspace_dir: PathBuf, + retention_days: u64, +} + +impl DataManagementTool { + pub fn new(workspace_dir: PathBuf, retention_days: u64) -> Self { + Self { + workspace_dir, + retention_days, + } + } + + async fn cmd_retention_status(&self) -> anyhow::Result { + let cutoff = chrono::Utc::now() + - chrono::Duration::days(i64::try_from(self.retention_days).unwrap_or(i64::MAX)); + let cutoff_ts = cutoff.timestamp().try_into().unwrap_or(0u64); + let count = count_files_older_than(&self.workspace_dir, cutoff_ts).await?; + + Ok(ToolResult { + success: true, + output: json!({ + "retention_days": self.retention_days, + "cutoff": cutoff.to_rfc3339(), + "affected_files": count, + }) + .to_string(), + error: None, + }) + } + + async fn cmd_purge(&self, dry_run: bool) -> anyhow::Result { + let cutoff = chrono::Utc::now() + - chrono::Duration::days(i64::try_from(self.retention_days).unwrap_or(i64::MAX)); + let cutoff_ts: u64 = cutoff.timestamp().try_into().unwrap_or(0); + let (deleted, bytes) = purge_old_files(&self.workspace_dir, cutoff_ts, dry_run).await?; + + Ok(ToolResult { + success: true, + output: json!({ + "dry_run": dry_run, + "files": deleted, + "bytes_freed": bytes, + "bytes_freed_human": format_bytes(bytes), + }) + .to_string(), + error: None, + }) + } + + async fn cmd_stats(&self) -> anyhow::Result { + let (total_files, total_bytes, breakdown) = dir_stats(&self.workspace_dir).await?; + Ok(ToolResult { + success: true, + output: json!({ + "total_files": total_files, + "total_size": total_bytes, + "total_size_human": format_bytes(total_bytes), + "subdirectories": breakdown, + }) + .to_string(), + error: None, + }) + } +} + +#[async_trait] +impl Tool for DataManagementTool { + fn name(&self) -> &str { + "data_management" + } + + fn description(&self) -> &str { + "Workspace data retention, purge, and storage statistics" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "enum": ["retention_status", "purge", "stats"], + "description": "Data management command" + }, + "dry_run": { + "type": "boolean", + "description": "If true, purge only lists what would be deleted (default true)" + } + }, + "required": ["command"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let command = match args.get("command").and_then(|v| v.as_str()) { + Some(c) => c, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'command' parameter".into()), + }); + } + }; + + match command { + "retention_status" => self.cmd_retention_status().await, + "purge" => { + let dry_run = args + .get("dry_run") + .and_then(|v| v.as_bool()) + .unwrap_or(true); + self.cmd_purge(dry_run).await + } + "stats" => self.cmd_stats().await, + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown command: {other}")), + }), + } + } +} + +// -- Helpers ------------------------------------------------------------------ + +fn format_bytes(bytes: u64) -> String { + const KB: u64 = 1024; + const MB: u64 = 1024 * KB; + const GB: u64 = 1024 * MB; + if bytes >= GB { + format!("{:.1} GB", bytes as f64 / GB as f64) + } else if bytes >= MB { + format!("{:.1} MB", bytes as f64 / MB as f64) + } else if bytes >= KB { + format!("{:.1} KB", bytes as f64 / KB as f64) + } else { + format!("{bytes} B") + } +} + +async fn count_files_older_than(dir: &Path, cutoff_epoch: u64) -> anyhow::Result { + let mut count = 0; + if !dir.is_dir() { + return Ok(0); + } + let mut rd = fs::read_dir(dir).await?; + while let Some(entry) = rd.next_entry().await? { + let path = entry.path(); + if path.is_dir() { + count += Box::pin(count_files_older_than(&path, cutoff_epoch)).await?; + } else if let Ok(meta) = fs::metadata(&path).await { + let modified = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH); + let epoch = modified + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + if epoch < cutoff_epoch { + count += 1; + } + } + } + Ok(count) +} + +async fn purge_old_files( + dir: &Path, + cutoff_epoch: u64, + dry_run: bool, +) -> anyhow::Result<(usize, u64)> { + let mut deleted = 0usize; + let mut bytes = 0u64; + if !dir.is_dir() { + return Ok((0, 0)); + } + let mut rd = fs::read_dir(dir).await?; + while let Some(entry) = rd.next_entry().await? { + let path = entry.path(); + if path.is_dir() { + let (d, b) = Box::pin(purge_old_files(&path, cutoff_epoch, dry_run)).await?; + deleted += d; + bytes += b; + } else if let Ok(meta) = fs::metadata(&path).await { + let modified = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH); + let epoch = modified + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + if epoch < cutoff_epoch { + bytes += meta.len(); + deleted += 1; + if !dry_run { + let _ = fs::remove_file(&path).await; + } + } + } + } + Ok((deleted, bytes)) +} + +async fn dir_stats(root: &Path) -> anyhow::Result<(usize, u64, serde_json::Value)> { + let mut total_files = 0usize; + let mut total_bytes = 0u64; + let mut breakdown = serde_json::Map::new(); + + if !root.is_dir() { + return Ok((0, 0, serde_json::Value::Object(breakdown))); + } + + let mut rd = fs::read_dir(root).await?; + while let Some(entry) = rd.next_entry().await? { + let path = entry.path(); + if path.is_dir() { + let name = entry.file_name().to_string_lossy().to_string(); + let (f, b) = count_dir_contents(&path).await?; + total_files += f; + total_bytes += b; + breakdown.insert( + name, + json!({"files": f, "size": b, "size_human": format_bytes(b)}), + ); + } else if let Ok(meta) = fs::metadata(&path).await { + total_files += 1; + total_bytes += meta.len(); + } + } + Ok(( + total_files, + total_bytes, + serde_json::Value::Object(breakdown), + )) +} + +async fn count_dir_contents(dir: &Path) -> anyhow::Result<(usize, u64)> { + let mut files = 0usize; + let mut bytes = 0u64; + let mut rd = fs::read_dir(dir).await?; + while let Some(entry) = rd.next_entry().await? { + let path = entry.path(); + if path.is_dir() { + let (f, b) = Box::pin(count_dir_contents(&path)).await?; + files += f; + bytes += b; + } else if let Ok(meta) = fs::metadata(&path).await { + files += 1; + bytes += meta.len(); + } + } + Ok((files, bytes)) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn make_tool(tmp: &TempDir) -> DataManagementTool { + DataManagementTool::new(tmp.path().to_path_buf(), 90) + } + + #[tokio::test] + async fn retention_status_reports_correct_cutoff() { + let tmp = TempDir::new().unwrap(); + let tool = make_tool(&tmp); + let res = tool + .execute(json!({"command": "retention_status"})) + .await + .unwrap(); + assert!(res.success); + let v: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert_eq!(v["retention_days"], 90); + assert!(v["cutoff"].is_string()); + } + + #[tokio::test] + async fn purge_dry_run_does_not_delete() { + let tmp = TempDir::new().unwrap(); + // Create a file with an old modification time by writing it (it will have + // the current mtime, so it should not be purged with a 90-day retention). + std::fs::write(tmp.path().join("recent.txt"), "data").unwrap(); + + let tool = make_tool(&tmp); + let res = tool + .execute(json!({"command": "purge", "dry_run": true})) + .await + .unwrap(); + assert!(res.success); + let v: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert_eq!(v["dry_run"], true); + // Recent file should not be counted for purge. + assert_eq!(v["files"], 0); + // File still exists. + assert!(tmp.path().join("recent.txt").exists()); + } + + #[tokio::test] + async fn stats_counts_files_correctly() { + let tmp = TempDir::new().unwrap(); + let sub = tmp.path().join("subdir"); + std::fs::create_dir_all(&sub).unwrap(); + std::fs::write(sub.join("a.txt"), "hello").unwrap(); + std::fs::write(sub.join("b.txt"), "world").unwrap(); + std::fs::write(tmp.path().join("root.txt"), "top").unwrap(); + + let tool = make_tool(&tmp); + let res = tool.execute(json!({"command": "stats"})).await.unwrap(); + assert!(res.success); + let v: serde_json::Value = serde_json::from_str(&res.output).unwrap(); + assert_eq!(v["total_files"], 3); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 89761a5f0..8f1f73b01 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -15,6 +15,7 @@ //! To add a new tool, implement [`Tool`] in a new submodule and register it in //! [`all_tools_with_runtime`]. See `AGENTS.md` §7.3 for the full change playbook. +pub mod backup_tool; pub mod browser; pub mod browser_open; pub mod cli_discovery; @@ -26,6 +27,7 @@ pub mod cron_remove; pub mod cron_run; pub mod cron_runs; pub mod cron_update; +pub mod data_management; pub mod delegate; pub mod file_edit; pub mod file_read; @@ -69,6 +71,7 @@ pub mod web_fetch; pub mod web_search_tool; pub mod workspace_tool; +pub use backup_tool::BackupTool; pub use browser::{BrowserTool, ComputerUseConfig}; pub use browser_open::BrowserOpenTool; pub use composio::ComposioTool; @@ -79,6 +82,7 @@ pub use cron_remove::CronRemoveTool; pub use cron_run::CronRunTool; pub use cron_runs::CronRunsTool; pub use cron_update::CronUpdateTool; +pub use data_management::DataManagementTool; pub use delegate::DelegateTool; pub use file_edit::FileEditTool; pub use file_read::FileReadTool; @@ -384,6 +388,23 @@ pub fn all_tools_with_runtime( ))); } + // Backup tool (enabled by default) + if root_config.backup.enabled { + tool_arcs.push(Arc::new(BackupTool::new( + workspace_dir.to_path_buf(), + root_config.backup.include_dirs.clone(), + root_config.backup.max_keep, + ))); + } + + // Data management tool (disabled by default) + if root_config.data_retention.enabled { + tool_arcs.push(Arc::new(DataManagementTool::new( + workspace_dir.to_path_buf(), + root_config.data_retention.retention_days, + ))); + } + // PDF extraction (feature-gated at compile time via rag-pdf) tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));