diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 1a43b182b..d5caaff1b 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -30,6 +30,7 @@ pub mod mattermost; pub mod nextcloud_talk; #[cfg(feature = "channel-nostr")] pub mod nostr; +pub mod notion; pub mod qq; pub mod session_store; pub mod signal; @@ -62,6 +63,7 @@ pub use mattermost::MattermostChannel; pub use nextcloud_talk::NextcloudTalkChannel; #[cfg(feature = "channel-nostr")] pub use nostr::NostrChannel; +pub use notion::NotionChannel; pub use qq::QQChannel; pub use signal::SignalChannel; pub use slack::SlackChannel; @@ -2981,6 +2983,12 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con channel.name() ); } + // Notion is a top-level config section, not part of ChannelsConfig + { + let notion_configured = + config.notion.enabled && !config.notion.database_id.trim().is_empty(); + println!(" {} Notion", if notion_configured { "✅" } else { "❌" }); + } if !cfg!(feature = "channel-matrix") { println!( " ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)." @@ -3413,6 +3421,34 @@ fn collect_configured_channels( }); } + // Notion database poller channel + if config.notion.enabled && !config.notion.database_id.trim().is_empty() { + let notion_api_key = if config.notion.api_key.trim().is_empty() { + std::env::var("NOTION_API_KEY").unwrap_or_default() + } else { + config.notion.api_key.trim().to_string() + }; + if notion_api_key.trim().is_empty() { + tracing::warn!( + "Notion channel enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)" + ); + } else { + channels.push(ConfiguredChannel { + display_name: "Notion", + channel: Arc::new(NotionChannel::new( + notion_api_key, + config.notion.database_id.clone(), + config.notion.poll_interval_secs, + config.notion.status_property.clone(), + config.notion.input_property.clone(), + config.notion.result_property.clone(), + config.notion.max_concurrent, + config.notion.recover_stale, + )), + }); + } + } + channels } diff --git a/src/channels/notion.rs b/src/channels/notion.rs new file mode 100644 index 000000000..6f8752d65 --- /dev/null +++ b/src/channels/notion.rs @@ -0,0 +1,614 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use anyhow::{bail, Result}; +use async_trait::async_trait; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::RwLock; + +const NOTION_API_BASE: &str = "https://api.notion.com/v1"; +const NOTION_VERSION: &str = "2022-06-28"; +const MAX_RESULT_LENGTH: usize = 2000; +const MAX_RETRIES: u32 = 3; +const RETRY_BASE_DELAY_MS: u64 = 2000; +/// Maximum number of characters to include from an error response body. +const MAX_ERROR_BODY_CHARS: usize = 500; + +/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary. +fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize { + if max_bytes >= s.len() { + return s.len(); + } + let mut idx = max_bytes; + while idx > 0 && !s.is_char_boundary(idx) { + idx -= 1; + } + idx +} + +/// Notion channel — polls a Notion database for pending tasks and writes results back. +/// +/// The channel connects to the Notion API, queries a database for rows with a "pending" +/// status, dispatches them as channel messages, and writes results back when processing +/// completes. It supports crash recovery by resetting stale "running" tasks on startup. +pub struct NotionChannel { + api_key: String, + database_id: String, + poll_interval_secs: u64, + status_property: String, + input_property: String, + result_property: String, + max_concurrent: usize, + status_type: Arc>, + inflight: Arc>>, + http: reqwest::Client, + recover_stale: bool, +} + +impl NotionChannel { + /// Create a new Notion channel with the given configuration. + pub fn new( + api_key: String, + database_id: String, + poll_interval_secs: u64, + status_property: String, + input_property: String, + result_property: String, + max_concurrent: usize, + recover_stale: bool, + ) -> Self { + Self { + api_key, + database_id, + poll_interval_secs, + status_property, + input_property, + result_property, + max_concurrent, + status_type: Arc::new(RwLock::new("select".to_string())), + inflight: Arc::new(RwLock::new(HashSet::new())), + http: reqwest::Client::new(), + recover_stale, + } + } + + /// Build the standard Notion API headers (Authorization, version, content-type). + fn headers(&self) -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", self.api_key) + .parse() + .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?, + ); + headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap()); + headers.insert("Content-Type", "application/json".parse().unwrap()); + Ok(headers) + } + + /// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx). + async fn api_call( + &self, + method: reqwest::Method, + url: &str, + body: Option, + ) -> Result { + let mut last_err = None; + for attempt in 0..MAX_RETRIES { + let mut req = self + .http + .request(method.clone(), url) + .headers(self.headers()?); + if let Some(ref b) = body { + req = req.json(b); + } + match req.send().await { + Ok(resp) => { + let status = resp.status(); + if status.is_success() { + return resp + .json() + .await + .map_err(|e| anyhow::anyhow!("Failed to parse response: {e}")); + } + let status_code = status.as_u16(); + // Only retry on 429 (rate limit) or 5xx (server errors) + if status_code != 429 && (400..500).contains(&status_code) { + let body_text = resp.text().await.unwrap_or_default(); + let truncated = + crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS); + bail!("Notion API error {status_code}: {truncated}"); + } + last_err = Some(anyhow::anyhow!("Notion API error: {status_code}")); + } + Err(e) => { + last_err = Some(anyhow::anyhow!("HTTP request failed: {e}")); + } + } + let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt); + tracing::warn!( + "Notion API call failed (attempt {}/{}), retrying in {}ms", + attempt + 1, + MAX_RETRIES, + delay + ); + tokio::time::sleep(std::time::Duration::from_millis(delay)).await; + } + Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries"))) + } + + /// Query the database schema and detect whether Status uses "select" or "status" type. + async fn detect_status_type(&self) -> Result { + let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id); + let resp = self.api_call(reqwest::Method::GET, &url, None).await?; + let status_type = resp + .get("properties") + .and_then(|p| p.get(&self.status_property)) + .and_then(|s| s.get("type")) + .and_then(|t| t.as_str()) + .unwrap_or("select") + .to_string(); + Ok(status_type) + } + + /// Query for rows where Status = "pending". + async fn query_pending(&self) -> Result> { + let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id); + let status_type = self.status_type.read().await.clone(); + let filter = build_status_filter(&self.status_property, &status_type, "pending"); + let resp = self + .api_call( + reqwest::Method::POST, + &url, + Some(serde_json::json!({ "filter": filter })), + ) + .await?; + Ok(resp + .get("results") + .and_then(|r| r.as_array()) + .cloned() + .unwrap_or_default()) + } + + /// Atomically claim a task. Returns true if this caller got it. + async fn claim_task(&self, page_id: &str) -> bool { + let mut inflight = self.inflight.write().await; + if inflight.contains(page_id) { + return false; + } + if inflight.len() >= self.max_concurrent { + return false; + } + inflight.insert(page_id.to_string()); + true + } + + /// Release a task from the inflight set. + async fn release_task(&self, page_id: &str) { + let mut inflight = self.inflight.write().await; + inflight.remove(page_id); + } + + /// Update a row's status. + async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let status_type = self.status_type.read().await.clone(); + let payload = serde_json::json!({ + "properties": { + &self.status_property: build_status_payload(&status_type, status_value), + } + }); + self.api_call(reqwest::Method::PATCH, &url, Some(payload)) + .await?; + Ok(()) + } + + /// Write result text to the Result column. + async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let payload = serde_json::json!({ + "properties": { + &self.result_property: build_rich_text_payload(result_text), + } + }); + self.api_call(reqwest::Method::PATCH, &url, Some(payload)) + .await?; + Ok(()) + } + + /// On startup, reset "running" tasks back to "pending" for crash recovery. + async fn recover_stale(&self) -> Result<()> { + let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id); + let status_type = self.status_type.read().await.clone(); + let filter = build_status_filter(&self.status_property, &status_type, "running"); + let resp = self + .api_call( + reqwest::Method::POST, + &url, + Some(serde_json::json!({ "filter": filter })), + ) + .await?; + let stale = resp + .get("results") + .and_then(|r| r.as_array()) + .cloned() + .unwrap_or_default(); + if stale.is_empty() { + return Ok(()); + } + tracing::warn!( + "Found {} stale task(s) in 'running' state, resetting to 'pending'", + stale.len() + ); + for task in &stale { + if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) { + let page_url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let payload = serde_json::json!({ + "properties": { + &self.status_property: build_status_payload(&status_type, "pending"), + &self.result_property: build_rich_text_payload( + "Reset: poller restarted while task was running" + ), + } + }); + let short_id_end = floor_utf8_char_boundary(page_id, 8); + let short_id = &page_id[..short_id_end]; + if let Err(e) = self + .api_call(reqwest::Method::PATCH, &page_url, Some(payload)) + .await + { + tracing::error!("Could not reset stale task {short_id}: {e}"); + } else { + tracing::info!("Reset stale task {short_id} to pending"); + } + } + } + Ok(()) + } +} + +#[async_trait] +impl Channel for NotionChannel { + fn name(&self) -> &str { + "notion" + } + + async fn send(&self, message: &SendMessage) -> Result<()> { + // recipient is the page_id for Notion + let page_id = &message.recipient; + let status_type = self.status_type.read().await.clone(); + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let payload = serde_json::json!({ + "properties": { + &self.status_property: build_status_payload(&status_type, "done"), + &self.result_property: build_rich_text_payload(&message.content), + } + }); + self.api_call(reqwest::Method::PATCH, &url, Some(payload)) + .await?; + self.release_task(page_id).await; + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> Result<()> { + // Detect status property type + match self.detect_status_type().await { + Ok(st) => { + tracing::info!("Notion status property type: {st}"); + *self.status_type.write().await = st; + } + Err(e) => { + bail!("Failed to detect Notion database schema: {e}"); + } + } + + // Crash recovery + if self.recover_stale { + if let Err(e) = self.recover_stale().await { + tracing::error!("Notion stale task recovery failed: {e}"); + } + } + + // Polling loop + loop { + match self.query_pending().await { + Ok(tasks) => { + if !tasks.is_empty() { + tracing::info!("Notion: found {} pending task(s)", tasks.len()); + } + for task in tasks { + let page_id = match task.get("id").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => continue, + }; + + let input_text = extract_text_from_property( + task.get("properties") + .and_then(|p| p.get(&self.input_property)), + ); + + if input_text.trim().is_empty() { + let short_end = floor_utf8_char_boundary(&page_id, 8); + tracing::warn!( + "Notion: empty input for task {}, skipping", + &page_id[..short_end] + ); + continue; + } + + if !self.claim_task(&page_id).await { + continue; + } + + // Set status to running + if let Err(e) = self.set_status(&page_id, "running").await { + tracing::error!("Notion: failed to set running status: {e}"); + self.release_task(&page_id).await; + continue; + } + + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if tx + .send(ChannelMessage { + id: page_id.clone(), + sender: "notion".into(), + reply_target: page_id, + content: input_text, + channel: "notion".into(), + timestamp, + thread_ts: None, + }) + .await + .is_err() + { + tracing::info!("Notion channel shutting down"); + return Ok(()); + } + } + } + Err(e) => { + tracing::error!("Notion poll error: {e}"); + } + } + + tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await; + } + } + + async fn health_check(&self) -> bool { + let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id); + self.api_call(reqwest::Method::GET, &url, None) + .await + .is_ok() + } +} + +// ── Helper functions ────────────────────────────────────────────── + +/// Build a Notion API filter object for the given status property. +fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value { + if status_type == "status" { + serde_json::json!({ + "property": property, + "status": { "equals": value } + }) + } else { + serde_json::json!({ + "property": property, + "select": { "equals": value } + }) + } +} + +/// Build a Notion API property-update payload for a status field. +fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value { + if status_type == "status" { + serde_json::json!({ "status": { "name": value } }) + } else { + serde_json::json!({ "select": { "name": value } }) + } +} + +/// Build a Notion API rich-text property payload, truncating if necessary. +fn build_rich_text_payload(value: &str) -> serde_json::Value { + let truncated = truncate_result(value); + serde_json::json!({ + "rich_text": [{ + "text": { "content": truncated } + }] + }) +} + +/// Truncate result text to fit within the Notion rich-text content limit. +fn truncate_result(value: &str) -> String { + if value.len() <= MAX_RESULT_LENGTH { + return value.to_string(); + } + let cut = MAX_RESULT_LENGTH.saturating_sub(30); + // Ensure we cut on a char boundary + let end = floor_utf8_char_boundary(value, cut); + format!("{}\n\n... [output truncated]", &value[..end]) +} + +/// Extract plain text from a Notion property (title or rich_text type). +fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String { + let Some(prop) = prop else { + return String::new(); + }; + let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or(""); + let array_key = match ptype { + "title" => "title", + "rich_text" => "rich_text", + _ => return String::new(), + }; + prop.get(array_key) + .and_then(|arr| arr.as_array()) + .map(|items| { + items + .iter() + .filter_map(|item| item.get("plain_text").and_then(|t| t.as_str())) + .collect::>() + .join("") + }) + .unwrap_or_default() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn claim_task_deduplication() { + let channel = NotionChannel::new( + "test-key".into(), + "test-db".into(), + 5, + "Status".into(), + "Input".into(), + "Result".into(), + 4, + false, + ); + + assert!(channel.claim_task("page-1").await); + // Second claim for same page should fail + assert!(!channel.claim_task("page-1").await); + // Different page should succeed + assert!(channel.claim_task("page-2").await); + + // After release, can claim again + channel.release_task("page-1").await; + assert!(channel.claim_task("page-1").await); + } + + #[test] + fn result_truncation_within_limit() { + let short = "hello world"; + assert_eq!(truncate_result(short), short); + } + + #[test] + fn result_truncation_over_limit() { + let long = "a".repeat(MAX_RESULT_LENGTH + 100); + let truncated = truncate_result(&long); + assert!(truncated.len() <= MAX_RESULT_LENGTH); + assert!(truncated.ends_with("... [output truncated]")); + } + + #[test] + fn result_truncation_multibyte_safe() { + // Build a string that would cut in the middle of a multibyte char + let mut s = String::new(); + for _ in 0..700 { + s.push('\u{6E2C}'); // 3-byte UTF-8 char + } + let truncated = truncate_result(&s); + // Should not panic and should be valid UTF-8 + assert!(truncated.len() <= MAX_RESULT_LENGTH); + assert!(truncated.ends_with("... [output truncated]")); + } + + #[test] + fn status_payload_select_type() { + let payload = build_status_payload("select", "pending"); + assert_eq!( + payload, + serde_json::json!({ "select": { "name": "pending" } }) + ); + } + + #[test] + fn status_payload_status_type() { + let payload = build_status_payload("status", "done"); + assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } })); + } + + #[test] + fn rich_text_payload_construction() { + let payload = build_rich_text_payload("test output"); + let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap(); + assert_eq!(text, "test output"); + } + + #[test] + fn status_filter_select_type() { + let filter = build_status_filter("Status", "select", "pending"); + assert_eq!( + filter, + serde_json::json!({ + "property": "Status", + "select": { "equals": "pending" } + }) + ); + } + + #[test] + fn status_filter_status_type() { + let filter = build_status_filter("Status", "status", "running"); + assert_eq!( + filter, + serde_json::json!({ + "property": "Status", + "status": { "equals": "running" } + }) + ); + } + + #[test] + fn extract_text_from_title_property() { + let prop = serde_json::json!({ + "type": "title", + "title": [ + { "plain_text": "Hello " }, + { "plain_text": "World" } + ] + }); + assert_eq!(extract_text_from_property(Some(&prop)), "Hello World"); + } + + #[test] + fn extract_text_from_rich_text_property() { + let prop = serde_json::json!({ + "type": "rich_text", + "rich_text": [{ "plain_text": "task content" }] + }); + assert_eq!(extract_text_from_property(Some(&prop)), "task content"); + } + + #[test] + fn extract_text_from_none() { + assert_eq!(extract_text_from_property(None), ""); + } + + #[test] + fn extract_text_from_unknown_type() { + let prop = serde_json::json!({ "type": "number", "number": 42 }); + assert_eq!(extract_text_from_property(Some(&prop)), ""); + } + + #[tokio::test] + async fn claim_task_respects_max_concurrent() { + let channel = NotionChannel::new( + "test-key".into(), + "test-db".into(), + 5, + "Status".into(), + "Input".into(), + "Result".into(), + 2, // max_concurrent = 2 + false, + ); + + assert!(channel.claim_task("page-1").await); + assert!(channel.claim_task("page-2").await); + // Third claim should be rejected (at capacity) + assert!(!channel.claim_task("page-3").await); + + // After releasing one, can claim again + channel.release_task("page-1").await; + assert!(channel.claim_task("page-3").await); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 6cf2b5dea..1cf2960a0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -12,15 +12,16 @@ pub use schema::{ DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, - MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, - MultimodalConfig, NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, - OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, + MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, + ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, + NotionConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, + PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, - StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, - ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig, - WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, + SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, + StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, + TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, + TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig, }; pub fn name_and_presence(channel: Option<&T>) -> (&'static str, bool) { diff --git a/src/config/schema.rs b/src/config/schema.rs index 046c018f1..ba9b6aa43 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -122,8 +122,17 @@ pub struct Config { /// Security subsystem configuration (`[security]`). #[serde(default)] + /// Backup tool configuration (`[backup]`). + pub backup: BackupConfig, + + /// Data retention and purge configuration (`[data_retention]`). + pub data_retention: DataRetentionConfig, + pub security: SecurityConfig, + /// Managed cybersecurity service configuration (`[security_ops]`). + pub security_ops: SecurityOpsConfig, + /// Runtime adapter configuration (`[runtime]`). Controls native vs Docker execution. #[serde(default)] pub runtime: RuntimeConfig, @@ -188,6 +197,10 @@ pub struct Config { #[serde(default)] pub composio: ComposioConfig, + /// Microsoft 365 Graph API integration (`[microsoft365]`). + #[serde(default)] + pub microsoft365: Microsoft365Config, + /// Secrets encryption configuration (`[secrets]`). #[serde(default)] pub secrets: SecretsConfig, @@ -212,6 +225,10 @@ pub struct Config { #[serde(default)] pub web_search: WebSearchConfig, + /// Project delivery intelligence configuration (`[project_intel]`). + #[serde(default)] + pub project_intel: ProjectIntelConfig, + /// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]`). #[serde(default)] pub proxy: ProxyConfig, @@ -264,13 +281,13 @@ pub struct Config { #[serde(default)] pub workspace: WorkspaceConfig, - /// Backup tool configuration (`[backup]`). + /// Notion integration configuration (`[notion]`). #[serde(default)] - pub backup: BackupConfig, + pub notion: NotionConfig, - /// Data retention and purge configuration (`[data_retention]`). + /// Secure inter-node transport configuration (`[node_transport]`). #[serde(default)] - pub data_retention: DataRetentionConfig, + pub node_transport: NodeTransportConfig, } /// Multi-client workspace isolation configuration. @@ -1352,6 +1369,67 @@ impl Default for GatewayConfig { } } +/// Secure transport configuration for inter-node communication (`[node_transport]`). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct NodeTransportConfig { + /// Enable the secure transport layer. + #[serde(default = "default_node_transport_enabled")] + pub enabled: bool, + /// Shared secret for HMAC authentication between nodes. + #[serde(default)] + pub shared_secret: String, + /// Maximum age of signed requests in seconds (replay protection). + #[serde(default = "default_max_request_age")] + pub max_request_age_secs: i64, + /// Require HTTPS for all node communication. + #[serde(default = "default_require_https")] + pub require_https: bool, + /// Allow specific node IPs/CIDRs. + #[serde(default)] + pub allowed_peers: Vec, + /// Path to TLS certificate file. + #[serde(default)] + pub tls_cert_path: Option, + /// Path to TLS private key file. + #[serde(default)] + pub tls_key_path: Option, + /// Require client certificates (mutual TLS). + #[serde(default)] + pub mutual_tls: bool, + /// Maximum number of connections per peer. + #[serde(default = "default_connection_pool_size")] + pub connection_pool_size: usize, +} + +fn default_node_transport_enabled() -> bool { + true +} +fn default_max_request_age() -> i64 { + 300 +} +fn default_require_https() -> bool { + true +} +fn default_connection_pool_size() -> usize { + 4 +} + +impl Default for NodeTransportConfig { + fn default() -> Self { + Self { + enabled: default_node_transport_enabled(), + shared_secret: String::new(), + max_request_age_secs: default_max_request_age(), + require_https: default_require_https(), + allowed_peers: Vec::new(), + tls_cert_path: None, + tls_key_path: None, + mutual_tls: false, + connection_pool_size: default_connection_pool_size(), + } + } +} + // ── Composio (managed tool surface) ───────────────────────────── /// Composio managed OAuth tools integration (`[composio]` section). @@ -1384,6 +1462,78 @@ impl Default for ComposioConfig { } } +// ── Microsoft 365 (Graph API integration) ─────────────────────── + +/// Microsoft 365 integration via Microsoft Graph API (`[microsoft365]` section). +/// +/// Provides access to Outlook mail, Teams messages, Calendar events, +/// OneDrive files, and SharePoint search. +#[derive(Clone, Serialize, Deserialize, JsonSchema)] +pub struct Microsoft365Config { + /// Enable Microsoft 365 integration + #[serde(default, alias = "enable")] + pub enabled: bool, + /// Azure AD tenant ID + #[serde(default)] + pub tenant_id: Option, + /// Azure AD application (client) ID + #[serde(default)] + pub client_id: Option, + /// Azure AD client secret (stored encrypted when secrets.encrypt = true) + #[serde(default)] + pub client_secret: Option, + /// Authentication flow: "client_credentials" or "device_code" + #[serde(default = "default_ms365_auth_flow")] + pub auth_flow: String, + /// OAuth scopes to request + #[serde(default = "default_ms365_scopes")] + pub scopes: Vec, + /// Encrypt the token cache file on disk + #[serde(default = "default_true")] + pub token_cache_encrypted: bool, + /// User principal name or "me" (for delegated flows) + #[serde(default)] + pub user_id: Option, +} + +fn default_ms365_auth_flow() -> String { + "client_credentials".to_string() +} + +fn default_ms365_scopes() -> Vec { + vec!["https://graph.microsoft.com/.default".to_string()] +} + +impl std::fmt::Debug for Microsoft365Config { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Microsoft365Config") + .field("enabled", &self.enabled) + .field("tenant_id", &self.tenant_id) + .field("client_id", &self.client_id) + .field("client_secret", &self.client_secret.as_ref().map(|_| "***")) + .field("auth_flow", &self.auth_flow) + .field("scopes", &self.scopes) + .field("token_cache_encrypted", &self.token_cache_encrypted) + .field("user_id", &self.user_id) + .finish() + } +} + +impl Default for Microsoft365Config { + fn default() -> Self { + Self { + enabled: false, + tenant_id: None, + client_id: None, + client_secret: None, + auth_flow: default_ms365_auth_flow(), + scopes: default_ms365_scopes(), + token_cache_encrypted: true, + user_id: None, + } + } +} + // ── Secrets (encrypted credential store) ──────────────────────── /// Secrets encryption configuration (`[secrets]` section). @@ -1648,6 +1798,64 @@ impl Default for WebSearchConfig { } } +// ── Project Intelligence ──────────────────────────────────────── + +/// Project delivery intelligence configuration (`[project_intel]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct ProjectIntelConfig { + /// Enable the project_intel tool. Default: false. + #[serde(default)] + pub enabled: bool, + /// Default report language (en, de, fr, it). Default: "en". + #[serde(default = "default_project_intel_language")] + pub default_language: String, + /// Output directory for generated reports. + #[serde(default = "default_project_intel_report_dir")] + pub report_output_dir: String, + /// Optional custom templates directory. + #[serde(default)] + pub templates_dir: Option, + /// Risk detection sensitivity: low, medium, high. Default: "medium". + #[serde(default = "default_project_intel_risk_sensitivity")] + pub risk_sensitivity: String, + /// Include git log data in reports. Default: true. + #[serde(default = "default_true")] + pub include_git_data: bool, + /// Include Jira data in reports. Default: false. + #[serde(default)] + pub include_jira_data: bool, + /// Jira instance base URL (required if include_jira_data is true). + #[serde(default)] + pub jira_base_url: Option, +} + +fn default_project_intel_language() -> String { + "en".into() +} + +fn default_project_intel_report_dir() -> String { + "~/.zeroclaw/project-reports".into() +} + +fn default_project_intel_risk_sensitivity() -> String { + "medium".into() +} + +impl Default for ProjectIntelConfig { + fn default() -> Self { + Self { + enabled: false, + default_language: default_project_intel_language(), + report_output_dir: default_project_intel_report_dir(), + templates_dir: None, + risk_sensitivity: default_project_intel_risk_sensitivity(), + include_git_data: true, + include_jira_data: false, + jira_base_url: None, + } + } +} + // ── Backup ────────────────────────────────────────────────────── /// Backup tool configuration (`[backup]` section). @@ -1665,10 +1873,10 @@ pub struct BackupConfig { /// Output directory for backup archives (relative to workspace root). #[serde(default = "default_backup_destination_dir")] pub destination_dir: String, - /// Optional cron expression for scheduled automatic backups (e.g. `"0 2 * * *"`). + /// Optional cron expression for scheduled automatic backups. #[serde(default)] pub schedule_cron: Option, - /// IANA timezone for `schedule_cron` (e.g. `"UTC"`, `"Europe/Zurich"`). + /// IANA timezone for `schedule_cron`. #[serde(default)] pub schedule_timezone: Option, /// Compress backup archives. @@ -3174,10 +3382,10 @@ impl Default for CronConfig { /// Tunnel configuration for exposing the gateway publicly (`[tunnel]` section). /// -/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"custom"`. +/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, `"custom"`. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct TunnelConfig { - /// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, or `"custom"`. Default: `"none"`. + /// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, or `"custom"`. Default: `"none"`. pub provider: String, /// Cloudflare Tunnel configuration (used when `provider = "cloudflare"`). @@ -3192,6 +3400,10 @@ pub struct TunnelConfig { #[serde(default)] pub ngrok: Option, + /// OpenVPN tunnel configuration (used when `provider = "openvpn"`). + #[serde(default)] + pub openvpn: Option, + /// Custom tunnel command configuration (used when `provider = "custom"`). #[serde(default)] pub custom: Option, @@ -3204,6 +3416,7 @@ impl Default for TunnelConfig { cloudflare: None, tailscale: None, ngrok: None, + openvpn: None, custom: None, } } @@ -3232,6 +3445,36 @@ pub struct NgrokTunnelConfig { pub domain: Option, } +/// OpenVPN tunnel configuration (`[tunnel.openvpn]`). +/// +/// Required when `tunnel.provider = "openvpn"`. Omitting this section entirely +/// preserves previous behavior. Setting `tunnel.provider = "none"` (or removing +/// the `[tunnel.openvpn]` block) cleanly reverts to no-tunnel mode. +/// +/// Defaults: `connect_timeout_secs = 30`. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct OpenVpnTunnelConfig { + /// Path to `.ovpn` configuration file (must not be empty). + pub config_file: String, + /// Optional path to auth credentials file (`--auth-user-pass`). + #[serde(default)] + pub auth_file: Option, + /// Advertised address once VPN is connected (e.g., `"10.8.0.2:42617"`). + /// When omitted the tunnel falls back to `http://{local_host}:{local_port}`. + #[serde(default)] + pub advertise_address: Option, + /// Connection timeout in seconds (default: 30, must be > 0). + #[serde(default = "default_openvpn_timeout")] + pub connect_timeout_secs: u64, + /// Extra openvpn CLI arguments forwarded verbatim. + #[serde(default)] + pub extra_args: Vec, +} + +fn default_openvpn_timeout() -> u64 { + 30 +} + #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct CustomTunnelConfig { /// Command template to start the tunnel. Use {port} and {host} placeholders. @@ -4011,6 +4254,10 @@ pub struct SecurityConfig { /// Emergency-stop state machine configuration. #[serde(default)] pub estop: EstopConfig, + + /// Nevis IAM integration for SSO/MFA authentication and role-based access. + #[serde(default)] + pub nevis: NevisConfig, } /// OTP validation strategy. @@ -4122,6 +4369,163 @@ impl Default for EstopConfig { } } +/// Nevis IAM integration configuration. +/// +/// When `enabled` is true, ZeroClaw validates incoming requests against a Nevis +/// Security Suite instance and maps Nevis roles to tool/workspace permissions. +#[derive(Clone, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct NevisConfig { + /// Enable Nevis IAM integration. Defaults to false for backward compatibility. + #[serde(default)] + pub enabled: bool, + + /// Base URL of the Nevis instance (e.g. `https://nevis.example.com`). + #[serde(default)] + pub instance_url: String, + + /// Nevis realm to authenticate against. + #[serde(default = "default_nevis_realm")] + pub realm: String, + + /// OAuth2 client ID registered in Nevis. + #[serde(default)] + pub client_id: String, + + /// OAuth2 client secret. Encrypted via SecretStore when stored on disk. + #[serde(default)] + pub client_secret: Option, + + /// Token validation strategy: `"local"` (JWKS) or `"remote"` (introspection). + #[serde(default = "default_nevis_token_validation")] + pub token_validation: String, + + /// JWKS endpoint URL for local token validation. + #[serde(default)] + pub jwks_url: Option, + + /// Nevis role to ZeroClaw permission mappings. + #[serde(default)] + pub role_mapping: Vec, + + /// Require MFA verification for all Nevis-authenticated requests. + #[serde(default)] + pub require_mfa: bool, + + /// Session timeout in seconds. + #[serde(default = "default_nevis_session_timeout_secs")] + pub session_timeout_secs: u64, +} + +impl std::fmt::Debug for NevisConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NevisConfig") + .field("enabled", &self.enabled) + .field("instance_url", &self.instance_url) + .field("realm", &self.realm) + .field("client_id", &self.client_id) + .field( + "client_secret", + &self.client_secret.as_ref().map(|_| "[REDACTED]"), + ) + .field("token_validation", &self.token_validation) + .field("jwks_url", &self.jwks_url) + .field("role_mapping", &self.role_mapping) + .field("require_mfa", &self.require_mfa) + .field("session_timeout_secs", &self.session_timeout_secs) + .finish() + } +} + +impl NevisConfig { + /// Validate that required fields are present when Nevis is enabled. + /// + /// Call at config load time to fail fast on invalid configuration rather + /// than deferring errors to the first authentication request. + pub fn validate(&self) -> Result<(), String> { + if !self.enabled { + return Ok(()); + } + + if self.instance_url.trim().is_empty() { + return Err("nevis.instance_url is required when Nevis IAM is enabled".into()); + } + + if self.client_id.trim().is_empty() { + return Err("nevis.client_id is required when Nevis IAM is enabled".into()); + } + + if self.realm.trim().is_empty() { + return Err("nevis.realm is required when Nevis IAM is enabled".into()); + } + + match self.token_validation.as_str() { + "local" | "remote" => {} + other => { + return Err(format!( + "nevis.token_validation has invalid value '{other}': \ + expected 'local' or 'remote'" + )); + } + } + + if self.token_validation == "local" && self.jwks_url.is_none() { + return Err("nevis.jwks_url is required when token_validation is 'local'".into()); + } + + if self.session_timeout_secs == 0 { + return Err("nevis.session_timeout_secs must be greater than 0".into()); + } + + Ok(()) + } +} + +fn default_nevis_realm() -> String { + "master".into() +} + +fn default_nevis_token_validation() -> String { + "local".into() +} + +fn default_nevis_session_timeout_secs() -> u64 { + 3600 +} + +impl Default for NevisConfig { + fn default() -> Self { + Self { + enabled: false, + instance_url: String::new(), + realm: default_nevis_realm(), + client_id: String::new(), + client_secret: None, + token_validation: default_nevis_token_validation(), + jwks_url: None, + role_mapping: Vec::new(), + require_mfa: false, + session_timeout_secs: default_nevis_session_timeout_secs(), + } + } +} + +/// Maps a Nevis role to ZeroClaw tool permissions and workspace access. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct NevisRoleMappingConfig { + /// Nevis role name (case-insensitive). + pub nevis_role: String, + + /// Tool names this role can access. Use `"all"` for unrestricted tool access. + #[serde(default)] + pub zeroclaw_permissions: Vec, + + /// Workspace names this role can access. Use `"all"` for unrestricted. + #[serde(default)] + pub workspace_access: Vec, +} + /// Sandbox configuration for OS-level isolation #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SandboxConfig { @@ -4352,6 +4756,129 @@ pub fn default_nostr_relays() -> Vec { ] } +// -- Notion -- + +/// Notion integration configuration (`[notion]`). +/// +/// When `enabled = true`, the agent polls a Notion database for pending tasks +/// and exposes a `notion` tool for querying, reading, creating, and updating pages. +/// Requires `api_key` (or the `NOTION_API_KEY` env var) and `database_id`. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct NotionConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub api_key: String, + #[serde(default)] + pub database_id: String, + #[serde(default = "default_notion_poll_interval")] + pub poll_interval_secs: u64, + #[serde(default = "default_notion_status_prop")] + pub status_property: String, + #[serde(default = "default_notion_input_prop")] + pub input_property: String, + #[serde(default = "default_notion_result_prop")] + pub result_property: String, + #[serde(default = "default_notion_max_concurrent")] + pub max_concurrent: usize, + #[serde(default = "default_notion_recover_stale")] + pub recover_stale: bool, +} + +fn default_notion_poll_interval() -> u64 { + 5 +} +fn default_notion_status_prop() -> String { + "Status".into() +} +fn default_notion_input_prop() -> String { + "Input".into() +} +fn default_notion_result_prop() -> String { + "Result".into() +} +fn default_notion_max_concurrent() -> usize { + 4 +} +fn default_notion_recover_stale() -> bool { + true +} + +impl Default for NotionConfig { + fn default() -> Self { + Self { + enabled: false, + api_key: String::new(), + database_id: String::new(), + poll_interval_secs: default_notion_poll_interval(), + status_property: default_notion_status_prop(), + input_property: default_notion_input_prop(), + result_property: default_notion_result_prop(), + max_concurrent: default_notion_max_concurrent(), + recover_stale: default_notion_recover_stale(), + } + } +} + +// ── Security ops config ───────────────────────────────────────── + +/// Managed Cybersecurity Service (MCSS) dashboard agent configuration (`[security_ops]`). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SecurityOpsConfig { + /// Enable security operations tools. + #[serde(default)] + pub enabled: bool, + /// Directory containing incident response playbook definitions (JSON). + #[serde(default = "default_playbooks_dir")] + pub playbooks_dir: String, + /// Automatically triage incoming alerts without user prompt. + #[serde(default)] + pub auto_triage: bool, + /// Require human approval before executing playbook actions. + #[serde(default = "default_require_approval")] + pub require_approval_for_actions: bool, + /// Maximum severity level that can be auto-remediated without approval. + /// One of: "low", "medium", "high", "critical". Default: "low". + #[serde(default = "default_max_auto_severity")] + pub max_auto_severity: String, + /// Directory for generated security reports. + #[serde(default = "default_report_output_dir")] + pub report_output_dir: String, + /// Optional SIEM webhook URL for alert ingestion. + #[serde(default)] + pub siem_integration: Option, +} + +fn default_playbooks_dir() -> String { + "~/.zeroclaw/playbooks".into() +} + +fn default_require_approval() -> bool { + true +} + +fn default_max_auto_severity() -> String { + "low".into() +} + +fn default_report_output_dir() -> String { + "~/.zeroclaw/security-reports".into() +} + +impl Default for SecurityOpsConfig { + fn default() -> Self { + Self { + enabled: false, + playbooks_dir: default_playbooks_dir(), + auto_triage: false, + require_approval_for_actions: true, + max_auto_severity: default_max_auto_severity(), + report_output_dir: default_report_output_dir(), + siem_integration: None, + } + } +} + // ── Config impl ────────────────────────────────────────────────── impl Default for Config { @@ -4374,7 +4901,10 @@ impl Default for Config { extra_headers: HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: BackupConfig::default(), + data_retention: DataRetentionConfig::default(), security: SecurityConfig::default(), + security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), @@ -4390,12 +4920,14 @@ impl Default for Config { tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), multimodal: MultimodalConfig::default(), web_fetch: WebFetchConfig::default(), web_search: WebSearchConfig::default(), + project_intel: ProjectIntelConfig::default(), proxy: ProxyConfig::default(), identity: IdentityConfig::default(), cost: CostConfig::default(), @@ -4410,8 +4942,8 @@ impl Default for Config { mcp: McpConfig::default(), nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), - backup: BackupConfig::default(), - data_retention: DataRetentionConfig::default(), + notion: NotionConfig::default(), + node_transport: NodeTransportConfig::default(), } } } @@ -4887,6 +5419,11 @@ impl Config { &mut config.composio.api_key, "config.composio.api_key", )?; + decrypt_optional_secret( + &store, + &mut config.microsoft365.client_secret, + "config.microsoft365.client_secret", + )?; decrypt_optional_secret( &store, @@ -5144,6 +5681,18 @@ impl Config { decrypt_secret(&store, token, "config.gateway.paired_tokens[]")?; } + // Decrypt Nevis IAM secret + decrypt_optional_secret( + &store, + &mut config.security.nevis.client_secret, + "config.security.nevis.client_secret", + )?; + + // Notion API key (top-level, not in ChannelsConfig) + if !config.notion.api_key.is_empty() { + decrypt_secret(&store, &mut config.notion.api_key, "config.notion.api_key")?; + } + config.apply_env_overrides(); config.validate()?; tracing::info!( @@ -5279,6 +5828,20 @@ impl Config { /// Called after TOML deserialization and env-override application to catch /// obviously invalid values early instead of failing at arbitrary runtime points. pub fn validate(&self) -> Result<()> { + // Tunnel — OpenVPN + if self.tunnel.provider.trim() == "openvpn" { + let openvpn = self.tunnel.openvpn.as_ref().ok_or_else(|| { + anyhow::anyhow!("tunnel.provider='openvpn' requires [tunnel.openvpn]") + })?; + + if openvpn.config_file.trim().is_empty() { + anyhow::bail!("tunnel.openvpn.config_file must not be empty"); + } + if openvpn.connect_timeout_secs == 0 { + anyhow::bail!("tunnel.openvpn.connect_timeout_secs must be greater than 0"); + } + } + // Gateway if self.gateway.host.trim().is_empty() { anyhow::bail!("gateway.host must not be empty"); @@ -5435,17 +5998,143 @@ impl Config { } } + // Microsoft 365 + if self.microsoft365.enabled { + let tenant = self + .microsoft365 + .tenant_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if tenant.is_none() { + anyhow::bail!( + "microsoft365.tenant_id must not be empty when microsoft365 is enabled" + ); + } + let client = self + .microsoft365 + .client_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if client.is_none() { + anyhow::bail!( + "microsoft365.client_id must not be empty when microsoft365 is enabled" + ); + } + let flow = self.microsoft365.auth_flow.trim(); + if flow != "client_credentials" && flow != "device_code" { + anyhow::bail!( + "microsoft365.auth_flow must be 'client_credentials' or 'device_code'" + ); + } + if flow == "client_credentials" + && self + .microsoft365 + .client_secret + .as_deref() + .map_or(true, |s| s.trim().is_empty()) + { + anyhow::bail!( + "microsoft365.client_secret must not be empty when auth_flow is 'client_credentials'" + ); + } + } + + // Microsoft 365 + if self.microsoft365.enabled { + let tenant = self + .microsoft365 + .tenant_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if tenant.is_none() { + anyhow::bail!( + "microsoft365.tenant_id must not be empty when microsoft365 is enabled" + ); + } + let client = self + .microsoft365 + .client_id + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()); + if client.is_none() { + anyhow::bail!( + "microsoft365.client_id must not be empty when microsoft365 is enabled" + ); + } + let flow = self.microsoft365.auth_flow.trim(); + if flow != "client_credentials" && flow != "device_code" { + anyhow::bail!("microsoft365.auth_flow must be client_credentials or device_code"); + } + if flow == "client_credentials" + && self + .microsoft365 + .client_secret + .as_deref() + .map_or(true, |s| s.trim().is_empty()) + { + anyhow::bail!("microsoft365.client_secret must not be empty when auth_flow is client_credentials"); + } + } + // MCP if self.mcp.enabled { validate_mcp_config(&self.mcp)?; } + // Project intelligence + if self.project_intel.enabled { + let lang = &self.project_intel.default_language; + if !["en", "de", "fr", "it"].contains(&lang.as_str()) { + anyhow::bail!( + "project_intel.default_language must be one of: en, de, fr, it (got '{lang}')" + ); + } + let sens = &self.project_intel.risk_sensitivity; + if !["low", "medium", "high"].contains(&sens.as_str()) { + anyhow::bail!( + "project_intel.risk_sensitivity must be one of: low, medium, high (got '{sens}')" + ); + } + if let Some(ref tpl_dir) = self.project_intel.templates_dir { + let path = std::path::Path::new(tpl_dir); + if !path.exists() { + anyhow::bail!("project_intel.templates_dir path does not exist: {tpl_dir}"); + } + } + } + // Proxy (delegate to existing validation) self.proxy.validate()?; - // MCP servers - if self.mcp.enabled { - validate_mcp_config(&self.mcp)?; + // Notion + if self.notion.enabled { + if self.notion.database_id.trim().is_empty() { + anyhow::bail!("notion.database_id must not be empty when notion.enabled = true"); + } + if self.notion.poll_interval_secs == 0 { + anyhow::bail!("notion.poll_interval_secs must be greater than 0"); + } + if self.notion.max_concurrent == 0 { + anyhow::bail!("notion.max_concurrent must be greater than 0"); + } + if self.notion.status_property.trim().is_empty() { + anyhow::bail!("notion.status_property must not be empty"); + } + if self.notion.input_property.trim().is_empty() { + anyhow::bail!("notion.input_property must not be empty"); + } + if self.notion.result_property.trim().is_empty() { + anyhow::bail!("notion.result_property must not be empty"); + } + } + + // Nevis IAM — delegate to NevisConfig::validate() for field-level checks + if let Err(msg) = self.security.nevis.validate() { + anyhow::bail!("security.nevis: {msg}"); } Ok(()) @@ -5814,6 +6503,11 @@ impl Config { &mut config_to_save.composio.api_key, "config.composio.api_key", )?; + encrypt_optional_secret( + &store, + &mut config_to_save.microsoft365.client_secret, + "config.microsoft365.client_secret", + )?; encrypt_optional_secret( &store, @@ -6071,6 +6765,22 @@ impl Config { encrypt_secret(&store, token, "config.gateway.paired_tokens[]")?; } + // Encrypt Nevis IAM secret + encrypt_optional_secret( + &store, + &mut config_to_save.security.nevis.client_secret, + "config.security.nevis.client_secret", + )?; + + // Notion API key (top-level, not in ChannelsConfig) + if !config_to_save.notion.api_key.is_empty() { + encrypt_secret( + &store, + &mut config_to_save.notion.api_key, + "config.notion.api_key", + )?; + } + let toml_str = toml::to_string_pretty(&config_to_save).context("Failed to serialize config")?; @@ -6158,6 +6868,7 @@ impl Config { } } +#[allow(clippy::unused_async)] // async needed on unix for tokio File I/O; no-op on other platforms async fn sync_directory(path: &Path) -> Result<()> { #[cfg(unix)] { @@ -6183,6 +6894,7 @@ mod tests { #[cfg(unix)] use std::os::unix::fs::PermissionsExt; use std::path::PathBuf; + #[cfg(unix)] use tempfile::TempDir; use tokio::sync::{Mutex, MutexGuard}; use tokio::test; @@ -6439,7 +7151,10 @@ default_temperature = 0.7 allowed_roots: vec![], non_cli_excluded_tools: vec![], }, + backup: BackupConfig::default(), + data_retention: DataRetentionConfig::default(), security: SecurityConfig::default(), + security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig { kind: "docker".into(), ..RuntimeConfig::default() @@ -6500,12 +7215,14 @@ default_temperature = 0.7 tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), multimodal: MultimodalConfig::default(), web_fetch: WebFetchConfig::default(), web_search: WebSearchConfig::default(), + project_intel: ProjectIntelConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), identity: IdentityConfig::default(), @@ -6520,8 +7237,8 @@ default_temperature = 0.7 mcp: McpConfig::default(), nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), - backup: BackupConfig::default(), - data_retention: DataRetentionConfig::default(), + notion: NotionConfig::default(), + node_transport: NodeTransportConfig::default(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -6779,7 +7496,10 @@ tool_dispatcher = "xml" extra_headers: HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: BackupConfig::default(), + data_retention: DataRetentionConfig::default(), security: SecurityConfig::default(), + security_ops: SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), @@ -6795,12 +7515,14 @@ tool_dispatcher = "xml" tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), multimodal: MultimodalConfig::default(), web_fetch: WebFetchConfig::default(), web_search: WebSearchConfig::default(), + project_intel: ProjectIntelConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), identity: IdentityConfig::default(), @@ -6815,8 +7537,8 @@ tool_dispatcher = "xml" mcp: McpConfig::default(), nodes: NodesConfig::default(), workspace: WorkspaceConfig::default(), - backup: BackupConfig::default(), - data_retention: DataRetentionConfig::default(), + notion: NotionConfig::default(), + node_transport: NodeTransportConfig::default(), }; config.save().await.unwrap(); @@ -9631,4 +10353,187 @@ require_otp_to_resume = true assert_eq!(config.swarms.len(), 1); assert!(config.swarms.contains_key("pipeline")); } + + #[tokio::test] + async fn nevis_client_secret_encrypt_decrypt_roundtrip() { + let dir = std::env::temp_dir().join(format!( + "zeroclaw_test_nevis_secret_{}", + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&dir).await.unwrap(); + + let plaintext_secret = "nevis-test-client-secret-value"; + + let mut config = Config::default(); + config.workspace_dir = dir.join("workspace"); + config.config_path = dir.join("config.toml"); + config.security.nevis.client_secret = Some(plaintext_secret.into()); + + // Save (triggers encryption) + config.save().await.unwrap(); + + // Read raw TOML and verify plaintext secret is NOT present + let raw_toml = tokio::fs::read_to_string(&config.config_path) + .await + .unwrap(); + assert!( + !raw_toml.contains(plaintext_secret), + "Saved TOML must not contain the plaintext client_secret" + ); + + // Parse stored TOML and verify the value is encrypted + let stored: Config = toml::from_str(&raw_toml).unwrap(); + let stored_secret = stored.security.nevis.client_secret.as_ref().unwrap(); + assert!( + crate::security::SecretStore::is_encrypted(stored_secret), + "Stored client_secret must be marked as encrypted" + ); + + // Decrypt and verify it matches the original plaintext + let store = crate::security::SecretStore::new(&dir, true); + assert_eq!(store.decrypt(stored_secret).unwrap(), plaintext_secret); + + // Simulate a full load: deserialize then decrypt (mirrors load_or_init logic) + let mut loaded: Config = toml::from_str(&raw_toml).unwrap(); + loaded.config_path = dir.join("config.toml"); + let load_store = crate::security::SecretStore::new(&dir, loaded.secrets.encrypt); + decrypt_optional_secret( + &load_store, + &mut loaded.security.nevis.client_secret, + "config.security.nevis.client_secret", + ) + .unwrap(); + assert_eq!( + loaded.security.nevis.client_secret.as_deref().unwrap(), + plaintext_secret, + "Loaded client_secret must match the original plaintext after decryption" + ); + + let _ = fs::remove_dir_all(&dir).await; + } + + // ══════════════════════════════════════════════════════════ + // Nevis config validation tests + // ══════════════════════════════════════════════════════════ + + #[test] + async fn nevis_config_validate_disabled_accepts_empty_fields() { + let cfg = NevisConfig::default(); + assert!(!cfg.enabled); + assert!(cfg.validate().is_ok()); + } + + #[test] + async fn nevis_config_validate_rejects_empty_instance_url() { + let cfg = NevisConfig { + enabled: true, + instance_url: String::new(), + client_id: "test-client".into(), + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("instance_url")); + } + + #[test] + async fn nevis_config_validate_rejects_empty_client_id() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: String::new(), + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("client_id")); + } + + #[test] + async fn nevis_config_validate_rejects_empty_realm() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: "test-client".into(), + realm: String::new(), + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("realm")); + } + + #[test] + async fn nevis_config_validate_rejects_local_without_jwks() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: "test-client".into(), + token_validation: "local".into(), + jwks_url: None, + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("jwks_url")); + } + + #[test] + async fn nevis_config_validate_rejects_zero_session_timeout() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + client_id: "test-client".into(), + token_validation: "remote".into(), + session_timeout_secs: 0, + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!(err.contains("session_timeout_secs")); + } + + #[test] + async fn nevis_config_validate_accepts_valid_enabled_config() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + realm: "master".into(), + client_id: "test-client".into(), + token_validation: "remote".into(), + session_timeout_secs: 3600, + ..NevisConfig::default() + }; + assert!(cfg.validate().is_ok()); + } + + #[test] + async fn nevis_config_validate_rejects_invalid_token_validation() { + let cfg = NevisConfig { + enabled: true, + instance_url: "https://nevis.example.com".into(), + realm: "master".into(), + client_id: "test-client".into(), + token_validation: "invalid_mode".into(), + session_timeout_secs: 3600, + ..NevisConfig::default() + }; + let err = cfg.validate().unwrap_err(); + assert!( + err.contains("invalid value 'invalid_mode'"), + "Expected invalid token_validation error, got: {err}" + ); + } + + #[test] + async fn nevis_config_debug_redacts_client_secret() { + let cfg = NevisConfig { + client_secret: Some("super-secret".into()), + ..NevisConfig::default() + }; + let debug_output = format!("{:?}", cfg); + assert!( + !debug_output.contains("super-secret"), + "Debug output must not contain the raw client_secret" + ); + assert!( + debug_output.contains("[REDACTED]"), + "Debug output must show [REDACTED] for client_secret" + ); + } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 290cfd482..4c9564770 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -181,7 +181,7 @@ async fn run_agent_job( let run_result = match job.session_target { SessionTarget::Main | SessionTarget::Isolated => { - crate::agent::run( + Box::pin(crate::agent::run( config.clone(), Some(prefixed_prompt), None, @@ -191,7 +191,7 @@ async fn run_agent_job( false, None, job.allowed_tools.clone(), - ) + )) .await } }; @@ -784,7 +784,7 @@ mod tests { job.prompt = Some("Say hello".into()); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &security, &job).await; + let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await; assert!(!success); assert!(output.contains("agent job failed:")); } @@ -799,7 +799,7 @@ mod tests { job.prompt = Some("Say hello".into()); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &security, &job).await; + let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await; assert!(!success); assert!(output.contains("blocked by security policy")); assert!(output.contains("read-only")); @@ -815,7 +815,7 @@ mod tests { job.prompt = Some("Say hello".into()); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &security, &job).await; + let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await; assert!(!success); assert!(output.contains("blocked by security policy")); assert!(output.contains("rate limit exceeded")); diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 9d7af7126..267dae28a 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -77,7 +77,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> { max_backoff, move || { let cfg = channels_cfg.clone(); - async move { crate::channels::start_channels(cfg).await } + async move { Box::pin(crate::channels::start_channels(cfg)).await } }, )); } else { @@ -245,7 +245,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { // ── Phase 1: LLM decision (two-phase mode) ────────────── let tasks_to_run = if two_phase { let decision_prompt = HeartbeatEngine::build_decision_prompt(&tasks); - match crate::agent::run( + match Box::pin(crate::agent::run( config.clone(), Some(decision_prompt), None, @@ -255,7 +255,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { false, None, None, - ) + )) .await { Ok(response) => { @@ -288,7 +288,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { for task in &tasks_to_run { let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text); let temp = config.default_temperature; - match crate::agent::run( + match Box::pin(crate::agent::run( config.clone(), Some(prompt), None, @@ -298,7 +298,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { false, None, None, - ) + )) .await { Ok(output) => { diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 825b88156..f34703408 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -910,7 +910,7 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res /// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk). async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result { let config = state.config.lock().clone(); - crate::agent::process_message(config, message).await + Box::pin(crate::agent::process_message(config, message)).await } /// Webhook request body @@ -1238,7 +1238,7 @@ async fn handle_whatsapp_message( .await; } - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { // Send reply via WhatsApp if let Err(e) = wa @@ -1346,7 +1346,7 @@ async fn handle_linq_webhook( } // Call the LLM - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { // Send reply via Linq if let Err(e) = linq @@ -1438,7 +1438,7 @@ async fn handle_wati_webhook(State(state): State, body: Bytes) -> impl } // Call the LLM - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { // Send reply via WATI if let Err(e) = wati @@ -1542,7 +1542,7 @@ async fn handle_nextcloud_talk_webhook( .await; } - match run_gateway_chat_with_tools(&state, &msg.content).await { + match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await { Ok(response) => { if let Err(e) = nextcloud_talk .send(&SendMessage::new(response, &msg.reply_target)) @@ -2492,11 +2492,11 @@ mod tests { node_registry: Arc::new(nodes::NodeRegistry::new(16)), }; - let response = handle_nextcloud_talk_webhook( + let response = Box::pin(handle_nextcloud_talk_webhook( State(state), HeaderMap::new(), Bytes::from_static(br#"{"type":"message"}"#), - ) + )) .await .into_response(); @@ -2558,9 +2558,13 @@ mod tests { HeaderValue::from_str(invalid_signature).unwrap(), ); - let response = handle_nextcloud_talk_webhook(State(state), headers, Bytes::from(body)) - .await - .into_response(); + let response = Box::pin(handle_nextcloud_talk_webhook( + State(state), + headers, + Bytes::from(body), + )) + .await + .into_response(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0); } diff --git a/src/lib.rs b/src/lib.rs index 71248da85..94b0d3765 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ pub(crate) mod integrations; pub mod memory; pub(crate) mod migration; pub(crate) mod multimodal; +pub mod nodes; pub mod observability; pub(crate) mod onboard; pub mod peripherals; diff --git a/src/main.rs b/src/main.rs index bf73607af..b08d4de0c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -844,7 +844,7 @@ async fn main() -> Result<()> { // Auto-start channels if user said yes during wizard if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") { - channels::start_channels(config).await?; + Box::pin(channels::start_channels(config)).await?; } return Ok(()); } @@ -880,7 +880,7 @@ async fn main() -> Result<()> { } => { let final_temperature = temperature.unwrap_or(config.default_temperature); - agent::run( + Box::pin(agent::run( config, message, provider, @@ -890,7 +890,7 @@ async fn main() -> Result<()> { true, session_state_file, None, - ) + )) .await .map(|_| ()) } @@ -1189,8 +1189,8 @@ async fn main() -> Result<()> { }, Commands::Channel { channel_command } => match channel_command { - ChannelCommands::Start => channels::start_channels(config).await, - ChannelCommands::Doctor => channels::doctor_channels(config).await, + ChannelCommands::Start => Box::pin(channels::start_channels(config)).await, + ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await, other => channels::handle_command(other, &config).await, }, diff --git a/src/nodes/mod.rs b/src/nodes/mod.rs new file mode 100644 index 000000000..1207bb50c --- /dev/null +++ b/src/nodes/mod.rs @@ -0,0 +1,3 @@ +pub mod transport; + +pub use transport::NodeTransport; diff --git a/src/nodes/transport.rs b/src/nodes/transport.rs new file mode 100644 index 000000000..75bc4d434 --- /dev/null +++ b/src/nodes/transport.rs @@ -0,0 +1,235 @@ +//! Corporate-friendly secure node transport using standard HTTPS + HMAC-SHA256 authentication. +//! +//! All inter-node traffic uses plain HTTPS on port 443 — no exotic protocols, +//! no custom binary framing, no UDP tunneling. This makes the transport +//! compatible with corporate proxies, firewalls, and IT audit expectations. + +use anyhow::{bail, Result}; +use chrono::Utc; +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +type HmacSha256 = Hmac; + +/// Signs a request payload with HMAC-SHA256. +/// +/// Uses `timestamp` + `nonce` alongside the payload to prevent replay attacks. +pub fn sign_request( + shared_secret: &str, + payload: &[u8], + timestamp: i64, + nonce: &str, +) -> Result { + let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes()) + .map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?; + mac.update(×tamp.to_le_bytes()); + mac.update(nonce.as_bytes()); + mac.update(payload); + Ok(hex::encode(mac.finalize().into_bytes())) +} + +/// Verify a signed request, rejecting stale timestamps for replay protection. +pub fn verify_request( + shared_secret: &str, + payload: &[u8], + timestamp: i64, + nonce: &str, + signature: &str, + max_age_secs: i64, +) -> Result { + let now = Utc::now().timestamp(); + if (now - timestamp).abs() > max_age_secs { + bail!("Request timestamp too old or too far in future"); + } + + let expected = sign_request(shared_secret, payload, timestamp, nonce)?; + Ok(constant_time_eq(expected.as_bytes(), signature.as_bytes())) +} + +/// Constant-time comparison to prevent timing attacks. +fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + a.iter() + .zip(b.iter()) + .fold(0u8, |acc, (x, y)| acc | (x ^ y)) + == 0 +} + +// ── Node transport client ─────────────────────────────────────── + +/// Sends authenticated HTTPS requests to peer nodes. +/// +/// Every outgoing request carries three custom headers: +/// - `X-ZeroClaw-Timestamp` — unix epoch seconds +/// - `X-ZeroClaw-Nonce` — random UUID v4 +/// - `X-ZeroClaw-Signature` — HMAC-SHA256 hex digest +/// +/// Incoming requests are verified with the same scheme via [`Self::verify_incoming`]. +pub struct NodeTransport { + http: reqwest::Client, + shared_secret: String, + max_request_age_secs: i64, +} + +impl NodeTransport { + pub fn new(shared_secret: String) -> Self { + Self { + http: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("HTTP client build"), + shared_secret, + max_request_age_secs: 300, // 5 min replay window + } + } + + /// Send an authenticated request to a peer node. + pub async fn send( + &self, + node_address: &str, + endpoint: &str, + payload: serde_json::Value, + ) -> Result { + let body = serde_json::to_vec(&payload)?; + let timestamp = Utc::now().timestamp(); + let nonce = uuid::Uuid::new_v4().to_string(); + let signature = sign_request(&self.shared_secret, &body, timestamp, &nonce)?; + + let url = format!("https://{node_address}/api/node-control/{endpoint}"); + let resp = self + .http + .post(&url) + .header("X-ZeroClaw-Timestamp", timestamp.to_string()) + .header("X-ZeroClaw-Nonce", &nonce) + .header("X-ZeroClaw-Signature", &signature) + .header("Content-Type", "application/json") + .body(body) + .send() + .await?; + + if !resp.status().is_success() { + bail!( + "Node request failed: {} {}", + resp.status(), + resp.text().await.unwrap_or_default() + ); + } + + Ok(resp.json().await?) + } + + /// Verify an incoming request from a peer node. + pub fn verify_incoming( + &self, + payload: &[u8], + timestamp_header: &str, + nonce_header: &str, + signature_header: &str, + ) -> Result { + let timestamp: i64 = timestamp_header + .parse() + .map_err(|_| anyhow::anyhow!("Invalid timestamp header"))?; + verify_request( + &self.shared_secret, + payload, + timestamp, + nonce_header, + signature_header, + self.max_request_age_secs, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_SECRET: &str = "test-shared-secret-key"; + + #[test] + fn sign_request_deterministic() { + let sig1 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap(); + let sig2 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap(); + assert_eq!(sig1, sig2, "Same inputs must produce the same signature"); + } + + #[test] + fn verify_request_accepts_valid_signature() { + let now = Utc::now().timestamp(); + let sig = sign_request(TEST_SECRET, b"payload", now, "nonce-a").unwrap(); + let ok = verify_request(TEST_SECRET, b"payload", now, "nonce-a", &sig, 300).unwrap(); + assert!(ok, "Valid signature must pass verification"); + } + + #[test] + fn verify_request_rejects_tampered_payload() { + let now = Utc::now().timestamp(); + let sig = sign_request(TEST_SECRET, b"original", now, "nonce-b").unwrap(); + let ok = verify_request(TEST_SECRET, b"tampered", now, "nonce-b", &sig, 300).unwrap(); + assert!(!ok, "Tampered payload must fail verification"); + } + + #[test] + fn verify_request_rejects_expired_timestamp() { + let old = Utc::now().timestamp() - 600; + let sig = sign_request(TEST_SECRET, b"data", old, "nonce-c").unwrap(); + let result = verify_request(TEST_SECRET, b"data", old, "nonce-c", &sig, 300); + assert!(result.is_err(), "Expired timestamp must be rejected"); + } + + #[test] + fn verify_request_rejects_wrong_secret() { + let now = Utc::now().timestamp(); + let sig = sign_request(TEST_SECRET, b"data", now, "nonce-d").unwrap(); + let ok = verify_request("wrong-secret", b"data", now, "nonce-d", &sig, 300).unwrap(); + assert!(!ok, "Wrong secret must fail verification"); + } + + #[test] + fn constant_time_eq_correctness() { + assert!(constant_time_eq(b"abc", b"abc")); + assert!(!constant_time_eq(b"abc", b"abd")); + assert!(!constant_time_eq(b"abc", b"ab")); + assert!(!constant_time_eq(b"", b"a")); + assert!(constant_time_eq(b"", b"")); + } + + #[test] + fn node_transport_construction() { + let transport = NodeTransport::new("secret-key".into()); + assert_eq!(transport.max_request_age_secs, 300); + } + + #[test] + fn node_transport_verify_incoming_valid() { + let transport = NodeTransport::new(TEST_SECRET.into()); + let now = Utc::now().timestamp(); + let payload = b"test-body"; + let nonce = "incoming-nonce"; + let sig = sign_request(TEST_SECRET, payload, now, nonce).unwrap(); + + let ok = transport + .verify_incoming(payload, &now.to_string(), nonce, &sig) + .unwrap(); + assert!(ok, "Valid incoming request must pass verification"); + } + + #[test] + fn node_transport_verify_incoming_bad_timestamp_header() { + let transport = NodeTransport::new(TEST_SECRET.into()); + let result = transport.verify_incoming(b"body", "not-a-number", "nonce", "sig"); + assert!(result.is_err(), "Non-numeric timestamp header must error"); + } + + #[test] + fn sign_request_different_nonce_different_signature() { + let sig1 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-1").unwrap(); + let sig2 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-2").unwrap(); + assert_ne!( + sig1, sig2, + "Different nonces must produce different signatures" + ); + } +} diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 44d3e948b..c1acf17f9 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -143,7 +143,10 @@ pub async fn run_wizard(force: bool) -> Result { extra_headers: std::collections::HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: crate::config::BackupConfig::default(), + data_retention: crate::config::DataRetentionConfig::default(), security: crate::config::SecurityConfig::default(), + security_ops: crate::config::SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), @@ -159,12 +162,14 @@ pub async fn run_wizard(force: bool) -> Result { tunnel: tunnel_config, gateway: crate::config::GatewayConfig::default(), composio: composio_config, + microsoft365: crate::config::Microsoft365Config::default(), secrets: secrets_config, browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), multimodal: crate::config::MultimodalConfig::default(), web_fetch: crate::config::WebFetchConfig::default(), web_search: crate::config::WebSearchConfig::default(), + project_intel: crate::config::ProjectIntelConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), cost: crate::config::CostConfig::default(), @@ -179,8 +184,8 @@ pub async fn run_wizard(force: bool) -> Result { mcp: crate::config::McpConfig::default(), nodes: crate::config::NodesConfig::default(), workspace: crate::config::WorkspaceConfig::default(), - backup: crate::config::BackupConfig::default(), - data_retention: crate::config::DataRetentionConfig::default(), + notion: crate::config::NotionConfig::default(), + node_transport: crate::config::NodeTransportConfig::default(), }; println!( @@ -504,7 +509,10 @@ async fn run_quick_setup_with_home( extra_headers: std::collections::HashMap::new(), observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), + backup: crate::config::BackupConfig::default(), + data_retention: crate::config::DataRetentionConfig::default(), security: crate::config::SecurityConfig::default(), + security_ops: crate::config::SecurityOpsConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), @@ -520,12 +528,14 @@ async fn run_quick_setup_with_home( tunnel: crate::config::TunnelConfig::default(), gateway: crate::config::GatewayConfig::default(), composio: ComposioConfig::default(), + microsoft365: crate::config::Microsoft365Config::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), multimodal: crate::config::MultimodalConfig::default(), web_fetch: crate::config::WebFetchConfig::default(), web_search: crate::config::WebSearchConfig::default(), + project_intel: crate::config::ProjectIntelConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), cost: crate::config::CostConfig::default(), @@ -540,8 +550,8 @@ async fn run_quick_setup_with_home( mcp: crate::config::McpConfig::default(), nodes: crate::config::NodesConfig::default(), workspace: crate::config::WorkspaceConfig::default(), - backup: crate::config::BackupConfig::default(), - data_retention: crate::config::DataRetentionConfig::default(), + notion: crate::config::NotionConfig::default(), + node_transport: crate::config::NodeTransportConfig::default(), }; config.save().await?; diff --git a/src/security/iam_policy.rs b/src/security/iam_policy.rs new file mode 100644 index 000000000..36a5fab00 --- /dev/null +++ b/src/security/iam_policy.rs @@ -0,0 +1,449 @@ +//! IAM-aware policy enforcement for Nevis role-to-permission mapping. +//! +//! Evaluates tool and workspace access based on Nevis roles using a +//! deny-by-default policy model. All policy decisions are audit-logged. + +use super::nevis::NevisIdentity; +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Maps a single Nevis role to ZeroClaw permissions. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoleMapping { + /// Nevis role name (case-insensitive matching). + pub nevis_role: String, + /// Tool names this role can access. Use `"all"` to grant all tools. + pub zeroclaw_permissions: Vec, + /// Workspace names this role can access. Use `"all"` for unrestricted. + #[serde(default)] + pub workspace_access: Vec, +} + +/// Result of a policy evaluation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PolicyDecision { + /// Access is allowed. + Allow, + /// Access is denied, with reason. + Deny(String), +} + +impl PolicyDecision { + pub fn is_allowed(&self) -> bool { + matches!(self, PolicyDecision::Allow) + } +} + +/// IAM policy engine that maps Nevis roles to ZeroClaw tool permissions. +/// +/// Deny-by-default: if no role mapping grants access, the request is denied. +#[derive(Debug, Clone)] +pub struct IamPolicy { + /// Compiled role mappings indexed by lowercase Nevis role name. + role_map: HashMap, +} + +#[derive(Debug, Clone)] +struct CompiledRole { + /// Whether this role has access to all tools. + all_tools: bool, + /// Specific tool names this role can access (lowercase). + allowed_tools: Vec, + /// Whether this role has access to all workspaces. + all_workspaces: bool, + /// Specific workspace names this role can access (lowercase). + allowed_workspaces: Vec, +} + +impl IamPolicy { + /// Build a policy from role mappings (typically from config). + /// + /// Returns an error if duplicate normalized role names are detected, + /// since silent last-wins overwrites can accidentally broaden or revoke access. + pub fn from_mappings(mappings: &[RoleMapping]) -> Result { + let mut role_map = HashMap::new(); + + for mapping in mappings { + let key = mapping.nevis_role.trim().to_ascii_lowercase(); + if key.is_empty() { + continue; + } + + let all_tools = mapping + .zeroclaw_permissions + .iter() + .any(|p| p.eq_ignore_ascii_case("all")); + let allowed_tools: Vec = mapping + .zeroclaw_permissions + .iter() + .filter(|p| !p.eq_ignore_ascii_case("all")) + .map(|p| p.trim().to_ascii_lowercase()) + .collect(); + + let all_workspaces = mapping + .workspace_access + .iter() + .any(|w| w.eq_ignore_ascii_case("all")); + let allowed_workspaces: Vec = mapping + .workspace_access + .iter() + .filter(|w| !w.eq_ignore_ascii_case("all")) + .map(|w| w.trim().to_ascii_lowercase()) + .collect(); + + if role_map.contains_key(&key) { + bail!( + "IAM policy: duplicate role mapping for normalized key '{}' \ + (from nevis_role '{}') — remove or merge the duplicate entry", + key, + mapping.nevis_role + ); + } + + role_map.insert( + key, + CompiledRole { + all_tools, + allowed_tools, + all_workspaces, + allowed_workspaces, + }, + ); + } + + Ok(Self { role_map }) + } + + /// Evaluate whether an identity is allowed to use a specific tool. + /// + /// Deny-by-default: returns `Deny` unless at least one of the identity's + /// roles grants access to the requested tool. + pub fn evaluate_tool_access( + &self, + identity: &NevisIdentity, + tool_name: &str, + ) -> PolicyDecision { + let normalized_tool = tool_name.trim().to_ascii_lowercase(); + if normalized_tool.is_empty() { + return PolicyDecision::Deny("empty tool name".into()); + } + + for role in &identity.roles { + let key = role.trim().to_ascii_lowercase(); + if let Some(compiled) = self.role_map.get(&key) { + if compiled.all_tools + || compiled.allowed_tools.iter().any(|t| t == &normalized_tool) + { + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + role = %key, + tool = %normalized_tool, + "IAM policy: tool access ALLOWED" + ); + return PolicyDecision::Allow; + } + } + } + + let reason = format!( + "no role grants access to tool '{normalized_tool}' for user '{}'", + crate::security::redact(&identity.user_id) + ); + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + tool = %normalized_tool, + "IAM policy: tool access DENIED" + ); + PolicyDecision::Deny(reason) + } + + /// Evaluate whether an identity is allowed to access a specific workspace. + /// + /// Deny-by-default: returns `Deny` unless at least one of the identity's + /// roles grants access to the requested workspace. + pub fn evaluate_workspace_access( + &self, + identity: &NevisIdentity, + workspace: &str, + ) -> PolicyDecision { + let normalized_ws = workspace.trim().to_ascii_lowercase(); + if normalized_ws.is_empty() { + return PolicyDecision::Deny("empty workspace name".into()); + } + + for role in &identity.roles { + let key = role.trim().to_ascii_lowercase(); + if let Some(compiled) = self.role_map.get(&key) { + if compiled.all_workspaces + || compiled + .allowed_workspaces + .iter() + .any(|w| w == &normalized_ws) + { + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + role = %key, + workspace = %normalized_ws, + "IAM policy: workspace access ALLOWED" + ); + return PolicyDecision::Allow; + } + } + } + + let reason = format!( + "no role grants access to workspace '{normalized_ws}' for user '{}'", + crate::security::redact(&identity.user_id) + ); + tracing::info!( + user_id = %crate::security::redact(&identity.user_id), + workspace = %normalized_ws, + "IAM policy: workspace access DENIED" + ); + PolicyDecision::Deny(reason) + } + + /// Check if the policy has any role mappings configured. + pub fn is_empty(&self) -> bool { + self.role_map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_mappings() -> Vec { + vec![ + RoleMapping { + nevis_role: "admin".into(), + zeroclaw_permissions: vec!["all".into()], + workspace_access: vec!["all".into()], + }, + RoleMapping { + nevis_role: "operator".into(), + zeroclaw_permissions: vec![ + "shell".into(), + "file_read".into(), + "file_write".into(), + "memory_search".into(), + ], + workspace_access: vec!["production".into(), "staging".into()], + }, + RoleMapping { + nevis_role: "viewer".into(), + zeroclaw_permissions: vec!["file_read".into(), "memory_search".into()], + workspace_access: vec!["staging".into()], + }, + ] + } + + fn identity_with_roles(roles: Vec<&str>) -> NevisIdentity { + NevisIdentity { + user_id: "zeroclaw_user".into(), + roles: roles.into_iter().map(String::from).collect(), + scopes: vec!["openid".into()], + mfa_verified: true, + session_expiry: u64::MAX, + } + } + + #[test] + fn admin_gets_all_tools() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "any_tool_name") + .is_allowed()); + } + + #[test] + fn admin_gets_all_workspaces() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(policy + .evaluate_workspace_access(&identity, "production") + .is_allowed()); + assert!(policy + .evaluate_workspace_access(&identity, "any_workspace") + .is_allowed()); + } + + #[test] + fn operator_gets_subset_of_tools() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["operator"]); + + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(!policy + .evaluate_tool_access(&identity, "browser") + .is_allowed()); + } + + #[test] + fn operator_workspace_access_is_scoped() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["operator"]); + + assert!(policy + .evaluate_workspace_access(&identity, "production") + .is_allowed()); + assert!(policy + .evaluate_workspace_access(&identity, "staging") + .is_allowed()); + assert!(!policy + .evaluate_workspace_access(&identity, "development") + .is_allowed()); + } + + #[test] + fn viewer_is_read_only() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["viewer"]); + + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "memory_search") + .is_allowed()); + assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(!policy + .evaluate_tool_access(&identity, "file_write") + .is_allowed()); + } + + #[test] + fn deny_by_default_for_unknown_role() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["unknown_role"]); + + assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed()); + assert!(!policy + .evaluate_workspace_access(&identity, "production") + .is_allowed()); + } + + #[test] + fn deny_by_default_for_no_roles() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec![]); + + assert!(!policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + } + + #[test] + fn multiple_roles_union_permissions() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["viewer", "operator"]); + + // viewer has file_read, operator has shell — both should be accessible + assert!(policy + .evaluate_tool_access(&identity, "file_read") + .is_allowed()); + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + } + + #[test] + fn role_matching_is_case_insensitive() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["ADMIN"]); + + assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed()); + } + + #[test] + fn tool_matching_is_case_insensitive() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["operator"]); + + assert!(policy.evaluate_tool_access(&identity, "SHELL").is_allowed()); + assert!(policy + .evaluate_tool_access(&identity, "File_Read") + .is_allowed()); + } + + #[test] + fn empty_tool_name_is_denied() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(!policy.evaluate_tool_access(&identity, "").is_allowed()); + assert!(!policy.evaluate_tool_access(&identity, " ").is_allowed()); + } + + #[test] + fn empty_workspace_name_is_denied() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(!policy.evaluate_workspace_access(&identity, "").is_allowed()); + } + + #[test] + fn empty_mappings_deny_everything() { + let policy = IamPolicy::from_mappings(&[]).unwrap(); + let identity = identity_with_roles(vec!["admin"]); + + assert!(policy.is_empty()); + assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed()); + } + + #[test] + fn policy_decision_deny_contains_reason() { + let policy = IamPolicy::from_mappings(&test_mappings()).unwrap(); + let identity = identity_with_roles(vec!["viewer"]); + + let decision = policy.evaluate_tool_access(&identity, "shell"); + match decision { + PolicyDecision::Deny(reason) => { + assert!(reason.contains("shell")); + } + PolicyDecision::Allow => panic!("expected deny"), + } + } + + #[test] + fn duplicate_normalized_roles_are_rejected() { + let mappings = vec![ + RoleMapping { + nevis_role: "admin".into(), + zeroclaw_permissions: vec!["all".into()], + workspace_access: vec!["all".into()], + }, + RoleMapping { + nevis_role: " ADMIN ".into(), + zeroclaw_permissions: vec!["file_read".into()], + workspace_access: vec![], + }, + ]; + let err = IamPolicy::from_mappings(&mappings).unwrap_err(); + assert!( + err.to_string().contains("duplicate role mapping"), + "Expected duplicate role error, got: {err}" + ); + } + + #[test] + fn empty_role_name_in_mapping_is_skipped() { + let mappings = vec![RoleMapping { + nevis_role: " ".into(), + zeroclaw_permissions: vec!["all".into()], + workspace_access: vec![], + }]; + let policy = IamPolicy::from_mappings(&mappings).unwrap(); + assert!(policy.is_empty()); + } +} diff --git a/src/security/mod.rs b/src/security/mod.rs index 37f62c531..433e7046f 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -29,15 +29,19 @@ pub mod domain_matcher; pub mod estop; #[cfg(target_os = "linux")] pub mod firejail; +pub mod iam_policy; #[cfg(feature = "sandbox-landlock")] pub mod landlock; pub mod leak_detector; +pub mod nevis; pub mod otp; pub mod pairing; +pub mod playbook; pub mod policy; pub mod prompt_guard; pub mod secrets; pub mod traits; +pub mod vulnerability; pub mod workspace_boundary; #[allow(unused_imports)] @@ -56,6 +60,11 @@ pub use policy::{AutonomyLevel, SecurityPolicy}; pub use secrets::SecretStore; #[allow(unused_imports)] pub use traits::{NoopSandbox, Sandbox}; +// Nevis IAM integration +#[allow(unused_imports)] +pub use iam_policy::{IamPolicy, PolicyDecision}; +#[allow(unused_imports)] +pub use nevis::{NevisAuthProvider, NevisIdentity}; // Prompt injection defense exports #[allow(unused_imports)] pub use leak_detector::{LeakDetector, LeakResult}; @@ -64,19 +73,16 @@ pub use prompt_guard::{GuardAction, GuardResult, PromptGuard}; #[allow(unused_imports)] pub use workspace_boundary::{BoundaryVerdict, WorkspaceBoundary}; -/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix. +/// Redact sensitive values for safe logging. Shows first 4 characters + "***" suffix. +/// Uses char-boundary-safe indexing to avoid panics on multi-byte UTF-8 strings. /// This function intentionally breaks the data-flow taint chain for static analysis. pub fn redact(value: &str) -> String { - if value.len() <= 4 { + let char_count = value.chars().count(); + if char_count <= 4 { "***".to_string() } else { - // Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8. - let prefix = value - .char_indices() - .nth(4) - .map(|(byte_idx, _)| &value[..byte_idx]) - .unwrap_or(value); - format!("{}***", prefix) + let prefix: String = value.chars().take(4).collect(); + format!("{prefix}***") } } diff --git a/src/security/nevis.rs b/src/security/nevis.rs new file mode 100644 index 000000000..f6b5ef109 --- /dev/null +++ b/src/security/nevis.rs @@ -0,0 +1,587 @@ +//! Nevis IAM authentication provider for ZeroClaw. +//! +//! Integrates with Nevis Security Suite (Adnovum) for OAuth2/OIDC token +//! validation, FIDO2/passkey verification, and session management. Maps Nevis +//! roles to ZeroClaw tool permissions via [`super::iam_policy::IamPolicy`]. + +use anyhow::{bail, Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +/// Identity resolved from a validated Nevis token or session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NevisIdentity { + /// Unique user identifier from Nevis. + pub user_id: String, + /// Nevis roles assigned to this user. + pub roles: Vec, + /// OAuth2 scopes granted to this session. + pub scopes: Vec, + /// Whether the user completed MFA (FIDO2/passkey/OTP) in this session. + pub mfa_verified: bool, + /// When this session expires (seconds since UNIX epoch). + pub session_expiry: u64, +} + +/// Token validation strategy. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TokenValidationMode { + /// Validate JWT locally using cached JWKS keys. + Local, + /// Validate token by calling the Nevis introspection endpoint. + Remote, +} + +impl TokenValidationMode { + pub fn from_str_config(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "local" => Ok(Self::Local), + "remote" => Ok(Self::Remote), + other => bail!("invalid token_validation mode '{other}': expected 'local' or 'remote'"), + } + } +} + +/// Authentication provider backed by a Nevis instance. +/// +/// Validates tokens, manages sessions, and resolves identities. The provider +/// is designed to be shared across concurrent requests (`Send + Sync`). +pub struct NevisAuthProvider { + /// Base URL of the Nevis instance (e.g. `https://nevis.example.com`). + instance_url: String, + /// Nevis realm to authenticate against. + realm: String, + /// OAuth2 client ID registered in Nevis. + client_id: String, + /// OAuth2 client secret (decrypted at startup). + client_secret: Option, + /// Token validation strategy. + validation_mode: TokenValidationMode, + /// JWKS endpoint for local token validation. + jwks_url: Option, + /// Whether MFA is required for all authentications. + require_mfa: bool, + /// Session timeout duration. + session_timeout: Duration, + /// HTTP client for Nevis API calls. + http_client: reqwest::Client, +} + +impl std::fmt::Debug for NevisAuthProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NevisAuthProvider") + .field("instance_url", &self.instance_url) + .field("realm", &self.realm) + .field("client_id", &self.client_id) + .field( + "client_secret", + &self.client_secret.as_ref().map(|_| "[REDACTED]"), + ) + .field("validation_mode", &self.validation_mode) + .field("jwks_url", &self.jwks_url) + .field("require_mfa", &self.require_mfa) + .field("session_timeout", &self.session_timeout) + .finish_non_exhaustive() + } +} + +// Safety: All fields are Send + Sync. The doc comment promises concurrent use, +// so enforce it at compile time to prevent regressions. +#[allow(clippy::used_underscore_items)] +const _: () = { + fn _assert_send_sync() {} + fn _assert() { + _assert_send_sync::(); + } +}; + +impl NevisAuthProvider { + /// Create a new Nevis auth provider from config values. + /// + /// `client_secret` should already be decrypted by the config loader. + pub fn new( + instance_url: String, + realm: String, + client_id: String, + client_secret: Option, + token_validation: &str, + jwks_url: Option, + require_mfa: bool, + session_timeout_secs: u64, + ) -> Result { + let validation_mode = TokenValidationMode::from_str_config(token_validation)?; + + if validation_mode == TokenValidationMode::Local && jwks_url.is_none() { + bail!( + "Nevis token_validation is 'local' but no jwks_url is configured. \ + Either set jwks_url or use token_validation = 'remote'." + ); + } + + let http_client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .context("Failed to create HTTP client for Nevis")?; + + Ok(Self { + instance_url, + realm, + client_id, + client_secret, + validation_mode, + jwks_url, + require_mfa, + session_timeout: Duration::from_secs(session_timeout_secs), + http_client, + }) + } + + /// Validate a bearer token and resolve the caller's identity. + /// + /// Returns `NevisIdentity` on success, or an error if the token is invalid, + /// expired, or MFA requirements are not met. + pub async fn validate_token(&self, token: &str) -> Result { + if token.is_empty() { + bail!("empty bearer token"); + } + + let identity = match self.validation_mode { + TokenValidationMode::Local => self.validate_token_local(token).await?, + TokenValidationMode::Remote => self.validate_token_remote(token).await?, + }; + + if self.require_mfa && !identity.mfa_verified { + bail!( + "MFA is required but user '{}' has not completed MFA verification", + crate::security::redact(&identity.user_id) + ); + } + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if identity.session_expiry > 0 && identity.session_expiry < now { + bail!("Nevis session expired"); + } + + Ok(identity) + } + + /// Validate token by calling the Nevis introspection endpoint. + async fn validate_token_remote(&self, token: &str) -> Result { + let introspect_url = format!( + "{}/auth/realms/{}/protocol/openid-connect/token/introspect", + self.instance_url.trim_end_matches('/'), + self.realm, + ); + + let mut form = vec![("token", token), ("client_id", &self.client_id)]; + // client_secret is optional (public clients don't need it) + let secret_ref; + if let Some(ref secret) = self.client_secret { + secret_ref = secret.as_str(); + form.push(("client_secret", secret_ref)); + } + + let resp = self + .http_client + .post(&introspect_url) + .form(&form) + .send() + .await + .context("Failed to reach Nevis introspection endpoint")?; + + if !resp.status().is_success() { + bail!( + "Nevis introspection returned HTTP {}", + resp.status().as_u16() + ); + } + + let body: IntrospectionResponse = resp + .json() + .await + .context("Failed to parse Nevis introspection response")?; + + if !body.active { + bail!("Token is not active (revoked or expired)"); + } + + let user_id = body + .sub + .filter(|s| !s.trim().is_empty()) + .context("Token has missing or empty `sub` claim")?; + + let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default(); + roles.sort(); + roles.dedup(); + + Ok(NevisIdentity { + user_id, + roles, + scopes: body + .scope + .unwrap_or_default() + .split_whitespace() + .map(String::from) + .collect(), + mfa_verified: body.acr.as_deref() == Some("mfa") + || body + .amr + .iter() + .flatten() + .any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"), + session_expiry: body.exp.unwrap_or(0), + }) + } + + /// Validate token locally using JWKS. + /// + /// Local JWT/JWKS validation is not yet implemented. Rather than silently + /// falling back to the remote introspection endpoint (which would hide a + /// misconfiguration), this returns an explicit error directing the operator + /// to use `token_validation = "remote"` until local JWKS support is added. + #[allow(clippy::unused_async)] // Will use async when JWKS validation is implemented + async fn validate_token_local(&self, token: &str) -> Result { + // JWT structure check: header.payload.signature + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + bail!("Invalid JWT structure: expected 3 dot-separated parts"); + } + + bail!( + "Local JWKS token validation is not yet implemented. \ + Set token_validation = \"remote\" to use the Nevis introspection endpoint." + ); + } + + /// Validate a Nevis session token (cookie-based sessions). + pub async fn validate_session(&self, session_token: &str) -> Result { + if session_token.is_empty() { + bail!("empty session token"); + } + + let session_url = format!( + "{}/auth/realms/{}/protocol/openid-connect/userinfo", + self.instance_url.trim_end_matches('/'), + self.realm, + ); + + let resp = self + .http_client + .get(&session_url) + .bearer_auth(session_token) + .send() + .await + .context("Failed to reach Nevis userinfo endpoint")?; + + if !resp.status().is_success() { + bail!( + "Nevis session validation returned HTTP {}", + resp.status().as_u16() + ); + } + + let body: UserInfoResponse = resp + .json() + .await + .context("Failed to parse Nevis userinfo response")?; + + if body.sub.trim().is_empty() { + bail!("Userinfo response has missing or empty `sub` claim"); + } + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default(); + roles.sort(); + roles.dedup(); + + let identity = NevisIdentity { + user_id: body.sub, + roles, + scopes: body + .scope + .unwrap_or_default() + .split_whitespace() + .map(String::from) + .collect(), + mfa_verified: body.acr.as_deref() == Some("mfa") + || body + .amr + .iter() + .flatten() + .any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"), + session_expiry: now + self.session_timeout.as_secs(), + }; + + if self.require_mfa && !identity.mfa_verified { + bail!( + "MFA is required but user '{}' has not completed MFA verification", + crate::security::redact(&identity.user_id) + ); + } + + Ok(identity) + } + + /// Health check against the Nevis instance. + pub async fn health_check(&self) -> Result<()> { + let health_url = format!( + "{}/auth/realms/{}", + self.instance_url.trim_end_matches('/'), + self.realm, + ); + + let resp = self + .http_client + .get(&health_url) + .send() + .await + .context("Nevis health check failed: cannot reach instance")?; + + if !resp.status().is_success() { + bail!("Nevis health check failed: HTTP {}", resp.status().as_u16()); + } + + Ok(()) + } + + /// Getter for instance URL (for diagnostics). + pub fn instance_url(&self) -> &str { + &self.instance_url + } + + /// Getter for realm. + pub fn realm(&self) -> &str { + &self.realm + } +} + +// ── Wire types for Nevis API responses ───────────────────────────── + +#[derive(Debug, Deserialize)] +struct IntrospectionResponse { + active: bool, + sub: Option, + scope: Option, + exp: Option, + #[serde(rename = "realm_access")] + realm_access: Option, + /// Authentication Context Class Reference + acr: Option, + /// Authentication Methods References + amr: Option>, +} + +#[derive(Debug, Deserialize)] +struct RealmAccess { + #[serde(default)] + roles: Vec, +} + +#[derive(Debug, Deserialize)] +struct UserInfoResponse { + sub: String, + #[serde(rename = "realm_access")] + realm_access: Option, + scope: Option, + acr: Option, + /// Authentication Methods References + amr: Option>, +} + +// ── Tests ────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_validation_mode_from_str() { + assert_eq!( + TokenValidationMode::from_str_config("local").unwrap(), + TokenValidationMode::Local + ); + assert_eq!( + TokenValidationMode::from_str_config("REMOTE").unwrap(), + TokenValidationMode::Remote + ); + assert!(TokenValidationMode::from_str_config("invalid").is_err()); + } + + #[test] + fn local_mode_requires_jwks_url() { + let result = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "local", + None, // no JWKS URL + false, + 3600, + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("jwks_url")); + } + + #[test] + fn remote_mode_works_without_jwks_url() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "remote", + None, + false, + 3600, + ); + assert!(provider.is_ok()); + } + + #[test] + fn provider_stores_config_correctly() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "test-realm".into(), + "zeroclaw-client".into(), + Some("test-secret".into()), + "remote", + None, + true, + 7200, + ) + .unwrap(); + + assert_eq!(provider.instance_url(), "https://nevis.example.com"); + assert_eq!(provider.realm(), "test-realm"); + assert!(provider.require_mfa); + assert_eq!(provider.session_timeout, Duration::from_secs(7200)); + } + + #[test] + fn debug_redacts_client_secret() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "test-realm".into(), + "zeroclaw-client".into(), + Some("super-secret-value".into()), + "remote", + None, + false, + 3600, + ) + .unwrap(); + + let debug_output = format!("{:?}", provider); + assert!( + !debug_output.contains("super-secret-value"), + "Debug output must not contain the raw client_secret" + ); + assert!( + debug_output.contains("[REDACTED]"), + "Debug output must show [REDACTED] for client_secret" + ); + } + + #[tokio::test] + async fn validate_token_rejects_empty() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "remote", + None, + false, + 3600, + ) + .unwrap(); + + let err = provider.validate_token("").await.unwrap_err(); + assert!(err.to_string().contains("empty bearer token")); + } + + #[tokio::test] + async fn validate_session_rejects_empty() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "remote", + None, + false, + 3600, + ) + .unwrap(); + + let err = provider.validate_session("").await.unwrap_err(); + assert!(err.to_string().contains("empty session token")); + } + + #[test] + fn nevis_identity_serde_roundtrip() { + let identity = NevisIdentity { + user_id: "zeroclaw_user".into(), + roles: vec!["admin".into(), "operator".into()], + scopes: vec!["openid".into(), "profile".into()], + mfa_verified: true, + session_expiry: 1_700_000_000, + }; + + let json = serde_json::to_string(&identity).unwrap(); + let parsed: NevisIdentity = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.user_id, "zeroclaw_user"); + assert_eq!(parsed.roles.len(), 2); + assert!(parsed.mfa_verified); + } + + #[tokio::test] + async fn local_validation_rejects_malformed_jwt() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "local", + Some("https://nevis.example.com/.well-known/jwks.json".into()), + false, + 3600, + ) + .unwrap(); + + let err = provider.validate_token("not-a-jwt").await.unwrap_err(); + assert!(err.to_string().contains("Invalid JWT structure")); + } + + #[tokio::test] + async fn local_validation_errors_instead_of_silent_fallback() { + let provider = NevisAuthProvider::new( + "https://nevis.example.com".into(), + "master".into(), + "zeroclaw-client".into(), + None, + "local", + Some("https://nevis.example.com/.well-known/jwks.json".into()), + false, + 3600, + ) + .unwrap(); + + // A well-formed JWT structure should hit the "not yet implemented" error + // instead of silently falling back to remote introspection. + let err = provider + .validate_token("header.payload.signature") + .await + .unwrap_err(); + assert!(err.to_string().contains("not yet implemented")); + } +} diff --git a/src/security/playbook.rs b/src/security/playbook.rs new file mode 100644 index 000000000..cce5a27ff --- /dev/null +++ b/src/security/playbook.rs @@ -0,0 +1,459 @@ +//! Incident response playbook definitions and execution engine. +//! +//! Playbooks define structured response procedures for security incidents. +//! Each playbook has named steps, some of which require human approval before +//! execution. Playbooks are loaded from JSON files in the configured directory. + +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// A single step in an incident response playbook. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PlaybookStep { + /// Machine-readable action identifier (e.g. "isolate_host", "block_ip"). + pub action: String, + /// Human-readable description of what this step does. + pub description: String, + /// Whether this step requires explicit human approval before execution. + #[serde(default)] + pub requires_approval: bool, + /// Timeout in seconds for this step. Default: 300 (5 minutes). + #[serde(default = "default_timeout_secs")] + pub timeout_secs: u64, +} + +fn default_timeout_secs() -> u64 { + 300 +} + +/// An incident response playbook. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Playbook { + /// Unique playbook name (e.g. "suspicious_login"). + pub name: String, + /// Human-readable description. + pub description: String, + /// Ordered list of response steps. + pub steps: Vec, + /// Minimum alert severity that triggers this playbook (low/medium/high/critical). + #[serde(default = "default_severity_filter")] + pub severity_filter: String, + /// Step indices (0-based) that can be auto-approved when below max_auto_severity. + #[serde(default)] + pub auto_approve_steps: Vec, +} + +fn default_severity_filter() -> String { + "medium".into() +} + +/// Result of executing a single playbook step. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepExecutionResult { + pub step_index: usize, + pub action: String, + pub status: StepStatus, + pub message: String, +} + +/// Status of a playbook step. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum StepStatus { + /// Step completed successfully. + Completed, + /// Step is waiting for human approval. + PendingApproval, + /// Step was skipped (e.g. not applicable). + Skipped, + /// Step failed with an error. + Failed, +} + +impl std::fmt::Display for StepStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Completed => write!(f, "completed"), + Self::PendingApproval => write!(f, "pending_approval"), + Self::Skipped => write!(f, "skipped"), + Self::Failed => write!(f, "failed"), + } + } +} + +/// Load all playbook definitions from a directory of JSON files. +pub fn load_playbooks(dir: &Path) -> Vec { + let mut playbooks = Vec::new(); + + if !dir.exists() || !dir.is_dir() { + return builtin_playbooks(); + } + + if let Ok(entries) = std::fs::read_dir(dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().map_or(false, |ext| ext == "json") { + match std::fs::read_to_string(&path) { + Ok(contents) => match serde_json::from_str::(&contents) { + Ok(pb) => playbooks.push(pb), + Err(e) => { + tracing::warn!("Failed to parse playbook {}: {e}", path.display()); + } + }, + Err(e) => { + tracing::warn!("Failed to read playbook {}: {e}", path.display()); + } + } + } + } + } + + // Merge built-in playbooks that aren't overridden by user-defined ones + for builtin in builtin_playbooks() { + if !playbooks.iter().any(|p| p.name == builtin.name) { + playbooks.push(builtin); + } + } + + playbooks +} + +/// Severity ordering for comparison: low < medium < high < critical. +pub fn severity_level(severity: &str) -> u8 { + match severity.to_lowercase().as_str() { + "low" => 1, + "medium" => 2, + "high" => 3, + "critical" => 4, + // Deny-by-default: unknown severities get the highest level to prevent + // auto-approval of unrecognized severity labels. + _ => u8::MAX, + } +} + +/// Check whether a step can be auto-approved given config constraints. +pub fn can_auto_approve( + playbook: &Playbook, + step_index: usize, + alert_severity: &str, + max_auto_severity: &str, +) -> bool { + // Never auto-approve if alert severity exceeds the configured max + if severity_level(alert_severity) > severity_level(max_auto_severity) { + return false; + } + + // Only auto-approve steps explicitly listed in auto_approve_steps + playbook.auto_approve_steps.contains(&step_index) +} + +/// Evaluate a playbook step. Returns the result with approval gating. +/// +/// Steps that require approval and cannot be auto-approved will return +/// `StepStatus::PendingApproval` without executing. +pub fn evaluate_step( + playbook: &Playbook, + step_index: usize, + alert_severity: &str, + max_auto_severity: &str, + require_approval: bool, +) -> StepExecutionResult { + let step = match playbook.steps.get(step_index) { + Some(s) => s, + None => { + return StepExecutionResult { + step_index, + action: "unknown".into(), + status: StepStatus::Failed, + message: format!("Step index {step_index} out of range"), + }; + } + }; + + // Enforce approval gates: steps that require approval must either be + // auto-approved or wait for human approval. Never mark an unexecuted + // approval-gated step as Completed. + if step.requires_approval + && (!require_approval + || !can_auto_approve(playbook, step_index, alert_severity, max_auto_severity)) + { + return StepExecutionResult { + step_index, + action: step.action.clone(), + status: StepStatus::PendingApproval, + message: format!( + "Step '{}' requires human approval (severity: {alert_severity})", + step.description + ), + }; + } + + // Step is approved (either doesn't require approval, or was auto-approved) + // Actual execution would be delegated to the appropriate tool/system + StepExecutionResult { + step_index, + action: step.action.clone(), + status: StepStatus::Completed, + message: format!("Executed: {}", step.description), + } +} + +/// Built-in playbook definitions for common incident types. +pub fn builtin_playbooks() -> Vec { + vec![ + Playbook { + name: "suspicious_login".into(), + description: "Respond to suspicious login activity detected by SIEM".into(), + steps: vec![ + PlaybookStep { + action: "gather_login_context".into(), + description: "Collect login metadata: IP, geo, device fingerprint, time".into(), + requires_approval: false, + timeout_secs: 60, + }, + PlaybookStep { + action: "check_threat_intel".into(), + description: "Query threat intelligence for source IP reputation".into(), + requires_approval: false, + timeout_secs: 30, + }, + PlaybookStep { + action: "notify_user".into(), + description: "Send verification notification to account owner".into(), + requires_approval: true, + timeout_secs: 300, + }, + PlaybookStep { + action: "force_password_reset".into(), + description: "Force password reset if login confirmed unauthorized".into(), + requires_approval: true, + timeout_secs: 120, + }, + ], + severity_filter: "medium".into(), + auto_approve_steps: vec![0, 1], + }, + Playbook { + name: "malware_detected".into(), + description: "Respond to malware detection on endpoint".into(), + steps: vec![ + PlaybookStep { + action: "isolate_endpoint".into(), + description: "Network-isolate the affected endpoint".into(), + requires_approval: true, + timeout_secs: 60, + }, + PlaybookStep { + action: "collect_forensics".into(), + description: "Capture memory dump and disk image for analysis".into(), + requires_approval: false, + timeout_secs: 600, + }, + PlaybookStep { + action: "scan_lateral_movement".into(), + description: "Check for lateral movement indicators on adjacent hosts".into(), + requires_approval: false, + timeout_secs: 300, + }, + PlaybookStep { + action: "remediate_endpoint".into(), + description: "Remove malware and restore endpoint to clean state".into(), + requires_approval: true, + timeout_secs: 600, + }, + ], + severity_filter: "high".into(), + auto_approve_steps: vec![1, 2], + }, + Playbook { + name: "data_exfiltration_attempt".into(), + description: "Respond to suspected data exfiltration".into(), + steps: vec![ + PlaybookStep { + action: "block_egress".into(), + description: "Block suspicious outbound connections".into(), + requires_approval: true, + timeout_secs: 30, + }, + PlaybookStep { + action: "identify_data_scope".into(), + description: "Determine what data may have been accessed or transferred".into(), + requires_approval: false, + timeout_secs: 300, + }, + PlaybookStep { + action: "preserve_evidence".into(), + description: "Preserve network logs and access records".into(), + requires_approval: false, + timeout_secs: 120, + }, + PlaybookStep { + action: "escalate_to_legal".into(), + description: "Notify legal and compliance teams".into(), + requires_approval: true, + timeout_secs: 60, + }, + ], + severity_filter: "critical".into(), + auto_approve_steps: vec![1, 2], + }, + Playbook { + name: "brute_force".into(), + description: "Respond to brute force authentication attempts".into(), + steps: vec![ + PlaybookStep { + action: "block_source_ip".into(), + description: "Block the attacking source IP at firewall".into(), + requires_approval: true, + timeout_secs: 30, + }, + PlaybookStep { + action: "check_compromised_accounts".into(), + description: "Check if any accounts were successfully compromised".into(), + requires_approval: false, + timeout_secs: 120, + }, + PlaybookStep { + action: "enable_rate_limiting".into(), + description: "Enable enhanced rate limiting on auth endpoints".into(), + requires_approval: true, + timeout_secs: 60, + }, + ], + severity_filter: "medium".into(), + auto_approve_steps: vec![1], + }, + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builtin_playbooks_are_valid() { + let playbooks = builtin_playbooks(); + assert_eq!(playbooks.len(), 4); + + let names: Vec<&str> = playbooks.iter().map(|p| p.name.as_str()).collect(); + assert!(names.contains(&"suspicious_login")); + assert!(names.contains(&"malware_detected")); + assert!(names.contains(&"data_exfiltration_attempt")); + assert!(names.contains(&"brute_force")); + + for pb in &playbooks { + assert!(!pb.steps.is_empty(), "Playbook {} has no steps", pb.name); + assert!(!pb.description.is_empty()); + } + } + + #[test] + fn severity_level_ordering() { + assert!(severity_level("low") < severity_level("medium")); + assert!(severity_level("medium") < severity_level("high")); + assert!(severity_level("high") < severity_level("critical")); + assert_eq!(severity_level("unknown"), u8::MAX); + } + + #[test] + fn auto_approve_respects_severity_cap() { + let pb = &builtin_playbooks()[0]; // suspicious_login + + // Step 0 is in auto_approve_steps + assert!(can_auto_approve(pb, 0, "low", "low")); + assert!(can_auto_approve(pb, 0, "low", "medium")); + + // Alert severity exceeds max -> cannot auto-approve + assert!(!can_auto_approve(pb, 0, "high", "low")); + assert!(!can_auto_approve(pb, 0, "critical", "medium")); + + // Step 2 is NOT in auto_approve_steps + assert!(!can_auto_approve(pb, 2, "low", "critical")); + } + + #[test] + fn evaluate_step_requires_approval() { + let pb = &builtin_playbooks()[0]; // suspicious_login + + // Step 2 (notify_user) requires approval, high severity, max=low -> pending + let result = evaluate_step(pb, 2, "high", "low", true); + assert_eq!(result.status, StepStatus::PendingApproval); + assert_eq!(result.action, "notify_user"); + + // Step 0 (gather_login_context) does NOT require approval -> completed + let result = evaluate_step(pb, 0, "high", "low", true); + assert_eq!(result.status, StepStatus::Completed); + } + + #[test] + fn evaluate_step_out_of_range() { + let pb = &builtin_playbooks()[0]; + let result = evaluate_step(pb, 99, "low", "low", true); + assert_eq!(result.status, StepStatus::Failed); + } + + #[test] + fn playbook_json_roundtrip() { + let pb = &builtin_playbooks()[0]; + let json = serde_json::to_string(pb).unwrap(); + let parsed: Playbook = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, *pb); + } + + #[test] + fn load_playbooks_from_nonexistent_dir_returns_builtins() { + let playbooks = load_playbooks(Path::new("/nonexistent/dir")); + assert_eq!(playbooks.len(), 4); + } + + #[test] + fn load_playbooks_merges_custom_and_builtin() { + let dir = tempfile::tempdir().unwrap(); + let custom = Playbook { + name: "custom_playbook".into(), + description: "A custom playbook".into(), + steps: vec![PlaybookStep { + action: "custom_action".into(), + description: "Do something custom".into(), + requires_approval: true, + timeout_secs: 60, + }], + severity_filter: "low".into(), + auto_approve_steps: vec![], + }; + let json = serde_json::to_string(&custom).unwrap(); + std::fs::write(dir.path().join("custom.json"), json).unwrap(); + + let playbooks = load_playbooks(dir.path()); + // 4 builtins + 1 custom + assert_eq!(playbooks.len(), 5); + assert!(playbooks.iter().any(|p| p.name == "custom_playbook")); + } + + #[test] + fn load_playbooks_custom_overrides_builtin() { + let dir = tempfile::tempdir().unwrap(); + let override_pb = Playbook { + name: "suspicious_login".into(), + description: "Custom override".into(), + steps: vec![PlaybookStep { + action: "custom_step".into(), + description: "Overridden step".into(), + requires_approval: false, + timeout_secs: 30, + }], + severity_filter: "low".into(), + auto_approve_steps: vec![0], + }; + let json = serde_json::to_string(&override_pb).unwrap(); + std::fs::write(dir.path().join("suspicious_login.json"), json).unwrap(); + + let playbooks = load_playbooks(dir.path()); + // 3 remaining builtins + 1 overridden = 4 + assert_eq!(playbooks.len(), 4); + let sl = playbooks + .iter() + .find(|p| p.name == "suspicious_login") + .unwrap(); + assert_eq!(sl.description, "Custom override"); + } +} diff --git a/src/security/vulnerability.rs b/src/security/vulnerability.rs new file mode 100644 index 000000000..0b8e30535 --- /dev/null +++ b/src/security/vulnerability.rs @@ -0,0 +1,397 @@ +//! Vulnerability scan result parsing and management. +//! +//! Parses vulnerability scan outputs from common scanners (Nessus, Qualys, generic +//! CVSS JSON) and provides priority scoring with business context adjustments. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; + +/// A single vulnerability finding. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Finding { + /// CVE identifier (e.g. "CVE-2024-1234"). May be empty for non-CVE findings. + #[serde(default)] + pub cve_id: String, + /// CVSS base score (0.0 - 10.0). + pub cvss_score: f64, + /// Severity label: "low", "medium", "high", "critical". + pub severity: String, + /// Affected asset identifier (hostname, IP, or service name). + pub affected_asset: String, + /// Description of the vulnerability. + pub description: String, + /// Recommended remediation steps. + #[serde(default)] + pub remediation: String, + /// Whether the asset is internet-facing (increases effective priority). + #[serde(default)] + pub internet_facing: bool, + /// Whether the asset is in a production environment. + #[serde(default = "default_true")] + pub production: bool, +} + +fn default_true() -> bool { + true +} + +/// A parsed vulnerability scan report. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VulnerabilityReport { + /// When the scan was performed. + pub scan_date: DateTime, + /// Scanner that produced the results (e.g. "nessus", "qualys", "generic"). + pub scanner: String, + /// Individual findings from the scan. + pub findings: Vec, +} + +/// Compute effective priority score for a finding. +/// +/// Base: CVSS score (0-10). Adjustments: +/// - Internet-facing: +2.0 (capped at 10.0) +/// - Production: +1.0 (capped at 10.0) +pub fn effective_priority(finding: &Finding) -> f64 { + let mut score = finding.cvss_score; + if finding.internet_facing { + score += 2.0; + } + if finding.production { + score += 1.0; + } + score.min(10.0) +} + +/// Classify CVSS score into severity label. +pub fn cvss_to_severity(cvss: f64) -> &'static str { + match cvss { + s if s >= 9.0 => "critical", + s if s >= 7.0 => "high", + s if s >= 4.0 => "medium", + s if s > 0.0 => "low", + _ => "informational", + } +} + +/// Parse a generic CVSS JSON vulnerability report. +/// +/// Expects a JSON object with: +/// - `scan_date`: ISO 8601 date string +/// - `scanner`: string +/// - `findings`: array of Finding objects +pub fn parse_vulnerability_json(json_str: &str) -> anyhow::Result { + let report: VulnerabilityReport = serde_json::from_str(json_str) + .map_err(|e| anyhow::anyhow!("Failed to parse vulnerability report: {e}"))?; + + for (i, finding) in report.findings.iter().enumerate() { + if !(0.0..=10.0).contains(&finding.cvss_score) { + anyhow::bail!( + "findings[{}].cvss_score must be between 0.0 and 10.0, got {}", + i, + finding.cvss_score + ); + } + } + + Ok(report) +} + +/// Generate a summary of the vulnerability report. +pub fn generate_summary(report: &VulnerabilityReport) -> String { + if report.findings.is_empty() { + return format!( + "Vulnerability scan by {} on {}: No findings.", + report.scanner, + report.scan_date.format("%Y-%m-%d") + ); + } + + let total = report.findings.len(); + let critical = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("critical")) + .count(); + let high = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("high")) + .count(); + let medium = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("medium")) + .count(); + let low = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("low")) + .count(); + let informational = report + .findings + .iter() + .filter(|f| f.severity.eq_ignore_ascii_case("informational")) + .count(); + + // Sort by effective priority descending + let mut sorted: Vec<&Finding> = report.findings.iter().collect(); + sorted.sort_by(|a, b| { + effective_priority(b) + .partial_cmp(&effective_priority(a)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut summary = format!( + "## Vulnerability Scan Summary\n\ + **Scanner:** {} | **Date:** {}\n\ + **Total findings:** {} (Critical: {}, High: {}, Medium: {}, Low: {}, Informational: {})\n\n", + report.scanner, + report.scan_date.format("%Y-%m-%d"), + total, + critical, + high, + medium, + low, + informational + ); + + // Top 10 by effective priority + summary.push_str("### Top Findings by Priority\n\n"); + for (i, finding) in sorted.iter().take(10).enumerate() { + let priority = effective_priority(finding); + let context = match (finding.internet_facing, finding.production) { + (true, true) => " [internet-facing, production]", + (true, false) => " [internet-facing]", + (false, true) => " [production]", + (false, false) => "", + }; + let _ = writeln!( + summary, + "{}. **{}** (CVSS: {:.1}, Priority: {:.1}){}\n Asset: {} | {}", + i + 1, + if finding.cve_id.is_empty() { + "No CVE" + } else { + &finding.cve_id + }, + finding.cvss_score, + priority, + context, + finding.affected_asset, + finding.description + ); + if !finding.remediation.is_empty() { + let _ = writeln!(summary, " Remediation: {}", finding.remediation); + } + summary.push('\n'); + } + + // Remediation recommendations + if critical > 0 || high > 0 { + summary.push_str("### Remediation Recommendations\n\n"); + if critical > 0 { + let _ = writeln!( + summary, + "- **URGENT:** {} critical findings require immediate remediation", + critical + ); + } + if high > 0 { + let _ = writeln!( + summary, + "- **HIGH:** {} high-severity findings should be addressed within 7 days", + high + ); + } + let internet_facing_critical = sorted + .iter() + .filter(|f| f.internet_facing && (f.severity == "critical" || f.severity == "high")) + .count(); + if internet_facing_critical > 0 { + let _ = writeln!( + summary, + "- **PRIORITY:** {} critical/high findings on internet-facing assets", + internet_facing_critical + ); + } + } + + summary +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_findings() -> Vec { + vec![ + Finding { + cve_id: "CVE-2024-0001".into(), + cvss_score: 9.8, + severity: "critical".into(), + affected_asset: "web-server-01".into(), + description: "Remote code execution in web framework".into(), + remediation: "Upgrade to version 2.1.0".into(), + internet_facing: true, + production: true, + }, + Finding { + cve_id: "CVE-2024-0002".into(), + cvss_score: 7.5, + severity: "high".into(), + affected_asset: "db-server-01".into(), + description: "SQL injection in query parser".into(), + remediation: "Apply patch KB-12345".into(), + internet_facing: false, + production: true, + }, + Finding { + cve_id: "CVE-2024-0003".into(), + cvss_score: 4.3, + severity: "medium".into(), + affected_asset: "staging-app-01".into(), + description: "Information disclosure via debug endpoint".into(), + remediation: "Disable debug endpoint in config".into(), + internet_facing: false, + production: false, + }, + ] + } + + #[test] + fn effective_priority_adds_context_bonuses() { + let mut f = Finding { + cve_id: String::new(), + cvss_score: 7.0, + severity: "high".into(), + affected_asset: "host".into(), + description: "test".into(), + remediation: String::new(), + internet_facing: false, + production: false, + }; + + assert!((effective_priority(&f) - 7.0).abs() < f64::EPSILON); + + f.internet_facing = true; + assert!((effective_priority(&f) - 9.0).abs() < f64::EPSILON); + + f.production = true; + assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); // capped + + // High CVSS + both bonuses still caps at 10.0 + f.cvss_score = 9.5; + assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); + } + + #[test] + fn cvss_to_severity_classification() { + assert_eq!(cvss_to_severity(9.8), "critical"); + assert_eq!(cvss_to_severity(9.0), "critical"); + assert_eq!(cvss_to_severity(8.5), "high"); + assert_eq!(cvss_to_severity(7.0), "high"); + assert_eq!(cvss_to_severity(5.0), "medium"); + assert_eq!(cvss_to_severity(4.0), "medium"); + assert_eq!(cvss_to_severity(3.9), "low"); + assert_eq!(cvss_to_severity(0.1), "low"); + assert_eq!(cvss_to_severity(0.0), "informational"); + } + + #[test] + fn parse_vulnerability_json_roundtrip() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "nessus".into(), + findings: sample_findings(), + }; + + let json = serde_json::to_string(&report).unwrap(); + let parsed = parse_vulnerability_json(&json).unwrap(); + + assert_eq!(parsed.scanner, "nessus"); + assert_eq!(parsed.findings.len(), 3); + assert_eq!(parsed.findings[0].cve_id, "CVE-2024-0001"); + } + + #[test] + fn parse_vulnerability_json_rejects_invalid() { + let result = parse_vulnerability_json("not json"); + assert!(result.is_err()); + } + + #[test] + fn generate_summary_includes_key_sections() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "qualys".into(), + findings: sample_findings(), + }; + + let summary = generate_summary(&report); + + assert!(summary.contains("qualys")); + assert!(summary.contains("Total findings:** 3")); + assert!(summary.contains("Critical: 1")); + assert!(summary.contains("High: 1")); + assert!(summary.contains("CVE-2024-0001")); + assert!(summary.contains("URGENT")); + assert!(summary.contains("internet-facing")); + } + + #[test] + fn parse_vulnerability_json_rejects_out_of_range_cvss() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "test".into(), + findings: vec![Finding { + cve_id: "CVE-2024-9999".into(), + cvss_score: 11.0, + severity: "critical".into(), + affected_asset: "host".into(), + description: "bad score".into(), + remediation: String::new(), + internet_facing: false, + production: false, + }], + }; + let json = serde_json::to_string(&report).unwrap(); + let result = parse_vulnerability_json(&json); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("cvss_score must be between 0.0 and 10.0")); + } + + #[test] + fn parse_vulnerability_json_rejects_negative_cvss() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "test".into(), + findings: vec![Finding { + cve_id: "CVE-2024-9998".into(), + cvss_score: -1.0, + severity: "low".into(), + affected_asset: "host".into(), + description: "negative score".into(), + remediation: String::new(), + internet_facing: false, + production: false, + }], + }; + let json = serde_json::to_string(&report).unwrap(); + let result = parse_vulnerability_json(&json); + assert!(result.is_err()); + } + + #[test] + fn generate_summary_empty_findings() { + let report = VulnerabilityReport { + scan_date: Utc::now(), + scanner: "nessus".into(), + findings: vec![], + }; + + let summary = generate_summary(&report); + assert!(summary.contains("No findings")); + } +} diff --git a/src/tools/browser.rs b/src/tools/browser.rs index 1603176c1..5bd559b12 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -1470,7 +1470,7 @@ mod native_backend { // When running as a service (systemd/OpenRC), the browser sandbox // fails because the process lacks a user namespace / session. // --no-sandbox and --disable-dev-shm-usage are required in this context. - if is_service_environment() { + if super::is_service_environment() { args.push(Value::String("--no-sandbox".to_string())); args.push(Value::String("--disable-dev-shm-usage".to_string())); } diff --git a/src/tools/microsoft365/auth.rs b/src/tools/microsoft365/auth.rs new file mode 100644 index 000000000..07afd4b14 --- /dev/null +++ b/src/tools/microsoft365/auth.rs @@ -0,0 +1,400 @@ +use anyhow::Context; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::path::PathBuf; +use tokio::sync::Mutex; + +/// Cached OAuth2 token state persisted to disk between runs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedTokenState { + pub access_token: String, + pub refresh_token: Option, + /// Unix timestamp (seconds) when the access token expires. + pub expires_at: i64, +} + +impl CachedTokenState { + /// Returns `true` when the token is expired or will expire within 60 seconds. + pub fn is_expired(&self) -> bool { + let now = chrono::Utc::now().timestamp(); + self.expires_at <= now + 60 + } +} + +/// Thread-safe token cache with disk persistence. +pub struct TokenCache { + inner: RwLock>, + /// Serialises the slow acquire/refresh path so only one caller performs the + /// network round-trip while others wait and then read the updated cache. + acquire_lock: Mutex<()>, + config: super::types::Microsoft365ResolvedConfig, + cache_path: PathBuf, +} + +impl TokenCache { + pub fn new( + config: super::types::Microsoft365ResolvedConfig, + zeroclaw_dir: &std::path::Path, + ) -> anyhow::Result { + if config.token_cache_encrypted { + anyhow::bail!( + "microsoft365: token_cache_encrypted is enabled but encryption is not yet \ + implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \ + to false or wait for encryption support." + ); + } + + // Scope cache file to (tenant_id, client_id, auth_flow) so config + // changes never reuse tokens from a different account/flow. + let mut hasher = DefaultHasher::new(); + config.tenant_id.hash(&mut hasher); + config.client_id.hash(&mut hasher); + config.auth_flow.hash(&mut hasher); + let fingerprint = format!("{:016x}", hasher.finish()); + + let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json")); + let cached = Self::load_from_disk(&cache_path); + Ok(Self { + inner: RwLock::new(cached), + acquire_lock: Mutex::new(()), + config, + cache_path, + }) + } + + /// Get a valid access token, refreshing or re-authenticating as needed. + pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result { + // Fast path: cached and not expired. + { + let guard = self.inner.read(); + if let Some(ref state) = *guard { + if !state.is_expired() { + return Ok(state.access_token.clone()); + } + } + } + + // Slow path: serialise through a mutex so only one caller performs the + // network round-trip while concurrent callers wait and re-check. + let _lock = self.acquire_lock.lock().await; + + // Re-check after acquiring the lock — another caller may have refreshed + // while we were waiting. + { + let guard = self.inner.read(); + if let Some(ref state) = *guard { + if !state.is_expired() { + return Ok(state.access_token.clone()); + } + } + } + + let new_state = self.acquire_token(client).await?; + let token = new_state.access_token.clone(); + self.persist_to_disk(&new_state); + *self.inner.write() = Some(new_state); + Ok(token) + } + + async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result { + // Try refresh first if we have a refresh token and the flow supports it. + // Client credentials flow does not issue refresh tokens, so skip the + // attempt entirely to avoid a wasted round-trip. + if self.config.auth_flow.as_str() != "client_credentials" { + // Clone the token out so the RwLock guard is dropped before the await. + let refresh_token_copy = { + let guard = self.inner.read(); + guard.as_ref().and_then(|state| state.refresh_token.clone()) + }; + if let Some(refresh_tok) = refresh_token_copy { + match self.refresh_token(client, &refresh_tok).await { + Ok(new_state) => return Ok(new_state), + Err(e) => { + tracing::debug!("ms365: refresh token failed, re-authenticating: {e}"); + } + } + } + } + + match self.config.auth_flow.as_str() { + "client_credentials" => self.client_credentials_flow(client).await, + "device_code" => self.device_code_flow(client).await, + other => anyhow::bail!("Unsupported auth flow: {other}"), + } + } + + async fn client_credentials_flow( + &self, + client: &reqwest::Client, + ) -> anyhow::Result { + let client_secret = self + .config + .client_secret + .as_deref() + .context("client_credentials flow requires client_secret")?; + + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let scope = self.config.scopes.join(" "); + + let resp = client + .post(&token_url) + .form(&[ + ("grant_type", "client_credentials"), + ("client_id", &self.config.client_id), + ("client_secret", client_secret), + ("scope", &scope), + ]) + .send() + .await + .context("ms365: failed to request client_credentials token")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::debug!("ms365: client_credentials raw OAuth error: {body}"); + anyhow::bail!("ms365: client_credentials token request failed ({status})"); + } + + let token_resp: TokenResponse = resp + .json() + .await + .context("ms365: failed to parse token response")?; + + Ok(CachedTokenState { + access_token: token_resp.access_token, + refresh_token: token_resp.refresh_token, + expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, + }) + } + + async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result { + let device_code_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode", + self.config.tenant_id + ); + let scope = self.config.scopes.join(" "); + + let resp = client + .post(&device_code_url) + .form(&[ + ("client_id", self.config.client_id.as_str()), + ("scope", &scope), + ]) + .send() + .await + .context("ms365: failed to request device code")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::debug!("ms365: device_code initiation raw error: {body}"); + anyhow::bail!("ms365: device code request failed ({status})"); + } + + let device_resp: DeviceCodeResponse = resp + .json() + .await + .context("ms365: failed to parse device code response")?; + + // Log only a generic prompt; the full device_resp.message may contain + // sensitive verification URIs or codes that should not appear in logs. + tracing::info!( + "ms365: device code auth required — follow the instructions shown to the user" + ); + // Print the user-facing message to stderr so the operator can act on it + // without it being captured in structured log sinks. + eprintln!("ms365: {}", device_resp.message); + + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let interval = device_resp.interval.max(5); + let max_polls = u32::try_from( + (device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1), + ) + .unwrap_or(u32::MAX); + + for _ in 0..max_polls { + tokio::time::sleep(std::time::Duration::from_secs(interval)).await; + + let poll_resp = client + .post(&token_url) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("client_id", self.config.client_id.as_str()), + ("device_code", &device_resp.device_code), + ]) + .send() + .await + .context("ms365: failed to poll device code token")?; + + if poll_resp.status().is_success() { + let token_resp: TokenResponse = poll_resp + .json() + .await + .context("ms365: failed to parse token response")?; + return Ok(CachedTokenState { + access_token: token_resp.access_token, + refresh_token: token_resp.refresh_token, + expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, + }); + } + + let body = poll_resp.text().await.unwrap_or_default(); + if body.contains("authorization_pending") { + continue; + } + if body.contains("slow_down") { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + continue; + } + tracing::debug!("ms365: device code polling raw error: {body}"); + anyhow::bail!("ms365: device code polling failed"); + } + + anyhow::bail!("ms365: device code flow timed out waiting for user authorization") + } + + async fn refresh_token( + &self, + client: &reqwest::Client, + refresh_token: &str, + ) -> anyhow::Result { + let token_url = format!( + "https://login.microsoftonline.com/{}/oauth2/v2.0/token", + self.config.tenant_id + ); + + let mut params = vec![ + ("grant_type", "refresh_token"), + ("client_id", self.config.client_id.as_str()), + ("refresh_token", refresh_token), + ]; + + let secret_ref; + if let Some(ref secret) = self.config.client_secret { + secret_ref = secret.as_str(); + params.push(("client_secret", secret_ref)); + } + + let resp = client + .post(&token_url) + .form(¶ms) + .send() + .await + .context("ms365: failed to refresh token")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::debug!("ms365: token refresh raw error: {body}"); + anyhow::bail!("ms365: token refresh failed ({status})"); + } + + let token_resp: TokenResponse = resp + .json() + .await + .context("ms365: failed to parse refresh token response")?; + + Ok(CachedTokenState { + access_token: token_resp.access_token, + refresh_token: token_resp + .refresh_token + .or_else(|| Some(refresh_token.to_string())), + expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, + }) + } + + fn load_from_disk(path: &std::path::Path) -> Option { + let data = std::fs::read_to_string(path).ok()?; + serde_json::from_str(&data).ok() + } + + fn persist_to_disk(&self, state: &CachedTokenState) { + if let Ok(json) = serde_json::to_string_pretty(state) { + if let Err(e) = std::fs::write(&self.cache_path, json) { + tracing::warn!("ms365: failed to persist token cache: {e}"); + } + } + } +} + +#[derive(Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default = "default_expires_in")] + expires_in: i64, +} + +fn default_expires_in() -> i64 { + 3600 +} + +#[derive(Deserialize)] +struct DeviceCodeResponse { + device_code: String, + message: String, + #[serde(default = "default_device_interval")] + interval: u64, + #[serde(default = "default_device_expires_in")] + expires_in: i64, +} + +fn default_device_interval() -> u64 { + 5 +} + +fn default_device_expires_in() -> i64 { + 900 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_is_expired_when_past_deadline() { + let state = CachedTokenState { + access_token: "test".into(), + refresh_token: None, + expires_at: chrono::Utc::now().timestamp() - 10, + }; + assert!(state.is_expired()); + } + + #[test] + fn token_is_expired_within_buffer() { + let state = CachedTokenState { + access_token: "test".into(), + refresh_token: None, + expires_at: chrono::Utc::now().timestamp() + 30, + }; + assert!(state.is_expired()); + } + + #[test] + fn token_is_valid_when_far_from_expiry() { + let state = CachedTokenState { + access_token: "test".into(), + refresh_token: None, + expires_at: chrono::Utc::now().timestamp() + 3600, + }; + assert!(!state.is_expired()); + } + + #[test] + fn load_from_disk_returns_none_for_missing_file() { + let path = std::path::Path::new("/nonexistent/ms365_token_cache.json"); + assert!(TokenCache::load_from_disk(path).is_none()); + } +} diff --git a/src/tools/microsoft365/graph_client.rs b/src/tools/microsoft365/graph_client.rs new file mode 100644 index 000000000..0cda00247 --- /dev/null +++ b/src/tools/microsoft365/graph_client.rs @@ -0,0 +1,495 @@ +use anyhow::Context; + +const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0"; + +/// Build the user path segment: `/me` or `/users/{user_id}`. +/// The user_id is percent-encoded to prevent path-traversal attacks. +fn user_path(user_id: &str) -> String { + if user_id == "me" { + "/me".to_string() + } else { + format!("/users/{}", urlencoding::encode(user_id)) + } +} + +/// Percent-encode a single path segment to prevent path-traversal attacks. +fn encode_path_segment(segment: &str) -> String { + urlencoding::encode(segment).into_owned() +} + +/// List mail messages for a user. +pub async fn mail_list( + client: &reqwest::Client, + token: &str, + user_id: &str, + folder: Option<&str>, + top: u32, +) -> anyhow::Result { + let base = user_path(user_id); + let path = match folder { + Some(f) => format!( + "{GRAPH_BASE}{base}/mailFolders/{}/messages", + encode_path_segment(f) + ), + None => format!("{GRAPH_BASE}{base}/messages"), + }; + + let resp = client + .get(&path) + .bearer_auth(token) + .query(&[("$top", top.to_string())]) + .send() + .await + .context("ms365: mail_list request failed")?; + + handle_json_response(resp, "mail_list").await +} + +/// Send a mail message. +pub async fn mail_send( + client: &reqwest::Client, + token: &str, + user_id: &str, + to: &[String], + subject: &str, + body: &str, +) -> anyhow::Result<()> { + let base = user_path(user_id); + let url = format!("{GRAPH_BASE}{base}/sendMail"); + + let to_recipients: Vec = to + .iter() + .map(|addr| { + serde_json::json!({ + "emailAddress": { "address": addr } + }) + }) + .collect(); + + let payload = serde_json::json!({ + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": body + }, + "toRecipients": to_recipients + } + }); + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: mail_send request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: mail_send raw error body: {body}"); + anyhow::bail!("ms365: mail_send failed ({status}, code={code})"); + } + + Ok(()) +} + +/// List messages in a Teams channel. +pub async fn teams_message_list( + client: &reqwest::Client, + token: &str, + team_id: &str, + channel_id: &str, + top: u32, +) -> anyhow::Result { + let url = format!( + "{GRAPH_BASE}/teams/{}/channels/{}/messages", + encode_path_segment(team_id), + encode_path_segment(channel_id) + ); + + let resp = client + .get(&url) + .bearer_auth(token) + .query(&[("$top", top.to_string())]) + .send() + .await + .context("ms365: teams_message_list request failed")?; + + handle_json_response(resp, "teams_message_list").await +} + +/// Send a message to a Teams channel. +pub async fn teams_message_send( + client: &reqwest::Client, + token: &str, + team_id: &str, + channel_id: &str, + body: &str, +) -> anyhow::Result<()> { + let url = format!( + "{GRAPH_BASE}/teams/{}/channels/{}/messages", + encode_path_segment(team_id), + encode_path_segment(channel_id) + ); + + let payload = serde_json::json!({ + "body": { + "content": body + } + }); + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: teams_message_send request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: teams_message_send raw error body: {body}"); + anyhow::bail!("ms365: teams_message_send failed ({status}, code={code})"); + } + + Ok(()) +} + +/// List calendar events in a date range. +pub async fn calendar_events_list( + client: &reqwest::Client, + token: &str, + user_id: &str, + start: &str, + end: &str, + top: u32, +) -> anyhow::Result { + let base = user_path(user_id); + let url = format!("{GRAPH_BASE}{base}/calendarView"); + + let resp = client + .get(&url) + .bearer_auth(token) + .query(&[ + ("startDateTime", start.to_string()), + ("endDateTime", end.to_string()), + ("$top", top.to_string()), + ]) + .send() + .await + .context("ms365: calendar_events_list request failed")?; + + handle_json_response(resp, "calendar_events_list").await +} + +/// Create a calendar event. +pub async fn calendar_event_create( + client: &reqwest::Client, + token: &str, + user_id: &str, + subject: &str, + start: &str, + end: &str, + attendees: &[String], + body_text: Option<&str>, +) -> anyhow::Result { + let base = user_path(user_id); + let url = format!("{GRAPH_BASE}{base}/events"); + + let attendee_list: Vec = attendees + .iter() + .map(|email| { + serde_json::json!({ + "emailAddress": { "address": email }, + "type": "required" + }) + }) + .collect(); + + let mut payload = serde_json::json!({ + "subject": subject, + "start": { + "dateTime": start, + "timeZone": "UTC" + }, + "end": { + "dateTime": end, + "timeZone": "UTC" + }, + "attendees": attendee_list + }); + + if let Some(text) = body_text { + payload["body"] = serde_json::json!({ + "contentType": "Text", + "content": text + }); + } + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: calendar_event_create request failed")?; + + let value = handle_json_response(resp, "calendar_event_create").await?; + let event_id = value["id"].as_str().unwrap_or("unknown").to_string(); + Ok(event_id) +} + +/// Delete a calendar event by ID. +pub async fn calendar_event_delete( + client: &reqwest::Client, + token: &str, + user_id: &str, + event_id: &str, +) -> anyhow::Result<()> { + let base = user_path(user_id); + let url = format!( + "{GRAPH_BASE}{base}/events/{}", + encode_path_segment(event_id) + ); + + let resp = client + .delete(&url) + .bearer_auth(token) + .send() + .await + .context("ms365: calendar_event_delete request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: calendar_event_delete raw error body: {body}"); + anyhow::bail!("ms365: calendar_event_delete failed ({status}, code={code})"); + } + + Ok(()) +} + +/// List children of a OneDrive folder. +pub async fn onedrive_list( + client: &reqwest::Client, + token: &str, + user_id: &str, + path: Option<&str>, +) -> anyhow::Result { + let base = user_path(user_id); + let url = match path { + Some(p) if !p.is_empty() => { + let encoded = urlencoding::encode(p); + format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children") + } + _ => format!("{GRAPH_BASE}{base}/drive/root/children"), + }; + + let resp = client + .get(&url) + .bearer_auth(token) + .send() + .await + .context("ms365: onedrive_list request failed")?; + + handle_json_response(resp, "onedrive_list").await +} + +/// Download a OneDrive item by ID, with a maximum size guard. +pub async fn onedrive_download( + client: &reqwest::Client, + token: &str, + user_id: &str, + item_id: &str, + max_size: usize, +) -> anyhow::Result> { + let base = user_path(user_id); + let url = format!( + "{GRAPH_BASE}{base}/drive/items/{}/content", + encode_path_segment(item_id) + ); + + let resp = client + .get(&url) + .bearer_auth(token) + .send() + .await + .context("ms365: onedrive_download request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: onedrive_download raw error body: {body}"); + anyhow::bail!("ms365: onedrive_download failed ({status}, code={code})"); + } + + let bytes = resp + .bytes() + .await + .context("ms365: failed to read download body")?; + if bytes.len() > max_size { + anyhow::bail!( + "ms365: downloaded file exceeds max_size ({} > {max_size})", + bytes.len() + ); + } + + Ok(bytes.to_vec()) +} + +/// Search SharePoint for documents matching a query. +pub async fn sharepoint_search( + client: &reqwest::Client, + token: &str, + query: &str, + top: u32, +) -> anyhow::Result { + let url = format!("{GRAPH_BASE}/search/query"); + + let payload = serde_json::json!({ + "requests": [{ + "entityTypes": ["driveItem", "listItem", "site"], + "query": { + "queryString": query + }, + "from": 0, + "size": top + }] + }); + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&payload) + .send() + .await + .context("ms365: sharepoint_search request failed")?; + + handle_json_response(resp, "sharepoint_search").await +} + +/// Extract a short, safe error code from a Graph API JSON error body. +/// Returns `None` when the body is not a recognised Graph error envelope. +fn extract_graph_error_code(body: &str) -> Option { + let parsed: serde_json::Value = serde_json::from_str(body).ok()?; + let code = parsed + .get("error") + .and_then(|e| e.get("code")) + .and_then(|c| c.as_str()) + .map(|s| s.to_string()); + code +} + +/// Parse a JSON response body, returning an error on non-success status. +/// Raw Graph API error bodies are not propagated; only the HTTP status and a +/// short error code (when available) are surfaced to avoid leaking internal +/// API details. +async fn handle_json_response( + resp: reqwest::Response, + operation: &str, +) -> anyhow::Result { + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string()); + tracing::debug!("ms365: {operation} raw error body: {body}"); + anyhow::bail!("ms365: {operation} failed ({status}, code={code})"); + } + + resp.json() + .await + .with_context(|| format!("ms365: failed to parse {operation} response")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn user_path_me() { + assert_eq!(user_path("me"), "/me"); + } + + #[test] + fn user_path_specific_user() { + assert_eq!(user_path("user@contoso.com"), "/users/user%40contoso.com"); + } + + #[test] + fn mail_list_url_no_folder() { + let base = user_path("me"); + let url = format!("{GRAPH_BASE}{base}/messages"); + assert_eq!(url, "https://graph.microsoft.com/v1.0/me/messages"); + } + + #[test] + fn mail_list_url_with_folder() { + let base = user_path("me"); + let folder = "inbox"; + let url = format!( + "{GRAPH_BASE}{base}/mailFolders/{}/messages", + encode_path_segment(folder) + ); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/me/mailFolders/inbox/messages" + ); + } + + #[test] + fn calendar_view_url() { + let base = user_path("user@example.com"); + let url = format!("{GRAPH_BASE}{base}/calendarView"); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/users/user%40example.com/calendarView" + ); + } + + #[test] + fn teams_message_url() { + let url = format!( + "{GRAPH_BASE}/teams/{}/channels/{}/messages", + encode_path_segment("team-123"), + encode_path_segment("channel-456") + ); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/teams/team-123/channels/channel-456/messages" + ); + } + + #[test] + fn onedrive_root_url() { + let base = user_path("me"); + let url = format!("{GRAPH_BASE}{base}/drive/root/children"); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/me/drive/root/children" + ); + } + + #[test] + fn onedrive_path_url() { + let base = user_path("me"); + let encoded = urlencoding::encode("Documents/Reports"); + let url = format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children"); + assert_eq!( + url, + "https://graph.microsoft.com/v1.0/me/drive/root:/Documents%2FReports:/children" + ); + } + + #[test] + fn sharepoint_search_url() { + let url = format!("{GRAPH_BASE}/search/query"); + assert_eq!(url, "https://graph.microsoft.com/v1.0/search/query"); + } +} diff --git a/src/tools/microsoft365/mod.rs b/src/tools/microsoft365/mod.rs new file mode 100644 index 000000000..1876556e5 --- /dev/null +++ b/src/tools/microsoft365/mod.rs @@ -0,0 +1,567 @@ +//! Microsoft 365 integration tool — Graph API access for Mail, Teams, Calendar, +//! OneDrive, and SharePoint via a single action-dispatched tool surface. +//! +//! Auth is handled through direct HTTP calls to the Microsoft identity platform +//! (client credentials or device code flow) with token caching. + +pub mod auth; +pub mod graph_client; +pub mod types; + +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; +use crate::tools::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +/// Maximum download size for OneDrive files (10 MB). +const MAX_ONEDRIVE_DOWNLOAD_SIZE: usize = 10 * 1024 * 1024; + +/// Default number of items to return in list operations. +const DEFAULT_TOP: u32 = 25; + +pub struct Microsoft365Tool { + config: types::Microsoft365ResolvedConfig, + security: Arc, + token_cache: Arc, + http_client: reqwest::Client, +} + +impl Microsoft365Tool { + pub fn new( + config: types::Microsoft365ResolvedConfig, + security: Arc, + zeroclaw_dir: &std::path::Path, + ) -> anyhow::Result { + let http_client = + crate::config::build_runtime_proxy_client_with_timeouts("tool.microsoft365", 60, 10); + let token_cache = Arc::new(auth::TokenCache::new(config.clone(), zeroclaw_dir)?); + Ok(Self { + config, + security, + token_cache, + http_client, + }) + } + + async fn get_token(&self) -> anyhow::Result { + self.token_cache.get_token(&self.http_client).await + } + + fn user_id(&self) -> &str { + &self.config.user_id + } + + async fn dispatch(&self, action: &str, args: &serde_json::Value) -> anyhow::Result { + match action { + "mail_list" => self.handle_mail_list(args).await, + "mail_send" => self.handle_mail_send(args).await, + "teams_message_list" => self.handle_teams_message_list(args).await, + "teams_message_send" => self.handle_teams_message_send(args).await, + "calendar_events_list" => self.handle_calendar_events_list(args).await, + "calendar_event_create" => self.handle_calendar_event_create(args).await, + "calendar_event_delete" => self.handle_calendar_event_delete(args).await, + "onedrive_list" => self.handle_onedrive_list(args).await, + "onedrive_download" => self.handle_onedrive_download(args).await, + "sharepoint_search" => self.handle_sharepoint_search(args).await, + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown action: {action}")), + }), + } + } + + // ── Read actions ──────────────────────────────────────────────── + + async fn handle_mail_list(&self, args: &serde_json::Value) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.mail_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let folder = args["folder"].as_str(); + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = + graph_client::mail_list(&self.http_client, &token, self.user_id(), folder, top).await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_teams_message_list( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.teams_message_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let team_id = args["team_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("team_id is required"))?; + let channel_id = args["channel_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("channel_id is required"))?; + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = + graph_client::teams_message_list(&self.http_client, &token, team_id, channel_id, top) + .await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_calendar_events_list( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.calendar_events_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let start = args["start"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("start datetime is required"))?; + let end = args["end"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("end datetime is required"))?; + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = graph_client::calendar_events_list( + &self.http_client, + &token, + self.user_id(), + start, + end, + top, + ) + .await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_onedrive_list(&self, args: &serde_json::Value) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_list") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let path = args["path"].as_str(); + + let result = + graph_client::onedrive_list(&self.http_client, &token, self.user_id(), path).await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + async fn handle_onedrive_download( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_download") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let item_id = args["item_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("item_id is required"))?; + let max_size = args["max_size"] + .as_u64() + .and_then(|v| usize::try_from(v).ok()) + .unwrap_or(MAX_ONEDRIVE_DOWNLOAD_SIZE) + .min(MAX_ONEDRIVE_DOWNLOAD_SIZE); + + let bytes = graph_client::onedrive_download( + &self.http_client, + &token, + self.user_id(), + item_id, + max_size, + ) + .await?; + + // Return base64-encoded for binary safety. + use base64::Engine; + let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + + Ok(ToolResult { + success: true, + output: format!( + "Downloaded {} bytes (base64 encoded):\n{encoded}", + bytes.len() + ), + error: None, + }) + } + + async fn handle_sharepoint_search( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Read, "microsoft365.sharepoint_search") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let query = args["query"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("query is required"))?; + let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP))) + .unwrap_or(DEFAULT_TOP); + + let result = graph_client::sharepoint_search(&self.http_client, &token, query, top).await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result)?, + error: None, + }) + } + + // ── Write actions ─────────────────────────────────────────────── + + async fn handle_mail_send(&self, args: &serde_json::Value) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.mail_send") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let to: Vec = args["to"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("to must be an array of email addresses"))? + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + if to.is_empty() { + anyhow::bail!("to must contain at least one email address"); + } + + let subject = args["subject"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("subject is required"))?; + let body = args["body"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("body is required"))?; + + graph_client::mail_send( + &self.http_client, + &token, + self.user_id(), + &to, + subject, + body, + ) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Email sent to: {}", to.join(", ")), + error: None, + }) + } + + async fn handle_teams_message_send( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.teams_message_send") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let team_id = args["team_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("team_id is required"))?; + let channel_id = args["channel_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("channel_id is required"))?; + let body = args["body"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("body is required"))?; + + graph_client::teams_message_send(&self.http_client, &token, team_id, channel_id, body) + .await?; + + Ok(ToolResult { + success: true, + output: "Teams message sent".to_string(), + error: None, + }) + } + + async fn handle_calendar_event_create( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_create") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let subject = args["subject"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("subject is required"))?; + let start = args["start"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("start datetime is required"))?; + let end = args["end"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("end datetime is required"))?; + let attendees: Vec = args["attendees"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + let body_text = args["body"].as_str(); + + let event_id = graph_client::calendar_event_create( + &self.http_client, + &token, + self.user_id(), + subject, + start, + end, + &attendees, + body_text, + ) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Calendar event created (id: {event_id})"), + error: None, + }) + } + + async fn handle_calendar_event_delete( + &self, + args: &serde_json::Value, + ) -> anyhow::Result { + self.security + .enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_delete") + .map_err(|e| anyhow::anyhow!(e))?; + + let token = self.get_token().await?; + let event_id = args["event_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("event_id is required"))?; + + graph_client::calendar_event_delete(&self.http_client, &token, self.user_id(), event_id) + .await?; + + Ok(ToolResult { + success: true, + output: format!("Calendar event {event_id} deleted"), + error: None, + }) + } +} + +#[async_trait] +impl Tool for Microsoft365Tool { + fn name(&self) -> &str { + "microsoft365" + } + + fn description(&self) -> &str { + "Microsoft 365 integration: manage Outlook mail, Teams messages, Calendar events, \ + OneDrive files, and SharePoint search via Microsoft Graph API" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "required": ["action"], + "properties": { + "action": { + "type": "string", + "enum": [ + "mail_list", + "mail_send", + "teams_message_list", + "teams_message_send", + "calendar_events_list", + "calendar_event_create", + "calendar_event_delete", + "onedrive_list", + "onedrive_download", + "sharepoint_search" + ], + "description": "The Microsoft 365 action to perform" + }, + "folder": { + "type": "string", + "description": "Mail folder ID (for mail_list, e.g. 'inbox', 'sentitems')" + }, + "to": { + "type": "array", + "items": { "type": "string" }, + "description": "Recipient email addresses (for mail_send)" + }, + "subject": { + "type": "string", + "description": "Email subject or calendar event subject" + }, + "body": { + "type": "string", + "description": "Message body text" + }, + "team_id": { + "type": "string", + "description": "Teams team ID (for teams_message_list/send)" + }, + "channel_id": { + "type": "string", + "description": "Teams channel ID (for teams_message_list/send)" + }, + "start": { + "type": "string", + "description": "Start datetime in ISO 8601 format (for calendar actions)" + }, + "end": { + "type": "string", + "description": "End datetime in ISO 8601 format (for calendar actions)" + }, + "attendees": { + "type": "array", + "items": { "type": "string" }, + "description": "Attendee email addresses (for calendar_event_create)" + }, + "event_id": { + "type": "string", + "description": "Calendar event ID (for calendar_event_delete)" + }, + "path": { + "type": "string", + "description": "OneDrive folder path (for onedrive_list)" + }, + "item_id": { + "type": "string", + "description": "OneDrive item ID (for onedrive_download)" + }, + "max_size": { + "type": "integer", + "description": "Maximum download size in bytes (for onedrive_download, default 10MB)" + }, + "query": { + "type": "string", + "description": "Search query (for sharepoint_search)" + }, + "top": { + "type": "integer", + "description": "Maximum number of items to return (default 25)" + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = match args["action"].as_str() { + Some(a) => a.to_string(), + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'action' parameter is required".to_string()), + }); + } + }; + + match self.dispatch(&action, &args).await { + Ok(result) => Ok(result), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("microsoft365.{action} failed: {e}")), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tool_name_is_microsoft365() { + // Verify the schema is valid JSON with the expected structure. + let schema_str = r#"{"type":"object","required":["action"]}"#; + let _: serde_json::Value = serde_json::from_str(schema_str).unwrap(); + } + + #[test] + fn parameters_schema_has_action_enum() { + let schema = json!({ + "type": "object", + "required": ["action"], + "properties": { + "action": { + "type": "string", + "enum": [ + "mail_list", + "mail_send", + "teams_message_list", + "teams_message_send", + "calendar_events_list", + "calendar_event_create", + "calendar_event_delete", + "onedrive_list", + "onedrive_download", + "sharepoint_search" + ] + } + } + }); + + let actions = schema["properties"]["action"]["enum"].as_array().unwrap(); + assert_eq!(actions.len(), 10); + assert!(actions.contains(&json!("mail_list"))); + assert!(actions.contains(&json!("sharepoint_search"))); + } + + #[test] + fn action_dispatch_table_is_exhaustive() { + let valid_actions = [ + "mail_list", + "mail_send", + "teams_message_list", + "teams_message_send", + "calendar_events_list", + "calendar_event_create", + "calendar_event_delete", + "onedrive_list", + "onedrive_download", + "sharepoint_search", + ]; + assert_eq!(valid_actions.len(), 10); + assert!(!valid_actions.contains(&"invalid_action")); + } +} diff --git a/src/tools/microsoft365/types.rs b/src/tools/microsoft365/types.rs new file mode 100644 index 000000000..72a71f0a5 --- /dev/null +++ b/src/tools/microsoft365/types.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +/// Resolved Microsoft 365 configuration with all secrets decrypted and defaults applied. +#[derive(Clone, Serialize, Deserialize)] +pub struct Microsoft365ResolvedConfig { + pub tenant_id: String, + pub client_id: String, + pub client_secret: Option, + pub auth_flow: String, + pub scopes: Vec, + pub token_cache_encrypted: bool, + pub user_id: String, +} + +impl std::fmt::Debug for Microsoft365ResolvedConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Microsoft365ResolvedConfig") + .field("tenant_id", &self.tenant_id) + .field("client_id", &self.client_id) + .field("client_secret", &self.client_secret.as_ref().map(|_| "***")) + .field("auth_flow", &self.auth_flow) + .field("scopes", &self.scopes) + .field("token_cache_encrypted", &self.token_cache_encrypted) + .field("user_id", &self.user_id) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolved_config_serialization_roundtrip() { + let config = Microsoft365ResolvedConfig { + tenant_id: "test-tenant".into(), + client_id: "test-client".into(), + client_secret: Some("secret".into()), + auth_flow: "client_credentials".into(), + scopes: vec!["https://graph.microsoft.com/.default".into()], + token_cache_encrypted: false, + user_id: "me".into(), + }; + + let json = serde_json::to_string(&config).unwrap(); + let parsed: Microsoft365ResolvedConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.tenant_id, "test-tenant"); + assert_eq!(parsed.client_id, "test-client"); + assert_eq!(parsed.client_secret.as_deref(), Some("secret")); + assert_eq!(parsed.auth_flow, "client_credentials"); + assert_eq!(parsed.scopes.len(), 1); + assert_eq!(parsed.user_id, "me"); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 79c34e279..8f1f73b01 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -50,14 +50,19 @@ pub mod mcp_transport; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod microsoft365; pub mod model_routing_config; pub mod node_tool; +pub mod notion_tool; pub mod pdf_read; +pub mod project_intel; pub mod proxy_config; pub mod pushover; +pub mod report_templates; pub mod schedule; pub mod schema; pub mod screenshot; +pub mod security_ops; pub mod shell; pub mod swarm; pub mod tool_search; @@ -98,16 +103,20 @@ pub use mcp_tool::McpToolWrapper; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use microsoft365::Microsoft365Tool; pub use model_routing_config::ModelRoutingConfigTool; #[allow(unused_imports)] pub use node_tool::NodeTool; +pub use notion_tool::NotionTool; pub use pdf_read::PdfReadTool; +pub use project_intel::ProjectIntelTool; pub use proxy_config::ProxyConfigTool; pub use pushover::PushoverTool; pub use schedule::ScheduleTool; #[allow(unused_imports)] pub use schema::{CleaningStrategy, SchemaCleanr}; pub use screenshot::ScreenshotTool; +pub use security_ops::SecurityOpsTool; pub use shell::ShellTool; pub use swarm::SwarmTool; pub use tool_search::ToolSearchTool; @@ -348,8 +357,36 @@ pub fn all_tools_with_runtime( ))); } - // PDF extraction (feature-gated at compile time via rag-pdf) - tool_arcs.push(Arc::new(PdfReadTool::new(security.clone()))); + // Notion API tool (conditionally registered) + if root_config.notion.enabled { + let notion_api_key = if root_config.notion.api_key.trim().is_empty() { + std::env::var("NOTION_API_KEY").unwrap_or_default() + } else { + root_config.notion.api_key.trim().to_string() + }; + if notion_api_key.trim().is_empty() { + tracing::warn!( + "Notion tool enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)" + ); + } else { + tool_arcs.push(Arc::new(NotionTool::new(notion_api_key, security.clone()))); + } + } + + // Project delivery intelligence + if root_config.project_intel.enabled { + tool_arcs.push(Arc::new(ProjectIntelTool::new( + root_config.project_intel.default_language.clone(), + root_config.project_intel.risk_sensitivity.clone(), + ))); + } + + // MCSS Security Operations + if root_config.security_ops.enabled { + tool_arcs.push(Arc::new(SecurityOpsTool::new( + root_config.security_ops.clone(), + ))); + } // Backup tool (enabled by default) if root_config.backup.enabled { @@ -368,6 +405,9 @@ pub fn all_tools_with_runtime( ))); } + // PDF extraction (feature-gated at compile time via rag-pdf) + tool_arcs.push(Arc::new(PdfReadTool::new(security.clone()))); + // Vision tools are always available tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone()))); tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone()))); @@ -382,6 +422,61 @@ pub fn all_tools_with_runtime( } } + // Microsoft 365 Graph API integration + if root_config.microsoft365.enabled { + let ms_cfg = &root_config.microsoft365; + let tenant_id = ms_cfg + .tenant_id + .as_deref() + .unwrap_or_default() + .trim() + .to_string(); + let client_id = ms_cfg + .client_id + .as_deref() + .unwrap_or_default() + .trim() + .to_string(); + if !tenant_id.is_empty() && !client_id.is_empty() { + // Fail fast: client_credentials flow requires a client_secret at registration time. + if ms_cfg.auth_flow.trim() == "client_credentials" + && ms_cfg + .client_secret + .as_deref() + .map_or(true, |s| s.trim().is_empty()) + { + tracing::error!( + "microsoft365: client_credentials auth_flow requires a non-empty client_secret" + ); + return (boxed_registry_from_arcs(tool_arcs), None); + } + + let resolved = microsoft365::types::Microsoft365ResolvedConfig { + tenant_id, + client_id, + client_secret: ms_cfg.client_secret.clone(), + auth_flow: ms_cfg.auth_flow.clone(), + scopes: ms_cfg.scopes.clone(), + token_cache_encrypted: ms_cfg.token_cache_encrypted, + user_id: ms_cfg.user_id.as_deref().unwrap_or("me").to_string(), + }; + // Store token cache in the config directory (next to config.toml), + // not the workspace directory, to keep bearer tokens out of the + // project tree. + let cache_dir = root_config.config_path.parent().unwrap_or(workspace_dir); + match Microsoft365Tool::new(resolved, security.clone(), cache_dir) { + Ok(tool) => tool_arcs.push(Arc::new(tool)), + Err(e) => { + tracing::error!("microsoft365: failed to initialize tool: {e}"); + } + } + } else { + tracing::warn!( + "microsoft365: skipped registration because tenant_id or client_id is empty" + ); + } + } + // Add delegation tool when agents are configured let delegate_fallback_credential = fallback_api_key.and_then(|value| { let trimmed_value = value.trim(); diff --git a/src/tools/notion_tool.rs b/src/tools/notion_tool.rs new file mode 100644 index 000000000..4fb044d89 --- /dev/null +++ b/src/tools/notion_tool.rs @@ -0,0 +1,438 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::{policy::ToolOperation, SecurityPolicy}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +const NOTION_API_BASE: &str = "https://api.notion.com/v1"; +const NOTION_VERSION: &str = "2022-06-28"; +const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30; +/// Maximum number of characters to include from an error response body. +const MAX_ERROR_BODY_CHARS: usize = 500; + +/// Tool for interacting with the Notion API — query databases, read/create/update pages, +/// and search the workspace. Each action is gated by the appropriate security operation +/// (Read for queries, Act for mutations). +pub struct NotionTool { + api_key: String, + http: reqwest::Client, + security: Arc, +} + +impl NotionTool { + /// Create a new Notion tool with the given API key and security policy. + pub fn new(api_key: String, security: Arc) -> Self { + Self { + api_key, + http: reqwest::Client::new(), + security, + } + } + + /// Build the standard Notion API headers (Authorization, version, content-type). + fn headers(&self) -> anyhow::Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", self.api_key) + .parse() + .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?, + ); + headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap()); + headers.insert("Content-Type", "application/json".parse().unwrap()); + Ok(headers) + } + + /// Query a Notion database with an optional filter. + async fn query_database( + &self, + database_id: &str, + filter: Option<&serde_json::Value>, + ) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/databases/{database_id}/query"); + let mut body = json!({}); + if let Some(f) = filter { + body["filter"] = f.clone(); + } + let resp = self + .http + .post(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion query_database failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Read a single Notion page by ID. + async fn read_page(&self, page_id: &str) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let resp = self + .http + .get(&url) + .headers(self.headers()?) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion read_page failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Create a new Notion page, optionally within a database. + async fn create_page( + &self, + properties: &serde_json::Value, + database_id: Option<&str>, + ) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/pages"); + let mut body = json!({ "properties": properties }); + if let Some(db_id) = database_id { + body["parent"] = json!({ "database_id": db_id }); + } + let resp = self + .http + .post(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion create_page failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Update an existing Notion page's properties. + async fn update_page( + &self, + page_id: &str, + properties: &serde_json::Value, + ) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/pages/{page_id}"); + let body = json!({ "properties": properties }); + let resp = self + .http + .patch(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion update_page failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } + + /// Search the Notion workspace by query string. + async fn search(&self, query: &str) -> anyhow::Result { + let url = format!("{NOTION_API_BASE}/search"); + let body = json!({ "query": query }); + let resp = self + .http + .post(&url) + .headers(self.headers()?) + .json(&body) + .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS)) + .send() + .await?; + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS); + anyhow::bail!("Notion search failed ({status}): {truncated}"); + } + resp.json().await.map_err(Into::into) + } +} + +#[async_trait] +impl Tool for NotionTool { + fn name(&self) -> &str { + "notion" + } + + fn description(&self) -> &str { + "Interact with Notion: query databases, read/create/update pages, and search the workspace." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["query_database", "read_page", "create_page", "update_page", "search"], + "description": "The Notion API action to perform" + }, + "database_id": { + "type": "string", + "description": "Database ID (required for query_database, optional for create_page)" + }, + "page_id": { + "type": "string", + "description": "Page ID (required for read_page and update_page)" + }, + "filter": { + "type": "object", + "description": "Notion filter object for query_database" + }, + "properties": { + "type": "object", + "description": "Properties object for create_page and update_page" + }, + "query": { + "type": "string", + "description": "Search query string for the search action" + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = match args.get("action").and_then(|v| v.as_str()) { + Some(a) => a, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: action".into()), + }); + } + }; + + // Enforce granular security: Read for queries, Act for mutations + let operation = match action { + "query_database" | "read_page" | "search" => ToolOperation::Read, + "create_page" | "update_page" => ToolOperation::Act, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search" + )), + }); + } + }; + + if let Err(error) = self.security.enforce_tool_operation(operation, "notion") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + + let result = match action { + "query_database" => { + let database_id = match args.get("database_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("query_database requires database_id parameter".into()), + }); + } + }; + let filter = args.get("filter"); + self.query_database(database_id, filter).await + } + "read_page" => { + let page_id = match args.get("page_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("read_page requires page_id parameter".into()), + }); + } + }; + self.read_page(page_id).await + } + "create_page" => { + let properties = match args.get("properties") { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("create_page requires properties parameter".into()), + }); + } + }; + let database_id = args.get("database_id").and_then(|v| v.as_str()); + self.create_page(properties, database_id).await + } + "update_page" => { + let page_id = match args.get("page_id").and_then(|v| v.as_str()) { + Some(id) => id, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("update_page requires page_id parameter".into()), + }); + } + }; + let properties = match args.get("properties") { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("update_page requires properties parameter".into()), + }); + } + }; + self.update_page(page_id, properties).await + } + "search" => { + let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); + self.search(query).await + } + _ => unreachable!(), // Already handled above + }; + + match result { + Ok(value) => Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::SecurityPolicy; + + fn test_tool() -> NotionTool { + let security = Arc::new(SecurityPolicy::default()); + NotionTool::new("test-key".into(), security) + } + + #[test] + fn tool_name_is_notion() { + let tool = test_tool(); + assert_eq!(tool.name(), "notion"); + } + + #[test] + fn parameters_schema_has_required_action() { + let tool = test_tool(); + let schema = tool.parameters_schema(); + let required = schema["required"].as_array().unwrap(); + assert!(required.iter().any(|v| v.as_str() == Some("action"))); + } + + #[test] + fn parameters_schema_defines_all_actions() { + let tool = test_tool(); + let schema = tool.parameters_schema(); + let actions = schema["properties"]["action"]["enum"].as_array().unwrap(); + let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect(); + assert!(action_strs.contains(&"query_database")); + assert!(action_strs.contains(&"read_page")); + assert!(action_strs.contains(&"create_page")); + assert!(action_strs.contains(&"update_page")); + assert!(action_strs.contains(&"search")); + } + + #[tokio::test] + async fn execute_missing_action_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("action")); + } + + #[tokio::test] + async fn execute_unknown_action_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "invalid"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("Unknown action")); + } + + #[tokio::test] + async fn execute_query_database_missing_id_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "query_database"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("database_id")); + } + + #[tokio::test] + async fn execute_read_page_missing_id_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "read_page"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("page_id")); + } + + #[tokio::test] + async fn execute_create_page_missing_properties_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "create_page"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("properties")); + } + + #[tokio::test] + async fn execute_update_page_missing_page_id_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "update_page", "properties": {}})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("page_id")); + } + + #[tokio::test] + async fn execute_update_page_missing_properties_returns_error() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "update_page", "page_id": "test-id"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("properties")); + } +} diff --git a/src/tools/project_intel.rs b/src/tools/project_intel.rs new file mode 100644 index 000000000..0e3372eb8 --- /dev/null +++ b/src/tools/project_intel.rs @@ -0,0 +1,750 @@ +//! Project delivery intelligence tool. +//! +//! Provides read-only analysis and generation for project management: +//! status reports, risk detection, client communication drafting, +//! sprint summaries, and effort estimation. + +use super::report_templates; +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::collections::HashMap; +use std::fmt::Write as _; + +/// Project intelligence tool for consulting project management. +/// +/// All actions are read-only analysis/generation; nothing is modified externally. +pub struct ProjectIntelTool { + default_language: String, + risk_sensitivity: RiskSensitivity, +} + +/// Risk detection sensitivity level. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RiskSensitivity { + Low, + Medium, + High, +} + +impl RiskSensitivity { + fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "low" => Self::Low, + "high" => Self::High, + _ => Self::Medium, + } + } + + /// Threshold multiplier: higher sensitivity means lower thresholds. + fn threshold_factor(self) -> f64 { + match self { + Self::Low => 1.5, + Self::Medium => 1.0, + Self::High => 0.5, + } + } +} + +impl ProjectIntelTool { + pub fn new(default_language: String, risk_sensitivity: String) -> Self { + Self { + default_language, + risk_sensitivity: RiskSensitivity::from_str(&risk_sensitivity), + } + } + + fn execute_status_report(&self, args: &serde_json::Value) -> anyhow::Result { + let project_name = args + .get("project_name") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for status_report"))?; + let period = args + .get("period") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'period' for status_report"))?; + let lang = args + .get("language") + .and_then(|v| v.as_str()) + .unwrap_or(&self.default_language); + let git_log = args + .get("git_log") + .and_then(|v| v.as_str()) + .unwrap_or("No git data provided"); + let jira_summary = args + .get("jira_summary") + .and_then(|v| v.as_str()) + .unwrap_or("No Jira data provided"); + let notes = args.get("notes").and_then(|v| v.as_str()).unwrap_or(""); + + let tpl = report_templates::weekly_status_template(lang); + let mut vars = HashMap::new(); + vars.insert("project_name".into(), project_name.to_string()); + vars.insert("period".into(), period.to_string()); + vars.insert("completed".into(), git_log.to_string()); + vars.insert("in_progress".into(), jira_summary.to_string()); + vars.insert("blocked".into(), notes.to_string()); + vars.insert("next_steps".into(), "To be determined".into()); + + let rendered = tpl.render(&vars); + Ok(ToolResult { + success: true, + output: rendered, + error: None, + }) + } + + fn execute_risk_scan(&self, args: &serde_json::Value) -> anyhow::Result { + let deadlines = args + .get("deadlines") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let velocity = args + .get("velocity") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let blockers = args + .get("blockers") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let lang = args + .get("language") + .and_then(|v| v.as_str()) + .unwrap_or(&self.default_language); + + let mut risks = Vec::new(); + + // Heuristic risk detection based on signals + let factor = self.risk_sensitivity.threshold_factor(); + + if !blockers.is_empty() { + let blocker_count = blockers.lines().filter(|l| !l.trim().is_empty()).count(); + let severity = if (blocker_count as f64) > 3.0 * factor { + "critical" + } else if (blocker_count as f64) > 1.0 * factor { + "high" + } else { + "medium" + }; + risks.push(RiskItem { + title: "Active blockers detected".into(), + severity: severity.into(), + detail: format!("{blocker_count} blocker(s) identified"), + mitigation: "Escalate blockers, assign owners, set resolution deadlines".into(), + }); + } + + if deadlines.to_lowercase().contains("overdue") + || deadlines.to_lowercase().contains("missed") + { + risks.push(RiskItem { + title: "Deadline risk".into(), + severity: "high".into(), + detail: "Overdue or missed deadlines detected in project context".into(), + mitigation: "Re-prioritize scope, negotiate timeline, add resources".into(), + }); + } + + if velocity.to_lowercase().contains("declining") || velocity.to_lowercase().contains("slow") + { + risks.push(RiskItem { + title: "Velocity degradation".into(), + severity: "medium".into(), + detail: "Team velocity is declining or below expectations".into(), + mitigation: "Identify bottlenecks, reduce WIP, address technical debt".into(), + }); + } + + if risks.is_empty() { + risks.push(RiskItem { + title: "No significant risks detected".into(), + severity: "low".into(), + detail: "Current project signals within normal parameters".into(), + mitigation: "Continue monitoring".into(), + }); + } + + let tpl = report_templates::risk_register_template(lang); + let risks_text = risks + .iter() + .map(|r| { + format!( + "- [{}] {}: {}", + r.severity.to_uppercase(), + r.title, + r.detail + ) + }) + .collect::>() + .join("\n"); + let mitigations_text = risks + .iter() + .map(|r| format!("- {}: {}", r.title, r.mitigation)) + .collect::>() + .join("\n"); + + let mut vars = HashMap::new(); + vars.insert( + "project_name".into(), + args.get("project_name") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown") + .to_string(), + ); + vars.insert("risks".into(), risks_text); + vars.insert("mitigations".into(), mitigations_text); + + Ok(ToolResult { + success: true, + output: tpl.render(&vars), + error: None, + }) + } + + fn execute_draft_update(&self, args: &serde_json::Value) -> anyhow::Result { + let project_name = args + .get("project_name") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for draft_update"))?; + let audience = args + .get("audience") + .and_then(|v| v.as_str()) + .unwrap_or("client"); + let tone = args + .get("tone") + .and_then(|v| v.as_str()) + .unwrap_or("formal"); + let highlights = args + .get("highlights") + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("missing required 'highlights' for draft_update"))?; + let concerns = args.get("concerns").and_then(|v| v.as_str()).unwrap_or(""); + + let greeting = match (audience, tone) { + ("client", "casual") => "Hi there,".to_string(), + ("client", _) => "Dear valued partner,".to_string(), + ("internal", "casual") => "Hey team,".to_string(), + ("internal", _) => "Dear team,".to_string(), + (_, "casual") => "Hi,".to_string(), + _ => "Dear reader,".to_string(), + }; + + let closing = match tone { + "casual" => "Cheers", + _ => "Best regards", + }; + + let mut body = format!( + "{greeting}\n\nHere is an update on {project_name}.\n\n**Highlights:**\n{highlights}" + ); + if !concerns.is_empty() { + let _ = write!(body, "\n\n**Items requiring attention:**\n{concerns}"); + } + let _ = write!( + body, + "\n\nPlease do not hesitate to reach out with any questions.\n\n{closing}" + ); + + Ok(ToolResult { + success: true, + output: body, + error: None, + }) + } + + fn execute_sprint_summary(&self, args: &serde_json::Value) -> anyhow::Result { + let sprint_dates = args + .get("sprint_dates") + .and_then(|v| v.as_str()) + .unwrap_or("current sprint"); + let completed = args + .get("completed") + .and_then(|v| v.as_str()) + .unwrap_or("None specified"); + let in_progress = args + .get("in_progress") + .and_then(|v| v.as_str()) + .unwrap_or("None specified"); + let blocked = args + .get("blocked") + .and_then(|v| v.as_str()) + .unwrap_or("None"); + let velocity = args + .get("velocity") + .and_then(|v| v.as_str()) + .unwrap_or("Not calculated"); + let lang = args + .get("language") + .and_then(|v| v.as_str()) + .unwrap_or(&self.default_language); + + let tpl = report_templates::sprint_review_template(lang); + let mut vars = HashMap::new(); + vars.insert("sprint_dates".into(), sprint_dates.to_string()); + vars.insert("completed".into(), completed.to_string()); + vars.insert("in_progress".into(), in_progress.to_string()); + vars.insert("blocked".into(), blocked.to_string()); + vars.insert("velocity".into(), velocity.to_string()); + + Ok(ToolResult { + success: true, + output: tpl.render(&vars), + error: None, + }) + } + + fn execute_effort_estimate(&self, args: &serde_json::Value) -> anyhow::Result { + let tasks = args.get("tasks").and_then(|v| v.as_str()).unwrap_or(""); + + if tasks.trim().is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("No task descriptions provided".into()), + }); + } + + let mut estimates = Vec::new(); + for line in tasks.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + let (size, rationale) = estimate_task_effort(line); + estimates.push(format!("- **{size}** | {line}\n Rationale: {rationale}")); + } + + let output = format!( + "## Effort Estimates\n\n{}\n\n_Sizes: XS (<2h), S (2-4h), M (4-8h), L (1-3d), XL (3-5d), XXL (>5d)_", + estimates.join("\n") + ); + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +struct RiskItem { + title: String, + severity: String, + detail: String, + mitigation: String, +} + +/// Heuristic effort estimation from task description text. +fn estimate_task_effort(description: &str) -> (&'static str, &'static str) { + let lower = description.to_lowercase(); + let word_count = description.split_whitespace().count(); + + // Signal-based heuristics + let complexity_signals = [ + "refactor", + "rewrite", + "migrate", + "redesign", + "architecture", + "infrastructure", + ]; + let medium_signals = [ + "implement", + "create", + "build", + "integrate", + "add feature", + "new module", + ]; + let small_signals = [ + "fix", "update", "tweak", "adjust", "rename", "typo", "bump", "config", + ]; + + if complexity_signals.iter().any(|s| lower.contains(s)) { + if word_count > 15 { + return ( + "XXL", + "Large-scope structural change with extensive description", + ); + } + return ("XL", "Structural change requiring significant effort"); + } + + if medium_signals.iter().any(|s| lower.contains(s)) { + if word_count > 12 { + return ("L", "Feature implementation with detailed requirements"); + } + return ("M", "Standard feature implementation"); + } + + if small_signals.iter().any(|s| lower.contains(s)) { + if word_count > 10 { + return ("S", "Small change with additional context"); + } + return ("XS", "Minor targeted change"); + } + + // Fallback: estimate by description length as a proxy for complexity + if word_count > 20 { + ("L", "Complex task inferred from detailed description") + } else if word_count > 10 { + ("M", "Moderate task inferred from description length") + } else { + ("S", "Simple task inferred from brief description") + } +} + +#[async_trait] +impl Tool for ProjectIntelTool { + fn name(&self) -> &str { + "project_intel" + } + + fn description(&self) -> &str { + "Project delivery intelligence: generate status reports, detect risks, draft client updates, summarize sprints, and estimate effort. Read-only analysis tool." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["status_report", "risk_scan", "draft_update", "sprint_summary", "effort_estimate"], + "description": "The analysis action to perform" + }, + "project_name": { + "type": "string", + "description": "Project name (for status_report, risk_scan, draft_update)" + }, + "period": { + "type": "string", + "description": "Reporting period: week, sprint, or month (for status_report)" + }, + "language": { + "type": "string", + "description": "Report language: en, de, fr, it (default from config)" + }, + "git_log": { + "type": "string", + "description": "Git log summary text (for status_report)" + }, + "jira_summary": { + "type": "string", + "description": "Jira/issue tracker summary (for status_report)" + }, + "notes": { + "type": "string", + "description": "Additional notes or context" + }, + "deadlines": { + "type": "string", + "description": "Deadline information (for risk_scan)" + }, + "velocity": { + "type": "string", + "description": "Team velocity data (for risk_scan, sprint_summary)" + }, + "blockers": { + "type": "string", + "description": "Current blockers (for risk_scan)" + }, + "audience": { + "type": "string", + "enum": ["client", "internal"], + "description": "Target audience (for draft_update)" + }, + "tone": { + "type": "string", + "enum": ["formal", "casual"], + "description": "Communication tone (for draft_update)" + }, + "highlights": { + "type": "string", + "description": "Key highlights for the update (for draft_update)" + }, + "concerns": { + "type": "string", + "description": "Items requiring attention (for draft_update)" + }, + "sprint_dates": { + "type": "string", + "description": "Sprint date range (for sprint_summary)" + }, + "completed": { + "type": "string", + "description": "Completed items (for sprint_summary)" + }, + "in_progress": { + "type": "string", + "description": "In-progress items (for sprint_summary)" + }, + "blocked": { + "type": "string", + "description": "Blocked items (for sprint_summary)" + }, + "tasks": { + "type": "string", + "description": "Task descriptions, one per line (for effort_estimate)" + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = args + .get("action") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?; + + match action { + "status_report" => self.execute_status_report(&args), + "risk_scan" => self.execute_risk_scan(&args), + "draft_update" => self.execute_draft_update(&args), + "sprint_summary" => self.execute_sprint_summary(&args), + "effort_estimate" => self.execute_effort_estimate(&args), + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{other}'. Valid actions: status_report, risk_scan, draft_update, sprint_summary, effort_estimate" + )), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tool() -> ProjectIntelTool { + ProjectIntelTool::new("en".into(), "medium".into()) + } + + #[test] + fn tool_name_and_description() { + let t = tool(); + assert_eq!(t.name(), "project_intel"); + assert!(!t.description().is_empty()); + } + + #[test] + fn parameters_schema_has_action() { + let t = tool(); + let schema = t.parameters_schema(); + assert!(schema["properties"]["action"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::Value::String("action".into()))); + } + + #[tokio::test] + async fn status_report_renders() { + let t = tool(); + let result = t + .execute(json!({ + "action": "status_report", + "project_name": "TestProject", + "period": "week", + "git_log": "- feat: added login" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("TestProject")); + assert!(result.output.contains("added login")); + } + + #[tokio::test] + async fn risk_scan_detects_blockers() { + let t = tool(); + let result = t + .execute(json!({ + "action": "risk_scan", + "blockers": "DB migration stuck\nCI pipeline broken\nAPI key expired" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("blocker")); + } + + #[tokio::test] + async fn risk_scan_detects_deadline_risk() { + let t = tool(); + let result = t + .execute(json!({ + "action": "risk_scan", + "deadlines": "Sprint deadline overdue by 3 days" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Deadline risk")); + } + + #[tokio::test] + async fn risk_scan_no_signals_returns_low_risk() { + let t = tool(); + let result = t.execute(json!({ "action": "risk_scan" })).await.unwrap(); + assert!(result.success); + assert!(result.output.contains("No significant risks")); + } + + #[tokio::test] + async fn draft_update_formal_client() { + let t = tool(); + let result = t + .execute(json!({ + "action": "draft_update", + "project_name": "Portal", + "audience": "client", + "tone": "formal", + "highlights": "Phase 1 delivered" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Dear valued partner")); + assert!(result.output.contains("Portal")); + assert!(result.output.contains("Phase 1 delivered")); + } + + #[tokio::test] + async fn draft_update_casual_internal() { + let t = tool(); + let result = t + .execute(json!({ + "action": "draft_update", + "project_name": "ZeroClaw", + "audience": "internal", + "tone": "casual", + "highlights": "Core loop stabilized" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Hey team")); + assert!(result.output.contains("Cheers")); + } + + #[tokio::test] + async fn sprint_summary_renders() { + let t = tool(); + let result = t + .execute(json!({ + "action": "sprint_summary", + "sprint_dates": "2026-03-01 to 2026-03-14", + "completed": "- Login page\n- API endpoints", + "in_progress": "- Dashboard", + "blocked": "- Payment integration" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Login page")); + assert!(result.output.contains("Dashboard")); + } + + #[tokio::test] + async fn effort_estimate_basic() { + let t = tool(); + let result = t + .execute(json!({ + "action": "effort_estimate", + "tasks": "Fix typo in README\nImplement user authentication\nRefactor database layer" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("XS")); + assert!(result.output.contains("Refactor database layer")); + } + + #[tokio::test] + async fn effort_estimate_empty_tasks_fails() { + let t = tool(); + let result = t + .execute(json!({ "action": "effort_estimate", "tasks": "" })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("No task descriptions")); + } + + #[tokio::test] + async fn unknown_action_returns_error() { + let t = tool(); + let result = t + .execute(json!({ "action": "invalid_thing" })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown action")); + } + + #[tokio::test] + async fn missing_action_returns_error() { + let t = tool(); + let result = t.execute(json!({})).await; + assert!(result.is_err()); + } + + #[test] + fn effort_estimate_heuristics_coverage() { + assert_eq!(estimate_task_effort("Fix typo").0, "XS"); + assert_eq!(estimate_task_effort("Update config values").0, "XS"); + assert_eq!( + estimate_task_effort("Implement new notification system").0, + "M" + ); + assert_eq!( + estimate_task_effort("Refactor the entire authentication module").0, + "XL" + ); + assert_eq!( + estimate_task_effort("Migrate the database schema to support multi-tenancy with data isolation and proper indexing across all services").0, + "XXL" + ); + } + + #[test] + fn risk_sensitivity_threshold_ordering() { + assert!( + RiskSensitivity::High.threshold_factor() < RiskSensitivity::Medium.threshold_factor() + ); + assert!( + RiskSensitivity::Medium.threshold_factor() < RiskSensitivity::Low.threshold_factor() + ); + } + + #[test] + fn risk_sensitivity_from_str_variants() { + assert_eq!(RiskSensitivity::from_str("low"), RiskSensitivity::Low); + assert_eq!(RiskSensitivity::from_str("high"), RiskSensitivity::High); + assert_eq!(RiskSensitivity::from_str("medium"), RiskSensitivity::Medium); + assert_eq!( + RiskSensitivity::from_str("unknown"), + RiskSensitivity::Medium + ); + } + + #[tokio::test] + async fn high_sensitivity_detects_single_blocker_as_high() { + let t = ProjectIntelTool::new("en".into(), "high".into()); + let result = t + .execute(json!({ + "action": "risk_scan", + "blockers": "Single blocker" + })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("[HIGH]") || result.output.contains("[CRITICAL]")); + } +} diff --git a/src/tools/report_templates.rs b/src/tools/report_templates.rs new file mode 100644 index 000000000..930ecbeff --- /dev/null +++ b/src/tools/report_templates.rs @@ -0,0 +1,582 @@ +//! Report template engine for project delivery intelligence. +//! +//! Provides built-in templates for weekly status, sprint review, risk register, +//! and milestone reports with multi-language support (EN, DE, FR, IT). + +use std::collections::HashMap; +use std::fmt::Write as _; + +/// Supported report output formats. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReportFormat { + Markdown, + Html, +} + +/// A named section within a report template. +#[derive(Debug, Clone)] +pub struct TemplateSection { + pub heading: String, + pub body: String, +} + +/// A report template with named sections and variable placeholders. +#[derive(Debug, Clone)] +pub struct ReportTemplate { + pub name: String, + pub sections: Vec, + pub format: ReportFormat, +} + +/// Escape a string for safe inclusion in HTML output. +fn escape_html(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +impl ReportTemplate { + /// Render the template by substituting `{{key}}` placeholders with values. + pub fn render(&self, vars: &HashMap) -> String { + let mut out = String::new(); + for section in &self.sections { + let heading = substitute(§ion.heading, vars); + let body = substitute(§ion.body, vars); + match self.format { + ReportFormat::Markdown => { + let _ = write!(out, "## {heading}\n\n{body}\n\n"); + } + ReportFormat::Html => { + let heading = escape_html(&heading); + let body = escape_html(&body); + let _ = write!(out, "

