Merge remote-tracking branch 'origin/master' into work/security-ops

# Conflicts:
#	src/config/mod.rs
#	src/config/schema.rs
#	src/onboard/wizard.rs
#	src/tools/mod.rs
This commit is contained in:
argenis de la rosa
2026-03-16 02:16:55 -04:00
26 changed files with 6417 additions and 52 deletions
+36
View File
@@ -30,6 +30,7 @@ pub mod mattermost;
pub mod nextcloud_talk;
#[cfg(feature = "channel-nostr")]
pub mod nostr;
pub mod notion;
pub mod qq;
pub mod session_store;
pub mod signal;
@@ -62,6 +63,7 @@ pub use mattermost::MattermostChannel;
pub use nextcloud_talk::NextcloudTalkChannel;
#[cfg(feature = "channel-nostr")]
pub use nostr::NostrChannel;
pub use notion::NotionChannel;
pub use qq::QQChannel;
pub use signal::SignalChannel;
pub use slack::SlackChannel;
@@ -2981,6 +2983,12 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con
channel.name()
);
}
// Notion is a top-level config section, not part of ChannelsConfig
{
let notion_configured =
config.notion.enabled && !config.notion.database_id.trim().is_empty();
println!(" {} Notion", if notion_configured { "" } else { "" });
}
if !cfg!(feature = "channel-matrix") {
println!(
" ️ Matrix channel support is disabled in this build (enable `channel-matrix`)."
@@ -3413,6 +3421,34 @@ fn collect_configured_channels(
});
}
// Notion database poller channel
if config.notion.enabled && !config.notion.database_id.trim().is_empty() {
let notion_api_key = if config.notion.api_key.trim().is_empty() {
std::env::var("NOTION_API_KEY").unwrap_or_default()
} else {
config.notion.api_key.trim().to_string()
};
if notion_api_key.trim().is_empty() {
tracing::warn!(
"Notion channel enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
);
} else {
channels.push(ConfiguredChannel {
display_name: "Notion",
channel: Arc::new(NotionChannel::new(
notion_api_key,
config.notion.database_id.clone(),
config.notion.poll_interval_secs,
config.notion.status_property.clone(),
config.notion.input_property.clone(),
config.notion.result_property.clone(),
config.notion.max_concurrent,
config.notion.recover_stale,
)),
});
}
}
channels
}
+614
View File
@@ -0,0 +1,614 @@
use super::traits::{Channel, ChannelMessage, SendMessage};
use anyhow::{bail, Result};
use async_trait::async_trait;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
const NOTION_VERSION: &str = "2022-06-28";
const MAX_RESULT_LENGTH: usize = 2000;
const MAX_RETRIES: u32 = 3;
const RETRY_BASE_DELAY_MS: u64 = 2000;
/// Maximum number of characters to include from an error response body.
const MAX_ERROR_BODY_CHARS: usize = 500;
/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary.
fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
if max_bytes >= s.len() {
return s.len();
}
let mut idx = max_bytes;
while idx > 0 && !s.is_char_boundary(idx) {
idx -= 1;
}
idx
}
/// Notion channel — polls a Notion database for pending tasks and writes results back.
///
/// The channel connects to the Notion API, queries a database for rows with a "pending"
/// status, dispatches them as channel messages, and writes results back when processing
/// completes. It supports crash recovery by resetting stale "running" tasks on startup.
pub struct NotionChannel {
api_key: String,
database_id: String,
poll_interval_secs: u64,
status_property: String,
input_property: String,
result_property: String,
max_concurrent: usize,
status_type: Arc<RwLock<String>>,
inflight: Arc<RwLock<HashSet<String>>>,
http: reqwest::Client,
recover_stale: bool,
}
impl NotionChannel {
/// Create a new Notion channel with the given configuration.
pub fn new(
api_key: String,
database_id: String,
poll_interval_secs: u64,
status_property: String,
input_property: String,
result_property: String,
max_concurrent: usize,
recover_stale: bool,
) -> Self {
Self {
api_key,
database_id,
poll_interval_secs,
status_property,
input_property,
result_property,
max_concurrent,
status_type: Arc::new(RwLock::new("select".to_string())),
inflight: Arc::new(RwLock::new(HashSet::new())),
http: reqwest::Client::new(),
recover_stale,
}
}
/// Build the standard Notion API headers (Authorization, version, content-type).
fn headers(&self) -> Result<reqwest::header::HeaderMap> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
);
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
headers.insert("Content-Type", "application/json".parse().unwrap());
Ok(headers)
}
/// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx).
async fn api_call(
&self,
method: reqwest::Method,
url: &str,
body: Option<serde_json::Value>,
) -> Result<serde_json::Value> {
let mut last_err = None;
for attempt in 0..MAX_RETRIES {
let mut req = self
.http
.request(method.clone(), url)
.headers(self.headers()?);
if let Some(ref b) = body {
req = req.json(b);
}
match req.send().await {
Ok(resp) => {
let status = resp.status();
if status.is_success() {
return resp
.json()
.await
.map_err(|e| anyhow::anyhow!("Failed to parse response: {e}"));
}
let status_code = status.as_u16();
// Only retry on 429 (rate limit) or 5xx (server errors)
if status_code != 429 && (400..500).contains(&status_code) {
let body_text = resp.text().await.unwrap_or_default();
let truncated =
crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
bail!("Notion API error {status_code}: {truncated}");
}
last_err = Some(anyhow::anyhow!("Notion API error: {status_code}"));
}
Err(e) => {
last_err = Some(anyhow::anyhow!("HTTP request failed: {e}"));
}
}
let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
tracing::warn!(
"Notion API call failed (attempt {}/{}), retrying in {}ms",
attempt + 1,
MAX_RETRIES,
delay
);
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
}
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries")))
}
/// Query the database schema and detect whether Status uses "select" or "status" type.
async fn detect_status_type(&self) -> Result<String> {
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
let status_type = resp
.get("properties")
.and_then(|p| p.get(&self.status_property))
.and_then(|s| s.get("type"))
.and_then(|t| t.as_str())
.unwrap_or("select")
.to_string();
Ok(status_type)
}
/// Query for rows where Status = "pending".
async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
let status_type = self.status_type.read().await.clone();
let filter = build_status_filter(&self.status_property, &status_type, "pending");
let resp = self
.api_call(
reqwest::Method::POST,
&url,
Some(serde_json::json!({ "filter": filter })),
)
.await?;
Ok(resp
.get("results")
.and_then(|r| r.as_array())
.cloned()
.unwrap_or_default())
}
/// Atomically claim a task. Returns true if this caller got it.
async fn claim_task(&self, page_id: &str) -> bool {
let mut inflight = self.inflight.write().await;
if inflight.contains(page_id) {
return false;
}
if inflight.len() >= self.max_concurrent {
return false;
}
inflight.insert(page_id.to_string());
true
}
/// Release a task from the inflight set.
async fn release_task(&self, page_id: &str) {
let mut inflight = self.inflight.write().await;
inflight.remove(page_id);
}
/// Update a row's status.
async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let status_type = self.status_type.read().await.clone();
let payload = serde_json::json!({
"properties": {
&self.status_property: build_status_payload(&status_type, status_value),
}
});
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
.await?;
Ok(())
}
/// Write result text to the Result column.
async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let payload = serde_json::json!({
"properties": {
&self.result_property: build_rich_text_payload(result_text),
}
});
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
.await?;
Ok(())
}
/// On startup, reset "running" tasks back to "pending" for crash recovery.
async fn recover_stale(&self) -> Result<()> {
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
let status_type = self.status_type.read().await.clone();
let filter = build_status_filter(&self.status_property, &status_type, "running");
let resp = self
.api_call(
reqwest::Method::POST,
&url,
Some(serde_json::json!({ "filter": filter })),
)
.await?;
let stale = resp
.get("results")
.and_then(|r| r.as_array())
.cloned()
.unwrap_or_default();
if stale.is_empty() {
return Ok(());
}
tracing::warn!(
"Found {} stale task(s) in 'running' state, resetting to 'pending'",
stale.len()
);
for task in &stale {
if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
let payload = serde_json::json!({
"properties": {
&self.status_property: build_status_payload(&status_type, "pending"),
&self.result_property: build_rich_text_payload(
"Reset: poller restarted while task was running"
),
}
});
let short_id_end = floor_utf8_char_boundary(page_id, 8);
let short_id = &page_id[..short_id_end];
if let Err(e) = self
.api_call(reqwest::Method::PATCH, &page_url, Some(payload))
.await
{
tracing::error!("Could not reset stale task {short_id}: {e}");
} else {
tracing::info!("Reset stale task {short_id} to pending");
}
}
}
Ok(())
}
}
#[async_trait]
impl Channel for NotionChannel {
fn name(&self) -> &str {
"notion"
}
async fn send(&self, message: &SendMessage) -> Result<()> {
// recipient is the page_id for Notion
let page_id = &message.recipient;
let status_type = self.status_type.read().await.clone();
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let payload = serde_json::json!({
"properties": {
&self.status_property: build_status_payload(&status_type, "done"),
&self.result_property: build_rich_text_payload(&message.content),
}
});
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
.await?;
self.release_task(page_id).await;
Ok(())
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
// Detect status property type
match self.detect_status_type().await {
Ok(st) => {
tracing::info!("Notion status property type: {st}");
*self.status_type.write().await = st;
}
Err(e) => {
bail!("Failed to detect Notion database schema: {e}");
}
}
// Crash recovery
if self.recover_stale {
if let Err(e) = self.recover_stale().await {
tracing::error!("Notion stale task recovery failed: {e}");
}
}
// Polling loop
loop {
match self.query_pending().await {
Ok(tasks) => {
if !tasks.is_empty() {
tracing::info!("Notion: found {} pending task(s)", tasks.len());
}
for task in tasks {
let page_id = match task.get("id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => continue,
};
let input_text = extract_text_from_property(
task.get("properties")
.and_then(|p| p.get(&self.input_property)),
);
if input_text.trim().is_empty() {
let short_end = floor_utf8_char_boundary(&page_id, 8);
tracing::warn!(
"Notion: empty input for task {}, skipping",
&page_id[..short_end]
);
continue;
}
if !self.claim_task(&page_id).await {
continue;
}
// Set status to running
if let Err(e) = self.set_status(&page_id, "running").await {
tracing::error!("Notion: failed to set running status: {e}");
self.release_task(&page_id).await;
continue;
}
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if tx
.send(ChannelMessage {
id: page_id.clone(),
sender: "notion".into(),
reply_target: page_id,
content: input_text,
channel: "notion".into(),
timestamp,
thread_ts: None,
})
.await
.is_err()
{
tracing::info!("Notion channel shutting down");
return Ok(());
}
}
}
Err(e) => {
tracing::error!("Notion poll error: {e}");
}
}
tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
}
}
async fn health_check(&self) -> bool {
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
self.api_call(reqwest::Method::GET, &url, None)
.await
.is_ok()
}
}
// ── Helper functions ──────────────────────────────────────────────
/// Build a Notion API filter object for the given status property.
fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
if status_type == "status" {
serde_json::json!({
"property": property,
"status": { "equals": value }
})
} else {
serde_json::json!({
"property": property,
"select": { "equals": value }
})
}
}
/// Build a Notion API property-update payload for a status field.
fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
if status_type == "status" {
serde_json::json!({ "status": { "name": value } })
} else {
serde_json::json!({ "select": { "name": value } })
}
}
/// Build a Notion API rich-text property payload, truncating if necessary.
fn build_rich_text_payload(value: &str) -> serde_json::Value {
let truncated = truncate_result(value);
serde_json::json!({
"rich_text": [{
"text": { "content": truncated }
}]
})
}
/// Truncate result text to fit within the Notion rich-text content limit.
fn truncate_result(value: &str) -> String {
if value.len() <= MAX_RESULT_LENGTH {
return value.to_string();
}
let cut = MAX_RESULT_LENGTH.saturating_sub(30);
// Ensure we cut on a char boundary
let end = floor_utf8_char_boundary(value, cut);
format!("{}\n\n... [output truncated]", &value[..end])
}
/// Extract plain text from a Notion property (title or rich_text type).
fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
let Some(prop) = prop else {
return String::new();
};
let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
let array_key = match ptype {
"title" => "title",
"rich_text" => "rich_text",
_ => return String::new(),
};
prop.get(array_key)
.and_then(|arr| arr.as_array())
.map(|items| {
items
.iter()
.filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("")
})
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn claim_task_deduplication() {
let channel = NotionChannel::new(
"test-key".into(),
"test-db".into(),
5,
"Status".into(),
"Input".into(),
"Result".into(),
4,
false,
);
assert!(channel.claim_task("page-1").await);
// Second claim for same page should fail
assert!(!channel.claim_task("page-1").await);
// Different page should succeed
assert!(channel.claim_task("page-2").await);
// After release, can claim again
channel.release_task("page-1").await;
assert!(channel.claim_task("page-1").await);
}
#[test]
fn result_truncation_within_limit() {
let short = "hello world";
assert_eq!(truncate_result(short), short);
}
#[test]
fn result_truncation_over_limit() {
let long = "a".repeat(MAX_RESULT_LENGTH + 100);
let truncated = truncate_result(&long);
assert!(truncated.len() <= MAX_RESULT_LENGTH);
assert!(truncated.ends_with("... [output truncated]"));
}
#[test]
fn result_truncation_multibyte_safe() {
// Build a string that would cut in the middle of a multibyte char
let mut s = String::new();
for _ in 0..700 {
s.push('\u{6E2C}'); // 3-byte UTF-8 char
}
let truncated = truncate_result(&s);
// Should not panic and should be valid UTF-8
assert!(truncated.len() <= MAX_RESULT_LENGTH);
assert!(truncated.ends_with("... [output truncated]"));
}
#[test]
fn status_payload_select_type() {
let payload = build_status_payload("select", "pending");
assert_eq!(
payload,
serde_json::json!({ "select": { "name": "pending" } })
);
}
#[test]
fn status_payload_status_type() {
let payload = build_status_payload("status", "done");
assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
}
#[test]
fn rich_text_payload_construction() {
let payload = build_rich_text_payload("test output");
let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
assert_eq!(text, "test output");
}
#[test]
fn status_filter_select_type() {
let filter = build_status_filter("Status", "select", "pending");
assert_eq!(
filter,
serde_json::json!({
"property": "Status",
"select": { "equals": "pending" }
})
);
}
#[test]
fn status_filter_status_type() {
let filter = build_status_filter("Status", "status", "running");
assert_eq!(
filter,
serde_json::json!({
"property": "Status",
"status": { "equals": "running" }
})
);
}
#[test]
fn extract_text_from_title_property() {
let prop = serde_json::json!({
"type": "title",
"title": [
{ "plain_text": "Hello " },
{ "plain_text": "World" }
]
});
assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
}
#[test]
fn extract_text_from_rich_text_property() {
let prop = serde_json::json!({
"type": "rich_text",
"rich_text": [{ "plain_text": "task content" }]
});
assert_eq!(extract_text_from_property(Some(&prop)), "task content");
}
#[test]
fn extract_text_from_none() {
assert_eq!(extract_text_from_property(None), "");
}
#[test]
fn extract_text_from_unknown_type() {
let prop = serde_json::json!({ "type": "number", "number": 42 });
assert_eq!(extract_text_from_property(Some(&prop)), "");
}
#[tokio::test]
async fn claim_task_respects_max_concurrent() {
let channel = NotionChannel::new(
"test-key".into(),
"test-db".into(),
5,
"Status".into(),
"Input".into(),
"Result".into(),
2, // max_concurrent = 2
false,
);
assert!(channel.claim_task("page-1").await);
assert!(channel.claim_task("page-2").await);
// Third claim should be rejected (at capacity)
assert!(!channel.claim_task("page-3").await);
// After releasing one, can claim again
channel.release_task("page-1").await;
assert!(channel.claim_task("page-3").await);
}
}
+10 -9
View File
@@ -12,15 +12,16 @@ pub use schema::{
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig,
McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig,
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope,
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig,
WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig,
};
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
+747 -5
View File
@@ -188,6 +188,10 @@ pub struct Config {
#[serde(default)]
pub composio: ComposioConfig,
/// Microsoft 365 Graph API integration (`[microsoft365]`).
#[serde(default)]
pub microsoft365: Microsoft365Config,
/// Secrets encryption configuration (`[secrets]`).
#[serde(default)]
pub secrets: SecretsConfig,
@@ -212,6 +216,10 @@ pub struct Config {
#[serde(default)]
pub web_search: WebSearchConfig,
/// Project delivery intelligence configuration (`[project_intel]`).
#[serde(default)]
pub project_intel: ProjectIntelConfig,
/// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]`).
#[serde(default)]
pub proxy: ProxyConfig,
@@ -1348,6 +1356,67 @@ impl Default for GatewayConfig {
}
}
/// Secure transport configuration for inter-node communication (`[node_transport]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct NodeTransportConfig {
/// Enable the secure transport layer.
#[serde(default = "default_node_transport_enabled")]
pub enabled: bool,
/// Shared secret for HMAC authentication between nodes.
#[serde(default)]
pub shared_secret: String,
/// Maximum age of signed requests in seconds (replay protection).
#[serde(default = "default_max_request_age")]
pub max_request_age_secs: i64,
/// Require HTTPS for all node communication.
#[serde(default = "default_require_https")]
pub require_https: bool,
/// Allow specific node IPs/CIDRs.
#[serde(default)]
pub allowed_peers: Vec<String>,
/// Path to TLS certificate file.
#[serde(default)]
pub tls_cert_path: Option<String>,
/// Path to TLS private key file.
#[serde(default)]
pub tls_key_path: Option<String>,
/// Require client certificates (mutual TLS).
#[serde(default)]
pub mutual_tls: bool,
/// Maximum number of connections per peer.
#[serde(default = "default_connection_pool_size")]
pub connection_pool_size: usize,
}
fn default_node_transport_enabled() -> bool {
true
}
fn default_max_request_age() -> i64 {
300
}
fn default_require_https() -> bool {
true
}
fn default_connection_pool_size() -> usize {
4
}
impl Default for NodeTransportConfig {
fn default() -> Self {
Self {
enabled: default_node_transport_enabled(),
shared_secret: String::new(),
max_request_age_secs: default_max_request_age(),
require_https: default_require_https(),
allowed_peers: Vec::new(),
tls_cert_path: None,
tls_key_path: None,
mutual_tls: false,
connection_pool_size: default_connection_pool_size(),
}
}
}
// ── Composio (managed tool surface) ─────────────────────────────
/// Composio managed OAuth tools integration (`[composio]` section).
@@ -1380,6 +1449,78 @@ impl Default for ComposioConfig {
}
}
// ── Microsoft 365 (Graph API integration) ───────────────────────
/// Microsoft 365 integration via Microsoft Graph API (`[microsoft365]` section).
///
/// Provides access to Outlook mail, Teams messages, Calendar events,
/// OneDrive files, and SharePoint search.
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
pub struct Microsoft365Config {
/// Enable Microsoft 365 integration
#[serde(default, alias = "enable")]
pub enabled: bool,
/// Azure AD tenant ID
#[serde(default)]
pub tenant_id: Option<String>,
/// Azure AD application (client) ID
#[serde(default)]
pub client_id: Option<String>,
/// Azure AD client secret (stored encrypted when secrets.encrypt = true)
#[serde(default)]
pub client_secret: Option<String>,
/// Authentication flow: "client_credentials" or "device_code"
#[serde(default = "default_ms365_auth_flow")]
pub auth_flow: String,
/// OAuth scopes to request
#[serde(default = "default_ms365_scopes")]
pub scopes: Vec<String>,
/// Encrypt the token cache file on disk
#[serde(default = "default_true")]
pub token_cache_encrypted: bool,
/// User principal name or "me" (for delegated flows)
#[serde(default)]
pub user_id: Option<String>,
}
fn default_ms365_auth_flow() -> String {
"client_credentials".to_string()
}
fn default_ms365_scopes() -> Vec<String> {
vec!["https://graph.microsoft.com/.default".to_string()]
}
impl std::fmt::Debug for Microsoft365Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Microsoft365Config")
.field("enabled", &self.enabled)
.field("tenant_id", &self.tenant_id)
.field("client_id", &self.client_id)
.field("client_secret", &self.client_secret.as_ref().map(|_| "***"))
.field("auth_flow", &self.auth_flow)
.field("scopes", &self.scopes)
.field("token_cache_encrypted", &self.token_cache_encrypted)
.field("user_id", &self.user_id)
.finish()
}
}
impl Default for Microsoft365Config {
fn default() -> Self {
Self {
enabled: false,
tenant_id: None,
client_id: None,
client_secret: None,
auth_flow: default_ms365_auth_flow(),
scopes: default_ms365_scopes(),
token_cache_encrypted: true,
user_id: None,
}
}
}
// ── Secrets (encrypted credential store) ────────────────────────
/// Secrets encryption configuration (`[secrets]` section).
@@ -1644,6 +1785,64 @@ impl Default for WebSearchConfig {
}
}
// ── Project Intelligence ────────────────────────────────────────
/// Project delivery intelligence configuration (`[project_intel]` section).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ProjectIntelConfig {
/// Enable the project_intel tool. Default: false.
#[serde(default)]
pub enabled: bool,
/// Default report language (en, de, fr, it). Default: "en".
#[serde(default = "default_project_intel_language")]
pub default_language: String,
/// Output directory for generated reports.
#[serde(default = "default_project_intel_report_dir")]
pub report_output_dir: String,
/// Optional custom templates directory.
#[serde(default)]
pub templates_dir: Option<String>,
/// Risk detection sensitivity: low, medium, high. Default: "medium".
#[serde(default = "default_project_intel_risk_sensitivity")]
pub risk_sensitivity: String,
/// Include git log data in reports. Default: true.
#[serde(default = "default_true")]
pub include_git_data: bool,
/// Include Jira data in reports. Default: false.
#[serde(default)]
pub include_jira_data: bool,
/// Jira instance base URL (required if include_jira_data is true).
#[serde(default)]
pub jira_base_url: Option<String>,
}
fn default_project_intel_language() -> String {
"en".into()
}
fn default_project_intel_report_dir() -> String {
"~/.zeroclaw/project-reports".into()
}
fn default_project_intel_risk_sensitivity() -> String {
"medium".into()
}
impl Default for ProjectIntelConfig {
fn default() -> Self {
Self {
enabled: false,
default_language: default_project_intel_language(),
report_output_dir: default_project_intel_report_dir(),
templates_dir: None,
risk_sensitivity: default_project_intel_risk_sensitivity(),
include_git_data: true,
include_jira_data: false,
jira_base_url: None,
}
}
}
// ── Proxy ───────────────────────────────────────────────────────
/// Proxy application scope — determines which outbound traffic uses the proxy.
@@ -3073,10 +3272,10 @@ impl Default for CronConfig {
/// Tunnel configuration for exposing the gateway publicly (`[tunnel]` section).
///
/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"custom"`.
/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, `"custom"`.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct TunnelConfig {
/// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, or `"custom"`. Default: `"none"`.
/// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, or `"custom"`. Default: `"none"`.
pub provider: String,
/// Cloudflare Tunnel configuration (used when `provider = "cloudflare"`).
@@ -3091,6 +3290,10 @@ pub struct TunnelConfig {
#[serde(default)]
pub ngrok: Option<NgrokTunnelConfig>,
/// OpenVPN tunnel configuration (used when `provider = "openvpn"`).
#[serde(default)]
pub openvpn: Option<OpenVpnTunnelConfig>,
/// Custom tunnel command configuration (used when `provider = "custom"`).
#[serde(default)]
pub custom: Option<CustomTunnelConfig>,
@@ -3103,6 +3306,7 @@ impl Default for TunnelConfig {
cloudflare: None,
tailscale: None,
ngrok: None,
openvpn: None,
custom: None,
}
}
@@ -3131,6 +3335,36 @@ pub struct NgrokTunnelConfig {
pub domain: Option<String>,
}
/// OpenVPN tunnel configuration (`[tunnel.openvpn]`).
///
/// Required when `tunnel.provider = "openvpn"`. Omitting this section entirely
/// preserves previous behavior. Setting `tunnel.provider = "none"` (or removing
/// the `[tunnel.openvpn]` block) cleanly reverts to no-tunnel mode.
///
/// Defaults: `connect_timeout_secs = 30`.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OpenVpnTunnelConfig {
/// Path to `.ovpn` configuration file (must not be empty).
pub config_file: String,
/// Optional path to auth credentials file (`--auth-user-pass`).
#[serde(default)]
pub auth_file: Option<String>,
/// Advertised address once VPN is connected (e.g., `"10.8.0.2:42617"`).
/// When omitted the tunnel falls back to `http://{local_host}:{local_port}`.
#[serde(default)]
pub advertise_address: Option<String>,
/// Connection timeout in seconds (default: 30, must be > 0).
#[serde(default = "default_openvpn_timeout")]
pub connect_timeout_secs: u64,
/// Extra openvpn CLI arguments forwarded verbatim.
#[serde(default)]
pub extra_args: Vec<String>,
}
fn default_openvpn_timeout() -> u64 {
30
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct CustomTunnelConfig {
/// Command template to start the tunnel. Use {port} and {host} placeholders.
@@ -3910,6 +4144,10 @@ pub struct SecurityConfig {
/// Emergency-stop state machine configuration.
#[serde(default)]
pub estop: EstopConfig,
/// Nevis IAM integration for SSO/MFA authentication and role-based access.
#[serde(default)]
pub nevis: NevisConfig,
}
/// OTP validation strategy.
@@ -4021,6 +4259,163 @@ impl Default for EstopConfig {
}
}
/// Nevis IAM integration configuration.
///
/// When `enabled` is true, ZeroClaw validates incoming requests against a Nevis
/// Security Suite instance and maps Nevis roles to tool/workspace permissions.
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct NevisConfig {
/// Enable Nevis IAM integration. Defaults to false for backward compatibility.
#[serde(default)]
pub enabled: bool,
/// Base URL of the Nevis instance (e.g. `https://nevis.example.com`).
#[serde(default)]
pub instance_url: String,
/// Nevis realm to authenticate against.
#[serde(default = "default_nevis_realm")]
pub realm: String,
/// OAuth2 client ID registered in Nevis.
#[serde(default)]
pub client_id: String,
/// OAuth2 client secret. Encrypted via SecretStore when stored on disk.
#[serde(default)]
pub client_secret: Option<String>,
/// Token validation strategy: `"local"` (JWKS) or `"remote"` (introspection).
#[serde(default = "default_nevis_token_validation")]
pub token_validation: String,
/// JWKS endpoint URL for local token validation.
#[serde(default)]
pub jwks_url: Option<String>,
/// Nevis role to ZeroClaw permission mappings.
#[serde(default)]
pub role_mapping: Vec<NevisRoleMappingConfig>,
/// Require MFA verification for all Nevis-authenticated requests.
#[serde(default)]
pub require_mfa: bool,
/// Session timeout in seconds.
#[serde(default = "default_nevis_session_timeout_secs")]
pub session_timeout_secs: u64,
}
impl std::fmt::Debug for NevisConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NevisConfig")
.field("enabled", &self.enabled)
.field("instance_url", &self.instance_url)
.field("realm", &self.realm)
.field("client_id", &self.client_id)
.field(
"client_secret",
&self.client_secret.as_ref().map(|_| "[REDACTED]"),
)
.field("token_validation", &self.token_validation)
.field("jwks_url", &self.jwks_url)
.field("role_mapping", &self.role_mapping)
.field("require_mfa", &self.require_mfa)
.field("session_timeout_secs", &self.session_timeout_secs)
.finish()
}
}
impl NevisConfig {
/// Validate that required fields are present when Nevis is enabled.
///
/// Call at config load time to fail fast on invalid configuration rather
/// than deferring errors to the first authentication request.
pub fn validate(&self) -> Result<(), String> {
if !self.enabled {
return Ok(());
}
if self.instance_url.trim().is_empty() {
return Err("nevis.instance_url is required when Nevis IAM is enabled".into());
}
if self.client_id.trim().is_empty() {
return Err("nevis.client_id is required when Nevis IAM is enabled".into());
}
if self.realm.trim().is_empty() {
return Err("nevis.realm is required when Nevis IAM is enabled".into());
}
match self.token_validation.as_str() {
"local" | "remote" => {}
other => {
return Err(format!(
"nevis.token_validation has invalid value '{other}': \
expected 'local' or 'remote'"
));
}
}
if self.token_validation == "local" && self.jwks_url.is_none() {
return Err("nevis.jwks_url is required when token_validation is 'local'".into());
}
if self.session_timeout_secs == 0 {
return Err("nevis.session_timeout_secs must be greater than 0".into());
}
Ok(())
}
}
fn default_nevis_realm() -> String {
"master".into()
}
fn default_nevis_token_validation() -> String {
"local".into()
}
fn default_nevis_session_timeout_secs() -> u64 {
3600
}
impl Default for NevisConfig {
fn default() -> Self {
Self {
enabled: false,
instance_url: String::new(),
realm: default_nevis_realm(),
client_id: String::new(),
client_secret: None,
token_validation: default_nevis_token_validation(),
jwks_url: None,
role_mapping: Vec::new(),
require_mfa: false,
session_timeout_secs: default_nevis_session_timeout_secs(),
}
}
}
/// Maps a Nevis role to ZeroClaw tool permissions and workspace access.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct NevisRoleMappingConfig {
/// Nevis role name (case-insensitive).
pub nevis_role: String,
/// Tool names this role can access. Use `"all"` for unrestricted tool access.
#[serde(default)]
pub zeroclaw_permissions: Vec<String>,
/// Workspace names this role can access. Use `"all"` for unrestricted.
#[serde(default)]
pub workspace_access: Vec<String>,
}
/// Sandbox configuration for OS-level isolation
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SandboxConfig {
@@ -4348,12 +4743,14 @@ impl Default for Config {
tunnel: TunnelConfig::default(),
gateway: GatewayConfig::default(),
composio: ComposioConfig::default(),
microsoft365: Microsoft365Config::default(),
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: HttpRequestConfig::default(),
multimodal: MultimodalConfig::default(),
web_fetch: WebFetchConfig::default(),
web_search: WebSearchConfig::default(),
project_intel: ProjectIntelConfig::default(),
proxy: ProxyConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
@@ -4844,6 +5241,11 @@ impl Config {
&mut config.composio.api_key,
"config.composio.api_key",
)?;
decrypt_optional_secret(
&store,
&mut config.microsoft365.client_secret,
"config.microsoft365.client_secret",
)?;
decrypt_optional_secret(
&store,
@@ -5107,6 +5509,18 @@ impl Config {
decrypt_secret(&store, token, "config.gateway.paired_tokens[]")?;
}
// Decrypt Nevis IAM secret
decrypt_optional_secret(
&store,
&mut config.security.nevis.client_secret,
"config.security.nevis.client_secret",
)?;
// Notion API key (top-level, not in ChannelsConfig)
if !config.notion.api_key.is_empty() {
decrypt_secret(&store, &mut config.notion.api_key, "config.notion.api_key")?;
}
config.apply_env_overrides();
config.validate()?;
tracing::info!(
@@ -5242,6 +5656,20 @@ impl Config {
/// Called after TOML deserialization and env-override application to catch
/// obviously invalid values early instead of failing at arbitrary runtime points.
pub fn validate(&self) -> Result<()> {
// Tunnel — OpenVPN
if self.tunnel.provider.trim() == "openvpn" {
let openvpn = self.tunnel.openvpn.as_ref().ok_or_else(|| {
anyhow::anyhow!("tunnel.provider='openvpn' requires [tunnel.openvpn]")
})?;
if openvpn.config_file.trim().is_empty() {
anyhow::bail!("tunnel.openvpn.config_file must not be empty");
}
if openvpn.connect_timeout_secs == 0 {
anyhow::bail!("tunnel.openvpn.connect_timeout_secs must be greater than 0");
}
}
// Gateway
if self.gateway.host.trim().is_empty() {
anyhow::bail!("gateway.host must not be empty");
@@ -5398,6 +5826,88 @@ impl Config {
}
}
// Microsoft 365
if self.microsoft365.enabled {
let tenant = self
.microsoft365
.tenant_id
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty());
if tenant.is_none() {
anyhow::bail!(
"microsoft365.tenant_id must not be empty when microsoft365 is enabled"
);
}
let client = self
.microsoft365
.client_id
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty());
if client.is_none() {
anyhow::bail!(
"microsoft365.client_id must not be empty when microsoft365 is enabled"
);
}
let flow = self.microsoft365.auth_flow.trim();
if flow != "client_credentials" && flow != "device_code" {
anyhow::bail!(
"microsoft365.auth_flow must be 'client_credentials' or 'device_code'"
);
}
if flow == "client_credentials"
&& self
.microsoft365
.client_secret
.as_deref()
.map_or(true, |s| s.trim().is_empty())
{
anyhow::bail!(
"microsoft365.client_secret must not be empty when auth_flow is 'client_credentials'"
);
}
}
// Microsoft 365
if self.microsoft365.enabled {
let tenant = self
.microsoft365
.tenant_id
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty());
if tenant.is_none() {
anyhow::bail!(
"microsoft365.tenant_id must not be empty when microsoft365 is enabled"
);
}
let client = self
.microsoft365
.client_id
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty());
if client.is_none() {
anyhow::bail!(
"microsoft365.client_id must not be empty when microsoft365 is enabled"
);
}
let flow = self.microsoft365.auth_flow.trim();
if flow != "client_credentials" && flow != "device_code" {
anyhow::bail!("microsoft365.auth_flow must be client_credentials or device_code");
}
if flow == "client_credentials"
&& self
.microsoft365
.client_secret
.as_deref()
.map_or(true, |s| s.trim().is_empty())
{
anyhow::bail!("microsoft365.client_secret must not be empty when auth_flow is client_credentials");
}
}
// MCP
if self.mcp.enabled {
validate_mcp_config(&self.mcp)?;
@@ -5431,9 +5941,31 @@ impl Config {
// Proxy (delegate to existing validation)
self.proxy.validate()?;
// MCP servers
if self.mcp.enabled {
validate_mcp_config(&self.mcp)?;
// Notion
if self.notion.enabled {
if self.notion.database_id.trim().is_empty() {
anyhow::bail!("notion.database_id must not be empty when notion.enabled = true");
}
if self.notion.poll_interval_secs == 0 {
anyhow::bail!("notion.poll_interval_secs must be greater than 0");
}
if self.notion.max_concurrent == 0 {
anyhow::bail!("notion.max_concurrent must be greater than 0");
}
if self.notion.status_property.trim().is_empty() {
anyhow::bail!("notion.status_property must not be empty");
}
if self.notion.input_property.trim().is_empty() {
anyhow::bail!("notion.input_property must not be empty");
}
if self.notion.result_property.trim().is_empty() {
anyhow::bail!("notion.result_property must not be empty");
}
}
// Nevis IAM — delegate to NevisConfig::validate() for field-level checks
if let Err(msg) = self.security.nevis.validate() {
anyhow::bail!("security.nevis: {msg}");
}
Ok(())
@@ -5802,6 +6334,11 @@ impl Config {
&mut config_to_save.composio.api_key,
"config.composio.api_key",
)?;
encrypt_optional_secret(
&store,
&mut config_to_save.microsoft365.client_secret,
"config.microsoft365.client_secret",
)?;
encrypt_optional_secret(
&store,
@@ -6065,6 +6602,22 @@ impl Config {
encrypt_secret(&store, token, "config.gateway.paired_tokens[]")?;
}
// Encrypt Nevis IAM secret
encrypt_optional_secret(
&store,
&mut config_to_save.security.nevis.client_secret,
"config.security.nevis.client_secret",
)?;
// Notion API key (top-level, not in ChannelsConfig)
if !config_to_save.notion.api_key.is_empty() {
encrypt_secret(
&store,
&mut config_to_save.notion.api_key,
"config.notion.api_key",
)?;
}
let toml_str =
toml::to_string_pretty(&config_to_save).context("Failed to serialize config")?;
@@ -6152,6 +6705,7 @@ impl Config {
}
}
#[allow(clippy::unused_async)] // async needed on unix for tokio File I/O; no-op on other platforms
async fn sync_directory(path: &Path) -> Result<()> {
#[cfg(unix)]
{
@@ -6177,6 +6731,7 @@ mod tests {
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
#[cfg(unix)]
use tempfile::TempDir;
use tokio::sync::{Mutex, MutexGuard};
use tokio::test;
@@ -6494,12 +7049,14 @@ default_temperature = 0.7
tunnel: TunnelConfig::default(),
gateway: GatewayConfig::default(),
composio: ComposioConfig::default(),
microsoft365: Microsoft365Config::default(),
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: HttpRequestConfig::default(),
multimodal: MultimodalConfig::default(),
web_fetch: WebFetchConfig::default(),
web_search: WebSearchConfig::default(),
project_intel: ProjectIntelConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
identity: IdentityConfig::default(),
@@ -6788,12 +7345,14 @@ tool_dispatcher = "xml"
tunnel: TunnelConfig::default(),
gateway: GatewayConfig::default(),
composio: ComposioConfig::default(),
microsoft365: Microsoft365Config::default(),
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: HttpRequestConfig::default(),
multimodal: MultimodalConfig::default(),
web_fetch: WebFetchConfig::default(),
web_search: WebSearchConfig::default(),
project_intel: ProjectIntelConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
identity: IdentityConfig::default(),
@@ -9623,4 +10182,187 @@ require_otp_to_resume = true
assert_eq!(config.swarms.len(), 1);
assert!(config.swarms.contains_key("pipeline"));
}
#[tokio::test]
async fn nevis_client_secret_encrypt_decrypt_roundtrip() {
let dir = std::env::temp_dir().join(format!(
"zeroclaw_test_nevis_secret_{}",
uuid::Uuid::new_v4()
));
fs::create_dir_all(&dir).await.unwrap();
let plaintext_secret = "nevis-test-client-secret-value";
let mut config = Config::default();
config.workspace_dir = dir.join("workspace");
config.config_path = dir.join("config.toml");
config.security.nevis.client_secret = Some(plaintext_secret.into());
// Save (triggers encryption)
config.save().await.unwrap();
// Read raw TOML and verify plaintext secret is NOT present
let raw_toml = tokio::fs::read_to_string(&config.config_path)
.await
.unwrap();
assert!(
!raw_toml.contains(plaintext_secret),
"Saved TOML must not contain the plaintext client_secret"
);
// Parse stored TOML and verify the value is encrypted
let stored: Config = toml::from_str(&raw_toml).unwrap();
let stored_secret = stored.security.nevis.client_secret.as_ref().unwrap();
assert!(
crate::security::SecretStore::is_encrypted(stored_secret),
"Stored client_secret must be marked as encrypted"
);
// Decrypt and verify it matches the original plaintext
let store = crate::security::SecretStore::new(&dir, true);
assert_eq!(store.decrypt(stored_secret).unwrap(), plaintext_secret);
// Simulate a full load: deserialize then decrypt (mirrors load_or_init logic)
let mut loaded: Config = toml::from_str(&raw_toml).unwrap();
loaded.config_path = dir.join("config.toml");
let load_store = crate::security::SecretStore::new(&dir, loaded.secrets.encrypt);
decrypt_optional_secret(
&load_store,
&mut loaded.security.nevis.client_secret,
"config.security.nevis.client_secret",
)
.unwrap();
assert_eq!(
loaded.security.nevis.client_secret.as_deref().unwrap(),
plaintext_secret,
"Loaded client_secret must match the original plaintext after decryption"
);
let _ = fs::remove_dir_all(&dir).await;
}
// ══════════════════════════════════════════════════════════
// Nevis config validation tests
// ══════════════════════════════════════════════════════════
#[test]
async fn nevis_config_validate_disabled_accepts_empty_fields() {
let cfg = NevisConfig::default();
assert!(!cfg.enabled);
assert!(cfg.validate().is_ok());
}
#[test]
async fn nevis_config_validate_rejects_empty_instance_url() {
let cfg = NevisConfig {
enabled: true,
instance_url: String::new(),
client_id: "test-client".into(),
..NevisConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(err.contains("instance_url"));
}
#[test]
async fn nevis_config_validate_rejects_empty_client_id() {
let cfg = NevisConfig {
enabled: true,
instance_url: "https://nevis.example.com".into(),
client_id: String::new(),
..NevisConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(err.contains("client_id"));
}
#[test]
async fn nevis_config_validate_rejects_empty_realm() {
let cfg = NevisConfig {
enabled: true,
instance_url: "https://nevis.example.com".into(),
client_id: "test-client".into(),
realm: String::new(),
..NevisConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(err.contains("realm"));
}
#[test]
async fn nevis_config_validate_rejects_local_without_jwks() {
let cfg = NevisConfig {
enabled: true,
instance_url: "https://nevis.example.com".into(),
client_id: "test-client".into(),
token_validation: "local".into(),
jwks_url: None,
..NevisConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(err.contains("jwks_url"));
}
#[test]
async fn nevis_config_validate_rejects_zero_session_timeout() {
let cfg = NevisConfig {
enabled: true,
instance_url: "https://nevis.example.com".into(),
client_id: "test-client".into(),
token_validation: "remote".into(),
session_timeout_secs: 0,
..NevisConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(err.contains("session_timeout_secs"));
}
#[test]
async fn nevis_config_validate_accepts_valid_enabled_config() {
let cfg = NevisConfig {
enabled: true,
instance_url: "https://nevis.example.com".into(),
realm: "master".into(),
client_id: "test-client".into(),
token_validation: "remote".into(),
session_timeout_secs: 3600,
..NevisConfig::default()
};
assert!(cfg.validate().is_ok());
}
#[test]
async fn nevis_config_validate_rejects_invalid_token_validation() {
let cfg = NevisConfig {
enabled: true,
instance_url: "https://nevis.example.com".into(),
realm: "master".into(),
client_id: "test-client".into(),
token_validation: "invalid_mode".into(),
session_timeout_secs: 3600,
..NevisConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(
err.contains("invalid value 'invalid_mode'"),
"Expected invalid token_validation error, got: {err}"
);
}
#[test]
async fn nevis_config_debug_redacts_client_secret() {
let cfg = NevisConfig {
client_secret: Some("super-secret".into()),
..NevisConfig::default()
};
let debug_output = format!("{:?}", cfg);
assert!(
!debug_output.contains("super-secret"),
"Debug output must not contain the raw client_secret"
);
assert!(
debug_output.contains("[REDACTED]"),
"Debug output must show [REDACTED] for client_secret"
);
}
}
+5 -5
View File
@@ -181,7 +181,7 @@ async fn run_agent_job(
let run_result = match job.session_target {
SessionTarget::Main | SessionTarget::Isolated => {
crate::agent::run(
Box::pin(crate::agent::run(
config.clone(),
Some(prefixed_prompt),
None,
@@ -191,7 +191,7 @@ async fn run_agent_job(
false,
None,
job.allowed_tools.clone(),
)
))
.await
}
};
@@ -784,7 +784,7 @@ mod tests {
job.prompt = Some("Say hello".into());
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_agent_job(&config, &security, &job).await;
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("agent job failed:"));
}
@@ -799,7 +799,7 @@ mod tests {
job.prompt = Some("Say hello".into());
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_agent_job(&config, &security, &job).await;
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("blocked by security policy"));
assert!(output.contains("read-only"));
@@ -815,7 +815,7 @@ mod tests {
job.prompt = Some("Say hello".into());
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_agent_job(&config, &security, &job).await;
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("blocked by security policy"));
assert!(output.contains("rate limit exceeded"));
+5 -5
View File
@@ -77,7 +77,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
max_backoff,
move || {
let cfg = channels_cfg.clone();
async move { crate::channels::start_channels(cfg).await }
async move { Box::pin(crate::channels::start_channels(cfg)).await }
},
));
} else {
@@ -245,7 +245,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
// ── Phase 1: LLM decision (two-phase mode) ──────────────
let tasks_to_run = if two_phase {
let decision_prompt = HeartbeatEngine::build_decision_prompt(&tasks);
match crate::agent::run(
match Box::pin(crate::agent::run(
config.clone(),
Some(decision_prompt),
None,
@@ -255,7 +255,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
false,
None,
None,
)
))
.await
{
Ok(response) => {
@@ -288,7 +288,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
for task in &tasks_to_run {
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
let temp = config.default_temperature;
match crate::agent::run(
match Box::pin(crate::agent::run(
config.clone(),
Some(prompt),
None,
@@ -298,7 +298,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
false,
None,
None,
)
))
.await
{
Ok(output) => {
+14 -10
View File
@@ -910,7 +910,7 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
/// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk).
async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result<String> {
let config = state.config.lock().clone();
crate::agent::process_message(config, message).await
Box::pin(crate::agent::process_message(config, message)).await
}
/// Webhook request body
@@ -1238,7 +1238,7 @@ async fn handle_whatsapp_message(
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa
@@ -1346,7 +1346,7 @@ async fn handle_linq_webhook(
}
// Call the LLM
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
// Send reply via Linq
if let Err(e) = linq
@@ -1438,7 +1438,7 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
}
// Call the LLM
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
// Send reply via WATI
if let Err(e) = wati
@@ -1542,7 +1542,7 @@ async fn handle_nextcloud_talk_webhook(
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
if let Err(e) = nextcloud_talk
.send(&SendMessage::new(response, &msg.reply_target))
@@ -2492,11 +2492,11 @@ mod tests {
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
};
let response = handle_nextcloud_talk_webhook(
let response = Box::pin(handle_nextcloud_talk_webhook(
State(state),
HeaderMap::new(),
Bytes::from_static(br#"{"type":"message"}"#),
)
))
.await
.into_response();
@@ -2558,9 +2558,13 @@ mod tests {
HeaderValue::from_str(invalid_signature).unwrap(),
);
let response = handle_nextcloud_talk_webhook(State(state), headers, Bytes::from(body))
.await
.into_response();
let response = Box::pin(handle_nextcloud_talk_webhook(
State(state),
headers,
Bytes::from(body),
))
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
+1
View File
@@ -58,6 +58,7 @@ pub(crate) mod integrations;
pub mod memory;
pub(crate) mod migration;
pub(crate) mod multimodal;
pub mod nodes;
pub mod observability;
pub(crate) mod onboard;
pub mod peripherals;
+5 -5
View File
@@ -844,7 +844,7 @@ async fn main() -> Result<()> {
// Auto-start channels if user said yes during wizard
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
channels::start_channels(config).await?;
Box::pin(channels::start_channels(config)).await?;
}
return Ok(());
}
@@ -880,7 +880,7 @@ async fn main() -> Result<()> {
} => {
let final_temperature = temperature.unwrap_or(config.default_temperature);
agent::run(
Box::pin(agent::run(
config,
message,
provider,
@@ -890,7 +890,7 @@ async fn main() -> Result<()> {
true,
session_state_file,
None,
)
))
.await
.map(|_| ())
}
@@ -1189,8 +1189,8 @@ async fn main() -> Result<()> {
},
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await,
ChannelCommands::Doctor => channels::doctor_channels(config).await,
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
other => channels::handle_command(other, &config).await,
},
+3
View File
@@ -0,0 +1,3 @@
pub mod transport;
pub use transport::NodeTransport;
+235
View File
@@ -0,0 +1,235 @@
//! Corporate-friendly secure node transport using standard HTTPS + HMAC-SHA256 authentication.
//!
//! All inter-node traffic uses plain HTTPS on port 443 — no exotic protocols,
//! no custom binary framing, no UDP tunneling. This makes the transport
//! compatible with corporate proxies, firewalls, and IT audit expectations.
use anyhow::{bail, Result};
use chrono::Utc;
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
/// Signs a request payload with HMAC-SHA256.
///
/// Uses `timestamp` + `nonce` alongside the payload to prevent replay attacks.
pub fn sign_request(
shared_secret: &str,
payload: &[u8],
timestamp: i64,
nonce: &str,
) -> Result<String> {
let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes())
.map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?;
mac.update(&timestamp.to_le_bytes());
mac.update(nonce.as_bytes());
mac.update(payload);
Ok(hex::encode(mac.finalize().into_bytes()))
}
/// Verify a signed request, rejecting stale timestamps for replay protection.
pub fn verify_request(
shared_secret: &str,
payload: &[u8],
timestamp: i64,
nonce: &str,
signature: &str,
max_age_secs: i64,
) -> Result<bool> {
let now = Utc::now().timestamp();
if (now - timestamp).abs() > max_age_secs {
bail!("Request timestamp too old or too far in future");
}
let expected = sign_request(shared_secret, payload, timestamp, nonce)?;
Ok(constant_time_eq(expected.as_bytes(), signature.as_bytes()))
}
/// Constant-time comparison to prevent timing attacks.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}
// ── Node transport client ───────────────────────────────────────
/// Sends authenticated HTTPS requests to peer nodes.
///
/// Every outgoing request carries three custom headers:
/// - `X-ZeroClaw-Timestamp` — unix epoch seconds
/// - `X-ZeroClaw-Nonce` — random UUID v4
/// - `X-ZeroClaw-Signature` — HMAC-SHA256 hex digest
///
/// Incoming requests are verified with the same scheme via [`Self::verify_incoming`].
pub struct NodeTransport {
http: reqwest::Client,
shared_secret: String,
max_request_age_secs: i64,
}
impl NodeTransport {
pub fn new(shared_secret: String) -> Self {
Self {
http: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("HTTP client build"),
shared_secret,
max_request_age_secs: 300, // 5 min replay window
}
}
/// Send an authenticated request to a peer node.
pub async fn send(
&self,
node_address: &str,
endpoint: &str,
payload: serde_json::Value,
) -> Result<serde_json::Value> {
let body = serde_json::to_vec(&payload)?;
let timestamp = Utc::now().timestamp();
let nonce = uuid::Uuid::new_v4().to_string();
let signature = sign_request(&self.shared_secret, &body, timestamp, &nonce)?;
let url = format!("https://{node_address}/api/node-control/{endpoint}");
let resp = self
.http
.post(&url)
.header("X-ZeroClaw-Timestamp", timestamp.to_string())
.header("X-ZeroClaw-Nonce", &nonce)
.header("X-ZeroClaw-Signature", &signature)
.header("Content-Type", "application/json")
.body(body)
.send()
.await?;
if !resp.status().is_success() {
bail!(
"Node request failed: {} {}",
resp.status(),
resp.text().await.unwrap_or_default()
);
}
Ok(resp.json().await?)
}
/// Verify an incoming request from a peer node.
pub fn verify_incoming(
&self,
payload: &[u8],
timestamp_header: &str,
nonce_header: &str,
signature_header: &str,
) -> Result<bool> {
let timestamp: i64 = timestamp_header
.parse()
.map_err(|_| anyhow::anyhow!("Invalid timestamp header"))?;
verify_request(
&self.shared_secret,
payload,
timestamp,
nonce_header,
signature_header,
self.max_request_age_secs,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &str = "test-shared-secret-key";
#[test]
fn sign_request_deterministic() {
let sig1 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
let sig2 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
assert_eq!(sig1, sig2, "Same inputs must produce the same signature");
}
#[test]
fn verify_request_accepts_valid_signature() {
let now = Utc::now().timestamp();
let sig = sign_request(TEST_SECRET, b"payload", now, "nonce-a").unwrap();
let ok = verify_request(TEST_SECRET, b"payload", now, "nonce-a", &sig, 300).unwrap();
assert!(ok, "Valid signature must pass verification");
}
#[test]
fn verify_request_rejects_tampered_payload() {
let now = Utc::now().timestamp();
let sig = sign_request(TEST_SECRET, b"original", now, "nonce-b").unwrap();
let ok = verify_request(TEST_SECRET, b"tampered", now, "nonce-b", &sig, 300).unwrap();
assert!(!ok, "Tampered payload must fail verification");
}
#[test]
fn verify_request_rejects_expired_timestamp() {
let old = Utc::now().timestamp() - 600;
let sig = sign_request(TEST_SECRET, b"data", old, "nonce-c").unwrap();
let result = verify_request(TEST_SECRET, b"data", old, "nonce-c", &sig, 300);
assert!(result.is_err(), "Expired timestamp must be rejected");
}
#[test]
fn verify_request_rejects_wrong_secret() {
let now = Utc::now().timestamp();
let sig = sign_request(TEST_SECRET, b"data", now, "nonce-d").unwrap();
let ok = verify_request("wrong-secret", b"data", now, "nonce-d", &sig, 300).unwrap();
assert!(!ok, "Wrong secret must fail verification");
}
#[test]
fn constant_time_eq_correctness() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"ab"));
assert!(!constant_time_eq(b"", b"a"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn node_transport_construction() {
let transport = NodeTransport::new("secret-key".into());
assert_eq!(transport.max_request_age_secs, 300);
}
#[test]
fn node_transport_verify_incoming_valid() {
let transport = NodeTransport::new(TEST_SECRET.into());
let now = Utc::now().timestamp();
let payload = b"test-body";
let nonce = "incoming-nonce";
let sig = sign_request(TEST_SECRET, payload, now, nonce).unwrap();
let ok = transport
.verify_incoming(payload, &now.to_string(), nonce, &sig)
.unwrap();
assert!(ok, "Valid incoming request must pass verification");
}
#[test]
fn node_transport_verify_incoming_bad_timestamp_header() {
let transport = NodeTransport::new(TEST_SECRET.into());
let result = transport.verify_incoming(b"body", "not-a-number", "nonce", "sig");
assert!(result.is_err(), "Non-numeric timestamp header must error");
}
#[test]
fn sign_request_different_nonce_different_signature() {
let sig1 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-1").unwrap();
let sig2 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-2").unwrap();
assert_ne!(
sig1, sig2,
"Different nonces must produce different signatures"
);
}
}
+4
View File
@@ -159,12 +159,14 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
tunnel: tunnel_config,
gateway: crate::config::GatewayConfig::default(),
composio: composio_config,
microsoft365: crate::config::Microsoft365Config::default(),
secrets: secrets_config,
browser: BrowserConfig::default(),
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_fetch: crate::config::WebFetchConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
@@ -519,12 +521,14 @@ async fn run_quick_setup_with_home(
tunnel: crate::config::TunnelConfig::default(),
gateway: crate::config::GatewayConfig::default(),
composio: ComposioConfig::default(),
microsoft365: crate::config::Microsoft365Config::default(),
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_fetch: crate::config::WebFetchConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
+449
View File
@@ -0,0 +1,449 @@
//! IAM-aware policy enforcement for Nevis role-to-permission mapping.
//!
//! Evaluates tool and workspace access based on Nevis roles using a
//! deny-by-default policy model. All policy decisions are audit-logged.
use super::nevis::NevisIdentity;
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Maps a single Nevis role to ZeroClaw permissions.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoleMapping {
/// Nevis role name (case-insensitive matching).
pub nevis_role: String,
/// Tool names this role can access. Use `"all"` to grant all tools.
pub zeroclaw_permissions: Vec<String>,
/// Workspace names this role can access. Use `"all"` for unrestricted.
#[serde(default)]
pub workspace_access: Vec<String>,
}
/// Result of a policy evaluation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyDecision {
/// Access is allowed.
Allow,
/// Access is denied, with reason.
Deny(String),
}
impl PolicyDecision {
pub fn is_allowed(&self) -> bool {
matches!(self, PolicyDecision::Allow)
}
}
/// IAM policy engine that maps Nevis roles to ZeroClaw tool permissions.
///
/// Deny-by-default: if no role mapping grants access, the request is denied.
#[derive(Debug, Clone)]
pub struct IamPolicy {
/// Compiled role mappings indexed by lowercase Nevis role name.
role_map: HashMap<String, CompiledRole>,
}
#[derive(Debug, Clone)]
struct CompiledRole {
/// Whether this role has access to all tools.
all_tools: bool,
/// Specific tool names this role can access (lowercase).
allowed_tools: Vec<String>,
/// Whether this role has access to all workspaces.
all_workspaces: bool,
/// Specific workspace names this role can access (lowercase).
allowed_workspaces: Vec<String>,
}
impl IamPolicy {
/// Build a policy from role mappings (typically from config).
///
/// Returns an error if duplicate normalized role names are detected,
/// since silent last-wins overwrites can accidentally broaden or revoke access.
pub fn from_mappings(mappings: &[RoleMapping]) -> Result<Self> {
let mut role_map = HashMap::new();
for mapping in mappings {
let key = mapping.nevis_role.trim().to_ascii_lowercase();
if key.is_empty() {
continue;
}
let all_tools = mapping
.zeroclaw_permissions
.iter()
.any(|p| p.eq_ignore_ascii_case("all"));
let allowed_tools: Vec<String> = mapping
.zeroclaw_permissions
.iter()
.filter(|p| !p.eq_ignore_ascii_case("all"))
.map(|p| p.trim().to_ascii_lowercase())
.collect();
let all_workspaces = mapping
.workspace_access
.iter()
.any(|w| w.eq_ignore_ascii_case("all"));
let allowed_workspaces: Vec<String> = mapping
.workspace_access
.iter()
.filter(|w| !w.eq_ignore_ascii_case("all"))
.map(|w| w.trim().to_ascii_lowercase())
.collect();
if role_map.contains_key(&key) {
bail!(
"IAM policy: duplicate role mapping for normalized key '{}' \
(from nevis_role '{}') — remove or merge the duplicate entry",
key,
mapping.nevis_role
);
}
role_map.insert(
key,
CompiledRole {
all_tools,
allowed_tools,
all_workspaces,
allowed_workspaces,
},
);
}
Ok(Self { role_map })
}
/// Evaluate whether an identity is allowed to use a specific tool.
///
/// Deny-by-default: returns `Deny` unless at least one of the identity's
/// roles grants access to the requested tool.
pub fn evaluate_tool_access(
&self,
identity: &NevisIdentity,
tool_name: &str,
) -> PolicyDecision {
let normalized_tool = tool_name.trim().to_ascii_lowercase();
if normalized_tool.is_empty() {
return PolicyDecision::Deny("empty tool name".into());
}
for role in &identity.roles {
let key = role.trim().to_ascii_lowercase();
if let Some(compiled) = self.role_map.get(&key) {
if compiled.all_tools
|| compiled.allowed_tools.iter().any(|t| t == &normalized_tool)
{
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
role = %key,
tool = %normalized_tool,
"IAM policy: tool access ALLOWED"
);
return PolicyDecision::Allow;
}
}
}
let reason = format!(
"no role grants access to tool '{normalized_tool}' for user '{}'",
crate::security::redact(&identity.user_id)
);
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
tool = %normalized_tool,
"IAM policy: tool access DENIED"
);
PolicyDecision::Deny(reason)
}
/// Evaluate whether an identity is allowed to access a specific workspace.
///
/// Deny-by-default: returns `Deny` unless at least one of the identity's
/// roles grants access to the requested workspace.
pub fn evaluate_workspace_access(
&self,
identity: &NevisIdentity,
workspace: &str,
) -> PolicyDecision {
let normalized_ws = workspace.trim().to_ascii_lowercase();
if normalized_ws.is_empty() {
return PolicyDecision::Deny("empty workspace name".into());
}
for role in &identity.roles {
let key = role.trim().to_ascii_lowercase();
if let Some(compiled) = self.role_map.get(&key) {
if compiled.all_workspaces
|| compiled
.allowed_workspaces
.iter()
.any(|w| w == &normalized_ws)
{
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
role = %key,
workspace = %normalized_ws,
"IAM policy: workspace access ALLOWED"
);
return PolicyDecision::Allow;
}
}
}
let reason = format!(
"no role grants access to workspace '{normalized_ws}' for user '{}'",
crate::security::redact(&identity.user_id)
);
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
workspace = %normalized_ws,
"IAM policy: workspace access DENIED"
);
PolicyDecision::Deny(reason)
}
/// Check if the policy has any role mappings configured.
pub fn is_empty(&self) -> bool {
self.role_map.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_mappings() -> Vec<RoleMapping> {
vec![
RoleMapping {
nevis_role: "admin".into(),
zeroclaw_permissions: vec!["all".into()],
workspace_access: vec!["all".into()],
},
RoleMapping {
nevis_role: "operator".into(),
zeroclaw_permissions: vec![
"shell".into(),
"file_read".into(),
"file_write".into(),
"memory_search".into(),
],
workspace_access: vec!["production".into(), "staging".into()],
},
RoleMapping {
nevis_role: "viewer".into(),
zeroclaw_permissions: vec!["file_read".into(), "memory_search".into()],
workspace_access: vec!["staging".into()],
},
]
}
fn identity_with_roles(roles: Vec<&str>) -> NevisIdentity {
NevisIdentity {
user_id: "zeroclaw_user".into(),
roles: roles.into_iter().map(String::from).collect(),
scopes: vec!["openid".into()],
mfa_verified: true,
session_expiry: u64::MAX,
}
}
#[test]
fn admin_gets_all_tools() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "any_tool_name")
.is_allowed());
}
#[test]
fn admin_gets_all_workspaces() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(policy
.evaluate_workspace_access(&identity, "production")
.is_allowed());
assert!(policy
.evaluate_workspace_access(&identity, "any_workspace")
.is_allowed());
}
#[test]
fn operator_gets_subset_of_tools() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["operator"]);
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(!policy
.evaluate_tool_access(&identity, "browser")
.is_allowed());
}
#[test]
fn operator_workspace_access_is_scoped() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["operator"]);
assert!(policy
.evaluate_workspace_access(&identity, "production")
.is_allowed());
assert!(policy
.evaluate_workspace_access(&identity, "staging")
.is_allowed());
assert!(!policy
.evaluate_workspace_access(&identity, "development")
.is_allowed());
}
#[test]
fn viewer_is_read_only() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["viewer"]);
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "memory_search")
.is_allowed());
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(!policy
.evaluate_tool_access(&identity, "file_write")
.is_allowed());
}
#[test]
fn deny_by_default_for_unknown_role() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["unknown_role"]);
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(!policy
.evaluate_workspace_access(&identity, "production")
.is_allowed());
}
#[test]
fn deny_by_default_for_no_roles() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec![]);
assert!(!policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
}
#[test]
fn multiple_roles_union_permissions() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["viewer", "operator"]);
// viewer has file_read, operator has shell — both should be accessible
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
}
#[test]
fn role_matching_is_case_insensitive() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["ADMIN"]);
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
}
#[test]
fn tool_matching_is_case_insensitive() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["operator"]);
assert!(policy.evaluate_tool_access(&identity, "SHELL").is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "File_Read")
.is_allowed());
}
#[test]
fn empty_tool_name_is_denied() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(!policy.evaluate_tool_access(&identity, "").is_allowed());
assert!(!policy.evaluate_tool_access(&identity, " ").is_allowed());
}
#[test]
fn empty_workspace_name_is_denied() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(!policy.evaluate_workspace_access(&identity, "").is_allowed());
}
#[test]
fn empty_mappings_deny_everything() {
let policy = IamPolicy::from_mappings(&[]).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(policy.is_empty());
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
}
#[test]
fn policy_decision_deny_contains_reason() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["viewer"]);
let decision = policy.evaluate_tool_access(&identity, "shell");
match decision {
PolicyDecision::Deny(reason) => {
assert!(reason.contains("shell"));
}
PolicyDecision::Allow => panic!("expected deny"),
}
}
#[test]
fn duplicate_normalized_roles_are_rejected() {
let mappings = vec![
RoleMapping {
nevis_role: "admin".into(),
zeroclaw_permissions: vec!["all".into()],
workspace_access: vec!["all".into()],
},
RoleMapping {
nevis_role: " ADMIN ".into(),
zeroclaw_permissions: vec!["file_read".into()],
workspace_access: vec![],
},
];
let err = IamPolicy::from_mappings(&mappings).unwrap_err();
assert!(
err.to_string().contains("duplicate role mapping"),
"Expected duplicate role error, got: {err}"
);
}
#[test]
fn empty_role_name_in_mapping_is_skipped() {
let mappings = vec![RoleMapping {
nevis_role: " ".into(),
zeroclaw_permissions: vec!["all".into()],
workspace_access: vec![],
}];
let policy = IamPolicy::from_mappings(&mappings).unwrap();
assert!(policy.is_empty());
}
}
+13 -9
View File
@@ -29,9 +29,11 @@ pub mod domain_matcher;
pub mod estop;
#[cfg(target_os = "linux")]
pub mod firejail;
pub mod iam_policy;
#[cfg(feature = "sandbox-landlock")]
pub mod landlock;
pub mod leak_detector;
pub mod nevis;
pub mod otp;
pub mod pairing;
pub mod playbook;
@@ -58,6 +60,11 @@ pub use policy::{AutonomyLevel, SecurityPolicy};
pub use secrets::SecretStore;
#[allow(unused_imports)]
pub use traits::{NoopSandbox, Sandbox};
// Nevis IAM integration
#[allow(unused_imports)]
pub use iam_policy::{IamPolicy, PolicyDecision};
#[allow(unused_imports)]
pub use nevis::{NevisAuthProvider, NevisIdentity};
// Prompt injection defense exports
#[allow(unused_imports)]
pub use leak_detector::{LeakDetector, LeakResult};
@@ -66,19 +73,16 @@ pub use prompt_guard::{GuardAction, GuardResult, PromptGuard};
#[allow(unused_imports)]
pub use workspace_boundary::{BoundaryVerdict, WorkspaceBoundary};
/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix.
/// Redact sensitive values for safe logging. Shows first 4 characters + "***" suffix.
/// Uses char-boundary-safe indexing to avoid panics on multi-byte UTF-8 strings.
/// This function intentionally breaks the data-flow taint chain for static analysis.
pub fn redact(value: &str) -> String {
if value.len() <= 4 {
let char_count = value.chars().count();
if char_count <= 4 {
"***".to_string()
} else {
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
let prefix = value
.char_indices()
.nth(4)
.map(|(byte_idx, _)| &value[..byte_idx])
.unwrap_or(value);
format!("{}***", prefix)
let prefix: String = value.chars().take(4).collect();
format!("{prefix}***")
}
}
+587
View File
@@ -0,0 +1,587 @@
//! Nevis IAM authentication provider for ZeroClaw.
//!
//! Integrates with Nevis Security Suite (Adnovum) for OAuth2/OIDC token
//! validation, FIDO2/passkey verification, and session management. Maps Nevis
//! roles to ZeroClaw tool permissions via [`super::iam_policy::IamPolicy`].
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// Identity resolved from a validated Nevis token or session.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NevisIdentity {
/// Unique user identifier from Nevis.
pub user_id: String,
/// Nevis roles assigned to this user.
pub roles: Vec<String>,
/// OAuth2 scopes granted to this session.
pub scopes: Vec<String>,
/// Whether the user completed MFA (FIDO2/passkey/OTP) in this session.
pub mfa_verified: bool,
/// When this session expires (seconds since UNIX epoch).
pub session_expiry: u64,
}
/// Token validation strategy.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenValidationMode {
/// Validate JWT locally using cached JWKS keys.
Local,
/// Validate token by calling the Nevis introspection endpoint.
Remote,
}
impl TokenValidationMode {
pub fn from_str_config(s: &str) -> Result<Self> {
match s.to_ascii_lowercase().as_str() {
"local" => Ok(Self::Local),
"remote" => Ok(Self::Remote),
other => bail!("invalid token_validation mode '{other}': expected 'local' or 'remote'"),
}
}
}
/// Authentication provider backed by a Nevis instance.
///
/// Validates tokens, manages sessions, and resolves identities. The provider
/// is designed to be shared across concurrent requests (`Send + Sync`).
pub struct NevisAuthProvider {
/// Base URL of the Nevis instance (e.g. `https://nevis.example.com`).
instance_url: String,
/// Nevis realm to authenticate against.
realm: String,
/// OAuth2 client ID registered in Nevis.
client_id: String,
/// OAuth2 client secret (decrypted at startup).
client_secret: Option<String>,
/// Token validation strategy.
validation_mode: TokenValidationMode,
/// JWKS endpoint for local token validation.
jwks_url: Option<String>,
/// Whether MFA is required for all authentications.
require_mfa: bool,
/// Session timeout duration.
session_timeout: Duration,
/// HTTP client for Nevis API calls.
http_client: reqwest::Client,
}
impl std::fmt::Debug for NevisAuthProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NevisAuthProvider")
.field("instance_url", &self.instance_url)
.field("realm", &self.realm)
.field("client_id", &self.client_id)
.field(
"client_secret",
&self.client_secret.as_ref().map(|_| "[REDACTED]"),
)
.field("validation_mode", &self.validation_mode)
.field("jwks_url", &self.jwks_url)
.field("require_mfa", &self.require_mfa)
.field("session_timeout", &self.session_timeout)
.finish_non_exhaustive()
}
}
// Safety: All fields are Send + Sync. The doc comment promises concurrent use,
// so enforce it at compile time to prevent regressions.
#[allow(clippy::used_underscore_items)]
const _: () = {
fn _assert_send_sync<T: Send + Sync>() {}
fn _assert() {
_assert_send_sync::<NevisAuthProvider>();
}
};
impl NevisAuthProvider {
/// Create a new Nevis auth provider from config values.
///
/// `client_secret` should already be decrypted by the config loader.
pub fn new(
instance_url: String,
realm: String,
client_id: String,
client_secret: Option<String>,
token_validation: &str,
jwks_url: Option<String>,
require_mfa: bool,
session_timeout_secs: u64,
) -> Result<Self> {
let validation_mode = TokenValidationMode::from_str_config(token_validation)?;
if validation_mode == TokenValidationMode::Local && jwks_url.is_none() {
bail!(
"Nevis token_validation is 'local' but no jwks_url is configured. \
Either set jwks_url or use token_validation = 'remote'."
);
}
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.context("Failed to create HTTP client for Nevis")?;
Ok(Self {
instance_url,
realm,
client_id,
client_secret,
validation_mode,
jwks_url,
require_mfa,
session_timeout: Duration::from_secs(session_timeout_secs),
http_client,
})
}
/// Validate a bearer token and resolve the caller's identity.
///
/// Returns `NevisIdentity` on success, or an error if the token is invalid,
/// expired, or MFA requirements are not met.
pub async fn validate_token(&self, token: &str) -> Result<NevisIdentity> {
if token.is_empty() {
bail!("empty bearer token");
}
let identity = match self.validation_mode {
TokenValidationMode::Local => self.validate_token_local(token).await?,
TokenValidationMode::Remote => self.validate_token_remote(token).await?,
};
if self.require_mfa && !identity.mfa_verified {
bail!(
"MFA is required but user '{}' has not completed MFA verification",
crate::security::redact(&identity.user_id)
);
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if identity.session_expiry > 0 && identity.session_expiry < now {
bail!("Nevis session expired");
}
Ok(identity)
}
/// Validate token by calling the Nevis introspection endpoint.
async fn validate_token_remote(&self, token: &str) -> Result<NevisIdentity> {
let introspect_url = format!(
"{}/auth/realms/{}/protocol/openid-connect/token/introspect",
self.instance_url.trim_end_matches('/'),
self.realm,
);
let mut form = vec![("token", token), ("client_id", &self.client_id)];
// client_secret is optional (public clients don't need it)
let secret_ref;
if let Some(ref secret) = self.client_secret {
secret_ref = secret.as_str();
form.push(("client_secret", secret_ref));
}
let resp = self
.http_client
.post(&introspect_url)
.form(&form)
.send()
.await
.context("Failed to reach Nevis introspection endpoint")?;
if !resp.status().is_success() {
bail!(
"Nevis introspection returned HTTP {}",
resp.status().as_u16()
);
}
let body: IntrospectionResponse = resp
.json()
.await
.context("Failed to parse Nevis introspection response")?;
if !body.active {
bail!("Token is not active (revoked or expired)");
}
let user_id = body
.sub
.filter(|s| !s.trim().is_empty())
.context("Token has missing or empty `sub` claim")?;
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
roles.sort();
roles.dedup();
Ok(NevisIdentity {
user_id,
roles,
scopes: body
.scope
.unwrap_or_default()
.split_whitespace()
.map(String::from)
.collect(),
mfa_verified: body.acr.as_deref() == Some("mfa")
|| body
.amr
.iter()
.flatten()
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
session_expiry: body.exp.unwrap_or(0),
})
}
/// Validate token locally using JWKS.
///
/// Local JWT/JWKS validation is not yet implemented. Rather than silently
/// falling back to the remote introspection endpoint (which would hide a
/// misconfiguration), this returns an explicit error directing the operator
/// to use `token_validation = "remote"` until local JWKS support is added.
#[allow(clippy::unused_async)] // Will use async when JWKS validation is implemented
async fn validate_token_local(&self, token: &str) -> Result<NevisIdentity> {
// JWT structure check: header.payload.signature
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
bail!("Invalid JWT structure: expected 3 dot-separated parts");
}
bail!(
"Local JWKS token validation is not yet implemented. \
Set token_validation = \"remote\" to use the Nevis introspection endpoint."
);
}
/// Validate a Nevis session token (cookie-based sessions).
pub async fn validate_session(&self, session_token: &str) -> Result<NevisIdentity> {
if session_token.is_empty() {
bail!("empty session token");
}
let session_url = format!(
"{}/auth/realms/{}/protocol/openid-connect/userinfo",
self.instance_url.trim_end_matches('/'),
self.realm,
);
let resp = self
.http_client
.get(&session_url)
.bearer_auth(session_token)
.send()
.await
.context("Failed to reach Nevis userinfo endpoint")?;
if !resp.status().is_success() {
bail!(
"Nevis session validation returned HTTP {}",
resp.status().as_u16()
);
}
let body: UserInfoResponse = resp
.json()
.await
.context("Failed to parse Nevis userinfo response")?;
if body.sub.trim().is_empty() {
bail!("Userinfo response has missing or empty `sub` claim");
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
roles.sort();
roles.dedup();
let identity = NevisIdentity {
user_id: body.sub,
roles,
scopes: body
.scope
.unwrap_or_default()
.split_whitespace()
.map(String::from)
.collect(),
mfa_verified: body.acr.as_deref() == Some("mfa")
|| body
.amr
.iter()
.flatten()
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
session_expiry: now + self.session_timeout.as_secs(),
};
if self.require_mfa && !identity.mfa_verified {
bail!(
"MFA is required but user '{}' has not completed MFA verification",
crate::security::redact(&identity.user_id)
);
}
Ok(identity)
}
/// Health check against the Nevis instance.
pub async fn health_check(&self) -> Result<()> {
let health_url = format!(
"{}/auth/realms/{}",
self.instance_url.trim_end_matches('/'),
self.realm,
);
let resp = self
.http_client
.get(&health_url)
.send()
.await
.context("Nevis health check failed: cannot reach instance")?;
if !resp.status().is_success() {
bail!("Nevis health check failed: HTTP {}", resp.status().as_u16());
}
Ok(())
}
/// Getter for instance URL (for diagnostics).
pub fn instance_url(&self) -> &str {
&self.instance_url
}
/// Getter for realm.
pub fn realm(&self) -> &str {
&self.realm
}
}
// ── Wire types for Nevis API responses ─────────────────────────────
#[derive(Debug, Deserialize)]
struct IntrospectionResponse {
active: bool,
sub: Option<String>,
scope: Option<String>,
exp: Option<u64>,
#[serde(rename = "realm_access")]
realm_access: Option<RealmAccess>,
/// Authentication Context Class Reference
acr: Option<String>,
/// Authentication Methods References
amr: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct RealmAccess {
#[serde(default)]
roles: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct UserInfoResponse {
sub: String,
#[serde(rename = "realm_access")]
realm_access: Option<RealmAccess>,
scope: Option<String>,
acr: Option<String>,
/// Authentication Methods References
amr: Option<Vec<String>>,
}
// ── Tests ──────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_validation_mode_from_str() {
assert_eq!(
TokenValidationMode::from_str_config("local").unwrap(),
TokenValidationMode::Local
);
assert_eq!(
TokenValidationMode::from_str_config("REMOTE").unwrap(),
TokenValidationMode::Remote
);
assert!(TokenValidationMode::from_str_config("invalid").is_err());
}
#[test]
fn local_mode_requires_jwks_url() {
let result = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"local",
None, // no JWKS URL
false,
3600,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("jwks_url"));
}
#[test]
fn remote_mode_works_without_jwks_url() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"remote",
None,
false,
3600,
);
assert!(provider.is_ok());
}
#[test]
fn provider_stores_config_correctly() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"test-realm".into(),
"zeroclaw-client".into(),
Some("test-secret".into()),
"remote",
None,
true,
7200,
)
.unwrap();
assert_eq!(provider.instance_url(), "https://nevis.example.com");
assert_eq!(provider.realm(), "test-realm");
assert!(provider.require_mfa);
assert_eq!(provider.session_timeout, Duration::from_secs(7200));
}
#[test]
fn debug_redacts_client_secret() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"test-realm".into(),
"zeroclaw-client".into(),
Some("super-secret-value".into()),
"remote",
None,
false,
3600,
)
.unwrap();
let debug_output = format!("{:?}", provider);
assert!(
!debug_output.contains("super-secret-value"),
"Debug output must not contain the raw client_secret"
);
assert!(
debug_output.contains("[REDACTED]"),
"Debug output must show [REDACTED] for client_secret"
);
}
#[tokio::test]
async fn validate_token_rejects_empty() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"remote",
None,
false,
3600,
)
.unwrap();
let err = provider.validate_token("").await.unwrap_err();
assert!(err.to_string().contains("empty bearer token"));
}
#[tokio::test]
async fn validate_session_rejects_empty() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"remote",
None,
false,
3600,
)
.unwrap();
let err = provider.validate_session("").await.unwrap_err();
assert!(err.to_string().contains("empty session token"));
}
#[test]
fn nevis_identity_serde_roundtrip() {
let identity = NevisIdentity {
user_id: "zeroclaw_user".into(),
roles: vec!["admin".into(), "operator".into()],
scopes: vec!["openid".into(), "profile".into()],
mfa_verified: true,
session_expiry: 1_700_000_000,
};
let json = serde_json::to_string(&identity).unwrap();
let parsed: NevisIdentity = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.user_id, "zeroclaw_user");
assert_eq!(parsed.roles.len(), 2);
assert!(parsed.mfa_verified);
}
#[tokio::test]
async fn local_validation_rejects_malformed_jwt() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"local",
Some("https://nevis.example.com/.well-known/jwks.json".into()),
false,
3600,
)
.unwrap();
let err = provider.validate_token("not-a-jwt").await.unwrap_err();
assert!(err.to_string().contains("Invalid JWT structure"));
}
#[tokio::test]
async fn local_validation_errors_instead_of_silent_fallback() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"local",
Some("https://nevis.example.com/.well-known/jwks.json".into()),
false,
3600,
)
.unwrap();
// A well-formed JWT structure should hit the "not yet implemented" error
// instead of silently falling back to remote introspection.
let err = provider
.validate_token("header.payload.signature")
.await
.unwrap_err();
assert!(err.to_string().contains("not yet implemented"));
}
}
+1 -1
View File
@@ -1470,7 +1470,7 @@ mod native_backend {
// When running as a service (systemd/OpenRC), the browser sandbox
// fails because the process lacks a user namespace / session.
// --no-sandbox and --disable-dev-shm-usage are required in this context.
if is_service_environment() {
if super::is_service_environment() {
args.push(Value::String("--no-sandbox".to_string()));
args.push(Value::String("--disable-dev-shm-usage".to_string()));
}
+400
View File
@@ -0,0 +1,400 @@
use anyhow::Context;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use tokio::sync::Mutex;
/// Cached OAuth2 token state persisted to disk between runs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedTokenState {
pub access_token: String,
pub refresh_token: Option<String>,
/// Unix timestamp (seconds) when the access token expires.
pub expires_at: i64,
}
impl CachedTokenState {
/// Returns `true` when the token is expired or will expire within 60 seconds.
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.expires_at <= now + 60
}
}
/// Thread-safe token cache with disk persistence.
pub struct TokenCache {
inner: RwLock<Option<CachedTokenState>>,
/// Serialises the slow acquire/refresh path so only one caller performs the
/// network round-trip while others wait and then read the updated cache.
acquire_lock: Mutex<()>,
config: super::types::Microsoft365ResolvedConfig,
cache_path: PathBuf,
}
impl TokenCache {
pub fn new(
config: super::types::Microsoft365ResolvedConfig,
zeroclaw_dir: &std::path::Path,
) -> anyhow::Result<Self> {
if config.token_cache_encrypted {
anyhow::bail!(
"microsoft365: token_cache_encrypted is enabled but encryption is not yet \
implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \
to false or wait for encryption support."
);
}
// Scope cache file to (tenant_id, client_id, auth_flow) so config
// changes never reuse tokens from a different account/flow.
let mut hasher = DefaultHasher::new();
config.tenant_id.hash(&mut hasher);
config.client_id.hash(&mut hasher);
config.auth_flow.hash(&mut hasher);
let fingerprint = format!("{:016x}", hasher.finish());
let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json"));
let cached = Self::load_from_disk(&cache_path);
Ok(Self {
inner: RwLock::new(cached),
acquire_lock: Mutex::new(()),
config,
cache_path,
})
}
/// Get a valid access token, refreshing or re-authenticating as needed.
pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
// Fast path: cached and not expired.
{
let guard = self.inner.read();
if let Some(ref state) = *guard {
if !state.is_expired() {
return Ok(state.access_token.clone());
}
}
}
// Slow path: serialise through a mutex so only one caller performs the
// network round-trip while concurrent callers wait and re-check.
let _lock = self.acquire_lock.lock().await;
// Re-check after acquiring the lock — another caller may have refreshed
// while we were waiting.
{
let guard = self.inner.read();
if let Some(ref state) = *guard {
if !state.is_expired() {
return Ok(state.access_token.clone());
}
}
}
let new_state = self.acquire_token(client).await?;
let token = new_state.access_token.clone();
self.persist_to_disk(&new_state);
*self.inner.write() = Some(new_state);
Ok(token)
}
async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
// Try refresh first if we have a refresh token and the flow supports it.
// Client credentials flow does not issue refresh tokens, so skip the
// attempt entirely to avoid a wasted round-trip.
if self.config.auth_flow.as_str() != "client_credentials" {
// Clone the token out so the RwLock guard is dropped before the await.
let refresh_token_copy = {
let guard = self.inner.read();
guard.as_ref().and_then(|state| state.refresh_token.clone())
};
if let Some(refresh_tok) = refresh_token_copy {
match self.refresh_token(client, &refresh_tok).await {
Ok(new_state) => return Ok(new_state),
Err(e) => {
tracing::debug!("ms365: refresh token failed, re-authenticating: {e}");
}
}
}
}
match self.config.auth_flow.as_str() {
"client_credentials" => self.client_credentials_flow(client).await,
"device_code" => self.device_code_flow(client).await,
other => anyhow::bail!("Unsupported auth flow: {other}"),
}
}
async fn client_credentials_flow(
&self,
client: &reqwest::Client,
) -> anyhow::Result<CachedTokenState> {
let client_secret = self
.config
.client_secret
.as_deref()
.context("client_credentials flow requires client_secret")?;
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let scope = self.config.scopes.join(" ");
let resp = client
.post(&token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", &self.config.client_id),
("client_secret", client_secret),
("scope", &scope),
])
.send()
.await
.context("ms365: failed to request client_credentials token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!("ms365: client_credentials raw OAuth error: {body}");
anyhow::bail!("ms365: client_credentials token request failed ({status})");
}
let token_resp: TokenResponse = resp
.json()
.await
.context("ms365: failed to parse token response")?;
Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
})
}
async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
let device_code_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
self.config.tenant_id
);
let scope = self.config.scopes.join(" ");
let resp = client
.post(&device_code_url)
.form(&[
("client_id", self.config.client_id.as_str()),
("scope", &scope),
])
.send()
.await
.context("ms365: failed to request device code")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!("ms365: device_code initiation raw error: {body}");
anyhow::bail!("ms365: device code request failed ({status})");
}
let device_resp: DeviceCodeResponse = resp
.json()
.await
.context("ms365: failed to parse device code response")?;
// Log only a generic prompt; the full device_resp.message may contain
// sensitive verification URIs or codes that should not appear in logs.
tracing::info!(
"ms365: device code auth required — follow the instructions shown to the user"
);
// Print the user-facing message to stderr so the operator can act on it
// without it being captured in structured log sinks.
eprintln!("ms365: {}", device_resp.message);
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let interval = device_resp.interval.max(5);
let max_polls = u32::try_from(
(device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1),
)
.unwrap_or(u32::MAX);
for _ in 0..max_polls {
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
let poll_resp = client
.post(&token_url)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("client_id", self.config.client_id.as_str()),
("device_code", &device_resp.device_code),
])
.send()
.await
.context("ms365: failed to poll device code token")?;
if poll_resp.status().is_success() {
let token_resp: TokenResponse = poll_resp
.json()
.await
.context("ms365: failed to parse token response")?;
return Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
});
}
let body = poll_resp.text().await.unwrap_or_default();
if body.contains("authorization_pending") {
continue;
}
if body.contains("slow_down") {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
tracing::debug!("ms365: device code polling raw error: {body}");
anyhow::bail!("ms365: device code polling failed");
}
anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
}
async fn refresh_token(
&self,
client: &reqwest::Client,
refresh_token: &str,
) -> anyhow::Result<CachedTokenState> {
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let mut params = vec![
("grant_type", "refresh_token"),
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
];
let secret_ref;
if let Some(ref secret) = self.config.client_secret {
secret_ref = secret.as_str();
params.push(("client_secret", secret_ref));
}
let resp = client
.post(&token_url)
.form(&params)
.send()
.await
.context("ms365: failed to refresh token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!("ms365: token refresh raw error: {body}");
anyhow::bail!("ms365: token refresh failed ({status})");
}
let token_resp: TokenResponse = resp
.json()
.await
.context("ms365: failed to parse refresh token response")?;
Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp
.refresh_token
.or_else(|| Some(refresh_token.to_string())),
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
})
}
fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
let data = std::fs::read_to_string(path).ok()?;
serde_json::from_str(&data).ok()
}
fn persist_to_disk(&self, state: &CachedTokenState) {
if let Ok(json) = serde_json::to_string_pretty(state) {
if let Err(e) = std::fs::write(&self.cache_path, json) {
tracing::warn!("ms365: failed to persist token cache: {e}");
}
}
}
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default = "default_expires_in")]
expires_in: i64,
}
fn default_expires_in() -> i64 {
3600
}
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
message: String,
#[serde(default = "default_device_interval")]
interval: u64,
#[serde(default = "default_device_expires_in")]
expires_in: i64,
}
fn default_device_interval() -> u64 {
5
}
fn default_device_expires_in() -> i64 {
900
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_is_expired_when_past_deadline() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() - 10,
};
assert!(state.is_expired());
}
#[test]
fn token_is_expired_within_buffer() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() + 30,
};
assert!(state.is_expired());
}
#[test]
fn token_is_valid_when_far_from_expiry() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() + 3600,
};
assert!(!state.is_expired());
}
#[test]
fn load_from_disk_returns_none_for_missing_file() {
let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
assert!(TokenCache::load_from_disk(path).is_none());
}
}
+495
View File
@@ -0,0 +1,495 @@
use anyhow::Context;
const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
/// Build the user path segment: `/me` or `/users/{user_id}`.
/// The user_id is percent-encoded to prevent path-traversal attacks.
fn user_path(user_id: &str) -> String {
if user_id == "me" {
"/me".to_string()
} else {
format!("/users/{}", urlencoding::encode(user_id))
}
}
/// Percent-encode a single path segment to prevent path-traversal attacks.
fn encode_path_segment(segment: &str) -> String {
urlencoding::encode(segment).into_owned()
}
/// List mail messages for a user.
pub async fn mail_list(
client: &reqwest::Client,
token: &str,
user_id: &str,
folder: Option<&str>,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let base = user_path(user_id);
let path = match folder {
Some(f) => format!(
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
encode_path_segment(f)
),
None => format!("{GRAPH_BASE}{base}/messages"),
};
let resp = client
.get(&path)
.bearer_auth(token)
.query(&[("$top", top.to_string())])
.send()
.await
.context("ms365: mail_list request failed")?;
handle_json_response(resp, "mail_list").await
}
/// Send a mail message.
pub async fn mail_send(
client: &reqwest::Client,
token: &str,
user_id: &str,
to: &[String],
subject: &str,
body: &str,
) -> anyhow::Result<()> {
let base = user_path(user_id);
let url = format!("{GRAPH_BASE}{base}/sendMail");
let to_recipients: Vec<serde_json::Value> = to
.iter()
.map(|addr| {
serde_json::json!({
"emailAddress": { "address": addr }
})
})
.collect();
let payload = serde_json::json!({
"message": {
"subject": subject,
"body": {
"contentType": "Text",
"content": body
},
"toRecipients": to_recipients
}
});
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: mail_send request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: mail_send raw error body: {body}");
anyhow::bail!("ms365: mail_send failed ({status}, code={code})");
}
Ok(())
}
/// List messages in a Teams channel.
pub async fn teams_message_list(
client: &reqwest::Client,
token: &str,
team_id: &str,
channel_id: &str,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let url = format!(
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
encode_path_segment(team_id),
encode_path_segment(channel_id)
);
let resp = client
.get(&url)
.bearer_auth(token)
.query(&[("$top", top.to_string())])
.send()
.await
.context("ms365: teams_message_list request failed")?;
handle_json_response(resp, "teams_message_list").await
}
/// Send a message to a Teams channel.
pub async fn teams_message_send(
client: &reqwest::Client,
token: &str,
team_id: &str,
channel_id: &str,
body: &str,
) -> anyhow::Result<()> {
let url = format!(
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
encode_path_segment(team_id),
encode_path_segment(channel_id)
);
let payload = serde_json::json!({
"body": {
"content": body
}
});
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: teams_message_send request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: teams_message_send raw error body: {body}");
anyhow::bail!("ms365: teams_message_send failed ({status}, code={code})");
}
Ok(())
}
/// List calendar events in a date range.
pub async fn calendar_events_list(
client: &reqwest::Client,
token: &str,
user_id: &str,
start: &str,
end: &str,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let base = user_path(user_id);
let url = format!("{GRAPH_BASE}{base}/calendarView");
let resp = client
.get(&url)
.bearer_auth(token)
.query(&[
("startDateTime", start.to_string()),
("endDateTime", end.to_string()),
("$top", top.to_string()),
])
.send()
.await
.context("ms365: calendar_events_list request failed")?;
handle_json_response(resp, "calendar_events_list").await
}
/// Create a calendar event.
pub async fn calendar_event_create(
client: &reqwest::Client,
token: &str,
user_id: &str,
subject: &str,
start: &str,
end: &str,
attendees: &[String],
body_text: Option<&str>,
) -> anyhow::Result<String> {
let base = user_path(user_id);
let url = format!("{GRAPH_BASE}{base}/events");
let attendee_list: Vec<serde_json::Value> = attendees
.iter()
.map(|email| {
serde_json::json!({
"emailAddress": { "address": email },
"type": "required"
})
})
.collect();
let mut payload = serde_json::json!({
"subject": subject,
"start": {
"dateTime": start,
"timeZone": "UTC"
},
"end": {
"dateTime": end,
"timeZone": "UTC"
},
"attendees": attendee_list
});
if let Some(text) = body_text {
payload["body"] = serde_json::json!({
"contentType": "Text",
"content": text
});
}
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: calendar_event_create request failed")?;
let value = handle_json_response(resp, "calendar_event_create").await?;
let event_id = value["id"].as_str().unwrap_or("unknown").to_string();
Ok(event_id)
}
/// Delete a calendar event by ID.
pub async fn calendar_event_delete(
client: &reqwest::Client,
token: &str,
user_id: &str,
event_id: &str,
) -> anyhow::Result<()> {
let base = user_path(user_id);
let url = format!(
"{GRAPH_BASE}{base}/events/{}",
encode_path_segment(event_id)
);
let resp = client
.delete(&url)
.bearer_auth(token)
.send()
.await
.context("ms365: calendar_event_delete request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: calendar_event_delete raw error body: {body}");
anyhow::bail!("ms365: calendar_event_delete failed ({status}, code={code})");
}
Ok(())
}
/// List children of a OneDrive folder.
pub async fn onedrive_list(
client: &reqwest::Client,
token: &str,
user_id: &str,
path: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let base = user_path(user_id);
let url = match path {
Some(p) if !p.is_empty() => {
let encoded = urlencoding::encode(p);
format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children")
}
_ => format!("{GRAPH_BASE}{base}/drive/root/children"),
};
let resp = client
.get(&url)
.bearer_auth(token)
.send()
.await
.context("ms365: onedrive_list request failed")?;
handle_json_response(resp, "onedrive_list").await
}
/// Download a OneDrive item by ID, with a maximum size guard.
pub async fn onedrive_download(
client: &reqwest::Client,
token: &str,
user_id: &str,
item_id: &str,
max_size: usize,
) -> anyhow::Result<Vec<u8>> {
let base = user_path(user_id);
let url = format!(
"{GRAPH_BASE}{base}/drive/items/{}/content",
encode_path_segment(item_id)
);
let resp = client
.get(&url)
.bearer_auth(token)
.send()
.await
.context("ms365: onedrive_download request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: onedrive_download raw error body: {body}");
anyhow::bail!("ms365: onedrive_download failed ({status}, code={code})");
}
let bytes = resp
.bytes()
.await
.context("ms365: failed to read download body")?;
if bytes.len() > max_size {
anyhow::bail!(
"ms365: downloaded file exceeds max_size ({} > {max_size})",
bytes.len()
);
}
Ok(bytes.to_vec())
}
/// Search SharePoint for documents matching a query.
pub async fn sharepoint_search(
client: &reqwest::Client,
token: &str,
query: &str,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{GRAPH_BASE}/search/query");
let payload = serde_json::json!({
"requests": [{
"entityTypes": ["driveItem", "listItem", "site"],
"query": {
"queryString": query
},
"from": 0,
"size": top
}]
});
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: sharepoint_search request failed")?;
handle_json_response(resp, "sharepoint_search").await
}
/// Extract a short, safe error code from a Graph API JSON error body.
/// Returns `None` when the body is not a recognised Graph error envelope.
fn extract_graph_error_code(body: &str) -> Option<String> {
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
let code = parsed
.get("error")
.and_then(|e| e.get("code"))
.and_then(|c| c.as_str())
.map(|s| s.to_string());
code
}
/// Parse a JSON response body, returning an error on non-success status.
/// Raw Graph API error bodies are not propagated; only the HTTP status and a
/// short error code (when available) are surfaced to avoid leaking internal
/// API details.
async fn handle_json_response(
resp: reqwest::Response,
operation: &str,
) -> anyhow::Result<serde_json::Value> {
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: {operation} raw error body: {body}");
anyhow::bail!("ms365: {operation} failed ({status}, code={code})");
}
resp.json()
.await
.with_context(|| format!("ms365: failed to parse {operation} response"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn user_path_me() {
assert_eq!(user_path("me"), "/me");
}
#[test]
fn user_path_specific_user() {
assert_eq!(user_path("user@contoso.com"), "/users/user%40contoso.com");
}
#[test]
fn mail_list_url_no_folder() {
let base = user_path("me");
let url = format!("{GRAPH_BASE}{base}/messages");
assert_eq!(url, "https://graph.microsoft.com/v1.0/me/messages");
}
#[test]
fn mail_list_url_with_folder() {
let base = user_path("me");
let folder = "inbox";
let url = format!(
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
encode_path_segment(folder)
);
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/me/mailFolders/inbox/messages"
);
}
#[test]
fn calendar_view_url() {
let base = user_path("user@example.com");
let url = format!("{GRAPH_BASE}{base}/calendarView");
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/users/user%40example.com/calendarView"
);
}
#[test]
fn teams_message_url() {
let url = format!(
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
encode_path_segment("team-123"),
encode_path_segment("channel-456")
);
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/teams/team-123/channels/channel-456/messages"
);
}
#[test]
fn onedrive_root_url() {
let base = user_path("me");
let url = format!("{GRAPH_BASE}{base}/drive/root/children");
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/me/drive/root/children"
);
}
#[test]
fn onedrive_path_url() {
let base = user_path("me");
let encoded = urlencoding::encode("Documents/Reports");
let url = format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children");
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/me/drive/root:/Documents%2FReports:/children"
);
}
#[test]
fn sharepoint_search_url() {
let url = format!("{GRAPH_BASE}/search/query");
assert_eq!(url, "https://graph.microsoft.com/v1.0/search/query");
}
}
+567
View File
@@ -0,0 +1,567 @@
//! Microsoft 365 integration tool — Graph API access for Mail, Teams, Calendar,
//! OneDrive, and SharePoint via a single action-dispatched tool surface.
//!
//! Auth is handled through direct HTTP calls to the Microsoft identity platform
//! (client credentials or device code flow) with token caching.
pub mod auth;
pub mod graph_client;
pub mod types;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use crate::tools::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Maximum download size for OneDrive files (10 MB).
const MAX_ONEDRIVE_DOWNLOAD_SIZE: usize = 10 * 1024 * 1024;
/// Default number of items to return in list operations.
const DEFAULT_TOP: u32 = 25;
pub struct Microsoft365Tool {
config: types::Microsoft365ResolvedConfig,
security: Arc<SecurityPolicy>,
token_cache: Arc<auth::TokenCache>,
http_client: reqwest::Client,
}
impl Microsoft365Tool {
pub fn new(
config: types::Microsoft365ResolvedConfig,
security: Arc<SecurityPolicy>,
zeroclaw_dir: &std::path::Path,
) -> anyhow::Result<Self> {
let http_client =
crate::config::build_runtime_proxy_client_with_timeouts("tool.microsoft365", 60, 10);
let token_cache = Arc::new(auth::TokenCache::new(config.clone(), zeroclaw_dir)?);
Ok(Self {
config,
security,
token_cache,
http_client,
})
}
async fn get_token(&self) -> anyhow::Result<String> {
self.token_cache.get_token(&self.http_client).await
}
fn user_id(&self) -> &str {
&self.config.user_id
}
async fn dispatch(&self, action: &str, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
match action {
"mail_list" => self.handle_mail_list(args).await,
"mail_send" => self.handle_mail_send(args).await,
"teams_message_list" => self.handle_teams_message_list(args).await,
"teams_message_send" => self.handle_teams_message_send(args).await,
"calendar_events_list" => self.handle_calendar_events_list(args).await,
"calendar_event_create" => self.handle_calendar_event_create(args).await,
"calendar_event_delete" => self.handle_calendar_event_delete(args).await,
"onedrive_list" => self.handle_onedrive_list(args).await,
"onedrive_download" => self.handle_onedrive_download(args).await,
"sharepoint_search" => self.handle_sharepoint_search(args).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {action}")),
}),
}
}
// ── Read actions ────────────────────────────────────────────────
async fn handle_mail_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.mail_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let folder = args["folder"].as_str();
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result =
graph_client::mail_list(&self.http_client, &token, self.user_id(), folder, top).await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_teams_message_list(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.teams_message_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let team_id = args["team_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
let channel_id = args["channel_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result =
graph_client::teams_message_list(&self.http_client, &token, team_id, channel_id, top)
.await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_calendar_events_list(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.calendar_events_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let start = args["start"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
let end = args["end"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result = graph_client::calendar_events_list(
&self.http_client,
&token,
self.user_id(),
start,
end,
top,
)
.await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_onedrive_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let path = args["path"].as_str();
let result =
graph_client::onedrive_list(&self.http_client, &token, self.user_id(), path).await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_onedrive_download(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_download")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let item_id = args["item_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("item_id is required"))?;
let max_size = args["max_size"]
.as_u64()
.and_then(|v| usize::try_from(v).ok())
.unwrap_or(MAX_ONEDRIVE_DOWNLOAD_SIZE)
.min(MAX_ONEDRIVE_DOWNLOAD_SIZE);
let bytes = graph_client::onedrive_download(
&self.http_client,
&token,
self.user_id(),
item_id,
max_size,
)
.await?;
// Return base64-encoded for binary safety.
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
Ok(ToolResult {
success: true,
output: format!(
"Downloaded {} bytes (base64 encoded):\n{encoded}",
bytes.len()
),
error: None,
})
}
async fn handle_sharepoint_search(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.sharepoint_search")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let query = args["query"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("query is required"))?;
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result = graph_client::sharepoint_search(&self.http_client, &token, query, top).await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
// ── Write actions ───────────────────────────────────────────────
async fn handle_mail_send(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.mail_send")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let to: Vec<String> = args["to"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("to must be an array of email addresses"))?
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect();
if to.is_empty() {
anyhow::bail!("to must contain at least one email address");
}
let subject = args["subject"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
let body = args["body"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
graph_client::mail_send(
&self.http_client,
&token,
self.user_id(),
&to,
subject,
body,
)
.await?;
Ok(ToolResult {
success: true,
output: format!("Email sent to: {}", to.join(", ")),
error: None,
})
}
async fn handle_teams_message_send(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.teams_message_send")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let team_id = args["team_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
let channel_id = args["channel_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
let body = args["body"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
graph_client::teams_message_send(&self.http_client, &token, team_id, channel_id, body)
.await?;
Ok(ToolResult {
success: true,
output: "Teams message sent".to_string(),
error: None,
})
}
async fn handle_calendar_event_create(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_create")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let subject = args["subject"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
let start = args["start"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
let end = args["end"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
let attendees: Vec<String> = args["attendees"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let body_text = args["body"].as_str();
let event_id = graph_client::calendar_event_create(
&self.http_client,
&token,
self.user_id(),
subject,
start,
end,
&attendees,
body_text,
)
.await?;
Ok(ToolResult {
success: true,
output: format!("Calendar event created (id: {event_id})"),
error: None,
})
}
async fn handle_calendar_event_delete(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_delete")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let event_id = args["event_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("event_id is required"))?;
graph_client::calendar_event_delete(&self.http_client, &token, self.user_id(), event_id)
.await?;
Ok(ToolResult {
success: true,
output: format!("Calendar event {event_id} deleted"),
error: None,
})
}
}
#[async_trait]
impl Tool for Microsoft365Tool {
fn name(&self) -> &str {
"microsoft365"
}
fn description(&self) -> &str {
"Microsoft 365 integration: manage Outlook mail, Teams messages, Calendar events, \
OneDrive files, and SharePoint search via Microsoft Graph API"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"required": ["action"],
"properties": {
"action": {
"type": "string",
"enum": [
"mail_list",
"mail_send",
"teams_message_list",
"teams_message_send",
"calendar_events_list",
"calendar_event_create",
"calendar_event_delete",
"onedrive_list",
"onedrive_download",
"sharepoint_search"
],
"description": "The Microsoft 365 action to perform"
},
"folder": {
"type": "string",
"description": "Mail folder ID (for mail_list, e.g. 'inbox', 'sentitems')"
},
"to": {
"type": "array",
"items": { "type": "string" },
"description": "Recipient email addresses (for mail_send)"
},
"subject": {
"type": "string",
"description": "Email subject or calendar event subject"
},
"body": {
"type": "string",
"description": "Message body text"
},
"team_id": {
"type": "string",
"description": "Teams team ID (for teams_message_list/send)"
},
"channel_id": {
"type": "string",
"description": "Teams channel ID (for teams_message_list/send)"
},
"start": {
"type": "string",
"description": "Start datetime in ISO 8601 format (for calendar actions)"
},
"end": {
"type": "string",
"description": "End datetime in ISO 8601 format (for calendar actions)"
},
"attendees": {
"type": "array",
"items": { "type": "string" },
"description": "Attendee email addresses (for calendar_event_create)"
},
"event_id": {
"type": "string",
"description": "Calendar event ID (for calendar_event_delete)"
},
"path": {
"type": "string",
"description": "OneDrive folder path (for onedrive_list)"
},
"item_id": {
"type": "string",
"description": "OneDrive item ID (for onedrive_download)"
},
"max_size": {
"type": "integer",
"description": "Maximum download size in bytes (for onedrive_download, default 10MB)"
},
"query": {
"type": "string",
"description": "Search query (for sharepoint_search)"
},
"top": {
"type": "integer",
"description": "Maximum number of items to return (default 25)"
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args["action"].as_str() {
Some(a) => a.to_string(),
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'action' parameter is required".to_string()),
});
}
};
match self.dispatch(&action, &args).await {
Ok(result) => Ok(result),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("microsoft365.{action} failed: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_name_is_microsoft365() {
// Verify the schema is valid JSON with the expected structure.
let schema_str = r#"{"type":"object","required":["action"]}"#;
let _: serde_json::Value = serde_json::from_str(schema_str).unwrap();
}
#[test]
fn parameters_schema_has_action_enum() {
let schema = json!({
"type": "object",
"required": ["action"],
"properties": {
"action": {
"type": "string",
"enum": [
"mail_list",
"mail_send",
"teams_message_list",
"teams_message_send",
"calendar_events_list",
"calendar_event_create",
"calendar_event_delete",
"onedrive_list",
"onedrive_download",
"sharepoint_search"
]
}
}
});
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
assert_eq!(actions.len(), 10);
assert!(actions.contains(&json!("mail_list")));
assert!(actions.contains(&json!("sharepoint_search")));
}
#[test]
fn action_dispatch_table_is_exhaustive() {
let valid_actions = [
"mail_list",
"mail_send",
"teams_message_list",
"teams_message_send",
"calendar_events_list",
"calendar_event_create",
"calendar_event_delete",
"onedrive_list",
"onedrive_download",
"sharepoint_search",
];
assert_eq!(valid_actions.len(), 10);
assert!(!valid_actions.contains(&"invalid_action"));
}
}
+55
View File
@@ -0,0 +1,55 @@
use serde::{Deserialize, Serialize};
/// Resolved Microsoft 365 configuration with all secrets decrypted and defaults applied.
#[derive(Clone, Serialize, Deserialize)]
pub struct Microsoft365ResolvedConfig {
pub tenant_id: String,
pub client_id: String,
pub client_secret: Option<String>,
pub auth_flow: String,
pub scopes: Vec<String>,
pub token_cache_encrypted: bool,
pub user_id: String,
}
impl std::fmt::Debug for Microsoft365ResolvedConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Microsoft365ResolvedConfig")
.field("tenant_id", &self.tenant_id)
.field("client_id", &self.client_id)
.field("client_secret", &self.client_secret.as_ref().map(|_| "***"))
.field("auth_flow", &self.auth_flow)
.field("scopes", &self.scopes)
.field("token_cache_encrypted", &self.token_cache_encrypted)
.field("user_id", &self.user_id)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolved_config_serialization_roundtrip() {
let config = Microsoft365ResolvedConfig {
tenant_id: "test-tenant".into(),
client_id: "test-client".into(),
client_secret: Some("secret".into()),
auth_flow: "client_credentials".into(),
scopes: vec!["https://graph.microsoft.com/.default".into()],
token_cache_encrypted: false,
user_id: "me".into(),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: Microsoft365ResolvedConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.tenant_id, "test-tenant");
assert_eq!(parsed.client_id, "test-client");
assert_eq!(parsed.client_secret.as_deref(), Some("secret"));
assert_eq!(parsed.auth_flow, "client_credentials");
assert_eq!(parsed.scopes.len(), 1);
assert_eq!(parsed.user_id, "me");
}
}
+88 -1
View File
@@ -48,11 +48,15 @@ pub mod mcp_transport;
pub mod memory_forget;
pub mod memory_recall;
pub mod memory_store;
pub mod microsoft365;
pub mod model_routing_config;
pub mod node_tool;
pub mod notion_tool;
pub mod pdf_read;
pub mod project_intel;
pub mod proxy_config;
pub mod pushover;
pub mod report_templates;
pub mod schedule;
pub mod schema;
pub mod screenshot;
@@ -95,10 +99,13 @@ pub use mcp_tool::McpToolWrapper;
pub use memory_forget::MemoryForgetTool;
pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool;
pub use microsoft365::Microsoft365Tool;
pub use model_routing_config::ModelRoutingConfigTool;
#[allow(unused_imports)]
pub use node_tool::NodeTool;
pub use notion_tool::NotionTool;
pub use pdf_read::PdfReadTool;
pub use project_intel::ProjectIntelTool;
pub use proxy_config::ProxyConfigTool;
pub use pushover::PushoverTool;
pub use schedule::ScheduleTool;
@@ -346,10 +353,35 @@ pub fn all_tools_with_runtime(
)));
}
// Security operations (MCSS) tools
// Notion API tool (conditionally registered)
if root_config.notion.enabled {
let notion_api_key = if root_config.notion.api_key.trim().is_empty() {
std::env::var("NOTION_API_KEY").unwrap_or_default()
} else {
root_config.notion.api_key.trim().to_string()
};
if notion_api_key.trim().is_empty() {
tracing::warn!(
"Notion tool enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
);
} else {
tool_arcs.push(Arc::new(NotionTool::new(notion_api_key, security.clone())));
}
}
// Project delivery intelligence
if root_config.project_intel.enabled {
tool_arcs.push(Arc::new(ProjectIntelTool::new(
root_config.project_intel.default_language.clone(),
root_config.project_intel.risk_sensitivity.clone(),
)));
}
// MCSS Security Operations
if root_config.security_ops.enabled {
tool_arcs.push(Arc::new(SecurityOpsTool::new(
root_config.security_ops.clone(),
security.clone(),
)));
}
@@ -370,6 +402,61 @@ pub fn all_tools_with_runtime(
}
}
// Microsoft 365 Graph API integration
if root_config.microsoft365.enabled {
let ms_cfg = &root_config.microsoft365;
let tenant_id = ms_cfg
.tenant_id
.as_deref()
.unwrap_or_default()
.trim()
.to_string();
let client_id = ms_cfg
.client_id
.as_deref()
.unwrap_or_default()
.trim()
.to_string();
if !tenant_id.is_empty() && !client_id.is_empty() {
// Fail fast: client_credentials flow requires a client_secret at registration time.
if ms_cfg.auth_flow.trim() == "client_credentials"
&& ms_cfg
.client_secret
.as_deref()
.map_or(true, |s| s.trim().is_empty())
{
tracing::error!(
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
);
return (boxed_registry_from_arcs(tool_arcs), None);
}
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
tenant_id,
client_id,
client_secret: ms_cfg.client_secret.clone(),
auth_flow: ms_cfg.auth_flow.clone(),
scopes: ms_cfg.scopes.clone(),
token_cache_encrypted: ms_cfg.token_cache_encrypted,
user_id: ms_cfg.user_id.as_deref().unwrap_or("me").to_string(),
};
// Store token cache in the config directory (next to config.toml),
// not the workspace directory, to keep bearer tokens out of the
// project tree.
let cache_dir = root_config.config_path.parent().unwrap_or(workspace_dir);
match Microsoft365Tool::new(resolved, security.clone(), cache_dir) {
Ok(tool) => tool_arcs.push(Arc::new(tool)),
Err(e) => {
tracing::error!("microsoft365: failed to initialize tool: {e}");
}
}
} else {
tracing::warn!(
"microsoft365: skipped registration because tenant_id or client_id is empty"
);
}
}
// Add delegation tool when agents are configured
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
let trimmed_value = value.trim();
+438
View File
@@ -0,0 +1,438 @@
use super::traits::{Tool, ToolResult};
use crate::security::{policy::ToolOperation, SecurityPolicy};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
const NOTION_VERSION: &str = "2022-06-28";
const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
/// Maximum number of characters to include from an error response body.
const MAX_ERROR_BODY_CHARS: usize = 500;
/// Tool for interacting with the Notion API — query databases, read/create/update pages,
/// and search the workspace. Each action is gated by the appropriate security operation
/// (Read for queries, Act for mutations).
pub struct NotionTool {
api_key: String,
http: reqwest::Client,
security: Arc<SecurityPolicy>,
}
impl NotionTool {
/// Create a new Notion tool with the given API key and security policy.
pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
Self {
api_key,
http: reqwest::Client::new(),
security,
}
}
/// Build the standard Notion API headers (Authorization, version, content-type).
fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
);
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
headers.insert("Content-Type", "application/json".parse().unwrap());
Ok(headers)
}
/// Query a Notion database with an optional filter.
async fn query_database(
&self,
database_id: &str,
filter: Option<&serde_json::Value>,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
let mut body = json!({});
if let Some(f) = filter {
body["filter"] = f.clone();
}
let resp = self
.http
.post(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion query_database failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Read a single Notion page by ID.
async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let resp = self
.http
.get(&url)
.headers(self.headers()?)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion read_page failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Create a new Notion page, optionally within a database.
async fn create_page(
&self,
properties: &serde_json::Value,
database_id: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/pages");
let mut body = json!({ "properties": properties });
if let Some(db_id) = database_id {
body["parent"] = json!({ "database_id": db_id });
}
let resp = self
.http
.post(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion create_page failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Update an existing Notion page's properties.
async fn update_page(
&self,
page_id: &str,
properties: &serde_json::Value,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let body = json!({ "properties": properties });
let resp = self
.http
.patch(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion update_page failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Search the Notion workspace by query string.
async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/search");
let body = json!({ "query": query });
let resp = self
.http
.post(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion search failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
}
#[async_trait]
impl Tool for NotionTool {
fn name(&self) -> &str {
"notion"
}
fn description(&self) -> &str {
"Interact with Notion: query databases, read/create/update pages, and search the workspace."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["query_database", "read_page", "create_page", "update_page", "search"],
"description": "The Notion API action to perform"
},
"database_id": {
"type": "string",
"description": "Database ID (required for query_database, optional for create_page)"
},
"page_id": {
"type": "string",
"description": "Page ID (required for read_page and update_page)"
},
"filter": {
"type": "object",
"description": "Notion filter object for query_database"
},
"properties": {
"type": "object",
"description": "Properties object for create_page and update_page"
},
"query": {
"type": "string",
"description": "Search query string for the search action"
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|v| v.as_str()) {
Some(a) => a,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: action".into()),
});
}
};
// Enforce granular security: Read for queries, Act for mutations
let operation = match action {
"query_database" | "read_page" | "search" => ToolOperation::Read,
"create_page" | "update_page" => ToolOperation::Act,
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
)),
});
}
};
if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let result = match action {
"query_database" => {
let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("query_database requires database_id parameter".into()),
});
}
};
let filter = args.get("filter");
self.query_database(database_id, filter).await
}
"read_page" => {
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("read_page requires page_id parameter".into()),
});
}
};
self.read_page(page_id).await
}
"create_page" => {
let properties = match args.get("properties") {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("create_page requires properties parameter".into()),
});
}
};
let database_id = args.get("database_id").and_then(|v| v.as_str());
self.create_page(properties, database_id).await
}
"update_page" => {
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("update_page requires page_id parameter".into()),
});
}
};
let properties = match args.get("properties") {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("update_page requires properties parameter".into()),
});
}
};
self.update_page(page_id, properties).await
}
"search" => {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
self.search(query).await
}
_ => unreachable!(), // Already handled above
};
match result {
Ok(value) => Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::SecurityPolicy;
fn test_tool() -> NotionTool {
let security = Arc::new(SecurityPolicy::default());
NotionTool::new("test-key".into(), security)
}
#[test]
fn tool_name_is_notion() {
let tool = test_tool();
assert_eq!(tool.name(), "notion");
}
#[test]
fn parameters_schema_has_required_action() {
let tool = test_tool();
let schema = tool.parameters_schema();
let required = schema["required"].as_array().unwrap();
assert!(required.iter().any(|v| v.as_str() == Some("action")));
}
#[test]
fn parameters_schema_defines_all_actions() {
let tool = test_tool();
let schema = tool.parameters_schema();
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
assert!(action_strs.contains(&"query_database"));
assert!(action_strs.contains(&"read_page"));
assert!(action_strs.contains(&"create_page"));
assert!(action_strs.contains(&"update_page"));
assert!(action_strs.contains(&"search"));
}
#[tokio::test]
async fn execute_missing_action_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("action"));
}
#[tokio::test]
async fn execute_unknown_action_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("Unknown action"));
}
#[tokio::test]
async fn execute_query_database_missing_id_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "query_database"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("database_id"));
}
#[tokio::test]
async fn execute_read_page_missing_id_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("page_id"));
}
#[tokio::test]
async fn execute_create_page_missing_properties_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "create_page"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("properties"));
}
#[tokio::test]
async fn execute_update_page_missing_page_id_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "update_page", "properties": {}}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("page_id"));
}
#[tokio::test]
async fn execute_update_page_missing_properties_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "update_page", "page_id": "test-id"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("properties"));
}
}
+750
View File
@@ -0,0 +1,750 @@
//! Project delivery intelligence tool.
//!
//! Provides read-only analysis and generation for project management:
//! status reports, risk detection, client communication drafting,
//! sprint summaries, and effort estimation.
use super::report_templates;
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::fmt::Write as _;
/// Project intelligence tool for consulting project management.
///
/// All actions are read-only analysis/generation; nothing is modified externally.
pub struct ProjectIntelTool {
default_language: String,
risk_sensitivity: RiskSensitivity,
}
/// Risk detection sensitivity level.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RiskSensitivity {
Low,
Medium,
High,
}
impl RiskSensitivity {
fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"low" => Self::Low,
"high" => Self::High,
_ => Self::Medium,
}
}
/// Threshold multiplier: higher sensitivity means lower thresholds.
fn threshold_factor(self) -> f64 {
match self {
Self::Low => 1.5,
Self::Medium => 1.0,
Self::High => 0.5,
}
}
}
impl ProjectIntelTool {
pub fn new(default_language: String, risk_sensitivity: String) -> Self {
Self {
default_language,
risk_sensitivity: RiskSensitivity::from_str(&risk_sensitivity),
}
}
fn execute_status_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let project_name = args
.get("project_name")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for status_report"))?;
let period = args
.get("period")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'period' for status_report"))?;
let lang = args
.get("language")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_language);
let git_log = args
.get("git_log")
.and_then(|v| v.as_str())
.unwrap_or("No git data provided");
let jira_summary = args
.get("jira_summary")
.and_then(|v| v.as_str())
.unwrap_or("No Jira data provided");
let notes = args.get("notes").and_then(|v| v.as_str()).unwrap_or("");
let tpl = report_templates::weekly_status_template(lang);
let mut vars = HashMap::new();
vars.insert("project_name".into(), project_name.to_string());
vars.insert("period".into(), period.to_string());
vars.insert("completed".into(), git_log.to_string());
vars.insert("in_progress".into(), jira_summary.to_string());
vars.insert("blocked".into(), notes.to_string());
vars.insert("next_steps".into(), "To be determined".into());
let rendered = tpl.render(&vars);
Ok(ToolResult {
success: true,
output: rendered,
error: None,
})
}
fn execute_risk_scan(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let deadlines = args
.get("deadlines")
.and_then(|v| v.as_str())
.unwrap_or_default();
let velocity = args
.get("velocity")
.and_then(|v| v.as_str())
.unwrap_or_default();
let blockers = args
.get("blockers")
.and_then(|v| v.as_str())
.unwrap_or_default();
let lang = args
.get("language")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_language);
let mut risks = Vec::new();
// Heuristic risk detection based on signals
let factor = self.risk_sensitivity.threshold_factor();
if !blockers.is_empty() {
let blocker_count = blockers.lines().filter(|l| !l.trim().is_empty()).count();
let severity = if (blocker_count as f64) > 3.0 * factor {
"critical"
} else if (blocker_count as f64) > 1.0 * factor {
"high"
} else {
"medium"
};
risks.push(RiskItem {
title: "Active blockers detected".into(),
severity: severity.into(),
detail: format!("{blocker_count} blocker(s) identified"),
mitigation: "Escalate blockers, assign owners, set resolution deadlines".into(),
});
}
if deadlines.to_lowercase().contains("overdue")
|| deadlines.to_lowercase().contains("missed")
{
risks.push(RiskItem {
title: "Deadline risk".into(),
severity: "high".into(),
detail: "Overdue or missed deadlines detected in project context".into(),
mitigation: "Re-prioritize scope, negotiate timeline, add resources".into(),
});
}
if velocity.to_lowercase().contains("declining") || velocity.to_lowercase().contains("slow")
{
risks.push(RiskItem {
title: "Velocity degradation".into(),
severity: "medium".into(),
detail: "Team velocity is declining or below expectations".into(),
mitigation: "Identify bottlenecks, reduce WIP, address technical debt".into(),
});
}
if risks.is_empty() {
risks.push(RiskItem {
title: "No significant risks detected".into(),
severity: "low".into(),
detail: "Current project signals within normal parameters".into(),
mitigation: "Continue monitoring".into(),
});
}
let tpl = report_templates::risk_register_template(lang);
let risks_text = risks
.iter()
.map(|r| {
format!(
"- [{}] {}: {}",
r.severity.to_uppercase(),
r.title,
r.detail
)
})
.collect::<Vec<_>>()
.join("\n");
let mitigations_text = risks
.iter()
.map(|r| format!("- {}: {}", r.title, r.mitigation))
.collect::<Vec<_>>()
.join("\n");
let mut vars = HashMap::new();
vars.insert(
"project_name".into(),
args.get("project_name")
.and_then(|v| v.as_str())
.unwrap_or("Unknown")
.to_string(),
);
vars.insert("risks".into(), risks_text);
vars.insert("mitigations".into(), mitigations_text);
Ok(ToolResult {
success: true,
output: tpl.render(&vars),
error: None,
})
}
fn execute_draft_update(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let project_name = args
.get("project_name")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for draft_update"))?;
let audience = args
.get("audience")
.and_then(|v| v.as_str())
.unwrap_or("client");
let tone = args
.get("tone")
.and_then(|v| v.as_str())
.unwrap_or("formal");
let highlights = args
.get("highlights")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'highlights' for draft_update"))?;
let concerns = args.get("concerns").and_then(|v| v.as_str()).unwrap_or("");
let greeting = match (audience, tone) {
("client", "casual") => "Hi there,".to_string(),
("client", _) => "Dear valued partner,".to_string(),
("internal", "casual") => "Hey team,".to_string(),
("internal", _) => "Dear team,".to_string(),
(_, "casual") => "Hi,".to_string(),
_ => "Dear reader,".to_string(),
};
let closing = match tone {
"casual" => "Cheers",
_ => "Best regards",
};
let mut body = format!(
"{greeting}\n\nHere is an update on {project_name}.\n\n**Highlights:**\n{highlights}"
);
if !concerns.is_empty() {
let _ = write!(body, "\n\n**Items requiring attention:**\n{concerns}");
}
let _ = write!(
body,
"\n\nPlease do not hesitate to reach out with any questions.\n\n{closing}"
);
Ok(ToolResult {
success: true,
output: body,
error: None,
})
}
fn execute_sprint_summary(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let sprint_dates = args
.get("sprint_dates")
.and_then(|v| v.as_str())
.unwrap_or("current sprint");
let completed = args
.get("completed")
.and_then(|v| v.as_str())
.unwrap_or("None specified");
let in_progress = args
.get("in_progress")
.and_then(|v| v.as_str())
.unwrap_or("None specified");
let blocked = args
.get("blocked")
.and_then(|v| v.as_str())
.unwrap_or("None");
let velocity = args
.get("velocity")
.and_then(|v| v.as_str())
.unwrap_or("Not calculated");
let lang = args
.get("language")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_language);
let tpl = report_templates::sprint_review_template(lang);
let mut vars = HashMap::new();
vars.insert("sprint_dates".into(), sprint_dates.to_string());
vars.insert("completed".into(), completed.to_string());
vars.insert("in_progress".into(), in_progress.to_string());
vars.insert("blocked".into(), blocked.to_string());
vars.insert("velocity".into(), velocity.to_string());
Ok(ToolResult {
success: true,
output: tpl.render(&vars),
error: None,
})
}
fn execute_effort_estimate(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let tasks = args.get("tasks").and_then(|v| v.as_str()).unwrap_or("");
if tasks.trim().is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("No task descriptions provided".into()),
});
}
let mut estimates = Vec::new();
for line in tasks.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let (size, rationale) = estimate_task_effort(line);
estimates.push(format!("- **{size}** | {line}\n Rationale: {rationale}"));
}
let output = format!(
"## Effort Estimates\n\n{}\n\n_Sizes: XS (<2h), S (2-4h), M (4-8h), L (1-3d), XL (3-5d), XXL (>5d)_",
estimates.join("\n")
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
struct RiskItem {
title: String,
severity: String,
detail: String,
mitigation: String,
}
/// Heuristic effort estimation from task description text.
fn estimate_task_effort(description: &str) -> (&'static str, &'static str) {
let lower = description.to_lowercase();
let word_count = description.split_whitespace().count();
// Signal-based heuristics
let complexity_signals = [
"refactor",
"rewrite",
"migrate",
"redesign",
"architecture",
"infrastructure",
];
let medium_signals = [
"implement",
"create",
"build",
"integrate",
"add feature",
"new module",
];
let small_signals = [
"fix", "update", "tweak", "adjust", "rename", "typo", "bump", "config",
];
if complexity_signals.iter().any(|s| lower.contains(s)) {
if word_count > 15 {
return (
"XXL",
"Large-scope structural change with extensive description",
);
}
return ("XL", "Structural change requiring significant effort");
}
if medium_signals.iter().any(|s| lower.contains(s)) {
if word_count > 12 {
return ("L", "Feature implementation with detailed requirements");
}
return ("M", "Standard feature implementation");
}
if small_signals.iter().any(|s| lower.contains(s)) {
if word_count > 10 {
return ("S", "Small change with additional context");
}
return ("XS", "Minor targeted change");
}
// Fallback: estimate by description length as a proxy for complexity
if word_count > 20 {
("L", "Complex task inferred from detailed description")
} else if word_count > 10 {
("M", "Moderate task inferred from description length")
} else {
("S", "Simple task inferred from brief description")
}
}
#[async_trait]
impl Tool for ProjectIntelTool {
fn name(&self) -> &str {
"project_intel"
}
fn description(&self) -> &str {
"Project delivery intelligence: generate status reports, detect risks, draft client updates, summarize sprints, and estimate effort. Read-only analysis tool."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["status_report", "risk_scan", "draft_update", "sprint_summary", "effort_estimate"],
"description": "The analysis action to perform"
},
"project_name": {
"type": "string",
"description": "Project name (for status_report, risk_scan, draft_update)"
},
"period": {
"type": "string",
"description": "Reporting period: week, sprint, or month (for status_report)"
},
"language": {
"type": "string",
"description": "Report language: en, de, fr, it (default from config)"
},
"git_log": {
"type": "string",
"description": "Git log summary text (for status_report)"
},
"jira_summary": {
"type": "string",
"description": "Jira/issue tracker summary (for status_report)"
},
"notes": {
"type": "string",
"description": "Additional notes or context"
},
"deadlines": {
"type": "string",
"description": "Deadline information (for risk_scan)"
},
"velocity": {
"type": "string",
"description": "Team velocity data (for risk_scan, sprint_summary)"
},
"blockers": {
"type": "string",
"description": "Current blockers (for risk_scan)"
},
"audience": {
"type": "string",
"enum": ["client", "internal"],
"description": "Target audience (for draft_update)"
},
"tone": {
"type": "string",
"enum": ["formal", "casual"],
"description": "Communication tone (for draft_update)"
},
"highlights": {
"type": "string",
"description": "Key highlights for the update (for draft_update)"
},
"concerns": {
"type": "string",
"description": "Items requiring attention (for draft_update)"
},
"sprint_dates": {
"type": "string",
"description": "Sprint date range (for sprint_summary)"
},
"completed": {
"type": "string",
"description": "Completed items (for sprint_summary)"
},
"in_progress": {
"type": "string",
"description": "In-progress items (for sprint_summary)"
},
"blocked": {
"type": "string",
"description": "Blocked items (for sprint_summary)"
},
"tasks": {
"type": "string",
"description": "Task descriptions, one per line (for effort_estimate)"
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args
.get("action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
match action {
"status_report" => self.execute_status_report(&args),
"risk_scan" => self.execute_risk_scan(&args),
"draft_update" => self.execute_draft_update(&args),
"sprint_summary" => self.execute_sprint_summary(&args),
"effort_estimate" => self.execute_effort_estimate(&args),
other => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action '{other}'. Valid actions: status_report, risk_scan, draft_update, sprint_summary, effort_estimate"
)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tool() -> ProjectIntelTool {
ProjectIntelTool::new("en".into(), "medium".into())
}
#[test]
fn tool_name_and_description() {
let t = tool();
assert_eq!(t.name(), "project_intel");
assert!(!t.description().is_empty());
}
#[test]
fn parameters_schema_has_action() {
let t = tool();
let schema = t.parameters_schema();
assert!(schema["properties"]["action"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&serde_json::Value::String("action".into())));
}
#[tokio::test]
async fn status_report_renders() {
let t = tool();
let result = t
.execute(json!({
"action": "status_report",
"project_name": "TestProject",
"period": "week",
"git_log": "- feat: added login"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("TestProject"));
assert!(result.output.contains("added login"));
}
#[tokio::test]
async fn risk_scan_detects_blockers() {
let t = tool();
let result = t
.execute(json!({
"action": "risk_scan",
"blockers": "DB migration stuck\nCI pipeline broken\nAPI key expired"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("blocker"));
}
#[tokio::test]
async fn risk_scan_detects_deadline_risk() {
let t = tool();
let result = t
.execute(json!({
"action": "risk_scan",
"deadlines": "Sprint deadline overdue by 3 days"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Deadline risk"));
}
#[tokio::test]
async fn risk_scan_no_signals_returns_low_risk() {
let t = tool();
let result = t.execute(json!({ "action": "risk_scan" })).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No significant risks"));
}
#[tokio::test]
async fn draft_update_formal_client() {
let t = tool();
let result = t
.execute(json!({
"action": "draft_update",
"project_name": "Portal",
"audience": "client",
"tone": "formal",
"highlights": "Phase 1 delivered"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Dear valued partner"));
assert!(result.output.contains("Portal"));
assert!(result.output.contains("Phase 1 delivered"));
}
#[tokio::test]
async fn draft_update_casual_internal() {
let t = tool();
let result = t
.execute(json!({
"action": "draft_update",
"project_name": "ZeroClaw",
"audience": "internal",
"tone": "casual",
"highlights": "Core loop stabilized"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Hey team"));
assert!(result.output.contains("Cheers"));
}
#[tokio::test]
async fn sprint_summary_renders() {
let t = tool();
let result = t
.execute(json!({
"action": "sprint_summary",
"sprint_dates": "2026-03-01 to 2026-03-14",
"completed": "- Login page\n- API endpoints",
"in_progress": "- Dashboard",
"blocked": "- Payment integration"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Login page"));
assert!(result.output.contains("Dashboard"));
}
#[tokio::test]
async fn effort_estimate_basic() {
let t = tool();
let result = t
.execute(json!({
"action": "effort_estimate",
"tasks": "Fix typo in README\nImplement user authentication\nRefactor database layer"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("XS"));
assert!(result.output.contains("Refactor database layer"));
}
#[tokio::test]
async fn effort_estimate_empty_tasks_fails() {
let t = tool();
let result = t
.execute(json!({ "action": "effort_estimate", "tasks": "" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("No task descriptions"));
}
#[tokio::test]
async fn unknown_action_returns_error() {
let t = tool();
let result = t
.execute(json!({ "action": "invalid_thing" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}
#[tokio::test]
async fn missing_action_returns_error() {
let t = tool();
let result = t.execute(json!({})).await;
assert!(result.is_err());
}
#[test]
fn effort_estimate_heuristics_coverage() {
assert_eq!(estimate_task_effort("Fix typo").0, "XS");
assert_eq!(estimate_task_effort("Update config values").0, "XS");
assert_eq!(
estimate_task_effort("Implement new notification system").0,
"M"
);
assert_eq!(
estimate_task_effort("Refactor the entire authentication module").0,
"XL"
);
assert_eq!(
estimate_task_effort("Migrate the database schema to support multi-tenancy with data isolation and proper indexing across all services").0,
"XXL"
);
}
#[test]
fn risk_sensitivity_threshold_ordering() {
assert!(
RiskSensitivity::High.threshold_factor() < RiskSensitivity::Medium.threshold_factor()
);
assert!(
RiskSensitivity::Medium.threshold_factor() < RiskSensitivity::Low.threshold_factor()
);
}
#[test]
fn risk_sensitivity_from_str_variants() {
assert_eq!(RiskSensitivity::from_str("low"), RiskSensitivity::Low);
assert_eq!(RiskSensitivity::from_str("high"), RiskSensitivity::High);
assert_eq!(RiskSensitivity::from_str("medium"), RiskSensitivity::Medium);
assert_eq!(
RiskSensitivity::from_str("unknown"),
RiskSensitivity::Medium
);
}
#[tokio::test]
async fn high_sensitivity_detects_single_blocker_as_high() {
let t = ProjectIntelTool::new("en".into(), "high".into());
let result = t
.execute(json!({
"action": "risk_scan",
"blockers": "Single blocker"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("[HIGH]") || result.output.contains("[CRITICAL]"));
}
}
+582
View File
@@ -0,0 +1,582 @@
//! Report template engine for project delivery intelligence.
//!
//! Provides built-in templates for weekly status, sprint review, risk register,
//! and milestone reports with multi-language support (EN, DE, FR, IT).
use std::collections::HashMap;
use std::fmt::Write as _;
/// Supported report output formats.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReportFormat {
Markdown,
Html,
}
/// A named section within a report template.
#[derive(Debug, Clone)]
pub struct TemplateSection {
pub heading: String,
pub body: String,
}
/// A report template with named sections and variable placeholders.
#[derive(Debug, Clone)]
pub struct ReportTemplate {
pub name: String,
pub sections: Vec<TemplateSection>,
pub format: ReportFormat,
}
/// Escape a string for safe inclusion in HTML output.
fn escape_html(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#x27;")
}
impl ReportTemplate {
/// Render the template by substituting `{{key}}` placeholders with values.
pub fn render(&self, vars: &HashMap<String, String>) -> String {
let mut out = String::new();
for section in &self.sections {
let heading = substitute(&section.heading, vars);
let body = substitute(&section.body, vars);
match self.format {
ReportFormat::Markdown => {
let _ = write!(out, "## {heading}\n\n{body}\n\n");
}
ReportFormat::Html => {
let heading = escape_html(&heading);
let body = escape_html(&body);
let _ = write!(out, "<h2>{heading}</h2>\n<p>{body}</p>\n");
}
}
}
out.trim_end().to_string()
}
}
/// Single-pass placeholder substitution.
///
/// Scans `template` left-to-right for `{{key}}` tokens and replaces them with
/// the corresponding value from `vars`. Because the scan is single-pass,
/// values that themselves contain `{{...}}` sequences are emitted literally
/// and never re-expanded, preventing injection of new placeholders.
fn substitute(template: &str, vars: &HashMap<String, String>) -> String {
let mut result = String::with_capacity(template.len());
let bytes = template.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' {
// Find the closing `}}`.
if let Some(close) = template[i + 2..].find("}}") {
let key = &template[i + 2..i + 2 + close];
if let Some(value) = vars.get(key) {
result.push_str(value);
} else {
// Unknown placeholder: emit as-is.
result.push_str(&template[i..i + 2 + close + 2]);
}
i += 2 + close + 2;
continue;
}
}
result.push(template.as_bytes()[i] as char);
i += 1;
}
result
}
// ── Built-in templates ────────────────────────────────────────────
/// Return the built-in weekly status template for the given language.
pub fn weekly_status_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Wochenstatus",
vec![
TemplateSection {
heading: "Zusammenfassung".into(),
body: "Projekt: {{project_name}} | Zeitraum: {{period}}".into(),
},
TemplateSection {
heading: "Erledigt".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Bearbeitung".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blockiert".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Naechste Schritte".into(),
body: "{{next_steps}}".into(),
},
],
),
"fr" => (
"Statut hebdomadaire",
vec![
TemplateSection {
heading: "Resume".into(),
body: "Projet: {{project_name}} | Periode: {{period}}".into(),
},
TemplateSection {
heading: "Termine".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "En cours".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloque".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Prochaines etapes".into(),
body: "{{next_steps}}".into(),
},
],
),
"it" => (
"Stato settimanale",
vec![
TemplateSection {
heading: "Riepilogo".into(),
body: "Progetto: {{project_name}} | Periodo: {{period}}".into(),
},
TemplateSection {
heading: "Completato".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In corso".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloccato".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Prossimi passi".into(),
body: "{{next_steps}}".into(),
},
],
),
_ => (
"Weekly Status",
vec![
TemplateSection {
heading: "Summary".into(),
body: "Project: {{project_name}} | Period: {{period}}".into(),
},
TemplateSection {
heading: "Completed".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Progress".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blocked".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Next Steps".into(),
body: "{{next_steps}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
/// Return the built-in sprint review template for the given language.
pub fn sprint_review_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Sprint-Uebersicht",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Erledigt".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Bearbeitung".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blockiert".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocity".into(),
body: "{{velocity}}".into(),
},
],
),
"fr" => (
"Revue de sprint",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Termine".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "En cours".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloque".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocite".into(),
body: "{{velocity}}".into(),
},
],
),
"it" => (
"Revisione sprint",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Completato".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In corso".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloccato".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocita".into(),
body: "{{velocity}}".into(),
},
],
),
_ => (
"Sprint Review",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Completed".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Progress".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blocked".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocity".into(),
body: "{{velocity}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
/// Return the built-in risk register template for the given language.
pub fn risk_register_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Risikoregister",
vec![
TemplateSection {
heading: "Projekt".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Risiken".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Massnahmen".into(),
body: "{{mitigations}}".into(),
},
],
),
"fr" => (
"Registre des risques",
vec![
TemplateSection {
heading: "Projet".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Risques".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Mesures".into(),
body: "{{mitigations}}".into(),
},
],
),
"it" => (
"Registro dei rischi",
vec![
TemplateSection {
heading: "Progetto".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Rischi".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Mitigazioni".into(),
body: "{{mitigations}}".into(),
},
],
),
_ => (
"Risk Register",
vec![
TemplateSection {
heading: "Project".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Risks".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Mitigations".into(),
body: "{{mitigations}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
/// Return the built-in milestone report template for the given language.
pub fn milestone_report_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Meilensteinbericht",
vec![
TemplateSection {
heading: "Projekt".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Meilensteine".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Status".into(),
body: "{{status}}".into(),
},
],
),
"fr" => (
"Rapport de jalons",
vec![
TemplateSection {
heading: "Projet".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Jalons".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Statut".into(),
body: "{{status}}".into(),
},
],
),
"it" => (
"Report milestone",
vec![
TemplateSection {
heading: "Progetto".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Milestone".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Stato".into(),
body: "{{status}}".into(),
},
],
),
_ => (
"Milestone Report",
vec![
TemplateSection {
heading: "Project".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Milestones".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Status".into(),
body: "{{status}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn weekly_status_renders_with_variables() {
let tpl = weekly_status_template("en");
let mut vars = HashMap::new();
vars.insert("project_name".into(), "ZeroClaw".into());
vars.insert("period".into(), "2026-W10".into());
vars.insert("completed".into(), "- Task A\n- Task B".into());
vars.insert("in_progress".into(), "- Task C".into());
vars.insert("blocked".into(), "None".into());
vars.insert("next_steps".into(), "- Task D".into());
let rendered = tpl.render(&vars);
assert!(rendered.contains("Project: ZeroClaw"));
assert!(rendered.contains("Period: 2026-W10"));
assert!(rendered.contains("- Task A"));
assert!(rendered.contains("## Completed"));
}
#[test]
fn weekly_status_de_renders_german_headings() {
let tpl = weekly_status_template("de");
let vars = HashMap::new();
let rendered = tpl.render(&vars);
assert!(rendered.contains("## Zusammenfassung"));
assert!(rendered.contains("## Erledigt"));
}
#[test]
fn weekly_status_fr_renders_french_headings() {
let tpl = weekly_status_template("fr");
let vars = HashMap::new();
let rendered = tpl.render(&vars);
assert!(rendered.contains("## Resume"));
assert!(rendered.contains("## Termine"));
}
#[test]
fn weekly_status_it_renders_italian_headings() {
let tpl = weekly_status_template("it");
let vars = HashMap::new();
let rendered = tpl.render(&vars);
assert!(rendered.contains("## Riepilogo"));
assert!(rendered.contains("## Completato"));
}
#[test]
fn html_format_renders_tags() {
let mut tpl = weekly_status_template("en");
tpl.format = ReportFormat::Html;
let mut vars = HashMap::new();
vars.insert("project_name".into(), "Test".into());
vars.insert("period".into(), "W1".into());
vars.insert("completed".into(), "Done".into());
vars.insert("in_progress".into(), "WIP".into());
vars.insert("blocked".into(), "None".into());
vars.insert("next_steps".into(), "Next".into());
let rendered = tpl.render(&vars);
assert!(rendered.contains("<h2>Summary</h2>"));
assert!(rendered.contains("<p>Project: Test | Period: W1</p>"));
}
#[test]
fn sprint_review_template_has_velocity_section() {
let tpl = sprint_review_template("en");
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
assert!(section_headings.contains(&"Velocity"));
}
#[test]
fn risk_register_template_has_risk_sections() {
let tpl = risk_register_template("en");
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
assert!(section_headings.contains(&"Risks"));
assert!(section_headings.contains(&"Mitigations"));
}
#[test]
fn milestone_template_all_languages() {
for lang in &["en", "de", "fr", "it"] {
let tpl = milestone_report_template(lang);
assert!(!tpl.name.is_empty());
assert_eq!(tpl.sections.len(), 3);
}
}
#[test]
fn substitute_leaves_unknown_placeholders() {
let vars = HashMap::new();
let result = substitute("Hello {{name}}", &vars);
assert_eq!(result, "Hello {{name}}");
}
#[test]
fn substitute_replaces_all_occurrences() {
let mut vars = HashMap::new();
vars.insert("x".into(), "1".into());
let result = substitute("{{x}} and {{x}}", &vars);
assert_eq!(result, "1 and 1");
}
}
+59 -2
View File
@@ -2,6 +2,7 @@ mod cloudflare;
mod custom;
mod ngrok;
mod none;
mod openvpn;
mod tailscale;
pub use cloudflare::CloudflareTunnel;
@@ -9,6 +10,7 @@ pub use custom::CustomTunnel;
pub use ngrok::NgrokTunnel;
#[allow(unused_imports)]
pub use none::NoneTunnel;
pub use openvpn::OpenVpnTunnel;
pub use tailscale::TailscaleTunnel;
use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig};
@@ -104,6 +106,20 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
))))
}
"openvpn" => {
let ov = config
.openvpn
.as_ref()
.ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?;
Ok(Some(Box::new(OpenVpnTunnel::new(
ov.config_file.clone(),
ov.auth_file.clone(),
ov.advertise_address.clone(),
ov.connect_timeout_secs,
ov.extra_args.clone(),
))))
}
"custom" => {
let cu = config
.custom
@@ -116,7 +132,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
))))
}
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"),
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"),
}
}
@@ -126,7 +142,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
mod tests {
use super::*;
use crate::config::schema::{
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig,
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig,
TunnelConfig,
};
use tokio::process::Command;
@@ -315,6 +332,46 @@ mod tests {
assert!(t.public_url().is_none());
}
#[test]
fn factory_openvpn_missing_config_errors() {
let cfg = TunnelConfig {
provider: "openvpn".into(),
..TunnelConfig::default()
};
assert_tunnel_err(&cfg, "[tunnel.openvpn]");
}
#[test]
fn factory_openvpn_with_config_ok() {
let cfg = TunnelConfig {
provider: "openvpn".into(),
openvpn: Some(OpenVpnTunnelConfig {
config_file: "client.ovpn".into(),
auth_file: None,
advertise_address: None,
connect_timeout_secs: 30,
extra_args: vec![],
}),
..TunnelConfig::default()
};
let t = create_tunnel(&cfg).unwrap();
assert!(t.is_some());
assert_eq!(t.unwrap().name(), "openvpn");
}
#[test]
fn openvpn_tunnel_name() {
let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert_eq!(t.name(), "openvpn");
assert!(t.public_url().is_none());
}
#[tokio::test]
async fn openvpn_health_false_before_start() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert!(!tunnel.health_check().await);
}
#[tokio::test]
async fn kill_shared_no_process_is_ok() {
let proc = new_shared_process();
+254
View File
@@ -0,0 +1,254 @@
use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess};
use anyhow::{bail, Result};
use tokio::io::AsyncBufReadExt;
use tokio::process::Command;
/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection.
///
/// Requires the `openvpn` binary installed and accessible. On most systems,
/// OpenVPN requires root/administrator privileges to create tun/tap devices.
///
/// The tunnel exposes the gateway via the VPN network using a configured
/// `advertise_address` (e.g., `"10.8.0.2:42617"`).
pub struct OpenVpnTunnel {
config_file: String,
auth_file: Option<String>,
advertise_address: Option<String>,
connect_timeout_secs: u64,
extra_args: Vec<String>,
proc: SharedProcess,
}
impl OpenVpnTunnel {
/// Create a new OpenVPN tunnel instance.
///
/// * `config_file` — path to the `.ovpn` configuration file.
/// * `auth_file` — optional path to a credentials file for `--auth-user-pass`.
/// * `advertise_address` — optional public address to advertise once connected.
/// * `connect_timeout_secs` — seconds to wait for the initialization sequence.
/// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary.
pub fn new(
config_file: String,
auth_file: Option<String>,
advertise_address: Option<String>,
connect_timeout_secs: u64,
extra_args: Vec<String>,
) -> Self {
Self {
config_file,
auth_file,
advertise_address,
connect_timeout_secs,
extra_args,
proc: new_shared_process(),
}
}
/// Build the openvpn command arguments.
fn build_args(&self) -> Vec<String> {
let mut args = vec!["--config".to_string(), self.config_file.clone()];
if let Some(ref auth) = self.auth_file {
args.push("--auth-user-pass".to_string());
args.push(auth.clone());
}
args.extend(self.extra_args.iter().cloned());
args
}
}
#[async_trait::async_trait]
impl Tunnel for OpenVpnTunnel {
fn name(&self) -> &str {
"openvpn"
}
/// Spawn the `openvpn` process and wait for the "Initialization Sequence
/// Completed" marker on stderr. Returns the public URL on success.
async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
// Validate config file exists before spawning
if !std::path::Path::new(&self.config_file).exists() {
bail!("OpenVPN config file not found: {}", self.config_file);
}
let args = self.build_args();
let mut child = Command::new("openvpn")
.args(&args)
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
// Wait for "Initialization Sequence Completed" in stderr
let stderr = child
.stderr
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?;
let mut reader = tokio::io::BufReader::new(stderr).lines();
let deadline = tokio::time::Instant::now()
+ tokio::time::Duration::from_secs(self.connect_timeout_secs);
let mut connected = false;
while tokio::time::Instant::now() < deadline {
let line =
tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await;
match line {
Ok(Ok(Some(l))) => {
tracing::debug!("openvpn: {l}");
if l.contains("Initialization Sequence Completed") {
connected = true;
break;
}
}
Ok(Ok(None)) => {
bail!("OpenVPN process exited before connection was established");
}
Ok(Err(e)) => {
bail!("Error reading openvpn output: {e}");
}
Err(_) => {
// Timeout on individual line read, continue waiting
}
}
}
if !connected {
child.kill().await.ok();
bail!(
"OpenVPN connection timed out after {}s waiting for initialization",
self.connect_timeout_secs
);
}
let public_url = self
.advertise_address
.clone()
.unwrap_or_else(|| format!("http://{local_host}:{local_port}"));
// Drain stderr in background to prevent OS pipe buffer from filling and
// blocking the openvpn process.
tokio::spawn(async move {
while let Ok(Some(line)) = reader.next_line().await {
tracing::trace!("openvpn: {line}");
}
});
let mut guard = self.proc.lock().await;
*guard = Some(TunnelProcess {
child,
public_url: public_url.clone(),
});
Ok(public_url)
}
/// Kill the openvpn child process and release its resources.
async fn stop(&self) -> Result<()> {
kill_shared(&self.proc).await
}
/// Return `true` if the openvpn child process is still running.
async fn health_check(&self) -> bool {
let guard = self.proc.lock().await;
guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
}
/// Return the public URL if the tunnel has been started.
fn public_url(&self) -> Option<String> {
self.proc
.try_lock()
.ok()
.and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constructor_stores_fields() {
let tunnel = OpenVpnTunnel::new(
"/etc/openvpn/client.ovpn".into(),
Some("/etc/openvpn/auth.txt".into()),
Some("10.8.0.2:42617".into()),
45,
vec!["--verb".into(), "3".into()],
);
assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn");
assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt"));
assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617"));
assert_eq!(tunnel.connect_timeout_secs, 45);
assert_eq!(tunnel.extra_args, vec!["--verb", "3"]);
}
#[test]
fn build_args_basic() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
let args = tunnel.build_args();
assert_eq!(args, vec!["--config", "client.ovpn"]);
}
#[test]
fn build_args_with_auth_and_extras() {
let tunnel = OpenVpnTunnel::new(
"client.ovpn".into(),
Some("auth.txt".into()),
None,
30,
vec!["--verb".into(), "5".into()],
);
let args = tunnel.build_args();
assert_eq!(
args,
vec![
"--config",
"client.ovpn",
"--auth-user-pass",
"auth.txt",
"--verb",
"5"
]
);
}
#[test]
fn public_url_is_none_before_start() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert!(tunnel.public_url().is_none());
}
#[tokio::test]
async fn health_check_is_false_before_start() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert!(!tunnel.health_check().await);
}
#[tokio::test]
async fn stop_without_started_process_is_ok() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
let result = tunnel.stop().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn start_with_missing_config_file_errors() {
let tunnel = OpenVpnTunnel::new(
"/nonexistent/path/to/client.ovpn".into(),
None,
None,
30,
vec![],
);
let result = tunnel.start("127.0.0.1", 8080).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("config file not found"));
}
}