Merge remote-tracking branch 'origin/master' into work/cloud-ops-v2
# Conflicts: # src/config/mod.rs # src/config/schema.rs # src/onboard/wizard.rs # src/tools/mod.rs
This commit is contained in:
commit
e7ad69d69a
@ -30,6 +30,7 @@ pub mod mattermost;
|
||||
pub mod nextcloud_talk;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub mod nostr;
|
||||
pub mod notion;
|
||||
pub mod qq;
|
||||
pub mod session_store;
|
||||
pub mod signal;
|
||||
@ -62,6 +63,7 @@ pub use mattermost::MattermostChannel;
|
||||
pub use nextcloud_talk::NextcloudTalkChannel;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub use nostr::NostrChannel;
|
||||
pub use notion::NotionChannel;
|
||||
pub use qq::QQChannel;
|
||||
pub use signal::SignalChannel;
|
||||
pub use slack::SlackChannel;
|
||||
@ -2981,6 +2983,12 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con
|
||||
channel.name()
|
||||
);
|
||||
}
|
||||
// Notion is a top-level config section, not part of ChannelsConfig
|
||||
{
|
||||
let notion_configured =
|
||||
config.notion.enabled && !config.notion.database_id.trim().is_empty();
|
||||
println!(" {} Notion", if notion_configured { "✅" } else { "❌" });
|
||||
}
|
||||
if !cfg!(feature = "channel-matrix") {
|
||||
println!(
|
||||
" ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)."
|
||||
@ -3413,6 +3421,34 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
// Notion database poller channel
|
||||
if config.notion.enabled && !config.notion.database_id.trim().is_empty() {
|
||||
let notion_api_key = if config.notion.api_key.trim().is_empty() {
|
||||
std::env::var("NOTION_API_KEY").unwrap_or_default()
|
||||
} else {
|
||||
config.notion.api_key.trim().to_string()
|
||||
};
|
||||
if notion_api_key.trim().is_empty() {
|
||||
tracing::warn!(
|
||||
"Notion channel enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
|
||||
);
|
||||
} else {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Notion",
|
||||
channel: Arc::new(NotionChannel::new(
|
||||
notion_api_key,
|
||||
config.notion.database_id.clone(),
|
||||
config.notion.poll_interval_secs,
|
||||
config.notion.status_property.clone(),
|
||||
config.notion.input_property.clone(),
|
||||
config.notion.result_property.clone(),
|
||||
config.notion.max_concurrent,
|
||||
config.notion.recover_stale,
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
channels
|
||||
}
|
||||
|
||||
|
||||
614
src/channels/notion.rs
Normal file
614
src/channels/notion.rs
Normal file
@ -0,0 +1,614 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
|
||||
const NOTION_VERSION: &str = "2022-06-28";
|
||||
const MAX_RESULT_LENGTH: usize = 2000;
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_BASE_DELAY_MS: u64 = 2000;
|
||||
/// Maximum number of characters to include from an error response body.
|
||||
const MAX_ERROR_BODY_CHARS: usize = 500;
|
||||
|
||||
/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary.
|
||||
fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
|
||||
if max_bytes >= s.len() {
|
||||
return s.len();
|
||||
}
|
||||
let mut idx = max_bytes;
|
||||
while idx > 0 && !s.is_char_boundary(idx) {
|
||||
idx -= 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
/// Notion channel — polls a Notion database for pending tasks and writes results back.
|
||||
///
|
||||
/// The channel connects to the Notion API, queries a database for rows with a "pending"
|
||||
/// status, dispatches them as channel messages, and writes results back when processing
|
||||
/// completes. It supports crash recovery by resetting stale "running" tasks on startup.
|
||||
pub struct NotionChannel {
|
||||
api_key: String,
|
||||
database_id: String,
|
||||
poll_interval_secs: u64,
|
||||
status_property: String,
|
||||
input_property: String,
|
||||
result_property: String,
|
||||
max_concurrent: usize,
|
||||
status_type: Arc<RwLock<String>>,
|
||||
inflight: Arc<RwLock<HashSet<String>>>,
|
||||
http: reqwest::Client,
|
||||
recover_stale: bool,
|
||||
}
|
||||
|
||||
impl NotionChannel {
|
||||
/// Create a new Notion channel with the given configuration.
|
||||
pub fn new(
|
||||
api_key: String,
|
||||
database_id: String,
|
||||
poll_interval_secs: u64,
|
||||
status_property: String,
|
||||
input_property: String,
|
||||
result_property: String,
|
||||
max_concurrent: usize,
|
||||
recover_stale: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
database_id,
|
||||
poll_interval_secs,
|
||||
status_property,
|
||||
input_property,
|
||||
result_property,
|
||||
max_concurrent,
|
||||
status_type: Arc::new(RwLock::new("select".to_string())),
|
||||
inflight: Arc::new(RwLock::new(HashSet::new())),
|
||||
http: reqwest::Client::new(),
|
||||
recover_stale,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the standard Notion API headers (Authorization, version, content-type).
|
||||
fn headers(&self) -> Result<reqwest::header::HeaderMap> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", self.api_key)
|
||||
.parse()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
|
||||
);
|
||||
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
|
||||
headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx).
|
||||
async fn api_call(
|
||||
&self,
|
||||
method: reqwest::Method,
|
||||
url: &str,
|
||||
body: Option<serde_json::Value>,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut last_err = None;
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
let mut req = self
|
||||
.http
|
||||
.request(method.clone(), url)
|
||||
.headers(self.headers()?);
|
||||
if let Some(ref b) = body {
|
||||
req = req.json(b);
|
||||
}
|
||||
match req.send().await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
return resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse response: {e}"));
|
||||
}
|
||||
let status_code = status.as_u16();
|
||||
// Only retry on 429 (rate limit) or 5xx (server errors)
|
||||
if status_code != 429 && (400..500).contains(&status_code) {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
let truncated =
|
||||
crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
|
||||
bail!("Notion API error {status_code}: {truncated}");
|
||||
}
|
||||
last_err = Some(anyhow::anyhow!("Notion API error: {status_code}"));
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(anyhow::anyhow!("HTTP request failed: {e}"));
|
||||
}
|
||||
}
|
||||
let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
|
||||
tracing::warn!(
|
||||
"Notion API call failed (attempt {}/{}), retrying in {}ms",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
delay
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
|
||||
}
|
||||
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries")))
|
||||
}
|
||||
|
||||
/// Query the database schema and detect whether Status uses "select" or "status" type.
|
||||
async fn detect_status_type(&self) -> Result<String> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
|
||||
let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
|
||||
let status_type = resp
|
||||
.get("properties")
|
||||
.and_then(|p| p.get(&self.status_property))
|
||||
.and_then(|s| s.get("type"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("select")
|
||||
.to_string();
|
||||
Ok(status_type)
|
||||
}
|
||||
|
||||
/// Query for rows where Status = "pending".
|
||||
async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let filter = build_status_filter(&self.status_property, &status_type, "pending");
|
||||
let resp = self
|
||||
.api_call(
|
||||
reqwest::Method::POST,
|
||||
&url,
|
||||
Some(serde_json::json!({ "filter": filter })),
|
||||
)
|
||||
.await?;
|
||||
Ok(resp
|
||||
.get("results")
|
||||
.and_then(|r| r.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Atomically claim a task. Returns true if this caller got it.
|
||||
async fn claim_task(&self, page_id: &str) -> bool {
|
||||
let mut inflight = self.inflight.write().await;
|
||||
if inflight.contains(page_id) {
|
||||
return false;
|
||||
}
|
||||
if inflight.len() >= self.max_concurrent {
|
||||
return false;
|
||||
}
|
||||
inflight.insert(page_id.to_string());
|
||||
true
|
||||
}
|
||||
|
||||
/// Release a task from the inflight set.
|
||||
async fn release_task(&self, page_id: &str) {
|
||||
let mut inflight = self.inflight.write().await;
|
||||
inflight.remove(page_id);
|
||||
}
|
||||
|
||||
/// Update a row's status.
|
||||
async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.status_property: build_status_payload(&status_type, status_value),
|
||||
}
|
||||
});
|
||||
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write result text to the Result column.
|
||||
async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.result_property: build_rich_text_payload(result_text),
|
||||
}
|
||||
});
|
||||
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// On startup, reset "running" tasks back to "pending" for crash recovery.
|
||||
async fn recover_stale(&self) -> Result<()> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let filter = build_status_filter(&self.status_property, &status_type, "running");
|
||||
let resp = self
|
||||
.api_call(
|
||||
reqwest::Method::POST,
|
||||
&url,
|
||||
Some(serde_json::json!({ "filter": filter })),
|
||||
)
|
||||
.await?;
|
||||
let stale = resp
|
||||
.get("results")
|
||||
.and_then(|r| r.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
if stale.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
tracing::warn!(
|
||||
"Found {} stale task(s) in 'running' state, resetting to 'pending'",
|
||||
stale.len()
|
||||
);
|
||||
for task in &stale {
|
||||
if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
|
||||
let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.status_property: build_status_payload(&status_type, "pending"),
|
||||
&self.result_property: build_rich_text_payload(
|
||||
"Reset: poller restarted while task was running"
|
||||
),
|
||||
}
|
||||
});
|
||||
let short_id_end = floor_utf8_char_boundary(page_id, 8);
|
||||
let short_id = &page_id[..short_id_end];
|
||||
if let Err(e) = self
|
||||
.api_call(reqwest::Method::PATCH, &page_url, Some(payload))
|
||||
.await
|
||||
{
|
||||
tracing::error!("Could not reset stale task {short_id}: {e}");
|
||||
} else {
|
||||
tracing::info!("Reset stale task {short_id} to pending");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for NotionChannel {
|
||||
fn name(&self) -> &str {
|
||||
"notion"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> Result<()> {
|
||||
// recipient is the page_id for Notion
|
||||
let page_id = &message.recipient;
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.status_property: build_status_payload(&status_type, "done"),
|
||||
&self.result_property: build_rich_text_payload(&message.content),
|
||||
}
|
||||
});
|
||||
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
|
||||
.await?;
|
||||
self.release_task(page_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
// Detect status property type
|
||||
match self.detect_status_type().await {
|
||||
Ok(st) => {
|
||||
tracing::info!("Notion status property type: {st}");
|
||||
*self.status_type.write().await = st;
|
||||
}
|
||||
Err(e) => {
|
||||
bail!("Failed to detect Notion database schema: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
// Crash recovery
|
||||
if self.recover_stale {
|
||||
if let Err(e) = self.recover_stale().await {
|
||||
tracing::error!("Notion stale task recovery failed: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
// Polling loop
|
||||
loop {
|
||||
match self.query_pending().await {
|
||||
Ok(tasks) => {
|
||||
if !tasks.is_empty() {
|
||||
tracing::info!("Notion: found {} pending task(s)", tasks.len());
|
||||
}
|
||||
for task in tasks {
|
||||
let page_id = match task.get("id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let input_text = extract_text_from_property(
|
||||
task.get("properties")
|
||||
.and_then(|p| p.get(&self.input_property)),
|
||||
);
|
||||
|
||||
if input_text.trim().is_empty() {
|
||||
let short_end = floor_utf8_char_boundary(&page_id, 8);
|
||||
tracing::warn!(
|
||||
"Notion: empty input for task {}, skipping",
|
||||
&page_id[..short_end]
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !self.claim_task(&page_id).await {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Set status to running
|
||||
if let Err(e) = self.set_status(&page_id, "running").await {
|
||||
tracing::error!("Notion: failed to set running status: {e}");
|
||||
self.release_task(&page_id).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if tx
|
||||
.send(ChannelMessage {
|
||||
id: page_id.clone(),
|
||||
sender: "notion".into(),
|
||||
reply_target: page_id,
|
||||
content: input_text,
|
||||
channel: "notion".into(),
|
||||
timestamp,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
tracing::info!("Notion channel shutting down");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Notion poll error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
|
||||
self.api_call(reqwest::Method::GET, &url, None)
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ──────────────────────────────────────────────
|
||||
|
||||
/// Build a Notion API filter object for the given status property.
|
||||
fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
|
||||
if status_type == "status" {
|
||||
serde_json::json!({
|
||||
"property": property,
|
||||
"status": { "equals": value }
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"property": property,
|
||||
"select": { "equals": value }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Notion API property-update payload for a status field.
|
||||
fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
|
||||
if status_type == "status" {
|
||||
serde_json::json!({ "status": { "name": value } })
|
||||
} else {
|
||||
serde_json::json!({ "select": { "name": value } })
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Notion API rich-text property payload, truncating if necessary.
|
||||
fn build_rich_text_payload(value: &str) -> serde_json::Value {
|
||||
let truncated = truncate_result(value);
|
||||
serde_json::json!({
|
||||
"rich_text": [{
|
||||
"text": { "content": truncated }
|
||||
}]
|
||||
})
|
||||
}
|
||||
|
||||
/// Truncate result text to fit within the Notion rich-text content limit.
|
||||
fn truncate_result(value: &str) -> String {
|
||||
if value.len() <= MAX_RESULT_LENGTH {
|
||||
return value.to_string();
|
||||
}
|
||||
let cut = MAX_RESULT_LENGTH.saturating_sub(30);
|
||||
// Ensure we cut on a char boundary
|
||||
let end = floor_utf8_char_boundary(value, cut);
|
||||
format!("{}\n\n... [output truncated]", &value[..end])
|
||||
}
|
||||
|
||||
/// Extract plain text from a Notion property (title or rich_text type).
|
||||
fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
|
||||
let Some(prop) = prop else {
|
||||
return String::new();
|
||||
};
|
||||
let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
let array_key = match ptype {
|
||||
"title" => "title",
|
||||
"rich_text" => "rich_text",
|
||||
_ => return String::new(),
|
||||
};
|
||||
prop.get(array_key)
|
||||
.and_then(|arr| arr.as_array())
|
||||
.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_task_deduplication() {
|
||||
let channel = NotionChannel::new(
|
||||
"test-key".into(),
|
||||
"test-db".into(),
|
||||
5,
|
||||
"Status".into(),
|
||||
"Input".into(),
|
||||
"Result".into(),
|
||||
4,
|
||||
false,
|
||||
);
|
||||
|
||||
assert!(channel.claim_task("page-1").await);
|
||||
// Second claim for same page should fail
|
||||
assert!(!channel.claim_task("page-1").await);
|
||||
// Different page should succeed
|
||||
assert!(channel.claim_task("page-2").await);
|
||||
|
||||
// After release, can claim again
|
||||
channel.release_task("page-1").await;
|
||||
assert!(channel.claim_task("page-1").await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_truncation_within_limit() {
|
||||
let short = "hello world";
|
||||
assert_eq!(truncate_result(short), short);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_truncation_over_limit() {
|
||||
let long = "a".repeat(MAX_RESULT_LENGTH + 100);
|
||||
let truncated = truncate_result(&long);
|
||||
assert!(truncated.len() <= MAX_RESULT_LENGTH);
|
||||
assert!(truncated.ends_with("... [output truncated]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_truncation_multibyte_safe() {
|
||||
// Build a string that would cut in the middle of a multibyte char
|
||||
let mut s = String::new();
|
||||
for _ in 0..700 {
|
||||
s.push('\u{6E2C}'); // 3-byte UTF-8 char
|
||||
}
|
||||
let truncated = truncate_result(&s);
|
||||
// Should not panic and should be valid UTF-8
|
||||
assert!(truncated.len() <= MAX_RESULT_LENGTH);
|
||||
assert!(truncated.ends_with("... [output truncated]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_payload_select_type() {
|
||||
let payload = build_status_payload("select", "pending");
|
||||
assert_eq!(
|
||||
payload,
|
||||
serde_json::json!({ "select": { "name": "pending" } })
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_payload_status_type() {
|
||||
let payload = build_status_payload("status", "done");
|
||||
assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rich_text_payload_construction() {
|
||||
let payload = build_rich_text_payload("test output");
|
||||
let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
|
||||
assert_eq!(text, "test output");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_filter_select_type() {
|
||||
let filter = build_status_filter("Status", "select", "pending");
|
||||
assert_eq!(
|
||||
filter,
|
||||
serde_json::json!({
|
||||
"property": "Status",
|
||||
"select": { "equals": "pending" }
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_filter_status_type() {
|
||||
let filter = build_status_filter("Status", "status", "running");
|
||||
assert_eq!(
|
||||
filter,
|
||||
serde_json::json!({
|
||||
"property": "Status",
|
||||
"status": { "equals": "running" }
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_title_property() {
|
||||
let prop = serde_json::json!({
|
||||
"type": "title",
|
||||
"title": [
|
||||
{ "plain_text": "Hello " },
|
||||
{ "plain_text": "World" }
|
||||
]
|
||||
});
|
||||
assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_rich_text_property() {
|
||||
let prop = serde_json::json!({
|
||||
"type": "rich_text",
|
||||
"rich_text": [{ "plain_text": "task content" }]
|
||||
});
|
||||
assert_eq!(extract_text_from_property(Some(&prop)), "task content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_none() {
|
||||
assert_eq!(extract_text_from_property(None), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_unknown_type() {
|
||||
let prop = serde_json::json!({ "type": "number", "number": 42 });
|
||||
assert_eq!(extract_text_from_property(Some(&prop)), "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_task_respects_max_concurrent() {
|
||||
let channel = NotionChannel::new(
|
||||
"test-key".into(),
|
||||
"test-db".into(),
|
||||
5,
|
||||
"Status".into(),
|
||||
"Input".into(),
|
||||
"Result".into(),
|
||||
2, // max_concurrent = 2
|
||||
false,
|
||||
);
|
||||
|
||||
assert!(channel.claim_task("page-1").await);
|
||||
assert!(channel.claim_task("page-2").await);
|
||||
// Third claim should be rejected (at capacity)
|
||||
assert!(!channel.claim_task("page-3").await);
|
||||
|
||||
// After releasing one, can claim again
|
||||
channel.release_task("page-1").await;
|
||||
assert!(channel.claim_task("page-3").await);
|
||||
}
|
||||
}
|
||||
@ -6,17 +6,19 @@ pub mod workspace;
|
||||
pub use schema::{
|
||||
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig,
|
||||
BuiltinHooksConfig, ChannelsConfig, ClassificationRule, CloudOpsConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, DelegateAgentConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GoogleTtsConfig, HardwareConfig, HardwareTransport,
|
||||
HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig,
|
||||
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig,
|
||||
MultimodalConfig, NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig,
|
||||
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
AgentConfig, AuditConfig, AutonomyConfig, BackupConfig, BrowserComputerUseConfig,
|
||||
BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, CloudOpsConfig,
|
||||
ComposioConfig, Config, ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig,
|
||||
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig,
|
||||
PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
|
||||
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
|
||||
TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
|
||||
1129
src/config/schema.rs
1129
src/config/schema.rs
File diff suppressed because it is too large
Load Diff
@ -784,7 +784,7 @@ mod tests {
|
||||
job.prompt = Some("Say hello".into());
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
|
||||
let (success, output) = run_agent_job(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("agent job failed:"));
|
||||
}
|
||||
@ -799,7 +799,7 @@ mod tests {
|
||||
job.prompt = Some("Say hello".into());
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
|
||||
let (success, output) = run_agent_job(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("blocked by security policy"));
|
||||
assert!(output.contains("read-only"));
|
||||
@ -815,7 +815,7 @@ mod tests {
|
||||
job.prompt = Some("Say hello".into());
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
|
||||
let (success, output) = run_agent_job(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("blocked by security policy"));
|
||||
assert!(output.contains("rate limit exceeded"));
|
||||
|
||||
@ -77,7 +77,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
|
||||
max_backoff,
|
||||
move || {
|
||||
let cfg = channels_cfg.clone();
|
||||
async move { crate::channels::start_channels(cfg).await }
|
||||
async move { Box::pin(crate::channels::start_channels(cfg)).await }
|
||||
},
|
||||
));
|
||||
} else {
|
||||
|
||||
@ -910,7 +910,7 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
|
||||
/// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk).
|
||||
async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result<String> {
|
||||
let config = state.config.lock().clone();
|
||||
crate::agent::process_message(config, message).await
|
||||
Box::pin(crate::agent::process_message(config, message)).await
|
||||
}
|
||||
|
||||
/// Webhook request body
|
||||
@ -1238,7 +1238,7 @@ async fn handle_whatsapp_message(
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
// Send reply via WhatsApp
|
||||
if let Err(e) = wa
|
||||
@ -1346,7 +1346,7 @@ async fn handle_linq_webhook(
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
// Send reply via Linq
|
||||
if let Err(e) = linq
|
||||
@ -1438,7 +1438,7 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
// Send reply via WATI
|
||||
if let Err(e) = wati
|
||||
@ -1542,7 +1542,7 @@ async fn handle_nextcloud_talk_webhook(
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
if let Err(e) = nextcloud_talk
|
||||
.send(&SendMessage::new(response, &msg.reply_target))
|
||||
|
||||
@ -58,6 +58,7 @@ pub(crate) mod integrations;
|
||||
pub mod memory;
|
||||
pub(crate) mod migration;
|
||||
pub(crate) mod multimodal;
|
||||
pub mod nodes;
|
||||
pub mod observability;
|
||||
pub(crate) mod onboard;
|
||||
pub mod peripherals;
|
||||
|
||||
@ -844,7 +844,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Auto-start channels if user said yes during wizard
|
||||
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
||||
channels::start_channels(config).await?;
|
||||
Box::pin(channels::start_channels(config)).await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
@ -1189,8 +1189,8 @@ async fn main() -> Result<()> {
|
||||
},
|
||||
|
||||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => channels::start_channels(config).await,
|
||||
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
||||
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
|
||||
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
|
||||
other => channels::handle_command(other, &config).await,
|
||||
},
|
||||
|
||||
|
||||
3
src/nodes/mod.rs
Normal file
3
src/nodes/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod transport;
|
||||
|
||||
pub use transport::NodeTransport;
|
||||
235
src/nodes/transport.rs
Normal file
235
src/nodes/transport.rs
Normal file
@ -0,0 +1,235 @@
|
||||
//! Corporate-friendly secure node transport using standard HTTPS + HMAC-SHA256 authentication.
|
||||
//!
|
||||
//! All inter-node traffic uses plain HTTPS on port 443 — no exotic protocols,
|
||||
//! no custom binary framing, no UDP tunneling. This makes the transport
|
||||
//! compatible with corporate proxies, firewalls, and IT audit expectations.
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use chrono::Utc;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Signs a request payload with HMAC-SHA256.
|
||||
///
|
||||
/// Uses `timestamp` + `nonce` alongside the payload to prevent replay attacks.
|
||||
pub fn sign_request(
|
||||
shared_secret: &str,
|
||||
payload: &[u8],
|
||||
timestamp: i64,
|
||||
nonce: &str,
|
||||
) -> Result<String> {
|
||||
let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes())
|
||||
.map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?;
|
||||
mac.update(×tamp.to_le_bytes());
|
||||
mac.update(nonce.as_bytes());
|
||||
mac.update(payload);
|
||||
Ok(hex::encode(mac.finalize().into_bytes()))
|
||||
}
|
||||
|
||||
/// Verify a signed request, rejecting stale timestamps for replay protection.
|
||||
pub fn verify_request(
|
||||
shared_secret: &str,
|
||||
payload: &[u8],
|
||||
timestamp: i64,
|
||||
nonce: &str,
|
||||
signature: &str,
|
||||
max_age_secs: i64,
|
||||
) -> Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
if (now - timestamp).abs() > max_age_secs {
|
||||
bail!("Request timestamp too old or too far in future");
|
||||
}
|
||||
|
||||
let expected = sign_request(shared_secret, payload, timestamp, nonce)?;
|
||||
Ok(constant_time_eq(expected.as_bytes(), signature.as_bytes()))
|
||||
}
|
||||
|
||||
/// Constant-time comparison to prevent timing attacks.
|
||||
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
|
||||
== 0
|
||||
}
|
||||
|
||||
// ── Node transport client ───────────────────────────────────────
|
||||
|
||||
/// Sends authenticated HTTPS requests to peer nodes.
|
||||
///
|
||||
/// Every outgoing request carries three custom headers:
|
||||
/// - `X-ZeroClaw-Timestamp` — unix epoch seconds
|
||||
/// - `X-ZeroClaw-Nonce` — random UUID v4
|
||||
/// - `X-ZeroClaw-Signature` — HMAC-SHA256 hex digest
|
||||
///
|
||||
/// Incoming requests are verified with the same scheme via [`Self::verify_incoming`].
|
||||
pub struct NodeTransport {
|
||||
http: reqwest::Client,
|
||||
shared_secret: String,
|
||||
max_request_age_secs: i64,
|
||||
}
|
||||
|
||||
impl NodeTransport {
|
||||
pub fn new(shared_secret: String) -> Self {
|
||||
Self {
|
||||
http: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("HTTP client build"),
|
||||
shared_secret,
|
||||
max_request_age_secs: 300, // 5 min replay window
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an authenticated request to a peer node.
|
||||
pub async fn send(
|
||||
&self,
|
||||
node_address: &str,
|
||||
endpoint: &str,
|
||||
payload: serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
let body = serde_json::to_vec(&payload)?;
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let nonce = uuid::Uuid::new_v4().to_string();
|
||||
let signature = sign_request(&self.shared_secret, &body, timestamp, &nonce)?;
|
||||
|
||||
let url = format!("https://{node_address}/api/node-control/{endpoint}");
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("X-ZeroClaw-Timestamp", timestamp.to_string())
|
||||
.header("X-ZeroClaw-Nonce", &nonce)
|
||||
.header("X-ZeroClaw-Signature", &signature)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!(
|
||||
"Node request failed: {} {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Verify an incoming request from a peer node.
|
||||
pub fn verify_incoming(
|
||||
&self,
|
||||
payload: &[u8],
|
||||
timestamp_header: &str,
|
||||
nonce_header: &str,
|
||||
signature_header: &str,
|
||||
) -> Result<bool> {
|
||||
let timestamp: i64 = timestamp_header
|
||||
.parse()
|
||||
.map_err(|_| anyhow::anyhow!("Invalid timestamp header"))?;
|
||||
verify_request(
|
||||
&self.shared_secret,
|
||||
payload,
|
||||
timestamp,
|
||||
nonce_header,
|
||||
signature_header,
|
||||
self.max_request_age_secs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const TEST_SECRET: &str = "test-shared-secret-key";
|
||||
|
||||
#[test]
|
||||
fn sign_request_deterministic() {
|
||||
let sig1 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
|
||||
let sig2 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
|
||||
assert_eq!(sig1, sig2, "Same inputs must produce the same signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_accepts_valid_signature() {
|
||||
let now = Utc::now().timestamp();
|
||||
let sig = sign_request(TEST_SECRET, b"payload", now, "nonce-a").unwrap();
|
||||
let ok = verify_request(TEST_SECRET, b"payload", now, "nonce-a", &sig, 300).unwrap();
|
||||
assert!(ok, "Valid signature must pass verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_rejects_tampered_payload() {
|
||||
let now = Utc::now().timestamp();
|
||||
let sig = sign_request(TEST_SECRET, b"original", now, "nonce-b").unwrap();
|
||||
let ok = verify_request(TEST_SECRET, b"tampered", now, "nonce-b", &sig, 300).unwrap();
|
||||
assert!(!ok, "Tampered payload must fail verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_rejects_expired_timestamp() {
|
||||
let old = Utc::now().timestamp() - 600;
|
||||
let sig = sign_request(TEST_SECRET, b"data", old, "nonce-c").unwrap();
|
||||
let result = verify_request(TEST_SECRET, b"data", old, "nonce-c", &sig, 300);
|
||||
assert!(result.is_err(), "Expired timestamp must be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_rejects_wrong_secret() {
|
||||
let now = Utc::now().timestamp();
|
||||
let sig = sign_request(TEST_SECRET, b"data", now, "nonce-d").unwrap();
|
||||
let ok = verify_request("wrong-secret", b"data", now, "nonce-d", &sig, 300).unwrap();
|
||||
assert!(!ok, "Wrong secret must fail verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constant_time_eq_correctness() {
|
||||
assert!(constant_time_eq(b"abc", b"abc"));
|
||||
assert!(!constant_time_eq(b"abc", b"abd"));
|
||||
assert!(!constant_time_eq(b"abc", b"ab"));
|
||||
assert!(!constant_time_eq(b"", b"a"));
|
||||
assert!(constant_time_eq(b"", b""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_transport_construction() {
|
||||
let transport = NodeTransport::new("secret-key".into());
|
||||
assert_eq!(transport.max_request_age_secs, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_transport_verify_incoming_valid() {
|
||||
let transport = NodeTransport::new(TEST_SECRET.into());
|
||||
let now = Utc::now().timestamp();
|
||||
let payload = b"test-body";
|
||||
let nonce = "incoming-nonce";
|
||||
let sig = sign_request(TEST_SECRET, payload, now, nonce).unwrap();
|
||||
|
||||
let ok = transport
|
||||
.verify_incoming(payload, &now.to_string(), nonce, &sig)
|
||||
.unwrap();
|
||||
assert!(ok, "Valid incoming request must pass verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_transport_verify_incoming_bad_timestamp_header() {
|
||||
let transport = NodeTransport::new(TEST_SECRET.into());
|
||||
let result = transport.verify_incoming(b"body", "not-a-number", "nonce", "sig");
|
||||
assert!(result.is_err(), "Non-numeric timestamp header must error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sign_request_different_nonce_different_signature() {
|
||||
let sig1 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-1").unwrap();
|
||||
let sig2 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-2").unwrap();
|
||||
assert_ne!(
|
||||
sig1, sig2,
|
||||
"Different nonces must produce different signatures"
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -143,7 +143,12 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
extra_headers: std::collections::HashMap::new(),
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
backup: crate::config::BackupConfig::default(),
|
||||
data_retention: crate::config::DataRetentionConfig::default(),
|
||||
cloud_ops: crate::config::CloudOpsConfig::default(),
|
||||
conversational_ai: crate::config::ConversationalAiConfig::default(),
|
||||
security: crate::config::SecurityConfig::default(),
|
||||
security_ops: crate::config::SecurityOpsConfig::default(),
|
||||
runtime: RuntimeConfig::default(),
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
@ -159,12 +164,14 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
tunnel: tunnel_config,
|
||||
gateway: crate::config::GatewayConfig::default(),
|
||||
composio: composio_config,
|
||||
microsoft365: crate::config::Microsoft365Config::default(),
|
||||
secrets: secrets_config,
|
||||
browser: BrowserConfig::default(),
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
@ -179,9 +186,8 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
nodes: crate::config::NodesConfig::default(),
|
||||
workspace: crate::config::WorkspaceConfig::default(),
|
||||
cloud_ops: crate::config::CloudOpsConfig::default(),
|
||||
security_ops: crate::config::SecurityOpsConfig::default(),
|
||||
conversational_ai: crate::config::ConversationalAiConfig::default(),
|
||||
notion: crate::config::NotionConfig::default(),
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@ -505,7 +511,12 @@ async fn run_quick_setup_with_home(
|
||||
extra_headers: std::collections::HashMap::new(),
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
backup: crate::config::BackupConfig::default(),
|
||||
data_retention: crate::config::DataRetentionConfig::default(),
|
||||
cloud_ops: crate::config::CloudOpsConfig::default(),
|
||||
conversational_ai: crate::config::ConversationalAiConfig::default(),
|
||||
security: crate::config::SecurityConfig::default(),
|
||||
security_ops: crate::config::SecurityOpsConfig::default(),
|
||||
runtime: RuntimeConfig::default(),
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
@ -521,12 +532,14 @@ async fn run_quick_setup_with_home(
|
||||
tunnel: crate::config::TunnelConfig::default(),
|
||||
gateway: crate::config::GatewayConfig::default(),
|
||||
composio: ComposioConfig::default(),
|
||||
microsoft365: crate::config::Microsoft365Config::default(),
|
||||
secrets: SecretsConfig::default(),
|
||||
browser: BrowserConfig::default(),
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
@ -541,9 +554,8 @@ async fn run_quick_setup_with_home(
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
nodes: crate::config::NodesConfig::default(),
|
||||
workspace: crate::config::WorkspaceConfig::default(),
|
||||
cloud_ops: crate::config::CloudOpsConfig::default(),
|
||||
security_ops: crate::config::SecurityOpsConfig::default(),
|
||||
conversational_ai: crate::config::ConversationalAiConfig::default(),
|
||||
notion: crate::config::NotionConfig::default(),
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
|
||||
449
src/security/iam_policy.rs
Normal file
449
src/security/iam_policy.rs
Normal file
@ -0,0 +1,449 @@
|
||||
//! IAM-aware policy enforcement for Nevis role-to-permission mapping.
|
||||
//!
|
||||
//! Evaluates tool and workspace access based on Nevis roles using a
|
||||
//! deny-by-default policy model. All policy decisions are audit-logged.
|
||||
|
||||
use super::nevis::NevisIdentity;
|
||||
use anyhow::{bail, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Maps a single Nevis role to ZeroClaw permissions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoleMapping {
|
||||
/// Nevis role name (case-insensitive matching).
|
||||
pub nevis_role: String,
|
||||
/// Tool names this role can access. Use `"all"` to grant all tools.
|
||||
pub zeroclaw_permissions: Vec<String>,
|
||||
/// Workspace names this role can access. Use `"all"` for unrestricted.
|
||||
#[serde(default)]
|
||||
pub workspace_access: Vec<String>,
|
||||
}
|
||||
|
||||
/// Result of a policy evaluation.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PolicyDecision {
|
||||
/// Access is allowed.
|
||||
Allow,
|
||||
/// Access is denied, with reason.
|
||||
Deny(String),
|
||||
}
|
||||
|
||||
impl PolicyDecision {
|
||||
pub fn is_allowed(&self) -> bool {
|
||||
matches!(self, PolicyDecision::Allow)
|
||||
}
|
||||
}
|
||||
|
||||
/// IAM policy engine that maps Nevis roles to ZeroClaw tool permissions.
|
||||
///
|
||||
/// Deny-by-default: if no role mapping grants access, the request is denied.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IamPolicy {
|
||||
/// Compiled role mappings indexed by lowercase Nevis role name.
|
||||
role_map: HashMap<String, CompiledRole>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CompiledRole {
|
||||
/// Whether this role has access to all tools.
|
||||
all_tools: bool,
|
||||
/// Specific tool names this role can access (lowercase).
|
||||
allowed_tools: Vec<String>,
|
||||
/// Whether this role has access to all workspaces.
|
||||
all_workspaces: bool,
|
||||
/// Specific workspace names this role can access (lowercase).
|
||||
allowed_workspaces: Vec<String>,
|
||||
}
|
||||
|
||||
impl IamPolicy {
|
||||
/// Build a policy from role mappings (typically from config).
|
||||
///
|
||||
/// Returns an error if duplicate normalized role names are detected,
|
||||
/// since silent last-wins overwrites can accidentally broaden or revoke access.
|
||||
pub fn from_mappings(mappings: &[RoleMapping]) -> Result<Self> {
|
||||
let mut role_map = HashMap::new();
|
||||
|
||||
for mapping in mappings {
|
||||
let key = mapping.nevis_role.trim().to_ascii_lowercase();
|
||||
if key.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let all_tools = mapping
|
||||
.zeroclaw_permissions
|
||||
.iter()
|
||||
.any(|p| p.eq_ignore_ascii_case("all"));
|
||||
let allowed_tools: Vec<String> = mapping
|
||||
.zeroclaw_permissions
|
||||
.iter()
|
||||
.filter(|p| !p.eq_ignore_ascii_case("all"))
|
||||
.map(|p| p.trim().to_ascii_lowercase())
|
||||
.collect();
|
||||
|
||||
let all_workspaces = mapping
|
||||
.workspace_access
|
||||
.iter()
|
||||
.any(|w| w.eq_ignore_ascii_case("all"));
|
||||
let allowed_workspaces: Vec<String> = mapping
|
||||
.workspace_access
|
||||
.iter()
|
||||
.filter(|w| !w.eq_ignore_ascii_case("all"))
|
||||
.map(|w| w.trim().to_ascii_lowercase())
|
||||
.collect();
|
||||
|
||||
if role_map.contains_key(&key) {
|
||||
bail!(
|
||||
"IAM policy: duplicate role mapping for normalized key '{}' \
|
||||
(from nevis_role '{}') — remove or merge the duplicate entry",
|
||||
key,
|
||||
mapping.nevis_role
|
||||
);
|
||||
}
|
||||
|
||||
role_map.insert(
|
||||
key,
|
||||
CompiledRole {
|
||||
all_tools,
|
||||
allowed_tools,
|
||||
all_workspaces,
|
||||
allowed_workspaces,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self { role_map })
|
||||
}
|
||||
|
||||
/// Evaluate whether an identity is allowed to use a specific tool.
|
||||
///
|
||||
/// Deny-by-default: returns `Deny` unless at least one of the identity's
|
||||
/// roles grants access to the requested tool.
|
||||
pub fn evaluate_tool_access(
|
||||
&self,
|
||||
identity: &NevisIdentity,
|
||||
tool_name: &str,
|
||||
) -> PolicyDecision {
|
||||
let normalized_tool = tool_name.trim().to_ascii_lowercase();
|
||||
if normalized_tool.is_empty() {
|
||||
return PolicyDecision::Deny("empty tool name".into());
|
||||
}
|
||||
|
||||
for role in &identity.roles {
|
||||
let key = role.trim().to_ascii_lowercase();
|
||||
if let Some(compiled) = self.role_map.get(&key) {
|
||||
if compiled.all_tools
|
||||
|| compiled.allowed_tools.iter().any(|t| t == &normalized_tool)
|
||||
{
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
role = %key,
|
||||
tool = %normalized_tool,
|
||||
"IAM policy: tool access ALLOWED"
|
||||
);
|
||||
return PolicyDecision::Allow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let reason = format!(
|
||||
"no role grants access to tool '{normalized_tool}' for user '{}'",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
tool = %normalized_tool,
|
||||
"IAM policy: tool access DENIED"
|
||||
);
|
||||
PolicyDecision::Deny(reason)
|
||||
}
|
||||
|
||||
/// Evaluate whether an identity is allowed to access a specific workspace.
|
||||
///
|
||||
/// Deny-by-default: returns `Deny` unless at least one of the identity's
|
||||
/// roles grants access to the requested workspace.
|
||||
pub fn evaluate_workspace_access(
|
||||
&self,
|
||||
identity: &NevisIdentity,
|
||||
workspace: &str,
|
||||
) -> PolicyDecision {
|
||||
let normalized_ws = workspace.trim().to_ascii_lowercase();
|
||||
if normalized_ws.is_empty() {
|
||||
return PolicyDecision::Deny("empty workspace name".into());
|
||||
}
|
||||
|
||||
for role in &identity.roles {
|
||||
let key = role.trim().to_ascii_lowercase();
|
||||
if let Some(compiled) = self.role_map.get(&key) {
|
||||
if compiled.all_workspaces
|
||||
|| compiled
|
||||
.allowed_workspaces
|
||||
.iter()
|
||||
.any(|w| w == &normalized_ws)
|
||||
{
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
role = %key,
|
||||
workspace = %normalized_ws,
|
||||
"IAM policy: workspace access ALLOWED"
|
||||
);
|
||||
return PolicyDecision::Allow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let reason = format!(
|
||||
"no role grants access to workspace '{normalized_ws}' for user '{}'",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
workspace = %normalized_ws,
|
||||
"IAM policy: workspace access DENIED"
|
||||
);
|
||||
PolicyDecision::Deny(reason)
|
||||
}
|
||||
|
||||
/// Check if the policy has any role mappings configured.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.role_map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_mappings() -> Vec<RoleMapping> {
|
||||
vec![
|
||||
RoleMapping {
|
||||
nevis_role: "admin".into(),
|
||||
zeroclaw_permissions: vec!["all".into()],
|
||||
workspace_access: vec!["all".into()],
|
||||
},
|
||||
RoleMapping {
|
||||
nevis_role: "operator".into(),
|
||||
zeroclaw_permissions: vec![
|
||||
"shell".into(),
|
||||
"file_read".into(),
|
||||
"file_write".into(),
|
||||
"memory_search".into(),
|
||||
],
|
||||
workspace_access: vec!["production".into(), "staging".into()],
|
||||
},
|
||||
RoleMapping {
|
||||
nevis_role: "viewer".into(),
|
||||
zeroclaw_permissions: vec!["file_read".into(), "memory_search".into()],
|
||||
workspace_access: vec!["staging".into()],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn identity_with_roles(roles: Vec<&str>) -> NevisIdentity {
|
||||
NevisIdentity {
|
||||
user_id: "zeroclaw_user".into(),
|
||||
roles: roles.into_iter().map(String::from).collect(),
|
||||
scopes: vec!["openid".into()],
|
||||
mfa_verified: true,
|
||||
session_expiry: u64::MAX,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn admin_gets_all_tools() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "any_tool_name")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn admin_gets_all_workspaces() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "production")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "any_workspace")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operator_gets_subset_of_tools() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["operator"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_tool_access(&identity, "browser")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operator_workspace_access_is_scoped() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["operator"]);
|
||||
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "production")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "staging")
|
||||
.is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_workspace_access(&identity, "development")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn viewer_is_read_only() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["viewer"]);
|
||||
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "memory_search")
|
||||
.is_allowed());
|
||||
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_tool_access(&identity, "file_write")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_by_default_for_unknown_role() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["unknown_role"]);
|
||||
|
||||
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_workspace_access(&identity, "production")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_by_default_for_no_roles() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec![]);
|
||||
|
||||
assert!(!policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_roles_union_permissions() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["viewer", "operator"]);
|
||||
|
||||
// viewer has file_read, operator has shell — both should be accessible
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_matching_is_case_insensitive() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["ADMIN"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_matching_is_case_insensitive() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["operator"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "SHELL").is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "File_Read")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_tool_name_is_denied() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(!policy.evaluate_tool_access(&identity, "").is_allowed());
|
||||
assert!(!policy.evaluate_tool_access(&identity, " ").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_workspace_name_is_denied() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(!policy.evaluate_workspace_access(&identity, "").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_mappings_deny_everything() {
|
||||
let policy = IamPolicy::from_mappings(&[]).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(policy.is_empty());
|
||||
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_decision_deny_contains_reason() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["viewer"]);
|
||||
|
||||
let decision = policy.evaluate_tool_access(&identity, "shell");
|
||||
match decision {
|
||||
PolicyDecision::Deny(reason) => {
|
||||
assert!(reason.contains("shell"));
|
||||
}
|
||||
PolicyDecision::Allow => panic!("expected deny"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_normalized_roles_are_rejected() {
|
||||
let mappings = vec![
|
||||
RoleMapping {
|
||||
nevis_role: "admin".into(),
|
||||
zeroclaw_permissions: vec!["all".into()],
|
||||
workspace_access: vec!["all".into()],
|
||||
},
|
||||
RoleMapping {
|
||||
nevis_role: " ADMIN ".into(),
|
||||
zeroclaw_permissions: vec!["file_read".into()],
|
||||
workspace_access: vec![],
|
||||
},
|
||||
];
|
||||
let err = IamPolicy::from_mappings(&mappings).unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("duplicate role mapping"),
|
||||
"Expected duplicate role error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_role_name_in_mapping_is_skipped() {
|
||||
let mappings = vec![RoleMapping {
|
||||
nevis_role: " ".into(),
|
||||
zeroclaw_permissions: vec!["all".into()],
|
||||
workspace_access: vec![],
|
||||
}];
|
||||
let policy = IamPolicy::from_mappings(&mappings).unwrap();
|
||||
assert!(policy.is_empty());
|
||||
}
|
||||
}
|
||||
@ -29,15 +29,19 @@ pub mod domain_matcher;
|
||||
pub mod estop;
|
||||
#[cfg(target_os = "linux")]
|
||||
pub mod firejail;
|
||||
pub mod iam_policy;
|
||||
#[cfg(feature = "sandbox-landlock")]
|
||||
pub mod landlock;
|
||||
pub mod leak_detector;
|
||||
pub mod nevis;
|
||||
pub mod otp;
|
||||
pub mod pairing;
|
||||
pub mod playbook;
|
||||
pub mod policy;
|
||||
pub mod prompt_guard;
|
||||
pub mod secrets;
|
||||
pub mod traits;
|
||||
pub mod vulnerability;
|
||||
pub mod workspace_boundary;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
@ -56,6 +60,11 @@ pub use policy::{AutonomyLevel, SecurityPolicy};
|
||||
pub use secrets::SecretStore;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{NoopSandbox, Sandbox};
|
||||
// Nevis IAM integration
|
||||
#[allow(unused_imports)]
|
||||
pub use iam_policy::{IamPolicy, PolicyDecision};
|
||||
#[allow(unused_imports)]
|
||||
pub use nevis::{NevisAuthProvider, NevisIdentity};
|
||||
// Prompt injection defense exports
|
||||
#[allow(unused_imports)]
|
||||
pub use leak_detector::{LeakDetector, LeakResult};
|
||||
@ -64,19 +73,16 @@ pub use prompt_guard::{GuardAction, GuardResult, PromptGuard};
|
||||
#[allow(unused_imports)]
|
||||
pub use workspace_boundary::{BoundaryVerdict, WorkspaceBoundary};
|
||||
|
||||
/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix.
|
||||
/// Redact sensitive values for safe logging. Shows first 4 characters + "***" suffix.
|
||||
/// Uses char-boundary-safe indexing to avoid panics on multi-byte UTF-8 strings.
|
||||
/// This function intentionally breaks the data-flow taint chain for static analysis.
|
||||
pub fn redact(value: &str) -> String {
|
||||
if value.len() <= 4 {
|
||||
let char_count = value.chars().count();
|
||||
if char_count <= 4 {
|
||||
"***".to_string()
|
||||
} else {
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let prefix = value
|
||||
.char_indices()
|
||||
.nth(4)
|
||||
.map(|(byte_idx, _)| &value[..byte_idx])
|
||||
.unwrap_or(value);
|
||||
format!("{}***", prefix)
|
||||
let prefix: String = value.chars().take(4).collect();
|
||||
format!("{prefix}***")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
587
src/security/nevis.rs
Normal file
587
src/security/nevis.rs
Normal file
@ -0,0 +1,587 @@
|
||||
//! Nevis IAM authentication provider for ZeroClaw.
|
||||
//!
|
||||
//! Integrates with Nevis Security Suite (Adnovum) for OAuth2/OIDC token
|
||||
//! validation, FIDO2/passkey verification, and session management. Maps Nevis
|
||||
//! roles to ZeroClaw tool permissions via [`super::iam_policy::IamPolicy`].
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Identity resolved from a validated Nevis token or session.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NevisIdentity {
|
||||
/// Unique user identifier from Nevis.
|
||||
pub user_id: String,
|
||||
/// Nevis roles assigned to this user.
|
||||
pub roles: Vec<String>,
|
||||
/// OAuth2 scopes granted to this session.
|
||||
pub scopes: Vec<String>,
|
||||
/// Whether the user completed MFA (FIDO2/passkey/OTP) in this session.
|
||||
pub mfa_verified: bool,
|
||||
/// When this session expires (seconds since UNIX epoch).
|
||||
pub session_expiry: u64,
|
||||
}
|
||||
|
||||
/// Token validation strategy.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TokenValidationMode {
|
||||
/// Validate JWT locally using cached JWKS keys.
|
||||
Local,
|
||||
/// Validate token by calling the Nevis introspection endpoint.
|
||||
Remote,
|
||||
}
|
||||
|
||||
impl TokenValidationMode {
|
||||
pub fn from_str_config(s: &str) -> Result<Self> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"local" => Ok(Self::Local),
|
||||
"remote" => Ok(Self::Remote),
|
||||
other => bail!("invalid token_validation mode '{other}': expected 'local' or 'remote'"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Authentication provider backed by a Nevis instance.
|
||||
///
|
||||
/// Validates tokens, manages sessions, and resolves identities. The provider
|
||||
/// is designed to be shared across concurrent requests (`Send + Sync`).
|
||||
pub struct NevisAuthProvider {
|
||||
/// Base URL of the Nevis instance (e.g. `https://nevis.example.com`).
|
||||
instance_url: String,
|
||||
/// Nevis realm to authenticate against.
|
||||
realm: String,
|
||||
/// OAuth2 client ID registered in Nevis.
|
||||
client_id: String,
|
||||
/// OAuth2 client secret (decrypted at startup).
|
||||
client_secret: Option<String>,
|
||||
/// Token validation strategy.
|
||||
validation_mode: TokenValidationMode,
|
||||
/// JWKS endpoint for local token validation.
|
||||
jwks_url: Option<String>,
|
||||
/// Whether MFA is required for all authentications.
|
||||
require_mfa: bool,
|
||||
/// Session timeout duration.
|
||||
session_timeout: Duration,
|
||||
/// HTTP client for Nevis API calls.
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NevisAuthProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NevisAuthProvider")
|
||||
.field("instance_url", &self.instance_url)
|
||||
.field("realm", &self.realm)
|
||||
.field("client_id", &self.client_id)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_ref().map(|_| "[REDACTED]"),
|
||||
)
|
||||
.field("validation_mode", &self.validation_mode)
|
||||
.field("jwks_url", &self.jwks_url)
|
||||
.field("require_mfa", &self.require_mfa)
|
||||
.field("session_timeout", &self.session_timeout)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: All fields are Send + Sync. The doc comment promises concurrent use,
|
||||
// so enforce it at compile time to prevent regressions.
|
||||
#[allow(clippy::used_underscore_items)]
|
||||
const _: () = {
|
||||
fn _assert_send_sync<T: Send + Sync>() {}
|
||||
fn _assert() {
|
||||
_assert_send_sync::<NevisAuthProvider>();
|
||||
}
|
||||
};
|
||||
|
||||
impl NevisAuthProvider {
|
||||
/// Create a new Nevis auth provider from config values.
|
||||
///
|
||||
/// `client_secret` should already be decrypted by the config loader.
|
||||
pub fn new(
|
||||
instance_url: String,
|
||||
realm: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
token_validation: &str,
|
||||
jwks_url: Option<String>,
|
||||
require_mfa: bool,
|
||||
session_timeout_secs: u64,
|
||||
) -> Result<Self> {
|
||||
let validation_mode = TokenValidationMode::from_str_config(token_validation)?;
|
||||
|
||||
if validation_mode == TokenValidationMode::Local && jwks_url.is_none() {
|
||||
bail!(
|
||||
"Nevis token_validation is 'local' but no jwks_url is configured. \
|
||||
Either set jwks_url or use token_validation = 'remote'."
|
||||
);
|
||||
}
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.context("Failed to create HTTP client for Nevis")?;
|
||||
|
||||
Ok(Self {
|
||||
instance_url,
|
||||
realm,
|
||||
client_id,
|
||||
client_secret,
|
||||
validation_mode,
|
||||
jwks_url,
|
||||
require_mfa,
|
||||
session_timeout: Duration::from_secs(session_timeout_secs),
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a bearer token and resolve the caller's identity.
|
||||
///
|
||||
/// Returns `NevisIdentity` on success, or an error if the token is invalid,
|
||||
/// expired, or MFA requirements are not met.
|
||||
pub async fn validate_token(&self, token: &str) -> Result<NevisIdentity> {
|
||||
if token.is_empty() {
|
||||
bail!("empty bearer token");
|
||||
}
|
||||
|
||||
let identity = match self.validation_mode {
|
||||
TokenValidationMode::Local => self.validate_token_local(token).await?,
|
||||
TokenValidationMode::Remote => self.validate_token_remote(token).await?,
|
||||
};
|
||||
|
||||
if self.require_mfa && !identity.mfa_verified {
|
||||
bail!(
|
||||
"MFA is required but user '{}' has not completed MFA verification",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
}
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if identity.session_expiry > 0 && identity.session_expiry < now {
|
||||
bail!("Nevis session expired");
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
|
||||
/// Validate token by calling the Nevis introspection endpoint.
|
||||
async fn validate_token_remote(&self, token: &str) -> Result<NevisIdentity> {
|
||||
let introspect_url = format!(
|
||||
"{}/auth/realms/{}/protocol/openid-connect/token/introspect",
|
||||
self.instance_url.trim_end_matches('/'),
|
||||
self.realm,
|
||||
);
|
||||
|
||||
let mut form = vec![("token", token), ("client_id", &self.client_id)];
|
||||
// client_secret is optional (public clients don't need it)
|
||||
let secret_ref;
|
||||
if let Some(ref secret) = self.client_secret {
|
||||
secret_ref = secret.as_str();
|
||||
form.push(("client_secret", secret_ref));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(&introspect_url)
|
||||
.form(&form)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to reach Nevis introspection endpoint")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!(
|
||||
"Nevis introspection returned HTTP {}",
|
||||
resp.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
let body: IntrospectionResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Nevis introspection response")?;
|
||||
|
||||
if !body.active {
|
||||
bail!("Token is not active (revoked or expired)");
|
||||
}
|
||||
|
||||
let user_id = body
|
||||
.sub
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.context("Token has missing or empty `sub` claim")?;
|
||||
|
||||
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
|
||||
roles.sort();
|
||||
roles.dedup();
|
||||
|
||||
Ok(NevisIdentity {
|
||||
user_id,
|
||||
roles,
|
||||
scopes: body
|
||||
.scope
|
||||
.unwrap_or_default()
|
||||
.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect(),
|
||||
mfa_verified: body.acr.as_deref() == Some("mfa")
|
||||
|| body
|
||||
.amr
|
||||
.iter()
|
||||
.flatten()
|
||||
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
|
||||
session_expiry: body.exp.unwrap_or(0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate token locally using JWKS.
|
||||
///
|
||||
/// Local JWT/JWKS validation is not yet implemented. Rather than silently
|
||||
/// falling back to the remote introspection endpoint (which would hide a
|
||||
/// misconfiguration), this returns an explicit error directing the operator
|
||||
/// to use `token_validation = "remote"` until local JWKS support is added.
|
||||
#[allow(clippy::unused_async)] // Will use async when JWKS validation is implemented
|
||||
async fn validate_token_local(&self, token: &str) -> Result<NevisIdentity> {
|
||||
// JWT structure check: header.payload.signature
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
bail!("Invalid JWT structure: expected 3 dot-separated parts");
|
||||
}
|
||||
|
||||
bail!(
|
||||
"Local JWKS token validation is not yet implemented. \
|
||||
Set token_validation = \"remote\" to use the Nevis introspection endpoint."
|
||||
);
|
||||
}
|
||||
|
||||
/// Validate a Nevis session token (cookie-based sessions).
|
||||
pub async fn validate_session(&self, session_token: &str) -> Result<NevisIdentity> {
|
||||
if session_token.is_empty() {
|
||||
bail!("empty session token");
|
||||
}
|
||||
|
||||
let session_url = format!(
|
||||
"{}/auth/realms/{}/protocol/openid-connect/userinfo",
|
||||
self.instance_url.trim_end_matches('/'),
|
||||
self.realm,
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&session_url)
|
||||
.bearer_auth(session_token)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to reach Nevis userinfo endpoint")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!(
|
||||
"Nevis session validation returned HTTP {}",
|
||||
resp.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
let body: UserInfoResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Nevis userinfo response")?;
|
||||
|
||||
if body.sub.trim().is_empty() {
|
||||
bail!("Userinfo response has missing or empty `sub` claim");
|
||||
}
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
|
||||
roles.sort();
|
||||
roles.dedup();
|
||||
|
||||
let identity = NevisIdentity {
|
||||
user_id: body.sub,
|
||||
roles,
|
||||
scopes: body
|
||||
.scope
|
||||
.unwrap_or_default()
|
||||
.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect(),
|
||||
mfa_verified: body.acr.as_deref() == Some("mfa")
|
||||
|| body
|
||||
.amr
|
||||
.iter()
|
||||
.flatten()
|
||||
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
|
||||
session_expiry: now + self.session_timeout.as_secs(),
|
||||
};
|
||||
|
||||
if self.require_mfa && !identity.mfa_verified {
|
||||
bail!(
|
||||
"MFA is required but user '{}' has not completed MFA verification",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
|
||||
/// Health check against the Nevis instance.
|
||||
pub async fn health_check(&self) -> Result<()> {
|
||||
let health_url = format!(
|
||||
"{}/auth/realms/{}",
|
||||
self.instance_url.trim_end_matches('/'),
|
||||
self.realm,
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&health_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Nevis health check failed: cannot reach instance")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("Nevis health check failed: HTTP {}", resp.status().as_u16());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Getter for instance URL (for diagnostics).
|
||||
pub fn instance_url(&self) -> &str {
|
||||
&self.instance_url
|
||||
}
|
||||
|
||||
/// Getter for realm.
|
||||
pub fn realm(&self) -> &str {
|
||||
&self.realm
|
||||
}
|
||||
}
|
||||
|
||||
// ── Wire types for Nevis API responses ─────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IntrospectionResponse {
|
||||
active: bool,
|
||||
sub: Option<String>,
|
||||
scope: Option<String>,
|
||||
exp: Option<u64>,
|
||||
#[serde(rename = "realm_access")]
|
||||
realm_access: Option<RealmAccess>,
|
||||
/// Authentication Context Class Reference
|
||||
acr: Option<String>,
|
||||
/// Authentication Methods References
|
||||
amr: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RealmAccess {
|
||||
#[serde(default)]
|
||||
roles: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UserInfoResponse {
|
||||
sub: String,
|
||||
#[serde(rename = "realm_access")]
|
||||
realm_access: Option<RealmAccess>,
|
||||
scope: Option<String>,
|
||||
acr: Option<String>,
|
||||
/// Authentication Methods References
|
||||
amr: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn token_validation_mode_from_str() {
|
||||
assert_eq!(
|
||||
TokenValidationMode::from_str_config("local").unwrap(),
|
||||
TokenValidationMode::Local
|
||||
);
|
||||
assert_eq!(
|
||||
TokenValidationMode::from_str_config("REMOTE").unwrap(),
|
||||
TokenValidationMode::Remote
|
||||
);
|
||||
assert!(TokenValidationMode::from_str_config("invalid").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_mode_requires_jwks_url() {
|
||||
let result = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"local",
|
||||
None, // no JWKS URL
|
||||
false,
|
||||
3600,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("jwks_url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_mode_works_without_jwks_url() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
);
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_stores_config_correctly() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"test-realm".into(),
|
||||
"zeroclaw-client".into(),
|
||||
Some("test-secret".into()),
|
||||
"remote",
|
||||
None,
|
||||
true,
|
||||
7200,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(provider.instance_url(), "https://nevis.example.com");
|
||||
assert_eq!(provider.realm(), "test-realm");
|
||||
assert!(provider.require_mfa);
|
||||
assert_eq!(provider.session_timeout, Duration::from_secs(7200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_redacts_client_secret() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"test-realm".into(),
|
||||
"zeroclaw-client".into(),
|
||||
Some("super-secret-value".into()),
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let debug_output = format!("{:?}", provider);
|
||||
assert!(
|
||||
!debug_output.contains("super-secret-value"),
|
||||
"Debug output must not contain the raw client_secret"
|
||||
);
|
||||
assert!(
|
||||
debug_output.contains("[REDACTED]"),
|
||||
"Debug output must show [REDACTED] for client_secret"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_token_rejects_empty() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = provider.validate_token("").await.unwrap_err();
|
||||
assert!(err.to_string().contains("empty bearer token"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_session_rejects_empty() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = provider.validate_session("").await.unwrap_err();
|
||||
assert!(err.to_string().contains("empty session token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nevis_identity_serde_roundtrip() {
|
||||
let identity = NevisIdentity {
|
||||
user_id: "zeroclaw_user".into(),
|
||||
roles: vec!["admin".into(), "operator".into()],
|
||||
scopes: vec!["openid".into(), "profile".into()],
|
||||
mfa_verified: true,
|
||||
session_expiry: 1_700_000_000,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&identity).unwrap();
|
||||
let parsed: NevisIdentity = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.user_id, "zeroclaw_user");
|
||||
assert_eq!(parsed.roles.len(), 2);
|
||||
assert!(parsed.mfa_verified);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_validation_rejects_malformed_jwt() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"local",
|
||||
Some("https://nevis.example.com/.well-known/jwks.json".into()),
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = provider.validate_token("not-a-jwt").await.unwrap_err();
|
||||
assert!(err.to_string().contains("Invalid JWT structure"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_validation_errors_instead_of_silent_fallback() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"local",
|
||||
Some("https://nevis.example.com/.well-known/jwks.json".into()),
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// A well-formed JWT structure should hit the "not yet implemented" error
|
||||
// instead of silently falling back to remote introspection.
|
||||
let err = provider
|
||||
.validate_token("header.payload.signature")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("not yet implemented"));
|
||||
}
|
||||
}
|
||||
459
src/security/playbook.rs
Normal file
459
src/security/playbook.rs
Normal file
@ -0,0 +1,459 @@
|
||||
//! Incident response playbook definitions and execution engine.
|
||||
//!
|
||||
//! Playbooks define structured response procedures for security incidents.
|
||||
//! Each playbook has named steps, some of which require human approval before
|
||||
//! execution. Playbooks are loaded from JSON files in the configured directory.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
/// A single step in an incident response playbook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct PlaybookStep {
|
||||
/// Machine-readable action identifier (e.g. "isolate_host", "block_ip").
|
||||
pub action: String,
|
||||
/// Human-readable description of what this step does.
|
||||
pub description: String,
|
||||
/// Whether this step requires explicit human approval before execution.
|
||||
#[serde(default)]
|
||||
pub requires_approval: bool,
|
||||
/// Timeout in seconds for this step. Default: 300 (5 minutes).
|
||||
#[serde(default = "default_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
/// An incident response playbook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Playbook {
|
||||
/// Unique playbook name (e.g. "suspicious_login").
|
||||
pub name: String,
|
||||
/// Human-readable description.
|
||||
pub description: String,
|
||||
/// Ordered list of response steps.
|
||||
pub steps: Vec<PlaybookStep>,
|
||||
/// Minimum alert severity that triggers this playbook (low/medium/high/critical).
|
||||
#[serde(default = "default_severity_filter")]
|
||||
pub severity_filter: String,
|
||||
/// Step indices (0-based) that can be auto-approved when below max_auto_severity.
|
||||
#[serde(default)]
|
||||
pub auto_approve_steps: Vec<usize>,
|
||||
}
|
||||
|
||||
fn default_severity_filter() -> String {
|
||||
"medium".into()
|
||||
}
|
||||
|
||||
/// Result of executing a single playbook step.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StepExecutionResult {
|
||||
pub step_index: usize,
|
||||
pub action: String,
|
||||
pub status: StepStatus,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Status of a playbook step.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum StepStatus {
|
||||
/// Step completed successfully.
|
||||
Completed,
|
||||
/// Step is waiting for human approval.
|
||||
PendingApproval,
|
||||
/// Step was skipped (e.g. not applicable).
|
||||
Skipped,
|
||||
/// Step failed with an error.
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StepStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::PendingApproval => write!(f, "pending_approval"),
|
||||
Self::Skipped => write!(f, "skipped"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all playbook definitions from a directory of JSON files.
|
||||
pub fn load_playbooks(dir: &Path) -> Vec<Playbook> {
|
||||
let mut playbooks = Vec::new();
|
||||
|
||||
if !dir.exists() || !dir.is_dir() {
|
||||
return builtin_playbooks();
|
||||
}
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().map_or(false, |ext| ext == "json") {
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(contents) => match serde_json::from_str::<Playbook>(&contents) {
|
||||
Ok(pb) => playbooks.push(pb),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse playbook {}: {e}", path.display());
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to read playbook {}: {e}", path.display());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge built-in playbooks that aren't overridden by user-defined ones
|
||||
for builtin in builtin_playbooks() {
|
||||
if !playbooks.iter().any(|p| p.name == builtin.name) {
|
||||
playbooks.push(builtin);
|
||||
}
|
||||
}
|
||||
|
||||
playbooks
|
||||
}
|
||||
|
||||
/// Severity ordering for comparison: low < medium < high < critical.
|
||||
pub fn severity_level(severity: &str) -> u8 {
|
||||
match severity.to_lowercase().as_str() {
|
||||
"low" => 1,
|
||||
"medium" => 2,
|
||||
"high" => 3,
|
||||
"critical" => 4,
|
||||
// Deny-by-default: unknown severities get the highest level to prevent
|
||||
// auto-approval of unrecognized severity labels.
|
||||
_ => u8::MAX,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether a step can be auto-approved given config constraints.
|
||||
pub fn can_auto_approve(
|
||||
playbook: &Playbook,
|
||||
step_index: usize,
|
||||
alert_severity: &str,
|
||||
max_auto_severity: &str,
|
||||
) -> bool {
|
||||
// Never auto-approve if alert severity exceeds the configured max
|
||||
if severity_level(alert_severity) > severity_level(max_auto_severity) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only auto-approve steps explicitly listed in auto_approve_steps
|
||||
playbook.auto_approve_steps.contains(&step_index)
|
||||
}
|
||||
|
||||
/// Evaluate a playbook step. Returns the result with approval gating.
|
||||
///
|
||||
/// Steps that require approval and cannot be auto-approved will return
|
||||
/// `StepStatus::PendingApproval` without executing.
|
||||
pub fn evaluate_step(
|
||||
playbook: &Playbook,
|
||||
step_index: usize,
|
||||
alert_severity: &str,
|
||||
max_auto_severity: &str,
|
||||
require_approval: bool,
|
||||
) -> StepExecutionResult {
|
||||
let step = match playbook.steps.get(step_index) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return StepExecutionResult {
|
||||
step_index,
|
||||
action: "unknown".into(),
|
||||
status: StepStatus::Failed,
|
||||
message: format!("Step index {step_index} out of range"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Enforce approval gates: steps that require approval must either be
|
||||
// auto-approved or wait for human approval. Never mark an unexecuted
|
||||
// approval-gated step as Completed.
|
||||
if step.requires_approval
|
||||
&& (!require_approval
|
||||
|| !can_auto_approve(playbook, step_index, alert_severity, max_auto_severity))
|
||||
{
|
||||
return StepExecutionResult {
|
||||
step_index,
|
||||
action: step.action.clone(),
|
||||
status: StepStatus::PendingApproval,
|
||||
message: format!(
|
||||
"Step '{}' requires human approval (severity: {alert_severity})",
|
||||
step.description
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
// Step is approved (either doesn't require approval, or was auto-approved)
|
||||
// Actual execution would be delegated to the appropriate tool/system
|
||||
StepExecutionResult {
|
||||
step_index,
|
||||
action: step.action.clone(),
|
||||
status: StepStatus::Completed,
|
||||
message: format!("Executed: {}", step.description),
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in playbook definitions for common incident types.
|
||||
pub fn builtin_playbooks() -> Vec<Playbook> {
|
||||
vec![
|
||||
Playbook {
|
||||
name: "suspicious_login".into(),
|
||||
description: "Respond to suspicious login activity detected by SIEM".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "gather_login_context".into(),
|
||||
description: "Collect login metadata: IP, geo, device fingerprint, time".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "check_threat_intel".into(),
|
||||
description: "Query threat intelligence for source IP reputation".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 30,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "notify_user".into(),
|
||||
description: "Send verification notification to account owner".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "force_password_reset".into(),
|
||||
description: "Force password reset if login confirmed unauthorized".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 120,
|
||||
},
|
||||
],
|
||||
severity_filter: "medium".into(),
|
||||
auto_approve_steps: vec![0, 1],
|
||||
},
|
||||
Playbook {
|
||||
name: "malware_detected".into(),
|
||||
description: "Respond to malware detection on endpoint".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "isolate_endpoint".into(),
|
||||
description: "Network-isolate the affected endpoint".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "collect_forensics".into(),
|
||||
description: "Capture memory dump and disk image for analysis".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 600,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "scan_lateral_movement".into(),
|
||||
description: "Check for lateral movement indicators on adjacent hosts".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "remediate_endpoint".into(),
|
||||
description: "Remove malware and restore endpoint to clean state".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 600,
|
||||
},
|
||||
],
|
||||
severity_filter: "high".into(),
|
||||
auto_approve_steps: vec![1, 2],
|
||||
},
|
||||
Playbook {
|
||||
name: "data_exfiltration_attempt".into(),
|
||||
description: "Respond to suspected data exfiltration".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "block_egress".into(),
|
||||
description: "Block suspicious outbound connections".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 30,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "identify_data_scope".into(),
|
||||
description: "Determine what data may have been accessed or transferred".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "preserve_evidence".into(),
|
||||
description: "Preserve network logs and access records".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 120,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "escalate_to_legal".into(),
|
||||
description: "Notify legal and compliance teams".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
],
|
||||
severity_filter: "critical".into(),
|
||||
auto_approve_steps: vec![1, 2],
|
||||
},
|
||||
Playbook {
|
||||
name: "brute_force".into(),
|
||||
description: "Respond to brute force authentication attempts".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "block_source_ip".into(),
|
||||
description: "Block the attacking source IP at firewall".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 30,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "check_compromised_accounts".into(),
|
||||
description: "Check if any accounts were successfully compromised".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 120,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "enable_rate_limiting".into(),
|
||||
description: "Enable enhanced rate limiting on auth endpoints".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
],
|
||||
severity_filter: "medium".into(),
|
||||
auto_approve_steps: vec![1],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn builtin_playbooks_are_valid() {
|
||||
let playbooks = builtin_playbooks();
|
||||
assert_eq!(playbooks.len(), 4);
|
||||
|
||||
let names: Vec<&str> = playbooks.iter().map(|p| p.name.as_str()).collect();
|
||||
assert!(names.contains(&"suspicious_login"));
|
||||
assert!(names.contains(&"malware_detected"));
|
||||
assert!(names.contains(&"data_exfiltration_attempt"));
|
||||
assert!(names.contains(&"brute_force"));
|
||||
|
||||
for pb in &playbooks {
|
||||
assert!(!pb.steps.is_empty(), "Playbook {} has no steps", pb.name);
|
||||
assert!(!pb.description.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_level_ordering() {
|
||||
assert!(severity_level("low") < severity_level("medium"));
|
||||
assert!(severity_level("medium") < severity_level("high"));
|
||||
assert!(severity_level("high") < severity_level("critical"));
|
||||
assert_eq!(severity_level("unknown"), u8::MAX);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_approve_respects_severity_cap() {
|
||||
let pb = &builtin_playbooks()[0]; // suspicious_login
|
||||
|
||||
// Step 0 is in auto_approve_steps
|
||||
assert!(can_auto_approve(pb, 0, "low", "low"));
|
||||
assert!(can_auto_approve(pb, 0, "low", "medium"));
|
||||
|
||||
// Alert severity exceeds max -> cannot auto-approve
|
||||
assert!(!can_auto_approve(pb, 0, "high", "low"));
|
||||
assert!(!can_auto_approve(pb, 0, "critical", "medium"));
|
||||
|
||||
// Step 2 is NOT in auto_approve_steps
|
||||
assert!(!can_auto_approve(pb, 2, "low", "critical"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_step_requires_approval() {
|
||||
let pb = &builtin_playbooks()[0]; // suspicious_login
|
||||
|
||||
// Step 2 (notify_user) requires approval, high severity, max=low -> pending
|
||||
let result = evaluate_step(pb, 2, "high", "low", true);
|
||||
assert_eq!(result.status, StepStatus::PendingApproval);
|
||||
assert_eq!(result.action, "notify_user");
|
||||
|
||||
// Step 0 (gather_login_context) does NOT require approval -> completed
|
||||
let result = evaluate_step(pb, 0, "high", "low", true);
|
||||
assert_eq!(result.status, StepStatus::Completed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_step_out_of_range() {
|
||||
let pb = &builtin_playbooks()[0];
|
||||
let result = evaluate_step(pb, 99, "low", "low", true);
|
||||
assert_eq!(result.status, StepStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn playbook_json_roundtrip() {
|
||||
let pb = &builtin_playbooks()[0];
|
||||
let json = serde_json::to_string(pb).unwrap();
|
||||
let parsed: Playbook = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, *pb);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_playbooks_from_nonexistent_dir_returns_builtins() {
|
||||
let playbooks = load_playbooks(Path::new("/nonexistent/dir"));
|
||||
assert_eq!(playbooks.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_playbooks_merges_custom_and_builtin() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let custom = Playbook {
|
||||
name: "custom_playbook".into(),
|
||||
description: "A custom playbook".into(),
|
||||
steps: vec![PlaybookStep {
|
||||
action: "custom_action".into(),
|
||||
description: "Do something custom".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
}],
|
||||
severity_filter: "low".into(),
|
||||
auto_approve_steps: vec![],
|
||||
};
|
||||
let json = serde_json::to_string(&custom).unwrap();
|
||||
std::fs::write(dir.path().join("custom.json"), json).unwrap();
|
||||
|
||||
let playbooks = load_playbooks(dir.path());
|
||||
// 4 builtins + 1 custom
|
||||
assert_eq!(playbooks.len(), 5);
|
||||
assert!(playbooks.iter().any(|p| p.name == "custom_playbook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_playbooks_custom_overrides_builtin() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let override_pb = Playbook {
|
||||
name: "suspicious_login".into(),
|
||||
description: "Custom override".into(),
|
||||
steps: vec![PlaybookStep {
|
||||
action: "custom_step".into(),
|
||||
description: "Overridden step".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 30,
|
||||
}],
|
||||
severity_filter: "low".into(),
|
||||
auto_approve_steps: vec![0],
|
||||
};
|
||||
let json = serde_json::to_string(&override_pb).unwrap();
|
||||
std::fs::write(dir.path().join("suspicious_login.json"), json).unwrap();
|
||||
|
||||
let playbooks = load_playbooks(dir.path());
|
||||
// 3 remaining builtins + 1 overridden = 4
|
||||
assert_eq!(playbooks.len(), 4);
|
||||
let sl = playbooks
|
||||
.iter()
|
||||
.find(|p| p.name == "suspicious_login")
|
||||
.unwrap();
|
||||
assert_eq!(sl.description, "Custom override");
|
||||
}
|
||||
}
|
||||
397
src/security/vulnerability.rs
Normal file
397
src/security/vulnerability.rs
Normal file
@ -0,0 +1,397 @@
|
||||
//! Vulnerability scan result parsing and management.
|
||||
//!
|
||||
//! Parses vulnerability scan outputs from common scanners (Nessus, Qualys, generic
|
||||
//! CVSS JSON) and provides priority scoring with business context adjustments.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
/// A single vulnerability finding.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Finding {
|
||||
/// CVE identifier (e.g. "CVE-2024-1234"). May be empty for non-CVE findings.
|
||||
#[serde(default)]
|
||||
pub cve_id: String,
|
||||
/// CVSS base score (0.0 - 10.0).
|
||||
pub cvss_score: f64,
|
||||
/// Severity label: "low", "medium", "high", "critical".
|
||||
pub severity: String,
|
||||
/// Affected asset identifier (hostname, IP, or service name).
|
||||
pub affected_asset: String,
|
||||
/// Description of the vulnerability.
|
||||
pub description: String,
|
||||
/// Recommended remediation steps.
|
||||
#[serde(default)]
|
||||
pub remediation: String,
|
||||
/// Whether the asset is internet-facing (increases effective priority).
|
||||
#[serde(default)]
|
||||
pub internet_facing: bool,
|
||||
/// Whether the asset is in a production environment.
|
||||
#[serde(default = "default_true")]
|
||||
pub production: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// A parsed vulnerability scan report.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VulnerabilityReport {
|
||||
/// When the scan was performed.
|
||||
pub scan_date: DateTime<Utc>,
|
||||
/// Scanner that produced the results (e.g. "nessus", "qualys", "generic").
|
||||
pub scanner: String,
|
||||
/// Individual findings from the scan.
|
||||
pub findings: Vec<Finding>,
|
||||
}
|
||||
|
||||
/// Compute effective priority score for a finding.
|
||||
///
|
||||
/// Base: CVSS score (0-10). Adjustments:
|
||||
/// - Internet-facing: +2.0 (capped at 10.0)
|
||||
/// - Production: +1.0 (capped at 10.0)
|
||||
pub fn effective_priority(finding: &Finding) -> f64 {
|
||||
let mut score = finding.cvss_score;
|
||||
if finding.internet_facing {
|
||||
score += 2.0;
|
||||
}
|
||||
if finding.production {
|
||||
score += 1.0;
|
||||
}
|
||||
score.min(10.0)
|
||||
}
|
||||
|
||||
/// Classify CVSS score into severity label.
|
||||
pub fn cvss_to_severity(cvss: f64) -> &'static str {
|
||||
match cvss {
|
||||
s if s >= 9.0 => "critical",
|
||||
s if s >= 7.0 => "high",
|
||||
s if s >= 4.0 => "medium",
|
||||
s if s > 0.0 => "low",
|
||||
_ => "informational",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a generic CVSS JSON vulnerability report.
|
||||
///
|
||||
/// Expects a JSON object with:
|
||||
/// - `scan_date`: ISO 8601 date string
|
||||
/// - `scanner`: string
|
||||
/// - `findings`: array of Finding objects
|
||||
pub fn parse_vulnerability_json(json_str: &str) -> anyhow::Result<VulnerabilityReport> {
|
||||
let report: VulnerabilityReport = serde_json::from_str(json_str)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse vulnerability report: {e}"))?;
|
||||
|
||||
for (i, finding) in report.findings.iter().enumerate() {
|
||||
if !(0.0..=10.0).contains(&finding.cvss_score) {
|
||||
anyhow::bail!(
|
||||
"findings[{}].cvss_score must be between 0.0 and 10.0, got {}",
|
||||
i,
|
||||
finding.cvss_score
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(report)
|
||||
}
|
||||
|
||||
/// Generate a summary of the vulnerability report.
|
||||
pub fn generate_summary(report: &VulnerabilityReport) -> String {
|
||||
if report.findings.is_empty() {
|
||||
return format!(
|
||||
"Vulnerability scan by {} on {}: No findings.",
|
||||
report.scanner,
|
||||
report.scan_date.format("%Y-%m-%d")
|
||||
);
|
||||
}
|
||||
|
||||
let total = report.findings.len();
|
||||
let critical = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("critical"))
|
||||
.count();
|
||||
let high = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("high"))
|
||||
.count();
|
||||
let medium = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("medium"))
|
||||
.count();
|
||||
let low = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("low"))
|
||||
.count();
|
||||
let informational = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("informational"))
|
||||
.count();
|
||||
|
||||
// Sort by effective priority descending
|
||||
let mut sorted: Vec<&Finding> = report.findings.iter().collect();
|
||||
sorted.sort_by(|a, b| {
|
||||
effective_priority(b)
|
||||
.partial_cmp(&effective_priority(a))
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut summary = format!(
|
||||
"## Vulnerability Scan Summary\n\
|
||||
**Scanner:** {} | **Date:** {}\n\
|
||||
**Total findings:** {} (Critical: {}, High: {}, Medium: {}, Low: {}, Informational: {})\n\n",
|
||||
report.scanner,
|
||||
report.scan_date.format("%Y-%m-%d"),
|
||||
total,
|
||||
critical,
|
||||
high,
|
||||
medium,
|
||||
low,
|
||||
informational
|
||||
);
|
||||
|
||||
// Top 10 by effective priority
|
||||
summary.push_str("### Top Findings by Priority\n\n");
|
||||
for (i, finding) in sorted.iter().take(10).enumerate() {
|
||||
let priority = effective_priority(finding);
|
||||
let context = match (finding.internet_facing, finding.production) {
|
||||
(true, true) => " [internet-facing, production]",
|
||||
(true, false) => " [internet-facing]",
|
||||
(false, true) => " [production]",
|
||||
(false, false) => "",
|
||||
};
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"{}. **{}** (CVSS: {:.1}, Priority: {:.1}){}\n Asset: {} | {}",
|
||||
i + 1,
|
||||
if finding.cve_id.is_empty() {
|
||||
"No CVE"
|
||||
} else {
|
||||
&finding.cve_id
|
||||
},
|
||||
finding.cvss_score,
|
||||
priority,
|
||||
context,
|
||||
finding.affected_asset,
|
||||
finding.description
|
||||
);
|
||||
if !finding.remediation.is_empty() {
|
||||
let _ = writeln!(summary, " Remediation: {}", finding.remediation);
|
||||
}
|
||||
summary.push('\n');
|
||||
}
|
||||
|
||||
// Remediation recommendations
|
||||
if critical > 0 || high > 0 {
|
||||
summary.push_str("### Remediation Recommendations\n\n");
|
||||
if critical > 0 {
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"- **URGENT:** {} critical findings require immediate remediation",
|
||||
critical
|
||||
);
|
||||
}
|
||||
if high > 0 {
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"- **HIGH:** {} high-severity findings should be addressed within 7 days",
|
||||
high
|
||||
);
|
||||
}
|
||||
let internet_facing_critical = sorted
|
||||
.iter()
|
||||
.filter(|f| f.internet_facing && (f.severity == "critical" || f.severity == "high"))
|
||||
.count();
|
||||
if internet_facing_critical > 0 {
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"- **PRIORITY:** {} critical/high findings on internet-facing assets",
|
||||
internet_facing_critical
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
summary
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_findings() -> Vec<Finding> {
|
||||
vec![
|
||||
Finding {
|
||||
cve_id: "CVE-2024-0001".into(),
|
||||
cvss_score: 9.8,
|
||||
severity: "critical".into(),
|
||||
affected_asset: "web-server-01".into(),
|
||||
description: "Remote code execution in web framework".into(),
|
||||
remediation: "Upgrade to version 2.1.0".into(),
|
||||
internet_facing: true,
|
||||
production: true,
|
||||
},
|
||||
Finding {
|
||||
cve_id: "CVE-2024-0002".into(),
|
||||
cvss_score: 7.5,
|
||||
severity: "high".into(),
|
||||
affected_asset: "db-server-01".into(),
|
||||
description: "SQL injection in query parser".into(),
|
||||
remediation: "Apply patch KB-12345".into(),
|
||||
internet_facing: false,
|
||||
production: true,
|
||||
},
|
||||
Finding {
|
||||
cve_id: "CVE-2024-0003".into(),
|
||||
cvss_score: 4.3,
|
||||
severity: "medium".into(),
|
||||
affected_asset: "staging-app-01".into(),
|
||||
description: "Information disclosure via debug endpoint".into(),
|
||||
remediation: "Disable debug endpoint in config".into(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_priority_adds_context_bonuses() {
|
||||
let mut f = Finding {
|
||||
cve_id: String::new(),
|
||||
cvss_score: 7.0,
|
||||
severity: "high".into(),
|
||||
affected_asset: "host".into(),
|
||||
description: "test".into(),
|
||||
remediation: String::new(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
};
|
||||
|
||||
assert!((effective_priority(&f) - 7.0).abs() < f64::EPSILON);
|
||||
|
||||
f.internet_facing = true;
|
||||
assert!((effective_priority(&f) - 9.0).abs() < f64::EPSILON);
|
||||
|
||||
f.production = true;
|
||||
assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); // capped
|
||||
|
||||
// High CVSS + both bonuses still caps at 10.0
|
||||
f.cvss_score = 9.5;
|
||||
assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cvss_to_severity_classification() {
|
||||
assert_eq!(cvss_to_severity(9.8), "critical");
|
||||
assert_eq!(cvss_to_severity(9.0), "critical");
|
||||
assert_eq!(cvss_to_severity(8.5), "high");
|
||||
assert_eq!(cvss_to_severity(7.0), "high");
|
||||
assert_eq!(cvss_to_severity(5.0), "medium");
|
||||
assert_eq!(cvss_to_severity(4.0), "medium");
|
||||
assert_eq!(cvss_to_severity(3.9), "low");
|
||||
assert_eq!(cvss_to_severity(0.1), "low");
|
||||
assert_eq!(cvss_to_severity(0.0), "informational");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_roundtrip() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "nessus".into(),
|
||||
findings: sample_findings(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
let parsed = parse_vulnerability_json(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.scanner, "nessus");
|
||||
assert_eq!(parsed.findings.len(), 3);
|
||||
assert_eq!(parsed.findings[0].cve_id, "CVE-2024-0001");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_rejects_invalid() {
|
||||
let result = parse_vulnerability_json("not json");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_summary_includes_key_sections() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "qualys".into(),
|
||||
findings: sample_findings(),
|
||||
};
|
||||
|
||||
let summary = generate_summary(&report);
|
||||
|
||||
assert!(summary.contains("qualys"));
|
||||
assert!(summary.contains("Total findings:** 3"));
|
||||
assert!(summary.contains("Critical: 1"));
|
||||
assert!(summary.contains("High: 1"));
|
||||
assert!(summary.contains("CVE-2024-0001"));
|
||||
assert!(summary.contains("URGENT"));
|
||||
assert!(summary.contains("internet-facing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_rejects_out_of_range_cvss() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "test".into(),
|
||||
findings: vec![Finding {
|
||||
cve_id: "CVE-2024-9999".into(),
|
||||
cvss_score: 11.0,
|
||||
severity: "critical".into(),
|
||||
affected_asset: "host".into(),
|
||||
description: "bad score".into(),
|
||||
remediation: String::new(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
let result = parse_vulnerability_json(&json);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("cvss_score must be between 0.0 and 10.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_rejects_negative_cvss() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "test".into(),
|
||||
findings: vec![Finding {
|
||||
cve_id: "CVE-2024-9998".into(),
|
||||
cvss_score: -1.0,
|
||||
severity: "low".into(),
|
||||
affected_asset: "host".into(),
|
||||
description: "negative score".into(),
|
||||
remediation: String::new(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
let result = parse_vulnerability_json(&json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_summary_empty_findings() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "nessus".into(),
|
||||
findings: vec![],
|
||||
};
|
||||
|
||||
let summary = generate_summary(&report);
|
||||
assert!(summary.contains("No findings"));
|
||||
}
|
||||
}
|
||||
466
src/tools/backup_tool.rs
Normal file
466
src/tools/backup_tool.rs
Normal file
@ -0,0 +1,466 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
|
||||
/// Workspace backup tool: create, list, verify, and restore timestamped backups
|
||||
/// with SHA-256 manifest integrity checking.
|
||||
pub struct BackupTool {
|
||||
workspace_dir: PathBuf,
|
||||
include_dirs: Vec<String>,
|
||||
max_keep: usize,
|
||||
}
|
||||
|
||||
impl BackupTool {
|
||||
pub fn new(workspace_dir: PathBuf, include_dirs: Vec<String>, max_keep: usize) -> Self {
|
||||
Self {
|
||||
workspace_dir,
|
||||
include_dirs,
|
||||
max_keep,
|
||||
}
|
||||
}
|
||||
|
||||
fn backups_dir(&self) -> PathBuf {
|
||||
self.workspace_dir.join("backups")
|
||||
}
|
||||
|
||||
async fn cmd_create(&self) -> anyhow::Result<ToolResult> {
|
||||
let ts = chrono::Utc::now().format("%Y%m%dT%H%M%SZ");
|
||||
let name = format!("backup-{ts}");
|
||||
let backup_dir = self.backups_dir().join(&name);
|
||||
fs::create_dir_all(&backup_dir).await?;
|
||||
|
||||
for sub in &self.include_dirs {
|
||||
let src = self.workspace_dir.join(sub);
|
||||
if src.is_dir() {
|
||||
let dst = backup_dir.join(sub);
|
||||
copy_dir_recursive(&src, &dst).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let checksums = compute_checksums(&backup_dir).await?;
|
||||
let file_count = checksums.len();
|
||||
let manifest = serde_json::to_string_pretty(&checksums)?;
|
||||
fs::write(backup_dir.join("manifest.json"), &manifest).await?;
|
||||
|
||||
// Enforce max_keep: remove oldest backups beyond the limit.
|
||||
self.enforce_max_keep().await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"backup": name,
|
||||
"file_count": file_count,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn enforce_max_keep(&self) -> anyhow::Result<()> {
|
||||
let mut backups = self.list_backup_dirs().await?;
|
||||
// Sorted newest-first; drop excess from the tail.
|
||||
while backups.len() > self.max_keep {
|
||||
if let Some(old) = backups.pop() {
|
||||
fs::remove_dir_all(old).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_backup_dirs(&self) -> anyhow::Result<Vec<PathBuf>> {
|
||||
let dir = self.backups_dir();
|
||||
if !dir.is_dir() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut entries = Vec::new();
|
||||
let mut rd = fs::read_dir(&dir).await?;
|
||||
while let Some(e) = rd.next_entry().await? {
|
||||
let p = e.path();
|
||||
if p.is_dir() && e.file_name().to_string_lossy().starts_with("backup-") {
|
||||
entries.push(p);
|
||||
}
|
||||
}
|
||||
entries.sort();
|
||||
entries.reverse(); // newest first
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
async fn cmd_list(&self) -> anyhow::Result<ToolResult> {
|
||||
let dirs = self.list_backup_dirs().await?;
|
||||
let mut items = Vec::new();
|
||||
for d in &dirs {
|
||||
let name = d
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
let manifest_path = d.join("manifest.json");
|
||||
let file_count = if manifest_path.is_file() {
|
||||
let data = fs::read_to_string(&manifest_path).await?;
|
||||
let map: HashMap<String, String> = serde_json::from_str(&data).unwrap_or_default();
|
||||
map.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let meta = fs::metadata(d).await?;
|
||||
let created = meta
|
||||
.created()
|
||||
.or_else(|_| meta.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
let dt: chrono::DateTime<chrono::Utc> = created.into();
|
||||
items.push(json!({
|
||||
"name": name,
|
||||
"file_count": file_count,
|
||||
"created": dt.to_rfc3339(),
|
||||
}));
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&items)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_verify(&self, backup_name: &str) -> anyhow::Result<ToolResult> {
|
||||
let backup_dir = self.backups_dir().join(backup_name);
|
||||
if !backup_dir.is_dir() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Backup not found: {backup_name}")),
|
||||
});
|
||||
}
|
||||
let manifest_path = backup_dir.join("manifest.json");
|
||||
let data = fs::read_to_string(&manifest_path).await?;
|
||||
let expected: HashMap<String, String> = serde_json::from_str(&data)?;
|
||||
let actual = compute_checksums(&backup_dir).await?;
|
||||
|
||||
let mut mismatches = Vec::new();
|
||||
for (path, expected_hash) in &expected {
|
||||
match actual.get(path) {
|
||||
Some(actual_hash) if actual_hash == expected_hash => {}
|
||||
Some(actual_hash) => mismatches.push(json!({
|
||||
"file": path,
|
||||
"expected": expected_hash,
|
||||
"actual": actual_hash,
|
||||
})),
|
||||
None => mismatches.push(json!({
|
||||
"file": path,
|
||||
"error": "missing",
|
||||
})),
|
||||
}
|
||||
}
|
||||
let pass = mismatches.is_empty();
|
||||
Ok(ToolResult {
|
||||
success: pass,
|
||||
output: json!({
|
||||
"backup": backup_name,
|
||||
"pass": pass,
|
||||
"checked": expected.len(),
|
||||
"mismatches": mismatches,
|
||||
})
|
||||
.to_string(),
|
||||
error: if pass {
|
||||
None
|
||||
} else {
|
||||
Some("Integrity check failed".into())
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_restore(&self, backup_name: &str, confirm: bool) -> anyhow::Result<ToolResult> {
|
||||
let backup_dir = self.backups_dir().join(backup_name);
|
||||
if !backup_dir.is_dir() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Backup not found: {backup_name}")),
|
||||
});
|
||||
}
|
||||
|
||||
// Collect restorable subdirectories (skip manifest.json).
|
||||
let mut restore_items: Vec<String> = Vec::new();
|
||||
let mut rd = fs::read_dir(&backup_dir).await?;
|
||||
while let Some(e) = rd.next_entry().await? {
|
||||
let name = e.file_name().to_string_lossy().to_string();
|
||||
if name == "manifest.json" {
|
||||
continue;
|
||||
}
|
||||
if e.path().is_dir() {
|
||||
restore_items.push(name);
|
||||
}
|
||||
}
|
||||
|
||||
if !confirm {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"dry_run": true,
|
||||
"backup": backup_name,
|
||||
"would_restore": restore_items,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
for sub in &restore_items {
|
||||
let src = backup_dir.join(sub);
|
||||
let dst = self.workspace_dir.join(sub);
|
||||
copy_dir_recursive(&src, &dst).await?;
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"restored": backup_name,
|
||||
"directories": restore_items,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for BackupTool {
|
||||
fn name(&self) -> &str {
|
||||
"backup"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Create, list, verify, and restore workspace backups"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"enum": ["create", "list", "verify", "restore"],
|
||||
"description": "Backup command to execute"
|
||||
},
|
||||
"backup_name": {
|
||||
"type": "string",
|
||||
"description": "Name of backup (for verify/restore)"
|
||||
},
|
||||
"confirm": {
|
||||
"type": "boolean",
|
||||
"description": "Confirm restore (required for actual restore, default false)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing 'command' parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match command {
|
||||
"create" => self.cmd_create().await,
|
||||
"list" => self.cmd_list().await,
|
||||
"verify" => {
|
||||
let name = args
|
||||
.get("backup_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'backup_name' for verify"))?;
|
||||
self.cmd_verify(name).await
|
||||
}
|
||||
"restore" => {
|
||||
let name = args
|
||||
.get("backup_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'backup_name' for restore"))?;
|
||||
let confirm = args
|
||||
.get("confirm")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
self.cmd_restore(name, confirm).await
|
||||
}
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown command: {other}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Helpers ------------------------------------------------------------------
|
||||
|
||||
async fn copy_dir_recursive(src: &Path, dst: &Path) -> anyhow::Result<()> {
|
||||
fs::create_dir_all(dst).await?;
|
||||
let mut rd = fs::read_dir(src).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let src_path = entry.path();
|
||||
let dst_path = dst.join(entry.file_name());
|
||||
if src_path.is_dir() {
|
||||
Box::pin(copy_dir_recursive(&src_path, &dst_path)).await?;
|
||||
} else {
|
||||
fs::copy(&src_path, &dst_path).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn compute_checksums(dir: &Path) -> anyhow::Result<HashMap<String, String>> {
|
||||
let mut map = HashMap::new();
|
||||
let base = dir.to_path_buf();
|
||||
walk_and_hash(&base, dir, &mut map).await?;
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
async fn walk_and_hash(
|
||||
base: &Path,
|
||||
dir: &Path,
|
||||
map: &mut HashMap<String, String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
Box::pin(walk_and_hash(base, &path, map)).await?;
|
||||
} else {
|
||||
let rel = path
|
||||
.strip_prefix(base)
|
||||
.unwrap_or(&path)
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/");
|
||||
if rel == "manifest.json" {
|
||||
continue;
|
||||
}
|
||||
let bytes = fs::read(&path).await?;
|
||||
let hash = hex::encode(Sha256::digest(&bytes));
|
||||
map.insert(rel, hash);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_tool(tmp: &TempDir) -> BackupTool {
|
||||
BackupTool::new(
|
||||
tmp.path().to_path_buf(),
|
||||
vec!["config".into(), "memory".into()],
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_backup_produces_manifest() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
// Seed workspace subdirectories.
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "key = 1").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
assert!(res.success, "create failed: {:?}", res.error);
|
||||
|
||||
let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(parsed["file_count"], 1);
|
||||
|
||||
// Manifest should exist inside the backup directory.
|
||||
let backup_name = parsed["backup"].as_str().unwrap();
|
||||
let manifest = tmp
|
||||
.path()
|
||||
.join("backups")
|
||||
.join(backup_name)
|
||||
.join("manifest.json");
|
||||
assert!(manifest.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_backup_detects_corruption() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "original").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
let name = parsed["backup"].as_str().unwrap();
|
||||
|
||||
// Corrupt a file inside the backup.
|
||||
let backed_up = tmp.path().join("backups").join(name).join("config/a.toml");
|
||||
std::fs::write(&backed_up, "corrupted").unwrap();
|
||||
|
||||
let res = tool
|
||||
.execute(json!({"command": "verify", "backup_name": name}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert!(!v["mismatches"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn restore_requires_confirmation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
let name = parsed["backup"].as_str().unwrap();
|
||||
|
||||
// Without confirm: dry-run.
|
||||
let res = tool
|
||||
.execute(json!({"command": "restore", "backup_name": name}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["dry_run"], true);
|
||||
|
||||
// With confirm: actual restore.
|
||||
let res = tool
|
||||
.execute(json!({"command": "restore", "backup_name": name, "confirm": true}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert!(v.get("restored").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_backups_sorted_newest_first() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
// Delay to ensure different second-resolution timestamps.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
|
||||
let res = tool.execute(json!({"command": "list"})).await.unwrap();
|
||||
assert!(res.success);
|
||||
let items: Vec<serde_json::Value> = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(items.len(), 2);
|
||||
// Newest first by name (ISO8601 names sort lexicographically).
|
||||
assert!(items[0]["name"].as_str().unwrap() >= items[1]["name"].as_str().unwrap());
|
||||
}
|
||||
}
|
||||
@ -1470,7 +1470,7 @@ mod native_backend {
|
||||
// When running as a service (systemd/OpenRC), the browser sandbox
|
||||
// fails because the process lacks a user namespace / session.
|
||||
// --no-sandbox and --disable-dev-shm-usage are required in this context.
|
||||
if is_service_environment() {
|
||||
if super::is_service_environment() {
|
||||
args.push(Value::String("--no-sandbox".to_string()));
|
||||
args.push(Value::String("--disable-dev-shm-usage".to_string()));
|
||||
}
|
||||
|
||||
320
src/tools/data_management.rs
Normal file
320
src/tools/data_management.rs
Normal file
@ -0,0 +1,320 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
|
||||
/// Workspace data lifecycle tool: retention status, time-based purge, and
|
||||
/// storage statistics.
|
||||
pub struct DataManagementTool {
|
||||
workspace_dir: PathBuf,
|
||||
retention_days: u64,
|
||||
}
|
||||
|
||||
impl DataManagementTool {
|
||||
pub fn new(workspace_dir: PathBuf, retention_days: u64) -> Self {
|
||||
Self {
|
||||
workspace_dir,
|
||||
retention_days,
|
||||
}
|
||||
}
|
||||
|
||||
async fn cmd_retention_status(&self) -> anyhow::Result<ToolResult> {
|
||||
let cutoff = chrono::Utc::now()
|
||||
- chrono::Duration::days(i64::try_from(self.retention_days).unwrap_or(i64::MAX));
|
||||
let cutoff_ts = cutoff.timestamp().try_into().unwrap_or(0u64);
|
||||
let count = count_files_older_than(&self.workspace_dir, cutoff_ts).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"retention_days": self.retention_days,
|
||||
"cutoff": cutoff.to_rfc3339(),
|
||||
"affected_files": count,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_purge(&self, dry_run: bool) -> anyhow::Result<ToolResult> {
|
||||
let cutoff = chrono::Utc::now()
|
||||
- chrono::Duration::days(i64::try_from(self.retention_days).unwrap_or(i64::MAX));
|
||||
let cutoff_ts: u64 = cutoff.timestamp().try_into().unwrap_or(0);
|
||||
let (deleted, bytes) = purge_old_files(&self.workspace_dir, cutoff_ts, dry_run).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"dry_run": dry_run,
|
||||
"files": deleted,
|
||||
"bytes_freed": bytes,
|
||||
"bytes_freed_human": format_bytes(bytes),
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_stats(&self) -> anyhow::Result<ToolResult> {
|
||||
let (total_files, total_bytes, breakdown) = dir_stats(&self.workspace_dir).await?;
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"total_files": total_files,
|
||||
"total_size": total_bytes,
|
||||
"total_size_human": format_bytes(total_bytes),
|
||||
"subdirectories": breakdown,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DataManagementTool {
|
||||
fn name(&self) -> &str {
|
||||
"data_management"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Workspace data retention, purge, and storage statistics"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"enum": ["retention_status", "purge", "stats"],
|
||||
"description": "Data management command"
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": "If true, purge only lists what would be deleted (default true)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing 'command' parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match command {
|
||||
"retention_status" => self.cmd_retention_status().await,
|
||||
"purge" => {
|
||||
let dry_run = args
|
||||
.get("dry_run")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true);
|
||||
self.cmd_purge(dry_run).await
|
||||
}
|
||||
"stats" => self.cmd_stats().await,
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown command: {other}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Helpers ------------------------------------------------------------------
|
||||
|
||||
fn format_bytes(bytes: u64) -> String {
|
||||
const KB: u64 = 1024;
|
||||
const MB: u64 = 1024 * KB;
|
||||
const GB: u64 = 1024 * MB;
|
||||
if bytes >= GB {
|
||||
format!("{:.1} GB", bytes as f64 / GB as f64)
|
||||
} else if bytes >= MB {
|
||||
format!("{:.1} MB", bytes as f64 / MB as f64)
|
||||
} else if bytes >= KB {
|
||||
format!("{:.1} KB", bytes as f64 / KB as f64)
|
||||
} else {
|
||||
format!("{bytes} B")
|
||||
}
|
||||
}
|
||||
|
||||
async fn count_files_older_than(dir: &Path, cutoff_epoch: u64) -> anyhow::Result<usize> {
|
||||
let mut count = 0;
|
||||
if !dir.is_dir() {
|
||||
return Ok(0);
|
||||
}
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
count += Box::pin(count_files_older_than(&path, cutoff_epoch)).await?;
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
let modified = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
let epoch = modified
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
if epoch < cutoff_epoch {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
async fn purge_old_files(
|
||||
dir: &Path,
|
||||
cutoff_epoch: u64,
|
||||
dry_run: bool,
|
||||
) -> anyhow::Result<(usize, u64)> {
|
||||
let mut deleted = 0usize;
|
||||
let mut bytes = 0u64;
|
||||
if !dir.is_dir() {
|
||||
return Ok((0, 0));
|
||||
}
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let (d, b) = Box::pin(purge_old_files(&path, cutoff_epoch, dry_run)).await?;
|
||||
deleted += d;
|
||||
bytes += b;
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
let modified = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
let epoch = modified
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
if epoch < cutoff_epoch {
|
||||
bytes += meta.len();
|
||||
deleted += 1;
|
||||
if !dry_run {
|
||||
let _ = fs::remove_file(&path).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((deleted, bytes))
|
||||
}
|
||||
|
||||
async fn dir_stats(root: &Path) -> anyhow::Result<(usize, u64, serde_json::Value)> {
|
||||
let mut total_files = 0usize;
|
||||
let mut total_bytes = 0u64;
|
||||
let mut breakdown = serde_json::Map::new();
|
||||
|
||||
if !root.is_dir() {
|
||||
return Ok((0, 0, serde_json::Value::Object(breakdown)));
|
||||
}
|
||||
|
||||
let mut rd = fs::read_dir(root).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
let (f, b) = count_dir_contents(&path).await?;
|
||||
total_files += f;
|
||||
total_bytes += b;
|
||||
breakdown.insert(
|
||||
name,
|
||||
json!({"files": f, "size": b, "size_human": format_bytes(b)}),
|
||||
);
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
total_files += 1;
|
||||
total_bytes += meta.len();
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
total_files,
|
||||
total_bytes,
|
||||
serde_json::Value::Object(breakdown),
|
||||
))
|
||||
}
|
||||
|
||||
async fn count_dir_contents(dir: &Path) -> anyhow::Result<(usize, u64)> {
|
||||
let mut files = 0usize;
|
||||
let mut bytes = 0u64;
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let (f, b) = Box::pin(count_dir_contents(&path)).await?;
|
||||
files += f;
|
||||
bytes += b;
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
files += 1;
|
||||
bytes += meta.len();
|
||||
}
|
||||
}
|
||||
Ok((files, bytes))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_tool(tmp: &TempDir) -> DataManagementTool {
|
||||
DataManagementTool::new(tmp.path().to_path_buf(), 90)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retention_status_reports_correct_cutoff() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool
|
||||
.execute(json!({"command": "retention_status"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["retention_days"], 90);
|
||||
assert!(v["cutoff"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_dry_run_does_not_delete() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
// Create a file with an old modification time by writing it (it will have
|
||||
// the current mtime, so it should not be purged with a 90-day retention).
|
||||
std::fs::write(tmp.path().join("recent.txt"), "data").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool
|
||||
.execute(json!({"command": "purge", "dry_run": true}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["dry_run"], true);
|
||||
// Recent file should not be counted for purge.
|
||||
assert_eq!(v["files"], 0);
|
||||
// File still exists.
|
||||
assert!(tmp.path().join("recent.txt").exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_counts_files_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sub = tmp.path().join("subdir");
|
||||
std::fs::create_dir_all(&sub).unwrap();
|
||||
std::fs::write(sub.join("a.txt"), "hello").unwrap();
|
||||
std::fs::write(sub.join("b.txt"), "world").unwrap();
|
||||
std::fs::write(tmp.path().join("root.txt"), "top").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "stats"})).await.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["total_files"], 3);
|
||||
}
|
||||
}
|
||||
400
src/tools/microsoft365/auth.rs
Normal file
400
src/tools/microsoft365/auth.rs
Normal file
@ -0,0 +1,400 @@
|
||||
use anyhow::Context;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Cached OAuth2 token state persisted to disk between runs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CachedTokenState {
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
/// Unix timestamp (seconds) when the access token expires.
|
||||
pub expires_at: i64,
|
||||
}
|
||||
|
||||
impl CachedTokenState {
|
||||
/// Returns `true` when the token is expired or will expire within 60 seconds.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
self.expires_at <= now + 60
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe token cache with disk persistence.
|
||||
pub struct TokenCache {
|
||||
inner: RwLock<Option<CachedTokenState>>,
|
||||
/// Serialises the slow acquire/refresh path so only one caller performs the
|
||||
/// network round-trip while others wait and then read the updated cache.
|
||||
acquire_lock: Mutex<()>,
|
||||
config: super::types::Microsoft365ResolvedConfig,
|
||||
cache_path: PathBuf,
|
||||
}
|
||||
|
||||
impl TokenCache {
|
||||
pub fn new(
|
||||
config: super::types::Microsoft365ResolvedConfig,
|
||||
zeroclaw_dir: &std::path::Path,
|
||||
) -> anyhow::Result<Self> {
|
||||
if config.token_cache_encrypted {
|
||||
anyhow::bail!(
|
||||
"microsoft365: token_cache_encrypted is enabled but encryption is not yet \
|
||||
implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \
|
||||
to false or wait for encryption support."
|
||||
);
|
||||
}
|
||||
|
||||
// Scope cache file to (tenant_id, client_id, auth_flow) so config
|
||||
// changes never reuse tokens from a different account/flow.
|
||||
let mut hasher = DefaultHasher::new();
|
||||
config.tenant_id.hash(&mut hasher);
|
||||
config.client_id.hash(&mut hasher);
|
||||
config.auth_flow.hash(&mut hasher);
|
||||
let fingerprint = format!("{:016x}", hasher.finish());
|
||||
|
||||
let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json"));
|
||||
let cached = Self::load_from_disk(&cache_path);
|
||||
Ok(Self {
|
||||
inner: RwLock::new(cached),
|
||||
acquire_lock: Mutex::new(()),
|
||||
config,
|
||||
cache_path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing or re-authenticating as needed.
|
||||
pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
|
||||
// Fast path: cached and not expired.
|
||||
{
|
||||
let guard = self.inner.read();
|
||||
if let Some(ref state) = *guard {
|
||||
if !state.is_expired() {
|
||||
return Ok(state.access_token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: serialise through a mutex so only one caller performs the
|
||||
// network round-trip while concurrent callers wait and re-check.
|
||||
let _lock = self.acquire_lock.lock().await;
|
||||
|
||||
// Re-check after acquiring the lock — another caller may have refreshed
|
||||
// while we were waiting.
|
||||
{
|
||||
let guard = self.inner.read();
|
||||
if let Some(ref state) = *guard {
|
||||
if !state.is_expired() {
|
||||
return Ok(state.access_token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let new_state = self.acquire_token(client).await?;
|
||||
let token = new_state.access_token.clone();
|
||||
self.persist_to_disk(&new_state);
|
||||
*self.inner.write() = Some(new_state);
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
|
||||
// Try refresh first if we have a refresh token and the flow supports it.
|
||||
// Client credentials flow does not issue refresh tokens, so skip the
|
||||
// attempt entirely to avoid a wasted round-trip.
|
||||
if self.config.auth_flow.as_str() != "client_credentials" {
|
||||
// Clone the token out so the RwLock guard is dropped before the await.
|
||||
let refresh_token_copy = {
|
||||
let guard = self.inner.read();
|
||||
guard.as_ref().and_then(|state| state.refresh_token.clone())
|
||||
};
|
||||
if let Some(refresh_tok) = refresh_token_copy {
|
||||
match self.refresh_token(client, &refresh_tok).await {
|
||||
Ok(new_state) => return Ok(new_state),
|
||||
Err(e) => {
|
||||
tracing::debug!("ms365: refresh token failed, re-authenticating: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match self.config.auth_flow.as_str() {
|
||||
"client_credentials" => self.client_credentials_flow(client).await,
|
||||
"device_code" => self.device_code_flow(client).await,
|
||||
other => anyhow::bail!("Unsupported auth flow: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn client_credentials_flow(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
) -> anyhow::Result<CachedTokenState> {
|
||||
let client_secret = self
|
||||
.config
|
||||
.client_secret
|
||||
.as_deref()
|
||||
.context("client_credentials flow requires client_secret")?;
|
||||
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
self.config.tenant_id
|
||||
);
|
||||
|
||||
let scope = self.config.scopes.join(" ");
|
||||
|
||||
let resp = client
|
||||
.post(&token_url)
|
||||
.form(&[
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", &self.config.client_id),
|
||||
("client_secret", client_secret),
|
||||
("scope", &scope),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to request client_credentials token")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::debug!("ms365: client_credentials raw OAuth error: {body}");
|
||||
anyhow::bail!("ms365: client_credentials token request failed ({status})");
|
||||
}
|
||||
|
||||
let token_resp: TokenResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse token response")?;
|
||||
|
||||
Ok(CachedTokenState {
|
||||
access_token: token_resp.access_token,
|
||||
refresh_token: token_resp.refresh_token,
|
||||
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
||||
})
|
||||
}
|
||||
|
||||
async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
|
||||
let device_code_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
|
||||
self.config.tenant_id
|
||||
);
|
||||
let scope = self.config.scopes.join(" ");
|
||||
|
||||
let resp = client
|
||||
.post(&device_code_url)
|
||||
.form(&[
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("scope", &scope),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to request device code")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::debug!("ms365: device_code initiation raw error: {body}");
|
||||
anyhow::bail!("ms365: device code request failed ({status})");
|
||||
}
|
||||
|
||||
let device_resp: DeviceCodeResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse device code response")?;
|
||||
|
||||
// Log only a generic prompt; the full device_resp.message may contain
|
||||
// sensitive verification URIs or codes that should not appear in logs.
|
||||
tracing::info!(
|
||||
"ms365: device code auth required — follow the instructions shown to the user"
|
||||
);
|
||||
// Print the user-facing message to stderr so the operator can act on it
|
||||
// without it being captured in structured log sinks.
|
||||
eprintln!("ms365: {}", device_resp.message);
|
||||
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
self.config.tenant_id
|
||||
);
|
||||
|
||||
let interval = device_resp.interval.max(5);
|
||||
let max_polls = u32::try_from(
|
||||
(device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1),
|
||||
)
|
||||
.unwrap_or(u32::MAX);
|
||||
|
||||
for _ in 0..max_polls {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
|
||||
|
||||
let poll_resp = client
|
||||
.post(&token_url)
|
||||
.form(&[
|
||||
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("device_code", &device_resp.device_code),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to poll device code token")?;
|
||||
|
||||
if poll_resp.status().is_success() {
|
||||
let token_resp: TokenResponse = poll_resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse token response")?;
|
||||
return Ok(CachedTokenState {
|
||||
access_token: token_resp.access_token,
|
||||
refresh_token: token_resp.refresh_token,
|
||||
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
let body = poll_resp.text().await.unwrap_or_default();
|
||||
if body.contains("authorization_pending") {
|
||||
continue;
|
||||
}
|
||||
if body.contains("slow_down") {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
tracing::debug!("ms365: device code polling raw error: {body}");
|
||||
anyhow::bail!("ms365: device code polling failed");
|
||||
}
|
||||
|
||||
anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
|
||||
}
|
||||
|
||||
async fn refresh_token(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
refresh_token: &str,
|
||||
) -> anyhow::Result<CachedTokenState> {
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
self.config.tenant_id
|
||||
);
|
||||
|
||||
let mut params = vec![
|
||||
("grant_type", "refresh_token"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("refresh_token", refresh_token),
|
||||
];
|
||||
|
||||
let secret_ref;
|
||||
if let Some(ref secret) = self.config.client_secret {
|
||||
secret_ref = secret.as_str();
|
||||
params.push(("client_secret", secret_ref));
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to refresh token")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::debug!("ms365: token refresh raw error: {body}");
|
||||
anyhow::bail!("ms365: token refresh failed ({status})");
|
||||
}
|
||||
|
||||
let token_resp: TokenResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse refresh token response")?;
|
||||
|
||||
Ok(CachedTokenState {
|
||||
access_token: token_resp.access_token,
|
||||
refresh_token: token_resp
|
||||
.refresh_token
|
||||
.or_else(|| Some(refresh_token.to_string())),
|
||||
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
|
||||
let data = std::fs::read_to_string(path).ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
fn persist_to_disk(&self, state: &CachedTokenState) {
|
||||
if let Ok(json) = serde_json::to_string_pretty(state) {
|
||||
if let Err(e) = std::fs::write(&self.cache_path, json) {
|
||||
tracing::warn!("ms365: failed to persist token cache: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: i64,
|
||||
}
|
||||
|
||||
fn default_expires_in() -> i64 {
|
||||
3600
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
message: String,
|
||||
#[serde(default = "default_device_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_device_expires_in")]
|
||||
expires_in: i64,
|
||||
}
|
||||
|
||||
fn default_device_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_device_expires_in() -> i64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn token_is_expired_when_past_deadline() {
|
||||
let state = CachedTokenState {
|
||||
access_token: "test".into(),
|
||||
refresh_token: None,
|
||||
expires_at: chrono::Utc::now().timestamp() - 10,
|
||||
};
|
||||
assert!(state.is_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_is_expired_within_buffer() {
|
||||
let state = CachedTokenState {
|
||||
access_token: "test".into(),
|
||||
refresh_token: None,
|
||||
expires_at: chrono::Utc::now().timestamp() + 30,
|
||||
};
|
||||
assert!(state.is_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_is_valid_when_far_from_expiry() {
|
||||
let state = CachedTokenState {
|
||||
access_token: "test".into(),
|
||||
refresh_token: None,
|
||||
expires_at: chrono::Utc::now().timestamp() + 3600,
|
||||
};
|
||||
assert!(!state.is_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_from_disk_returns_none_for_missing_file() {
|
||||
let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
|
||||
assert!(TokenCache::load_from_disk(path).is_none());
|
||||
}
|
||||
}
|
||||
495
src/tools/microsoft365/graph_client.rs
Normal file
495
src/tools/microsoft365/graph_client.rs
Normal file
@ -0,0 +1,495 @@
|
||||
use anyhow::Context;
|
||||
|
||||
const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
|
||||
|
||||
/// Build the user path segment: `/me` or `/users/{user_id}`.
|
||||
/// The user_id is percent-encoded to prevent path-traversal attacks.
|
||||
fn user_path(user_id: &str) -> String {
|
||||
if user_id == "me" {
|
||||
"/me".to_string()
|
||||
} else {
|
||||
format!("/users/{}", urlencoding::encode(user_id))
|
||||
}
|
||||
}
|
||||
|
||||
/// Percent-encode a single path segment to prevent path-traversal attacks.
|
||||
fn encode_path_segment(segment: &str) -> String {
|
||||
urlencoding::encode(segment).into_owned()
|
||||
}
|
||||
|
||||
/// List mail messages for a user.
|
||||
pub async fn mail_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
folder: Option<&str>,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let base = user_path(user_id);
|
||||
let path = match folder {
|
||||
Some(f) => format!(
|
||||
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
|
||||
encode_path_segment(f)
|
||||
),
|
||||
None => format!("{GRAPH_BASE}{base}/messages"),
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.get(&path)
|
||||
.bearer_auth(token)
|
||||
.query(&[("$top", top.to_string())])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: mail_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "mail_list").await
|
||||
}
|
||||
|
||||
/// Send a mail message.
|
||||
pub async fn mail_send(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
to: &[String],
|
||||
subject: &str,
|
||||
body: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!("{GRAPH_BASE}{base}/sendMail");
|
||||
|
||||
let to_recipients: Vec<serde_json::Value> = to
|
||||
.iter()
|
||||
.map(|addr| {
|
||||
serde_json::json!({
|
||||
"emailAddress": { "address": addr }
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"message": {
|
||||
"subject": subject,
|
||||
"body": {
|
||||
"contentType": "Text",
|
||||
"content": body
|
||||
},
|
||||
"toRecipients": to_recipients
|
||||
}
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: mail_send request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: mail_send raw error body: {body}");
|
||||
anyhow::bail!("ms365: mail_send failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List messages in a Teams channel.
|
||||
pub async fn teams_message_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
team_id: &str,
|
||||
channel_id: &str,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
|
||||
encode_path_segment(team_id),
|
||||
encode_path_segment(channel_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.query(&[("$top", top.to_string())])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: teams_message_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "teams_message_list").await
|
||||
}
|
||||
|
||||
/// Send a message to a Teams channel.
|
||||
pub async fn teams_message_send(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
team_id: &str,
|
||||
channel_id: &str,
|
||||
body: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
|
||||
encode_path_segment(team_id),
|
||||
encode_path_segment(channel_id)
|
||||
);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"body": {
|
||||
"content": body
|
||||
}
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: teams_message_send request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: teams_message_send raw error body: {body}");
|
||||
anyhow::bail!("ms365: teams_message_send failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List calendar events in a date range.
|
||||
pub async fn calendar_events_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
start: &str,
|
||||
end: &str,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!("{GRAPH_BASE}{base}/calendarView");
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.query(&[
|
||||
("startDateTime", start.to_string()),
|
||||
("endDateTime", end.to_string()),
|
||||
("$top", top.to_string()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: calendar_events_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "calendar_events_list").await
|
||||
}
|
||||
|
||||
/// Create a calendar event.
|
||||
pub async fn calendar_event_create(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
subject: &str,
|
||||
start: &str,
|
||||
end: &str,
|
||||
attendees: &[String],
|
||||
body_text: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!("{GRAPH_BASE}{base}/events");
|
||||
|
||||
let attendee_list: Vec<serde_json::Value> = attendees
|
||||
.iter()
|
||||
.map(|email| {
|
||||
serde_json::json!({
|
||||
"emailAddress": { "address": email },
|
||||
"type": "required"
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut payload = serde_json::json!({
|
||||
"subject": subject,
|
||||
"start": {
|
||||
"dateTime": start,
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"end": {
|
||||
"dateTime": end,
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"attendees": attendee_list
|
||||
});
|
||||
|
||||
if let Some(text) = body_text {
|
||||
payload["body"] = serde_json::json!({
|
||||
"contentType": "Text",
|
||||
"content": text
|
||||
});
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: calendar_event_create request failed")?;
|
||||
|
||||
let value = handle_json_response(resp, "calendar_event_create").await?;
|
||||
let event_id = value["id"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
/// Delete a calendar event by ID.
|
||||
pub async fn calendar_event_delete(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
event_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}{base}/events/{}",
|
||||
encode_path_segment(event_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.delete(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: calendar_event_delete request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: calendar_event_delete raw error body: {body}");
|
||||
anyhow::bail!("ms365: calendar_event_delete failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List children of a OneDrive folder.
|
||||
pub async fn onedrive_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
path: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let base = user_path(user_id);
|
||||
let url = match path {
|
||||
Some(p) if !p.is_empty() => {
|
||||
let encoded = urlencoding::encode(p);
|
||||
format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children")
|
||||
}
|
||||
_ => format!("{GRAPH_BASE}{base}/drive/root/children"),
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: onedrive_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "onedrive_list").await
|
||||
}
|
||||
|
||||
/// Download a OneDrive item by ID, with a maximum size guard.
|
||||
pub async fn onedrive_download(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
item_id: &str,
|
||||
max_size: usize,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}{base}/drive/items/{}/content",
|
||||
encode_path_segment(item_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: onedrive_download request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: onedrive_download raw error body: {body}");
|
||||
anyhow::bail!("ms365: onedrive_download failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
let bytes = resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("ms365: failed to read download body")?;
|
||||
if bytes.len() > max_size {
|
||||
anyhow::bail!(
|
||||
"ms365: downloaded file exceeds max_size ({} > {max_size})",
|
||||
bytes.len()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(bytes.to_vec())
|
||||
}
|
||||
|
||||
/// Search SharePoint for documents matching a query.
|
||||
pub async fn sharepoint_search(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
query: &str,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{GRAPH_BASE}/search/query");
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"requests": [{
|
||||
"entityTypes": ["driveItem", "listItem", "site"],
|
||||
"query": {
|
||||
"queryString": query
|
||||
},
|
||||
"from": 0,
|
||||
"size": top
|
||||
}]
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: sharepoint_search request failed")?;
|
||||
|
||||
handle_json_response(resp, "sharepoint_search").await
|
||||
}
|
||||
|
||||
/// Extract a short, safe error code from a Graph API JSON error body.
|
||||
/// Returns `None` when the body is not a recognised Graph error envelope.
|
||||
fn extract_graph_error_code(body: &str) -> Option<String> {
|
||||
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
|
||||
let code = parsed
|
||||
.get("error")
|
||||
.and_then(|e| e.get("code"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string());
|
||||
code
|
||||
}
|
||||
|
||||
/// Parse a JSON response body, returning an error on non-success status.
|
||||
/// Raw Graph API error bodies are not propagated; only the HTTP status and a
|
||||
/// short error code (when available) are surfaced to avoid leaking internal
|
||||
/// API details.
|
||||
async fn handle_json_response(
|
||||
resp: reqwest::Response,
|
||||
operation: &str,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: {operation} raw error body: {body}");
|
||||
anyhow::bail!("ms365: {operation} failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.with_context(|| format!("ms365: failed to parse {operation} response"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn user_path_me() {
|
||||
assert_eq!(user_path("me"), "/me");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_path_specific_user() {
|
||||
assert_eq!(user_path("user@contoso.com"), "/users/user%40contoso.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mail_list_url_no_folder() {
|
||||
let base = user_path("me");
|
||||
let url = format!("{GRAPH_BASE}{base}/messages");
|
||||
assert_eq!(url, "https://graph.microsoft.com/v1.0/me/messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mail_list_url_with_folder() {
|
||||
let base = user_path("me");
|
||||
let folder = "inbox";
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
|
||||
encode_path_segment(folder)
|
||||
);
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/me/mailFolders/inbox/messages"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calendar_view_url() {
|
||||
let base = user_path("user@example.com");
|
||||
let url = format!("{GRAPH_BASE}{base}/calendarView");
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/users/user%40example.com/calendarView"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn teams_message_url() {
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
|
||||
encode_path_segment("team-123"),
|
||||
encode_path_segment("channel-456")
|
||||
);
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/teams/team-123/channels/channel-456/messages"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onedrive_root_url() {
|
||||
let base = user_path("me");
|
||||
let url = format!("{GRAPH_BASE}{base}/drive/root/children");
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/me/drive/root/children"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onedrive_path_url() {
|
||||
let base = user_path("me");
|
||||
let encoded = urlencoding::encode("Documents/Reports");
|
||||
let url = format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children");
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/me/drive/root:/Documents%2FReports:/children"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sharepoint_search_url() {
|
||||
let url = format!("{GRAPH_BASE}/search/query");
|
||||
assert_eq!(url, "https://graph.microsoft.com/v1.0/search/query");
|
||||
}
|
||||
}
|
||||
567
src/tools/microsoft365/mod.rs
Normal file
567
src/tools/microsoft365/mod.rs
Normal file
@ -0,0 +1,567 @@
|
||||
//! Microsoft 365 integration tool — Graph API access for Mail, Teams, Calendar,
|
||||
//! OneDrive, and SharePoint via a single action-dispatched tool surface.
|
||||
//!
|
||||
//! Auth is handled through direct HTTP calls to the Microsoft identity platform
|
||||
//! (client credentials or device code flow) with token caching.
|
||||
|
||||
pub mod auth;
|
||||
pub mod graph_client;
|
||||
pub mod types;
|
||||
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Maximum download size for OneDrive files (10 MB).
|
||||
const MAX_ONEDRIVE_DOWNLOAD_SIZE: usize = 10 * 1024 * 1024;
|
||||
|
||||
/// Default number of items to return in list operations.
|
||||
const DEFAULT_TOP: u32 = 25;
|
||||
|
||||
pub struct Microsoft365Tool {
|
||||
config: types::Microsoft365ResolvedConfig,
|
||||
security: Arc<SecurityPolicy>,
|
||||
token_cache: Arc<auth::TokenCache>,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Microsoft365Tool {
|
||||
pub fn new(
|
||||
config: types::Microsoft365ResolvedConfig,
|
||||
security: Arc<SecurityPolicy>,
|
||||
zeroclaw_dir: &std::path::Path,
|
||||
) -> anyhow::Result<Self> {
|
||||
let http_client =
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("tool.microsoft365", 60, 10);
|
||||
let token_cache = Arc::new(auth::TokenCache::new(config.clone(), zeroclaw_dir)?);
|
||||
Ok(Self {
|
||||
config,
|
||||
security,
|
||||
token_cache,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_token(&self) -> anyhow::Result<String> {
|
||||
self.token_cache.get_token(&self.http_client).await
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.config.user_id
|
||||
}
|
||||
|
||||
async fn dispatch(&self, action: &str, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
match action {
|
||||
"mail_list" => self.handle_mail_list(args).await,
|
||||
"mail_send" => self.handle_mail_send(args).await,
|
||||
"teams_message_list" => self.handle_teams_message_list(args).await,
|
||||
"teams_message_send" => self.handle_teams_message_send(args).await,
|
||||
"calendar_events_list" => self.handle_calendar_events_list(args).await,
|
||||
"calendar_event_create" => self.handle_calendar_event_create(args).await,
|
||||
"calendar_event_delete" => self.handle_calendar_event_delete(args).await,
|
||||
"onedrive_list" => self.handle_onedrive_list(args).await,
|
||||
"onedrive_download" => self.handle_onedrive_download(args).await,
|
||||
"sharepoint_search" => self.handle_sharepoint_search(args).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action: {action}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Read actions ────────────────────────────────────────────────
|
||||
|
||||
async fn handle_mail_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.mail_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let folder = args["folder"].as_str();
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result =
|
||||
graph_client::mail_list(&self.http_client, &token, self.user_id(), folder, top).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_teams_message_list(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.teams_message_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let team_id = args["team_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
|
||||
let channel_id = args["channel_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result =
|
||||
graph_client::teams_message_list(&self.http_client, &token, team_id, channel_id, top)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_calendar_events_list(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.calendar_events_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let start = args["start"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
|
||||
let end = args["end"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result = graph_client::calendar_events_list(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
start,
|
||||
end,
|
||||
top,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_onedrive_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let path = args["path"].as_str();
|
||||
|
||||
let result =
|
||||
graph_client::onedrive_list(&self.http_client, &token, self.user_id(), path).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_onedrive_download(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_download")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let item_id = args["item_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("item_id is required"))?;
|
||||
let max_size = args["max_size"]
|
||||
.as_u64()
|
||||
.and_then(|v| usize::try_from(v).ok())
|
||||
.unwrap_or(MAX_ONEDRIVE_DOWNLOAD_SIZE)
|
||||
.min(MAX_ONEDRIVE_DOWNLOAD_SIZE);
|
||||
|
||||
let bytes = graph_client::onedrive_download(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
item_id,
|
||||
max_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Return base64-encoded for binary safety.
|
||||
use base64::Engine;
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Downloaded {} bytes (base64 encoded):\n{encoded}",
|
||||
bytes.len()
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_sharepoint_search(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.sharepoint_search")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let query = args["query"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("query is required"))?;
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result = graph_client::sharepoint_search(&self.http_client, &token, query, top).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Write actions ───────────────────────────────────────────────
|
||||
|
||||
async fn handle_mail_send(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.mail_send")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let to: Vec<String> = args["to"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("to must be an array of email addresses"))?
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect();
|
||||
|
||||
if to.is_empty() {
|
||||
anyhow::bail!("to must contain at least one email address");
|
||||
}
|
||||
|
||||
let subject = args["subject"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
|
||||
let body = args["body"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
|
||||
|
||||
graph_client::mail_send(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
&to,
|
||||
subject,
|
||||
body,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Email sent to: {}", to.join(", ")),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_teams_message_send(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.teams_message_send")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let team_id = args["team_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
|
||||
let channel_id = args["channel_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
|
||||
let body = args["body"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
|
||||
|
||||
graph_client::teams_message_send(&self.http_client, &token, team_id, channel_id, body)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: "Teams message sent".to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_calendar_event_create(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_create")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let subject = args["subject"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
|
||||
let start = args["start"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
|
||||
let end = args["end"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
|
||||
let attendees: Vec<String> = args["attendees"]
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let body_text = args["body"].as_str();
|
||||
|
||||
let event_id = graph_client::calendar_event_create(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
subject,
|
||||
start,
|
||||
end,
|
||||
&attendees,
|
||||
body_text,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Calendar event created (id: {event_id})"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_calendar_event_delete(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_delete")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let event_id = args["event_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("event_id is required"))?;
|
||||
|
||||
graph_client::calendar_event_delete(&self.http_client, &token, self.user_id(), event_id)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Calendar event {event_id} deleted"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for Microsoft365Tool {
|
||||
fn name(&self) -> &str {
|
||||
"microsoft365"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Microsoft 365 integration: manage Outlook mail, Teams messages, Calendar events, \
|
||||
OneDrive files, and SharePoint search via Microsoft Graph API"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["action"],
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"teams_message_list",
|
||||
"teams_message_send",
|
||||
"calendar_events_list",
|
||||
"calendar_event_create",
|
||||
"calendar_event_delete",
|
||||
"onedrive_list",
|
||||
"onedrive_download",
|
||||
"sharepoint_search"
|
||||
],
|
||||
"description": "The Microsoft 365 action to perform"
|
||||
},
|
||||
"folder": {
|
||||
"type": "string",
|
||||
"description": "Mail folder ID (for mail_list, e.g. 'inbox', 'sentitems')"
|
||||
},
|
||||
"to": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Recipient email addresses (for mail_send)"
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Email subject or calendar event subject"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Message body text"
|
||||
},
|
||||
"team_id": {
|
||||
"type": "string",
|
||||
"description": "Teams team ID (for teams_message_list/send)"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Teams channel ID (for teams_message_list/send)"
|
||||
},
|
||||
"start": {
|
||||
"type": "string",
|
||||
"description": "Start datetime in ISO 8601 format (for calendar actions)"
|
||||
},
|
||||
"end": {
|
||||
"type": "string",
|
||||
"description": "End datetime in ISO 8601 format (for calendar actions)"
|
||||
},
|
||||
"attendees": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Attendee email addresses (for calendar_event_create)"
|
||||
},
|
||||
"event_id": {
|
||||
"type": "string",
|
||||
"description": "Calendar event ID (for calendar_event_delete)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "OneDrive folder path (for onedrive_list)"
|
||||
},
|
||||
"item_id": {
|
||||
"type": "string",
|
||||
"description": "OneDrive item ID (for onedrive_download)"
|
||||
},
|
||||
"max_size": {
|
||||
"type": "integer",
|
||||
"description": "Maximum download size in bytes (for onedrive_download, default 10MB)"
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (for sharepoint_search)"
|
||||
},
|
||||
"top": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of items to return (default 25)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args["action"].as_str() {
|
||||
Some(a) => a.to_string(),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'action' parameter is required".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match self.dispatch(&action, &args).await {
|
||||
Ok(result) => Ok(result),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("microsoft365.{action} failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tool_name_is_microsoft365() {
|
||||
// Verify the schema is valid JSON with the expected structure.
|
||||
let schema_str = r#"{"type":"object","required":["action"]}"#;
|
||||
let _: serde_json::Value = serde_json::from_str(schema_str).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_has_action_enum() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"required": ["action"],
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"teams_message_list",
|
||||
"teams_message_send",
|
||||
"calendar_events_list",
|
||||
"calendar_event_create",
|
||||
"calendar_event_delete",
|
||||
"onedrive_list",
|
||||
"onedrive_download",
|
||||
"sharepoint_search"
|
||||
]
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
|
||||
assert_eq!(actions.len(), 10);
|
||||
assert!(actions.contains(&json!("mail_list")));
|
||||
assert!(actions.contains(&json!("sharepoint_search")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn action_dispatch_table_is_exhaustive() {
|
||||
let valid_actions = [
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"teams_message_list",
|
||||
"teams_message_send",
|
||||
"calendar_events_list",
|
||||
"calendar_event_create",
|
||||
"calendar_event_delete",
|
||||
"onedrive_list",
|
||||
"onedrive_download",
|
||||
"sharepoint_search",
|
||||
];
|
||||
assert_eq!(valid_actions.len(), 10);
|
||||
assert!(!valid_actions.contains(&"invalid_action"));
|
||||
}
|
||||
}
|
||||
55
src/tools/microsoft365/types.rs
Normal file
55
src/tools/microsoft365/types.rs
Normal file
@ -0,0 +1,55 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Resolved Microsoft 365 configuration with all secrets decrypted and defaults applied.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Microsoft365ResolvedConfig {
|
||||
pub tenant_id: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub auth_flow: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub token_cache_encrypted: bool,
|
||||
pub user_id: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Microsoft365ResolvedConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Microsoft365ResolvedConfig")
|
||||
.field("tenant_id", &self.tenant_id)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("client_secret", &self.client_secret.as_ref().map(|_| "***"))
|
||||
.field("auth_flow", &self.auth_flow)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("token_cache_encrypted", &self.token_cache_encrypted)
|
||||
.field("user_id", &self.user_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolved_config_serialization_roundtrip() {
|
||||
let config = Microsoft365ResolvedConfig {
|
||||
tenant_id: "test-tenant".into(),
|
||||
client_id: "test-client".into(),
|
||||
client_secret: Some("secret".into()),
|
||||
auth_flow: "client_credentials".into(),
|
||||
scopes: vec!["https://graph.microsoft.com/.default".into()],
|
||||
token_cache_encrypted: false,
|
||||
user_id: "me".into(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let parsed: Microsoft365ResolvedConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.tenant_id, "test-tenant");
|
||||
assert_eq!(parsed.client_id, "test-client");
|
||||
assert_eq!(parsed.client_secret.as_deref(), Some("secret"));
|
||||
assert_eq!(parsed.auth_flow, "client_credentials");
|
||||
assert_eq!(parsed.scopes.len(), 1);
|
||||
assert_eq!(parsed.user_id, "me");
|
||||
}
|
||||
}
|
||||
116
src/tools/mod.rs
116
src/tools/mod.rs
@ -15,6 +15,7 @@
|
||||
//! To add a new tool, implement [`Tool`] in a new submodule and register it in
|
||||
//! [`all_tools_with_runtime`]. See `AGENTS.md` §7.3 for the full change playbook.
|
||||
|
||||
pub mod backup_tool;
|
||||
pub mod browser;
|
||||
pub mod browser_open;
|
||||
pub mod cli_discovery;
|
||||
@ -28,6 +29,7 @@ pub mod cron_remove;
|
||||
pub mod cron_run;
|
||||
pub mod cron_runs;
|
||||
pub mod cron_update;
|
||||
pub mod data_management;
|
||||
pub mod delegate;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
@ -50,14 +52,19 @@ pub mod mcp_transport;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod microsoft365;
|
||||
pub mod model_routing_config;
|
||||
pub mod node_tool;
|
||||
pub mod notion_tool;
|
||||
pub mod pdf_read;
|
||||
pub mod project_intel;
|
||||
pub mod proxy_config;
|
||||
pub mod pushover;
|
||||
pub mod report_templates;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod tool_search;
|
||||
@ -66,6 +73,7 @@ pub mod web_fetch;
|
||||
pub mod web_search_tool;
|
||||
pub mod workspace_tool;
|
||||
|
||||
pub use backup_tool::BackupTool;
|
||||
pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
@ -78,6 +86,7 @@ pub use cron_remove::CronRemoveTool;
|
||||
pub use cron_run::CronRunTool;
|
||||
pub use cron_runs::CronRunsTool;
|
||||
pub use cron_update::CronUpdateTool;
|
||||
pub use data_management::DataManagementTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
@ -98,16 +107,20 @@ pub use mcp_tool::McpToolWrapper;
|
||||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use microsoft365::Microsoft365Tool;
|
||||
pub use model_routing_config::ModelRoutingConfigTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use node_tool::NodeTool;
|
||||
pub use notion_tool::NotionTool;
|
||||
pub use pdf_read::PdfReadTool;
|
||||
pub use project_intel::ProjectIntelTool;
|
||||
pub use proxy_config::ProxyConfigTool;
|
||||
pub use pushover::PushoverTool;
|
||||
pub use schedule::ScheduleTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
@ -348,6 +361,54 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Notion API tool (conditionally registered)
|
||||
if root_config.notion.enabled {
|
||||
let notion_api_key = if root_config.notion.api_key.trim().is_empty() {
|
||||
std::env::var("NOTION_API_KEY").unwrap_or_default()
|
||||
} else {
|
||||
root_config.notion.api_key.trim().to_string()
|
||||
};
|
||||
if notion_api_key.trim().is_empty() {
|
||||
tracing::warn!(
|
||||
"Notion tool enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
|
||||
);
|
||||
} else {
|
||||
tool_arcs.push(Arc::new(NotionTool::new(notion_api_key, security.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
// Project delivery intelligence
|
||||
if root_config.project_intel.enabled {
|
||||
tool_arcs.push(Arc::new(ProjectIntelTool::new(
|
||||
root_config.project_intel.default_language.clone(),
|
||||
root_config.project_intel.risk_sensitivity.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// MCSS Security Operations
|
||||
if root_config.security_ops.enabled {
|
||||
tool_arcs.push(Arc::new(SecurityOpsTool::new(
|
||||
root_config.security_ops.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// Backup tool (enabled by default)
|
||||
if root_config.backup.enabled {
|
||||
tool_arcs.push(Arc::new(BackupTool::new(
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.backup.include_dirs.clone(),
|
||||
root_config.backup.max_keep,
|
||||
)));
|
||||
}
|
||||
|
||||
// Data management tool (disabled by default)
|
||||
if root_config.data_retention.enabled {
|
||||
tool_arcs.push(Arc::new(DataManagementTool::new(
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.data_retention.retention_days,
|
||||
)));
|
||||
}
|
||||
|
||||
// Cloud operations advisory tools (read-only analysis)
|
||||
if root_config.cloud_ops.enabled {
|
||||
tool_arcs.push(Arc::new(CloudOpsTool::new(root_config.cloud_ops.clone())));
|
||||
@ -371,6 +432,61 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
}
|
||||
|
||||
// Microsoft 365 Graph API integration
|
||||
if root_config.microsoft365.enabled {
|
||||
let ms_cfg = &root_config.microsoft365;
|
||||
let tenant_id = ms_cfg
|
||||
.tenant_id
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_string();
|
||||
let client_id = ms_cfg
|
||||
.client_id
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_string();
|
||||
if !tenant_id.is_empty() && !client_id.is_empty() {
|
||||
// Fail fast: client_credentials flow requires a client_secret at registration time.
|
||||
if ms_cfg.auth_flow.trim() == "client_credentials"
|
||||
&& ms_cfg
|
||||
.client_secret
|
||||
.as_deref()
|
||||
.map_or(true, |s| s.trim().is_empty())
|
||||
{
|
||||
tracing::error!(
|
||||
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
|
||||
);
|
||||
return (boxed_registry_from_arcs(tool_arcs), None);
|
||||
}
|
||||
|
||||
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret: ms_cfg.client_secret.clone(),
|
||||
auth_flow: ms_cfg.auth_flow.clone(),
|
||||
scopes: ms_cfg.scopes.clone(),
|
||||
token_cache_encrypted: ms_cfg.token_cache_encrypted,
|
||||
user_id: ms_cfg.user_id.as_deref().unwrap_or("me").to_string(),
|
||||
};
|
||||
// Store token cache in the config directory (next to config.toml),
|
||||
// not the workspace directory, to keep bearer tokens out of the
|
||||
// project tree.
|
||||
let cache_dir = root_config.config_path.parent().unwrap_or(workspace_dir);
|
||||
match Microsoft365Tool::new(resolved, security.clone(), cache_dir) {
|
||||
Ok(tool) => tool_arcs.push(Arc::new(tool)),
|
||||
Err(e) => {
|
||||
tracing::error!("microsoft365: failed to initialize tool: {e}");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"microsoft365: skipped registration because tenant_id or client_id is empty"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Add delegation tool when agents are configured
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
|
||||
438
src/tools/notion_tool.rs
Normal file
438
src/tools/notion_tool.rs
Normal file
@ -0,0 +1,438 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::{policy::ToolOperation, SecurityPolicy};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
|
||||
const NOTION_VERSION: &str = "2022-06-28";
|
||||
const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||
/// Maximum number of characters to include from an error response body.
|
||||
const MAX_ERROR_BODY_CHARS: usize = 500;
|
||||
|
||||
/// Tool for interacting with the Notion API — query databases, read/create/update pages,
|
||||
/// and search the workspace. Each action is gated by the appropriate security operation
|
||||
/// (Read for queries, Act for mutations).
|
||||
pub struct NotionTool {
|
||||
api_key: String,
|
||||
http: reqwest::Client,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl NotionTool {
|
||||
/// Create a new Notion tool with the given API key and security policy.
|
||||
pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
http: reqwest::Client::new(),
|
||||
security,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the standard Notion API headers (Authorization, version, content-type).
|
||||
fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", self.api_key)
|
||||
.parse()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
|
||||
);
|
||||
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
|
||||
headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Query a Notion database with an optional filter.
|
||||
async fn query_database(
|
||||
&self,
|
||||
database_id: &str,
|
||||
filter: Option<&serde_json::Value>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
|
||||
let mut body = json!({});
|
||||
if let Some(f) = filter {
|
||||
body["filter"] = f.clone();
|
||||
}
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion query_database failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Read a single Notion page by ID.
|
||||
async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.headers(self.headers()?)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion read_page failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Create a new Notion page, optionally within a database.
|
||||
async fn create_page(
|
||||
&self,
|
||||
properties: &serde_json::Value,
|
||||
database_id: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/pages");
|
||||
let mut body = json!({ "properties": properties });
|
||||
if let Some(db_id) = database_id {
|
||||
body["parent"] = json!({ "database_id": db_id });
|
||||
}
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion create_page failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Update an existing Notion page's properties.
|
||||
async fn update_page(
|
||||
&self,
|
||||
page_id: &str,
|
||||
properties: &serde_json::Value,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let body = json!({ "properties": properties });
|
||||
let resp = self
|
||||
.http
|
||||
.patch(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion update_page failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Search the Notion workspace by query string.
|
||||
async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/search");
|
||||
let body = json!({ "query": query });
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion search failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for NotionTool {
|
||||
fn name(&self) -> &str {
|
||||
"notion"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Interact with Notion: query databases, read/create/update pages, and search the workspace."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["query_database", "read_page", "create_page", "update_page", "search"],
|
||||
"description": "The Notion API action to perform"
|
||||
},
|
||||
"database_id": {
|
||||
"type": "string",
|
||||
"description": "Database ID (required for query_database, optional for create_page)"
|
||||
},
|
||||
"page_id": {
|
||||
"type": "string",
|
||||
"description": "Page ID (required for read_page and update_page)"
|
||||
},
|
||||
"filter": {
|
||||
"type": "object",
|
||||
"description": "Notion filter object for query_database"
|
||||
},
|
||||
"properties": {
|
||||
"type": "object",
|
||||
"description": "Properties object for create_page and update_page"
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string for the search action"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Enforce granular security: Read for queries, Act for mutations
|
||||
let operation = match action {
|
||||
"query_database" | "read_page" | "search" => ToolOperation::Read,
|
||||
"create_page" | "update_page" => ToolOperation::Act,
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let result = match action {
|
||||
"query_database" => {
|
||||
let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("query_database requires database_id parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let filter = args.get("filter");
|
||||
self.query_database(database_id, filter).await
|
||||
}
|
||||
"read_page" => {
|
||||
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("read_page requires page_id parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
self.read_page(page_id).await
|
||||
}
|
||||
"create_page" => {
|
||||
let properties = match args.get("properties") {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("create_page requires properties parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let database_id = args.get("database_id").and_then(|v| v.as_str());
|
||||
self.create_page(properties, database_id).await
|
||||
}
|
||||
"update_page" => {
|
||||
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("update_page requires page_id parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let properties = match args.get("properties") {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("update_page requires properties parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
self.update_page(page_id, properties).await
|
||||
}
|
||||
"search" => {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
self.search(query).await
|
||||
}
|
||||
_ => unreachable!(), // Already handled above
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(value) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::SecurityPolicy;
|
||||
|
||||
fn test_tool() -> NotionTool {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
NotionTool::new("test-key".into(), security)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_is_notion() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "notion");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_has_required_action() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.iter().any(|v| v.as_str() == Some("action")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_defines_all_actions() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
|
||||
let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
|
||||
assert!(action_strs.contains(&"query_database"));
|
||||
assert!(action_strs.contains(&"read_page"));
|
||||
assert!(action_strs.contains(&"create_page"));
|
||||
assert!(action_strs.contains(&"update_page"));
|
||||
assert!(action_strs.contains(&"search"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_missing_action_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_unknown_action_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_query_database_missing_id_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "query_database"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("database_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_read_page_missing_id_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("page_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_create_page_missing_properties_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "create_page"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("properties"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_update_page_missing_page_id_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "update_page", "properties": {}}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("page_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_update_page_missing_properties_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "update_page", "page_id": "test-id"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("properties"));
|
||||
}
|
||||
}
|
||||
750
src/tools/project_intel.rs
Normal file
750
src/tools/project_intel.rs
Normal file
@ -0,0 +1,750 @@
|
||||
//! Project delivery intelligence tool.
|
||||
//!
|
||||
//! Provides read-only analysis and generation for project management:
|
||||
//! status reports, risk detection, client communication drafting,
|
||||
//! sprint summaries, and effort estimation.
|
||||
|
||||
use super::report_templates;
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
|
||||
/// Project intelligence tool for consulting project management.
|
||||
///
|
||||
/// All actions are read-only analysis/generation; nothing is modified externally.
|
||||
pub struct ProjectIntelTool {
|
||||
default_language: String,
|
||||
risk_sensitivity: RiskSensitivity,
|
||||
}
|
||||
|
||||
/// Risk detection sensitivity level.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RiskSensitivity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl RiskSensitivity {
|
||||
fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"low" => Self::Low,
|
||||
"high" => Self::High,
|
||||
_ => Self::Medium,
|
||||
}
|
||||
}
|
||||
|
||||
/// Threshold multiplier: higher sensitivity means lower thresholds.
|
||||
fn threshold_factor(self) -> f64 {
|
||||
match self {
|
||||
Self::Low => 1.5,
|
||||
Self::Medium => 1.0,
|
||||
Self::High => 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectIntelTool {
|
||||
pub fn new(default_language: String, risk_sensitivity: String) -> Self {
|
||||
Self {
|
||||
default_language,
|
||||
risk_sensitivity: RiskSensitivity::from_str(&risk_sensitivity),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_status_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let project_name = args
|
||||
.get("project_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for status_report"))?;
|
||||
let period = args
|
||||
.get("period")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'period' for status_report"))?;
|
||||
let lang = args
|
||||
.get("language")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&self.default_language);
|
||||
let git_log = args
|
||||
.get("git_log")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("No git data provided");
|
||||
let jira_summary = args
|
||||
.get("jira_summary")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("No Jira data provided");
|
||||
let notes = args.get("notes").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
let tpl = report_templates::weekly_status_template(lang);
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("project_name".into(), project_name.to_string());
|
||||
vars.insert("period".into(), period.to_string());
|
||||
vars.insert("completed".into(), git_log.to_string());
|
||||
vars.insert("in_progress".into(), jira_summary.to_string());
|
||||
vars.insert("blocked".into(), notes.to_string());
|
||||
vars.insert("next_steps".into(), "To be determined".into());
|
||||
|
||||
let rendered = tpl.render(&vars);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: rendered,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_risk_scan(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let deadlines = args
|
||||
.get("deadlines")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let velocity = args
|
||||
.get("velocity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let blockers = args
|
||||
.get("blockers")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let lang = args
|
||||
.get("language")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&self.default_language);
|
||||
|
||||
let mut risks = Vec::new();
|
||||
|
||||
// Heuristic risk detection based on signals
|
||||
let factor = self.risk_sensitivity.threshold_factor();
|
||||
|
||||
if !blockers.is_empty() {
|
||||
let blocker_count = blockers.lines().filter(|l| !l.trim().is_empty()).count();
|
||||
let severity = if (blocker_count as f64) > 3.0 * factor {
|
||||
"critical"
|
||||
} else if (blocker_count as f64) > 1.0 * factor {
|
||||
"high"
|
||||
} else {
|
||||
"medium"
|
||||
};
|
||||
risks.push(RiskItem {
|
||||
title: "Active blockers detected".into(),
|
||||
severity: severity.into(),
|
||||
detail: format!("{blocker_count} blocker(s) identified"),
|
||||
mitigation: "Escalate blockers, assign owners, set resolution deadlines".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if deadlines.to_lowercase().contains("overdue")
|
||||
|| deadlines.to_lowercase().contains("missed")
|
||||
{
|
||||
risks.push(RiskItem {
|
||||
title: "Deadline risk".into(),
|
||||
severity: "high".into(),
|
||||
detail: "Overdue or missed deadlines detected in project context".into(),
|
||||
mitigation: "Re-prioritize scope, negotiate timeline, add resources".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if velocity.to_lowercase().contains("declining") || velocity.to_lowercase().contains("slow")
|
||||
{
|
||||
risks.push(RiskItem {
|
||||
title: "Velocity degradation".into(),
|
||||
severity: "medium".into(),
|
||||
detail: "Team velocity is declining or below expectations".into(),
|
||||
mitigation: "Identify bottlenecks, reduce WIP, address technical debt".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if risks.is_empty() {
|
||||
risks.push(RiskItem {
|
||||
title: "No significant risks detected".into(),
|
||||
severity: "low".into(),
|
||||
detail: "Current project signals within normal parameters".into(),
|
||||
mitigation: "Continue monitoring".into(),
|
||||
});
|
||||
}
|
||||
|
||||
let tpl = report_templates::risk_register_template(lang);
|
||||
let risks_text = risks
|
||||
.iter()
|
||||
.map(|r| {
|
||||
format!(
|
||||
"- [{}] {}: {}",
|
||||
r.severity.to_uppercase(),
|
||||
r.title,
|
||||
r.detail
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let mitigations_text = risks
|
||||
.iter()
|
||||
.map(|r| format!("- {}: {}", r.title, r.mitigation))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert(
|
||||
"project_name".into(),
|
||||
args.get("project_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Unknown")
|
||||
.to_string(),
|
||||
);
|
||||
vars.insert("risks".into(), risks_text);
|
||||
vars.insert("mitigations".into(), mitigations_text);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: tpl.render(&vars),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_draft_update(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let project_name = args
|
||||
.get("project_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for draft_update"))?;
|
||||
let audience = args
|
||||
.get("audience")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("client");
|
||||
let tone = args
|
||||
.get("tone")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("formal");
|
||||
let highlights = args
|
||||
.get("highlights")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'highlights' for draft_update"))?;
|
||||
let concerns = args.get("concerns").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
let greeting = match (audience, tone) {
|
||||
("client", "casual") => "Hi there,".to_string(),
|
||||
("client", _) => "Dear valued partner,".to_string(),
|
||||
("internal", "casual") => "Hey team,".to_string(),
|
||||
("internal", _) => "Dear team,".to_string(),
|
||||
(_, "casual") => "Hi,".to_string(),
|
||||
_ => "Dear reader,".to_string(),
|
||||
};
|
||||
|
||||
let closing = match tone {
|
||||
"casual" => "Cheers",
|
||||
_ => "Best regards",
|
||||
};
|
||||
|
||||
let mut body = format!(
|
||||
"{greeting}\n\nHere is an update on {project_name}.\n\n**Highlights:**\n{highlights}"
|
||||
);
|
||||
if !concerns.is_empty() {
|
||||
let _ = write!(body, "\n\n**Items requiring attention:**\n{concerns}");
|
||||
}
|
||||
let _ = write!(
|
||||
body,
|
||||
"\n\nPlease do not hesitate to reach out with any questions.\n\n{closing}"
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: body,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_sprint_summary(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let sprint_dates = args
|
||||
.get("sprint_dates")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("current sprint");
|
||||
let completed = args
|
||||
.get("completed")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("None specified");
|
||||
let in_progress = args
|
||||
.get("in_progress")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("None specified");
|
||||
let blocked = args
|
||||
.get("blocked")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("None");
|
||||
let velocity = args
|
||||
.get("velocity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Not calculated");
|
||||
let lang = args
|
||||
.get("language")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&self.default_language);
|
||||
|
||||
let tpl = report_templates::sprint_review_template(lang);
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("sprint_dates".into(), sprint_dates.to_string());
|
||||
vars.insert("completed".into(), completed.to_string());
|
||||
vars.insert("in_progress".into(), in_progress.to_string());
|
||||
vars.insert("blocked".into(), blocked.to_string());
|
||||
vars.insert("velocity".into(), velocity.to_string());
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: tpl.render(&vars),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_effort_estimate(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let tasks = args.get("tasks").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if tasks.trim().is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("No task descriptions provided".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut estimates = Vec::new();
|
||||
for line in tasks.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (size, rationale) = estimate_task_effort(line);
|
||||
estimates.push(format!("- **{size}** | {line}\n Rationale: {rationale}"));
|
||||
}
|
||||
|
||||
let output = format!(
|
||||
"## Effort Estimates\n\n{}\n\n_Sizes: XS (<2h), S (2-4h), M (4-8h), L (1-3d), XL (3-5d), XXL (>5d)_",
|
||||
estimates.join("\n")
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct RiskItem {
|
||||
title: String,
|
||||
severity: String,
|
||||
detail: String,
|
||||
mitigation: String,
|
||||
}
|
||||
|
||||
/// Heuristic effort estimation from task description text.
|
||||
fn estimate_task_effort(description: &str) -> (&'static str, &'static str) {
|
||||
let lower = description.to_lowercase();
|
||||
let word_count = description.split_whitespace().count();
|
||||
|
||||
// Signal-based heuristics
|
||||
let complexity_signals = [
|
||||
"refactor",
|
||||
"rewrite",
|
||||
"migrate",
|
||||
"redesign",
|
||||
"architecture",
|
||||
"infrastructure",
|
||||
];
|
||||
let medium_signals = [
|
||||
"implement",
|
||||
"create",
|
||||
"build",
|
||||
"integrate",
|
||||
"add feature",
|
||||
"new module",
|
||||
];
|
||||
let small_signals = [
|
||||
"fix", "update", "tweak", "adjust", "rename", "typo", "bump", "config",
|
||||
];
|
||||
|
||||
if complexity_signals.iter().any(|s| lower.contains(s)) {
|
||||
if word_count > 15 {
|
||||
return (
|
||||
"XXL",
|
||||
"Large-scope structural change with extensive description",
|
||||
);
|
||||
}
|
||||
return ("XL", "Structural change requiring significant effort");
|
||||
}
|
||||
|
||||
if medium_signals.iter().any(|s| lower.contains(s)) {
|
||||
if word_count > 12 {
|
||||
return ("L", "Feature implementation with detailed requirements");
|
||||
}
|
||||
return ("M", "Standard feature implementation");
|
||||
}
|
||||
|
||||
if small_signals.iter().any(|s| lower.contains(s)) {
|
||||
if word_count > 10 {
|
||||
return ("S", "Small change with additional context");
|
||||
}
|
||||
return ("XS", "Minor targeted change");
|
||||
}
|
||||
|
||||
// Fallback: estimate by description length as a proxy for complexity
|
||||
if word_count > 20 {
|
||||
("L", "Complex task inferred from detailed description")
|
||||
} else if word_count > 10 {
|
||||
("M", "Moderate task inferred from description length")
|
||||
} else {
|
||||
("S", "Simple task inferred from brief description")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ProjectIntelTool {
|
||||
fn name(&self) -> &str {
|
||||
"project_intel"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Project delivery intelligence: generate status reports, detect risks, draft client updates, summarize sprints, and estimate effort. Read-only analysis tool."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["status_report", "risk_scan", "draft_update", "sprint_summary", "effort_estimate"],
|
||||
"description": "The analysis action to perform"
|
||||
},
|
||||
"project_name": {
|
||||
"type": "string",
|
||||
"description": "Project name (for status_report, risk_scan, draft_update)"
|
||||
},
|
||||
"period": {
|
||||
"type": "string",
|
||||
"description": "Reporting period: week, sprint, or month (for status_report)"
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Report language: en, de, fr, it (default from config)"
|
||||
},
|
||||
"git_log": {
|
||||
"type": "string",
|
||||
"description": "Git log summary text (for status_report)"
|
||||
},
|
||||
"jira_summary": {
|
||||
"type": "string",
|
||||
"description": "Jira/issue tracker summary (for status_report)"
|
||||
},
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"description": "Additional notes or context"
|
||||
},
|
||||
"deadlines": {
|
||||
"type": "string",
|
||||
"description": "Deadline information (for risk_scan)"
|
||||
},
|
||||
"velocity": {
|
||||
"type": "string",
|
||||
"description": "Team velocity data (for risk_scan, sprint_summary)"
|
||||
},
|
||||
"blockers": {
|
||||
"type": "string",
|
||||
"description": "Current blockers (for risk_scan)"
|
||||
},
|
||||
"audience": {
|
||||
"type": "string",
|
||||
"enum": ["client", "internal"],
|
||||
"description": "Target audience (for draft_update)"
|
||||
},
|
||||
"tone": {
|
||||
"type": "string",
|
||||
"enum": ["formal", "casual"],
|
||||
"description": "Communication tone (for draft_update)"
|
||||
},
|
||||
"highlights": {
|
||||
"type": "string",
|
||||
"description": "Key highlights for the update (for draft_update)"
|
||||
},
|
||||
"concerns": {
|
||||
"type": "string",
|
||||
"description": "Items requiring attention (for draft_update)"
|
||||
},
|
||||
"sprint_dates": {
|
||||
"type": "string",
|
||||
"description": "Sprint date range (for sprint_summary)"
|
||||
},
|
||||
"completed": {
|
||||
"type": "string",
|
||||
"description": "Completed items (for sprint_summary)"
|
||||
},
|
||||
"in_progress": {
|
||||
"type": "string",
|
||||
"description": "In-progress items (for sprint_summary)"
|
||||
},
|
||||
"blocked": {
|
||||
"type": "string",
|
||||
"description": "Blocked items (for sprint_summary)"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "string",
|
||||
"description": "Task descriptions, one per line (for effort_estimate)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
|
||||
|
||||
match action {
|
||||
"status_report" => self.execute_status_report(&args),
|
||||
"risk_scan" => self.execute_risk_scan(&args),
|
||||
"draft_update" => self.execute_draft_update(&args),
|
||||
"sprint_summary" => self.execute_sprint_summary(&args),
|
||||
"effort_estimate" => self.execute_effort_estimate(&args),
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action '{other}'. Valid actions: status_report, risk_scan, draft_update, sprint_summary, effort_estimate"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tool() -> ProjectIntelTool {
|
||||
ProjectIntelTool::new("en".into(), "medium".into())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_and_description() {
|
||||
let t = tool();
|
||||
assert_eq!(t.name(), "project_intel");
|
||||
assert!(!t.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_has_action() {
|
||||
let t = tool();
|
||||
let schema = t.parameters_schema();
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&serde_json::Value::String("action".into())));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn status_report_renders() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "status_report",
|
||||
"project_name": "TestProject",
|
||||
"period": "week",
|
||||
"git_log": "- feat: added login"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("TestProject"));
|
||||
assert!(result.output.contains("added login"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn risk_scan_detects_blockers() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "risk_scan",
|
||||
"blockers": "DB migration stuck\nCI pipeline broken\nAPI key expired"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("blocker"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn risk_scan_detects_deadline_risk() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "risk_scan",
|
||||
"deadlines": "Sprint deadline overdue by 3 days"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Deadline risk"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn risk_scan_no_signals_returns_low_risk() {
|
||||
let t = tool();
|
||||
let result = t.execute(json!({ "action": "risk_scan" })).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No significant risks"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn draft_update_formal_client() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "draft_update",
|
||||
"project_name": "Portal",
|
||||
"audience": "client",
|
||||
"tone": "formal",
|
||||
"highlights": "Phase 1 delivered"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Dear valued partner"));
|
||||
assert!(result.output.contains("Portal"));
|
||||
assert!(result.output.contains("Phase 1 delivered"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn draft_update_casual_internal() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "draft_update",
|
||||
"project_name": "ZeroClaw",
|
||||
"audience": "internal",
|
||||
"tone": "casual",
|
||||
"highlights": "Core loop stabilized"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Hey team"));
|
||||
assert!(result.output.contains("Cheers"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sprint_summary_renders() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "sprint_summary",
|
||||
"sprint_dates": "2026-03-01 to 2026-03-14",
|
||||
"completed": "- Login page\n- API endpoints",
|
||||
"in_progress": "- Dashboard",
|
||||
"blocked": "- Payment integration"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Login page"));
|
||||
assert!(result.output.contains("Dashboard"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn effort_estimate_basic() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "effort_estimate",
|
||||
"tasks": "Fix typo in README\nImplement user authentication\nRefactor database layer"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("XS"));
|
||||
assert!(result.output.contains("Refactor database layer"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn effort_estimate_empty_tasks_fails() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({ "action": "effort_estimate", "tasks": "" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("No task descriptions"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_action_returns_error() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({ "action": "invalid_thing" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_action_returns_error() {
|
||||
let t = tool();
|
||||
let result = t.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effort_estimate_heuristics_coverage() {
|
||||
assert_eq!(estimate_task_effort("Fix typo").0, "XS");
|
||||
assert_eq!(estimate_task_effort("Update config values").0, "XS");
|
||||
assert_eq!(
|
||||
estimate_task_effort("Implement new notification system").0,
|
||||
"M"
|
||||
);
|
||||
assert_eq!(
|
||||
estimate_task_effort("Refactor the entire authentication module").0,
|
||||
"XL"
|
||||
);
|
||||
assert_eq!(
|
||||
estimate_task_effort("Migrate the database schema to support multi-tenancy with data isolation and proper indexing across all services").0,
|
||||
"XXL"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn risk_sensitivity_threshold_ordering() {
|
||||
assert!(
|
||||
RiskSensitivity::High.threshold_factor() < RiskSensitivity::Medium.threshold_factor()
|
||||
);
|
||||
assert!(
|
||||
RiskSensitivity::Medium.threshold_factor() < RiskSensitivity::Low.threshold_factor()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn risk_sensitivity_from_str_variants() {
|
||||
assert_eq!(RiskSensitivity::from_str("low"), RiskSensitivity::Low);
|
||||
assert_eq!(RiskSensitivity::from_str("high"), RiskSensitivity::High);
|
||||
assert_eq!(RiskSensitivity::from_str("medium"), RiskSensitivity::Medium);
|
||||
assert_eq!(
|
||||
RiskSensitivity::from_str("unknown"),
|
||||
RiskSensitivity::Medium
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn high_sensitivity_detects_single_blocker_as_high() {
|
||||
let t = ProjectIntelTool::new("en".into(), "high".into());
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "risk_scan",
|
||||
"blockers": "Single blocker"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("[HIGH]") || result.output.contains("[CRITICAL]"));
|
||||
}
|
||||
}
|
||||
582
src/tools/report_templates.rs
Normal file
582
src/tools/report_templates.rs
Normal file
@ -0,0 +1,582 @@
|
||||
//! Report template engine for project delivery intelligence.
|
||||
//!
|
||||
//! Provides built-in templates for weekly status, sprint review, risk register,
|
||||
//! and milestone reports with multi-language support (EN, DE, FR, IT).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
|
||||
/// Supported report output formats.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReportFormat {
|
||||
Markdown,
|
||||
Html,
|
||||
}
|
||||
|
||||
/// A named section within a report template.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TemplateSection {
|
||||
pub heading: String,
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
/// A report template with named sections and variable placeholders.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReportTemplate {
|
||||
pub name: String,
|
||||
pub sections: Vec<TemplateSection>,
|
||||
pub format: ReportFormat,
|
||||
}
|
||||
|
||||
/// Escape a string for safe inclusion in HTML output.
|
||||
fn escape_html(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
impl ReportTemplate {
|
||||
/// Render the template by substituting `{{key}}` placeholders with values.
|
||||
pub fn render(&self, vars: &HashMap<String, String>) -> String {
|
||||
let mut out = String::new();
|
||||
for section in &self.sections {
|
||||
let heading = substitute(§ion.heading, vars);
|
||||
let body = substitute(§ion.body, vars);
|
||||
match self.format {
|
||||
ReportFormat::Markdown => {
|
||||
let _ = write!(out, "## {heading}\n\n{body}\n\n");
|
||||
}
|
||||
ReportFormat::Html => {
|
||||
let heading = escape_html(&heading);
|
||||
let body = escape_html(&body);
|
||||
let _ = write!(out, "<h2>{heading}</h2>\n<p>{body}</p>\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
out.trim_end().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Single-pass placeholder substitution.
|
||||
///
|
||||
/// Scans `template` left-to-right for `{{key}}` tokens and replaces them with
|
||||
/// the corresponding value from `vars`. Because the scan is single-pass,
|
||||
/// values that themselves contain `{{...}}` sequences are emitted literally
|
||||
/// and never re-expanded, preventing injection of new placeholders.
|
||||
fn substitute(template: &str, vars: &HashMap<String, String>) -> String {
|
||||
let mut result = String::with_capacity(template.len());
|
||||
let bytes = template.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut i = 0;
|
||||
|
||||
while i < len {
|
||||
if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' {
|
||||
// Find the closing `}}`.
|
||||
if let Some(close) = template[i + 2..].find("}}") {
|
||||
let key = &template[i + 2..i + 2 + close];
|
||||
if let Some(value) = vars.get(key) {
|
||||
result.push_str(value);
|
||||
} else {
|
||||
// Unknown placeholder: emit as-is.
|
||||
result.push_str(&template[i..i + 2 + close + 2]);
|
||||
}
|
||||
i += 2 + close + 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result.push(template.as_bytes()[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// ── Built-in templates ────────────────────────────────────────────
|
||||
|
||||
/// Return the built-in weekly status template for the given language.
|
||||
pub fn weekly_status_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Wochenstatus",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Zusammenfassung".into(),
|
||||
body: "Projekt: {{project_name}} | Zeitraum: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Erledigt".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Bearbeitung".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blockiert".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Naechste Schritte".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Statut hebdomadaire",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Resume".into(),
|
||||
body: "Projet: {{project_name}} | Periode: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Termine".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "En cours".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloque".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Prochaines etapes".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Stato settimanale",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Riepilogo".into(),
|
||||
body: "Progetto: {{project_name}} | Periodo: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completato".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In corso".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloccato".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Prossimi passi".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Weekly Status",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Summary".into(),
|
||||
body: "Project: {{project_name}} | Period: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completed".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Progress".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blocked".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Next Steps".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the built-in sprint review template for the given language.
|
||||
pub fn sprint_review_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Sprint-Uebersicht",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Erledigt".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Bearbeitung".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blockiert".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocity".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Revue de sprint",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Termine".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "En cours".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloque".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocite".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Revisione sprint",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completato".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In corso".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloccato".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocita".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Sprint Review",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completed".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Progress".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blocked".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocity".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the built-in risk register template for the given language.
|
||||
pub fn risk_register_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Risikoregister",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projekt".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Risiken".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Massnahmen".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Registre des risques",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projet".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Risques".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Mesures".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Registro dei rischi",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Progetto".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Rischi".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Mitigazioni".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Risk Register",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Project".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Risks".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Mitigations".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the built-in milestone report template for the given language.
|
||||
pub fn milestone_report_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Meilensteinbericht",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projekt".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Meilensteine".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Status".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Rapport de jalons",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projet".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Jalons".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Statut".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Report milestone",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Progetto".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Milestone".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Stato".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Milestone Report",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Project".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Milestones".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Status".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn weekly_status_renders_with_variables() {
|
||||
let tpl = weekly_status_template("en");
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("project_name".into(), "ZeroClaw".into());
|
||||
vars.insert("period".into(), "2026-W10".into());
|
||||
vars.insert("completed".into(), "- Task A\n- Task B".into());
|
||||
vars.insert("in_progress".into(), "- Task C".into());
|
||||
vars.insert("blocked".into(), "None".into());
|
||||
vars.insert("next_steps".into(), "- Task D".into());
|
||||
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("Project: ZeroClaw"));
|
||||
assert!(rendered.contains("Period: 2026-W10"));
|
||||
assert!(rendered.contains("- Task A"));
|
||||
assert!(rendered.contains("## Completed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weekly_status_de_renders_german_headings() {
|
||||
let tpl = weekly_status_template("de");
|
||||
let vars = HashMap::new();
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("## Zusammenfassung"));
|
||||
assert!(rendered.contains("## Erledigt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weekly_status_fr_renders_french_headings() {
|
||||
let tpl = weekly_status_template("fr");
|
||||
let vars = HashMap::new();
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("## Resume"));
|
||||
assert!(rendered.contains("## Termine"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weekly_status_it_renders_italian_headings() {
|
||||
let tpl = weekly_status_template("it");
|
||||
let vars = HashMap::new();
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("## Riepilogo"));
|
||||
assert!(rendered.contains("## Completato"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_format_renders_tags() {
|
||||
let mut tpl = weekly_status_template("en");
|
||||
tpl.format = ReportFormat::Html;
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("project_name".into(), "Test".into());
|
||||
vars.insert("period".into(), "W1".into());
|
||||
vars.insert("completed".into(), "Done".into());
|
||||
vars.insert("in_progress".into(), "WIP".into());
|
||||
vars.insert("blocked".into(), "None".into());
|
||||
vars.insert("next_steps".into(), "Next".into());
|
||||
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("<h2>Summary</h2>"));
|
||||
assert!(rendered.contains("<p>Project: Test | Period: W1</p>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sprint_review_template_has_velocity_section() {
|
||||
let tpl = sprint_review_template("en");
|
||||
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
|
||||
assert!(section_headings.contains(&"Velocity"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn risk_register_template_has_risk_sections() {
|
||||
let tpl = risk_register_template("en");
|
||||
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
|
||||
assert!(section_headings.contains(&"Risks"));
|
||||
assert!(section_headings.contains(&"Mitigations"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn milestone_template_all_languages() {
|
||||
for lang in &["en", "de", "fr", "it"] {
|
||||
let tpl = milestone_report_template(lang);
|
||||
assert!(!tpl.name.is_empty());
|
||||
assert_eq!(tpl.sections.len(), 3);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn substitute_leaves_unknown_placeholders() {
|
||||
let vars = HashMap::new();
|
||||
let result = substitute("Hello {{name}}", &vars);
|
||||
assert_eq!(result, "Hello {{name}}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn substitute_replaces_all_occurrences() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("x".into(), "1".into());
|
||||
let result = substitute("{{x}} and {{x}}", &vars);
|
||||
assert_eq!(result, "1 and 1");
|
||||
}
|
||||
}
|
||||
659
src/tools/security_ops.rs
Normal file
659
src/tools/security_ops.rs
Normal file
@ -0,0 +1,659 @@
|
||||
//! Security operations tool for managed cybersecurity service (MCSS) workflows.
|
||||
//!
|
||||
//! Provides alert triage, incident response playbook execution, vulnerability
|
||||
//! scan parsing, and security report generation. All actions that modify state
|
||||
//! enforce human approval gates unless explicitly configured otherwise.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::SecurityOpsConfig;
|
||||
use crate::security::playbook::{
|
||||
evaluate_step, load_playbooks, severity_level, Playbook, StepStatus,
|
||||
};
|
||||
use crate::security::vulnerability::{generate_summary, parse_vulnerability_json};
|
||||
|
||||
/// Security operations tool — triage alerts, run playbooks, parse vulns, generate reports.
|
||||
pub struct SecurityOpsTool {
|
||||
config: SecurityOpsConfig,
|
||||
playbooks: Vec<Playbook>,
|
||||
}
|
||||
|
||||
impl SecurityOpsTool {
|
||||
pub fn new(config: SecurityOpsConfig) -> Self {
|
||||
let playbooks_dir = expand_tilde(&config.playbooks_dir);
|
||||
let playbooks = load_playbooks(&playbooks_dir);
|
||||
Self { config, playbooks }
|
||||
}
|
||||
|
||||
/// Triage an alert: classify severity and recommend response.
|
||||
fn triage_alert(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let alert = args
|
||||
.get("alert")
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'alert' parameter"))?;
|
||||
|
||||
// Extract key fields for classification
|
||||
let alert_type = alert
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let source = alert
|
||||
.get("source")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let severity = alert
|
||||
.get("severity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("medium");
|
||||
let description = alert
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Classify and find matching playbooks
|
||||
let matching_playbooks: Vec<&Playbook> = self
|
||||
.playbooks
|
||||
.iter()
|
||||
.filter(|pb| {
|
||||
severity_level(severity) >= severity_level(&pb.severity_filter)
|
||||
&& (pb.name.contains(alert_type)
|
||||
|| alert_type.contains(&pb.name)
|
||||
|| description
|
||||
.to_lowercase()
|
||||
.contains(&pb.name.replace('_', " ")))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let playbook_names: Vec<&str> =
|
||||
matching_playbooks.iter().map(|p| p.name.as_str()).collect();
|
||||
|
||||
let output = json!({
|
||||
"classification": {
|
||||
"alert_type": alert_type,
|
||||
"source": source,
|
||||
"severity": severity,
|
||||
"severity_level": severity_level(severity),
|
||||
"priority": if severity_level(severity) >= 3 { "immediate" } else { "standard" },
|
||||
},
|
||||
"recommended_playbooks": playbook_names,
|
||||
"recommended_action": if matching_playbooks.is_empty() {
|
||||
"Manual investigation required — no matching playbook found"
|
||||
} else {
|
||||
"Execute recommended playbook(s)"
|
||||
},
|
||||
"auto_triage": self.config.auto_triage,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a playbook step with approval gating.
|
||||
fn run_playbook(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let playbook_name = args
|
||||
.get("playbook")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'playbook' parameter"))?;
|
||||
|
||||
let step_index =
|
||||
usize::try_from(args.get("step").and_then(|v| v.as_u64()).ok_or_else(|| {
|
||||
anyhow::anyhow!("Missing required 'step' parameter (0-based index)")
|
||||
})?)
|
||||
.map_err(|_| anyhow::anyhow!("'step' parameter value too large for this platform"))?;
|
||||
|
||||
let alert_severity = args
|
||||
.get("alert_severity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("medium");
|
||||
|
||||
let playbook = self
|
||||
.playbooks
|
||||
.iter()
|
||||
.find(|p| p.name == playbook_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("Playbook '{}' not found", playbook_name))?;
|
||||
|
||||
let result = evaluate_step(
|
||||
playbook,
|
||||
step_index,
|
||||
alert_severity,
|
||||
&self.config.max_auto_severity,
|
||||
self.config.require_approval_for_actions,
|
||||
);
|
||||
|
||||
let output = json!({
|
||||
"playbook": playbook_name,
|
||||
"step_index": result.step_index,
|
||||
"action": result.action,
|
||||
"status": result.status.to_string(),
|
||||
"message": result.message,
|
||||
"requires_manual_approval": result.status == StepStatus::PendingApproval,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: result.status != StepStatus::Failed,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: if result.status == StepStatus::Failed {
|
||||
Some(result.message)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse vulnerability scan results.
|
||||
fn parse_vulnerability(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let scan_data = args
|
||||
.get("scan_data")
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'scan_data' parameter"))?;
|
||||
|
||||
let json_str = if scan_data.is_string() {
|
||||
scan_data.as_str().unwrap().to_string()
|
||||
} else {
|
||||
serde_json::to_string(scan_data)?
|
||||
};
|
||||
|
||||
let report = parse_vulnerability_json(&json_str)?;
|
||||
let summary = generate_summary(&report);
|
||||
|
||||
let output = json!({
|
||||
"scanner": report.scanner,
|
||||
"scan_date": report.scan_date.to_rfc3339(),
|
||||
"total_findings": report.findings.len(),
|
||||
"by_severity": {
|
||||
"critical": report.findings.iter().filter(|f| f.severity == "critical").count(),
|
||||
"high": report.findings.iter().filter(|f| f.severity == "high").count(),
|
||||
"medium": report.findings.iter().filter(|f| f.severity == "medium").count(),
|
||||
"low": report.findings.iter().filter(|f| f.severity == "low").count(),
|
||||
},
|
||||
"summary": summary,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a client-facing security posture report.
|
||||
fn generate_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let client_name = args
|
||||
.get("client_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Client");
|
||||
let period = args
|
||||
.get("period")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("current");
|
||||
let alert_stats = args.get("alert_stats");
|
||||
let vuln_summary = args
|
||||
.get("vuln_summary")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let report = format!(
|
||||
"# Security Posture Report — {client_name}\n\
|
||||
**Period:** {period}\n\
|
||||
**Generated:** {}\n\n\
|
||||
## Executive Summary\n\n\
|
||||
This report provides an overview of the security posture for {client_name} \
|
||||
during the {period} period.\n\n\
|
||||
## Alert Summary\n\n\
|
||||
{}\n\n\
|
||||
## Vulnerability Assessment\n\n\
|
||||
{}\n\n\
|
||||
## Recommendations\n\n\
|
||||
1. Address all critical and high-severity findings immediately\n\
|
||||
2. Review and update incident response playbooks quarterly\n\
|
||||
3. Conduct regular vulnerability scans on all internet-facing assets\n\
|
||||
4. Ensure all endpoints have current security patches\n\n\
|
||||
---\n\
|
||||
*Report generated by ZeroClaw MCSS Agent*\n",
|
||||
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"),
|
||||
alert_stats
|
||||
.map(|s| serde_json::to_string_pretty(s).unwrap_or_default())
|
||||
.unwrap_or_else(|| "No alert statistics provided.".into()),
|
||||
if vuln_summary.is_empty() {
|
||||
"No vulnerability data provided."
|
||||
} else {
|
||||
vuln_summary
|
||||
},
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: report,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// List available playbooks.
|
||||
fn list_playbooks(&self) -> anyhow::Result<ToolResult> {
|
||||
if self.playbooks.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No playbooks available.".into(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let playbook_list: Vec<serde_json::Value> = self
|
||||
.playbooks
|
||||
.iter()
|
||||
.map(|pb| {
|
||||
json!({
|
||||
"name": pb.name,
|
||||
"description": pb.description,
|
||||
"steps": pb.steps.len(),
|
||||
"severity_filter": pb.severity_filter,
|
||||
"auto_approve_steps": pb.auto_approve_steps,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&playbook_list)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Summarize alert volume, categories, and resolution times.
|
||||
fn alert_stats(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let alerts = args
|
||||
.get("alerts")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'alerts' array parameter"))?;
|
||||
|
||||
let total = alerts.len();
|
||||
let mut by_severity = std::collections::HashMap::new();
|
||||
let mut by_category = std::collections::HashMap::new();
|
||||
let mut resolved_count = 0u64;
|
||||
let mut total_resolution_secs = 0u64;
|
||||
|
||||
for alert in alerts {
|
||||
let severity = alert
|
||||
.get("severity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
*by_severity.entry(severity.to_string()).or_insert(0u64) += 1;
|
||||
|
||||
let category = alert
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("uncategorized");
|
||||
*by_category.entry(category.to_string()).or_insert(0u64) += 1;
|
||||
|
||||
if let Some(resolution_secs) = alert.get("resolution_secs").and_then(|v| v.as_u64()) {
|
||||
resolved_count += 1;
|
||||
total_resolution_secs += resolution_secs;
|
||||
}
|
||||
}
|
||||
|
||||
let avg_resolution = if resolved_count > 0 {
|
||||
total_resolution_secs as f64 / resolved_count as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
let avg_resolution_secs_u64 = avg_resolution.max(0.0) as u64;
|
||||
|
||||
let output = json!({
|
||||
"total_alerts": total,
|
||||
"resolved": resolved_count,
|
||||
"unresolved": total as u64 - resolved_count,
|
||||
"by_severity": by_severity,
|
||||
"by_category": by_category,
|
||||
"avg_resolution_secs": avg_resolution,
|
||||
"avg_resolution_human": format_duration_secs(avg_resolution_secs_u64),
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn format_duration_secs(secs: u64) -> String {
|
||||
if secs < 60 {
|
||||
format!("{secs}s")
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
|
||||
/// Expand ~ to home directory.
|
||||
fn expand_tilde(path: &str) -> PathBuf {
|
||||
if let Some(rest) = path.strip_prefix("~/") {
|
||||
if let Some(user_dirs) = directories::UserDirs::new() {
|
||||
return user_dirs.home_dir().join(rest);
|
||||
}
|
||||
}
|
||||
PathBuf::from(path)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SecurityOpsTool {
|
||||
fn name(&self) -> &str {
|
||||
"security_ops"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Security operations tool for managed cybersecurity services. Actions: \
|
||||
triage_alert (classify/prioritize alerts), run_playbook (execute incident response steps), \
|
||||
parse_vulnerability (parse scan results), generate_report (create security posture reports), \
|
||||
list_playbooks (list available playbooks), alert_stats (summarize alert metrics)."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["action"],
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["triage_alert", "run_playbook", "parse_vulnerability", "generate_report", "list_playbooks", "alert_stats"],
|
||||
"description": "The security operation to perform"
|
||||
},
|
||||
"alert": {
|
||||
"type": "object",
|
||||
"description": "Alert JSON for triage_alert (requires: type, severity; optional: source, description)"
|
||||
},
|
||||
"playbook": {
|
||||
"type": "string",
|
||||
"description": "Playbook name for run_playbook"
|
||||
},
|
||||
"step": {
|
||||
"type": "integer",
|
||||
"description": "0-based step index for run_playbook"
|
||||
},
|
||||
"alert_severity": {
|
||||
"type": "string",
|
||||
"description": "Alert severity context for run_playbook"
|
||||
},
|
||||
"scan_data": {
|
||||
"description": "Vulnerability scan data (JSON string or object) for parse_vulnerability"
|
||||
},
|
||||
"client_name": {
|
||||
"type": "string",
|
||||
"description": "Client name for generate_report"
|
||||
},
|
||||
"period": {
|
||||
"type": "string",
|
||||
"description": "Reporting period for generate_report"
|
||||
},
|
||||
"alert_stats": {
|
||||
"type": "object",
|
||||
"description": "Alert statistics to include in generate_report"
|
||||
},
|
||||
"vuln_summary": {
|
||||
"type": "string",
|
||||
"description": "Vulnerability summary to include in generate_report"
|
||||
},
|
||||
"alerts": {
|
||||
"type": "array",
|
||||
"description": "Array of alert objects for alert_stats"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
|
||||
|
||||
match action {
|
||||
"triage_alert" => self.triage_alert(&args),
|
||||
"run_playbook" => self.run_playbook(&args),
|
||||
"parse_vulnerability" => self.parse_vulnerability(&args),
|
||||
"generate_report" => self.generate_report(&args),
|
||||
"list_playbooks" => self.list_playbooks(),
|
||||
"alert_stats" => self.alert_stats(&args),
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action '{action}'. Valid: triage_alert, run_playbook, \
|
||||
parse_vulnerability, generate_report, list_playbooks, alert_stats"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> SecurityOpsConfig {
|
||||
SecurityOpsConfig {
|
||||
enabled: true,
|
||||
playbooks_dir: "/nonexistent".into(),
|
||||
auto_triage: false,
|
||||
require_approval_for_actions: true,
|
||||
max_auto_severity: "low".into(),
|
||||
report_output_dir: "/tmp/reports".into(),
|
||||
siem_integration: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_tool() -> SecurityOpsTool {
|
||||
SecurityOpsTool::new(test_config())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_and_schema() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "security_ops");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("action")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn triage_alert_classifies_severity() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "triage_alert",
|
||||
"alert": {
|
||||
"type": "suspicious_login",
|
||||
"source": "siem",
|
||||
"severity": "high",
|
||||
"description": "Multiple failed login attempts followed by successful login"
|
||||
}
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["classification"]["severity"], "high");
|
||||
assert_eq!(output["classification"]["priority"], "immediate");
|
||||
// Should match suspicious_login playbook
|
||||
let playbooks = output["recommended_playbooks"].as_array().unwrap();
|
||||
assert!(playbooks.iter().any(|p| p == "suspicious_login"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn triage_alert_missing_alert_param() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "triage_alert"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_playbook_requires_approval() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "run_playbook",
|
||||
"playbook": "suspicious_login",
|
||||
"step": 2,
|
||||
"alert_severity": "high"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["status"], "pending_approval");
|
||||
assert_eq!(output["requires_manual_approval"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_playbook_executes_safe_step() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "run_playbook",
|
||||
"playbook": "suspicious_login",
|
||||
"step": 0,
|
||||
"alert_severity": "medium"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["status"], "completed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_playbook_not_found() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "run_playbook",
|
||||
"playbook": "nonexistent",
|
||||
"step": 0
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_vulnerability_valid_report() {
|
||||
let tool = test_tool();
|
||||
let scan_data = json!({
|
||||
"scan_date": "2025-01-15T10:00:00Z",
|
||||
"scanner": "nessus",
|
||||
"findings": [
|
||||
{
|
||||
"cve_id": "CVE-2024-0001",
|
||||
"cvss_score": 9.8,
|
||||
"severity": "critical",
|
||||
"affected_asset": "web-01",
|
||||
"description": "RCE in web framework",
|
||||
"remediation": "Upgrade",
|
||||
"internet_facing": true,
|
||||
"production": true
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "parse_vulnerability",
|
||||
"scan_data": scan_data
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["total_findings"], 1);
|
||||
assert_eq!(output["by_severity"]["critical"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn generate_report_produces_markdown() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "generate_report",
|
||||
"client_name": "ZeroClaw Corp",
|
||||
"period": "Q1 2025"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("ZeroClaw Corp"));
|
||||
assert!(result.output.contains("Q1 2025"));
|
||||
assert!(result.output.contains("Security Posture Report"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_playbooks_returns_builtins() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "list_playbooks"}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: Vec<serde_json::Value> = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output.len(), 4);
|
||||
let names: Vec<&str> = output.iter().map(|p| p["name"].as_str().unwrap()).collect();
|
||||
assert!(names.contains(&"suspicious_login"));
|
||||
assert!(names.contains(&"malware_detected"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn alert_stats_computes_summary() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "alert_stats",
|
||||
"alerts": [
|
||||
{"severity": "critical", "category": "malware", "resolution_secs": 3600},
|
||||
{"severity": "high", "category": "phishing", "resolution_secs": 1800},
|
||||
{"severity": "medium", "category": "malware"},
|
||||
{"severity": "low", "category": "policy_violation", "resolution_secs": 600}
|
||||
]
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["total_alerts"], 4);
|
||||
assert_eq!(output["resolved"], 3);
|
||||
assert_eq!(output["unresolved"], 1);
|
||||
assert_eq!(output["by_severity"]["critical"], 1);
|
||||
assert_eq!(output["by_category"]["malware"], 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_action_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "bad_action"})).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_duration_secs_readable() {
|
||||
assert_eq!(format_duration_secs(45), "45s");
|
||||
assert_eq!(format_duration_secs(125), "2m 5s");
|
||||
assert_eq!(format_duration_secs(3665), "1h 1m");
|
||||
}
|
||||
}
|
||||
@ -2,6 +2,7 @@ mod cloudflare;
|
||||
mod custom;
|
||||
mod ngrok;
|
||||
mod none;
|
||||
mod openvpn;
|
||||
mod tailscale;
|
||||
|
||||
pub use cloudflare::CloudflareTunnel;
|
||||
@ -9,6 +10,7 @@ pub use custom::CustomTunnel;
|
||||
pub use ngrok::NgrokTunnel;
|
||||
#[allow(unused_imports)]
|
||||
pub use none::NoneTunnel;
|
||||
pub use openvpn::OpenVpnTunnel;
|
||||
pub use tailscale::TailscaleTunnel;
|
||||
|
||||
use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig};
|
||||
@ -104,6 +106,20 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
"openvpn" => {
|
||||
let ov = config
|
||||
.openvpn
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?;
|
||||
Ok(Some(Box::new(OpenVpnTunnel::new(
|
||||
ov.config_file.clone(),
|
||||
ov.auth_file.clone(),
|
||||
ov.advertise_address.clone(),
|
||||
ov.connect_timeout_secs,
|
||||
ov.extra_args.clone(),
|
||||
))))
|
||||
}
|
||||
|
||||
"custom" => {
|
||||
let cu = config
|
||||
.custom
|
||||
@ -116,7 +132,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"),
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,7 +142,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::{
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig,
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig,
|
||||
TunnelConfig,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
@ -315,6 +332,46 @@ mod tests {
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_missing_config_errors() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
assert_tunnel_err(&cfg, "[tunnel.openvpn]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_with_config_ok() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
openvpn: Some(OpenVpnTunnelConfig {
|
||||
config_file: "client.ovpn".into(),
|
||||
auth_file: None,
|
||||
advertise_address: None,
|
||||
connect_timeout_secs: 30,
|
||||
extra_args: vec![],
|
||||
}),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
let t = create_tunnel(&cfg).unwrap();
|
||||
assert!(t.is_some());
|
||||
assert_eq!(t.unwrap().name(), "openvpn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openvpn_tunnel_name() {
|
||||
let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert_eq!(t.name(), "openvpn");
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openvpn_health_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_shared_no_process_is_ok() {
|
||||
let proc = new_shared_process();
|
||||
|
||||
254
src/tunnel/openvpn.rs
Normal file
254
src/tunnel/openvpn.rs
Normal file
@ -0,0 +1,254 @@
|
||||
use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess};
|
||||
use anyhow::{bail, Result};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection.
|
||||
///
|
||||
/// Requires the `openvpn` binary installed and accessible. On most systems,
|
||||
/// OpenVPN requires root/administrator privileges to create tun/tap devices.
|
||||
///
|
||||
/// The tunnel exposes the gateway via the VPN network using a configured
|
||||
/// `advertise_address` (e.g., `"10.8.0.2:42617"`).
|
||||
pub struct OpenVpnTunnel {
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
proc: SharedProcess,
|
||||
}
|
||||
|
||||
impl OpenVpnTunnel {
|
||||
/// Create a new OpenVPN tunnel instance.
|
||||
///
|
||||
/// * `config_file` — path to the `.ovpn` configuration file.
|
||||
/// * `auth_file` — optional path to a credentials file for `--auth-user-pass`.
|
||||
/// * `advertise_address` — optional public address to advertise once connected.
|
||||
/// * `connect_timeout_secs` — seconds to wait for the initialization sequence.
|
||||
/// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary.
|
||||
pub fn new(
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config_file,
|
||||
auth_file,
|
||||
advertise_address,
|
||||
connect_timeout_secs,
|
||||
extra_args,
|
||||
proc: new_shared_process(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the openvpn command arguments.
|
||||
fn build_args(&self) -> Vec<String> {
|
||||
let mut args = vec!["--config".to_string(), self.config_file.clone()];
|
||||
|
||||
if let Some(ref auth) = self.auth_file {
|
||||
args.push("--auth-user-pass".to_string());
|
||||
args.push(auth.clone());
|
||||
}
|
||||
|
||||
args.extend(self.extra_args.iter().cloned());
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tunnel for OpenVpnTunnel {
|
||||
fn name(&self) -> &str {
|
||||
"openvpn"
|
||||
}
|
||||
|
||||
/// Spawn the `openvpn` process and wait for the "Initialization Sequence
|
||||
/// Completed" marker on stderr. Returns the public URL on success.
|
||||
async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
|
||||
// Validate config file exists before spawning
|
||||
if !std::path::Path::new(&self.config_file).exists() {
|
||||
bail!("OpenVPN config file not found: {}", self.config_file);
|
||||
}
|
||||
|
||||
let args = self.build_args();
|
||||
|
||||
let mut child = Command::new("openvpn")
|
||||
.args(&args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
// Wait for "Initialization Sequence Completed" in stderr
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?;
|
||||
|
||||
let mut reader = tokio::io::BufReader::new(stderr).lines();
|
||||
let deadline = tokio::time::Instant::now()
|
||||
+ tokio::time::Duration::from_secs(self.connect_timeout_secs);
|
||||
|
||||
let mut connected = false;
|
||||
while tokio::time::Instant::now() < deadline {
|
||||
let line =
|
||||
tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await;
|
||||
|
||||
match line {
|
||||
Ok(Ok(Some(l))) => {
|
||||
tracing::debug!("openvpn: {l}");
|
||||
if l.contains("Initialization Sequence Completed") {
|
||||
connected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Ok(None)) => {
|
||||
bail!("OpenVPN process exited before connection was established");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
bail!("Error reading openvpn output: {e}");
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout on individual line read, continue waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !connected {
|
||||
child.kill().await.ok();
|
||||
bail!(
|
||||
"OpenVPN connection timed out after {}s waiting for initialization",
|
||||
self.connect_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
let public_url = self
|
||||
.advertise_address
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("http://{local_host}:{local_port}"));
|
||||
|
||||
// Drain stderr in background to prevent OS pipe buffer from filling and
|
||||
// blocking the openvpn process.
|
||||
tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = reader.next_line().await {
|
||||
tracing::trace!("openvpn: {line}");
|
||||
}
|
||||
});
|
||||
|
||||
let mut guard = self.proc.lock().await;
|
||||
*guard = Some(TunnelProcess {
|
||||
child,
|
||||
public_url: public_url.clone(),
|
||||
});
|
||||
|
||||
Ok(public_url)
|
||||
}
|
||||
|
||||
/// Kill the openvpn child process and release its resources.
|
||||
async fn stop(&self) -> Result<()> {
|
||||
kill_shared(&self.proc).await
|
||||
}
|
||||
|
||||
/// Return `true` if the openvpn child process is still running.
|
||||
async fn health_check(&self) -> bool {
|
||||
let guard = self.proc.lock().await;
|
||||
guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
|
||||
}
|
||||
|
||||
/// Return the public URL if the tunnel has been started.
|
||||
fn public_url(&self) -> Option<String> {
|
||||
self.proc
|
||||
.try_lock()
|
||||
.ok()
|
||||
.and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn constructor_stores_fields() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/etc/openvpn/client.ovpn".into(),
|
||||
Some("/etc/openvpn/auth.txt".into()),
|
||||
Some("10.8.0.2:42617".into()),
|
||||
45,
|
||||
vec!["--verb".into(), "3".into()],
|
||||
);
|
||||
assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn");
|
||||
assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt"));
|
||||
assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617"));
|
||||
assert_eq!(tunnel.connect_timeout_secs, 45);
|
||||
assert_eq!(tunnel.extra_args, vec!["--verb", "3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_basic() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(args, vec!["--config", "client.ovpn"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_with_auth_and_extras() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"client.ovpn".into(),
|
||||
Some("auth.txt".into()),
|
||||
None,
|
||||
30,
|
||||
vec!["--verb".into(), "5".into()],
|
||||
);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(
|
||||
args,
|
||||
vec![
|
||||
"--config",
|
||||
"client.ovpn",
|
||||
"--auth-user-pass",
|
||||
"auth.txt",
|
||||
"--verb",
|
||||
"5"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_url_is_none_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(tunnel.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_check_is_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stop_without_started_process_is_ok() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let result = tunnel.stop().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_with_missing_config_file_errors() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/nonexistent/path/to/client.ovpn".into(),
|
||||
None,
|
||||
None,
|
||||
30,
|
||||
vec![],
|
||||
);
|
||||
let result = tunnel.start("127.0.0.1", 8080).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("config file not found"));
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user