{heading}

\n

{body}

\n"); + } + } + } + out.trim_end().to_string() + } +} + +/// Single-pass placeholder substitution. +/// +/// Scans `template` left-to-right for `{{key}}` tokens and replaces them with +/// the corresponding value from `vars`. Because the scan is single-pass, +/// values that themselves contain `{{...}}` sequences are emitted literally +/// and never re-expanded, preventing injection of new placeholders. +fn substitute(template: &str, vars: &HashMap) -> String { + let mut result = String::with_capacity(template.len()); + let bytes = template.as_bytes(); + let len = bytes.len(); + let mut i = 0; + + while i < len { + if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' { + // Find the closing `}}`. + if let Some(close) = template[i + 2..].find("}}") { + let key = &template[i + 2..i + 2 + close]; + if let Some(value) = vars.get(key) { + result.push_str(value); + } else { + // Unknown placeholder: emit as-is. + result.push_str(&template[i..i + 2 + close + 2]); + } + i += 2 + close + 2; + continue; + } + } + result.push(template.as_bytes()[i] as char); + i += 1; + } + + result +} + +// ── Built-in templates ──────────────────────────────────────────── + +/// Return the built-in weekly status template for the given language. +pub fn weekly_status_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Wochenstatus", + vec![ + TemplateSection { + heading: "Zusammenfassung".into(), + body: "Projekt: {{project_name}} | Zeitraum: {{period}}".into(), + }, + TemplateSection { + heading: "Erledigt".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Bearbeitung".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blockiert".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Naechste Schritte".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + "fr" => ( + "Statut hebdomadaire", + vec![ + TemplateSection { + heading: "Resume".into(), + body: "Projet: {{project_name}} | Periode: {{period}}".into(), + }, + TemplateSection { + heading: "Termine".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "En cours".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloque".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Prochaines etapes".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + "it" => ( + "Stato settimanale", + vec![ + TemplateSection { + heading: "Riepilogo".into(), + body: "Progetto: {{project_name}} | Periodo: {{period}}".into(), + }, + TemplateSection { + heading: "Completato".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In corso".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloccato".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Prossimi passi".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + _ => ( + "Weekly Status", + vec![ + TemplateSection { + heading: "Summary".into(), + body: "Project: {{project_name}} | Period: {{period}}".into(), + }, + TemplateSection { + heading: "Completed".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Progress".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blocked".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Next Steps".into(), + body: "{{next_steps}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +/// Return the built-in sprint review template for the given language. +pub fn sprint_review_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Sprint-Uebersicht", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Erledigt".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Bearbeitung".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blockiert".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocity".into(), + body: "{{velocity}}".into(), + }, + ], + ), + "fr" => ( + "Revue de sprint", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Termine".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "En cours".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloque".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocite".into(), + body: "{{velocity}}".into(), + }, + ], + ), + "it" => ( + "Revisione sprint", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Completato".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In corso".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Bloccato".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocita".into(), + body: "{{velocity}}".into(), + }, + ], + ), + _ => ( + "Sprint Review", + vec![ + TemplateSection { + heading: "Sprint".into(), + body: "{{sprint_dates}}".into(), + }, + TemplateSection { + heading: "Completed".into(), + body: "{{completed}}".into(), + }, + TemplateSection { + heading: "In Progress".into(), + body: "{{in_progress}}".into(), + }, + TemplateSection { + heading: "Blocked".into(), + body: "{{blocked}}".into(), + }, + TemplateSection { + heading: "Velocity".into(), + body: "{{velocity}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +/// Return the built-in risk register template for the given language. +pub fn risk_register_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Risikoregister", + vec![ + TemplateSection { + heading: "Projekt".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Risiken".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Massnahmen".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + "fr" => ( + "Registre des risques", + vec![ + TemplateSection { + heading: "Projet".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Risques".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Mesures".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + "it" => ( + "Registro dei rischi", + vec![ + TemplateSection { + heading: "Progetto".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Rischi".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Mitigazioni".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + _ => ( + "Risk Register", + vec![ + TemplateSection { + heading: "Project".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Risks".into(), + body: "{{risks}}".into(), + }, + TemplateSection { + heading: "Mitigations".into(), + body: "{{mitigations}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +/// Return the built-in milestone report template for the given language. +pub fn milestone_report_template(lang: &str) -> ReportTemplate { + let (name, sections) = match lang { + "de" => ( + "Meilensteinbericht", + vec![ + TemplateSection { + heading: "Projekt".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Meilensteine".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Status".into(), + body: "{{status}}".into(), + }, + ], + ), + "fr" => ( + "Rapport de jalons", + vec![ + TemplateSection { + heading: "Projet".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Jalons".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Statut".into(), + body: "{{status}}".into(), + }, + ], + ), + "it" => ( + "Report milestone", + vec![ + TemplateSection { + heading: "Progetto".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Milestone".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Stato".into(), + body: "{{status}}".into(), + }, + ], + ), + _ => ( + "Milestone Report", + vec![ + TemplateSection { + heading: "Project".into(), + body: "{{project_name}}".into(), + }, + TemplateSection { + heading: "Milestones".into(), + body: "{{milestones}}".into(), + }, + TemplateSection { + heading: "Status".into(), + body: "{{status}}".into(), + }, + ], + ), + }; + ReportTemplate { + name: name.into(), + sections, + format: ReportFormat::Markdown, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn weekly_status_renders_with_variables() { + let tpl = weekly_status_template("en"); + let mut vars = HashMap::new(); + vars.insert("project_name".into(), "ZeroClaw".into()); + vars.insert("period".into(), "2026-W10".into()); + vars.insert("completed".into(), "- Task A\n- Task B".into()); + vars.insert("in_progress".into(), "- Task C".into()); + vars.insert("blocked".into(), "None".into()); + vars.insert("next_steps".into(), "- Task D".into()); + + let rendered = tpl.render(&vars); + assert!(rendered.contains("Project: ZeroClaw")); + assert!(rendered.contains("Period: 2026-W10")); + assert!(rendered.contains("- Task A")); + assert!(rendered.contains("## Completed")); + } + + #[test] + fn weekly_status_de_renders_german_headings() { + let tpl = weekly_status_template("de"); + let vars = HashMap::new(); + let rendered = tpl.render(&vars); + assert!(rendered.contains("## Zusammenfassung")); + assert!(rendered.contains("## Erledigt")); + } + + #[test] + fn weekly_status_fr_renders_french_headings() { + let tpl = weekly_status_template("fr"); + let vars = HashMap::new(); + let rendered = tpl.render(&vars); + assert!(rendered.contains("## Resume")); + assert!(rendered.contains("## Termine")); + } + + #[test] + fn weekly_status_it_renders_italian_headings() { + let tpl = weekly_status_template("it"); + let vars = HashMap::new(); + let rendered = tpl.render(&vars); + assert!(rendered.contains("## Riepilogo")); + assert!(rendered.contains("## Completato")); + } + + #[test] + fn html_format_renders_tags() { + let mut tpl = weekly_status_template("en"); + tpl.format = ReportFormat::Html; + let mut vars = HashMap::new(); + vars.insert("project_name".into(), "Test".into()); + vars.insert("period".into(), "W1".into()); + vars.insert("completed".into(), "Done".into()); + vars.insert("in_progress".into(), "WIP".into()); + vars.insert("blocked".into(), "None".into()); + vars.insert("next_steps".into(), "Next".into()); + + let rendered = tpl.render(&vars); + assert!(rendered.contains("

Summary

")); + assert!(rendered.contains("

Project: Test | Period: W1

")); + } + + #[test] + fn sprint_review_template_has_velocity_section() { + let tpl = sprint_review_template("en"); + let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect(); + assert!(section_headings.contains(&"Velocity")); + } + + #[test] + fn risk_register_template_has_risk_sections() { + let tpl = risk_register_template("en"); + let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect(); + assert!(section_headings.contains(&"Risks")); + assert!(section_headings.contains(&"Mitigations")); + } + + #[test] + fn milestone_template_all_languages() { + for lang in &["en", "de", "fr", "it"] { + let tpl = milestone_report_template(lang); + assert!(!tpl.name.is_empty()); + assert_eq!(tpl.sections.len(), 3); + } + } + + #[test] + fn substitute_leaves_unknown_placeholders() { + let vars = HashMap::new(); + let result = substitute("Hello {{name}}", &vars); + assert_eq!(result, "Hello {{name}}"); + } + + #[test] + fn substitute_replaces_all_occurrences() { + let mut vars = HashMap::new(); + vars.insert("x".into(), "1".into()); + let result = substitute("{{x}} and {{x}}", &vars); + assert_eq!(result, "1 and 1"); + } +} diff --git a/src/tools/security_ops.rs b/src/tools/security_ops.rs new file mode 100644 index 000000000..92ce18d06 --- /dev/null +++ b/src/tools/security_ops.rs @@ -0,0 +1,659 @@ +//! Security operations tool for managed cybersecurity service (MCSS) workflows. +//! +//! Provides alert triage, incident response playbook execution, vulnerability +//! scan parsing, and security report generation. All actions that modify state +//! enforce human approval gates unless explicitly configured otherwise. + +use async_trait::async_trait; +use serde_json::json; +use std::path::PathBuf; + +use super::traits::{Tool, ToolResult}; +use crate::config::SecurityOpsConfig; +use crate::security::playbook::{ + evaluate_step, load_playbooks, severity_level, Playbook, StepStatus, +}; +use crate::security::vulnerability::{generate_summary, parse_vulnerability_json}; + +/// Security operations tool — triage alerts, run playbooks, parse vulns, generate reports. +pub struct SecurityOpsTool { + config: SecurityOpsConfig, + playbooks: Vec, +} + +impl SecurityOpsTool { + pub fn new(config: SecurityOpsConfig) -> Self { + let playbooks_dir = expand_tilde(&config.playbooks_dir); + let playbooks = load_playbooks(&playbooks_dir); + Self { config, playbooks } + } + + /// Triage an alert: classify severity and recommend response. + fn triage_alert(&self, args: &serde_json::Value) -> anyhow::Result { + let alert = args + .get("alert") + .ok_or_else(|| anyhow::anyhow!("Missing required 'alert' parameter"))?; + + // Extract key fields for classification + let alert_type = alert + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let source = alert + .get("source") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let severity = alert + .get("severity") + .and_then(|v| v.as_str()) + .unwrap_or("medium"); + let description = alert + .get("description") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + // Classify and find matching playbooks + let matching_playbooks: Vec<&Playbook> = self + .playbooks + .iter() + .filter(|pb| { + severity_level(severity) >= severity_level(&pb.severity_filter) + && (pb.name.contains(alert_type) + || alert_type.contains(&pb.name) + || description + .to_lowercase() + .contains(&pb.name.replace('_', " "))) + }) + .collect(); + + let playbook_names: Vec<&str> = + matching_playbooks.iter().map(|p| p.name.as_str()).collect(); + + let output = json!({ + "classification": { + "alert_type": alert_type, + "source": source, + "severity": severity, + "severity_level": severity_level(severity), + "priority": if severity_level(severity) >= 3 { "immediate" } else { "standard" }, + }, + "recommended_playbooks": playbook_names, + "recommended_action": if matching_playbooks.is_empty() { + "Manual investigation required — no matching playbook found" + } else { + "Execute recommended playbook(s)" + }, + "auto_triage": self.config.auto_triage, + }); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } + + /// Execute a playbook step with approval gating. + fn run_playbook(&self, args: &serde_json::Value) -> anyhow::Result { + let playbook_name = args + .get("playbook") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'playbook' parameter"))?; + + let step_index = + usize::try_from(args.get("step").and_then(|v| v.as_u64()).ok_or_else(|| { + anyhow::anyhow!("Missing required 'step' parameter (0-based index)") + })?) + .map_err(|_| anyhow::anyhow!("'step' parameter value too large for this platform"))?; + + let alert_severity = args + .get("alert_severity") + .and_then(|v| v.as_str()) + .unwrap_or("medium"); + + let playbook = self + .playbooks + .iter() + .find(|p| p.name == playbook_name) + .ok_or_else(|| anyhow::anyhow!("Playbook '{}' not found", playbook_name))?; + + let result = evaluate_step( + playbook, + step_index, + alert_severity, + &self.config.max_auto_severity, + self.config.require_approval_for_actions, + ); + + let output = json!({ + "playbook": playbook_name, + "step_index": result.step_index, + "action": result.action, + "status": result.status.to_string(), + "message": result.message, + "requires_manual_approval": result.status == StepStatus::PendingApproval, + }); + + Ok(ToolResult { + success: result.status != StepStatus::Failed, + output: serde_json::to_string_pretty(&output)?, + error: if result.status == StepStatus::Failed { + Some(result.message) + } else { + None + }, + }) + } + + /// Parse vulnerability scan results. + fn parse_vulnerability(&self, args: &serde_json::Value) -> anyhow::Result { + let scan_data = args + .get("scan_data") + .ok_or_else(|| anyhow::anyhow!("Missing required 'scan_data' parameter"))?; + + let json_str = if scan_data.is_string() { + scan_data.as_str().unwrap().to_string() + } else { + serde_json::to_string(scan_data)? + }; + + let report = parse_vulnerability_json(&json_str)?; + let summary = generate_summary(&report); + + let output = json!({ + "scanner": report.scanner, + "scan_date": report.scan_date.to_rfc3339(), + "total_findings": report.findings.len(), + "by_severity": { + "critical": report.findings.iter().filter(|f| f.severity == "critical").count(), + "high": report.findings.iter().filter(|f| f.severity == "high").count(), + "medium": report.findings.iter().filter(|f| f.severity == "medium").count(), + "low": report.findings.iter().filter(|f| f.severity == "low").count(), + }, + "summary": summary, + }); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } + + /// Generate a client-facing security posture report. + fn generate_report(&self, args: &serde_json::Value) -> anyhow::Result { + let client_name = args + .get("client_name") + .and_then(|v| v.as_str()) + .unwrap_or("Client"); + let period = args + .get("period") + .and_then(|v| v.as_str()) + .unwrap_or("current"); + let alert_stats = args.get("alert_stats"); + let vuln_summary = args + .get("vuln_summary") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let report = format!( + "# Security Posture Report — {client_name}\n\ + **Period:** {period}\n\ + **Generated:** {}\n\n\ + ## Executive Summary\n\n\ + This report provides an overview of the security posture for {client_name} \ + during the {period} period.\n\n\ + ## Alert Summary\n\n\ + {}\n\n\ + ## Vulnerability Assessment\n\n\ + {}\n\n\ + ## Recommendations\n\n\ + 1. Address all critical and high-severity findings immediately\n\ + 2. Review and update incident response playbooks quarterly\n\ + 3. Conduct regular vulnerability scans on all internet-facing assets\n\ + 4. Ensure all endpoints have current security patches\n\n\ + ---\n\ + *Report generated by ZeroClaw MCSS Agent*\n", + chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"), + alert_stats + .map(|s| serde_json::to_string_pretty(s).unwrap_or_default()) + .unwrap_or_else(|| "No alert statistics provided.".into()), + if vuln_summary.is_empty() { + "No vulnerability data provided." + } else { + vuln_summary + }, + ); + + Ok(ToolResult { + success: true, + output: report, + error: None, + }) + } + + /// List available playbooks. + fn list_playbooks(&self) -> anyhow::Result { + if self.playbooks.is_empty() { + return Ok(ToolResult { + success: true, + output: "No playbooks available.".into(), + error: None, + }); + } + + let playbook_list: Vec = self + .playbooks + .iter() + .map(|pb| { + json!({ + "name": pb.name, + "description": pb.description, + "steps": pb.steps.len(), + "severity_filter": pb.severity_filter, + "auto_approve_steps": pb.auto_approve_steps, + }) + }) + .collect(); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&playbook_list)?, + error: None, + }) + } + + /// Summarize alert volume, categories, and resolution times. + fn alert_stats(&self, args: &serde_json::Value) -> anyhow::Result { + let alerts = args + .get("alerts") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'alerts' array parameter"))?; + + let total = alerts.len(); + let mut by_severity = std::collections::HashMap::new(); + let mut by_category = std::collections::HashMap::new(); + let mut resolved_count = 0u64; + let mut total_resolution_secs = 0u64; + + for alert in alerts { + let severity = alert + .get("severity") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + *by_severity.entry(severity.to_string()).or_insert(0u64) += 1; + + let category = alert + .get("category") + .and_then(|v| v.as_str()) + .unwrap_or("uncategorized"); + *by_category.entry(category.to_string()).or_insert(0u64) += 1; + + if let Some(resolution_secs) = alert.get("resolution_secs").and_then(|v| v.as_u64()) { + resolved_count += 1; + total_resolution_secs += resolution_secs; + } + } + + let avg_resolution = if resolved_count > 0 { + total_resolution_secs as f64 / resolved_count as f64 + } else { + 0.0 + }; + + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let avg_resolution_secs_u64 = avg_resolution.max(0.0) as u64; + + let output = json!({ + "total_alerts": total, + "resolved": resolved_count, + "unresolved": total as u64 - resolved_count, + "by_severity": by_severity, + "by_category": by_category, + "avg_resolution_secs": avg_resolution, + "avg_resolution_human": format_duration_secs(avg_resolution_secs_u64), + }); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output)?, + error: None, + }) + } +} + +fn format_duration_secs(secs: u64) -> String { + if secs < 60 { + format!("{secs}s") + } else if secs < 3600 { + format!("{}m {}s", secs / 60, secs % 60) + } else { + format!("{}h {}m", secs / 3600, (secs % 3600) / 60) + } +} + +/// Expand ~ to home directory. +fn expand_tilde(path: &str) -> PathBuf { + if let Some(rest) = path.strip_prefix("~/") { + if let Some(user_dirs) = directories::UserDirs::new() { + return user_dirs.home_dir().join(rest); + } + } + PathBuf::from(path) +} + +#[async_trait] +impl Tool for SecurityOpsTool { + fn name(&self) -> &str { + "security_ops" + } + + fn description(&self) -> &str { + "Security operations tool for managed cybersecurity services. Actions: \ + triage_alert (classify/prioritize alerts), run_playbook (execute incident response steps), \ + parse_vulnerability (parse scan results), generate_report (create security posture reports), \ + list_playbooks (list available playbooks), alert_stats (summarize alert metrics)." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "required": ["action"], + "properties": { + "action": { + "type": "string", + "enum": ["triage_alert", "run_playbook", "parse_vulnerability", "generate_report", "list_playbooks", "alert_stats"], + "description": "The security operation to perform" + }, + "alert": { + "type": "object", + "description": "Alert JSON for triage_alert (requires: type, severity; optional: source, description)" + }, + "playbook": { + "type": "string", + "description": "Playbook name for run_playbook" + }, + "step": { + "type": "integer", + "description": "0-based step index for run_playbook" + }, + "alert_severity": { + "type": "string", + "description": "Alert severity context for run_playbook" + }, + "scan_data": { + "description": "Vulnerability scan data (JSON string or object) for parse_vulnerability" + }, + "client_name": { + "type": "string", + "description": "Client name for generate_report" + }, + "period": { + "type": "string", + "description": "Reporting period for generate_report" + }, + "alert_stats": { + "type": "object", + "description": "Alert statistics to include in generate_report" + }, + "vuln_summary": { + "type": "string", + "description": "Vulnerability summary to include in generate_report" + }, + "alerts": { + "type": "array", + "description": "Array of alert objects for alert_stats" + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let action = args + .get("action") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?; + + match action { + "triage_alert" => self.triage_alert(&args), + "run_playbook" => self.run_playbook(&args), + "parse_vulnerability" => self.parse_vulnerability(&args), + "generate_report" => self.generate_report(&args), + "list_playbooks" => self.list_playbooks(), + "alert_stats" => self.alert_stats(&args), + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{action}'. Valid: triage_alert, run_playbook, \ + parse_vulnerability, generate_report, list_playbooks, alert_stats" + )), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_config() -> SecurityOpsConfig { + SecurityOpsConfig { + enabled: true, + playbooks_dir: "/nonexistent".into(), + auto_triage: false, + require_approval_for_actions: true, + max_auto_severity: "low".into(), + report_output_dir: "/tmp/reports".into(), + siem_integration: None, + } + } + + fn test_tool() -> SecurityOpsTool { + SecurityOpsTool::new(test_config()) + } + + #[test] + fn tool_name_and_schema() { + let tool = test_tool(); + assert_eq!(tool.name(), "security_ops"); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["action"].is_object()); + assert!(schema["required"] + .as_array() + .unwrap() + .contains(&json!("action"))); + } + + #[tokio::test] + async fn triage_alert_classifies_severity() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "triage_alert", + "alert": { + "type": "suspicious_login", + "source": "siem", + "severity": "high", + "description": "Multiple failed login attempts followed by successful login" + } + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["classification"]["severity"], "high"); + assert_eq!(output["classification"]["priority"], "immediate"); + // Should match suspicious_login playbook + let playbooks = output["recommended_playbooks"].as_array().unwrap(); + assert!(playbooks.iter().any(|p| p == "suspicious_login")); + } + + #[tokio::test] + async fn triage_alert_missing_alert_param() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "triage_alert"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn run_playbook_requires_approval() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "run_playbook", + "playbook": "suspicious_login", + "step": 2, + "alert_severity": "high" + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["status"], "pending_approval"); + assert_eq!(output["requires_manual_approval"], true); + } + + #[tokio::test] + async fn run_playbook_executes_safe_step() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "run_playbook", + "playbook": "suspicious_login", + "step": 0, + "alert_severity": "medium" + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["status"], "completed"); + } + + #[tokio::test] + async fn run_playbook_not_found() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "run_playbook", + "playbook": "nonexistent", + "step": 0 + })) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn parse_vulnerability_valid_report() { + let tool = test_tool(); + let scan_data = json!({ + "scan_date": "2025-01-15T10:00:00Z", + "scanner": "nessus", + "findings": [ + { + "cve_id": "CVE-2024-0001", + "cvss_score": 9.8, + "severity": "critical", + "affected_asset": "web-01", + "description": "RCE in web framework", + "remediation": "Upgrade", + "internet_facing": true, + "production": true + } + ] + }); + + let result = tool + .execute(json!({ + "action": "parse_vulnerability", + "scan_data": scan_data + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["total_findings"], 1); + assert_eq!(output["by_severity"]["critical"], 1); + } + + #[tokio::test] + async fn generate_report_produces_markdown() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "generate_report", + "client_name": "ZeroClaw Corp", + "period": "Q1 2025" + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("ZeroClaw Corp")); + assert!(result.output.contains("Q1 2025")); + assert!(result.output.contains("Security Posture Report")); + } + + #[tokio::test] + async fn list_playbooks_returns_builtins() { + let tool = test_tool(); + let result = tool + .execute(json!({"action": "list_playbooks"})) + .await + .unwrap(); + + assert!(result.success); + let output: Vec = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output.len(), 4); + let names: Vec<&str> = output.iter().map(|p| p["name"].as_str().unwrap()).collect(); + assert!(names.contains(&"suspicious_login")); + assert!(names.contains(&"malware_detected")); + } + + #[tokio::test] + async fn alert_stats_computes_summary() { + let tool = test_tool(); + let result = tool + .execute(json!({ + "action": "alert_stats", + "alerts": [ + {"severity": "critical", "category": "malware", "resolution_secs": 3600}, + {"severity": "high", "category": "phishing", "resolution_secs": 1800}, + {"severity": "medium", "category": "malware"}, + {"severity": "low", "category": "policy_violation", "resolution_secs": 600} + ] + })) + .await + .unwrap(); + + assert!(result.success); + let output: serde_json::Value = serde_json::from_str(&result.output).unwrap(); + assert_eq!(output["total_alerts"], 4); + assert_eq!(output["resolved"], 3); + assert_eq!(output["unresolved"], 1); + assert_eq!(output["by_severity"]["critical"], 1); + assert_eq!(output["by_category"]["malware"], 2); + } + + #[tokio::test] + async fn unknown_action_returns_error() { + let tool = test_tool(); + let result = tool.execute(json!({"action": "bad_action"})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown action")); + } + + #[test] + fn format_duration_secs_readable() { + assert_eq!(format_duration_secs(45), "45s"); + assert_eq!(format_duration_secs(125), "2m 5s"); + assert_eq!(format_duration_secs(3665), "1h 1m"); + } +} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 6a852d8cc..52424f8a5 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,6 +2,7 @@ mod cloudflare; mod custom; mod ngrok; mod none; +mod openvpn; mod tailscale; pub use cloudflare::CloudflareTunnel; @@ -9,6 +10,7 @@ pub use custom::CustomTunnel; pub use ngrok::NgrokTunnel; #[allow(unused_imports)] pub use none::NoneTunnel; +pub use openvpn::OpenVpnTunnel; pub use tailscale::TailscaleTunnel; use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig}; @@ -104,6 +106,20 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result>> { )))) } + "openvpn" => { + let ov = config + .openvpn + .as_ref() + .ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?; + Ok(Some(Box::new(OpenVpnTunnel::new( + ov.config_file.clone(), + ov.auth_file.clone(), + ov.advertise_address.clone(), + ov.connect_timeout_secs, + ov.extra_args.clone(), + )))) + } + "custom" => { let cu = config .custom @@ -116,7 +132,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result>> { )))) } - other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"), + other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"), } } @@ -126,7 +142,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result>> { mod tests { use super::*; use crate::config::schema::{ - CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig, + CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig, + TunnelConfig, }; use tokio::process::Command; @@ -315,6 +332,46 @@ mod tests { assert!(t.public_url().is_none()); } + #[test] + fn factory_openvpn_missing_config_errors() { + let cfg = TunnelConfig { + provider: "openvpn".into(), + ..TunnelConfig::default() + }; + assert_tunnel_err(&cfg, "[tunnel.openvpn]"); + } + + #[test] + fn factory_openvpn_with_config_ok() { + let cfg = TunnelConfig { + provider: "openvpn".into(), + openvpn: Some(OpenVpnTunnelConfig { + config_file: "client.ovpn".into(), + auth_file: None, + advertise_address: None, + connect_timeout_secs: 30, + extra_args: vec![], + }), + ..TunnelConfig::default() + }; + let t = create_tunnel(&cfg).unwrap(); + assert!(t.is_some()); + assert_eq!(t.unwrap().name(), "openvpn"); + } + + #[test] + fn openvpn_tunnel_name() { + let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert_eq!(t.name(), "openvpn"); + assert!(t.public_url().is_none()); + } + + #[tokio::test] + async fn openvpn_health_false_before_start() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert!(!tunnel.health_check().await); + } + #[tokio::test] async fn kill_shared_no_process_is_ok() { let proc = new_shared_process(); diff --git a/src/tunnel/openvpn.rs b/src/tunnel/openvpn.rs new file mode 100644 index 000000000..dd7f72ad7 --- /dev/null +++ b/src/tunnel/openvpn.rs @@ -0,0 +1,254 @@ +use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess}; +use anyhow::{bail, Result}; +use tokio::io::AsyncBufReadExt; +use tokio::process::Command; + +/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection. +/// +/// Requires the `openvpn` binary installed and accessible. On most systems, +/// OpenVPN requires root/administrator privileges to create tun/tap devices. +/// +/// The tunnel exposes the gateway via the VPN network using a configured +/// `advertise_address` (e.g., `"10.8.0.2:42617"`). +pub struct OpenVpnTunnel { + config_file: String, + auth_file: Option, + advertise_address: Option, + connect_timeout_secs: u64, + extra_args: Vec, + proc: SharedProcess, +} + +impl OpenVpnTunnel { + /// Create a new OpenVPN tunnel instance. + /// + /// * `config_file` — path to the `.ovpn` configuration file. + /// * `auth_file` — optional path to a credentials file for `--auth-user-pass`. + /// * `advertise_address` — optional public address to advertise once connected. + /// * `connect_timeout_secs` — seconds to wait for the initialization sequence. + /// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary. + pub fn new( + config_file: String, + auth_file: Option, + advertise_address: Option, + connect_timeout_secs: u64, + extra_args: Vec, + ) -> Self { + Self { + config_file, + auth_file, + advertise_address, + connect_timeout_secs, + extra_args, + proc: new_shared_process(), + } + } + + /// Build the openvpn command arguments. + fn build_args(&self) -> Vec { + let mut args = vec!["--config".to_string(), self.config_file.clone()]; + + if let Some(ref auth) = self.auth_file { + args.push("--auth-user-pass".to_string()); + args.push(auth.clone()); + } + + args.extend(self.extra_args.iter().cloned()); + args + } +} + +#[async_trait::async_trait] +impl Tunnel for OpenVpnTunnel { + fn name(&self) -> &str { + "openvpn" + } + + /// Spawn the `openvpn` process and wait for the "Initialization Sequence + /// Completed" marker on stderr. Returns the public URL on success. + async fn start(&self, local_host: &str, local_port: u16) -> Result { + // Validate config file exists before spawning + if !std::path::Path::new(&self.config_file).exists() { + bail!("OpenVPN config file not found: {}", self.config_file); + } + + let args = self.build_args(); + + let mut child = Command::new("openvpn") + .args(&args) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn()?; + + // Wait for "Initialization Sequence Completed" in stderr + let stderr = child + .stderr + .take() + .ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?; + + let mut reader = tokio::io::BufReader::new(stderr).lines(); + let deadline = tokio::time::Instant::now() + + tokio::time::Duration::from_secs(self.connect_timeout_secs); + + let mut connected = false; + while tokio::time::Instant::now() < deadline { + let line = + tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await; + + match line { + Ok(Ok(Some(l))) => { + tracing::debug!("openvpn: {l}"); + if l.contains("Initialization Sequence Completed") { + connected = true; + break; + } + } + Ok(Ok(None)) => { + bail!("OpenVPN process exited before connection was established"); + } + Ok(Err(e)) => { + bail!("Error reading openvpn output: {e}"); + } + Err(_) => { + // Timeout on individual line read, continue waiting + } + } + } + + if !connected { + child.kill().await.ok(); + bail!( + "OpenVPN connection timed out after {}s waiting for initialization", + self.connect_timeout_secs + ); + } + + let public_url = self + .advertise_address + .clone() + .unwrap_or_else(|| format!("http://{local_host}:{local_port}")); + + // Drain stderr in background to prevent OS pipe buffer from filling and + // blocking the openvpn process. + tokio::spawn(async move { + while let Ok(Some(line)) = reader.next_line().await { + tracing::trace!("openvpn: {line}"); + } + }); + + let mut guard = self.proc.lock().await; + *guard = Some(TunnelProcess { + child, + public_url: public_url.clone(), + }); + + Ok(public_url) + } + + /// Kill the openvpn child process and release its resources. + async fn stop(&self) -> Result<()> { + kill_shared(&self.proc).await + } + + /// Return `true` if the openvpn child process is still running. + async fn health_check(&self) -> bool { + let guard = self.proc.lock().await; + guard.as_ref().is_some_and(|tp| tp.child.id().is_some()) + } + + /// Return the public URL if the tunnel has been started. + fn public_url(&self) -> Option { + self.proc + .try_lock() + .ok() + .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constructor_stores_fields() { + let tunnel = OpenVpnTunnel::new( + "/etc/openvpn/client.ovpn".into(), + Some("/etc/openvpn/auth.txt".into()), + Some("10.8.0.2:42617".into()), + 45, + vec!["--verb".into(), "3".into()], + ); + assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn"); + assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt")); + assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617")); + assert_eq!(tunnel.connect_timeout_secs, 45); + assert_eq!(tunnel.extra_args, vec!["--verb", "3"]); + } + + #[test] + fn build_args_basic() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + let args = tunnel.build_args(); + assert_eq!(args, vec!["--config", "client.ovpn"]); + } + + #[test] + fn build_args_with_auth_and_extras() { + let tunnel = OpenVpnTunnel::new( + "client.ovpn".into(), + Some("auth.txt".into()), + None, + 30, + vec!["--verb".into(), "5".into()], + ); + let args = tunnel.build_args(); + assert_eq!( + args, + vec![ + "--config", + "client.ovpn", + "--auth-user-pass", + "auth.txt", + "--verb", + "5" + ] + ); + } + + #[test] + fn public_url_is_none_before_start() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert!(tunnel.public_url().is_none()); + } + + #[tokio::test] + async fn health_check_is_false_before_start() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + assert!(!tunnel.health_check().await); + } + + #[tokio::test] + async fn stop_without_started_process_is_ok() { + let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]); + let result = tunnel.stop().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn start_with_missing_config_file_errors() { + let tunnel = OpenVpnTunnel::new( + "/nonexistent/path/to/client.ovpn".into(), + None, + None, + 30, + vec![], + ); + let result = tunnel.start("127.0.0.1", 8080).await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("config file not found")); + } +}