Compare commits

...

54 Commits

Author SHA1 Message Date
argenis de la rosa 5d9f83aee8 fix: add SecurityOpsConfig to re-exports, fix test constructors
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 02:28:41 -04:00
argenis de la rosa 8f9ce15e33 Merge remote-tracking branch 'origin/master' into work/security-ops
# Conflicts:
#	src/config/mod.rs
#	src/config/schema.rs
#	src/onboard/wizard.rs
#	src/tools/mod.rs
2026-03-16 02:16:55 -04:00
Argenis dcc0a629ec feat(tools): add project delivery intelligence tool (#3656)
Add a new read-only project_intel tool that provides:
- Status report generation (weekly/sprint/month)
- Risk scanning with configurable sensitivity
- Client update drafting (formal/casual, client/internal)
- Sprint summary generation
- Heuristic effort estimation

Includes multi-language report templates (EN, DE, FR, IT),
ProjectIntelConfig schema with validation, and comprehensive tests.

Also fixes missing approval_manager field in 4 ChannelRuntimeContext
test constructors.

Supersedes #3591 — rebased on latest master. Original work by @rareba.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 02:10:14 -04:00
Argenis a8c6363cde feat(nodes): add secure HMAC-SHA256 node transport layer (#3654)
Add a new `nodes` module with HMAC-SHA256 authenticated transport for
secure inter-node communication over standard HTTPS. Includes replay
protection via timestamped nonces and constant-time signature
comparison.

Also adds `NodeTransportConfig` to the config schema and fixes missing
`approval_manager` field in four `ChannelRuntimeContext` test
constructors that failed compilation on latest master.

Original work by @rareba. Rebased on latest master to resolve merge
conflicts (SwarmConfig/SwarmStrategy exports, duplicate MCP validation,
test constructor fields).

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 01:53:47 -04:00
Argenis d9ab017df0 feat(tools): add Microsoft 365 integration via Graph API (#3653)
Add Microsoft 365 tool providing access to Outlook mail, Teams messages,
Calendar events, OneDrive files, and SharePoint search via Microsoft
Graph API. Includes OAuth2 token caching (client credentials and device
code flows), security policy enforcement, and config validation.

Rebased on latest master, resolving conflicts with SwarmConfig exports
and adding approval_manager to ChannelRuntimeContext test constructors.

Original work by @rareba.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 01:44:39 -04:00
Argenis 249434edb2 feat(notion): add Notion database poller channel and API tool (#3650)
Add Notion integration with two components:
- NotionChannel: polls a Notion database for tasks with configurable
  status properties, concurrency limits, and stale task recovery
- NotionTool: provides CRUD operations (query_database, read_page,
  create_page, update_page) for agent-driven Notion interactions

Includes config schema (NotionConfig), onboarding wizard support,
and full unit test coverage for both channel and tool.

Supersedes #3609 — rebased on latest master to resolve merge conflicts
with swarm feature additions in config/mod.rs.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 00:55:23 -04:00
Argenis 62781a8d45 fix(lint): Box::pin crate::agent::run calls to satisfy large_futures (#3675)
Wrap all crate::agent::run() calls with Box::pin() across scheduler,
daemon, gateway tests, and main.rs to satisfy clippy::large_futures.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 00:54:27 -04:00
Chris Hengge 0adec305f9 fix(tools): qualify is_service_environment with super:: inside mod native_backend (#3659)
Commit 811fab3b added is_service_environment() as a top-level function and
called it from two sites. The call at line 445 is at module scope and resolves
fine. The call at line 1473 is inside mod native_backend, which is a child
module — Rust does not implicitly import parent-scope items, so the unqualified
name fails with E0425 (cannot find function in this scope).

Fix: prefix the call with super:: so it resolves to the parent module's
function, matching how mod native_backend already imports other parent items
(e.g. use super::BrowserAction).

The browser-native feature flag is required to reproduce:
  cargo check --features browser-native  # fails without this fix
  cargo check --features browser-native  # clean with this fix

Co-authored-by: Argenis <theonlyhennygod@gmail.com>
2026-03-16 00:35:09 -04:00
Argenis 75701195d7 feat(security): add Nevis IAM integration for SSO/MFA authentication (#3651)
* feat(security): add Nevis IAM integration for SSO/MFA authentication

Add NevisAuthProvider supporting OAuth2/OIDC token validation (local JWKS +
remote introspection), FIDO2/passkey/OTP MFA verification, session management,
and health checks. Add IamPolicy engine mapping Nevis roles to ZeroClaw tool
and workspace permissions with deny-by-default enforcement and audit logging.

Add NevisConfig and NevisRoleMappingConfig to config schema with client_secret
wired through SecretStore encrypt/decrypt. All features disabled by default.

Rebased on latest master to resolve merge conflicts in security/mod.rs (redact
function) and config/schema.rs (test section).

Original work by @rareba. Supersedes #3593.

Co-Authored-By: rareba <5985289+rareba@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: cargo fmt Box::pin calls

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: rareba <5985289+rareba@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 00:34:52 -04:00
Argenis 82fe2e53fd feat(tunnel): add OpenVPN tunnel provider (#3648)
* feat(tunnel): add OpenVPN tunnel provider

Add OpenVPN as a new tunnel provider alongside cloudflare, tailscale,
ngrok, and custom. Includes config schema, validation, factory wiring,
and comprehensive unit tests.

Co-authored-by: rareba <rareba@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add missing approval_manager field to ChannelRuntimeContext constructors

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: rareba <rareba@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 00:34:34 -04:00
Argenis 13a933f6a6 Merge branch 'master' into work/security-ops 2026-03-15 23:34:50 -04:00
Argenis 327e2b4c47 style: cargo fmt Box::pin calls in cron scheduler (#3667)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 23:34:26 -04:00
Argenis 421671b796 Merge branch 'master' into work/security-ops 2026-03-15 23:25:48 -04:00
Argenis 5a5d9ae5f9 fix(lint): Box::pin large futures in cron scheduler and cron_run tool (#3666)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 23:25:28 -04:00
argenis de la rosa 835214e23f feat(security): add MCSS security operations tool
Add managed cybersecurity service (MCSS) tool with alert triage,
incident response playbook execution, vulnerability scan parsing,
and security report generation. Includes SecurityOpsConfig, playbook
engine with approval gating, vulnerability scoring, and full test
coverage. Also fixes pre-existing missing approval_manager field in
ChannelRuntimeContext test constructors.

Original work by @rareba. Supersedes #3599.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 23:12:16 -04:00
Argenis d8f228bd15 Merge pull request #3661 from zeroclaw-labs/work/multi-client-workspaces
feat(workspace): add multi-client workspace isolation
2026-03-15 23:10:01 -04:00
Argenis d5eeaed3d9 Merge pull request #3649 from zeroclaw-labs/work/capability-tool-access
feat(security): add capability-based tool access control
2026-03-15 23:09:51 -04:00
Argenis 429094b049 Merge pull request #3660 from zeroclaw-labs/fix/cron-large-future
fix: add Box::pin for large future and missing approval_manager in tests
2026-03-15 23:08:47 -04:00
argenis de la rosa 80213b08ef feat(workspace): add multi-client workspace isolation
Add workspace profile management, security boundary enforcement, and
a workspace management tool for isolated client engagements.

Original work by @rareba. Supersedes #3597 — rebased on latest master.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 22:41:18 -04:00
argenis de la rosa bf67124499 fix: add Box::pin for large future and missing approval_manager in tests
- Box::pin the cron_run execute_job_now call to satisfy clippy::large_futures
- Add missing approval_manager field to 4 query_classification test constructors

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 22:32:06 -04:00
argenis de la rosa cb250dfecf fix: add missing approval_manager field to ChannelRuntimeContext test constructors
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 20:03:59 -04:00
argenis de la rosa fabd35c4ea feat(security): add capability-based tool access control
Add an optional `allowed_tools` parameter that restricts which tools are
available to the agent. When `Some(list)`, only tools whose name appears
in the list are retained; when `None`, all tools remain available
(backward compatible). This enables fine-grained capability control for
cron jobs, heartbeat tasks, and CLI invocations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 19:34:34 -04:00
Argenis a695ca4b9c fix(onboard): auto-detect TTY instead of --interactive flag (#3573)
Remove the --interactive flag from `zeroclaw onboard`. The command now
auto-detects whether stdin/stdout are a TTY: if yes and no provider
flags are given, it launches the full interactive wizard; otherwise it
runs the quick (scriptable) setup path.

This means all three install methods work with a single flow:
  curl -fsSL https://zeroclawlabs.ai/install.sh | bash
  cargo install zeroclawlabs && zeroclaw onboard
  docker run … zeroclaw onboard --api-key …
2026-03-15 19:25:55 -04:00
Argenis 811fab3b87 fix(service): headless browser works in service mode (systemd/OpenRC) (#3645)
When zeroclaw runs as a service, the process inherits a minimal
environment without HOME, DISPLAY, or user namespaces. Headless
browsers (Chromium/Firefox) need HOME for profile/cache dirs and
fail with sandbox errors without user namespaces.

- Detect service environment via INVOCATION_ID, JOURNAL_STREAM,
  or missing HOME on Linux
- Auto-apply --no-sandbox and --disable-dev-shm-usage for Chrome
  in service mode
- Set HOME fallback and CHROMIUM_FLAGS on agent-browser commands
- systemd unit: add Environment=HOME=%h and PassEnvironment
- OpenRC script: export HOME=/var/lib/zeroclaw with start_pre()
  to create the directory

Closes #3584
2026-03-15 19:16:36 -04:00
Argenis 1a5d91fe69 fix(channels): wire query_classification config into channel message processing (#3619)
The QueryClassificationConfig was parsed from config but never applied
during channel message processing. This adds the query_classification
field to ChannelRuntimeContext and invokes the classifier in
process_channel_message to override the route when a classification
rule matches a model_routes hint.

Closes #3579
2026-03-15 19:16:32 -04:00
Argenis 6eec1c81b9 fix(ci): use ubuntu-22.04 for Linux release builds (#3573)
Build against glibc 2.35 to ensure binary compatibility with Ubuntu 22.04+.
2026-03-15 18:57:30 -04:00
Argenis 602db8bca1 fix: exclude name field from Mistral tool_calls (#3572)
* fix: exclude name field from Mistral tool_calls (#3572)

Add skip_serializing_if to the compatibility fields (name, arguments,
parameters) on the ToolCall struct so they are omitted from the JSON
payload when None. Mistral's API returns 422 "Extra inputs are not
permitted" when these extra null fields are present in tool_calls.

* fix: format serde attribute for CI lint compliance
2026-03-15 18:38:41 -04:00
SimianAstronaut7 314e1d3ae8 Merge pull request #3638 from zeroclaw-labs/work-issues/3487-channel-approval-manager
fix(security): enforce approval policy for channel-driven runs
2026-03-15 16:11:14 -04:00
SimianAstronaut7 82be05b1e9 Merge pull request #3636 from zeroclaw-labs/work-issues/3628-surface-tool-failures-in-chat
feat(agent): surface tool call failure reasons in chat
2026-03-15 16:07:38 -04:00
SimianAstronaut7 1373659058 Merge pull request #3634 from zeroclaw-labs/work-issues/3477-fix-matrix-channel-key
fix(channel): use plain "matrix" channel key for consistent outbound routing
2026-03-15 16:07:36 -04:00
Argenis c7f064e866 fix(channels): surface visible warning when whatsapp-web feature is missing (#3629)
The WhatsApp Web QR code was not shown during onboarding channel launch
because the wizard allowed configuring WhatsApp Web mode even when the
binary was built without the `whatsapp-web` feature flag. At runtime,
the channel was silently skipped with only a tracing::warn that most
users never see.

- Add compile-time warning in the onboarding wizard when WhatsApp Web
  mode is selected but the feature is not compiled in
- Add eprintln! in collect_configured_channels so users see a visible
  terminal warning when the feature is missing at startup

Closes #3577
2026-03-15 16:07:13 -04:00
Giulio V 9c1d63e109 feat(hands): add autonomous knowledge-accumulating agent packages (#3603)
Introduce the Hands system — autonomous agent packages that run on
schedules and accumulate knowledge over time. Each Hand maintains a
rolling context of findings across runs so the agent grows smarter
with every execution.

This PR adds:
- Hand definition type (TOML-deserializable, reuses cron Schedule)
- HandRun / HandRunStatus for execution records
- HandContext for rolling cross-run knowledge accumulation
- File-based persistence (load/save context as JSON)
- Directory-based Hand loading from ~/.zeroclaw/hands/*.toml
- 20 unit tests covering deserialization, persistence roundtrip,
  history capping, fact deduplication, and error handling

Execution integration with the agent loop is deferred to a follow-up.
2026-03-15 16:06:14 -04:00
Argenis 966edf1553 Merge pull request #3635 from zeroclaw-labs/chore/bump-v0.3.4
chore: bump version to v0.3.4
2026-03-15 15:59:51 -04:00
simianastronaut a1af84d992 fix(security): enforce approval policy for channel-driven runs
Channel-driven runs (Telegram, Matrix, Discord, etc.) previously bypassed
the ApprovalManager entirely — `None` was passed into the tool-call loop,
so `auto_approve`, `always_ask`, and supervised approval checks were
silently skipped for all non-CLI execution paths.

Add a non-interactive mode to ApprovalManager that enforces the same
autonomy config policies but auto-denies tools requiring interactive
approval (since no operator is present on channel runs). Specifically:

- Add `ApprovalManager::for_non_interactive()` constructor that creates
  a manager which auto-denies tools needing approval instead of prompting
- Add `is_non_interactive()` method so the tool-call loop can distinguish
  interactive (CLI prompt) from non-interactive (auto-deny) managers
- Update tool-call loop: non-interactive managers auto-deny instead of
  the previous auto-approve behavior for non-CLI channels
- Wire the non-interactive approval manager into ChannelRuntimeContext
  so channel runs enforce the full approval policy
- Add 8 tests covering non-interactive approval behavior

Security implications:
- `always_ask` tools are now denied on channels (previously bypassed)
- Supervised-mode unknown tools are now denied on channels (previously
  bypassed)
- `auto_approve` tools continue to work on channels unchanged
- `full` autonomy mode is unaffected (no approval needed regardless)
- `read_only` mode is unaffected (blocks execution elsewhere)

Closes #3487

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 15:56:57 -04:00
simianastronaut 0ad1965081 feat(agent): surface tool call failure reasons in chat progress messages
When a tool call fails (security policy block, hook cancellation, user
denial, or execution error), the failure reason is now included in the
progress message sent to the chat channel via on_delta. Previously only
a  icon was shown; now users see the actual reason (e.g. "Command not
allowed by security policy") without needing to check `zeroclaw doctor
traces`.

Closes #3628

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 15:49:27 -04:00
argenis de la rosa 70e8e7ebcd chore: bump version to v0.3.4
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 15:44:59 -04:00
Alix-007 2bcb82c5b3 fix(python): point docs URL at master branch (#3334)
Co-authored-by: Alix-007 <Alix-007@users.noreply.github.com>
2026-03-15 15:43:35 -04:00
simianastronaut e211b5c3e3 fix(channel): use plain "matrix" channel key for consistent outbound routing
The Matrix channel listener was building channel keys as `matrix:<room_id>`,
but the runtime channel mapping expects the plain key `matrix`. This mismatch
caused replies to silently drop in deployments using the Matrix channel.

Closes #3477

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 15:42:43 -04:00
Argenis 8691476577 Merge pull request #3624 from zeroclaw-labs/feat/multi-swarm-and-bugfixes
feat(swarm): multi-agent swarm orchestration + bug fixes (#3572, #3573)
2026-03-15 15:42:03 -04:00
SimianAstronaut7 e34a804255 Merge pull request #3632 from zeroclaw-labs/work-issues/3544-fix-codex-sse-buffering
fix(provider): use incremental SSE stream reading for openai-codex responses
2026-03-15 15:34:39 -04:00
SimianAstronaut7 6120b3f705 Merge pull request #3630 from zeroclaw-labs/work-issues/3567-allow-commands-bypass-high-risk
fix(security): let explicit allowed_commands bypass high-risk block
2026-03-15 15:34:37 -04:00
SimianAstronaut7 f175261e32 Merge pull request #3631 from zeroclaw-labs/work-issues/3486-fix-matrix-image-marker
fix(channels): use canonical IMAGE marker in Matrix channel
2026-03-15 15:34:31 -04:00
simianastronaut fd9f66cad7 fix(provider): use incremental SSE stream reading for openai-codex responses
Replace full-body buffering (`response.text().await`) in
`decode_responses_body()` with incremental `bytes_stream()` chunk
processing.  The previous approach held the HTTP connection open until
every byte had arrived; on high-latency links the long-lived connection
would frequently drop mid-read, producing the "error decoding response
body" failure on the first attempt (succeeding only after retry).

Reading chunks incrementally lets each network segment complete within
its own timeout window, eliminating the systematic first-attempt failure.

Closes #3544

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 15:22:55 -04:00
simianastronaut d928ebc92e fix(channels): use canonical IMAGE marker in Matrix channel
Matrix image messages used lowercase `[image: ...]` format instead of
the canonical `[IMAGE:...]` marker used by all other channels (Telegram,
Slack, Discord, QQ, LinQ). This caused Matrix image attachments to
bypass the multimodal vision pipeline which looks for `[IMAGE:...]`.

Closes #3486

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 15:14:53 -04:00
simianastronaut 9fca9f478a fix(security): let explicit allowed_commands bypass high-risk block
When `block_high_risk_commands = true`, commands like `curl` and `wget`
were unconditionally blocked even if explicitly listed in
`allowed_commands`. This made it impossible to use legitimate API calls
in full autonomy mode.

Now, if a command is explicitly named in `allowed_commands` (not via
the wildcard `*`), it is exempt from the `block_high_risk_commands`
gate. The wildcard entry intentionally does NOT grant this exemption,
preserving the safety net for broad allowlists.

Other security gates (supervised-mode approval, rate limiting, path
policy, argument validation) remain fully enforced.

Closes #3567

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 15:13:32 -04:00
SimianAstronaut7 7106632b51 Merge pull request #3627 from zeroclaw-labs/work-issues/3533-fix-utf8-slice-panic
fix(agent): use char-boundary-safe slicing to prevent CJK text panic
2026-03-15 14:36:49 -04:00
SimianAstronaut7 b834278754 Merge pull request #3626 from zeroclaw-labs/work-issues/3563-fix-cron-add-nl-security
fix(cron): add --agent flag so natural language prompts bypass shell security
2026-03-15 14:36:46 -04:00
SimianAstronaut7 186f6d9797 Merge pull request #3625 from zeroclaw-labs/work-issues/3568-http-request-private-hosts
feat(tool): add allow_private_hosts option to http_request tool
2026-03-15 14:36:44 -04:00
simianastronaut 6cdc92a256 fix(agent): use char-boundary-safe slicing to prevent CJK text panic
Replace unsafe byte-index string slicing (`&text[..N]`) with
char-boundary-safe alternatives in memory consolidation and security
redaction to prevent panics when multi-byte UTF-8 characters (e.g.
Chinese/Japanese/Korean) span the slice boundary.

Fixes the same class of bug as the prior fix in `execute_one_tool`
(commit 8fcbb6eb), applied to two remaining instances:
- `src/memory/consolidation.rs`: truncation at byte 4000 and 200
- `src/security/mod.rs`: `redact()` prefix at byte 4

Adds regression tests with CJK input for both locations.

Closes #3533

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 14:27:09 -04:00
simianastronaut 02599dcd3c fix(cron): add --agent flag to CLI cron commands to bypass shell security validation
The CLI `cron add` command always routed the second positional argument
through shell security policy validation, which blocked natural language
prompts like "Check server health: disk space, memory, CPU load". This
adds an `--agent` flag to `cron add`, `cron add-at`, `cron add-every`,
and `cron once` so that natural language prompts are correctly stored as
agent jobs without shell command validation.

Closes #3563

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 14:26:45 -04:00
simianastronaut fe64d7ef7e feat(tool): add allow_private_hosts option to http_request tool (#3568)
The http_request tool unconditionally blocked all private/LAN hosts with
no opt-out, preventing legitimate use cases like calling a local Home
Assistant instance or internal APIs. This adds an `allow_private_hosts`
config flag (default: false) under `[http_request]` that, when set to
true, skips the private-host SSRF check while still enforcing the domain
allowlist.

Closes #3568

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 14:23:54 -04:00
argenis de la rosa 996dbe95cf feat(swarm): multi-agent swarm orchestration, Mistral tool fix, restore --interactive
- Add SwarmTool with sequential (pipeline), parallel (fan-out/fan-in),
  and router (LLM-selected) strategies for multi-agent workflows
- Add SwarmConfig and SwarmStrategy to config schema
- Fix Mistral 422 error by adding skip_serializing_if to ToolCall
  compat fields (name, arguments, parameters, kind) — Fixes #3572
- Restore `zeroclaw onboard --interactive` flag with run_wizard
  routing and mutual-exclusion validation — Fixes #3573
- 20 new swarm tests, 2 serialization tests, 1 CLI test, config tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 14:23:20 -04:00
Argenis 45f953be6d Merge pull request #3578 from zeroclaw-labs/chore/bump-v0.3.3
chore: bump version to v0.3.3
2026-03-15 09:53:13 -04:00
argenis de la rosa 82f29bbcb1 chore: bump version to v0.3.3
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 09:41:22 -04:00
54 changed files with 12663 additions and 163 deletions
+4 -2
View File
@@ -155,11 +155,13 @@ jobs:
fail-fast: false
matrix:
include:
- os: ubuntu-latest
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
# ensuring compatibility with Ubuntu 22.04+ (#3573).
- os: ubuntu-22.04
target: x86_64-unknown-linux-gnu
artifact: zeroclaw
ext: tar.gz
- os: ubuntu-latest
- os: ubuntu-22.04
target: aarch64-unknown-linux-gnu
artifact: zeroclaw
ext: tar.gz
+4 -2
View File
@@ -156,11 +156,13 @@ jobs:
fail-fast: false
matrix:
include:
- os: ubuntu-latest
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
# ensuring compatibility with Ubuntu 22.04+ (#3573).
- os: ubuntu-22.04
target: x86_64-unknown-linux-gnu
artifact: zeroclaw
ext: tar.gz
- os: ubuntu-latest
- os: ubuntu-22.04
target: aarch64-unknown-linux-gnu
artifact: zeroclaw
ext: tar.gz
Generated
+1 -1
View File
@@ -7945,7 +7945,7 @@ dependencies = [
[[package]]
name = "zeroclawlabs"
version = "0.3.2"
version = "0.3.4"
dependencies = [
"anyhow",
"async-imap",
+1 -1
View File
@@ -4,7 +4,7 @@ resolver = "2"
[package]
name = "zeroclawlabs"
version = "0.3.2"
version = "0.3.4"
edition = "2021"
authors = ["theonlyhennygod"]
license = "MIT OR Apache-2.0"
+78 -1
View File
@@ -37,6 +37,7 @@ pub struct Agent {
classification_config: crate::config::QueryClassificationConfig,
available_hints: Vec<String>,
route_model_by_hint: HashMap<String, String>,
allowed_tools: Option<Vec<String>>,
}
pub struct AgentBuilder {
@@ -58,6 +59,7 @@ pub struct AgentBuilder {
classification_config: Option<crate::config::QueryClassificationConfig>,
available_hints: Option<Vec<String>>,
route_model_by_hint: Option<HashMap<String, String>>,
allowed_tools: Option<Vec<String>>,
}
impl AgentBuilder {
@@ -81,6 +83,7 @@ impl AgentBuilder {
classification_config: None,
available_hints: None,
route_model_by_hint: None,
allowed_tools: None,
}
}
@@ -180,10 +183,19 @@ impl AgentBuilder {
self
}
pub fn allowed_tools(mut self, allowed_tools: Option<Vec<String>>) -> Self {
self.allowed_tools = allowed_tools;
self
}
pub fn build(self) -> Result<Agent> {
let tools = self
let mut tools = self
.tools
.ok_or_else(|| anyhow::anyhow!("tools are required"))?;
let allowed = self.allowed_tools.clone();
if let Some(ref allow_list) = allowed {
tools.retain(|t| allow_list.iter().any(|name| name == t.name()));
}
let tool_specs = tools.iter().map(|tool| tool.spec()).collect();
Ok(Agent {
@@ -223,6 +235,7 @@ impl AgentBuilder {
classification_config: self.classification_config.unwrap_or_default(),
available_hints: self.available_hints.unwrap_or_default(),
route_model_by_hint: self.route_model_by_hint.unwrap_or_default(),
allowed_tools: allowed,
})
}
}
@@ -892,4 +905,68 @@ mod tests {
let seen = seen_models.lock();
assert_eq!(seen.as_slice(), &["hint:fast".to_string()]);
}
#[test]
fn builder_allowed_tools_none_keeps_all_tools() {
let provider = Box::new(MockProvider {
responses: Mutex::new(vec![]),
});
let memory_cfg = crate::config::MemoryConfig {
backend: "none".into(),
..crate::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> = Arc::from(
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
.expect("memory creation should succeed with valid config"),
);
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
let agent = Agent::builder()
.provider(provider)
.tools(vec![Box::new(MockTool)])
.memory(mem)
.observer(observer)
.tool_dispatcher(Box::new(NativeToolDispatcher))
.workspace_dir(std::path::PathBuf::from("/tmp"))
.allowed_tools(None)
.build()
.expect("agent builder should succeed with valid config");
assert_eq!(agent.tool_specs.len(), 1);
assert_eq!(agent.tool_specs[0].name, "echo");
}
#[test]
fn builder_allowed_tools_some_filters_tools() {
let provider = Box::new(MockProvider {
responses: Mutex::new(vec![]),
});
let memory_cfg = crate::config::MemoryConfig {
backend: "none".into(),
..crate::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> = Arc::from(
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
.expect("memory creation should succeed with valid config"),
);
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
let agent = Agent::builder()
.provider(provider)
.tools(vec![Box::new(MockTool)])
.memory(mem)
.observer(observer)
.tool_dispatcher(Box::new(NativeToolDispatcher))
.workspace_dir(std::path::PathBuf::from("/tmp"))
.allowed_tools(Some(vec!["nonexistent".to_string()]))
.build()
.expect("agent builder should succeed with valid config");
assert!(
agent.tool_specs.is_empty(),
"No tools should match a non-existent allowlist entry"
);
}
}
+233 -10
View File
@@ -93,6 +93,24 @@ pub(crate) fn filter_tool_specs_for_turn(
.collect()
}
/// Filters a tool spec list by an optional capability allowlist.
///
/// When `allowed` is `None`, all specs pass through unchanged.
/// When `allowed` is `Some(list)`, only specs whose name appears in the list
/// are retained. Unknown names in the allowlist are silently ignored.
pub(crate) fn filter_by_allowed_tools(
specs: Vec<crate::tools::ToolSpec>,
allowed: Option<&[String]>,
) -> Vec<crate::tools::ToolSpec> {
match allowed {
None => specs,
Some(list) => specs
.into_iter()
.filter(|spec| list.iter().any(|name| name == &spec.name))
.collect(),
}
}
/// Computes the list of MCP tool names that should be excluded for a given turn
/// based on `tool_filter_groups` and the user message.
///
@@ -2660,6 +2678,15 @@ pub(crate) async fn run_tool_call_loop(
"arguments": scrub_credentials(&tool_args.to_string()),
}),
);
if let Some(ref tx) = on_delta {
let _ = tx
.send(format!(
"\u{274c} {}: {}\n",
call.name,
truncate_with_ellipsis(&scrub_credentials(&cancelled), 200)
))
.await;
}
ordered_results[idx] = Some((
call.name.clone(),
call.tool_call_id.clone(),
@@ -2687,11 +2714,13 @@ pub(crate) async fn run_tool_call_loop(
arguments: tool_args.clone(),
};
// Only prompt interactively on CLI; auto-approve on other channels.
let decision = if channel_name == "cli" {
mgr.prompt_cli(&request)
// Interactive CLI: prompt the operator.
// Non-interactive (channels): auto-deny since no operator
// is present to approve.
let decision = if mgr.is_non_interactive() {
ApprovalResponse::No
} else {
ApprovalResponse::Yes
mgr.prompt_cli(&request)
};
mgr.record_decision(&tool_name, &tool_args, decision, channel_name);
@@ -2712,6 +2741,11 @@ pub(crate) async fn run_tool_call_loop(
"arguments": scrub_credentials(&tool_args.to_string()),
}),
);
if let Some(ref tx) = on_delta {
let _ = tx
.send(format!("\u{274c} {}: {}\n", tool_name, denied))
.await;
}
ordered_results[idx] = Some((
tool_name.clone(),
call.tool_call_id.clone(),
@@ -2748,6 +2782,11 @@ pub(crate) async fn run_tool_call_loop(
"deduplicated": true,
}),
);
if let Some(ref tx) = on_delta {
let _ = tx
.send(format!("\u{274c} {}: {}\n", tool_name, duplicate))
.await;
}
ordered_results[idx] = Some((
tool_name.clone(),
call.tool_call_id.clone(),
@@ -2850,13 +2889,19 @@ pub(crate) async fn run_tool_call_loop(
// ── Progress: tool completion ───────────────────────
if let Some(ref tx) = on_delta {
let secs = outcome.duration.as_secs();
let icon = if outcome.success {
"\u{2705}"
let progress_msg = if outcome.success {
format!("\u{2705} {} ({secs}s)\n", call.name)
} else if let Some(ref reason) = outcome.error_reason {
format!(
"\u{274c} {} ({secs}s): {}\n",
call.name,
truncate_with_ellipsis(reason, 200)
)
} else {
"\u{274c}"
format!("\u{274c} {} ({secs}s)\n", call.name)
};
tracing::debug!(tool = %call.name, secs, "Sending progress complete to draft");
let _ = tx.send(format!("{icon} {} ({secs}s)\n", call.name)).await;
let _ = tx.send(progress_msg).await;
}
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
@@ -2967,6 +3012,7 @@ pub async fn run(
peripheral_overrides: Vec<String>,
interactive: bool,
session_state_file: Option<PathBuf>,
allowed_tools: Option<Vec<String>>,
) -> Result<String> {
// ── Wire up agnostic subsystems ──────────────────────────────
let base_observer = observability::create_observer(&config.observability);
@@ -3028,6 +3074,19 @@ pub async fn run(
tools_registry.extend(peripheral_tools);
}
// ── Capability-based tool access control ─────────────────────
// When `allowed_tools` is `Some(list)`, restrict the tool registry to only
// those tools whose name appears in the list. Unknown names are silently
// ignored. When `None`, all tools remain available (backward compatible).
if let Some(ref allow_list) = allowed_tools {
tools_registry.retain(|t| allow_list.iter().any(|name| name == t.name()));
tracing::info!(
allowed = allow_list.len(),
retained = tools_registry.len(),
"Applied capability-based tool access filter"
);
}
// ── Wire MCP tools (non-fatal) — CLI path ────────────────────
// NOTE: MCP tools are injected after built-in tool filtering
// (filter_primary_agent_tools_or_fail / agent.allowed_tools / agent.denied_tools).
@@ -3847,7 +3906,7 @@ mod tests {
use std::time::Duration;
#[test]
fn test_scrub_credentials() {
fn scrub_credentials_redacts_bearer_token() {
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
let scrubbed = scrub_credentials(input);
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
@@ -3858,7 +3917,7 @@ mod tests {
}
#[test]
fn test_scrub_credentials_json() {
fn scrub_credentials_redacts_json_api_key() {
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
let scrubbed = scrub_credentials(input);
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
@@ -4135,6 +4194,52 @@ mod tests {
}
}
/// A tool that always returns a failure with a given error reason.
struct FailingTool {
tool_name: String,
error_reason: String,
}
impl FailingTool {
fn new(name: &str, error_reason: &str) -> Self {
Self {
tool_name: name.to_string(),
error_reason: error_reason.to_string(),
}
}
}
#[async_trait]
impl Tool for FailingTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
"A tool that always fails for testing failure surfacing"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"command": { "type": "string" }
}
})
}
async fn execute(
&self,
_args: serde_json::Value,
) -> anyhow::Result<crate::tools::ToolResult> {
Ok(crate::tools::ToolResult {
success: false,
output: String::new(),
error: Some(self.error_reason.clone()),
})
}
}
#[tokio::test]
async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() {
let calls = Arc::new(AtomicUsize::new(0));
@@ -6501,4 +6606,122 @@ Let me check the result."#;
let tokens = super::estimate_history_tokens(&history);
assert_eq!(tokens, 23);
}
#[tokio::test]
async fn run_tool_call_loop_surfaces_tool_failure_reason_in_on_delta() {
let provider = ScriptedProvider::from_text_responses(vec![
r#"<tool_call>
{"name":"failing_shell","arguments":{"command":"rm -rf /"}}
</tool_call>"#,
"I could not execute that command.",
]);
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(FailingTool::new(
"failing_shell",
"Command not allowed by security policy: rm -rf /",
))];
let mut history = vec![
ChatMessage::system("test-system"),
ChatMessage::user("delete everything"),
];
let observer = NoopObserver;
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(64);
let result = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"telegram",
&crate::config::MultimodalConfig::default(),
4,
None,
Some(tx),
None,
&[],
&[],
)
.await
.expect("tool loop should complete");
// Collect all messages sent to the on_delta channel.
let mut deltas = Vec::new();
while let Ok(msg) = rx.try_recv() {
deltas.push(msg);
}
let all_deltas = deltas.join("");
// The failure reason should appear in the progress messages.
assert!(
all_deltas.contains("Command not allowed by security policy"),
"on_delta messages should include the tool failure reason, got: {all_deltas}"
);
// Should also contain the cross mark (❌) icon to indicate failure.
assert!(
all_deltas.contains('\u{274c}'),
"on_delta messages should include ❌ for failed tool calls, got: {all_deltas}"
);
assert_eq!(result, "I could not execute that command.");
}
// ── filter_by_allowed_tools tests ─────────────────────────────────────
#[test]
fn filter_by_allowed_tools_none_passes_all() {
let specs = vec![
make_spec("shell"),
make_spec("memory_store"),
make_spec("file_read"),
];
let result = filter_by_allowed_tools(specs, None);
assert_eq!(result.len(), 3);
}
#[test]
fn filter_by_allowed_tools_some_restricts_to_listed() {
let specs = vec![
make_spec("shell"),
make_spec("memory_store"),
make_spec("file_read"),
];
let allowed = vec!["shell".to_string(), "memory_store".to_string()];
let result = filter_by_allowed_tools(specs, Some(&allowed));
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
assert_eq!(names.len(), 2);
assert!(names.contains(&"shell"));
assert!(names.contains(&"memory_store"));
assert!(!names.contains(&"file_read"));
}
#[test]
fn filter_by_allowed_tools_unknown_names_silently_ignored() {
let specs = vec![make_spec("shell"), make_spec("file_read")];
let allowed = vec![
"shell".to_string(),
"nonexistent_tool".to_string(),
"another_missing".to_string(),
];
let result = filter_by_allowed_tools(specs, Some(&allowed));
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
assert_eq!(names.len(), 1);
assert!(names.contains(&"shell"));
}
#[test]
fn filter_by_allowed_tools_empty_list_excludes_all() {
let specs = vec![make_spec("shell"), make_spec("file_read")];
let allowed: Vec<String> = vec![];
let result = filter_by_allowed_tools(specs, Some(&allowed));
assert!(result.is_empty());
}
}
+128 -4
View File
@@ -44,11 +44,18 @@ pub struct ApprovalLogEntry {
// ── ApprovalManager ──────────────────────────────────────────────
/// Manages the interactive approval workflow.
/// Manages the approval workflow for tool calls.
///
/// - Checks config-level `auto_approve` / `always_ask` lists
/// - Maintains a session-scoped "always" allowlist
/// - Records an audit trail of all decisions
///
/// Two modes:
/// - **Interactive** (CLI): tools needing approval trigger a stdin prompt.
/// - **Non-interactive** (channels): tools needing approval are auto-denied
/// because there is no interactive operator to approve them. `auto_approve`
/// policy is still enforced, and `always_ask` / supervised-default tools are
/// denied rather than silently allowed.
pub struct ApprovalManager {
/// Tools that never need approval (from config).
auto_approve: HashSet<String>,
@@ -56,6 +63,9 @@ pub struct ApprovalManager {
always_ask: HashSet<String>,
/// Autonomy level from config.
autonomy_level: AutonomyLevel,
/// When `true`, tools that would require interactive approval are
/// auto-denied instead. Used for channel-driven (non-CLI) runs.
non_interactive: bool,
/// Session-scoped allowlist built from "Always" responses.
session_allowlist: Mutex<HashSet<String>>,
/// Audit trail of approval decisions.
@@ -63,17 +73,40 @@ pub struct ApprovalManager {
}
impl ApprovalManager {
/// Create from autonomy config.
/// Create an interactive (CLI) approval manager from autonomy config.
pub fn from_config(config: &AutonomyConfig) -> Self {
Self {
auto_approve: config.auto_approve.iter().cloned().collect(),
always_ask: config.always_ask.iter().cloned().collect(),
autonomy_level: config.level,
non_interactive: false,
session_allowlist: Mutex::new(HashSet::new()),
audit_log: Mutex::new(Vec::new()),
}
}
/// Create a non-interactive approval manager for channel-driven runs.
///
/// Enforces the same `auto_approve` / `always_ask` / supervised policies
/// as the CLI manager, but tools that would require interactive approval
/// are auto-denied instead of prompting (since there is no operator).
pub fn for_non_interactive(config: &AutonomyConfig) -> Self {
Self {
auto_approve: config.auto_approve.iter().cloned().collect(),
always_ask: config.always_ask.iter().cloned().collect(),
autonomy_level: config.level,
non_interactive: true,
session_allowlist: Mutex::new(HashSet::new()),
audit_log: Mutex::new(Vec::new()),
}
}
/// Returns `true` when this manager operates in non-interactive mode
/// (i.e. for channel-driven runs where no operator can approve).
pub fn is_non_interactive(&self) -> bool {
self.non_interactive
}
/// Check whether a tool call requires interactive approval.
///
/// Returns `true` if the call needs a prompt, `false` if it can proceed.
@@ -147,8 +180,8 @@ impl ApprovalManager {
/// Prompt the user on the CLI and return their decision.
///
/// For non-CLI channels, returns `Yes` automatically (interactive
/// approval is only supported on CLI for now).
/// Only called for interactive (CLI) managers. Non-interactive managers
/// auto-deny in the tool-call loop before reaching this point.
pub fn prompt_cli(&self, request: &ApprovalRequest) -> ApprovalResponse {
prompt_cli_interactive(request)
}
@@ -401,6 +434,97 @@ mod tests {
assert!(summary.contains("just a string"));
}
// ── non-interactive (channel) mode ────────────────────────
#[test]
fn non_interactive_manager_reports_non_interactive() {
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
assert!(mgr.is_non_interactive());
}
#[test]
fn interactive_manager_reports_interactive() {
let mgr = ApprovalManager::from_config(&supervised_config());
assert!(!mgr.is_non_interactive());
}
#[test]
fn non_interactive_auto_approve_tools_skip_approval() {
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
// auto_approve tools (file_read, memory_recall) should not need approval.
assert!(!mgr.needs_approval("file_read"));
assert!(!mgr.needs_approval("memory_recall"));
}
#[test]
fn non_interactive_always_ask_tools_need_approval() {
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
// always_ask tools (shell) still report as needing approval,
// so the tool-call loop will auto-deny them in non-interactive mode.
assert!(mgr.needs_approval("shell"));
}
#[test]
fn non_interactive_unknown_tools_need_approval_in_supervised() {
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
// Unknown tools in supervised mode need approval (will be auto-denied
// by the tool-call loop for non-interactive managers).
assert!(mgr.needs_approval("file_write"));
assert!(mgr.needs_approval("http_request"));
}
#[test]
fn non_interactive_full_autonomy_never_needs_approval() {
let mgr = ApprovalManager::for_non_interactive(&full_config());
// Full autonomy means no approval needed, even in non-interactive mode.
assert!(!mgr.needs_approval("shell"));
assert!(!mgr.needs_approval("file_write"));
assert!(!mgr.needs_approval("anything"));
}
#[test]
fn non_interactive_readonly_never_needs_approval() {
let config = AutonomyConfig {
level: AutonomyLevel::ReadOnly,
..AutonomyConfig::default()
};
let mgr = ApprovalManager::for_non_interactive(&config);
// ReadOnly blocks execution elsewhere; approval manager does not prompt.
assert!(!mgr.needs_approval("shell"));
}
#[test]
fn non_interactive_session_allowlist_still_works() {
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
assert!(mgr.needs_approval("file_write"));
// Simulate an "Always" decision (would come from a prior channel run
// if the tool was auto-approved somehow, e.g. via config change).
mgr.record_decision(
"file_write",
&serde_json::json!({"path": "test.txt"}),
ApprovalResponse::Always,
"telegram",
);
assert!(!mgr.needs_approval("file_write"));
}
#[test]
fn non_interactive_always_ask_overrides_session_allowlist() {
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
mgr.record_decision(
"shell",
&serde_json::json!({"command": "ls"}),
ApprovalResponse::Always,
"telegram",
);
// shell is in always_ask, so it still needs approval even after "Always".
assert!(mgr.needs_approval("shell"));
}
// ── ApprovalResponse serde ───────────────────────────────
#[test]
+2 -2
View File
@@ -746,7 +746,7 @@ impl Channel for MatrixChannel {
MessageType::Notice(content) => (content.body.clone(), None),
MessageType::Image(content) => {
let dl = media_info(&content.source, &content.body);
(format!("[image: {}]", content.body), dl)
(format!("[IMAGE:{}]", content.body), dl)
}
MessageType::File(content) => {
let dl = media_info(&content.source, &content.body);
@@ -888,7 +888,7 @@ impl Channel for MatrixChannel {
sender: sender.clone(),
reply_target: format!("{}||{}", sender, room.room_id()),
content: body,
channel: format!("matrix:{}", room.room_id()),
channel: "matrix".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
+583 -2
View File
@@ -30,6 +30,7 @@ pub mod mattermost;
pub mod nextcloud_talk;
#[cfg(feature = "channel-nostr")]
pub mod nostr;
pub mod notion;
pub mod qq;
pub mod session_store;
pub mod signal;
@@ -62,6 +63,7 @@ pub use mattermost::MattermostChannel;
pub use nextcloud_talk::NextcloudTalkChannel;
#[cfg(feature = "channel-nostr")]
pub use nostr::NostrChannel;
pub use notion::NotionChannel;
pub use qq::QQChannel;
pub use signal::SignalChannel;
pub use slack::SlackChannel;
@@ -76,6 +78,7 @@ pub use whatsapp::WhatsAppChannel;
pub use whatsapp_web::WhatsAppWebChannel;
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
use crate::approval::ApprovalManager;
use crate::config::Config;
use crate::identity;
use crate::memory::{self, Memory};
@@ -311,9 +314,15 @@ struct ChannelRuntimeContext {
non_cli_excluded_tools: Arc<Vec<String>>,
tool_call_dedup_exempt: Arc<Vec<String>>,
model_routes: Arc<Vec<crate::config::ModelRouteConfig>>,
query_classification: crate::config::QueryClassificationConfig,
ack_reactions: bool,
show_tool_calls: bool,
session_store: Option<Arc<session_store::SessionStore>>,
/// Non-interactive approval manager for channel-driven runs.
/// Enforces `auto_approve` / `always_ask` / supervised policy from
/// `[autonomy]` config; auto-denies tools that would need interactive
/// approval since no operator is present on channel runs.
approval_manager: Arc<ApprovalManager>,
}
#[derive(Clone)]
@@ -1786,7 +1795,31 @@ async fn process_channel_message(
}
let history_key = conversation_history_key(&msg);
let route = get_route_selection(ctx.as_ref(), &history_key);
let mut route = get_route_selection(ctx.as_ref(), &history_key);
// ── Query classification: override route when a rule matches ──
if let Some(hint) = crate::agent::classifier::classify(&ctx.query_classification, &msg.content)
{
if let Some(matched_route) = ctx
.model_routes
.iter()
.find(|r| r.hint.eq_ignore_ascii_case(&hint))
{
tracing::info!(
target: "query_classification",
hint = hint.as_str(),
provider = matched_route.provider.as_str(),
model = matched_route.model.as_str(),
channel = %msg.channel,
"Channel message classified — overriding route"
);
route = ChannelRouteSelection {
provider: matched_route.provider.clone(),
model: matched_route.model.clone(),
};
}
}
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
Ok(provider) => provider,
@@ -2025,7 +2058,7 @@ async fn process_channel_message(
route.model.as_str(),
runtime_defaults.temperature,
true,
None,
Some(&*ctx.approval_manager),
msg.channel.as_str(),
&ctx.multimodal,
ctx.max_tool_iterations,
@@ -2950,6 +2983,12 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con
channel.name()
);
}
// Notion is a top-level config section, not part of ChannelsConfig
{
let notion_configured =
config.notion.enabled && !config.notion.database_id.trim().is_empty();
println!(" {} Notion", if notion_configured { "" } else { "" });
}
if !cfg!(feature = "channel-matrix") {
println!(
" ️ Matrix channel support is disabled in this build (enable `channel-matrix`)."
@@ -3235,6 +3274,8 @@ fn collect_configured_channels(
#[cfg(not(feature = "whatsapp-web"))]
{
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
eprintln!(" ⚠ WhatsApp Web is configured but the 'whatsapp-web' feature is not compiled in.");
eprintln!(" Rebuild with: cargo build --features whatsapp-web");
}
}
_ => {
@@ -3380,6 +3421,34 @@ fn collect_configured_channels(
});
}
// Notion database poller channel
if config.notion.enabled && !config.notion.database_id.trim().is_empty() {
let notion_api_key = if config.notion.api_key.trim().is_empty() {
std::env::var("NOTION_API_KEY").unwrap_or_default()
} else {
config.notion.api_key.trim().to_string()
};
if notion_api_key.trim().is_empty() {
tracing::warn!(
"Notion channel enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
);
} else {
channels.push(ConfiguredChannel {
display_name: "Notion",
channel: Arc::new(NotionChannel::new(
notion_api_key,
config.notion.database_id.clone(),
config.notion.poll_interval_secs,
config.notion.status_property.clone(),
config.notion.input_property.clone(),
config.notion.result_property.clone(),
config.notion.max_concurrent,
config.notion.recover_stale,
)),
});
}
}
channels
}
@@ -3835,6 +3904,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
non_cli_excluded_tools: Arc::new(config.autonomy.non_cli_excluded_tools.clone()),
tool_call_dedup_exempt: Arc::new(config.agent.tool_call_dedup_exempt.clone()),
model_routes: Arc::new(config.model_routes.clone()),
query_classification: config.query_classification.clone(),
ack_reactions: config.channels_config.ack_reactions,
show_tool_calls: config.channels_config.show_tool_calls,
session_store: if config.channels_config.session_persistence {
@@ -3851,6 +3921,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
} else {
None
},
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
});
// Hydrate in-memory conversation histories from persisted JSONL session files.
@@ -4136,9 +4207,13 @@ mod tests {
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
};
assert!(compact_sender_history(&ctx, &sender));
@@ -4240,9 +4315,13 @@ mod tests {
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
};
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
@@ -4300,9 +4379,13 @@ mod tests {
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
};
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
@@ -4818,9 +4901,13 @@ BTC is currently around $65,000 based on latest tool output."#
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -4886,9 +4973,13 @@ BTC is currently around $65,000 based on latest tool output."#
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -4968,9 +5059,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5035,9 +5130,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5112,9 +5211,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5209,9 +5312,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5288,9 +5395,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5382,9 +5493,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5461,9 +5576,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5530,9 +5649,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -5710,9 +5833,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
@@ -5798,9 +5925,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -5904,6 +6035,10 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
query_classification: crate::config::QueryClassificationConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -6001,9 +6136,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -6083,9 +6222,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -6150,9 +6293,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -6775,9 +6922,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -6868,9 +7019,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -6961,9 +7116,13 @@ BTC is currently around $65,000 based on latest tool output."#
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -7518,9 +7677,13 @@ This is an example JSON object for profile settings."#;
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
// Simulate a photo attachment message with [IMAGE:] marker.
@@ -7592,9 +7755,13 @@ This is an example JSON object for profile settings."#;
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(Vec::new()),
query_classification: crate::config::QueryClassificationConfig::default(),
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
@@ -7674,6 +7841,420 @@ This is an example JSON object for profile settings."#;
}
}
// ── Query classification in channel message processing ─────────
#[tokio::test]
async fn process_channel_message_applies_query_classification_route() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
let classification_config = crate::config::QueryClassificationConfig {
enabled: true,
rules: vec![crate::config::schema::ClassificationRule {
hint: "vision".into(),
keywords: vec!["analyze-image".into()],
..Default::default()
}],
};
let model_routes = vec![crate::config::ModelRouteConfig {
hint: "vision".into(),
provider: "vision-provider".into(),
model: "gpt-4-vision".into(),
api_key: None,
}];
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&default_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: InterruptOnNewMessageConfig {
telegram: false,
slack: false,
},
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(model_routes),
query_classification: classification_config,
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-qc-1".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "please analyze-image from the dataset".to_string(),
channel: "telegram".to_string(),
timestamp: 1,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
// Vision provider should have been called instead of the default.
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 1);
assert_eq!(
vision_provider_impl
.models
.lock()
.unwrap_or_else(|e| e.into_inner())
.as_slice(),
&["gpt-4-vision".to_string()]
);
}
#[tokio::test]
async fn process_channel_message_classification_disabled_uses_default_route() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
// Classification is disabled — matching keyword should NOT trigger reroute.
let classification_config = crate::config::QueryClassificationConfig {
enabled: false,
rules: vec![crate::config::schema::ClassificationRule {
hint: "vision".into(),
keywords: vec!["analyze-image".into()],
..Default::default()
}],
};
let model_routes = vec![crate::config::ModelRouteConfig {
hint: "vision".into(),
provider: "vision-provider".into(),
model: "gpt-4-vision".into(),
api_key: None,
}];
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&default_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: InterruptOnNewMessageConfig {
telegram: false,
slack: false,
},
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(model_routes),
query_classification: classification_config,
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-qc-disabled".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "please analyze-image from the dataset".to_string(),
channel: "telegram".to_string(),
timestamp: 1,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
// Default provider should be used since classification is disabled.
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn process_channel_message_classification_no_match_uses_default_route() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
// Classification enabled with a rule that won't match the message.
let classification_config = crate::config::QueryClassificationConfig {
enabled: true,
rules: vec![crate::config::schema::ClassificationRule {
hint: "vision".into(),
keywords: vec!["analyze-image".into()],
..Default::default()
}],
};
let model_routes = vec![crate::config::ModelRouteConfig {
hint: "vision".into(),
provider: "vision-provider".into(),
model: "gpt-4-vision".into(),
api_key: None,
}];
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&default_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: InterruptOnNewMessageConfig {
telegram: false,
slack: false,
},
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(model_routes),
query_classification: classification_config,
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-qc-nomatch".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "just a regular text message".to_string(),
channel: "telegram".to_string(),
timestamp: 1,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
// Default provider should be used since no classification rule matched.
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn process_channel_message_classification_priority_selects_highest() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
let fast_provider_impl = Arc::new(ModelCaptureProvider::default());
let fast_provider: Arc<dyn Provider> = fast_provider_impl.clone();
let code_provider_impl = Arc::new(ModelCaptureProvider::default());
let code_provider: Arc<dyn Provider> = code_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
provider_cache_seed.insert("fast-provider".to_string(), fast_provider);
provider_cache_seed.insert("code-provider".to_string(), code_provider);
// Both rules match "code" keyword, but "code" rule has higher priority.
let classification_config = crate::config::QueryClassificationConfig {
enabled: true,
rules: vec![
crate::config::schema::ClassificationRule {
hint: "fast".into(),
keywords: vec!["code".into()],
priority: 1,
..Default::default()
},
crate::config::schema::ClassificationRule {
hint: "code".into(),
keywords: vec!["code".into()],
priority: 10,
..Default::default()
},
],
};
let model_routes = vec![
crate::config::ModelRouteConfig {
hint: "fast".into(),
provider: "fast-provider".into(),
model: "fast-model".into(),
api_key: None,
},
crate::config::ModelRouteConfig {
hint: "code".into(),
provider: "code-provider".into(),
model: "code-model".into(),
api_key: None,
},
];
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&default_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: InterruptOnNewMessageConfig {
telegram: false,
slack: false,
},
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
non_cli_excluded_tools: Arc::new(Vec::new()),
tool_call_dedup_exempt: Arc::new(Vec::new()),
model_routes: Arc::new(model_routes),
query_classification: classification_config,
ack_reactions: true,
show_tool_calls: true,
session_store: None,
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-qc-prio".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "write some code for me".to_string(),
channel: "telegram".to_string(),
timestamp: 1,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
// Higher-priority "code" rule (priority=10) should win over "fast" (priority=1).
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
assert_eq!(fast_provider_impl.call_count.load(Ordering::SeqCst), 0);
assert_eq!(code_provider_impl.call_count.load(Ordering::SeqCst), 1);
assert_eq!(
code_provider_impl
.models
.lock()
.unwrap_or_else(|e| e.into_inner())
.as_slice(),
&["code-model".to_string()]
);
}
#[test]
fn build_channel_by_id_unconfigured_telegram_returns_error() {
let config = Config::default();
+614
View File
@@ -0,0 +1,614 @@
use super::traits::{Channel, ChannelMessage, SendMessage};
use anyhow::{bail, Result};
use async_trait::async_trait;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
const NOTION_VERSION: &str = "2022-06-28";
const MAX_RESULT_LENGTH: usize = 2000;
const MAX_RETRIES: u32 = 3;
const RETRY_BASE_DELAY_MS: u64 = 2000;
/// Maximum number of characters to include from an error response body.
const MAX_ERROR_BODY_CHARS: usize = 500;
/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary.
fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
if max_bytes >= s.len() {
return s.len();
}
let mut idx = max_bytes;
while idx > 0 && !s.is_char_boundary(idx) {
idx -= 1;
}
idx
}
/// Notion channel — polls a Notion database for pending tasks and writes results back.
///
/// The channel connects to the Notion API, queries a database for rows with a "pending"
/// status, dispatches them as channel messages, and writes results back when processing
/// completes. It supports crash recovery by resetting stale "running" tasks on startup.
pub struct NotionChannel {
api_key: String,
database_id: String,
poll_interval_secs: u64,
status_property: String,
input_property: String,
result_property: String,
max_concurrent: usize,
status_type: Arc<RwLock<String>>,
inflight: Arc<RwLock<HashSet<String>>>,
http: reqwest::Client,
recover_stale: bool,
}
impl NotionChannel {
/// Create a new Notion channel with the given configuration.
pub fn new(
api_key: String,
database_id: String,
poll_interval_secs: u64,
status_property: String,
input_property: String,
result_property: String,
max_concurrent: usize,
recover_stale: bool,
) -> Self {
Self {
api_key,
database_id,
poll_interval_secs,
status_property,
input_property,
result_property,
max_concurrent,
status_type: Arc::new(RwLock::new("select".to_string())),
inflight: Arc::new(RwLock::new(HashSet::new())),
http: reqwest::Client::new(),
recover_stale,
}
}
/// Build the standard Notion API headers (Authorization, version, content-type).
fn headers(&self) -> Result<reqwest::header::HeaderMap> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
);
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
headers.insert("Content-Type", "application/json".parse().unwrap());
Ok(headers)
}
/// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx).
async fn api_call(
&self,
method: reqwest::Method,
url: &str,
body: Option<serde_json::Value>,
) -> Result<serde_json::Value> {
let mut last_err = None;
for attempt in 0..MAX_RETRIES {
let mut req = self
.http
.request(method.clone(), url)
.headers(self.headers()?);
if let Some(ref b) = body {
req = req.json(b);
}
match req.send().await {
Ok(resp) => {
let status = resp.status();
if status.is_success() {
return resp
.json()
.await
.map_err(|e| anyhow::anyhow!("Failed to parse response: {e}"));
}
let status_code = status.as_u16();
// Only retry on 429 (rate limit) or 5xx (server errors)
if status_code != 429 && (400..500).contains(&status_code) {
let body_text = resp.text().await.unwrap_or_default();
let truncated =
crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
bail!("Notion API error {status_code}: {truncated}");
}
last_err = Some(anyhow::anyhow!("Notion API error: {status_code}"));
}
Err(e) => {
last_err = Some(anyhow::anyhow!("HTTP request failed: {e}"));
}
}
let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
tracing::warn!(
"Notion API call failed (attempt {}/{}), retrying in {}ms",
attempt + 1,
MAX_RETRIES,
delay
);
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
}
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries")))
}
/// Query the database schema and detect whether Status uses "select" or "status" type.
async fn detect_status_type(&self) -> Result<String> {
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
let status_type = resp
.get("properties")
.and_then(|p| p.get(&self.status_property))
.and_then(|s| s.get("type"))
.and_then(|t| t.as_str())
.unwrap_or("select")
.to_string();
Ok(status_type)
}
/// Query for rows where Status = "pending".
async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
let status_type = self.status_type.read().await.clone();
let filter = build_status_filter(&self.status_property, &status_type, "pending");
let resp = self
.api_call(
reqwest::Method::POST,
&url,
Some(serde_json::json!({ "filter": filter })),
)
.await?;
Ok(resp
.get("results")
.and_then(|r| r.as_array())
.cloned()
.unwrap_or_default())
}
/// Atomically claim a task. Returns true if this caller got it.
async fn claim_task(&self, page_id: &str) -> bool {
let mut inflight = self.inflight.write().await;
if inflight.contains(page_id) {
return false;
}
if inflight.len() >= self.max_concurrent {
return false;
}
inflight.insert(page_id.to_string());
true
}
/// Release a task from the inflight set.
async fn release_task(&self, page_id: &str) {
let mut inflight = self.inflight.write().await;
inflight.remove(page_id);
}
/// Update a row's status.
async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let status_type = self.status_type.read().await.clone();
let payload = serde_json::json!({
"properties": {
&self.status_property: build_status_payload(&status_type, status_value),
}
});
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
.await?;
Ok(())
}
/// Write result text to the Result column.
async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let payload = serde_json::json!({
"properties": {
&self.result_property: build_rich_text_payload(result_text),
}
});
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
.await?;
Ok(())
}
/// On startup, reset "running" tasks back to "pending" for crash recovery.
async fn recover_stale(&self) -> Result<()> {
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
let status_type = self.status_type.read().await.clone();
let filter = build_status_filter(&self.status_property, &status_type, "running");
let resp = self
.api_call(
reqwest::Method::POST,
&url,
Some(serde_json::json!({ "filter": filter })),
)
.await?;
let stale = resp
.get("results")
.and_then(|r| r.as_array())
.cloned()
.unwrap_or_default();
if stale.is_empty() {
return Ok(());
}
tracing::warn!(
"Found {} stale task(s) in 'running' state, resetting to 'pending'",
stale.len()
);
for task in &stale {
if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
let payload = serde_json::json!({
"properties": {
&self.status_property: build_status_payload(&status_type, "pending"),
&self.result_property: build_rich_text_payload(
"Reset: poller restarted while task was running"
),
}
});
let short_id_end = floor_utf8_char_boundary(page_id, 8);
let short_id = &page_id[..short_id_end];
if let Err(e) = self
.api_call(reqwest::Method::PATCH, &page_url, Some(payload))
.await
{
tracing::error!("Could not reset stale task {short_id}: {e}");
} else {
tracing::info!("Reset stale task {short_id} to pending");
}
}
}
Ok(())
}
}
#[async_trait]
impl Channel for NotionChannel {
fn name(&self) -> &str {
"notion"
}
async fn send(&self, message: &SendMessage) -> Result<()> {
// recipient is the page_id for Notion
let page_id = &message.recipient;
let status_type = self.status_type.read().await.clone();
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let payload = serde_json::json!({
"properties": {
&self.status_property: build_status_payload(&status_type, "done"),
&self.result_property: build_rich_text_payload(&message.content),
}
});
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
.await?;
self.release_task(page_id).await;
Ok(())
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
// Detect status property type
match self.detect_status_type().await {
Ok(st) => {
tracing::info!("Notion status property type: {st}");
*self.status_type.write().await = st;
}
Err(e) => {
bail!("Failed to detect Notion database schema: {e}");
}
}
// Crash recovery
if self.recover_stale {
if let Err(e) = self.recover_stale().await {
tracing::error!("Notion stale task recovery failed: {e}");
}
}
// Polling loop
loop {
match self.query_pending().await {
Ok(tasks) => {
if !tasks.is_empty() {
tracing::info!("Notion: found {} pending task(s)", tasks.len());
}
for task in tasks {
let page_id = match task.get("id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => continue,
};
let input_text = extract_text_from_property(
task.get("properties")
.and_then(|p| p.get(&self.input_property)),
);
if input_text.trim().is_empty() {
let short_end = floor_utf8_char_boundary(&page_id, 8);
tracing::warn!(
"Notion: empty input for task {}, skipping",
&page_id[..short_end]
);
continue;
}
if !self.claim_task(&page_id).await {
continue;
}
// Set status to running
if let Err(e) = self.set_status(&page_id, "running").await {
tracing::error!("Notion: failed to set running status: {e}");
self.release_task(&page_id).await;
continue;
}
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if tx
.send(ChannelMessage {
id: page_id.clone(),
sender: "notion".into(),
reply_target: page_id,
content: input_text,
channel: "notion".into(),
timestamp,
thread_ts: None,
})
.await
.is_err()
{
tracing::info!("Notion channel shutting down");
return Ok(());
}
}
}
Err(e) => {
tracing::error!("Notion poll error: {e}");
}
}
tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
}
}
async fn health_check(&self) -> bool {
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
self.api_call(reqwest::Method::GET, &url, None)
.await
.is_ok()
}
}
// ── Helper functions ──────────────────────────────────────────────
/// Build a Notion API filter object for the given status property.
fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
if status_type == "status" {
serde_json::json!({
"property": property,
"status": { "equals": value }
})
} else {
serde_json::json!({
"property": property,
"select": { "equals": value }
})
}
}
/// Build a Notion API property-update payload for a status field.
fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
if status_type == "status" {
serde_json::json!({ "status": { "name": value } })
} else {
serde_json::json!({ "select": { "name": value } })
}
}
/// Build a Notion API rich-text property payload, truncating if necessary.
fn build_rich_text_payload(value: &str) -> serde_json::Value {
let truncated = truncate_result(value);
serde_json::json!({
"rich_text": [{
"text": { "content": truncated }
}]
})
}
/// Truncate result text to fit within the Notion rich-text content limit.
fn truncate_result(value: &str) -> String {
if value.len() <= MAX_RESULT_LENGTH {
return value.to_string();
}
let cut = MAX_RESULT_LENGTH.saturating_sub(30);
// Ensure we cut on a char boundary
let end = floor_utf8_char_boundary(value, cut);
format!("{}\n\n... [output truncated]", &value[..end])
}
/// Extract plain text from a Notion property (title or rich_text type).
fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
let Some(prop) = prop else {
return String::new();
};
let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
let array_key = match ptype {
"title" => "title",
"rich_text" => "rich_text",
_ => return String::new(),
};
prop.get(array_key)
.and_then(|arr| arr.as_array())
.map(|items| {
items
.iter()
.filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("")
})
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn claim_task_deduplication() {
let channel = NotionChannel::new(
"test-key".into(),
"test-db".into(),
5,
"Status".into(),
"Input".into(),
"Result".into(),
4,
false,
);
assert!(channel.claim_task("page-1").await);
// Second claim for same page should fail
assert!(!channel.claim_task("page-1").await);
// Different page should succeed
assert!(channel.claim_task("page-2").await);
// After release, can claim again
channel.release_task("page-1").await;
assert!(channel.claim_task("page-1").await);
}
#[test]
fn result_truncation_within_limit() {
let short = "hello world";
assert_eq!(truncate_result(short), short);
}
#[test]
fn result_truncation_over_limit() {
let long = "a".repeat(MAX_RESULT_LENGTH + 100);
let truncated = truncate_result(&long);
assert!(truncated.len() <= MAX_RESULT_LENGTH);
assert!(truncated.ends_with("... [output truncated]"));
}
#[test]
fn result_truncation_multibyte_safe() {
// Build a string that would cut in the middle of a multibyte char
let mut s = String::new();
for _ in 0..700 {
s.push('\u{6E2C}'); // 3-byte UTF-8 char
}
let truncated = truncate_result(&s);
// Should not panic and should be valid UTF-8
assert!(truncated.len() <= MAX_RESULT_LENGTH);
assert!(truncated.ends_with("... [output truncated]"));
}
#[test]
fn status_payload_select_type() {
let payload = build_status_payload("select", "pending");
assert_eq!(
payload,
serde_json::json!({ "select": { "name": "pending" } })
);
}
#[test]
fn status_payload_status_type() {
let payload = build_status_payload("status", "done");
assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
}
#[test]
fn rich_text_payload_construction() {
let payload = build_rich_text_payload("test output");
let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
assert_eq!(text, "test output");
}
#[test]
fn status_filter_select_type() {
let filter = build_status_filter("Status", "select", "pending");
assert_eq!(
filter,
serde_json::json!({
"property": "Status",
"select": { "equals": "pending" }
})
);
}
#[test]
fn status_filter_status_type() {
let filter = build_status_filter("Status", "status", "running");
assert_eq!(
filter,
serde_json::json!({
"property": "Status",
"status": { "equals": "running" }
})
);
}
#[test]
fn extract_text_from_title_property() {
let prop = serde_json::json!({
"type": "title",
"title": [
{ "plain_text": "Hello " },
{ "plain_text": "World" }
]
});
assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
}
#[test]
fn extract_text_from_rich_text_property() {
let prop = serde_json::json!({
"type": "rich_text",
"rich_text": [{ "plain_text": "task content" }]
});
assert_eq!(extract_text_from_property(Some(&prop)), "task content");
}
#[test]
fn extract_text_from_none() {
assert_eq!(extract_text_from_property(None), "");
}
#[test]
fn extract_text_from_unknown_type() {
let prop = serde_json::json!({ "type": "number", "number": 42 });
assert_eq!(extract_text_from_property(Some(&prop)), "");
}
#[tokio::test]
async fn claim_task_respects_max_concurrent() {
let channel = NotionChannel::new(
"test-key".into(),
"test-db".into(),
5,
"Status".into(),
"Input".into(),
"Result".into(),
2, // max_concurrent = 2
false,
);
assert!(channel.claim_task("page-1").await);
assert!(channel.claim_task("page-2").await);
// Third claim should be rejected (at capacity)
assert!(!channel.claim_task("page-3").await);
// After releasing one, can claim again
channel.release_task("page-1").await;
assert!(channel.claim_task("page-3").await);
}
}
+11 -8
View File
@@ -1,5 +1,6 @@
pub mod schema;
pub mod traits;
pub mod workspace;
#[allow(unused_imports)]
pub use schema::{
@@ -11,14 +12,16 @@ pub use schema::{
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode,
TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig,
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope,
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig,
};
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
+1085 -5
View File
File diff suppressed because it is too large Load Diff
+382
View File
@@ -0,0 +1,382 @@
//! Workspace profile management for multi-client isolation.
//!
//! Each workspace represents an isolated client engagement with its own
//! memory namespace, audit trail, secrets scope, and tool restrictions.
//! Profiles are stored under `~/.zeroclaw/workspaces/<client_name>/`.
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
/// A single client workspace profile.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkspaceProfile {
/// Human-readable workspace name (also used as directory name).
pub name: String,
/// Allowed domains for network access within this workspace.
#[serde(default)]
pub allowed_domains: Vec<String>,
/// Credential profile name scoped to this workspace.
#[serde(default)]
pub credential_profile: Option<String>,
/// Memory namespace prefix for isolation.
#[serde(default)]
pub memory_namespace: Option<String>,
/// Audit namespace prefix for isolation.
#[serde(default)]
pub audit_namespace: Option<String>,
/// Tool names denied in this workspace (e.g. `["shell"]` to block shell access).
#[serde(default)]
pub tool_restrictions: Vec<String>,
}
impl WorkspaceProfile {
/// Effective memory namespace (falls back to workspace name).
pub fn effective_memory_namespace(&self) -> &str {
self.memory_namespace
.as_deref()
.unwrap_or(self.name.as_str())
}
/// Effective audit namespace (falls back to workspace name).
pub fn effective_audit_namespace(&self) -> &str {
self.audit_namespace
.as_deref()
.unwrap_or(self.name.as_str())
}
/// Returns true if the given tool name is restricted in this workspace.
pub fn is_tool_restricted(&self, tool_name: &str) -> bool {
self.tool_restrictions
.iter()
.any(|r| r.eq_ignore_ascii_case(tool_name))
}
/// Returns true if the given domain is allowed for this workspace.
/// An empty allowlist means all domains are allowed.
pub fn is_domain_allowed(&self, domain: &str) -> bool {
if self.allowed_domains.is_empty() {
return true;
}
let domain_lower = domain.to_ascii_lowercase();
self.allowed_domains
.iter()
.any(|d| domain_lower == d.to_ascii_lowercase())
}
}
/// Manages loading and switching between client workspace profiles.
#[derive(Debug, Clone)]
pub struct WorkspaceManager {
/// Base directory containing all workspace subdirectories.
workspaces_dir: PathBuf,
/// Loaded workspace profiles keyed by name.
profiles: HashMap<String, WorkspaceProfile>,
/// Currently active workspace name.
active: Option<String>,
}
impl WorkspaceManager {
/// Create a new workspace manager rooted at the given directory.
pub fn new(workspaces_dir: PathBuf) -> Self {
Self {
workspaces_dir,
profiles: HashMap::new(),
active: None,
}
}
/// Load all workspace profiles from disk.
///
/// Each subdirectory of `workspaces_dir` that contains a `profile.toml`
/// is treated as a workspace.
pub async fn load_profiles(&mut self) -> Result<()> {
self.profiles.clear();
let dir = &self.workspaces_dir;
if !dir.exists() {
return Ok(());
}
let mut entries = tokio::fs::read_dir(dir)
.await
.with_context(|| format!("reading workspaces directory: {}", dir.display()))?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if !path.is_dir() {
continue;
}
let profile_path = path.join("profile.toml");
if !profile_path.exists() {
continue;
}
match tokio::fs::read_to_string(&profile_path).await {
Ok(contents) => match toml::from_str::<WorkspaceProfile>(&contents) {
Ok(profile) => {
self.profiles.insert(profile.name.clone(), profile);
}
Err(e) => {
tracing::warn!(
"skipping malformed workspace profile {}: {e}",
profile_path.display()
);
}
},
Err(e) => {
tracing::warn!(
"skipping unreadable workspace profile {}: {e}",
profile_path.display()
);
}
}
}
Ok(())
}
/// Switch to the named workspace. Returns an error if it does not exist.
pub fn switch(&mut self, name: &str) -> Result<&WorkspaceProfile> {
if !self.profiles.contains_key(name) {
bail!("workspace '{}' not found", name);
}
self.active = Some(name.to_string());
Ok(&self.profiles[name])
}
/// Get the currently active workspace profile, if any.
pub fn active_profile(&self) -> Option<&WorkspaceProfile> {
self.active
.as_deref()
.and_then(|name| self.profiles.get(name))
}
/// Get the active workspace name.
pub fn active_name(&self) -> Option<&str> {
self.active.as_deref()
}
/// List all loaded workspace names.
pub fn list(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.profiles.keys().map(String::as_str).collect();
names.sort_unstable();
names
}
/// Get a workspace profile by name.
pub fn get(&self, name: &str) -> Option<&WorkspaceProfile> {
self.profiles.get(name)
}
/// Create a new workspace on disk and register it.
pub async fn create(&mut self, name: &str) -> Result<&WorkspaceProfile> {
if name.is_empty() {
bail!("workspace name must not be empty");
}
// Validate name: alphanumeric, hyphens, underscores only
if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
bail!(
"workspace name must contain only alphanumeric characters, hyphens, or underscores"
);
}
if self.profiles.contains_key(name) {
bail!("workspace '{}' already exists", name);
}
let ws_dir = self.workspaces_dir.join(name);
tokio::fs::create_dir_all(&ws_dir)
.await
.with_context(|| format!("creating workspace directory: {}", ws_dir.display()))?;
let profile = WorkspaceProfile {
name: name.to_string(),
allowed_domains: Vec::new(),
credential_profile: None,
memory_namespace: Some(name.to_string()),
audit_namespace: Some(name.to_string()),
tool_restrictions: Vec::new(),
};
let toml_str = toml::to_string_pretty(&profile).context("serializing workspace profile")?;
let profile_path = ws_dir.join("profile.toml");
tokio::fs::write(&profile_path, toml_str)
.await
.with_context(|| format!("writing workspace profile: {}", profile_path.display()))?;
self.profiles.insert(name.to_string(), profile);
Ok(&self.profiles[name])
}
/// Export a workspace profile as a sanitized TOML string (no secrets).
pub fn export(&self, name: &str) -> Result<String> {
let profile = self
.profiles
.get(name)
.with_context(|| format!("workspace '{}' not found", name))?;
// Create an export-safe copy with credential_profile redacted
let export = WorkspaceProfile {
credential_profile: profile
.credential_profile
.as_ref()
.map(|_| "***".to_string()),
..profile.clone()
};
toml::to_string_pretty(&export).context("serializing workspace profile for export")
}
/// Directory for a specific workspace.
pub fn workspace_dir(&self, name: &str) -> PathBuf {
self.workspaces_dir.join(name)
}
/// Base workspaces directory.
pub fn workspaces_dir(&self) -> &Path {
&self.workspaces_dir
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn sample_profile(name: &str) -> WorkspaceProfile {
WorkspaceProfile {
name: name.to_string(),
allowed_domains: vec!["example.com".to_string()],
credential_profile: Some("test-creds".to_string()),
memory_namespace: Some(format!("{name}_mem")),
audit_namespace: Some(format!("{name}_audit")),
tool_restrictions: vec!["shell".to_string()],
}
}
#[test]
fn workspace_profile_tool_restriction_check() {
let profile = sample_profile("client_a");
assert!(profile.is_tool_restricted("shell"));
assert!(profile.is_tool_restricted("Shell"));
assert!(!profile.is_tool_restricted("file_read"));
}
#[test]
fn workspace_profile_domain_allowlist_empty_allows_all() {
let mut profile = sample_profile("client_a");
profile.allowed_domains.clear();
assert!(profile.is_domain_allowed("anything.com"));
}
#[test]
fn workspace_profile_domain_allowlist_enforced() {
let profile = sample_profile("client_a");
assert!(profile.is_domain_allowed("example.com"));
assert!(!profile.is_domain_allowed("other.com"));
}
#[test]
fn workspace_profile_effective_namespaces() {
let profile = sample_profile("client_a");
assert_eq!(profile.effective_memory_namespace(), "client_a_mem");
assert_eq!(profile.effective_audit_namespace(), "client_a_audit");
let fallback = WorkspaceProfile {
name: "test_ws".to_string(),
memory_namespace: None,
audit_namespace: None,
..sample_profile("test_ws")
};
assert_eq!(fallback.effective_memory_namespace(), "test_ws");
assert_eq!(fallback.effective_audit_namespace(), "test_ws");
}
#[tokio::test]
async fn workspace_manager_create_and_list() {
let tmp = TempDir::new().unwrap();
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
mgr.create("client_alpha").await.unwrap();
mgr.create("client_beta").await.unwrap();
let names = mgr.list();
assert_eq!(names, vec!["client_alpha", "client_beta"]);
}
#[tokio::test]
async fn workspace_manager_create_rejects_duplicate() {
let tmp = TempDir::new().unwrap();
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
mgr.create("client_a").await.unwrap();
let result = mgr.create("client_a").await;
assert!(result.is_err());
}
#[tokio::test]
async fn workspace_manager_create_rejects_invalid_name() {
let tmp = TempDir::new().unwrap();
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
assert!(mgr.create("").await.is_err());
assert!(mgr.create("bad name").await.is_err());
assert!(mgr.create("../escape").await.is_err());
}
#[tokio::test]
async fn workspace_manager_switch_and_active() {
let tmp = TempDir::new().unwrap();
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
mgr.create("ws_one").await.unwrap();
assert!(mgr.active_profile().is_none());
mgr.switch("ws_one").unwrap();
assert_eq!(mgr.active_name(), Some("ws_one"));
assert!(mgr.active_profile().is_some());
}
#[test]
fn workspace_manager_switch_nonexistent_fails() {
let mgr = WorkspaceManager::new(PathBuf::from("/tmp/nonexistent"));
let mut mgr = mgr;
assert!(mgr.switch("no_such_ws").is_err());
}
#[tokio::test]
async fn workspace_manager_load_profiles_from_disk() {
let tmp = TempDir::new().unwrap();
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
// Create a workspace via the manager
mgr.create("loaded_ws").await.unwrap();
// Create a fresh manager and load from disk
let mut mgr2 = WorkspaceManager::new(tmp.path().to_path_buf());
mgr2.load_profiles().await.unwrap();
assert_eq!(mgr2.list(), vec!["loaded_ws"]);
let profile = mgr2.get("loaded_ws").unwrap();
assert_eq!(profile.name, "loaded_ws");
}
#[tokio::test]
async fn workspace_manager_export_redacts_credentials() {
let tmp = TempDir::new().unwrap();
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
mgr.create("export_test").await.unwrap();
// Manually set a credential profile
if let Some(profile) = mgr.profiles.get_mut("export_test") {
profile.credential_profile = Some("secret-cred-id".to_string());
}
let exported = mgr.export("export_test").unwrap();
assert!(exported.contains("***"));
assert!(!exported.contains("secret-cred-id"));
}
}
+172 -21
View File
@@ -152,44 +152,122 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<(
crate::CronCommands::Add {
expression,
tz,
agent,
command,
} => {
let schedule = Schedule::Cron {
expr: expression,
tz,
};
let job = add_shell_job(config, None, schedule, &command)?;
println!("✅ Added cron job {}", job.id);
println!(" Expr: {}", job.expression);
println!(" Next: {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
if agent {
let job = add_agent_job(
config,
None,
schedule,
&command,
SessionTarget::Isolated,
None,
None,
false,
)?;
println!("✅ Added agent cron job {}", job.id);
println!(" Expr : {}", job.expression);
println!(" Next : {}", job.next_run.to_rfc3339());
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
} else {
let job = add_shell_job(config, None, schedule, &command)?;
println!("✅ Added cron job {}", job.id);
println!(" Expr: {}", job.expression);
println!(" Next: {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
}
Ok(())
}
crate::CronCommands::AddAt { at, command } => {
crate::CronCommands::AddAt { at, agent, command } => {
let at = chrono::DateTime::parse_from_rfc3339(&at)
.map_err(|e| anyhow::anyhow!("Invalid RFC3339 timestamp for --at: {e}"))?
.with_timezone(&chrono::Utc);
let schedule = Schedule::At { at };
let job = add_shell_job(config, None, schedule, &command)?;
println!("✅ Added one-shot cron job {}", job.id);
println!(" At : {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
if agent {
let job = add_agent_job(
config,
None,
schedule,
&command,
SessionTarget::Isolated,
None,
None,
true,
)?;
println!("✅ Added one-shot agent cron job {}", job.id);
println!(" At : {}", job.next_run.to_rfc3339());
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
} else {
let job = add_shell_job(config, None, schedule, &command)?;
println!("✅ Added one-shot cron job {}", job.id);
println!(" At : {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
}
Ok(())
}
crate::CronCommands::AddEvery { every_ms, command } => {
crate::CronCommands::AddEvery {
every_ms,
agent,
command,
} => {
let schedule = Schedule::Every { every_ms };
let job = add_shell_job(config, None, schedule, &command)?;
println!("✅ Added interval cron job {}", job.id);
println!(" Every(ms): {every_ms}");
println!(" Next : {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
if agent {
let job = add_agent_job(
config,
None,
schedule,
&command,
SessionTarget::Isolated,
None,
None,
false,
)?;
println!("✅ Added interval agent cron job {}", job.id);
println!(" Every(ms): {every_ms}");
println!(" Next : {}", job.next_run.to_rfc3339());
println!(" Prompt : {}", job.prompt.as_deref().unwrap_or_default());
} else {
let job = add_shell_job(config, None, schedule, &command)?;
println!("✅ Added interval cron job {}", job.id);
println!(" Every(ms): {every_ms}");
println!(" Next : {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
}
Ok(())
}
crate::CronCommands::Once { delay, command } => {
let job = add_once(config, &delay, &command)?;
println!("✅ Added one-shot cron job {}", job.id);
println!(" At : {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
crate::CronCommands::Once {
delay,
agent,
command,
} => {
if agent {
let duration = parse_delay(&delay)?;
let at = chrono::Utc::now() + duration;
let schedule = Schedule::At { at };
let job = add_agent_job(
config,
None,
schedule,
&command,
SessionTarget::Isolated,
None,
None,
true,
)?;
println!("✅ Added one-shot agent cron job {}", job.id);
println!(" At : {}", job.next_run.to_rfc3339());
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
} else {
let job = add_once(config, &delay, &command)?;
println!("✅ Added one-shot cron job {}", job.id);
println!(" At : {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
}
Ok(())
}
crate::CronCommands::Update {
@@ -686,4 +764,77 @@ mod tests {
.to_string()
.contains("blocked by security policy"));
}
#[test]
fn cli_agent_flag_creates_agent_job() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
handle_command(
crate::CronCommands::Add {
expression: "*/15 * * * *".into(),
tz: None,
agent: true,
command: "Check server health: disk space, memory, CPU load".into(),
},
&config,
)
.unwrap();
let jobs = list_jobs(&config).unwrap();
assert_eq!(jobs.len(), 1);
assert_eq!(jobs[0].job_type, JobType::Agent);
assert_eq!(
jobs[0].prompt.as_deref(),
Some("Check server health: disk space, memory, CPU load")
);
}
#[test]
fn cli_agent_flag_bypasses_shell_security_validation() {
let tmp = TempDir::new().unwrap();
let mut config = test_config(&tmp);
config.autonomy.allowed_commands = vec!["echo".into()];
config.autonomy.level = crate::security::AutonomyLevel::Supervised;
// Without --agent, a natural language string would be blocked by shell
// security policy. With --agent, it routes to agent job and skips
// shell validation entirely.
let result = handle_command(
crate::CronCommands::Add {
expression: "*/15 * * * *".into(),
tz: None,
agent: true,
command: "Check server health: disk space, memory, CPU load".into(),
},
&config,
);
assert!(result.is_ok());
let jobs = list_jobs(&config).unwrap();
assert_eq!(jobs.len(), 1);
assert_eq!(jobs[0].job_type, JobType::Agent);
}
#[test]
fn cli_without_agent_flag_defaults_to_shell_job() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
handle_command(
crate::CronCommands::Add {
expression: "*/5 * * * *".into(),
tz: None,
agent: false,
command: "echo ok".into(),
},
&config,
)
.unwrap();
let jobs = list_jobs(&config).unwrap();
assert_eq!(jobs.len(), 1);
assert_eq!(jobs[0].job_type, JobType::Shell);
assert_eq!(jobs[0].command, "echo ok");
}
}
+36 -23
View File
@@ -53,7 +53,7 @@ pub async fn run(config: Config) -> Result<()> {
pub async fn execute_job_now(config: &Config, job: &CronJob) -> (bool, String) {
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
execute_job_with_retry(config, &security, job).await
Box::pin(execute_job_with_retry(config, &security, job)).await
}
async fn execute_job_with_retry(
@@ -68,7 +68,7 @@ async fn execute_job_with_retry(
for attempt in 0..=retries {
let (success, output) = match job.job_type {
JobType::Shell => run_job_command(config, security, job).await,
JobType::Agent => run_agent_job(config, security, job).await,
JobType::Agent => Box::pin(run_agent_job(config, security, job)).await,
};
last_output = output;
@@ -101,18 +101,21 @@ async fn process_due_jobs(
crate::health::mark_component_ok(component);
let max_concurrent = config.scheduler.max_concurrent.max(1);
let mut in_flight =
stream::iter(
jobs.into_iter().map(|job| {
let config = config.clone();
let security = Arc::clone(security);
let component = component.to_owned();
async move {
execute_and_persist_job(&config, security.as_ref(), &job, &component).await
}
}),
)
.buffer_unordered(max_concurrent);
let mut in_flight = stream::iter(jobs.into_iter().map(|job| {
let config = config.clone();
let security = Arc::clone(security);
let component = component.to_owned();
async move {
Box::pin(execute_and_persist_job(
&config,
security.as_ref(),
&job,
&component,
))
.await
}
}))
.buffer_unordered(max_concurrent);
while let Some((job_id, success, output)) = in_flight.next().await {
if !success {
@@ -131,9 +134,17 @@ async fn execute_and_persist_job(
warn_if_high_frequency_agent_job(job);
let started_at = Utc::now();
let (success, output) = execute_job_with_retry(config, security, job).await;
let (success, output) = Box::pin(execute_job_with_retry(config, security, job)).await;
let finished_at = Utc::now();
let success = persist_job_result(config, job, success, &output, started_at, finished_at).await;
let success = Box::pin(persist_job_result(
config,
job,
success,
&output,
started_at,
finished_at,
))
.await;
(job.id.clone(), success, output)
}
@@ -170,7 +181,7 @@ async fn run_agent_job(
let run_result = match job.session_target {
SessionTarget::Main | SessionTarget::Isolated => {
crate::agent::run(
Box::pin(crate::agent::run(
config.clone(),
Some(prefixed_prompt),
None,
@@ -179,7 +190,8 @@ async fn run_agent_job(
vec![],
false,
None,
)
job.allowed_tools.clone(),
))
.await
}
};
@@ -557,6 +569,7 @@ mod tests {
enabled: true,
delivery: DeliveryConfig::default(),
delete_after_run: false,
allowed_tools: None,
created_at: Utc::now(),
next_run: Utc::now(),
last_run: None,
@@ -742,7 +755,7 @@ mod tests {
.unwrap();
let job = test_job("sh ./retry-once.sh");
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
let (success, output) = Box::pin(execute_job_with_retry(&config, &security, &job)).await;
assert!(success);
assert!(output.contains("recovered"));
}
@@ -757,7 +770,7 @@ mod tests {
let job = test_job("ls always_missing_for_retry_test");
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
let (success, output) = Box::pin(execute_job_with_retry(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("always_missing_for_retry_test"));
}
@@ -771,7 +784,7 @@ mod tests {
job.prompt = Some("Say hello".into());
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_agent_job(&config, &security, &job).await;
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("agent job failed:"));
}
@@ -786,7 +799,7 @@ mod tests {
job.prompt = Some("Say hello".into());
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_agent_job(&config, &security, &job).await;
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("blocked by security policy"));
assert!(output.contains("read-only"));
@@ -802,7 +815,7 @@ mod tests {
job.prompt = Some("Say hello".into());
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_agent_job(&config, &security, &job).await;
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
assert!(!success);
assert!(output.contains("blocked by security policy"));
assert!(output.contains("rate limit exceeded"));
+1
View File
@@ -453,6 +453,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
},
last_status: row.get(15)?,
last_output: row.get(16)?,
allowed_tools: None,
})
}
+6
View File
@@ -115,6 +115,11 @@ pub struct CronJob {
pub enabled: bool,
pub delivery: DeliveryConfig,
pub delete_after_run: bool,
/// Optional allowlist of tool names this cron job may use.
/// When `Some(list)`, only tools whose name is in the list are available.
/// When `None`, all tools are available (backward compatible default).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_tools: Option<Vec<String>>,
pub created_at: DateTime<Utc>,
pub next_run: DateTime<Utc>,
pub last_run: Option<DateTime<Utc>>,
@@ -144,6 +149,7 @@ pub struct CronJobPatch {
pub model: Option<String>,
pub session_target: Option<SessionTarget>,
pub delete_after_run: Option<bool>,
pub allowed_tools: Option<Vec<String>>,
}
#[cfg(test)]
+7 -5
View File
@@ -77,7 +77,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
max_backoff,
move || {
let cfg = channels_cfg.clone();
async move { crate::channels::start_channels(cfg).await }
async move { Box::pin(crate::channels::start_channels(cfg)).await }
},
));
} else {
@@ -245,7 +245,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
// ── Phase 1: LLM decision (two-phase mode) ──────────────
let tasks_to_run = if two_phase {
let decision_prompt = HeartbeatEngine::build_decision_prompt(&tasks);
match crate::agent::run(
match Box::pin(crate::agent::run(
config.clone(),
Some(decision_prompt),
None,
@@ -254,7 +254,8 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
vec![],
false,
None,
)
None,
))
.await
{
Ok(response) => {
@@ -287,7 +288,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
for task in &tasks_to_run {
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
let temp = config.default_temperature;
match crate::agent::run(
match Box::pin(crate::agent::run(
config.clone(),
Some(prompt),
None,
@@ -296,7 +297,8 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
vec![],
false,
None,
)
None,
))
.await
{
Ok(output) => {
+14 -10
View File
@@ -910,7 +910,7 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
/// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk).
async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result<String> {
let config = state.config.lock().clone();
crate::agent::process_message(config, message).await
Box::pin(crate::agent::process_message(config, message)).await
}
/// Webhook request body
@@ -1238,7 +1238,7 @@ async fn handle_whatsapp_message(
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa
@@ -1346,7 +1346,7 @@ async fn handle_linq_webhook(
}
// Call the LLM
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
// Send reply via Linq
if let Err(e) = linq
@@ -1438,7 +1438,7 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
}
// Call the LLM
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
// Send reply via WATI
if let Err(e) = wati
@@ -1542,7 +1542,7 @@ async fn handle_nextcloud_talk_webhook(
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
Ok(response) => {
if let Err(e) = nextcloud_talk
.send(&SendMessage::new(response, &msg.reply_target))
@@ -2492,11 +2492,11 @@ mod tests {
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
};
let response = handle_nextcloud_talk_webhook(
let response = Box::pin(handle_nextcloud_talk_webhook(
State(state),
HeaderMap::new(),
Bytes::from_static(br#"{"type":"message"}"#),
)
))
.await
.into_response();
@@ -2558,9 +2558,13 @@ mod tests {
HeaderValue::from_str(invalid_signature).unwrap(),
);
let response = handle_nextcloud_talk_webhook(State(state), headers, Bytes::from(body))
.await
.into_response();
let response = Box::pin(handle_nextcloud_talk_webhook(
State(state),
headers,
Bytes::from(body),
))
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
+229
View File
@@ -0,0 +1,229 @@
pub mod types;
pub use types::{Hand, HandContext, HandRun, HandRunStatus};
use anyhow::{Context, Result};
use std::path::Path;
/// Load all hand definitions from TOML files in the given directory.
///
/// Each `.toml` file in `hands_dir` is expected to deserialize into a [`Hand`].
/// Files that fail to parse are logged and skipped.
pub fn load_hands(hands_dir: &Path) -> Result<Vec<Hand>> {
if !hands_dir.is_dir() {
return Ok(Vec::new());
}
let mut hands = Vec::new();
let entries = std::fs::read_dir(hands_dir)
.with_context(|| format!("failed to read hands directory: {}", hands_dir.display()))?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("toml") {
continue;
}
let content = std::fs::read_to_string(&path)
.with_context(|| format!("failed to read hand file: {}", path.display()))?;
match toml::from_str::<Hand>(&content) {
Ok(hand) => hands.push(hand),
Err(e) => {
tracing::warn!(path = %path.display(), error = %e, "skipping malformed hand file");
}
}
}
Ok(hands)
}
/// Load the rolling context for a hand.
///
/// Reads from `{hands_dir}/{name}/context.json`. Returns a fresh
/// [`HandContext`] if the file does not exist yet.
pub fn load_hand_context(hands_dir: &Path, name: &str) -> Result<HandContext> {
let path = hands_dir.join(name).join("context.json");
if !path.exists() {
return Ok(HandContext::new(name));
}
let content = std::fs::read_to_string(&path)
.with_context(|| format!("failed to read hand context: {}", path.display()))?;
let ctx: HandContext = serde_json::from_str(&content)
.with_context(|| format!("failed to parse hand context: {}", path.display()))?;
Ok(ctx)
}
/// Persist the rolling context for a hand.
///
/// Writes to `{hands_dir}/{name}/context.json`, creating the
/// directory if it does not exist.
pub fn save_hand_context(hands_dir: &Path, context: &HandContext) -> Result<()> {
let dir = hands_dir.join(&context.hand_name);
std::fs::create_dir_all(&dir)
.with_context(|| format!("failed to create hand context dir: {}", dir.display()))?;
let path = dir.join("context.json");
let json = serde_json::to_string_pretty(context)?;
std::fs::write(&path, json)
.with_context(|| format!("failed to write hand context: {}", path.display()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn write_hand_toml(dir: &Path, filename: &str, content: &str) {
std::fs::write(dir.join(filename), content).unwrap();
}
#[test]
fn load_hands_empty_dir() {
let tmp = TempDir::new().unwrap();
let hands = load_hands(tmp.path()).unwrap();
assert!(hands.is_empty());
}
#[test]
fn load_hands_nonexistent_dir() {
let hands = load_hands(Path::new("/nonexistent/path/hands")).unwrap();
assert!(hands.is_empty());
}
#[test]
fn load_hands_parses_valid_files() {
let tmp = TempDir::new().unwrap();
write_hand_toml(
tmp.path(),
"scanner.toml",
r#"
name = "scanner"
description = "Market scanner"
prompt = "Scan markets."
[schedule]
kind = "cron"
expr = "0 9 * * *"
"#,
);
write_hand_toml(
tmp.path(),
"digest.toml",
r#"
name = "digest"
description = "News digest"
prompt = "Digest news."
[schedule]
kind = "every"
every_ms = 3600000
"#,
);
let hands = load_hands(tmp.path()).unwrap();
assert_eq!(hands.len(), 2);
}
#[test]
fn load_hands_skips_malformed_files() {
let tmp = TempDir::new().unwrap();
write_hand_toml(tmp.path(), "bad.toml", "this is not valid toml struct");
write_hand_toml(
tmp.path(),
"good.toml",
r#"
name = "good"
description = "A good hand"
prompt = "Do good things."
[schedule]
kind = "every"
every_ms = 60000
"#,
);
let hands = load_hands(tmp.path()).unwrap();
assert_eq!(hands.len(), 1);
assert_eq!(hands[0].name, "good");
}
#[test]
fn load_hands_ignores_non_toml_files() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("readme.md"), "# Hands").unwrap();
std::fs::write(tmp.path().join("notes.txt"), "some notes").unwrap();
let hands = load_hands(tmp.path()).unwrap();
assert!(hands.is_empty());
}
#[test]
fn context_roundtrip_through_filesystem() {
let tmp = TempDir::new().unwrap();
let mut ctx = HandContext::new("test-hand");
let run = HandRun {
hand_name: "test-hand".into(),
run_id: "run-001".into(),
started_at: chrono::Utc::now(),
finished_at: Some(chrono::Utc::now()),
status: HandRunStatus::Completed,
findings: vec!["found something".into()],
knowledge_added: vec!["learned something".into()],
duration_ms: Some(500),
};
ctx.record_run(run, 100);
save_hand_context(tmp.path(), &ctx).unwrap();
let loaded = load_hand_context(tmp.path(), "test-hand").unwrap();
assert_eq!(loaded.hand_name, "test-hand");
assert_eq!(loaded.total_runs, 1);
assert_eq!(loaded.history.len(), 1);
assert_eq!(loaded.learned_facts, vec!["learned something"]);
}
#[test]
fn load_context_returns_fresh_when_missing() {
let tmp = TempDir::new().unwrap();
let ctx = load_hand_context(tmp.path(), "nonexistent").unwrap();
assert_eq!(ctx.hand_name, "nonexistent");
assert_eq!(ctx.total_runs, 0);
assert!(ctx.history.is_empty());
}
#[test]
fn save_context_creates_directory() {
let tmp = TempDir::new().unwrap();
let ctx = HandContext::new("new-hand");
save_hand_context(tmp.path(), &ctx).unwrap();
assert!(tmp.path().join("new-hand").join("context.json").exists());
}
#[test]
fn save_then_load_preserves_multiple_runs() {
let tmp = TempDir::new().unwrap();
let mut ctx = HandContext::new("multi");
for i in 0..5 {
let run = HandRun {
hand_name: "multi".into(),
run_id: format!("run-{i:03}"),
started_at: chrono::Utc::now(),
finished_at: Some(chrono::Utc::now()),
status: HandRunStatus::Completed,
findings: vec![format!("finding-{i}")],
knowledge_added: vec![format!("fact-{i}")],
duration_ms: Some(100),
};
ctx.record_run(run, 3);
}
save_hand_context(tmp.path(), &ctx).unwrap();
let loaded = load_hand_context(tmp.path(), "multi").unwrap();
assert_eq!(loaded.total_runs, 5);
assert_eq!(loaded.history.len(), 3, "history capped at max_history=3");
assert_eq!(loaded.learned_facts.len(), 5);
}
}
+345
View File
@@ -0,0 +1,345 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::cron::Schedule;
// ── Hand ───────────────────────────────────────────────────────
/// A Hand is an autonomous agent package that runs on a schedule,
/// accumulates knowledge over time, and reports results.
///
/// Hands are defined as TOML files in `~/.zeroclaw/hands/` and each
/// maintains a rolling context of findings across runs so the agent
/// grows smarter with every execution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hand {
/// Unique name (also used as directory/file stem)
pub name: String,
/// Human-readable description of what this hand does
pub description: String,
/// The schedule this hand runs on (reuses cron schedule types)
pub schedule: Schedule,
/// System prompt / execution plan for this hand
pub prompt: String,
/// Domain knowledge lines to inject into context
#[serde(default)]
pub knowledge: Vec<String>,
/// Tools this hand is allowed to use (None = all available)
#[serde(default)]
pub allowed_tools: Option<Vec<String>>,
/// Model override for this hand (None = default provider)
#[serde(default)]
pub model: Option<String>,
/// Whether this hand is currently active
#[serde(default = "default_true")]
pub active: bool,
/// Maximum runs to keep in history
#[serde(default = "default_max_runs")]
pub max_history: usize,
}
fn default_true() -> bool {
true
}
fn default_max_runs() -> usize {
100
}
// ── Hand Run ───────────────────────────────────────────────────
/// The status of a single hand execution.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "status")]
pub enum HandRunStatus {
Running,
Completed,
Failed { error: String },
}
/// Record of a single hand execution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandRun {
/// Name of the hand that produced this run
pub hand_name: String,
/// Unique identifier for this run
pub run_id: String,
/// When the run started
pub started_at: DateTime<Utc>,
/// When the run finished (None if still running)
pub finished_at: Option<DateTime<Utc>>,
/// Outcome of the run
pub status: HandRunStatus,
/// Key findings/outputs extracted from this run
#[serde(default)]
pub findings: Vec<String>,
/// New knowledge accumulated and stored to memory
#[serde(default)]
pub knowledge_added: Vec<String>,
/// Wall-clock duration in milliseconds
pub duration_ms: Option<u64>,
}
// ── Hand Context ───────────────────────────────────────────────
/// Rolling context that accumulates across hand runs.
///
/// Persisted as `~/.zeroclaw/hands/{name}/context.json`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandContext {
/// Name of the hand this context belongs to
pub hand_name: String,
/// Past runs, most-recent first, capped at `Hand::max_history`
#[serde(default)]
pub history: Vec<HandRun>,
/// Persistent facts learned across runs
#[serde(default)]
pub learned_facts: Vec<String>,
/// Timestamp of the last completed run
pub last_run: Option<DateTime<Utc>>,
/// Total number of successful runs
#[serde(default)]
pub total_runs: u64,
}
impl HandContext {
/// Create a fresh, empty context for a hand.
pub fn new(hand_name: &str) -> Self {
Self {
hand_name: hand_name.to_string(),
history: Vec::new(),
learned_facts: Vec::new(),
last_run: None,
total_runs: 0,
}
}
/// Record a completed run, updating counters and trimming history.
pub fn record_run(&mut self, run: HandRun, max_history: usize) {
if run.status == (HandRunStatus::Completed) {
self.total_runs += 1;
self.last_run = run.finished_at;
}
// Merge new knowledge
for fact in &run.knowledge_added {
if !self.learned_facts.contains(fact) {
self.learned_facts.push(fact.clone());
}
}
// Insert at the front (most-recent first)
self.history.insert(0, run);
// Cap history length
self.history.truncate(max_history);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cron::Schedule;
fn sample_hand() -> Hand {
Hand {
name: "market-scanner".into(),
description: "Scans market trends and reports findings".into(),
schedule: Schedule::Cron {
expr: "0 9 * * 1-5".into(),
tz: Some("America/New_York".into()),
},
prompt: "Scan market trends and report key findings.".into(),
knowledge: vec!["Focus on tech sector.".into()],
allowed_tools: Some(vec!["web_search".into(), "memory".into()]),
model: Some("claude-opus-4-6".into()),
active: true,
max_history: 50,
}
}
fn sample_run(name: &str, status: HandRunStatus) -> HandRun {
let now = Utc::now();
HandRun {
hand_name: name.into(),
run_id: uuid::Uuid::new_v4().to_string(),
started_at: now,
finished_at: Some(now),
status,
findings: vec!["finding-1".into()],
knowledge_added: vec!["learned-fact-A".into()],
duration_ms: Some(1234),
}
}
// ── Deserialization ────────────────────────────────────────
#[test]
fn hand_deserializes_from_toml() {
let toml_str = r#"
name = "market-scanner"
description = "Scans market trends"
prompt = "Scan trends."
[schedule]
kind = "cron"
expr = "0 9 * * 1-5"
tz = "America/New_York"
"#;
let hand: Hand = toml::from_str(toml_str).unwrap();
assert_eq!(hand.name, "market-scanner");
assert!(hand.active, "active should default to true");
assert_eq!(hand.max_history, 100, "max_history should default to 100");
assert!(hand.knowledge.is_empty());
assert!(hand.allowed_tools.is_none());
assert!(hand.model.is_none());
}
#[test]
fn hand_deserializes_full_toml() {
let toml_str = r#"
name = "news-digest"
description = "Daily news digest"
prompt = "Summarize the day's news."
knowledge = ["focus on AI", "include funding rounds"]
allowed_tools = ["web_search"]
model = "claude-opus-4-6"
active = false
max_history = 25
[schedule]
kind = "every"
every_ms = 3600000
"#;
let hand: Hand = toml::from_str(toml_str).unwrap();
assert_eq!(hand.name, "news-digest");
assert!(!hand.active);
assert_eq!(hand.max_history, 25);
assert_eq!(hand.knowledge.len(), 2);
assert_eq!(hand.allowed_tools.as_ref().unwrap().len(), 1);
assert_eq!(hand.model.as_deref(), Some("claude-opus-4-6"));
assert!(matches!(
hand.schedule,
Schedule::Every {
every_ms: 3_600_000
}
));
}
#[test]
fn hand_roundtrip_json() {
let hand = sample_hand();
let json = serde_json::to_string(&hand).unwrap();
let parsed: Hand = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, hand.name);
assert_eq!(parsed.max_history, hand.max_history);
}
// ── HandRunStatus ──────────────────────────────────────────
#[test]
fn hand_run_status_serde_roundtrip() {
let statuses = vec![
HandRunStatus::Running,
HandRunStatus::Completed,
HandRunStatus::Failed {
error: "timeout".into(),
},
];
for status in statuses {
let json = serde_json::to_string(&status).unwrap();
let parsed: HandRunStatus = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, status);
}
}
// ── HandContext ────────────────────────────────────────────
#[test]
fn context_new_is_empty() {
let ctx = HandContext::new("test-hand");
assert_eq!(ctx.hand_name, "test-hand");
assert!(ctx.history.is_empty());
assert!(ctx.learned_facts.is_empty());
assert!(ctx.last_run.is_none());
assert_eq!(ctx.total_runs, 0);
}
#[test]
fn context_record_run_increments_counters() {
let mut ctx = HandContext::new("scanner");
let run = sample_run("scanner", HandRunStatus::Completed);
ctx.record_run(run, 100);
assert_eq!(ctx.total_runs, 1);
assert!(ctx.last_run.is_some());
assert_eq!(ctx.history.len(), 1);
assert_eq!(ctx.learned_facts, vec!["learned-fact-A"]);
}
#[test]
fn context_record_failed_run_does_not_increment_total() {
let mut ctx = HandContext::new("scanner");
let run = sample_run(
"scanner",
HandRunStatus::Failed {
error: "boom".into(),
},
);
ctx.record_run(run, 100);
assert_eq!(ctx.total_runs, 0);
assert!(ctx.last_run.is_none());
assert_eq!(ctx.history.len(), 1);
}
#[test]
fn context_caps_history_at_max() {
let mut ctx = HandContext::new("scanner");
for _ in 0..10 {
let run = sample_run("scanner", HandRunStatus::Completed);
ctx.record_run(run, 3);
}
assert_eq!(ctx.history.len(), 3);
assert_eq!(ctx.total_runs, 10);
}
#[test]
fn context_deduplicates_learned_facts() {
let mut ctx = HandContext::new("scanner");
let run1 = sample_run("scanner", HandRunStatus::Completed);
let run2 = sample_run("scanner", HandRunStatus::Completed);
ctx.record_run(run1, 100);
ctx.record_run(run2, 100);
// Both runs add "learned-fact-A" but it should appear only once
assert_eq!(ctx.learned_facts.len(), 1);
}
#[test]
fn context_json_roundtrip() {
let mut ctx = HandContext::new("scanner");
let run = sample_run("scanner", HandRunStatus::Completed);
ctx.record_run(run, 100);
let json = serde_json::to_string_pretty(&ctx).unwrap();
let parsed: HandContext = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.hand_name, "scanner");
assert_eq!(parsed.total_runs, 1);
assert_eq!(parsed.history.len(), 1);
assert_eq!(parsed.learned_facts, vec!["learned-fact-A"]);
}
#[test]
fn most_recent_run_is_first_in_history() {
let mut ctx = HandContext::new("scanner");
for i in 0..3 {
let mut run = sample_run("scanner", HandRunStatus::Completed);
run.findings = vec![format!("finding-{i}")];
ctx.record_run(run, 100);
}
assert_eq!(ctx.history[0].findings[0], "finding-2");
assert_eq!(ctx.history[2].findings[0], "finding-0");
}
}
+21 -6
View File
@@ -48,6 +48,7 @@ pub(crate) mod cron;
pub(crate) mod daemon;
pub(crate) mod doctor;
pub mod gateway;
pub mod hands;
pub(crate) mod hardware;
pub(crate) mod health;
pub(crate) mod heartbeat;
@@ -57,6 +58,7 @@ pub(crate) mod integrations;
pub mod memory;
pub(crate) mod migration;
pub(crate) mod multimodal;
pub mod nodes;
pub mod observability;
pub(crate) mod onboard;
pub mod peripherals;
@@ -280,15 +282,19 @@ Times are evaluated in UTC by default; use --tz with an IANA \
timezone name to override.
Examples:
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
zeroclaw cron add '*/30 * * * *' 'Check system health'")]
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
zeroclaw cron add '*/5 * * * *' 'echo ok'")]
Add {
/// Cron expression
expression: String,
/// Optional IANA timezone (e.g. America/Los_Angeles)
#[arg(long)]
tz: Option<String>,
/// Command to run
/// Treat the argument as an agent prompt instead of a shell command
#[arg(long)]
agent: bool,
/// Command (shell) or prompt (agent) to run
command: String,
},
/// Add a one-shot scheduled task at an RFC3339 timestamp
@@ -303,7 +309,10 @@ Examples:
AddAt {
/// One-shot timestamp in RFC3339 format
at: String,
/// Command to run
/// Treat the argument as an agent prompt instead of a shell command
#[arg(long)]
agent: bool,
/// Command (shell) or prompt (agent) to run
command: String,
},
/// Add a fixed-interval scheduled task
@@ -318,7 +327,10 @@ Examples:
AddEvery {
/// Interval in milliseconds
every_ms: u64,
/// Command to run
/// Treat the argument as an agent prompt instead of a shell command
#[arg(long)]
agent: bool,
/// Command (shell) or prompt (agent) to run
command: String,
},
/// Add a one-shot delayed task (e.g. "30m", "2h", "1d")
@@ -335,7 +347,10 @@ Examples:
Once {
/// Delay duration
delay: String,
/// Command to run
/// Treat the argument as an agent prompt instead of a shell command
#[arg(long)]
agent: bool,
/// Command (shell) or prompt (agent) to run
command: String,
},
/// Remove a scheduled task
+41 -14
View File
@@ -37,7 +37,7 @@ use anyhow::{bail, Context, Result};
use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
use dialoguer::{Input, Password};
use serde::{Deserialize, Serialize};
use std::io::Write;
use std::io::{IsTerminal, Write};
use std::path::PathBuf;
use tracing::{info, warn};
use tracing_subscriber::{fmt, EnvFilter};
@@ -325,11 +325,12 @@ override with --tz and an IANA timezone name.
Examples:
zeroclaw cron list
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
zeroclaw cron add '*/30 * * * *' 'Check system health'
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder'
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
zeroclaw cron add '*/5 * * * *' 'echo ok'
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder' --agent
zeroclaw cron add-every 60000 'Ping heartbeat'
zeroclaw cron once 30m 'Run backup in 30 minutes'
zeroclaw cron once 30m 'Run backup in 30 minutes' --agent
zeroclaw cron pause <task-id>
zeroclaw cron update <task-id> --expression '0 8 * * *' --tz Europe/London")]
Cron {
@@ -718,10 +719,11 @@ async fn main() -> Result<()> {
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
// Onboard runs quick setup by default, or the interactive wizard with --interactive.
// The onboard wizard uses reqwest::blocking internally, which creates its own
// Tokio runtime. To avoid "Cannot drop a runtime in a context where blocking is
// not allowed", we run the wizard on a blocking thread via spawn_blocking.
// Onboard auto-detects the environment: if stdin/stdout are a TTY and no
// provider flags were given, it runs the full interactive wizard; otherwise
// it runs the quick (scriptable) setup. This means `curl … | bash` and
// `zeroclaw onboard --api-key …` both take the fast path, while a bare
// `zeroclaw onboard` in a terminal launches the wizard.
if let Commands::Onboard {
force,
reinit,
@@ -793,8 +795,16 @@ async fn main() -> Result<()> {
}
}
// Auto-detect: run the interactive wizard when in a TTY with no
// provider flags, quick setup otherwise (scriptable path).
let has_provider_flags =
api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some();
let is_tty = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
let config = if channels_only {
Box::pin(onboard::run_channels_repair_wizard()).await
} else if is_tty && !has_provider_flags {
Box::pin(onboard::run_wizard(force)).await
} else {
onboard::run_quick_setup(
api_key.as_deref(),
@@ -834,7 +844,7 @@ async fn main() -> Result<()> {
// Auto-start channels if user said yes during wizard
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
channels::start_channels(config).await?;
Box::pin(channels::start_channels(config)).await?;
}
return Ok(());
}
@@ -870,7 +880,7 @@ async fn main() -> Result<()> {
} => {
let final_temperature = temperature.unwrap_or(config.default_temperature);
agent::run(
Box::pin(agent::run(
config,
message,
provider,
@@ -879,7 +889,8 @@ async fn main() -> Result<()> {
peripheral,
true,
session_state_file,
)
None,
))
.await
.map(|_| ())
}
@@ -1178,8 +1189,8 @@ async fn main() -> Result<()> {
},
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await,
ChannelCommands::Doctor => channels::doctor_channels(config).await,
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
other => channels::handle_command(other, &config).await,
},
@@ -2206,6 +2217,22 @@ mod tests {
}
}
#[test]
fn onboard_cli_rejects_removed_interactive_flag() {
// --interactive was removed; onboard auto-detects TTY instead.
assert!(Cli::try_parse_from(["zeroclaw", "onboard", "--interactive"]).is_err());
}
#[test]
fn onboard_cli_bare_parses() {
let cli = Cli::try_parse_from(["zeroclaw", "onboard"]).expect("bare onboard should parse");
match cli.command {
Commands::Onboard { .. } => {}
other => panic!("expected onboard command, got {other:?}"),
}
}
#[test]
fn cli_parses_estop_default_engage() {
let cli = Cli::try_parse_from(["zeroclaw", "estop"]).expect("estop command should parse");
+24 -2
View File
@@ -43,8 +43,13 @@ pub async fn consolidate_turn(
let turn_text = format!("User: {user_message}\nAssistant: {assistant_response}");
// Truncate very long turns to avoid wasting tokens on consolidation.
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8 (e.g. CJK text).
let truncated = if turn_text.len() > 4000 {
format!("{}", &turn_text[..4000])
let mut end = 4000;
while end > 0 && !turn_text.is_char_boundary(end) {
end -= 1;
}
format!("{}", &turn_text[..end])
} else {
turn_text.clone()
};
@@ -92,8 +97,13 @@ fn parse_consolidation_response(raw: &str, fallback_text: &str) -> Consolidation
serde_json::from_str(cleaned).unwrap_or_else(|_| {
// Fallback: use truncated turn text as history entry.
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
let summary = if fallback_text.len() > 200 {
format!("{}", &fallback_text[..200])
let mut end = 200;
while end > 0 && !fallback_text.is_char_boundary(end) {
end -= 1;
}
format!("{}", &fallback_text[..end])
} else {
fallback_text.to_string()
};
@@ -150,4 +160,16 @@ mod tests {
// 200 bytes + "…" (3 bytes in UTF-8) = 203
assert!(result.history_entry.len() <= 203);
}
#[test]
fn fallback_truncates_cjk_text_without_panic() {
// Each CJK character is 3 bytes in UTF-8; byte index 200 may land
// inside a character. This must not panic.
let cjk_text = "二手书项目".repeat(50); // 250 chars = 750 bytes
let result = parse_consolidation_response("invalid", &cjk_text);
assert!(result
.history_entry
.is_char_boundary(result.history_entry.len()));
assert!(result.history_entry.ends_with('…'));
}
}
+3
View File
@@ -0,0 +1,3 @@
pub mod transport;
pub use transport::NodeTransport;
+235
View File
@@ -0,0 +1,235 @@
//! Corporate-friendly secure node transport using standard HTTPS + HMAC-SHA256 authentication.
//!
//! All inter-node traffic uses plain HTTPS on port 443 — no exotic protocols,
//! no custom binary framing, no UDP tunneling. This makes the transport
//! compatible with corporate proxies, firewalls, and IT audit expectations.
use anyhow::{bail, Result};
use chrono::Utc;
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
/// Signs a request payload with HMAC-SHA256.
///
/// Uses `timestamp` + `nonce` alongside the payload to prevent replay attacks.
pub fn sign_request(
shared_secret: &str,
payload: &[u8],
timestamp: i64,
nonce: &str,
) -> Result<String> {
let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes())
.map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?;
mac.update(&timestamp.to_le_bytes());
mac.update(nonce.as_bytes());
mac.update(payload);
Ok(hex::encode(mac.finalize().into_bytes()))
}
/// Verify a signed request, rejecting stale timestamps for replay protection.
pub fn verify_request(
shared_secret: &str,
payload: &[u8],
timestamp: i64,
nonce: &str,
signature: &str,
max_age_secs: i64,
) -> Result<bool> {
let now = Utc::now().timestamp();
if (now - timestamp).abs() > max_age_secs {
bail!("Request timestamp too old or too far in future");
}
let expected = sign_request(shared_secret, payload, timestamp, nonce)?;
Ok(constant_time_eq(expected.as_bytes(), signature.as_bytes()))
}
/// Constant-time comparison to prevent timing attacks.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}
// ── Node transport client ───────────────────────────────────────
/// Sends authenticated HTTPS requests to peer nodes.
///
/// Every outgoing request carries three custom headers:
/// - `X-ZeroClaw-Timestamp` — unix epoch seconds
/// - `X-ZeroClaw-Nonce` — random UUID v4
/// - `X-ZeroClaw-Signature` — HMAC-SHA256 hex digest
///
/// Incoming requests are verified with the same scheme via [`Self::verify_incoming`].
pub struct NodeTransport {
http: reqwest::Client,
shared_secret: String,
max_request_age_secs: i64,
}
impl NodeTransport {
pub fn new(shared_secret: String) -> Self {
Self {
http: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("HTTP client build"),
shared_secret,
max_request_age_secs: 300, // 5 min replay window
}
}
/// Send an authenticated request to a peer node.
pub async fn send(
&self,
node_address: &str,
endpoint: &str,
payload: serde_json::Value,
) -> Result<serde_json::Value> {
let body = serde_json::to_vec(&payload)?;
let timestamp = Utc::now().timestamp();
let nonce = uuid::Uuid::new_v4().to_string();
let signature = sign_request(&self.shared_secret, &body, timestamp, &nonce)?;
let url = format!("https://{node_address}/api/node-control/{endpoint}");
let resp = self
.http
.post(&url)
.header("X-ZeroClaw-Timestamp", timestamp.to_string())
.header("X-ZeroClaw-Nonce", &nonce)
.header("X-ZeroClaw-Signature", &signature)
.header("Content-Type", "application/json")
.body(body)
.send()
.await?;
if !resp.status().is_success() {
bail!(
"Node request failed: {} {}",
resp.status(),
resp.text().await.unwrap_or_default()
);
}
Ok(resp.json().await?)
}
/// Verify an incoming request from a peer node.
pub fn verify_incoming(
&self,
payload: &[u8],
timestamp_header: &str,
nonce_header: &str,
signature_header: &str,
) -> Result<bool> {
let timestamp: i64 = timestamp_header
.parse()
.map_err(|_| anyhow::anyhow!("Invalid timestamp header"))?;
verify_request(
&self.shared_secret,
payload,
timestamp,
nonce_header,
signature_header,
self.max_request_age_secs,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &str = "test-shared-secret-key";
#[test]
fn sign_request_deterministic() {
let sig1 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
let sig2 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
assert_eq!(sig1, sig2, "Same inputs must produce the same signature");
}
#[test]
fn verify_request_accepts_valid_signature() {
let now = Utc::now().timestamp();
let sig = sign_request(TEST_SECRET, b"payload", now, "nonce-a").unwrap();
let ok = verify_request(TEST_SECRET, b"payload", now, "nonce-a", &sig, 300).unwrap();
assert!(ok, "Valid signature must pass verification");
}
#[test]
fn verify_request_rejects_tampered_payload() {
let now = Utc::now().timestamp();
let sig = sign_request(TEST_SECRET, b"original", now, "nonce-b").unwrap();
let ok = verify_request(TEST_SECRET, b"tampered", now, "nonce-b", &sig, 300).unwrap();
assert!(!ok, "Tampered payload must fail verification");
}
#[test]
fn verify_request_rejects_expired_timestamp() {
let old = Utc::now().timestamp() - 600;
let sig = sign_request(TEST_SECRET, b"data", old, "nonce-c").unwrap();
let result = verify_request(TEST_SECRET, b"data", old, "nonce-c", &sig, 300);
assert!(result.is_err(), "Expired timestamp must be rejected");
}
#[test]
fn verify_request_rejects_wrong_secret() {
let now = Utc::now().timestamp();
let sig = sign_request(TEST_SECRET, b"data", now, "nonce-d").unwrap();
let ok = verify_request("wrong-secret", b"data", now, "nonce-d", &sig, 300).unwrap();
assert!(!ok, "Wrong secret must fail verification");
}
#[test]
fn constant_time_eq_correctness() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"ab"));
assert!(!constant_time_eq(b"", b"a"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn node_transport_construction() {
let transport = NodeTransport::new("secret-key".into());
assert_eq!(transport.max_request_age_secs, 300);
}
#[test]
fn node_transport_verify_incoming_valid() {
let transport = NodeTransport::new(TEST_SECRET.into());
let now = Utc::now().timestamp();
let payload = b"test-body";
let nonce = "incoming-nonce";
let sig = sign_request(TEST_SECRET, payload, now, nonce).unwrap();
let ok = transport
.verify_incoming(payload, &now.to_string(), nonce, &sig)
.unwrap();
assert!(ok, "Valid incoming request must pass verification");
}
#[test]
fn node_transport_verify_incoming_bad_timestamp_header() {
let transport = NodeTransport::new(TEST_SECRET.into());
let result = transport.verify_incoming(b"body", "not-a-number", "nonce", "sig");
assert!(result.is_err(), "Non-numeric timestamp header must error");
}
#[test]
fn sign_request_different_nonce_different_signature() {
let sig1 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-1").unwrap();
let sig2 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-2").unwrap();
assert_ne!(
sig1, sig2,
"Different nonces must produce different signatures"
);
}
}
+2 -1
View File
@@ -4,7 +4,7 @@ pub mod wizard;
#[allow(unused_imports)]
pub use wizard::{
run_channels_repair_wizard, run_models_list, run_models_refresh, run_models_refresh_all,
run_models_set, run_models_status, run_quick_setup,
run_models_set, run_models_status, run_quick_setup, run_wizard,
};
#[cfg(test)]
@@ -17,6 +17,7 @@ mod tests {
fn wizard_functions_are_reexported() {
assert_reexport_exists(run_channels_repair_wizard);
assert_reexport_exists(run_quick_setup);
assert_reexport_exists(run_wizard);
assert_reexport_exists(run_models_refresh);
assert_reexport_exists(run_models_list);
assert_reexport_exists(run_models_set);
+31
View File
@@ -144,6 +144,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(),
security: crate::config::SecurityConfig::default(),
security_ops: crate::config::SecurityOpsConfig::default(),
runtime: RuntimeConfig::default(),
reliability: crate::config::ReliabilityConfig::default(),
scheduler: crate::config::schema::SchedulerConfig::default(),
@@ -159,17 +160,20 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
tunnel: tunnel_config,
gateway: crate::config::GatewayConfig::default(),
composio: composio_config,
microsoft365: crate::config::Microsoft365Config::default(),
secrets: secrets_config,
browser: BrowserConfig::default(),
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_fetch: crate::config::WebFetchConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
peripherals: crate::config::PeripheralsConfig::default(),
agents: std::collections::HashMap::new(),
swarms: std::collections::HashMap::new(),
hooks: crate::config::HooksConfig::default(),
hardware: hardware_config,
query_classification: crate::config::QueryClassificationConfig::default(),
@@ -177,6 +181,9 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
tts: crate::config::TtsConfig::default(),
mcp: crate::config::McpConfig::default(),
nodes: crate::config::NodesConfig::default(),
workspace: crate::config::WorkspaceConfig::default(),
notion: crate::config::NotionConfig::default(),
node_transport: crate::config::NodeTransportConfig::default(),
};
println!(
@@ -501,6 +508,7 @@ async fn run_quick_setup_with_home(
observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(),
security: crate::config::SecurityConfig::default(),
security_ops: crate::config::SecurityOpsConfig::default(),
runtime: RuntimeConfig::default(),
reliability: crate::config::ReliabilityConfig::default(),
scheduler: crate::config::schema::SchedulerConfig::default(),
@@ -516,17 +524,20 @@ async fn run_quick_setup_with_home(
tunnel: crate::config::TunnelConfig::default(),
gateway: crate::config::GatewayConfig::default(),
composio: ComposioConfig::default(),
microsoft365: crate::config::Microsoft365Config::default(),
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_fetch: crate::config::WebFetchConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
peripherals: crate::config::PeripheralsConfig::default(),
agents: std::collections::HashMap::new(),
swarms: std::collections::HashMap::new(),
hooks: crate::config::HooksConfig::default(),
hardware: crate::config::HardwareConfig::default(),
query_classification: crate::config::QueryClassificationConfig::default(),
@@ -534,6 +545,9 @@ async fn run_quick_setup_with_home(
tts: crate::config::TtsConfig::default(),
mcp: crate::config::McpConfig::default(),
nodes: crate::config::NodesConfig::default(),
workspace: crate::config::WorkspaceConfig::default(),
notion: crate::config::NotionConfig::default(),
node_transport: crate::config::NodeTransportConfig::default(),
};
config.save().await?;
@@ -4147,6 +4161,23 @@ fn setup_channels() -> Result<ChannelsConfig> {
.interact()?;
if mode_idx == 0 {
// Compile-time check: warn early if the feature is not enabled.
#[cfg(not(feature = "whatsapp-web"))]
{
println!();
println!(
" {} {}",
style("").yellow().bold(),
style("The 'whatsapp-web' feature is not compiled in. WhatsApp Web will not work at runtime.").yellow()
);
println!(
" {} Rebuild with: {}",
style("").dim(),
style("cargo build --features whatsapp-web").white().bold()
);
println!();
}
println!(" {}", style("Mode: WhatsApp Web").dim());
print_bullet("1. Build with --features whatsapp-web");
print_bullet(
+55 -5
View File
@@ -500,19 +500,23 @@ struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type")]
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
kind: Option<String>,
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
function: Option<Function>,
// Compatibility: Some providers (e.g., older GLM) may use 'name' directly
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
arguments: Option<String>,
// Compatibility: DeepSeek sometimes wraps arguments differently
#[serde(rename = "parameters", default)]
#[serde(
rename = "parameters",
default,
skip_serializing_if = "Option::is_none"
)]
parameters: Option<serde_json::Value>,
}
@@ -3094,4 +3098,50 @@ mod tests {
// Should not panic
let _client = p.http_client();
}
#[test]
fn tool_call_none_fields_omitted_from_json() {
// Ensures providers like Mistral that reject extra fields (e.g. "name": null)
// don't receive them when the ToolCall compat fields are None.
let tc = ToolCall {
id: Some("call_1".to_string()),
kind: Some("function".to_string()),
function: Some(Function {
name: Some("shell".to_string()),
arguments: Some("{\"command\":\"ls\"}".to_string()),
}),
name: None,
arguments: None,
parameters: None,
};
let json = serde_json::to_value(&tc).unwrap();
assert!(!json.as_object().unwrap().contains_key("name"));
assert!(!json.as_object().unwrap().contains_key("arguments"));
assert!(!json.as_object().unwrap().contains_key("parameters"));
// Standard fields must be present
assert!(json.as_object().unwrap().contains_key("id"));
assert!(json.as_object().unwrap().contains_key("type"));
assert!(json.as_object().unwrap().contains_key("function"));
}
#[test]
fn tool_call_with_compat_fields_serializes_them() {
// When compat fields are Some, they should appear in the output.
let tc = ToolCall {
id: None,
kind: None,
function: None,
name: Some("shell".to_string()),
arguments: Some("{\"command\":\"ls\"}".to_string()),
parameters: None,
};
let json = serde_json::to_value(&tc).unwrap();
assert_eq!(json["name"], "shell");
assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
// None fields should be omitted
assert!(!json.as_object().unwrap().contains_key("id"));
assert!(!json.as_object().unwrap().contains_key("type"));
assert!(!json.as_object().unwrap().contains_key("function"));
assert!(!json.as_object().unwrap().contains_key("parameters"));
}
}
+18 -1
View File
@@ -4,6 +4,7 @@ use crate::multimodal;
use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities};
use crate::providers::ProviderRuntimeOptions;
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -472,8 +473,24 @@ fn extract_stream_error_message(event: &Value) -> Option<String> {
None
}
/// Read the response body incrementally via `bytes_stream()` to avoid
/// buffering the entire SSE payload in memory. The previous implementation
/// used `response.text().await?` which holds the HTTP connection open until
/// every byte has arrived — on high-latency links the long-lived connection
/// often drops mid-read, producing the "error decoding response body" failure
/// reported in #3544.
async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<String> {
let body = response.text().await?;
let mut body = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let bytes = chunk
.map_err(|err| anyhow::anyhow!("error reading OpenAI Codex response stream: {err}"))?;
let text = std::str::from_utf8(&bytes).map_err(|err| {
anyhow::anyhow!("OpenAI Codex response contained invalid UTF-8: {err}")
})?;
body.push_str(text);
}
if let Some(text) = parse_sse_text(&body)? {
return Ok(text);
+449
View File
@@ -0,0 +1,449 @@
//! IAM-aware policy enforcement for Nevis role-to-permission mapping.
//!
//! Evaluates tool and workspace access based on Nevis roles using a
//! deny-by-default policy model. All policy decisions are audit-logged.
use super::nevis::NevisIdentity;
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Maps a single Nevis role to ZeroClaw permissions.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoleMapping {
/// Nevis role name (case-insensitive matching).
pub nevis_role: String,
/// Tool names this role can access. Use `"all"` to grant all tools.
pub zeroclaw_permissions: Vec<String>,
/// Workspace names this role can access. Use `"all"` for unrestricted.
#[serde(default)]
pub workspace_access: Vec<String>,
}
/// Result of a policy evaluation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyDecision {
/// Access is allowed.
Allow,
/// Access is denied, with reason.
Deny(String),
}
impl PolicyDecision {
pub fn is_allowed(&self) -> bool {
matches!(self, PolicyDecision::Allow)
}
}
/// IAM policy engine that maps Nevis roles to ZeroClaw tool permissions.
///
/// Deny-by-default: if no role mapping grants access, the request is denied.
#[derive(Debug, Clone)]
pub struct IamPolicy {
/// Compiled role mappings indexed by lowercase Nevis role name.
role_map: HashMap<String, CompiledRole>,
}
#[derive(Debug, Clone)]
struct CompiledRole {
/// Whether this role has access to all tools.
all_tools: bool,
/// Specific tool names this role can access (lowercase).
allowed_tools: Vec<String>,
/// Whether this role has access to all workspaces.
all_workspaces: bool,
/// Specific workspace names this role can access (lowercase).
allowed_workspaces: Vec<String>,
}
impl IamPolicy {
/// Build a policy from role mappings (typically from config).
///
/// Returns an error if duplicate normalized role names are detected,
/// since silent last-wins overwrites can accidentally broaden or revoke access.
pub fn from_mappings(mappings: &[RoleMapping]) -> Result<Self> {
let mut role_map = HashMap::new();
for mapping in mappings {
let key = mapping.nevis_role.trim().to_ascii_lowercase();
if key.is_empty() {
continue;
}
let all_tools = mapping
.zeroclaw_permissions
.iter()
.any(|p| p.eq_ignore_ascii_case("all"));
let allowed_tools: Vec<String> = mapping
.zeroclaw_permissions
.iter()
.filter(|p| !p.eq_ignore_ascii_case("all"))
.map(|p| p.trim().to_ascii_lowercase())
.collect();
let all_workspaces = mapping
.workspace_access
.iter()
.any(|w| w.eq_ignore_ascii_case("all"));
let allowed_workspaces: Vec<String> = mapping
.workspace_access
.iter()
.filter(|w| !w.eq_ignore_ascii_case("all"))
.map(|w| w.trim().to_ascii_lowercase())
.collect();
if role_map.contains_key(&key) {
bail!(
"IAM policy: duplicate role mapping for normalized key '{}' \
(from nevis_role '{}') remove or merge the duplicate entry",
key,
mapping.nevis_role
);
}
role_map.insert(
key,
CompiledRole {
all_tools,
allowed_tools,
all_workspaces,
allowed_workspaces,
},
);
}
Ok(Self { role_map })
}
/// Evaluate whether an identity is allowed to use a specific tool.
///
/// Deny-by-default: returns `Deny` unless at least one of the identity's
/// roles grants access to the requested tool.
pub fn evaluate_tool_access(
&self,
identity: &NevisIdentity,
tool_name: &str,
) -> PolicyDecision {
let normalized_tool = tool_name.trim().to_ascii_lowercase();
if normalized_tool.is_empty() {
return PolicyDecision::Deny("empty tool name".into());
}
for role in &identity.roles {
let key = role.trim().to_ascii_lowercase();
if let Some(compiled) = self.role_map.get(&key) {
if compiled.all_tools
|| compiled.allowed_tools.iter().any(|t| t == &normalized_tool)
{
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
role = %key,
tool = %normalized_tool,
"IAM policy: tool access ALLOWED"
);
return PolicyDecision::Allow;
}
}
}
let reason = format!(
"no role grants access to tool '{normalized_tool}' for user '{}'",
crate::security::redact(&identity.user_id)
);
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
tool = %normalized_tool,
"IAM policy: tool access DENIED"
);
PolicyDecision::Deny(reason)
}
/// Evaluate whether an identity is allowed to access a specific workspace.
///
/// Deny-by-default: returns `Deny` unless at least one of the identity's
/// roles grants access to the requested workspace.
pub fn evaluate_workspace_access(
&self,
identity: &NevisIdentity,
workspace: &str,
) -> PolicyDecision {
let normalized_ws = workspace.trim().to_ascii_lowercase();
if normalized_ws.is_empty() {
return PolicyDecision::Deny("empty workspace name".into());
}
for role in &identity.roles {
let key = role.trim().to_ascii_lowercase();
if let Some(compiled) = self.role_map.get(&key) {
if compiled.all_workspaces
|| compiled
.allowed_workspaces
.iter()
.any(|w| w == &normalized_ws)
{
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
role = %key,
workspace = %normalized_ws,
"IAM policy: workspace access ALLOWED"
);
return PolicyDecision::Allow;
}
}
}
let reason = format!(
"no role grants access to workspace '{normalized_ws}' for user '{}'",
crate::security::redact(&identity.user_id)
);
tracing::info!(
user_id = %crate::security::redact(&identity.user_id),
workspace = %normalized_ws,
"IAM policy: workspace access DENIED"
);
PolicyDecision::Deny(reason)
}
/// Check if the policy has any role mappings configured.
pub fn is_empty(&self) -> bool {
self.role_map.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_mappings() -> Vec<RoleMapping> {
vec![
RoleMapping {
nevis_role: "admin".into(),
zeroclaw_permissions: vec!["all".into()],
workspace_access: vec!["all".into()],
},
RoleMapping {
nevis_role: "operator".into(),
zeroclaw_permissions: vec![
"shell".into(),
"file_read".into(),
"file_write".into(),
"memory_search".into(),
],
workspace_access: vec!["production".into(), "staging".into()],
},
RoleMapping {
nevis_role: "viewer".into(),
zeroclaw_permissions: vec!["file_read".into(), "memory_search".into()],
workspace_access: vec!["staging".into()],
},
]
}
fn identity_with_roles(roles: Vec<&str>) -> NevisIdentity {
NevisIdentity {
user_id: "zeroclaw_user".into(),
roles: roles.into_iter().map(String::from).collect(),
scopes: vec!["openid".into()],
mfa_verified: true,
session_expiry: u64::MAX,
}
}
#[test]
fn admin_gets_all_tools() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "any_tool_name")
.is_allowed());
}
#[test]
fn admin_gets_all_workspaces() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(policy
.evaluate_workspace_access(&identity, "production")
.is_allowed());
assert!(policy
.evaluate_workspace_access(&identity, "any_workspace")
.is_allowed());
}
#[test]
fn operator_gets_subset_of_tools() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["operator"]);
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(!policy
.evaluate_tool_access(&identity, "browser")
.is_allowed());
}
#[test]
fn operator_workspace_access_is_scoped() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["operator"]);
assert!(policy
.evaluate_workspace_access(&identity, "production")
.is_allowed());
assert!(policy
.evaluate_workspace_access(&identity, "staging")
.is_allowed());
assert!(!policy
.evaluate_workspace_access(&identity, "development")
.is_allowed());
}
#[test]
fn viewer_is_read_only() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["viewer"]);
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "memory_search")
.is_allowed());
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(!policy
.evaluate_tool_access(&identity, "file_write")
.is_allowed());
}
#[test]
fn deny_by_default_for_unknown_role() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["unknown_role"]);
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
assert!(!policy
.evaluate_workspace_access(&identity, "production")
.is_allowed());
}
#[test]
fn deny_by_default_for_no_roles() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec![]);
assert!(!policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
}
#[test]
fn multiple_roles_union_permissions() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["viewer", "operator"]);
// viewer has file_read, operator has shell — both should be accessible
assert!(policy
.evaluate_tool_access(&identity, "file_read")
.is_allowed());
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
}
#[test]
fn role_matching_is_case_insensitive() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["ADMIN"]);
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
}
#[test]
fn tool_matching_is_case_insensitive() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["operator"]);
assert!(policy.evaluate_tool_access(&identity, "SHELL").is_allowed());
assert!(policy
.evaluate_tool_access(&identity, "File_Read")
.is_allowed());
}
#[test]
fn empty_tool_name_is_denied() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(!policy.evaluate_tool_access(&identity, "").is_allowed());
assert!(!policy.evaluate_tool_access(&identity, " ").is_allowed());
}
#[test]
fn empty_workspace_name_is_denied() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(!policy.evaluate_workspace_access(&identity, "").is_allowed());
}
#[test]
fn empty_mappings_deny_everything() {
let policy = IamPolicy::from_mappings(&[]).unwrap();
let identity = identity_with_roles(vec!["admin"]);
assert!(policy.is_empty());
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
}
#[test]
fn policy_decision_deny_contains_reason() {
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
let identity = identity_with_roles(vec!["viewer"]);
let decision = policy.evaluate_tool_access(&identity, "shell");
match decision {
PolicyDecision::Deny(reason) => {
assert!(reason.contains("shell"));
}
PolicyDecision::Allow => panic!("expected deny"),
}
}
#[test]
fn duplicate_normalized_roles_are_rejected() {
let mappings = vec![
RoleMapping {
nevis_role: "admin".into(),
zeroclaw_permissions: vec!["all".into()],
workspace_access: vec!["all".into()],
},
RoleMapping {
nevis_role: " ADMIN ".into(),
zeroclaw_permissions: vec!["file_read".into()],
workspace_access: vec![],
},
];
let err = IamPolicy::from_mappings(&mappings).unwrap_err();
assert!(
err.to_string().contains("duplicate role mapping"),
"Expected duplicate role error, got: {err}"
);
}
#[test]
fn empty_role_name_in_mapping_is_skipped() {
let mappings = vec![RoleMapping {
nevis_role: " ".into(),
zeroclaw_permissions: vec!["all".into()],
workspace_access: vec![],
}];
let policy = IamPolicy::from_mappings(&mappings).unwrap();
assert!(policy.is_empty());
}
}
+27 -3
View File
@@ -29,15 +29,20 @@ pub mod domain_matcher;
pub mod estop;
#[cfg(target_os = "linux")]
pub mod firejail;
pub mod iam_policy;
#[cfg(feature = "sandbox-landlock")]
pub mod landlock;
pub mod leak_detector;
pub mod nevis;
pub mod otp;
pub mod pairing;
pub mod playbook;
pub mod policy;
pub mod prompt_guard;
pub mod secrets;
pub mod traits;
pub mod vulnerability;
pub mod workspace_boundary;
#[allow(unused_imports)]
pub use audit::{AuditEvent, AuditEventType, AuditLogger};
@@ -55,19 +60,29 @@ pub use policy::{AutonomyLevel, SecurityPolicy};
pub use secrets::SecretStore;
#[allow(unused_imports)]
pub use traits::{NoopSandbox, Sandbox};
// Nevis IAM integration
#[allow(unused_imports)]
pub use iam_policy::{IamPolicy, PolicyDecision};
#[allow(unused_imports)]
pub use nevis::{NevisAuthProvider, NevisIdentity};
// Prompt injection defense exports
#[allow(unused_imports)]
pub use leak_detector::{LeakDetector, LeakResult};
#[allow(unused_imports)]
pub use prompt_guard::{GuardAction, GuardResult, PromptGuard};
#[allow(unused_imports)]
pub use workspace_boundary::{BoundaryVerdict, WorkspaceBoundary};
/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix.
/// Redact sensitive values for safe logging. Shows first 4 characters + "***" suffix.
/// Uses char-boundary-safe indexing to avoid panics on multi-byte UTF-8 strings.
/// This function intentionally breaks the data-flow taint chain for static analysis.
pub fn redact(value: &str) -> String {
if value.len() <= 4 {
let char_count = value.chars().count();
if char_count <= 4 {
"***".to_string()
} else {
format!("{}***", &value[..4])
let prefix: String = value.chars().take(4).collect();
format!("{prefix}***")
}
}
@@ -102,4 +117,13 @@ mod tests {
assert_eq!(redact(""), "***");
assert_eq!(redact("12345"), "1234***");
}
#[test]
fn redact_handles_multibyte_utf8_without_panic() {
// CJK characters are 3 bytes each; slicing at byte 4 would panic
// without char-boundary-safe handling.
let result = redact("密码是很长的秘密");
assert!(result.ends_with("***"));
assert!(result.is_char_boundary(result.len()));
}
}
+587
View File
@@ -0,0 +1,587 @@
//! Nevis IAM authentication provider for ZeroClaw.
//!
//! Integrates with Nevis Security Suite (Adnovum) for OAuth2/OIDC token
//! validation, FIDO2/passkey verification, and session management. Maps Nevis
//! roles to ZeroClaw tool permissions via [`super::iam_policy::IamPolicy`].
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// Identity resolved from a validated Nevis token or session.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NevisIdentity {
/// Unique user identifier from Nevis.
pub user_id: String,
/// Nevis roles assigned to this user.
pub roles: Vec<String>,
/// OAuth2 scopes granted to this session.
pub scopes: Vec<String>,
/// Whether the user completed MFA (FIDO2/passkey/OTP) in this session.
pub mfa_verified: bool,
/// When this session expires (seconds since UNIX epoch).
pub session_expiry: u64,
}
/// Token validation strategy.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenValidationMode {
/// Validate JWT locally using cached JWKS keys.
Local,
/// Validate token by calling the Nevis introspection endpoint.
Remote,
}
impl TokenValidationMode {
pub fn from_str_config(s: &str) -> Result<Self> {
match s.to_ascii_lowercase().as_str() {
"local" => Ok(Self::Local),
"remote" => Ok(Self::Remote),
other => bail!("invalid token_validation mode '{other}': expected 'local' or 'remote'"),
}
}
}
/// Authentication provider backed by a Nevis instance.
///
/// Validates tokens, manages sessions, and resolves identities. The provider
/// is designed to be shared across concurrent requests (`Send + Sync`).
pub struct NevisAuthProvider {
/// Base URL of the Nevis instance (e.g. `https://nevis.example.com`).
instance_url: String,
/// Nevis realm to authenticate against.
realm: String,
/// OAuth2 client ID registered in Nevis.
client_id: String,
/// OAuth2 client secret (decrypted at startup).
client_secret: Option<String>,
/// Token validation strategy.
validation_mode: TokenValidationMode,
/// JWKS endpoint for local token validation.
jwks_url: Option<String>,
/// Whether MFA is required for all authentications.
require_mfa: bool,
/// Session timeout duration.
session_timeout: Duration,
/// HTTP client for Nevis API calls.
http_client: reqwest::Client,
}
impl std::fmt::Debug for NevisAuthProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NevisAuthProvider")
.field("instance_url", &self.instance_url)
.field("realm", &self.realm)
.field("client_id", &self.client_id)
.field(
"client_secret",
&self.client_secret.as_ref().map(|_| "[REDACTED]"),
)
.field("validation_mode", &self.validation_mode)
.field("jwks_url", &self.jwks_url)
.field("require_mfa", &self.require_mfa)
.field("session_timeout", &self.session_timeout)
.finish_non_exhaustive()
}
}
// Safety: All fields are Send + Sync. The doc comment promises concurrent use,
// so enforce it at compile time to prevent regressions.
#[allow(clippy::used_underscore_items)]
const _: () = {
fn _assert_send_sync<T: Send + Sync>() {}
fn _assert() {
_assert_send_sync::<NevisAuthProvider>();
}
};
impl NevisAuthProvider {
/// Create a new Nevis auth provider from config values.
///
/// `client_secret` should already be decrypted by the config loader.
pub fn new(
instance_url: String,
realm: String,
client_id: String,
client_secret: Option<String>,
token_validation: &str,
jwks_url: Option<String>,
require_mfa: bool,
session_timeout_secs: u64,
) -> Result<Self> {
let validation_mode = TokenValidationMode::from_str_config(token_validation)?;
if validation_mode == TokenValidationMode::Local && jwks_url.is_none() {
bail!(
"Nevis token_validation is 'local' but no jwks_url is configured. \
Either set jwks_url or use token_validation = 'remote'."
);
}
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.context("Failed to create HTTP client for Nevis")?;
Ok(Self {
instance_url,
realm,
client_id,
client_secret,
validation_mode,
jwks_url,
require_mfa,
session_timeout: Duration::from_secs(session_timeout_secs),
http_client,
})
}
/// Validate a bearer token and resolve the caller's identity.
///
/// Returns `NevisIdentity` on success, or an error if the token is invalid,
/// expired, or MFA requirements are not met.
pub async fn validate_token(&self, token: &str) -> Result<NevisIdentity> {
if token.is_empty() {
bail!("empty bearer token");
}
let identity = match self.validation_mode {
TokenValidationMode::Local => self.validate_token_local(token).await?,
TokenValidationMode::Remote => self.validate_token_remote(token).await?,
};
if self.require_mfa && !identity.mfa_verified {
bail!(
"MFA is required but user '{}' has not completed MFA verification",
crate::security::redact(&identity.user_id)
);
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if identity.session_expiry > 0 && identity.session_expiry < now {
bail!("Nevis session expired");
}
Ok(identity)
}
/// Validate token by calling the Nevis introspection endpoint.
async fn validate_token_remote(&self, token: &str) -> Result<NevisIdentity> {
let introspect_url = format!(
"{}/auth/realms/{}/protocol/openid-connect/token/introspect",
self.instance_url.trim_end_matches('/'),
self.realm,
);
let mut form = vec![("token", token), ("client_id", &self.client_id)];
// client_secret is optional (public clients don't need it)
let secret_ref;
if let Some(ref secret) = self.client_secret {
secret_ref = secret.as_str();
form.push(("client_secret", secret_ref));
}
let resp = self
.http_client
.post(&introspect_url)
.form(&form)
.send()
.await
.context("Failed to reach Nevis introspection endpoint")?;
if !resp.status().is_success() {
bail!(
"Nevis introspection returned HTTP {}",
resp.status().as_u16()
);
}
let body: IntrospectionResponse = resp
.json()
.await
.context("Failed to parse Nevis introspection response")?;
if !body.active {
bail!("Token is not active (revoked or expired)");
}
let user_id = body
.sub
.filter(|s| !s.trim().is_empty())
.context("Token has missing or empty `sub` claim")?;
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
roles.sort();
roles.dedup();
Ok(NevisIdentity {
user_id,
roles,
scopes: body
.scope
.unwrap_or_default()
.split_whitespace()
.map(String::from)
.collect(),
mfa_verified: body.acr.as_deref() == Some("mfa")
|| body
.amr
.iter()
.flatten()
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
session_expiry: body.exp.unwrap_or(0),
})
}
/// Validate token locally using JWKS.
///
/// Local JWT/JWKS validation is not yet implemented. Rather than silently
/// falling back to the remote introspection endpoint (which would hide a
/// misconfiguration), this returns an explicit error directing the operator
/// to use `token_validation = "remote"` until local JWKS support is added.
#[allow(clippy::unused_async)] // Will use async when JWKS validation is implemented
async fn validate_token_local(&self, token: &str) -> Result<NevisIdentity> {
// JWT structure check: header.payload.signature
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
bail!("Invalid JWT structure: expected 3 dot-separated parts");
}
bail!(
"Local JWKS token validation is not yet implemented. \
Set token_validation = \"remote\" to use the Nevis introspection endpoint."
);
}
/// Validate a Nevis session token (cookie-based sessions).
pub async fn validate_session(&self, session_token: &str) -> Result<NevisIdentity> {
if session_token.is_empty() {
bail!("empty session token");
}
let session_url = format!(
"{}/auth/realms/{}/protocol/openid-connect/userinfo",
self.instance_url.trim_end_matches('/'),
self.realm,
);
let resp = self
.http_client
.get(&session_url)
.bearer_auth(session_token)
.send()
.await
.context("Failed to reach Nevis userinfo endpoint")?;
if !resp.status().is_success() {
bail!(
"Nevis session validation returned HTTP {}",
resp.status().as_u16()
);
}
let body: UserInfoResponse = resp
.json()
.await
.context("Failed to parse Nevis userinfo response")?;
if body.sub.trim().is_empty() {
bail!("Userinfo response has missing or empty `sub` claim");
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
roles.sort();
roles.dedup();
let identity = NevisIdentity {
user_id: body.sub,
roles,
scopes: body
.scope
.unwrap_or_default()
.split_whitespace()
.map(String::from)
.collect(),
mfa_verified: body.acr.as_deref() == Some("mfa")
|| body
.amr
.iter()
.flatten()
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
session_expiry: now + self.session_timeout.as_secs(),
};
if self.require_mfa && !identity.mfa_verified {
bail!(
"MFA is required but user '{}' has not completed MFA verification",
crate::security::redact(&identity.user_id)
);
}
Ok(identity)
}
/// Health check against the Nevis instance.
pub async fn health_check(&self) -> Result<()> {
let health_url = format!(
"{}/auth/realms/{}",
self.instance_url.trim_end_matches('/'),
self.realm,
);
let resp = self
.http_client
.get(&health_url)
.send()
.await
.context("Nevis health check failed: cannot reach instance")?;
if !resp.status().is_success() {
bail!("Nevis health check failed: HTTP {}", resp.status().as_u16());
}
Ok(())
}
/// Getter for instance URL (for diagnostics).
pub fn instance_url(&self) -> &str {
&self.instance_url
}
/// Getter for realm.
pub fn realm(&self) -> &str {
&self.realm
}
}
// ── Wire types for Nevis API responses ─────────────────────────────
#[derive(Debug, Deserialize)]
struct IntrospectionResponse {
active: bool,
sub: Option<String>,
scope: Option<String>,
exp: Option<u64>,
#[serde(rename = "realm_access")]
realm_access: Option<RealmAccess>,
/// Authentication Context Class Reference
acr: Option<String>,
/// Authentication Methods References
amr: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct RealmAccess {
#[serde(default)]
roles: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct UserInfoResponse {
sub: String,
#[serde(rename = "realm_access")]
realm_access: Option<RealmAccess>,
scope: Option<String>,
acr: Option<String>,
/// Authentication Methods References
amr: Option<Vec<String>>,
}
// ── Tests ──────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_validation_mode_from_str() {
assert_eq!(
TokenValidationMode::from_str_config("local").unwrap(),
TokenValidationMode::Local
);
assert_eq!(
TokenValidationMode::from_str_config("REMOTE").unwrap(),
TokenValidationMode::Remote
);
assert!(TokenValidationMode::from_str_config("invalid").is_err());
}
#[test]
fn local_mode_requires_jwks_url() {
let result = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"local",
None, // no JWKS URL
false,
3600,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("jwks_url"));
}
#[test]
fn remote_mode_works_without_jwks_url() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"remote",
None,
false,
3600,
);
assert!(provider.is_ok());
}
#[test]
fn provider_stores_config_correctly() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"test-realm".into(),
"zeroclaw-client".into(),
Some("test-secret".into()),
"remote",
None,
true,
7200,
)
.unwrap();
assert_eq!(provider.instance_url(), "https://nevis.example.com");
assert_eq!(provider.realm(), "test-realm");
assert!(provider.require_mfa);
assert_eq!(provider.session_timeout, Duration::from_secs(7200));
}
#[test]
fn debug_redacts_client_secret() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"test-realm".into(),
"zeroclaw-client".into(),
Some("super-secret-value".into()),
"remote",
None,
false,
3600,
)
.unwrap();
let debug_output = format!("{:?}", provider);
assert!(
!debug_output.contains("super-secret-value"),
"Debug output must not contain the raw client_secret"
);
assert!(
debug_output.contains("[REDACTED]"),
"Debug output must show [REDACTED] for client_secret"
);
}
#[tokio::test]
async fn validate_token_rejects_empty() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"remote",
None,
false,
3600,
)
.unwrap();
let err = provider.validate_token("").await.unwrap_err();
assert!(err.to_string().contains("empty bearer token"));
}
#[tokio::test]
async fn validate_session_rejects_empty() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"remote",
None,
false,
3600,
)
.unwrap();
let err = provider.validate_session("").await.unwrap_err();
assert!(err.to_string().contains("empty session token"));
}
#[test]
fn nevis_identity_serde_roundtrip() {
let identity = NevisIdentity {
user_id: "zeroclaw_user".into(),
roles: vec!["admin".into(), "operator".into()],
scopes: vec!["openid".into(), "profile".into()],
mfa_verified: true,
session_expiry: 1_700_000_000,
};
let json = serde_json::to_string(&identity).unwrap();
let parsed: NevisIdentity = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.user_id, "zeroclaw_user");
assert_eq!(parsed.roles.len(), 2);
assert!(parsed.mfa_verified);
}
#[tokio::test]
async fn local_validation_rejects_malformed_jwt() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"local",
Some("https://nevis.example.com/.well-known/jwks.json".into()),
false,
3600,
)
.unwrap();
let err = provider.validate_token("not-a-jwt").await.unwrap_err();
assert!(err.to_string().contains("Invalid JWT structure"));
}
#[tokio::test]
async fn local_validation_errors_instead_of_silent_fallback() {
let provider = NevisAuthProvider::new(
"https://nevis.example.com".into(),
"master".into(),
"zeroclaw-client".into(),
None,
"local",
Some("https://nevis.example.com/.well-known/jwks.json".into()),
false,
3600,
)
.unwrap();
// A well-formed JWT structure should hit the "not yet implemented" error
// instead of silently falling back to remote introspection.
let err = provider
.validate_token("header.payload.signature")
.await
.unwrap_err();
assert!(err.to_string().contains("not yet implemented"));
}
}
+459
View File
@@ -0,0 +1,459 @@
//! Incident response playbook definitions and execution engine.
//!
//! Playbooks define structured response procedures for security incidents.
//! Each playbook has named steps, some of which require human approval before
//! execution. Playbooks are loaded from JSON files in the configured directory.
use serde::{Deserialize, Serialize};
use std::path::Path;
/// A single step in an incident response playbook.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PlaybookStep {
/// Machine-readable action identifier (e.g. "isolate_host", "block_ip").
pub action: String,
/// Human-readable description of what this step does.
pub description: String,
/// Whether this step requires explicit human approval before execution.
#[serde(default)]
pub requires_approval: bool,
/// Timeout in seconds for this step. Default: 300 (5 minutes).
#[serde(default = "default_timeout_secs")]
pub timeout_secs: u64,
}
fn default_timeout_secs() -> u64 {
300
}
/// An incident response playbook.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Playbook {
/// Unique playbook name (e.g. "suspicious_login").
pub name: String,
/// Human-readable description.
pub description: String,
/// Ordered list of response steps.
pub steps: Vec<PlaybookStep>,
/// Minimum alert severity that triggers this playbook (low/medium/high/critical).
#[serde(default = "default_severity_filter")]
pub severity_filter: String,
/// Step indices (0-based) that can be auto-approved when below max_auto_severity.
#[serde(default)]
pub auto_approve_steps: Vec<usize>,
}
fn default_severity_filter() -> String {
"medium".into()
}
/// Result of executing a single playbook step.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepExecutionResult {
pub step_index: usize,
pub action: String,
pub status: StepStatus,
pub message: String,
}
/// Status of a playbook step.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum StepStatus {
/// Step completed successfully.
Completed,
/// Step is waiting for human approval.
PendingApproval,
/// Step was skipped (e.g. not applicable).
Skipped,
/// Step failed with an error.
Failed,
}
impl std::fmt::Display for StepStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Completed => write!(f, "completed"),
Self::PendingApproval => write!(f, "pending_approval"),
Self::Skipped => write!(f, "skipped"),
Self::Failed => write!(f, "failed"),
}
}
}
/// Load all playbook definitions from a directory of JSON files.
pub fn load_playbooks(dir: &Path) -> Vec<Playbook> {
let mut playbooks = Vec::new();
if !dir.exists() || !dir.is_dir() {
return builtin_playbooks();
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map_or(false, |ext| ext == "json") {
match std::fs::read_to_string(&path) {
Ok(contents) => match serde_json::from_str::<Playbook>(&contents) {
Ok(pb) => playbooks.push(pb),
Err(e) => {
tracing::warn!("Failed to parse playbook {}: {e}", path.display());
}
},
Err(e) => {
tracing::warn!("Failed to read playbook {}: {e}", path.display());
}
}
}
}
}
// Merge built-in playbooks that aren't overridden by user-defined ones
for builtin in builtin_playbooks() {
if !playbooks.iter().any(|p| p.name == builtin.name) {
playbooks.push(builtin);
}
}
playbooks
}
/// Severity ordering for comparison: low < medium < high < critical.
pub fn severity_level(severity: &str) -> u8 {
match severity.to_lowercase().as_str() {
"low" => 1,
"medium" => 2,
"high" => 3,
"critical" => 4,
// Deny-by-default: unknown severities get the highest level to prevent
// auto-approval of unrecognized severity labels.
_ => u8::MAX,
}
}
/// Check whether a step can be auto-approved given config constraints.
pub fn can_auto_approve(
playbook: &Playbook,
step_index: usize,
alert_severity: &str,
max_auto_severity: &str,
) -> bool {
// Never auto-approve if alert severity exceeds the configured max
if severity_level(alert_severity) > severity_level(max_auto_severity) {
return false;
}
// Only auto-approve steps explicitly listed in auto_approve_steps
playbook.auto_approve_steps.contains(&step_index)
}
/// Evaluate a playbook step. Returns the result with approval gating.
///
/// Steps that require approval and cannot be auto-approved will return
/// `StepStatus::PendingApproval` without executing.
pub fn evaluate_step(
playbook: &Playbook,
step_index: usize,
alert_severity: &str,
max_auto_severity: &str,
require_approval: bool,
) -> StepExecutionResult {
let step = match playbook.steps.get(step_index) {
Some(s) => s,
None => {
return StepExecutionResult {
step_index,
action: "unknown".into(),
status: StepStatus::Failed,
message: format!("Step index {step_index} out of range"),
};
}
};
// Enforce approval gates: steps that require approval must either be
// auto-approved or wait for human approval. Never mark an unexecuted
// approval-gated step as Completed.
if step.requires_approval
&& (!require_approval
|| !can_auto_approve(playbook, step_index, alert_severity, max_auto_severity))
{
return StepExecutionResult {
step_index,
action: step.action.clone(),
status: StepStatus::PendingApproval,
message: format!(
"Step '{}' requires human approval (severity: {alert_severity})",
step.description
),
};
}
// Step is approved (either doesn't require approval, or was auto-approved)
// Actual execution would be delegated to the appropriate tool/system
StepExecutionResult {
step_index,
action: step.action.clone(),
status: StepStatus::Completed,
message: format!("Executed: {}", step.description),
}
}
/// Built-in playbook definitions for common incident types.
pub fn builtin_playbooks() -> Vec<Playbook> {
vec![
Playbook {
name: "suspicious_login".into(),
description: "Respond to suspicious login activity detected by SIEM".into(),
steps: vec![
PlaybookStep {
action: "gather_login_context".into(),
description: "Collect login metadata: IP, geo, device fingerprint, time".into(),
requires_approval: false,
timeout_secs: 60,
},
PlaybookStep {
action: "check_threat_intel".into(),
description: "Query threat intelligence for source IP reputation".into(),
requires_approval: false,
timeout_secs: 30,
},
PlaybookStep {
action: "notify_user".into(),
description: "Send verification notification to account owner".into(),
requires_approval: true,
timeout_secs: 300,
},
PlaybookStep {
action: "force_password_reset".into(),
description: "Force password reset if login confirmed unauthorized".into(),
requires_approval: true,
timeout_secs: 120,
},
],
severity_filter: "medium".into(),
auto_approve_steps: vec![0, 1],
},
Playbook {
name: "malware_detected".into(),
description: "Respond to malware detection on endpoint".into(),
steps: vec![
PlaybookStep {
action: "isolate_endpoint".into(),
description: "Network-isolate the affected endpoint".into(),
requires_approval: true,
timeout_secs: 60,
},
PlaybookStep {
action: "collect_forensics".into(),
description: "Capture memory dump and disk image for analysis".into(),
requires_approval: false,
timeout_secs: 600,
},
PlaybookStep {
action: "scan_lateral_movement".into(),
description: "Check for lateral movement indicators on adjacent hosts".into(),
requires_approval: false,
timeout_secs: 300,
},
PlaybookStep {
action: "remediate_endpoint".into(),
description: "Remove malware and restore endpoint to clean state".into(),
requires_approval: true,
timeout_secs: 600,
},
],
severity_filter: "high".into(),
auto_approve_steps: vec![1, 2],
},
Playbook {
name: "data_exfiltration_attempt".into(),
description: "Respond to suspected data exfiltration".into(),
steps: vec![
PlaybookStep {
action: "block_egress".into(),
description: "Block suspicious outbound connections".into(),
requires_approval: true,
timeout_secs: 30,
},
PlaybookStep {
action: "identify_data_scope".into(),
description: "Determine what data may have been accessed or transferred".into(),
requires_approval: false,
timeout_secs: 300,
},
PlaybookStep {
action: "preserve_evidence".into(),
description: "Preserve network logs and access records".into(),
requires_approval: false,
timeout_secs: 120,
},
PlaybookStep {
action: "escalate_to_legal".into(),
description: "Notify legal and compliance teams".into(),
requires_approval: true,
timeout_secs: 60,
},
],
severity_filter: "critical".into(),
auto_approve_steps: vec![1, 2],
},
Playbook {
name: "brute_force".into(),
description: "Respond to brute force authentication attempts".into(),
steps: vec![
PlaybookStep {
action: "block_source_ip".into(),
description: "Block the attacking source IP at firewall".into(),
requires_approval: true,
timeout_secs: 30,
},
PlaybookStep {
action: "check_compromised_accounts".into(),
description: "Check if any accounts were successfully compromised".into(),
requires_approval: false,
timeout_secs: 120,
},
PlaybookStep {
action: "enable_rate_limiting".into(),
description: "Enable enhanced rate limiting on auth endpoints".into(),
requires_approval: true,
timeout_secs: 60,
},
],
severity_filter: "medium".into(),
auto_approve_steps: vec![1],
},
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builtin_playbooks_are_valid() {
let playbooks = builtin_playbooks();
assert_eq!(playbooks.len(), 4);
let names: Vec<&str> = playbooks.iter().map(|p| p.name.as_str()).collect();
assert!(names.contains(&"suspicious_login"));
assert!(names.contains(&"malware_detected"));
assert!(names.contains(&"data_exfiltration_attempt"));
assert!(names.contains(&"brute_force"));
for pb in &playbooks {
assert!(!pb.steps.is_empty(), "Playbook {} has no steps", pb.name);
assert!(!pb.description.is_empty());
}
}
#[test]
fn severity_level_ordering() {
assert!(severity_level("low") < severity_level("medium"));
assert!(severity_level("medium") < severity_level("high"));
assert!(severity_level("high") < severity_level("critical"));
assert_eq!(severity_level("unknown"), u8::MAX);
}
#[test]
fn auto_approve_respects_severity_cap() {
let pb = &builtin_playbooks()[0]; // suspicious_login
// Step 0 is in auto_approve_steps
assert!(can_auto_approve(pb, 0, "low", "low"));
assert!(can_auto_approve(pb, 0, "low", "medium"));
// Alert severity exceeds max -> cannot auto-approve
assert!(!can_auto_approve(pb, 0, "high", "low"));
assert!(!can_auto_approve(pb, 0, "critical", "medium"));
// Step 2 is NOT in auto_approve_steps
assert!(!can_auto_approve(pb, 2, "low", "critical"));
}
#[test]
fn evaluate_step_requires_approval() {
let pb = &builtin_playbooks()[0]; // suspicious_login
// Step 2 (notify_user) requires approval, high severity, max=low -> pending
let result = evaluate_step(pb, 2, "high", "low", true);
assert_eq!(result.status, StepStatus::PendingApproval);
assert_eq!(result.action, "notify_user");
// Step 0 (gather_login_context) does NOT require approval -> completed
let result = evaluate_step(pb, 0, "high", "low", true);
assert_eq!(result.status, StepStatus::Completed);
}
#[test]
fn evaluate_step_out_of_range() {
let pb = &builtin_playbooks()[0];
let result = evaluate_step(pb, 99, "low", "low", true);
assert_eq!(result.status, StepStatus::Failed);
}
#[test]
fn playbook_json_roundtrip() {
let pb = &builtin_playbooks()[0];
let json = serde_json::to_string(pb).unwrap();
let parsed: Playbook = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, *pb);
}
#[test]
fn load_playbooks_from_nonexistent_dir_returns_builtins() {
let playbooks = load_playbooks(Path::new("/nonexistent/dir"));
assert_eq!(playbooks.len(), 4);
}
#[test]
fn load_playbooks_merges_custom_and_builtin() {
let dir = tempfile::tempdir().unwrap();
let custom = Playbook {
name: "custom_playbook".into(),
description: "A custom playbook".into(),
steps: vec![PlaybookStep {
action: "custom_action".into(),
description: "Do something custom".into(),
requires_approval: true,
timeout_secs: 60,
}],
severity_filter: "low".into(),
auto_approve_steps: vec![],
};
let json = serde_json::to_string(&custom).unwrap();
std::fs::write(dir.path().join("custom.json"), json).unwrap();
let playbooks = load_playbooks(dir.path());
// 4 builtins + 1 custom
assert_eq!(playbooks.len(), 5);
assert!(playbooks.iter().any(|p| p.name == "custom_playbook"));
}
#[test]
fn load_playbooks_custom_overrides_builtin() {
let dir = tempfile::tempdir().unwrap();
let override_pb = Playbook {
name: "suspicious_login".into(),
description: "Custom override".into(),
steps: vec![PlaybookStep {
action: "custom_step".into(),
description: "Overridden step".into(),
requires_approval: false,
timeout_secs: 30,
}],
severity_filter: "low".into(),
auto_approve_steps: vec![0],
};
let json = serde_json::to_string(&override_pb).unwrap();
std::fs::write(dir.path().join("suspicious_login.json"), json).unwrap();
let playbooks = load_playbooks(dir.path());
// 3 remaining builtins + 1 overridden = 4
assert_eq!(playbooks.len(), 4);
let sl = playbooks
.iter()
.find(|p| p.name == "suspicious_login")
.unwrap();
assert_eq!(sl.description, "Custom override");
}
}
+144 -3
View File
@@ -793,6 +793,8 @@ impl SecurityPolicy {
// 1. Allowlist check (is the base command permitted at all?)
// 2. Risk classification (high / medium / low)
// 3. Policy flags (block_high_risk_commands, require_approval_for_medium_risk)
// — explicit allowlist entries exempt a command from the high-risk block,
// but the wildcard "*" does NOT grant an exemption.
// 4. Autonomy level × approval status (supervised requires explicit approval)
// This ordering ensures deny-by-default: unknown commands are rejected
// before any risk or autonomy logic runs.
@@ -810,7 +812,7 @@ impl SecurityPolicy {
let risk = self.command_risk_level(command);
if risk == CommandRiskLevel::High {
if self.block_high_risk_commands {
if self.block_high_risk_commands && !self.is_command_explicitly_allowed(command) {
return Err("Command blocked: high-risk command is disallowed by policy".into());
}
if self.autonomy == AutonomyLevel::Supervised && !approved {
@@ -834,6 +836,48 @@ impl SecurityPolicy {
Ok(risk)
}
/// Check whether **every** segment of a command is explicitly listed in
/// `allowed_commands` — i.e., matched by a concrete entry rather than by
/// the wildcard `"*"`.
///
/// This is used to exempt explicitly-allowlisted high-risk commands from
/// the `block_high_risk_commands` gate. The wildcard entry intentionally
/// does **not** qualify as an explicit allowlist match, so that operators
/// who set `allowed_commands = ["*"]` still get the high-risk safety net.
fn is_command_explicitly_allowed(&self, command: &str) -> bool {
let segments = split_unquoted_segments(command);
for segment in &segments {
let cmd_part = skip_env_assignments(segment);
let mut words = cmd_part.split_whitespace();
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
let base_cmd_owned = command_basename(executable).to_ascii_lowercase();
let base_cmd = strip_windows_exe_suffix(&base_cmd_owned);
if base_cmd.is_empty() {
continue;
}
let explicitly_listed = self.allowed_commands.iter().any(|allowed| {
let allowed = strip_wrapping_quotes(allowed).trim();
// Skip wildcard — it does not count as an explicit entry.
if allowed.is_empty() || allowed == "*" {
return false;
}
is_allowlist_entry_match(allowed, executable, base_cmd)
});
if !explicitly_listed {
return false;
}
}
// At least one real command must be present.
segments.iter().any(|s| {
let s = skip_env_assignments(s.trim());
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
})
}
// ── Layered Command Allowlist ──────────────────────────────────────────
// Defence-in-depth: five independent gates run in order before the
// per-segment allowlist check. Each gate targets a specific bypass
@@ -1503,10 +1547,13 @@ mod tests {
}
#[test]
fn validate_command_blocks_high_risk_by_default() {
fn validate_command_blocks_high_risk_via_wildcard() {
// Wildcard allows the command through is_command_allowed, but
// block_high_risk_commands still rejects it because "*" does not
// count as an explicit allowlist entry.
let p = SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
allowed_commands: vec!["rm".into()],
allowed_commands: vec!["*".into()],
..SecurityPolicy::default()
};
@@ -1515,6 +1562,100 @@ mod tests {
assert!(result.unwrap_err().contains("high-risk"));
}
#[test]
fn validate_command_allows_explicitly_listed_high_risk() {
// When a high-risk command is explicitly in allowed_commands, the
// block_high_risk_commands gate is bypassed — the operator has made
// a deliberate decision to permit it.
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec!["curl".into()],
block_high_risk_commands: true,
..SecurityPolicy::default()
};
let result = p.validate_command_execution("curl https://api.example.com/data", true);
assert_eq!(result.unwrap(), CommandRiskLevel::High);
}
#[test]
fn validate_command_allows_wget_when_explicitly_listed() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec!["wget".into()],
block_high_risk_commands: true,
..SecurityPolicy::default()
};
let result =
p.validate_command_execution("wget https://releases.example.com/v1.tar.gz", true);
assert_eq!(result.unwrap(), CommandRiskLevel::High);
}
#[test]
fn validate_command_blocks_non_listed_high_risk_when_another_is_allowed() {
// Allowing curl explicitly should not exempt wget.
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec!["curl".into()],
block_high_risk_commands: true,
..SecurityPolicy::default()
};
let result = p.validate_command_execution("wget https://evil.com", true);
assert!(result.is_err());
assert!(result.unwrap_err().contains("not allowed"));
}
#[test]
fn validate_command_explicit_rm_bypasses_high_risk_block() {
// Operator explicitly listed "rm" — they accept the risk.
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec!["rm".into()],
block_high_risk_commands: true,
..SecurityPolicy::default()
};
let result = p.validate_command_execution("rm -rf /tmp/test", true);
assert_eq!(result.unwrap(), CommandRiskLevel::High);
}
#[test]
fn validate_command_high_risk_still_needs_approval_in_supervised() {
// Even when explicitly allowed, supervised mode still requires
// approval for high-risk commands (the approval gate is separate
// from the block gate).
let p = SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
allowed_commands: vec!["curl".into()],
block_high_risk_commands: true,
..SecurityPolicy::default()
};
let denied = p.validate_command_execution("curl https://api.example.com", false);
assert!(denied.is_err());
assert!(denied.unwrap_err().contains("requires explicit approval"));
let allowed = p.validate_command_execution("curl https://api.example.com", true);
assert_eq!(allowed.unwrap(), CommandRiskLevel::High);
}
#[test]
fn validate_command_pipe_needs_all_segments_explicitly_allowed() {
// When a pipeline contains a high-risk command, every segment
// must be explicitly allowed for the exemption to apply.
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec!["curl".into(), "grep".into()],
block_high_risk_commands: true,
..SecurityPolicy::default()
};
let result = p.validate_command_execution("curl https://api.example.com | grep data", true);
assert_eq!(result.unwrap(), CommandRiskLevel::High);
}
#[test]
fn validate_command_full_mode_skips_medium_risk_approval_gate() {
let p = SecurityPolicy {
+397
View File
@@ -0,0 +1,397 @@
//! Vulnerability scan result parsing and management.
//!
//! Parses vulnerability scan outputs from common scanners (Nessus, Qualys, generic
//! CVSS JSON) and provides priority scoring with business context adjustments.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
/// A single vulnerability finding.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Finding {
/// CVE identifier (e.g. "CVE-2024-1234"). May be empty for non-CVE findings.
#[serde(default)]
pub cve_id: String,
/// CVSS base score (0.0 - 10.0).
pub cvss_score: f64,
/// Severity label: "low", "medium", "high", "critical".
pub severity: String,
/// Affected asset identifier (hostname, IP, or service name).
pub affected_asset: String,
/// Description of the vulnerability.
pub description: String,
/// Recommended remediation steps.
#[serde(default)]
pub remediation: String,
/// Whether the asset is internet-facing (increases effective priority).
#[serde(default)]
pub internet_facing: bool,
/// Whether the asset is in a production environment.
#[serde(default = "default_true")]
pub production: bool,
}
fn default_true() -> bool {
true
}
/// A parsed vulnerability scan report.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VulnerabilityReport {
/// When the scan was performed.
pub scan_date: DateTime<Utc>,
/// Scanner that produced the results (e.g. "nessus", "qualys", "generic").
pub scanner: String,
/// Individual findings from the scan.
pub findings: Vec<Finding>,
}
/// Compute effective priority score for a finding.
///
/// Base: CVSS score (0-10). Adjustments:
/// - Internet-facing: +2.0 (capped at 10.0)
/// - Production: +1.0 (capped at 10.0)
pub fn effective_priority(finding: &Finding) -> f64 {
let mut score = finding.cvss_score;
if finding.internet_facing {
score += 2.0;
}
if finding.production {
score += 1.0;
}
score.min(10.0)
}
/// Classify CVSS score into severity label.
pub fn cvss_to_severity(cvss: f64) -> &'static str {
match cvss {
s if s >= 9.0 => "critical",
s if s >= 7.0 => "high",
s if s >= 4.0 => "medium",
s if s > 0.0 => "low",
_ => "informational",
}
}
/// Parse a generic CVSS JSON vulnerability report.
///
/// Expects a JSON object with:
/// - `scan_date`: ISO 8601 date string
/// - `scanner`: string
/// - `findings`: array of Finding objects
pub fn parse_vulnerability_json(json_str: &str) -> anyhow::Result<VulnerabilityReport> {
let report: VulnerabilityReport = serde_json::from_str(json_str)
.map_err(|e| anyhow::anyhow!("Failed to parse vulnerability report: {e}"))?;
for (i, finding) in report.findings.iter().enumerate() {
if !(0.0..=10.0).contains(&finding.cvss_score) {
anyhow::bail!(
"findings[{}].cvss_score must be between 0.0 and 10.0, got {}",
i,
finding.cvss_score
);
}
}
Ok(report)
}
/// Generate a summary of the vulnerability report.
pub fn generate_summary(report: &VulnerabilityReport) -> String {
if report.findings.is_empty() {
return format!(
"Vulnerability scan by {} on {}: No findings.",
report.scanner,
report.scan_date.format("%Y-%m-%d")
);
}
let total = report.findings.len();
let critical = report
.findings
.iter()
.filter(|f| f.severity.eq_ignore_ascii_case("critical"))
.count();
let high = report
.findings
.iter()
.filter(|f| f.severity.eq_ignore_ascii_case("high"))
.count();
let medium = report
.findings
.iter()
.filter(|f| f.severity.eq_ignore_ascii_case("medium"))
.count();
let low = report
.findings
.iter()
.filter(|f| f.severity.eq_ignore_ascii_case("low"))
.count();
let informational = report
.findings
.iter()
.filter(|f| f.severity.eq_ignore_ascii_case("informational"))
.count();
// Sort by effective priority descending
let mut sorted: Vec<&Finding> = report.findings.iter().collect();
sorted.sort_by(|a, b| {
effective_priority(b)
.partial_cmp(&effective_priority(a))
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut summary = format!(
"## Vulnerability Scan Summary\n\
**Scanner:** {} | **Date:** {}\n\
**Total findings:** {} (Critical: {}, High: {}, Medium: {}, Low: {}, Informational: {})\n\n",
report.scanner,
report.scan_date.format("%Y-%m-%d"),
total,
critical,
high,
medium,
low,
informational
);
// Top 10 by effective priority
summary.push_str("### Top Findings by Priority\n\n");
for (i, finding) in sorted.iter().take(10).enumerate() {
let priority = effective_priority(finding);
let context = match (finding.internet_facing, finding.production) {
(true, true) => " [internet-facing, production]",
(true, false) => " [internet-facing]",
(false, true) => " [production]",
(false, false) => "",
};
let _ = writeln!(
summary,
"{}. **{}** (CVSS: {:.1}, Priority: {:.1}){}\n Asset: {} | {}",
i + 1,
if finding.cve_id.is_empty() {
"No CVE"
} else {
&finding.cve_id
},
finding.cvss_score,
priority,
context,
finding.affected_asset,
finding.description
);
if !finding.remediation.is_empty() {
let _ = writeln!(summary, " Remediation: {}", finding.remediation);
}
summary.push('\n');
}
// Remediation recommendations
if critical > 0 || high > 0 {
summary.push_str("### Remediation Recommendations\n\n");
if critical > 0 {
let _ = writeln!(
summary,
"- **URGENT:** {} critical findings require immediate remediation",
critical
);
}
if high > 0 {
let _ = writeln!(
summary,
"- **HIGH:** {} high-severity findings should be addressed within 7 days",
high
);
}
let internet_facing_critical = sorted
.iter()
.filter(|f| f.internet_facing && (f.severity == "critical" || f.severity == "high"))
.count();
if internet_facing_critical > 0 {
let _ = writeln!(
summary,
"- **PRIORITY:** {} critical/high findings on internet-facing assets",
internet_facing_critical
);
}
}
summary
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_findings() -> Vec<Finding> {
vec![
Finding {
cve_id: "CVE-2024-0001".into(),
cvss_score: 9.8,
severity: "critical".into(),
affected_asset: "web-server-01".into(),
description: "Remote code execution in web framework".into(),
remediation: "Upgrade to version 2.1.0".into(),
internet_facing: true,
production: true,
},
Finding {
cve_id: "CVE-2024-0002".into(),
cvss_score: 7.5,
severity: "high".into(),
affected_asset: "db-server-01".into(),
description: "SQL injection in query parser".into(),
remediation: "Apply patch KB-12345".into(),
internet_facing: false,
production: true,
},
Finding {
cve_id: "CVE-2024-0003".into(),
cvss_score: 4.3,
severity: "medium".into(),
affected_asset: "staging-app-01".into(),
description: "Information disclosure via debug endpoint".into(),
remediation: "Disable debug endpoint in config".into(),
internet_facing: false,
production: false,
},
]
}
#[test]
fn effective_priority_adds_context_bonuses() {
let mut f = Finding {
cve_id: String::new(),
cvss_score: 7.0,
severity: "high".into(),
affected_asset: "host".into(),
description: "test".into(),
remediation: String::new(),
internet_facing: false,
production: false,
};
assert!((effective_priority(&f) - 7.0).abs() < f64::EPSILON);
f.internet_facing = true;
assert!((effective_priority(&f) - 9.0).abs() < f64::EPSILON);
f.production = true;
assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); // capped
// High CVSS + both bonuses still caps at 10.0
f.cvss_score = 9.5;
assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON);
}
#[test]
fn cvss_to_severity_classification() {
assert_eq!(cvss_to_severity(9.8), "critical");
assert_eq!(cvss_to_severity(9.0), "critical");
assert_eq!(cvss_to_severity(8.5), "high");
assert_eq!(cvss_to_severity(7.0), "high");
assert_eq!(cvss_to_severity(5.0), "medium");
assert_eq!(cvss_to_severity(4.0), "medium");
assert_eq!(cvss_to_severity(3.9), "low");
assert_eq!(cvss_to_severity(0.1), "low");
assert_eq!(cvss_to_severity(0.0), "informational");
}
#[test]
fn parse_vulnerability_json_roundtrip() {
let report = VulnerabilityReport {
scan_date: Utc::now(),
scanner: "nessus".into(),
findings: sample_findings(),
};
let json = serde_json::to_string(&report).unwrap();
let parsed = parse_vulnerability_json(&json).unwrap();
assert_eq!(parsed.scanner, "nessus");
assert_eq!(parsed.findings.len(), 3);
assert_eq!(parsed.findings[0].cve_id, "CVE-2024-0001");
}
#[test]
fn parse_vulnerability_json_rejects_invalid() {
let result = parse_vulnerability_json("not json");
assert!(result.is_err());
}
#[test]
fn generate_summary_includes_key_sections() {
let report = VulnerabilityReport {
scan_date: Utc::now(),
scanner: "qualys".into(),
findings: sample_findings(),
};
let summary = generate_summary(&report);
assert!(summary.contains("qualys"));
assert!(summary.contains("Total findings:** 3"));
assert!(summary.contains("Critical: 1"));
assert!(summary.contains("High: 1"));
assert!(summary.contains("CVE-2024-0001"));
assert!(summary.contains("URGENT"));
assert!(summary.contains("internet-facing"));
}
#[test]
fn parse_vulnerability_json_rejects_out_of_range_cvss() {
let report = VulnerabilityReport {
scan_date: Utc::now(),
scanner: "test".into(),
findings: vec![Finding {
cve_id: "CVE-2024-9999".into(),
cvss_score: 11.0,
severity: "critical".into(),
affected_asset: "host".into(),
description: "bad score".into(),
remediation: String::new(),
internet_facing: false,
production: false,
}],
};
let json = serde_json::to_string(&report).unwrap();
let result = parse_vulnerability_json(&json);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("cvss_score must be between 0.0 and 10.0"));
}
#[test]
fn parse_vulnerability_json_rejects_negative_cvss() {
let report = VulnerabilityReport {
scan_date: Utc::now(),
scanner: "test".into(),
findings: vec![Finding {
cve_id: "CVE-2024-9998".into(),
cvss_score: -1.0,
severity: "low".into(),
affected_asset: "host".into(),
description: "negative score".into(),
remediation: String::new(),
internet_facing: false,
production: false,
}],
};
let json = serde_json::to_string(&report).unwrap();
let result = parse_vulnerability_json(&json);
assert!(result.is_err());
}
#[test]
fn generate_summary_empty_findings() {
let report = VulnerabilityReport {
scan_date: Utc::now(),
scanner: "nessus".into(),
findings: vec![],
};
let summary = generate_summary(&report);
assert!(summary.contains("No findings"));
}
}
+211
View File
@@ -0,0 +1,211 @@
//! Workspace isolation boundary enforcement.
//!
//! Prevents cross-workspace data access and enforces per-workspace
//! domain allowlists and tool restrictions.
use crate::config::workspace::WorkspaceProfile;
use std::path::Path;
/// Outcome of a workspace boundary check.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BoundaryVerdict {
/// Access is allowed.
Allow,
/// Access is denied with a reason.
Deny(String),
}
/// Enforces isolation boundaries for the active workspace.
#[derive(Debug, Clone)]
pub struct WorkspaceBoundary {
/// The active workspace profile (if workspace isolation is active).
profile: Option<WorkspaceProfile>,
/// Whether cross-workspace search is allowed.
cross_workspace_search: bool,
}
impl WorkspaceBoundary {
/// Create a boundary enforcer for the given active workspace.
pub fn new(profile: Option<WorkspaceProfile>, cross_workspace_search: bool) -> Self {
Self {
profile,
cross_workspace_search,
}
}
/// Create a boundary enforcer with no active workspace (no restrictions).
pub fn inactive() -> Self {
Self {
profile: None,
cross_workspace_search: false,
}
}
/// Check whether a tool is allowed in the current workspace.
pub fn check_tool_access(&self, tool_name: &str) -> BoundaryVerdict {
if let Some(profile) = &self.profile {
if profile.is_tool_restricted(tool_name) {
return BoundaryVerdict::Deny(format!(
"tool '{}' is restricted in workspace '{}'",
tool_name, profile.name
));
}
}
BoundaryVerdict::Allow
}
/// Check whether a domain is allowed in the current workspace.
pub fn check_domain_access(&self, domain: &str) -> BoundaryVerdict {
if let Some(profile) = &self.profile {
if !profile.is_domain_allowed(domain) {
return BoundaryVerdict::Deny(format!(
"domain '{}' is not in the allowlist for workspace '{}'",
domain, profile.name
));
}
}
BoundaryVerdict::Allow
}
/// Check whether accessing a path is allowed given workspace isolation.
///
/// When a workspace is active, paths outside the workspace directory
/// and paths belonging to other workspaces are denied.
pub fn check_path_access(&self, path: &Path, workspaces_base: &Path) -> BoundaryVerdict {
let profile = match &self.profile {
Some(p) => p,
None => return BoundaryVerdict::Allow,
};
// If the path is under the workspaces base, verify it belongs to the active workspace
if let Ok(relative) = path.strip_prefix(workspaces_base) {
let first_component = relative
.components()
.next()
.and_then(|c| c.as_os_str().to_str());
if let Some(ws_name) = first_component {
if ws_name != profile.name {
if self.cross_workspace_search {
// Cross-workspace search is allowed, but only for read-like access
return BoundaryVerdict::Allow;
}
return BoundaryVerdict::Deny(format!(
"access to workspace '{}' is denied from workspace '{}'",
ws_name, profile.name
));
}
}
}
BoundaryVerdict::Allow
}
/// Whether workspace isolation is active.
pub fn is_active(&self) -> bool {
self.profile.is_some()
}
/// Get the active workspace name, if any.
pub fn active_workspace_name(&self) -> Option<&str> {
self.profile.as_ref().map(|p| p.name.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn test_profile() -> WorkspaceProfile {
WorkspaceProfile {
name: "client_a".to_string(),
allowed_domains: vec!["api.example.com".to_string()],
credential_profile: None,
memory_namespace: Some("client_a".to_string()),
audit_namespace: Some("client_a".to_string()),
tool_restrictions: vec!["shell".to_string()],
}
}
#[test]
fn boundary_inactive_allows_everything() {
let boundary = WorkspaceBoundary::inactive();
assert_eq!(boundary.check_tool_access("shell"), BoundaryVerdict::Allow);
assert_eq!(
boundary.check_domain_access("any.domain"),
BoundaryVerdict::Allow
);
assert!(!boundary.is_active());
}
#[test]
fn boundary_denies_restricted_tool() {
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
assert!(matches!(
boundary.check_tool_access("shell"),
BoundaryVerdict::Deny(_)
));
assert_eq!(
boundary.check_tool_access("file_read"),
BoundaryVerdict::Allow
);
}
#[test]
fn boundary_denies_unlisted_domain() {
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
assert_eq!(
boundary.check_domain_access("api.example.com"),
BoundaryVerdict::Allow
);
assert!(matches!(
boundary.check_domain_access("evil.com"),
BoundaryVerdict::Deny(_)
));
}
#[test]
fn boundary_denies_cross_workspace_path_access() {
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
let base = PathBuf::from("/home/zeroclaw_user/.zeroclaw/workspaces");
// Access to own workspace is allowed
let own_path = base.join("client_a").join("data.db");
assert_eq!(
boundary.check_path_access(&own_path, &base),
BoundaryVerdict::Allow
);
// Access to other workspace is denied
let other_path = base.join("client_b").join("data.db");
assert!(matches!(
boundary.check_path_access(&other_path, &base),
BoundaryVerdict::Deny(_)
));
}
#[test]
fn boundary_allows_cross_workspace_when_enabled() {
let boundary = WorkspaceBoundary::new(Some(test_profile()), true);
let base = PathBuf::from("/home/zeroclaw_user/.zeroclaw/workspaces");
let other_path = base.join("client_b").join("data.db");
assert_eq!(
boundary.check_path_access(&other_path, &base),
BoundaryVerdict::Allow
);
}
#[test]
fn boundary_allows_paths_outside_workspaces_dir() {
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
let base = PathBuf::from("/home/zeroclaw_user/.zeroclaw/workspaces");
let outside_path = PathBuf::from("/tmp/something");
assert_eq!(
boundary.check_path_access(&outside_path, &base),
BoundaryVerdict::Allow
);
}
}
+91 -6
View File
@@ -442,8 +442,24 @@ fn install_linux_systemd(config: &Config) -> Result<()> {
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
let unit = format!(
"[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n",
exe.display()
"[Unit]\n\
Description=ZeroClaw daemon\n\
After=network.target\n\
\n\
[Service]\n\
Type=simple\n\
ExecStart={exe} daemon\n\
Restart=always\n\
RestartSec=3\n\
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
Environment=HOME=%h\n\
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
# so graphical/headless browsers can function correctly.\n\
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
\n\
[Install]\n\
WantedBy=default.target\n",
exe = exe.display()
);
fs::write(&file, unit)?;
@@ -826,8 +842,8 @@ fn generate_openrc_script(exe_path: &Path, config_dir: &Path) -> String {
name="zeroclaw"
description="ZeroClaw daemon"
command="{}"
command_args="--config-dir {} daemon"
command="{exe}"
command_args="--config-dir {config_dir} daemon"
command_background="yes"
command_user="zeroclaw:zeroclaw"
pidfile="/run/${{RC_SVCNAME}}.pid"
@@ -835,13 +851,21 @@ umask 027
output_log="/var/log/zeroclaw/access.log"
error_log="/var/log/zeroclaw/error.log"
# Provide HOME so headless browsers can create profile/cache directories.
# Without this, Chromium/Firefox fail with sandbox or profile errors.
export HOME="/var/lib/zeroclaw"
depend() {{
need net
after firewall
}}
start_pre() {{
checkpath --directory --owner zeroclaw:zeroclaw --mode 0750 /var/lib/zeroclaw
}}
"#,
exe_path.display(),
config_dir.display()
exe = exe_path.display(),
config_dir = config_dir.display(),
)
}
@@ -1196,6 +1220,67 @@ mod tests {
assert!(script.contains("after firewall"));
}
#[test]
fn generate_openrc_script_sets_home_for_browser() {
use std::path::PathBuf;
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
assert!(
script.contains("export HOME=\"/var/lib/zeroclaw\""),
"OpenRC script must set HOME for headless browser support"
);
}
#[test]
fn generate_openrc_script_creates_home_directory() {
use std::path::PathBuf;
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
assert!(
script.contains("start_pre()"),
"OpenRC script must have start_pre to create HOME dir"
);
assert!(
script.contains("checkpath --directory --owner zeroclaw:zeroclaw"),
"start_pre must ensure /var/lib/zeroclaw exists with correct ownership"
);
}
#[test]
fn systemd_unit_contains_home_and_pass_environment() {
let unit = "[Unit]\n\
Description=ZeroClaw daemon\n\
After=network.target\n\
\n\
[Service]\n\
Type=simple\n\
ExecStart=/usr/local/bin/zeroclaw daemon\n\
Restart=always\n\
RestartSec=3\n\
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
Environment=HOME=%h\n\
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
# so graphical/headless browsers can function correctly.\n\
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
\n\
[Install]\n\
WantedBy=default.target\n"
.to_string();
assert!(
unit.contains("Environment=HOME=%h"),
"systemd unit must set HOME for headless browser support"
);
assert!(
unit.contains("PassEnvironment=DISPLAY XDG_RUNTIME_DIR"),
"systemd unit must pass through display/runtime env vars"
);
}
#[test]
fn warn_if_binary_in_home_detects_home_path() {
use std::path::PathBuf;
+126
View File
@@ -440,6 +440,12 @@ impl BrowserTool {
async fn run_command(&self, args: &[&str]) -> anyhow::Result<AgentBrowserResponse> {
let mut cmd = Command::new("agent-browser");
// When running as a service (systemd/OpenRC), the process may lack
// HOME which browsers need for profile directories.
if is_service_environment() {
ensure_browser_env(&mut cmd);
}
// Add session if configured
if let Some(ref session) = self.session_name {
cmd.arg("--session").arg(session);
@@ -1461,6 +1467,14 @@ mod native_backend {
args.push(Value::String("--disable-gpu".to_string()));
}
// When running as a service (systemd/OpenRC), the browser sandbox
// fails because the process lacks a user namespace / session.
// --no-sandbox and --disable-dev-shm-usage are required in this context.
if super::is_service_environment() {
args.push(Value::String("--no-sandbox".to_string()));
args.push(Value::String("--disable-dev-shm-usage".to_string()));
}
if !args.is_empty() {
chrome_options.insert("args".to_string(), Value::Array(args));
}
@@ -2111,6 +2125,44 @@ fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
|| v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
}
/// Detect whether the current process is running inside a service environment
/// (e.g. systemd, OpenRC, or launchd) where the browser sandbox and
/// environment setup may be restricted.
fn is_service_environment() -> bool {
if std::env::var_os("INVOCATION_ID").is_some() {
return true;
}
if std::env::var_os("JOURNAL_STREAM").is_some() {
return true;
}
#[cfg(target_os = "linux")]
if std::path::Path::new("/run/openrc").exists() && std::env::var_os("HOME").is_none() {
return true;
}
#[cfg(target_os = "linux")]
if std::env::var_os("HOME").is_none() {
return true;
}
false
}
/// Ensure environment variables required by headless browsers are present
/// when running inside a service context.
fn ensure_browser_env(cmd: &mut Command) {
if std::env::var_os("HOME").is_none() {
cmd.env("HOME", "/tmp");
}
let existing = std::env::var("CHROMIUM_FLAGS").unwrap_or_default();
if !existing.contains("--no-sandbox") {
let new_flags = if existing.is_empty() {
"--no-sandbox --disable-dev-shm-usage".to_string()
} else {
format!("{existing} --no-sandbox --disable-dev-shm-usage")
};
cmd.env("CHROMIUM_FLAGS", new_flags);
}
}
fn host_matches_allowlist(host: &str, allowed: &[String]) -> bool {
allowed.iter().any(|pattern| {
if pattern == "*" {
@@ -2492,4 +2544,78 @@ mod tests {
state.reset_session().await;
});
}
#[test]
fn ensure_browser_env_sets_home_when_missing() {
let original_home = std::env::var_os("HOME");
unsafe { std::env::remove_var("HOME") };
let mut cmd = Command::new("true");
ensure_browser_env(&mut cmd);
// Function completes without panic — HOME and CHROMIUM_FLAGS set on cmd.
if let Some(home) = original_home {
unsafe { std::env::set_var("HOME", home) };
}
}
#[test]
fn ensure_browser_env_sets_chromium_flags() {
let original = std::env::var_os("CHROMIUM_FLAGS");
unsafe { std::env::remove_var("CHROMIUM_FLAGS") };
let mut cmd = Command::new("true");
ensure_browser_env(&mut cmd);
if let Some(val) = original {
unsafe { std::env::set_var("CHROMIUM_FLAGS", val) };
}
}
#[test]
fn is_service_environment_detects_invocation_id() {
let original = std::env::var_os("INVOCATION_ID");
unsafe { std::env::set_var("INVOCATION_ID", "test-unit-id") };
assert!(is_service_environment());
if let Some(val) = original {
unsafe { std::env::set_var("INVOCATION_ID", val) };
} else {
unsafe { std::env::remove_var("INVOCATION_ID") };
}
}
#[test]
fn is_service_environment_detects_journal_stream() {
let original = std::env::var_os("JOURNAL_STREAM");
unsafe { std::env::set_var("JOURNAL_STREAM", "8:12345") };
assert!(is_service_environment());
if let Some(val) = original {
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
} else {
unsafe { std::env::remove_var("JOURNAL_STREAM") };
}
}
#[test]
fn is_service_environment_false_in_normal_context() {
let inv = std::env::var_os("INVOCATION_ID");
let journal = std::env::var_os("JOURNAL_STREAM");
unsafe { std::env::remove_var("INVOCATION_ID") };
unsafe { std::env::remove_var("JOURNAL_STREAM") };
if std::env::var_os("HOME").is_some() {
assert!(!is_service_environment());
}
if let Some(val) = inv {
unsafe { std::env::set_var("INVOCATION_ID", val) };
}
if let Some(val) = journal {
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
}
}
}
+2 -1
View File
@@ -116,7 +116,8 @@ impl Tool for CronRunTool {
}
let started_at = Utc::now();
let (success, output) = cron::scheduler::execute_job_now(&self.config, &job).await;
let (success, output) =
Box::pin(cron::scheduler::execute_job_now(&self.config, &job)).await;
let finished_at = Utc::now();
let duration_ms = (finished_at - started_at).num_milliseconds();
let status = if success { "ok" } else { "error" };
+84 -4
View File
@@ -12,6 +12,7 @@ pub struct HttpRequestTool {
allowed_domains: Vec<String>,
max_response_size: usize,
timeout_secs: u64,
allow_private_hosts: bool,
}
impl HttpRequestTool {
@@ -20,12 +21,14 @@ impl HttpRequestTool {
allowed_domains: Vec<String>,
max_response_size: usize,
timeout_secs: u64,
allow_private_hosts: bool,
) -> Self {
Self {
security,
allowed_domains: normalize_allowed_domains(allowed_domains),
max_response_size,
timeout_secs,
allow_private_hosts,
}
}
@@ -52,7 +55,7 @@ impl HttpRequestTool {
let host = extract_host(url)?;
if is_private_or_local_host(&host) {
if !self.allow_private_hosts && is_private_or_local_host(&host) {
anyhow::bail!("Blocked local/private host: {host}");
}
@@ -454,6 +457,13 @@ mod tests {
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool {
test_tool_with_private(allowed_domains, false)
}
fn test_tool_with_private(
allowed_domains: Vec<&str>,
allow_private_hosts: bool,
) -> HttpRequestTool {
let security = Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
..SecurityPolicy::default()
@@ -463,6 +473,7 @@ mod tests {
allowed_domains.into_iter().map(String::from).collect(),
1_000_000,
30,
allow_private_hosts,
)
}
@@ -570,7 +581,7 @@ mod tests {
#[test]
fn validate_requires_allowlist() {
let security = Arc::new(SecurityPolicy::default());
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30);
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30, false);
let err = tool
.validate_url("https://example.com")
.unwrap_err()
@@ -686,7 +697,7 @@ mod tests {
autonomy: AutonomyLevel::ReadOnly,
..SecurityPolicy::default()
});
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
let result = tool
.execute(json!({"url": "https://example.com"}))
.await
@@ -701,7 +712,7 @@ mod tests {
max_actions_per_hour: 0,
..SecurityPolicy::default()
});
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
let result = tool
.execute(json!({"url": "https://example.com"}))
.await
@@ -724,6 +735,7 @@ mod tests {
vec!["example.com".into()],
10,
30,
false,
);
let text = "hello world this is long";
let truncated = tool.truncate_response(text);
@@ -738,6 +750,7 @@ mod tests {
vec!["example.com".into()],
0, // max_response_size = 0 means no limit
30,
false,
);
let text = "a".repeat(10_000_000);
assert_eq!(tool.truncate_response(&text), text);
@@ -750,6 +763,7 @@ mod tests {
vec!["example.com".into()],
5,
30,
false,
);
let text = "hello world";
let truncated = tool.truncate_response(text);
@@ -935,4 +949,70 @@ mod tests {
.to_string();
assert!(err.contains("IPv6"));
}
// ── allow_private_hosts opt-in tests ────────────────────────
#[test]
fn default_blocks_private_hosts() {
let tool = test_tool(vec!["localhost", "192.168.1.5", "*"]);
assert!(tool
.validate_url("https://localhost:8080")
.unwrap_err()
.to_string()
.contains("local/private"));
assert!(tool
.validate_url("https://192.168.1.5")
.unwrap_err()
.to_string()
.contains("local/private"));
assert!(tool
.validate_url("https://10.0.0.1")
.unwrap_err()
.to_string()
.contains("local/private"));
}
#[test]
fn allow_private_hosts_permits_localhost() {
let tool = test_tool_with_private(vec!["localhost"], true);
assert!(tool.validate_url("https://localhost:8080").is_ok());
}
#[test]
fn allow_private_hosts_permits_private_ipv4() {
let tool = test_tool_with_private(vec!["192.168.1.5"], true);
assert!(tool.validate_url("https://192.168.1.5").is_ok());
}
#[test]
fn allow_private_hosts_permits_rfc1918_with_wildcard() {
let tool = test_tool_with_private(vec!["*"], true);
assert!(tool.validate_url("https://10.0.0.1").is_ok());
assert!(tool.validate_url("https://172.16.0.1").is_ok());
assert!(tool.validate_url("https://192.168.1.1").is_ok());
assert!(tool.validate_url("http://localhost:8123").is_ok());
}
#[test]
fn allow_private_hosts_still_requires_allowlist() {
let tool = test_tool_with_private(vec!["example.com"], true);
let err = tool
.validate_url("https://192.168.1.5")
.unwrap_err()
.to_string();
assert!(
err.contains("allowed_domains"),
"Private host should still need allowlist match, got: {err}"
);
}
#[test]
fn allow_private_hosts_false_still_blocks() {
let tool = test_tool_with_private(vec!["*"], false);
assert!(tool
.validate_url("https://localhost:8080")
.unwrap_err()
.to_string()
.contains("local/private"));
}
}
+400
View File
@@ -0,0 +1,400 @@
use anyhow::Context;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use tokio::sync::Mutex;
/// Cached OAuth2 token state persisted to disk between runs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedTokenState {
pub access_token: String,
pub refresh_token: Option<String>,
/// Unix timestamp (seconds) when the access token expires.
pub expires_at: i64,
}
impl CachedTokenState {
/// Returns `true` when the token is expired or will expire within 60 seconds.
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.expires_at <= now + 60
}
}
/// Thread-safe token cache with disk persistence.
pub struct TokenCache {
inner: RwLock<Option<CachedTokenState>>,
/// Serialises the slow acquire/refresh path so only one caller performs the
/// network round-trip while others wait and then read the updated cache.
acquire_lock: Mutex<()>,
config: super::types::Microsoft365ResolvedConfig,
cache_path: PathBuf,
}
impl TokenCache {
pub fn new(
config: super::types::Microsoft365ResolvedConfig,
zeroclaw_dir: &std::path::Path,
) -> anyhow::Result<Self> {
if config.token_cache_encrypted {
anyhow::bail!(
"microsoft365: token_cache_encrypted is enabled but encryption is not yet \
implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \
to false or wait for encryption support."
);
}
// Scope cache file to (tenant_id, client_id, auth_flow) so config
// changes never reuse tokens from a different account/flow.
let mut hasher = DefaultHasher::new();
config.tenant_id.hash(&mut hasher);
config.client_id.hash(&mut hasher);
config.auth_flow.hash(&mut hasher);
let fingerprint = format!("{:016x}", hasher.finish());
let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json"));
let cached = Self::load_from_disk(&cache_path);
Ok(Self {
inner: RwLock::new(cached),
acquire_lock: Mutex::new(()),
config,
cache_path,
})
}
/// Get a valid access token, refreshing or re-authenticating as needed.
pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
// Fast path: cached and not expired.
{
let guard = self.inner.read();
if let Some(ref state) = *guard {
if !state.is_expired() {
return Ok(state.access_token.clone());
}
}
}
// Slow path: serialise through a mutex so only one caller performs the
// network round-trip while concurrent callers wait and re-check.
let _lock = self.acquire_lock.lock().await;
// Re-check after acquiring the lock — another caller may have refreshed
// while we were waiting.
{
let guard = self.inner.read();
if let Some(ref state) = *guard {
if !state.is_expired() {
return Ok(state.access_token.clone());
}
}
}
let new_state = self.acquire_token(client).await?;
let token = new_state.access_token.clone();
self.persist_to_disk(&new_state);
*self.inner.write() = Some(new_state);
Ok(token)
}
async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
// Try refresh first if we have a refresh token and the flow supports it.
// Client credentials flow does not issue refresh tokens, so skip the
// attempt entirely to avoid a wasted round-trip.
if self.config.auth_flow.as_str() != "client_credentials" {
// Clone the token out so the RwLock guard is dropped before the await.
let refresh_token_copy = {
let guard = self.inner.read();
guard.as_ref().and_then(|state| state.refresh_token.clone())
};
if let Some(refresh_tok) = refresh_token_copy {
match self.refresh_token(client, &refresh_tok).await {
Ok(new_state) => return Ok(new_state),
Err(e) => {
tracing::debug!("ms365: refresh token failed, re-authenticating: {e}");
}
}
}
}
match self.config.auth_flow.as_str() {
"client_credentials" => self.client_credentials_flow(client).await,
"device_code" => self.device_code_flow(client).await,
other => anyhow::bail!("Unsupported auth flow: {other}"),
}
}
async fn client_credentials_flow(
&self,
client: &reqwest::Client,
) -> anyhow::Result<CachedTokenState> {
let client_secret = self
.config
.client_secret
.as_deref()
.context("client_credentials flow requires client_secret")?;
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let scope = self.config.scopes.join(" ");
let resp = client
.post(&token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", &self.config.client_id),
("client_secret", client_secret),
("scope", &scope),
])
.send()
.await
.context("ms365: failed to request client_credentials token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!("ms365: client_credentials raw OAuth error: {body}");
anyhow::bail!("ms365: client_credentials token request failed ({status})");
}
let token_resp: TokenResponse = resp
.json()
.await
.context("ms365: failed to parse token response")?;
Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
})
}
async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
let device_code_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
self.config.tenant_id
);
let scope = self.config.scopes.join(" ");
let resp = client
.post(&device_code_url)
.form(&[
("client_id", self.config.client_id.as_str()),
("scope", &scope),
])
.send()
.await
.context("ms365: failed to request device code")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!("ms365: device_code initiation raw error: {body}");
anyhow::bail!("ms365: device code request failed ({status})");
}
let device_resp: DeviceCodeResponse = resp
.json()
.await
.context("ms365: failed to parse device code response")?;
// Log only a generic prompt; the full device_resp.message may contain
// sensitive verification URIs or codes that should not appear in logs.
tracing::info!(
"ms365: device code auth required — follow the instructions shown to the user"
);
// Print the user-facing message to stderr so the operator can act on it
// without it being captured in structured log sinks.
eprintln!("ms365: {}", device_resp.message);
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let interval = device_resp.interval.max(5);
let max_polls = u32::try_from(
(device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1),
)
.unwrap_or(u32::MAX);
for _ in 0..max_polls {
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
let poll_resp = client
.post(&token_url)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("client_id", self.config.client_id.as_str()),
("device_code", &device_resp.device_code),
])
.send()
.await
.context("ms365: failed to poll device code token")?;
if poll_resp.status().is_success() {
let token_resp: TokenResponse = poll_resp
.json()
.await
.context("ms365: failed to parse token response")?;
return Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
});
}
let body = poll_resp.text().await.unwrap_or_default();
if body.contains("authorization_pending") {
continue;
}
if body.contains("slow_down") {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
tracing::debug!("ms365: device code polling raw error: {body}");
anyhow::bail!("ms365: device code polling failed");
}
anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
}
async fn refresh_token(
&self,
client: &reqwest::Client,
refresh_token: &str,
) -> anyhow::Result<CachedTokenState> {
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let mut params = vec![
("grant_type", "refresh_token"),
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
];
let secret_ref;
if let Some(ref secret) = self.config.client_secret {
secret_ref = secret.as_str();
params.push(("client_secret", secret_ref));
}
let resp = client
.post(&token_url)
.form(&params)
.send()
.await
.context("ms365: failed to refresh token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!("ms365: token refresh raw error: {body}");
anyhow::bail!("ms365: token refresh failed ({status})");
}
let token_resp: TokenResponse = resp
.json()
.await
.context("ms365: failed to parse refresh token response")?;
Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp
.refresh_token
.or_else(|| Some(refresh_token.to_string())),
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
})
}
fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
let data = std::fs::read_to_string(path).ok()?;
serde_json::from_str(&data).ok()
}
fn persist_to_disk(&self, state: &CachedTokenState) {
if let Ok(json) = serde_json::to_string_pretty(state) {
if let Err(e) = std::fs::write(&self.cache_path, json) {
tracing::warn!("ms365: failed to persist token cache: {e}");
}
}
}
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default = "default_expires_in")]
expires_in: i64,
}
fn default_expires_in() -> i64 {
3600
}
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
message: String,
#[serde(default = "default_device_interval")]
interval: u64,
#[serde(default = "default_device_expires_in")]
expires_in: i64,
}
fn default_device_interval() -> u64 {
5
}
fn default_device_expires_in() -> i64 {
900
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_is_expired_when_past_deadline() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() - 10,
};
assert!(state.is_expired());
}
#[test]
fn token_is_expired_within_buffer() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() + 30,
};
assert!(state.is_expired());
}
#[test]
fn token_is_valid_when_far_from_expiry() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() + 3600,
};
assert!(!state.is_expired());
}
#[test]
fn load_from_disk_returns_none_for_missing_file() {
let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
assert!(TokenCache::load_from_disk(path).is_none());
}
}
+495
View File
@@ -0,0 +1,495 @@
use anyhow::Context;
const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
/// Build the user path segment: `/me` or `/users/{user_id}`.
/// The user_id is percent-encoded to prevent path-traversal attacks.
fn user_path(user_id: &str) -> String {
if user_id == "me" {
"/me".to_string()
} else {
format!("/users/{}", urlencoding::encode(user_id))
}
}
/// Percent-encode a single path segment to prevent path-traversal attacks.
fn encode_path_segment(segment: &str) -> String {
urlencoding::encode(segment).into_owned()
}
/// List mail messages for a user.
pub async fn mail_list(
client: &reqwest::Client,
token: &str,
user_id: &str,
folder: Option<&str>,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let base = user_path(user_id);
let path = match folder {
Some(f) => format!(
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
encode_path_segment(f)
),
None => format!("{GRAPH_BASE}{base}/messages"),
};
let resp = client
.get(&path)
.bearer_auth(token)
.query(&[("$top", top.to_string())])
.send()
.await
.context("ms365: mail_list request failed")?;
handle_json_response(resp, "mail_list").await
}
/// Send a mail message.
pub async fn mail_send(
client: &reqwest::Client,
token: &str,
user_id: &str,
to: &[String],
subject: &str,
body: &str,
) -> anyhow::Result<()> {
let base = user_path(user_id);
let url = format!("{GRAPH_BASE}{base}/sendMail");
let to_recipients: Vec<serde_json::Value> = to
.iter()
.map(|addr| {
serde_json::json!({
"emailAddress": { "address": addr }
})
})
.collect();
let payload = serde_json::json!({
"message": {
"subject": subject,
"body": {
"contentType": "Text",
"content": body
},
"toRecipients": to_recipients
}
});
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: mail_send request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: mail_send raw error body: {body}");
anyhow::bail!("ms365: mail_send failed ({status}, code={code})");
}
Ok(())
}
/// List messages in a Teams channel.
pub async fn teams_message_list(
client: &reqwest::Client,
token: &str,
team_id: &str,
channel_id: &str,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let url = format!(
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
encode_path_segment(team_id),
encode_path_segment(channel_id)
);
let resp = client
.get(&url)
.bearer_auth(token)
.query(&[("$top", top.to_string())])
.send()
.await
.context("ms365: teams_message_list request failed")?;
handle_json_response(resp, "teams_message_list").await
}
/// Send a message to a Teams channel.
pub async fn teams_message_send(
client: &reqwest::Client,
token: &str,
team_id: &str,
channel_id: &str,
body: &str,
) -> anyhow::Result<()> {
let url = format!(
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
encode_path_segment(team_id),
encode_path_segment(channel_id)
);
let payload = serde_json::json!({
"body": {
"content": body
}
});
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: teams_message_send request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: teams_message_send raw error body: {body}");
anyhow::bail!("ms365: teams_message_send failed ({status}, code={code})");
}
Ok(())
}
/// List calendar events in a date range.
pub async fn calendar_events_list(
client: &reqwest::Client,
token: &str,
user_id: &str,
start: &str,
end: &str,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let base = user_path(user_id);
let url = format!("{GRAPH_BASE}{base}/calendarView");
let resp = client
.get(&url)
.bearer_auth(token)
.query(&[
("startDateTime", start.to_string()),
("endDateTime", end.to_string()),
("$top", top.to_string()),
])
.send()
.await
.context("ms365: calendar_events_list request failed")?;
handle_json_response(resp, "calendar_events_list").await
}
/// Create a calendar event.
pub async fn calendar_event_create(
client: &reqwest::Client,
token: &str,
user_id: &str,
subject: &str,
start: &str,
end: &str,
attendees: &[String],
body_text: Option<&str>,
) -> anyhow::Result<String> {
let base = user_path(user_id);
let url = format!("{GRAPH_BASE}{base}/events");
let attendee_list: Vec<serde_json::Value> = attendees
.iter()
.map(|email| {
serde_json::json!({
"emailAddress": { "address": email },
"type": "required"
})
})
.collect();
let mut payload = serde_json::json!({
"subject": subject,
"start": {
"dateTime": start,
"timeZone": "UTC"
},
"end": {
"dateTime": end,
"timeZone": "UTC"
},
"attendees": attendee_list
});
if let Some(text) = body_text {
payload["body"] = serde_json::json!({
"contentType": "Text",
"content": text
});
}
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: calendar_event_create request failed")?;
let value = handle_json_response(resp, "calendar_event_create").await?;
let event_id = value["id"].as_str().unwrap_or("unknown").to_string();
Ok(event_id)
}
/// Delete a calendar event by ID.
pub async fn calendar_event_delete(
client: &reqwest::Client,
token: &str,
user_id: &str,
event_id: &str,
) -> anyhow::Result<()> {
let base = user_path(user_id);
let url = format!(
"{GRAPH_BASE}{base}/events/{}",
encode_path_segment(event_id)
);
let resp = client
.delete(&url)
.bearer_auth(token)
.send()
.await
.context("ms365: calendar_event_delete request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: calendar_event_delete raw error body: {body}");
anyhow::bail!("ms365: calendar_event_delete failed ({status}, code={code})");
}
Ok(())
}
/// List children of a OneDrive folder.
pub async fn onedrive_list(
client: &reqwest::Client,
token: &str,
user_id: &str,
path: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let base = user_path(user_id);
let url = match path {
Some(p) if !p.is_empty() => {
let encoded = urlencoding::encode(p);
format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children")
}
_ => format!("{GRAPH_BASE}{base}/drive/root/children"),
};
let resp = client
.get(&url)
.bearer_auth(token)
.send()
.await
.context("ms365: onedrive_list request failed")?;
handle_json_response(resp, "onedrive_list").await
}
/// Download a OneDrive item by ID, with a maximum size guard.
pub async fn onedrive_download(
client: &reqwest::Client,
token: &str,
user_id: &str,
item_id: &str,
max_size: usize,
) -> anyhow::Result<Vec<u8>> {
let base = user_path(user_id);
let url = format!(
"{GRAPH_BASE}{base}/drive/items/{}/content",
encode_path_segment(item_id)
);
let resp = client
.get(&url)
.bearer_auth(token)
.send()
.await
.context("ms365: onedrive_download request failed")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: onedrive_download raw error body: {body}");
anyhow::bail!("ms365: onedrive_download failed ({status}, code={code})");
}
let bytes = resp
.bytes()
.await
.context("ms365: failed to read download body")?;
if bytes.len() > max_size {
anyhow::bail!(
"ms365: downloaded file exceeds max_size ({} > {max_size})",
bytes.len()
);
}
Ok(bytes.to_vec())
}
/// Search SharePoint for documents matching a query.
pub async fn sharepoint_search(
client: &reqwest::Client,
token: &str,
query: &str,
top: u32,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{GRAPH_BASE}/search/query");
let payload = serde_json::json!({
"requests": [{
"entityTypes": ["driveItem", "listItem", "site"],
"query": {
"queryString": query
},
"from": 0,
"size": top
}]
});
let resp = client
.post(&url)
.bearer_auth(token)
.json(&payload)
.send()
.await
.context("ms365: sharepoint_search request failed")?;
handle_json_response(resp, "sharepoint_search").await
}
/// Extract a short, safe error code from a Graph API JSON error body.
/// Returns `None` when the body is not a recognised Graph error envelope.
fn extract_graph_error_code(body: &str) -> Option<String> {
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
let code = parsed
.get("error")
.and_then(|e| e.get("code"))
.and_then(|c| c.as_str())
.map(|s| s.to_string());
code
}
/// Parse a JSON response body, returning an error on non-success status.
/// Raw Graph API error bodies are not propagated; only the HTTP status and a
/// short error code (when available) are surfaced to avoid leaking internal
/// API details.
async fn handle_json_response(
resp: reqwest::Response,
operation: &str,
) -> anyhow::Result<serde_json::Value> {
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
tracing::debug!("ms365: {operation} raw error body: {body}");
anyhow::bail!("ms365: {operation} failed ({status}, code={code})");
}
resp.json()
.await
.with_context(|| format!("ms365: failed to parse {operation} response"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn user_path_me() {
assert_eq!(user_path("me"), "/me");
}
#[test]
fn user_path_specific_user() {
assert_eq!(user_path("user@contoso.com"), "/users/user%40contoso.com");
}
#[test]
fn mail_list_url_no_folder() {
let base = user_path("me");
let url = format!("{GRAPH_BASE}{base}/messages");
assert_eq!(url, "https://graph.microsoft.com/v1.0/me/messages");
}
#[test]
fn mail_list_url_with_folder() {
let base = user_path("me");
let folder = "inbox";
let url = format!(
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
encode_path_segment(folder)
);
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/me/mailFolders/inbox/messages"
);
}
#[test]
fn calendar_view_url() {
let base = user_path("user@example.com");
let url = format!("{GRAPH_BASE}{base}/calendarView");
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/users/user%40example.com/calendarView"
);
}
#[test]
fn teams_message_url() {
let url = format!(
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
encode_path_segment("team-123"),
encode_path_segment("channel-456")
);
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/teams/team-123/channels/channel-456/messages"
);
}
#[test]
fn onedrive_root_url() {
let base = user_path("me");
let url = format!("{GRAPH_BASE}{base}/drive/root/children");
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/me/drive/root/children"
);
}
#[test]
fn onedrive_path_url() {
let base = user_path("me");
let encoded = urlencoding::encode("Documents/Reports");
let url = format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children");
assert_eq!(
url,
"https://graph.microsoft.com/v1.0/me/drive/root:/Documents%2FReports:/children"
);
}
#[test]
fn sharepoint_search_url() {
let url = format!("{GRAPH_BASE}/search/query");
assert_eq!(url, "https://graph.microsoft.com/v1.0/search/query");
}
}
+567
View File
@@ -0,0 +1,567 @@
//! Microsoft 365 integration tool — Graph API access for Mail, Teams, Calendar,
//! OneDrive, and SharePoint via a single action-dispatched tool surface.
//!
//! Auth is handled through direct HTTP calls to the Microsoft identity platform
//! (client credentials or device code flow) with token caching.
pub mod auth;
pub mod graph_client;
pub mod types;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use crate::tools::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Maximum download size for OneDrive files (10 MB).
const MAX_ONEDRIVE_DOWNLOAD_SIZE: usize = 10 * 1024 * 1024;
/// Default number of items to return in list operations.
const DEFAULT_TOP: u32 = 25;
pub struct Microsoft365Tool {
config: types::Microsoft365ResolvedConfig,
security: Arc<SecurityPolicy>,
token_cache: Arc<auth::TokenCache>,
http_client: reqwest::Client,
}
impl Microsoft365Tool {
pub fn new(
config: types::Microsoft365ResolvedConfig,
security: Arc<SecurityPolicy>,
zeroclaw_dir: &std::path::Path,
) -> anyhow::Result<Self> {
let http_client =
crate::config::build_runtime_proxy_client_with_timeouts("tool.microsoft365", 60, 10);
let token_cache = Arc::new(auth::TokenCache::new(config.clone(), zeroclaw_dir)?);
Ok(Self {
config,
security,
token_cache,
http_client,
})
}
async fn get_token(&self) -> anyhow::Result<String> {
self.token_cache.get_token(&self.http_client).await
}
fn user_id(&self) -> &str {
&self.config.user_id
}
async fn dispatch(&self, action: &str, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
match action {
"mail_list" => self.handle_mail_list(args).await,
"mail_send" => self.handle_mail_send(args).await,
"teams_message_list" => self.handle_teams_message_list(args).await,
"teams_message_send" => self.handle_teams_message_send(args).await,
"calendar_events_list" => self.handle_calendar_events_list(args).await,
"calendar_event_create" => self.handle_calendar_event_create(args).await,
"calendar_event_delete" => self.handle_calendar_event_delete(args).await,
"onedrive_list" => self.handle_onedrive_list(args).await,
"onedrive_download" => self.handle_onedrive_download(args).await,
"sharepoint_search" => self.handle_sharepoint_search(args).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {action}")),
}),
}
}
// ── Read actions ────────────────────────────────────────────────
async fn handle_mail_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.mail_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let folder = args["folder"].as_str();
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result =
graph_client::mail_list(&self.http_client, &token, self.user_id(), folder, top).await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_teams_message_list(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.teams_message_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let team_id = args["team_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
let channel_id = args["channel_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result =
graph_client::teams_message_list(&self.http_client, &token, team_id, channel_id, top)
.await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_calendar_events_list(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.calendar_events_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let start = args["start"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
let end = args["end"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result = graph_client::calendar_events_list(
&self.http_client,
&token,
self.user_id(),
start,
end,
top,
)
.await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_onedrive_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_list")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let path = args["path"].as_str();
let result =
graph_client::onedrive_list(&self.http_client, &token, self.user_id(), path).await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
async fn handle_onedrive_download(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_download")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let item_id = args["item_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("item_id is required"))?;
let max_size = args["max_size"]
.as_u64()
.and_then(|v| usize::try_from(v).ok())
.unwrap_or(MAX_ONEDRIVE_DOWNLOAD_SIZE)
.min(MAX_ONEDRIVE_DOWNLOAD_SIZE);
let bytes = graph_client::onedrive_download(
&self.http_client,
&token,
self.user_id(),
item_id,
max_size,
)
.await?;
// Return base64-encoded for binary safety.
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
Ok(ToolResult {
success: true,
output: format!(
"Downloaded {} bytes (base64 encoded):\n{encoded}",
bytes.len()
),
error: None,
})
}
async fn handle_sharepoint_search(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Read, "microsoft365.sharepoint_search")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let query = args["query"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("query is required"))?;
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
.unwrap_or(DEFAULT_TOP);
let result = graph_client::sharepoint_search(&self.http_client, &token, query, top).await?;
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&result)?,
error: None,
})
}
// ── Write actions ───────────────────────────────────────────────
async fn handle_mail_send(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.mail_send")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let to: Vec<String> = args["to"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("to must be an array of email addresses"))?
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect();
if to.is_empty() {
anyhow::bail!("to must contain at least one email address");
}
let subject = args["subject"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
let body = args["body"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
graph_client::mail_send(
&self.http_client,
&token,
self.user_id(),
&to,
subject,
body,
)
.await?;
Ok(ToolResult {
success: true,
output: format!("Email sent to: {}", to.join(", ")),
error: None,
})
}
async fn handle_teams_message_send(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.teams_message_send")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let team_id = args["team_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
let channel_id = args["channel_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
let body = args["body"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
graph_client::teams_message_send(&self.http_client, &token, team_id, channel_id, body)
.await?;
Ok(ToolResult {
success: true,
output: "Teams message sent".to_string(),
error: None,
})
}
async fn handle_calendar_event_create(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_create")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let subject = args["subject"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
let start = args["start"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
let end = args["end"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
let attendees: Vec<String> = args["attendees"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let body_text = args["body"].as_str();
let event_id = graph_client::calendar_event_create(
&self.http_client,
&token,
self.user_id(),
subject,
start,
end,
&attendees,
body_text,
)
.await?;
Ok(ToolResult {
success: true,
output: format!("Calendar event created (id: {event_id})"),
error: None,
})
}
async fn handle_calendar_event_delete(
&self,
args: &serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.security
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_delete")
.map_err(|e| anyhow::anyhow!(e))?;
let token = self.get_token().await?;
let event_id = args["event_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("event_id is required"))?;
graph_client::calendar_event_delete(&self.http_client, &token, self.user_id(), event_id)
.await?;
Ok(ToolResult {
success: true,
output: format!("Calendar event {event_id} deleted"),
error: None,
})
}
}
#[async_trait]
impl Tool for Microsoft365Tool {
fn name(&self) -> &str {
"microsoft365"
}
fn description(&self) -> &str {
"Microsoft 365 integration: manage Outlook mail, Teams messages, Calendar events, \
OneDrive files, and SharePoint search via Microsoft Graph API"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"required": ["action"],
"properties": {
"action": {
"type": "string",
"enum": [
"mail_list",
"mail_send",
"teams_message_list",
"teams_message_send",
"calendar_events_list",
"calendar_event_create",
"calendar_event_delete",
"onedrive_list",
"onedrive_download",
"sharepoint_search"
],
"description": "The Microsoft 365 action to perform"
},
"folder": {
"type": "string",
"description": "Mail folder ID (for mail_list, e.g. 'inbox', 'sentitems')"
},
"to": {
"type": "array",
"items": { "type": "string" },
"description": "Recipient email addresses (for mail_send)"
},
"subject": {
"type": "string",
"description": "Email subject or calendar event subject"
},
"body": {
"type": "string",
"description": "Message body text"
},
"team_id": {
"type": "string",
"description": "Teams team ID (for teams_message_list/send)"
},
"channel_id": {
"type": "string",
"description": "Teams channel ID (for teams_message_list/send)"
},
"start": {
"type": "string",
"description": "Start datetime in ISO 8601 format (for calendar actions)"
},
"end": {
"type": "string",
"description": "End datetime in ISO 8601 format (for calendar actions)"
},
"attendees": {
"type": "array",
"items": { "type": "string" },
"description": "Attendee email addresses (for calendar_event_create)"
},
"event_id": {
"type": "string",
"description": "Calendar event ID (for calendar_event_delete)"
},
"path": {
"type": "string",
"description": "OneDrive folder path (for onedrive_list)"
},
"item_id": {
"type": "string",
"description": "OneDrive item ID (for onedrive_download)"
},
"max_size": {
"type": "integer",
"description": "Maximum download size in bytes (for onedrive_download, default 10MB)"
},
"query": {
"type": "string",
"description": "Search query (for sharepoint_search)"
},
"top": {
"type": "integer",
"description": "Maximum number of items to return (default 25)"
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args["action"].as_str() {
Some(a) => a.to_string(),
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'action' parameter is required".to_string()),
});
}
};
match self.dispatch(&action, &args).await {
Ok(result) => Ok(result),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("microsoft365.{action} failed: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_name_is_microsoft365() {
// Verify the schema is valid JSON with the expected structure.
let schema_str = r#"{"type":"object","required":["action"]}"#;
let _: serde_json::Value = serde_json::from_str(schema_str).unwrap();
}
#[test]
fn parameters_schema_has_action_enum() {
let schema = json!({
"type": "object",
"required": ["action"],
"properties": {
"action": {
"type": "string",
"enum": [
"mail_list",
"mail_send",
"teams_message_list",
"teams_message_send",
"calendar_events_list",
"calendar_event_create",
"calendar_event_delete",
"onedrive_list",
"onedrive_download",
"sharepoint_search"
]
}
}
});
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
assert_eq!(actions.len(), 10);
assert!(actions.contains(&json!("mail_list")));
assert!(actions.contains(&json!("sharepoint_search")));
}
#[test]
fn action_dispatch_table_is_exhaustive() {
let valid_actions = [
"mail_list",
"mail_send",
"teams_message_list",
"teams_message_send",
"calendar_events_list",
"calendar_event_create",
"calendar_event_delete",
"onedrive_list",
"onedrive_download",
"sharepoint_search",
];
assert_eq!(valid_actions.len(), 10);
assert!(!valid_actions.contains(&"invalid_action"));
}
}
+55
View File
@@ -0,0 +1,55 @@
use serde::{Deserialize, Serialize};
/// Resolved Microsoft 365 configuration with all secrets decrypted and defaults applied.
#[derive(Clone, Serialize, Deserialize)]
pub struct Microsoft365ResolvedConfig {
pub tenant_id: String,
pub client_id: String,
pub client_secret: Option<String>,
pub auth_flow: String,
pub scopes: Vec<String>,
pub token_cache_encrypted: bool,
pub user_id: String,
}
impl std::fmt::Debug for Microsoft365ResolvedConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Microsoft365ResolvedConfig")
.field("tenant_id", &self.tenant_id)
.field("client_id", &self.client_id)
.field("client_secret", &self.client_secret.as_ref().map(|_| "***"))
.field("auth_flow", &self.auth_flow)
.field("scopes", &self.scopes)
.field("token_cache_encrypted", &self.token_cache_encrypted)
.field("user_id", &self.user_id)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolved_config_serialization_roundtrip() {
let config = Microsoft365ResolvedConfig {
tenant_id: "test-tenant".into(),
client_id: "test-client".into(),
client_secret: Some("secret".into()),
auth_flow: "client_credentials".into(),
scopes: vec!["https://graph.microsoft.com/.default".into()],
token_cache_encrypted: false,
user_id: "me".into(),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: Microsoft365ResolvedConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.tenant_id, "test-tenant");
assert_eq!(parsed.client_id, "test-client");
assert_eq!(parsed.client_secret.as_deref(), Some("secret"));
assert_eq!(parsed.auth_flow, "client_credentials");
assert_eq!(parsed.scopes.len(), 1);
assert_eq!(parsed.user_id, "me");
}
}
+152 -18
View File
@@ -48,19 +48,26 @@ pub mod mcp_transport;
pub mod memory_forget;
pub mod memory_recall;
pub mod memory_store;
pub mod microsoft365;
pub mod model_routing_config;
pub mod node_tool;
pub mod notion_tool;
pub mod pdf_read;
pub mod project_intel;
pub mod proxy_config;
pub mod pushover;
pub mod report_templates;
pub mod schedule;
pub mod schema;
pub mod screenshot;
pub mod security_ops;
pub mod shell;
pub mod swarm;
pub mod tool_search;
pub mod traits;
pub mod web_fetch;
pub mod web_search_tool;
pub mod workspace_tool;
pub use browser::{BrowserTool, ComputerUseConfig};
pub use browser_open::BrowserOpenTool;
@@ -92,23 +99,29 @@ pub use mcp_tool::McpToolWrapper;
pub use memory_forget::MemoryForgetTool;
pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool;
pub use microsoft365::Microsoft365Tool;
pub use model_routing_config::ModelRoutingConfigTool;
#[allow(unused_imports)]
pub use node_tool::NodeTool;
pub use notion_tool::NotionTool;
pub use pdf_read::PdfReadTool;
pub use project_intel::ProjectIntelTool;
pub use proxy_config::ProxyConfigTool;
pub use pushover::PushoverTool;
pub use schedule::ScheduleTool;
#[allow(unused_imports)]
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use screenshot::ScreenshotTool;
pub use security_ops::SecurityOpsTool;
pub use shell::ShellTool;
pub use swarm::SwarmTool;
pub use tool_search::ToolSearchTool;
pub use traits::Tool;
#[allow(unused_imports)]
pub use traits::{ToolResult, ToolSpec};
pub use web_fetch::WebFetchTool;
pub use web_search_tool::WebSearchTool;
pub use workspace_tool::WorkspaceTool;
use crate::config::{Config, DelegateAgentConfig};
use crate::memory::Memory;
@@ -314,6 +327,7 @@ pub fn all_tools_with_runtime(
http_config.allowed_domains.clone(),
http_config.max_response_size,
http_config.timeout_secs,
http_config.allow_private_hosts,
)));
}
@@ -339,6 +353,37 @@ pub fn all_tools_with_runtime(
)));
}
// Notion API tool (conditionally registered)
if root_config.notion.enabled {
let notion_api_key = if root_config.notion.api_key.trim().is_empty() {
std::env::var("NOTION_API_KEY").unwrap_or_default()
} else {
root_config.notion.api_key.trim().to_string()
};
if notion_api_key.trim().is_empty() {
tracing::warn!(
"Notion tool enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
);
} else {
tool_arcs.push(Arc::new(NotionTool::new(notion_api_key, security.clone())));
}
}
// Project delivery intelligence
if root_config.project_intel.enabled {
tool_arcs.push(Arc::new(ProjectIntelTool::new(
root_config.project_intel.default_language.clone(),
root_config.project_intel.risk_sensitivity.clone(),
)));
}
// MCSS Security Operations
if root_config.security_ops.enabled {
tool_arcs.push(Arc::new(SecurityOpsTool::new(
root_config.security_ops.clone(),
)));
}
// PDF extraction (feature-gated at compile time via rag-pdf)
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
@@ -356,7 +401,80 @@ pub fn all_tools_with_runtime(
}
}
// Microsoft 365 Graph API integration
if root_config.microsoft365.enabled {
let ms_cfg = &root_config.microsoft365;
let tenant_id = ms_cfg
.tenant_id
.as_deref()
.unwrap_or_default()
.trim()
.to_string();
let client_id = ms_cfg
.client_id
.as_deref()
.unwrap_or_default()
.trim()
.to_string();
if !tenant_id.is_empty() && !client_id.is_empty() {
// Fail fast: client_credentials flow requires a client_secret at registration time.
if ms_cfg.auth_flow.trim() == "client_credentials"
&& ms_cfg
.client_secret
.as_deref()
.map_or(true, |s| s.trim().is_empty())
{
tracing::error!(
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
);
return (boxed_registry_from_arcs(tool_arcs), None);
}
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
tenant_id,
client_id,
client_secret: ms_cfg.client_secret.clone(),
auth_flow: ms_cfg.auth_flow.clone(),
scopes: ms_cfg.scopes.clone(),
token_cache_encrypted: ms_cfg.token_cache_encrypted,
user_id: ms_cfg.user_id.as_deref().unwrap_or("me").to_string(),
};
// Store token cache in the config directory (next to config.toml),
// not the workspace directory, to keep bearer tokens out of the
// project tree.
let cache_dir = root_config.config_path.parent().unwrap_or(workspace_dir);
match Microsoft365Tool::new(resolved, security.clone(), cache_dir) {
Ok(tool) => tool_arcs.push(Arc::new(tool)),
Err(e) => {
tracing::error!("microsoft365: failed to initialize tool: {e}");
}
}
} else {
tracing::warn!(
"microsoft365: skipped registration because tenant_id or client_id is empty"
);
}
}
// Add delegation tool when agents are configured
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
let trimmed_value = value.trim();
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
});
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: root_config.api_url.clone(),
zeroclaw_dir: root_config
.config_path
.parent()
.map(std::path::PathBuf::from),
secrets_encrypt: root_config.secrets.encrypt,
reasoning_enabled: root_config.runtime.reasoning_enabled,
provider_timeout_secs: Some(root_config.provider_timeout_secs),
extra_headers: root_config.extra_headers.clone(),
api_path: root_config.api_path.clone(),
};
let delegate_handle: Option<DelegateParentToolsHandle> = if agents.is_empty() {
None
} else {
@@ -364,28 +482,12 @@ pub fn all_tools_with_runtime(
.iter()
.map(|(name, cfg)| (name.clone(), cfg.clone()))
.collect();
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
let trimmed_value = value.trim();
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
});
let parent_tools = Arc::new(RwLock::new(tool_arcs.clone()));
let delegate_tool = DelegateTool::new_with_options(
delegate_agents,
delegate_fallback_credential,
delegate_fallback_credential.clone(),
security.clone(),
crate::providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: root_config.api_url.clone(),
zeroclaw_dir: root_config
.config_path
.parent()
.map(std::path::PathBuf::from),
secrets_encrypt: root_config.secrets.encrypt,
reasoning_enabled: root_config.runtime.reasoning_enabled,
provider_timeout_secs: Some(root_config.provider_timeout_secs),
extra_headers: root_config.extra_headers.clone(),
api_path: root_config.api_path.clone(),
},
provider_runtime_options.clone(),
)
.with_parent_tools(Arc::clone(&parent_tools))
.with_multimodal_config(root_config.multimodal.clone());
@@ -393,6 +495,38 @@ pub fn all_tools_with_runtime(
Some(parent_tools)
};
// Add swarm tool when swarms are configured
if !root_config.swarms.is_empty() {
let swarm_agents: HashMap<String, DelegateAgentConfig> = agents
.iter()
.map(|(name, cfg)| (name.clone(), cfg.clone()))
.collect();
tool_arcs.push(Arc::new(SwarmTool::new(
root_config.swarms.clone(),
swarm_agents,
delegate_fallback_credential,
security.clone(),
provider_runtime_options,
)));
}
// Workspace management tool (conditionally registered when workspace isolation is enabled)
if root_config.workspace.enabled {
let workspaces_dir = if root_config.workspace.workspaces_dir.starts_with("~/") {
let home = directories::UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.unwrap_or_else(|| std::path::PathBuf::from("."));
home.join(&root_config.workspace.workspaces_dir[2..])
} else {
std::path::PathBuf::from(&root_config.workspace.workspaces_dir)
};
let ws_manager = crate::config::workspace::WorkspaceManager::new(workspaces_dir);
tool_arcs.push(Arc::new(WorkspaceTool::new(
Arc::new(tokio::sync::RwLock::new(ws_manager)),
security.clone(),
)));
}
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
}
+438
View File
@@ -0,0 +1,438 @@
use super::traits::{Tool, ToolResult};
use crate::security::{policy::ToolOperation, SecurityPolicy};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
const NOTION_VERSION: &str = "2022-06-28";
const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
/// Maximum number of characters to include from an error response body.
const MAX_ERROR_BODY_CHARS: usize = 500;
/// Tool for interacting with the Notion API — query databases, read/create/update pages,
/// and search the workspace. Each action is gated by the appropriate security operation
/// (Read for queries, Act for mutations).
pub struct NotionTool {
api_key: String,
http: reqwest::Client,
security: Arc<SecurityPolicy>,
}
impl NotionTool {
/// Create a new Notion tool with the given API key and security policy.
pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
Self {
api_key,
http: reqwest::Client::new(),
security,
}
}
/// Build the standard Notion API headers (Authorization, version, content-type).
fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", self.api_key)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
);
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
headers.insert("Content-Type", "application/json".parse().unwrap());
Ok(headers)
}
/// Query a Notion database with an optional filter.
async fn query_database(
&self,
database_id: &str,
filter: Option<&serde_json::Value>,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
let mut body = json!({});
if let Some(f) = filter {
body["filter"] = f.clone();
}
let resp = self
.http
.post(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion query_database failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Read a single Notion page by ID.
async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let resp = self
.http
.get(&url)
.headers(self.headers()?)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion read_page failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Create a new Notion page, optionally within a database.
async fn create_page(
&self,
properties: &serde_json::Value,
database_id: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/pages");
let mut body = json!({ "properties": properties });
if let Some(db_id) = database_id {
body["parent"] = json!({ "database_id": db_id });
}
let resp = self
.http
.post(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion create_page failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Update an existing Notion page's properties.
async fn update_page(
&self,
page_id: &str,
properties: &serde_json::Value,
) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
let body = json!({ "properties": properties });
let resp = self
.http
.patch(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion update_page failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
/// Search the Notion workspace by query string.
async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
let url = format!("{NOTION_API_BASE}/search");
let body = json!({ "query": query });
let resp = self
.http
.post(&url)
.headers(self.headers()?)
.json(&body)
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
anyhow::bail!("Notion search failed ({status}): {truncated}");
}
resp.json().await.map_err(Into::into)
}
}
#[async_trait]
impl Tool for NotionTool {
fn name(&self) -> &str {
"notion"
}
fn description(&self) -> &str {
"Interact with Notion: query databases, read/create/update pages, and search the workspace."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["query_database", "read_page", "create_page", "update_page", "search"],
"description": "The Notion API action to perform"
},
"database_id": {
"type": "string",
"description": "Database ID (required for query_database, optional for create_page)"
},
"page_id": {
"type": "string",
"description": "Page ID (required for read_page and update_page)"
},
"filter": {
"type": "object",
"description": "Notion filter object for query_database"
},
"properties": {
"type": "object",
"description": "Properties object for create_page and update_page"
},
"query": {
"type": "string",
"description": "Search query string for the search action"
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|v| v.as_str()) {
Some(a) => a,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: action".into()),
});
}
};
// Enforce granular security: Read for queries, Act for mutations
let operation = match action {
"query_database" | "read_page" | "search" => ToolOperation::Read,
"create_page" | "update_page" => ToolOperation::Act,
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
)),
});
}
};
if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let result = match action {
"query_database" => {
let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("query_database requires database_id parameter".into()),
});
}
};
let filter = args.get("filter");
self.query_database(database_id, filter).await
}
"read_page" => {
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("read_page requires page_id parameter".into()),
});
}
};
self.read_page(page_id).await
}
"create_page" => {
let properties = match args.get("properties") {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("create_page requires properties parameter".into()),
});
}
};
let database_id = args.get("database_id").and_then(|v| v.as_str());
self.create_page(properties, database_id).await
}
"update_page" => {
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("update_page requires page_id parameter".into()),
});
}
};
let properties = match args.get("properties") {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("update_page requires properties parameter".into()),
});
}
};
self.update_page(page_id, properties).await
}
"search" => {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
self.search(query).await
}
_ => unreachable!(), // Already handled above
};
match result {
Ok(value) => Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::SecurityPolicy;
fn test_tool() -> NotionTool {
let security = Arc::new(SecurityPolicy::default());
NotionTool::new("test-key".into(), security)
}
#[test]
fn tool_name_is_notion() {
let tool = test_tool();
assert_eq!(tool.name(), "notion");
}
#[test]
fn parameters_schema_has_required_action() {
let tool = test_tool();
let schema = tool.parameters_schema();
let required = schema["required"].as_array().unwrap();
assert!(required.iter().any(|v| v.as_str() == Some("action")));
}
#[test]
fn parameters_schema_defines_all_actions() {
let tool = test_tool();
let schema = tool.parameters_schema();
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
assert!(action_strs.contains(&"query_database"));
assert!(action_strs.contains(&"read_page"));
assert!(action_strs.contains(&"create_page"));
assert!(action_strs.contains(&"update_page"));
assert!(action_strs.contains(&"search"));
}
#[tokio::test]
async fn execute_missing_action_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("action"));
}
#[tokio::test]
async fn execute_unknown_action_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("Unknown action"));
}
#[tokio::test]
async fn execute_query_database_missing_id_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "query_database"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("database_id"));
}
#[tokio::test]
async fn execute_read_page_missing_id_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("page_id"));
}
#[tokio::test]
async fn execute_create_page_missing_properties_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "create_page"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("properties"));
}
#[tokio::test]
async fn execute_update_page_missing_page_id_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "update_page", "properties": {}}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("page_id"));
}
#[tokio::test]
async fn execute_update_page_missing_properties_returns_error() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "update_page", "page_id": "test-id"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("properties"));
}
}
+750
View File
@@ -0,0 +1,750 @@
//! Project delivery intelligence tool.
//!
//! Provides read-only analysis and generation for project management:
//! status reports, risk detection, client communication drafting,
//! sprint summaries, and effort estimation.
use super::report_templates;
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::fmt::Write as _;
/// Project intelligence tool for consulting project management.
///
/// All actions are read-only analysis/generation; nothing is modified externally.
pub struct ProjectIntelTool {
default_language: String,
risk_sensitivity: RiskSensitivity,
}
/// Risk detection sensitivity level.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RiskSensitivity {
Low,
Medium,
High,
}
impl RiskSensitivity {
fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"low" => Self::Low,
"high" => Self::High,
_ => Self::Medium,
}
}
/// Threshold multiplier: higher sensitivity means lower thresholds.
fn threshold_factor(self) -> f64 {
match self {
Self::Low => 1.5,
Self::Medium => 1.0,
Self::High => 0.5,
}
}
}
impl ProjectIntelTool {
pub fn new(default_language: String, risk_sensitivity: String) -> Self {
Self {
default_language,
risk_sensitivity: RiskSensitivity::from_str(&risk_sensitivity),
}
}
fn execute_status_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let project_name = args
.get("project_name")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for status_report"))?;
let period = args
.get("period")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'period' for status_report"))?;
let lang = args
.get("language")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_language);
let git_log = args
.get("git_log")
.and_then(|v| v.as_str())
.unwrap_or("No git data provided");
let jira_summary = args
.get("jira_summary")
.and_then(|v| v.as_str())
.unwrap_or("No Jira data provided");
let notes = args.get("notes").and_then(|v| v.as_str()).unwrap_or("");
let tpl = report_templates::weekly_status_template(lang);
let mut vars = HashMap::new();
vars.insert("project_name".into(), project_name.to_string());
vars.insert("period".into(), period.to_string());
vars.insert("completed".into(), git_log.to_string());
vars.insert("in_progress".into(), jira_summary.to_string());
vars.insert("blocked".into(), notes.to_string());
vars.insert("next_steps".into(), "To be determined".into());
let rendered = tpl.render(&vars);
Ok(ToolResult {
success: true,
output: rendered,
error: None,
})
}
fn execute_risk_scan(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let deadlines = args
.get("deadlines")
.and_then(|v| v.as_str())
.unwrap_or_default();
let velocity = args
.get("velocity")
.and_then(|v| v.as_str())
.unwrap_or_default();
let blockers = args
.get("blockers")
.and_then(|v| v.as_str())
.unwrap_or_default();
let lang = args
.get("language")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_language);
let mut risks = Vec::new();
// Heuristic risk detection based on signals
let factor = self.risk_sensitivity.threshold_factor();
if !blockers.is_empty() {
let blocker_count = blockers.lines().filter(|l| !l.trim().is_empty()).count();
let severity = if (blocker_count as f64) > 3.0 * factor {
"critical"
} else if (blocker_count as f64) > 1.0 * factor {
"high"
} else {
"medium"
};
risks.push(RiskItem {
title: "Active blockers detected".into(),
severity: severity.into(),
detail: format!("{blocker_count} blocker(s) identified"),
mitigation: "Escalate blockers, assign owners, set resolution deadlines".into(),
});
}
if deadlines.to_lowercase().contains("overdue")
|| deadlines.to_lowercase().contains("missed")
{
risks.push(RiskItem {
title: "Deadline risk".into(),
severity: "high".into(),
detail: "Overdue or missed deadlines detected in project context".into(),
mitigation: "Re-prioritize scope, negotiate timeline, add resources".into(),
});
}
if velocity.to_lowercase().contains("declining") || velocity.to_lowercase().contains("slow")
{
risks.push(RiskItem {
title: "Velocity degradation".into(),
severity: "medium".into(),
detail: "Team velocity is declining or below expectations".into(),
mitigation: "Identify bottlenecks, reduce WIP, address technical debt".into(),
});
}
if risks.is_empty() {
risks.push(RiskItem {
title: "No significant risks detected".into(),
severity: "low".into(),
detail: "Current project signals within normal parameters".into(),
mitigation: "Continue monitoring".into(),
});
}
let tpl = report_templates::risk_register_template(lang);
let risks_text = risks
.iter()
.map(|r| {
format!(
"- [{}] {}: {}",
r.severity.to_uppercase(),
r.title,
r.detail
)
})
.collect::<Vec<_>>()
.join("\n");
let mitigations_text = risks
.iter()
.map(|r| format!("- {}: {}", r.title, r.mitigation))
.collect::<Vec<_>>()
.join("\n");
let mut vars = HashMap::new();
vars.insert(
"project_name".into(),
args.get("project_name")
.and_then(|v| v.as_str())
.unwrap_or("Unknown")
.to_string(),
);
vars.insert("risks".into(), risks_text);
vars.insert("mitigations".into(), mitigations_text);
Ok(ToolResult {
success: true,
output: tpl.render(&vars),
error: None,
})
}
fn execute_draft_update(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let project_name = args
.get("project_name")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for draft_update"))?;
let audience = args
.get("audience")
.and_then(|v| v.as_str())
.unwrap_or("client");
let tone = args
.get("tone")
.and_then(|v| v.as_str())
.unwrap_or("formal");
let highlights = args
.get("highlights")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| anyhow::anyhow!("missing required 'highlights' for draft_update"))?;
let concerns = args.get("concerns").and_then(|v| v.as_str()).unwrap_or("");
let greeting = match (audience, tone) {
("client", "casual") => "Hi there,".to_string(),
("client", _) => "Dear valued partner,".to_string(),
("internal", "casual") => "Hey team,".to_string(),
("internal", _) => "Dear team,".to_string(),
(_, "casual") => "Hi,".to_string(),
_ => "Dear reader,".to_string(),
};
let closing = match tone {
"casual" => "Cheers",
_ => "Best regards",
};
let mut body = format!(
"{greeting}\n\nHere is an update on {project_name}.\n\n**Highlights:**\n{highlights}"
);
if !concerns.is_empty() {
let _ = write!(body, "\n\n**Items requiring attention:**\n{concerns}");
}
let _ = write!(
body,
"\n\nPlease do not hesitate to reach out with any questions.\n\n{closing}"
);
Ok(ToolResult {
success: true,
output: body,
error: None,
})
}
fn execute_sprint_summary(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let sprint_dates = args
.get("sprint_dates")
.and_then(|v| v.as_str())
.unwrap_or("current sprint");
let completed = args
.get("completed")
.and_then(|v| v.as_str())
.unwrap_or("None specified");
let in_progress = args
.get("in_progress")
.and_then(|v| v.as_str())
.unwrap_or("None specified");
let blocked = args
.get("blocked")
.and_then(|v| v.as_str())
.unwrap_or("None");
let velocity = args
.get("velocity")
.and_then(|v| v.as_str())
.unwrap_or("Not calculated");
let lang = args
.get("language")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_language);
let tpl = report_templates::sprint_review_template(lang);
let mut vars = HashMap::new();
vars.insert("sprint_dates".into(), sprint_dates.to_string());
vars.insert("completed".into(), completed.to_string());
vars.insert("in_progress".into(), in_progress.to_string());
vars.insert("blocked".into(), blocked.to_string());
vars.insert("velocity".into(), velocity.to_string());
Ok(ToolResult {
success: true,
output: tpl.render(&vars),
error: None,
})
}
fn execute_effort_estimate(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let tasks = args.get("tasks").and_then(|v| v.as_str()).unwrap_or("");
if tasks.trim().is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("No task descriptions provided".into()),
});
}
let mut estimates = Vec::new();
for line in tasks.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let (size, rationale) = estimate_task_effort(line);
estimates.push(format!("- **{size}** | {line}\n Rationale: {rationale}"));
}
let output = format!(
"## Effort Estimates\n\n{}\n\n_Sizes: XS (<2h), S (2-4h), M (4-8h), L (1-3d), XL (3-5d), XXL (>5d)_",
estimates.join("\n")
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
struct RiskItem {
title: String,
severity: String,
detail: String,
mitigation: String,
}
/// Heuristic effort estimation from task description text.
fn estimate_task_effort(description: &str) -> (&'static str, &'static str) {
let lower = description.to_lowercase();
let word_count = description.split_whitespace().count();
// Signal-based heuristics
let complexity_signals = [
"refactor",
"rewrite",
"migrate",
"redesign",
"architecture",
"infrastructure",
];
let medium_signals = [
"implement",
"create",
"build",
"integrate",
"add feature",
"new module",
];
let small_signals = [
"fix", "update", "tweak", "adjust", "rename", "typo", "bump", "config",
];
if complexity_signals.iter().any(|s| lower.contains(s)) {
if word_count > 15 {
return (
"XXL",
"Large-scope structural change with extensive description",
);
}
return ("XL", "Structural change requiring significant effort");
}
if medium_signals.iter().any(|s| lower.contains(s)) {
if word_count > 12 {
return ("L", "Feature implementation with detailed requirements");
}
return ("M", "Standard feature implementation");
}
if small_signals.iter().any(|s| lower.contains(s)) {
if word_count > 10 {
return ("S", "Small change with additional context");
}
return ("XS", "Minor targeted change");
}
// Fallback: estimate by description length as a proxy for complexity
if word_count > 20 {
("L", "Complex task inferred from detailed description")
} else if word_count > 10 {
("M", "Moderate task inferred from description length")
} else {
("S", "Simple task inferred from brief description")
}
}
#[async_trait]
impl Tool for ProjectIntelTool {
fn name(&self) -> &str {
"project_intel"
}
fn description(&self) -> &str {
"Project delivery intelligence: generate status reports, detect risks, draft client updates, summarize sprints, and estimate effort. Read-only analysis tool."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["status_report", "risk_scan", "draft_update", "sprint_summary", "effort_estimate"],
"description": "The analysis action to perform"
},
"project_name": {
"type": "string",
"description": "Project name (for status_report, risk_scan, draft_update)"
},
"period": {
"type": "string",
"description": "Reporting period: week, sprint, or month (for status_report)"
},
"language": {
"type": "string",
"description": "Report language: en, de, fr, it (default from config)"
},
"git_log": {
"type": "string",
"description": "Git log summary text (for status_report)"
},
"jira_summary": {
"type": "string",
"description": "Jira/issue tracker summary (for status_report)"
},
"notes": {
"type": "string",
"description": "Additional notes or context"
},
"deadlines": {
"type": "string",
"description": "Deadline information (for risk_scan)"
},
"velocity": {
"type": "string",
"description": "Team velocity data (for risk_scan, sprint_summary)"
},
"blockers": {
"type": "string",
"description": "Current blockers (for risk_scan)"
},
"audience": {
"type": "string",
"enum": ["client", "internal"],
"description": "Target audience (for draft_update)"
},
"tone": {
"type": "string",
"enum": ["formal", "casual"],
"description": "Communication tone (for draft_update)"
},
"highlights": {
"type": "string",
"description": "Key highlights for the update (for draft_update)"
},
"concerns": {
"type": "string",
"description": "Items requiring attention (for draft_update)"
},
"sprint_dates": {
"type": "string",
"description": "Sprint date range (for sprint_summary)"
},
"completed": {
"type": "string",
"description": "Completed items (for sprint_summary)"
},
"in_progress": {
"type": "string",
"description": "In-progress items (for sprint_summary)"
},
"blocked": {
"type": "string",
"description": "Blocked items (for sprint_summary)"
},
"tasks": {
"type": "string",
"description": "Task descriptions, one per line (for effort_estimate)"
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args
.get("action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
match action {
"status_report" => self.execute_status_report(&args),
"risk_scan" => self.execute_risk_scan(&args),
"draft_update" => self.execute_draft_update(&args),
"sprint_summary" => self.execute_sprint_summary(&args),
"effort_estimate" => self.execute_effort_estimate(&args),
other => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action '{other}'. Valid actions: status_report, risk_scan, draft_update, sprint_summary, effort_estimate"
)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tool() -> ProjectIntelTool {
ProjectIntelTool::new("en".into(), "medium".into())
}
#[test]
fn tool_name_and_description() {
let t = tool();
assert_eq!(t.name(), "project_intel");
assert!(!t.description().is_empty());
}
#[test]
fn parameters_schema_has_action() {
let t = tool();
let schema = t.parameters_schema();
assert!(schema["properties"]["action"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&serde_json::Value::String("action".into())));
}
#[tokio::test]
async fn status_report_renders() {
let t = tool();
let result = t
.execute(json!({
"action": "status_report",
"project_name": "TestProject",
"period": "week",
"git_log": "- feat: added login"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("TestProject"));
assert!(result.output.contains("added login"));
}
#[tokio::test]
async fn risk_scan_detects_blockers() {
let t = tool();
let result = t
.execute(json!({
"action": "risk_scan",
"blockers": "DB migration stuck\nCI pipeline broken\nAPI key expired"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("blocker"));
}
#[tokio::test]
async fn risk_scan_detects_deadline_risk() {
let t = tool();
let result = t
.execute(json!({
"action": "risk_scan",
"deadlines": "Sprint deadline overdue by 3 days"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Deadline risk"));
}
#[tokio::test]
async fn risk_scan_no_signals_returns_low_risk() {
let t = tool();
let result = t.execute(json!({ "action": "risk_scan" })).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No significant risks"));
}
#[tokio::test]
async fn draft_update_formal_client() {
let t = tool();
let result = t
.execute(json!({
"action": "draft_update",
"project_name": "Portal",
"audience": "client",
"tone": "formal",
"highlights": "Phase 1 delivered"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Dear valued partner"));
assert!(result.output.contains("Portal"));
assert!(result.output.contains("Phase 1 delivered"));
}
#[tokio::test]
async fn draft_update_casual_internal() {
let t = tool();
let result = t
.execute(json!({
"action": "draft_update",
"project_name": "ZeroClaw",
"audience": "internal",
"tone": "casual",
"highlights": "Core loop stabilized"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Hey team"));
assert!(result.output.contains("Cheers"));
}
#[tokio::test]
async fn sprint_summary_renders() {
let t = tool();
let result = t
.execute(json!({
"action": "sprint_summary",
"sprint_dates": "2026-03-01 to 2026-03-14",
"completed": "- Login page\n- API endpoints",
"in_progress": "- Dashboard",
"blocked": "- Payment integration"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Login page"));
assert!(result.output.contains("Dashboard"));
}
#[tokio::test]
async fn effort_estimate_basic() {
let t = tool();
let result = t
.execute(json!({
"action": "effort_estimate",
"tasks": "Fix typo in README\nImplement user authentication\nRefactor database layer"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("XS"));
assert!(result.output.contains("Refactor database layer"));
}
#[tokio::test]
async fn effort_estimate_empty_tasks_fails() {
let t = tool();
let result = t
.execute(json!({ "action": "effort_estimate", "tasks": "" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("No task descriptions"));
}
#[tokio::test]
async fn unknown_action_returns_error() {
let t = tool();
let result = t
.execute(json!({ "action": "invalid_thing" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}
#[tokio::test]
async fn missing_action_returns_error() {
let t = tool();
let result = t.execute(json!({})).await;
assert!(result.is_err());
}
#[test]
fn effort_estimate_heuristics_coverage() {
assert_eq!(estimate_task_effort("Fix typo").0, "XS");
assert_eq!(estimate_task_effort("Update config values").0, "XS");
assert_eq!(
estimate_task_effort("Implement new notification system").0,
"M"
);
assert_eq!(
estimate_task_effort("Refactor the entire authentication module").0,
"XL"
);
assert_eq!(
estimate_task_effort("Migrate the database schema to support multi-tenancy with data isolation and proper indexing across all services").0,
"XXL"
);
}
#[test]
fn risk_sensitivity_threshold_ordering() {
assert!(
RiskSensitivity::High.threshold_factor() < RiskSensitivity::Medium.threshold_factor()
);
assert!(
RiskSensitivity::Medium.threshold_factor() < RiskSensitivity::Low.threshold_factor()
);
}
#[test]
fn risk_sensitivity_from_str_variants() {
assert_eq!(RiskSensitivity::from_str("low"), RiskSensitivity::Low);
assert_eq!(RiskSensitivity::from_str("high"), RiskSensitivity::High);
assert_eq!(RiskSensitivity::from_str("medium"), RiskSensitivity::Medium);
assert_eq!(
RiskSensitivity::from_str("unknown"),
RiskSensitivity::Medium
);
}
#[tokio::test]
async fn high_sensitivity_detects_single_blocker_as_high() {
let t = ProjectIntelTool::new("en".into(), "high".into());
let result = t
.execute(json!({
"action": "risk_scan",
"blockers": "Single blocker"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("[HIGH]") || result.output.contains("[CRITICAL]"));
}
}
+582
View File
@@ -0,0 +1,582 @@
//! Report template engine for project delivery intelligence.
//!
//! Provides built-in templates for weekly status, sprint review, risk register,
//! and milestone reports with multi-language support (EN, DE, FR, IT).
use std::collections::HashMap;
use std::fmt::Write as _;
/// Supported report output formats.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReportFormat {
Markdown,
Html,
}
/// A named section within a report template.
#[derive(Debug, Clone)]
pub struct TemplateSection {
pub heading: String,
pub body: String,
}
/// A report template with named sections and variable placeholders.
#[derive(Debug, Clone)]
pub struct ReportTemplate {
pub name: String,
pub sections: Vec<TemplateSection>,
pub format: ReportFormat,
}
/// Escape a string for safe inclusion in HTML output.
fn escape_html(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#x27;")
}
impl ReportTemplate {
/// Render the template by substituting `{{key}}` placeholders with values.
pub fn render(&self, vars: &HashMap<String, String>) -> String {
let mut out = String::new();
for section in &self.sections {
let heading = substitute(&section.heading, vars);
let body = substitute(&section.body, vars);
match self.format {
ReportFormat::Markdown => {
let _ = write!(out, "## {heading}\n\n{body}\n\n");
}
ReportFormat::Html => {
let heading = escape_html(&heading);
let body = escape_html(&body);
let _ = write!(out, "<h2>{heading}</h2>\n<p>{body}</p>\n");
}
}
}
out.trim_end().to_string()
}
}
/// Single-pass placeholder substitution.
///
/// Scans `template` left-to-right for `{{key}}` tokens and replaces them with
/// the corresponding value from `vars`. Because the scan is single-pass,
/// values that themselves contain `{{...}}` sequences are emitted literally
/// and never re-expanded, preventing injection of new placeholders.
fn substitute(template: &str, vars: &HashMap<String, String>) -> String {
let mut result = String::with_capacity(template.len());
let bytes = template.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' {
// Find the closing `}}`.
if let Some(close) = template[i + 2..].find("}}") {
let key = &template[i + 2..i + 2 + close];
if let Some(value) = vars.get(key) {
result.push_str(value);
} else {
// Unknown placeholder: emit as-is.
result.push_str(&template[i..i + 2 + close + 2]);
}
i += 2 + close + 2;
continue;
}
}
result.push(template.as_bytes()[i] as char);
i += 1;
}
result
}
// ── Built-in templates ────────────────────────────────────────────
/// Return the built-in weekly status template for the given language.
pub fn weekly_status_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Wochenstatus",
vec![
TemplateSection {
heading: "Zusammenfassung".into(),
body: "Projekt: {{project_name}} | Zeitraum: {{period}}".into(),
},
TemplateSection {
heading: "Erledigt".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Bearbeitung".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blockiert".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Naechste Schritte".into(),
body: "{{next_steps}}".into(),
},
],
),
"fr" => (
"Statut hebdomadaire",
vec![
TemplateSection {
heading: "Resume".into(),
body: "Projet: {{project_name}} | Periode: {{period}}".into(),
},
TemplateSection {
heading: "Termine".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "En cours".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloque".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Prochaines etapes".into(),
body: "{{next_steps}}".into(),
},
],
),
"it" => (
"Stato settimanale",
vec![
TemplateSection {
heading: "Riepilogo".into(),
body: "Progetto: {{project_name}} | Periodo: {{period}}".into(),
},
TemplateSection {
heading: "Completato".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In corso".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloccato".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Prossimi passi".into(),
body: "{{next_steps}}".into(),
},
],
),
_ => (
"Weekly Status",
vec![
TemplateSection {
heading: "Summary".into(),
body: "Project: {{project_name}} | Period: {{period}}".into(),
},
TemplateSection {
heading: "Completed".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Progress".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blocked".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Next Steps".into(),
body: "{{next_steps}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
/// Return the built-in sprint review template for the given language.
pub fn sprint_review_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Sprint-Uebersicht",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Erledigt".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Bearbeitung".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blockiert".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocity".into(),
body: "{{velocity}}".into(),
},
],
),
"fr" => (
"Revue de sprint",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Termine".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "En cours".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloque".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocite".into(),
body: "{{velocity}}".into(),
},
],
),
"it" => (
"Revisione sprint",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Completato".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In corso".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Bloccato".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocita".into(),
body: "{{velocity}}".into(),
},
],
),
_ => (
"Sprint Review",
vec![
TemplateSection {
heading: "Sprint".into(),
body: "{{sprint_dates}}".into(),
},
TemplateSection {
heading: "Completed".into(),
body: "{{completed}}".into(),
},
TemplateSection {
heading: "In Progress".into(),
body: "{{in_progress}}".into(),
},
TemplateSection {
heading: "Blocked".into(),
body: "{{blocked}}".into(),
},
TemplateSection {
heading: "Velocity".into(),
body: "{{velocity}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
/// Return the built-in risk register template for the given language.
pub fn risk_register_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Risikoregister",
vec![
TemplateSection {
heading: "Projekt".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Risiken".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Massnahmen".into(),
body: "{{mitigations}}".into(),
},
],
),
"fr" => (
"Registre des risques",
vec![
TemplateSection {
heading: "Projet".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Risques".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Mesures".into(),
body: "{{mitigations}}".into(),
},
],
),
"it" => (
"Registro dei rischi",
vec![
TemplateSection {
heading: "Progetto".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Rischi".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Mitigazioni".into(),
body: "{{mitigations}}".into(),
},
],
),
_ => (
"Risk Register",
vec![
TemplateSection {
heading: "Project".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Risks".into(),
body: "{{risks}}".into(),
},
TemplateSection {
heading: "Mitigations".into(),
body: "{{mitigations}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
/// Return the built-in milestone report template for the given language.
pub fn milestone_report_template(lang: &str) -> ReportTemplate {
let (name, sections) = match lang {
"de" => (
"Meilensteinbericht",
vec![
TemplateSection {
heading: "Projekt".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Meilensteine".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Status".into(),
body: "{{status}}".into(),
},
],
),
"fr" => (
"Rapport de jalons",
vec![
TemplateSection {
heading: "Projet".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Jalons".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Statut".into(),
body: "{{status}}".into(),
},
],
),
"it" => (
"Report milestone",
vec![
TemplateSection {
heading: "Progetto".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Milestone".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Stato".into(),
body: "{{status}}".into(),
},
],
),
_ => (
"Milestone Report",
vec![
TemplateSection {
heading: "Project".into(),
body: "{{project_name}}".into(),
},
TemplateSection {
heading: "Milestones".into(),
body: "{{milestones}}".into(),
},
TemplateSection {
heading: "Status".into(),
body: "{{status}}".into(),
},
],
),
};
ReportTemplate {
name: name.into(),
sections,
format: ReportFormat::Markdown,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn weekly_status_renders_with_variables() {
let tpl = weekly_status_template("en");
let mut vars = HashMap::new();
vars.insert("project_name".into(), "ZeroClaw".into());
vars.insert("period".into(), "2026-W10".into());
vars.insert("completed".into(), "- Task A\n- Task B".into());
vars.insert("in_progress".into(), "- Task C".into());
vars.insert("blocked".into(), "None".into());
vars.insert("next_steps".into(), "- Task D".into());
let rendered = tpl.render(&vars);
assert!(rendered.contains("Project: ZeroClaw"));
assert!(rendered.contains("Period: 2026-W10"));
assert!(rendered.contains("- Task A"));
assert!(rendered.contains("## Completed"));
}
#[test]
fn weekly_status_de_renders_german_headings() {
let tpl = weekly_status_template("de");
let vars = HashMap::new();
let rendered = tpl.render(&vars);
assert!(rendered.contains("## Zusammenfassung"));
assert!(rendered.contains("## Erledigt"));
}
#[test]
fn weekly_status_fr_renders_french_headings() {
let tpl = weekly_status_template("fr");
let vars = HashMap::new();
let rendered = tpl.render(&vars);
assert!(rendered.contains("## Resume"));
assert!(rendered.contains("## Termine"));
}
#[test]
fn weekly_status_it_renders_italian_headings() {
let tpl = weekly_status_template("it");
let vars = HashMap::new();
let rendered = tpl.render(&vars);
assert!(rendered.contains("## Riepilogo"));
assert!(rendered.contains("## Completato"));
}
#[test]
fn html_format_renders_tags() {
let mut tpl = weekly_status_template("en");
tpl.format = ReportFormat::Html;
let mut vars = HashMap::new();
vars.insert("project_name".into(), "Test".into());
vars.insert("period".into(), "W1".into());
vars.insert("completed".into(), "Done".into());
vars.insert("in_progress".into(), "WIP".into());
vars.insert("blocked".into(), "None".into());
vars.insert("next_steps".into(), "Next".into());
let rendered = tpl.render(&vars);
assert!(rendered.contains("<h2>Summary</h2>"));
assert!(rendered.contains("<p>Project: Test | Period: W1</p>"));
}
#[test]
fn sprint_review_template_has_velocity_section() {
let tpl = sprint_review_template("en");
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
assert!(section_headings.contains(&"Velocity"));
}
#[test]
fn risk_register_template_has_risk_sections() {
let tpl = risk_register_template("en");
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
assert!(section_headings.contains(&"Risks"));
assert!(section_headings.contains(&"Mitigations"));
}
#[test]
fn milestone_template_all_languages() {
for lang in &["en", "de", "fr", "it"] {
let tpl = milestone_report_template(lang);
assert!(!tpl.name.is_empty());
assert_eq!(tpl.sections.len(), 3);
}
}
#[test]
fn substitute_leaves_unknown_placeholders() {
let vars = HashMap::new();
let result = substitute("Hello {{name}}", &vars);
assert_eq!(result, "Hello {{name}}");
}
#[test]
fn substitute_replaces_all_occurrences() {
let mut vars = HashMap::new();
vars.insert("x".into(), "1".into());
let result = substitute("{{x}} and {{x}}", &vars);
assert_eq!(result, "1 and 1");
}
}
+659
View File
@@ -0,0 +1,659 @@
//! Security operations tool for managed cybersecurity service (MCSS) workflows.
//!
//! Provides alert triage, incident response playbook execution, vulnerability
//! scan parsing, and security report generation. All actions that modify state
//! enforce human approval gates unless explicitly configured otherwise.
use async_trait::async_trait;
use serde_json::json;
use std::path::PathBuf;
use super::traits::{Tool, ToolResult};
use crate::config::SecurityOpsConfig;
use crate::security::playbook::{
evaluate_step, load_playbooks, severity_level, Playbook, StepStatus,
};
use crate::security::vulnerability::{generate_summary, parse_vulnerability_json};
/// Security operations tool — triage alerts, run playbooks, parse vulns, generate reports.
pub struct SecurityOpsTool {
config: SecurityOpsConfig,
playbooks: Vec<Playbook>,
}
impl SecurityOpsTool {
pub fn new(config: SecurityOpsConfig) -> Self {
let playbooks_dir = expand_tilde(&config.playbooks_dir);
let playbooks = load_playbooks(&playbooks_dir);
Self { config, playbooks }
}
/// Triage an alert: classify severity and recommend response.
fn triage_alert(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let alert = args
.get("alert")
.ok_or_else(|| anyhow::anyhow!("Missing required 'alert' parameter"))?;
// Extract key fields for classification
let alert_type = alert
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let source = alert
.get("source")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let severity = alert
.get("severity")
.and_then(|v| v.as_str())
.unwrap_or("medium");
let description = alert
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("");
// Classify and find matching playbooks
let matching_playbooks: Vec<&Playbook> = self
.playbooks
.iter()
.filter(|pb| {
severity_level(severity) >= severity_level(&pb.severity_filter)
&& (pb.name.contains(alert_type)
|| alert_type.contains(&pb.name)
|| description
.to_lowercase()
.contains(&pb.name.replace('_', " ")))
})
.collect();
let playbook_names: Vec<&str> =
matching_playbooks.iter().map(|p| p.name.as_str()).collect();
let output = json!({
"classification": {
"alert_type": alert_type,
"source": source,
"severity": severity,
"severity_level": severity_level(severity),
"priority": if severity_level(severity) >= 3 { "immediate" } else { "standard" },
},
"recommended_playbooks": playbook_names,
"recommended_action": if matching_playbooks.is_empty() {
"Manual investigation required — no matching playbook found"
} else {
"Execute recommended playbook(s)"
},
"auto_triage": self.config.auto_triage,
});
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&output)?,
error: None,
})
}
/// Execute a playbook step with approval gating.
fn run_playbook(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let playbook_name = args
.get("playbook")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required 'playbook' parameter"))?;
let step_index =
usize::try_from(args.get("step").and_then(|v| v.as_u64()).ok_or_else(|| {
anyhow::anyhow!("Missing required 'step' parameter (0-based index)")
})?)
.map_err(|_| anyhow::anyhow!("'step' parameter value too large for this platform"))?;
let alert_severity = args
.get("alert_severity")
.and_then(|v| v.as_str())
.unwrap_or("medium");
let playbook = self
.playbooks
.iter()
.find(|p| p.name == playbook_name)
.ok_or_else(|| anyhow::anyhow!("Playbook '{}' not found", playbook_name))?;
let result = evaluate_step(
playbook,
step_index,
alert_severity,
&self.config.max_auto_severity,
self.config.require_approval_for_actions,
);
let output = json!({
"playbook": playbook_name,
"step_index": result.step_index,
"action": result.action,
"status": result.status.to_string(),
"message": result.message,
"requires_manual_approval": result.status == StepStatus::PendingApproval,
});
Ok(ToolResult {
success: result.status != StepStatus::Failed,
output: serde_json::to_string_pretty(&output)?,
error: if result.status == StepStatus::Failed {
Some(result.message)
} else {
None
},
})
}
/// Parse vulnerability scan results.
fn parse_vulnerability(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let scan_data = args
.get("scan_data")
.ok_or_else(|| anyhow::anyhow!("Missing required 'scan_data' parameter"))?;
let json_str = if scan_data.is_string() {
scan_data.as_str().unwrap().to_string()
} else {
serde_json::to_string(scan_data)?
};
let report = parse_vulnerability_json(&json_str)?;
let summary = generate_summary(&report);
let output = json!({
"scanner": report.scanner,
"scan_date": report.scan_date.to_rfc3339(),
"total_findings": report.findings.len(),
"by_severity": {
"critical": report.findings.iter().filter(|f| f.severity == "critical").count(),
"high": report.findings.iter().filter(|f| f.severity == "high").count(),
"medium": report.findings.iter().filter(|f| f.severity == "medium").count(),
"low": report.findings.iter().filter(|f| f.severity == "low").count(),
},
"summary": summary,
});
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&output)?,
error: None,
})
}
/// Generate a client-facing security posture report.
fn generate_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let client_name = args
.get("client_name")
.and_then(|v| v.as_str())
.unwrap_or("Client");
let period = args
.get("period")
.and_then(|v| v.as_str())
.unwrap_or("current");
let alert_stats = args.get("alert_stats");
let vuln_summary = args
.get("vuln_summary")
.and_then(|v| v.as_str())
.unwrap_or("");
let report = format!(
"# Security Posture Report — {client_name}\n\
**Period:** {period}\n\
**Generated:** {}\n\n\
## Executive Summary\n\n\
This report provides an overview of the security posture for {client_name} \
during the {period} period.\n\n\
## Alert Summary\n\n\
{}\n\n\
## Vulnerability Assessment\n\n\
{}\n\n\
## Recommendations\n\n\
1. Address all critical and high-severity findings immediately\n\
2. Review and update incident response playbooks quarterly\n\
3. Conduct regular vulnerability scans on all internet-facing assets\n\
4. Ensure all endpoints have current security patches\n\n\
---\n\
*Report generated by ZeroClaw MCSS Agent*\n",
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"),
alert_stats
.map(|s| serde_json::to_string_pretty(s).unwrap_or_default())
.unwrap_or_else(|| "No alert statistics provided.".into()),
if vuln_summary.is_empty() {
"No vulnerability data provided."
} else {
vuln_summary
},
);
Ok(ToolResult {
success: true,
output: report,
error: None,
})
}
/// List available playbooks.
fn list_playbooks(&self) -> anyhow::Result<ToolResult> {
if self.playbooks.is_empty() {
return Ok(ToolResult {
success: true,
output: "No playbooks available.".into(),
error: None,
});
}
let playbook_list: Vec<serde_json::Value> = self
.playbooks
.iter()
.map(|pb| {
json!({
"name": pb.name,
"description": pb.description,
"steps": pb.steps.len(),
"severity_filter": pb.severity_filter,
"auto_approve_steps": pb.auto_approve_steps,
})
})
.collect();
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&playbook_list)?,
error: None,
})
}
/// Summarize alert volume, categories, and resolution times.
fn alert_stats(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let alerts = args
.get("alerts")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow::anyhow!("Missing required 'alerts' array parameter"))?;
let total = alerts.len();
let mut by_severity = std::collections::HashMap::new();
let mut by_category = std::collections::HashMap::new();
let mut resolved_count = 0u64;
let mut total_resolution_secs = 0u64;
for alert in alerts {
let severity = alert
.get("severity")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
*by_severity.entry(severity.to_string()).or_insert(0u64) += 1;
let category = alert
.get("category")
.and_then(|v| v.as_str())
.unwrap_or("uncategorized");
*by_category.entry(category.to_string()).or_insert(0u64) += 1;
if let Some(resolution_secs) = alert.get("resolution_secs").and_then(|v| v.as_u64()) {
resolved_count += 1;
total_resolution_secs += resolution_secs;
}
}
let avg_resolution = if resolved_count > 0 {
total_resolution_secs as f64 / resolved_count as f64
} else {
0.0
};
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let avg_resolution_secs_u64 = avg_resolution.max(0.0) as u64;
let output = json!({
"total_alerts": total,
"resolved": resolved_count,
"unresolved": total as u64 - resolved_count,
"by_severity": by_severity,
"by_category": by_category,
"avg_resolution_secs": avg_resolution,
"avg_resolution_human": format_duration_secs(avg_resolution_secs_u64),
});
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&output)?,
error: None,
})
}
}
fn format_duration_secs(secs: u64) -> String {
if secs < 60 {
format!("{secs}s")
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
}
}
/// Expand ~ to home directory.
fn expand_tilde(path: &str) -> PathBuf {
if let Some(rest) = path.strip_prefix("~/") {
if let Some(user_dirs) = directories::UserDirs::new() {
return user_dirs.home_dir().join(rest);
}
}
PathBuf::from(path)
}
#[async_trait]
impl Tool for SecurityOpsTool {
fn name(&self) -> &str {
"security_ops"
}
fn description(&self) -> &str {
"Security operations tool for managed cybersecurity services. Actions: \
triage_alert (classify/prioritize alerts), run_playbook (execute incident response steps), \
parse_vulnerability (parse scan results), generate_report (create security posture reports), \
list_playbooks (list available playbooks), alert_stats (summarize alert metrics)."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"required": ["action"],
"properties": {
"action": {
"type": "string",
"enum": ["triage_alert", "run_playbook", "parse_vulnerability", "generate_report", "list_playbooks", "alert_stats"],
"description": "The security operation to perform"
},
"alert": {
"type": "object",
"description": "Alert JSON for triage_alert (requires: type, severity; optional: source, description)"
},
"playbook": {
"type": "string",
"description": "Playbook name for run_playbook"
},
"step": {
"type": "integer",
"description": "0-based step index for run_playbook"
},
"alert_severity": {
"type": "string",
"description": "Alert severity context for run_playbook"
},
"scan_data": {
"description": "Vulnerability scan data (JSON string or object) for parse_vulnerability"
},
"client_name": {
"type": "string",
"description": "Client name for generate_report"
},
"period": {
"type": "string",
"description": "Reporting period for generate_report"
},
"alert_stats": {
"type": "object",
"description": "Alert statistics to include in generate_report"
},
"vuln_summary": {
"type": "string",
"description": "Vulnerability summary to include in generate_report"
},
"alerts": {
"type": "array",
"description": "Array of alert objects for alert_stats"
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args
.get("action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
match action {
"triage_alert" => self.triage_alert(&args),
"run_playbook" => self.run_playbook(&args),
"parse_vulnerability" => self.parse_vulnerability(&args),
"generate_report" => self.generate_report(&args),
"list_playbooks" => self.list_playbooks(),
"alert_stats" => self.alert_stats(&args),
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action '{action}'. Valid: triage_alert, run_playbook, \
parse_vulnerability, generate_report, list_playbooks, alert_stats"
)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> SecurityOpsConfig {
SecurityOpsConfig {
enabled: true,
playbooks_dir: "/nonexistent".into(),
auto_triage: false,
require_approval_for_actions: true,
max_auto_severity: "low".into(),
report_output_dir: "/tmp/reports".into(),
siem_integration: None,
}
}
fn test_tool() -> SecurityOpsTool {
SecurityOpsTool::new(test_config())
}
#[test]
fn tool_name_and_schema() {
let tool = test_tool();
assert_eq!(tool.name(), "security_ops");
let schema = tool.parameters_schema();
assert!(schema["properties"]["action"].is_object());
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("action")));
}
#[tokio::test]
async fn triage_alert_classifies_severity() {
let tool = test_tool();
let result = tool
.execute(json!({
"action": "triage_alert",
"alert": {
"type": "suspicious_login",
"source": "siem",
"severity": "high",
"description": "Multiple failed login attempts followed by successful login"
}
}))
.await
.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["classification"]["severity"], "high");
assert_eq!(output["classification"]["priority"], "immediate");
// Should match suspicious_login playbook
let playbooks = output["recommended_playbooks"].as_array().unwrap();
assert!(playbooks.iter().any(|p| p == "suspicious_login"));
}
#[tokio::test]
async fn triage_alert_missing_alert_param() {
let tool = test_tool();
let result = tool.execute(json!({"action": "triage_alert"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn run_playbook_requires_approval() {
let tool = test_tool();
let result = tool
.execute(json!({
"action": "run_playbook",
"playbook": "suspicious_login",
"step": 2,
"alert_severity": "high"
}))
.await
.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["status"], "pending_approval");
assert_eq!(output["requires_manual_approval"], true);
}
#[tokio::test]
async fn run_playbook_executes_safe_step() {
let tool = test_tool();
let result = tool
.execute(json!({
"action": "run_playbook",
"playbook": "suspicious_login",
"step": 0,
"alert_severity": "medium"
}))
.await
.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["status"], "completed");
}
#[tokio::test]
async fn run_playbook_not_found() {
let tool = test_tool();
let result = tool
.execute(json!({
"action": "run_playbook",
"playbook": "nonexistent",
"step": 0
}))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn parse_vulnerability_valid_report() {
let tool = test_tool();
let scan_data = json!({
"scan_date": "2025-01-15T10:00:00Z",
"scanner": "nessus",
"findings": [
{
"cve_id": "CVE-2024-0001",
"cvss_score": 9.8,
"severity": "critical",
"affected_asset": "web-01",
"description": "RCE in web framework",
"remediation": "Upgrade",
"internet_facing": true,
"production": true
}
]
});
let result = tool
.execute(json!({
"action": "parse_vulnerability",
"scan_data": scan_data
}))
.await
.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["total_findings"], 1);
assert_eq!(output["by_severity"]["critical"], 1);
}
#[tokio::test]
async fn generate_report_produces_markdown() {
let tool = test_tool();
let result = tool
.execute(json!({
"action": "generate_report",
"client_name": "ZeroClaw Corp",
"period": "Q1 2025"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("ZeroClaw Corp"));
assert!(result.output.contains("Q1 2025"));
assert!(result.output.contains("Security Posture Report"));
}
#[tokio::test]
async fn list_playbooks_returns_builtins() {
let tool = test_tool();
let result = tool
.execute(json!({"action": "list_playbooks"}))
.await
.unwrap();
assert!(result.success);
let output: Vec<serde_json::Value> = serde_json::from_str(&result.output).unwrap();
assert_eq!(output.len(), 4);
let names: Vec<&str> = output.iter().map(|p| p["name"].as_str().unwrap()).collect();
assert!(names.contains(&"suspicious_login"));
assert!(names.contains(&"malware_detected"));
}
#[tokio::test]
async fn alert_stats_computes_summary() {
let tool = test_tool();
let result = tool
.execute(json!({
"action": "alert_stats",
"alerts": [
{"severity": "critical", "category": "malware", "resolution_secs": 3600},
{"severity": "high", "category": "phishing", "resolution_secs": 1800},
{"severity": "medium", "category": "malware"},
{"severity": "low", "category": "policy_violation", "resolution_secs": 600}
]
}))
.await
.unwrap();
assert!(result.success);
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
assert_eq!(output["total_alerts"], 4);
assert_eq!(output["resolved"], 3);
assert_eq!(output["unresolved"], 1);
assert_eq!(output["by_severity"]["critical"], 1);
assert_eq!(output["by_category"]["malware"], 2);
}
#[tokio::test]
async fn unknown_action_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({"action": "bad_action"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}
#[test]
fn format_duration_secs_readable() {
assert_eq!(format_duration_secs(45), "45s");
assert_eq!(format_duration_secs(125), "2m 5s");
assert_eq!(format_duration_secs(3665), "1h 1m");
}
}
+953
View File
@@ -0,0 +1,953 @@
use super::traits::{Tool, ToolResult};
use crate::config::{DelegateAgentConfig, SwarmConfig, SwarmStrategy};
use crate::providers::{self, Provider};
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
/// Default timeout for individual agent calls within a swarm.
const SWARM_AGENT_TIMEOUT_SECS: u64 = 120;
/// Tool that orchestrates multiple agents as a swarm. Supports sequential
/// (pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies.
pub struct SwarmTool {
swarms: Arc<HashMap<String, SwarmConfig>>,
agents: Arc<HashMap<String, DelegateAgentConfig>>,
security: Arc<SecurityPolicy>,
fallback_credential: Option<String>,
provider_runtime_options: providers::ProviderRuntimeOptions,
}
impl SwarmTool {
pub fn new(
swarms: HashMap<String, SwarmConfig>,
agents: HashMap<String, DelegateAgentConfig>,
fallback_credential: Option<String>,
security: Arc<SecurityPolicy>,
provider_runtime_options: providers::ProviderRuntimeOptions,
) -> Self {
Self {
swarms: Arc::new(swarms),
agents: Arc::new(agents),
security,
fallback_credential,
provider_runtime_options,
}
}
fn create_provider_for_agent(
&self,
agent_config: &DelegateAgentConfig,
agent_name: &str,
) -> Result<Box<dyn Provider>, ToolResult> {
let credential = agent_config
.api_key
.clone()
.or_else(|| self.fallback_credential.clone());
providers::create_provider_with_options(
&agent_config.provider,
credential.as_deref(),
&self.provider_runtime_options,
)
.map_err(|e| ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Failed to create provider '{}' for agent '{agent_name}': {e}",
agent_config.provider
)),
})
}
async fn call_agent(
&self,
agent_name: &str,
agent_config: &DelegateAgentConfig,
prompt: &str,
timeout_secs: u64,
) -> Result<String, String> {
let provider = self
.create_provider_for_agent(agent_config, agent_name)
.map_err(|r| r.error.unwrap_or_default())?;
let temperature = agent_config.temperature.unwrap_or(0.7);
let result = tokio::time::timeout(
Duration::from_secs(timeout_secs),
provider.chat_with_system(
agent_config.system_prompt.as_deref(),
prompt,
&agent_config.model,
temperature,
),
)
.await;
match result {
Ok(Ok(response)) => {
if response.trim().is_empty() {
Ok("[Empty response]".to_string())
} else {
Ok(response)
}
}
Ok(Err(e)) => Err(format!("Agent '{agent_name}' failed: {e}")),
Err(_) => Err(format!(
"Agent '{agent_name}' timed out after {timeout_secs}s"
)),
}
}
async fn execute_sequential(
&self,
swarm_config: &SwarmConfig,
prompt: &str,
context: &str,
) -> anyhow::Result<ToolResult> {
let mut current_input = if context.is_empty() {
prompt.to_string()
} else {
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
};
let per_agent_timeout = swarm_config.timeout_secs / swarm_config.agents.len().max(1) as u64;
let mut results = Vec::new();
for (i, agent_name) in swarm_config.agents.iter().enumerate() {
let agent_config = match self.agents.get(agent_name) {
Some(cfg) => cfg,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
});
}
};
let agent_prompt = if i == 0 {
current_input.clone()
} else {
format!("[Previous agent output]\n{current_input}\n\n[Original task]\n{prompt}")
};
match self
.call_agent(agent_name, agent_config, &agent_prompt, per_agent_timeout)
.await
{
Ok(output) => {
results.push(format!(
"[{agent_name} ({}/{})] {output}",
agent_config.provider, agent_config.model
));
current_input = output;
}
Err(e) => {
return Ok(ToolResult {
success: false,
output: results.join("\n\n"),
error: Some(e),
});
}
}
}
Ok(ToolResult {
success: true,
output: format!(
"[Swarm sequential — {} agents]\n\n{}",
swarm_config.agents.len(),
results.join("\n\n")
),
error: None,
})
}
async fn execute_parallel(
&self,
swarm_config: &SwarmConfig,
prompt: &str,
context: &str,
) -> anyhow::Result<ToolResult> {
let full_prompt = if context.is_empty() {
prompt.to_string()
} else {
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
};
let mut join_set = tokio::task::JoinSet::new();
for agent_name in &swarm_config.agents {
let agent_config = match self.agents.get(agent_name) {
Some(cfg) => cfg.clone(),
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
});
}
};
let credential = agent_config
.api_key
.clone()
.or_else(|| self.fallback_credential.clone());
let provider = match providers::create_provider_with_options(
&agent_config.provider,
credential.as_deref(),
&self.provider_runtime_options,
) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Failed to create provider for agent '{agent_name}': {e}"
)),
});
}
};
let name = agent_name.clone();
let prompt_clone = full_prompt.clone();
let timeout = swarm_config.timeout_secs;
let model = agent_config.model.clone();
let temperature = agent_config.temperature.unwrap_or(0.7);
let system_prompt = agent_config.system_prompt.clone();
let provider_name = agent_config.provider.clone();
join_set.spawn(async move {
let result = tokio::time::timeout(
Duration::from_secs(timeout),
provider.chat_with_system(
system_prompt.as_deref(),
&prompt_clone,
&model,
temperature,
),
)
.await;
let output = match result {
Ok(Ok(text)) => {
if text.trim().is_empty() {
"[Empty response]".to_string()
} else {
text
}
}
Ok(Err(e)) => format!("[Error] {e}"),
Err(_) => format!("[Timed out after {timeout}s]"),
};
(name, provider_name, model, output)
});
}
let mut results = Vec::new();
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok((name, provider_name, model, output)) => {
results.push(format!("[{name} ({provider_name}/{model})]\n{output}"));
}
Err(e) => {
results.push(format!("[join error] {e}"));
}
}
}
Ok(ToolResult {
success: true,
output: format!(
"[Swarm parallel — {} agents]\n\n{}",
swarm_config.agents.len(),
results.join("\n\n---\n\n")
),
error: None,
})
}
async fn execute_router(
&self,
swarm_config: &SwarmConfig,
prompt: &str,
context: &str,
) -> anyhow::Result<ToolResult> {
if swarm_config.agents.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Router swarm has no agents to choose from".into()),
});
}
// Build agent descriptions for the router prompt
let agent_descriptions: Vec<String> = swarm_config
.agents
.iter()
.filter_map(|name| {
self.agents.get(name).map(|cfg| {
let desc = cfg
.system_prompt
.as_deref()
.unwrap_or("General purpose agent");
format!(
"- {name}: {desc} (provider: {}, model: {})",
cfg.provider, cfg.model
)
})
})
.collect();
// Use the first agent's provider for routing
let first_agent_name = &swarm_config.agents[0];
let first_agent_config = match self.agents.get(first_agent_name) {
Some(cfg) => cfg,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Swarm references unknown agent '{first_agent_name}'"
)),
});
}
};
let router_provider = self
.create_provider_for_agent(first_agent_config, first_agent_name)
.map_err(|r| anyhow::anyhow!(r.error.unwrap_or_default()))?;
let base_router_prompt = swarm_config
.router_prompt
.as_deref()
.unwrap_or("Pick the single best agent for this task.");
let routing_prompt = format!(
"{base_router_prompt}\n\nAvailable agents:\n{}\n\nUser task: {prompt}\n\n\
Respond with ONLY the agent name, nothing else.",
agent_descriptions.join("\n")
);
let chosen = tokio::time::timeout(
Duration::from_secs(SWARM_AGENT_TIMEOUT_SECS),
router_provider.chat_with_system(
Some("You are a routing assistant. Respond with only the agent name."),
&routing_prompt,
&first_agent_config.model,
0.0,
),
)
.await;
let chosen_name = match chosen {
Ok(Ok(name)) => name.trim().to_string(),
Ok(Err(e)) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Router LLM call failed: {e}")),
});
}
Err(_) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Router LLM call timed out".into()),
});
}
};
// Case-insensitive matching with fallback to first agent
let matched_name = swarm_config
.agents
.iter()
.find(|name| name.eq_ignore_ascii_case(&chosen_name))
.cloned()
.unwrap_or_else(|| swarm_config.agents[0].clone());
let agent_config = match self.agents.get(&matched_name) {
Some(cfg) => cfg,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Router selected unknown agent '{matched_name}'")),
});
}
};
let full_prompt = if context.is_empty() {
prompt.to_string()
} else {
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
};
match self
.call_agent(
&matched_name,
agent_config,
&full_prompt,
swarm_config.timeout_secs,
)
.await
{
Ok(output) => Ok(ToolResult {
success: true,
output: format!(
"[Swarm router — selected '{matched_name}' ({}/{})]\n{output}",
agent_config.provider, agent_config.model
),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
}
#[async_trait]
impl Tool for SwarmTool {
fn name(&self) -> &str {
"swarm"
}
fn description(&self) -> &str {
"Orchestrate a swarm of agents to collaboratively handle a task. Supports sequential \
(pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies."
}
fn parameters_schema(&self) -> serde_json::Value {
let swarm_names: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
json!({
"type": "object",
"additionalProperties": false,
"properties": {
"swarm": {
"type": "string",
"minLength": 1,
"description": format!(
"Name of the swarm to invoke. Available: {}",
if swarm_names.is_empty() {
"(none configured)".to_string()
} else {
swarm_names.join(", ")
}
)
},
"prompt": {
"type": "string",
"minLength": 1,
"description": "The task/prompt to send to the swarm"
},
"context": {
"type": "string",
"description": "Optional context to include (e.g. relevant code, prior findings)"
}
},
"required": ["swarm", "prompt"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let swarm_name = args
.get("swarm")
.and_then(|v| v.as_str())
.map(str::trim)
.ok_or_else(|| anyhow::anyhow!("Missing 'swarm' parameter"))?;
if swarm_name.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'swarm' parameter must not be empty".into()),
});
}
let prompt = args
.get("prompt")
.and_then(|v| v.as_str())
.map(str::trim)
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
if prompt.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'prompt' parameter must not be empty".into()),
});
}
let context = args
.get("context")
.and_then(|v| v.as_str())
.map(str::trim)
.unwrap_or("");
let swarm_config = match self.swarms.get(swarm_name) {
Some(cfg) => cfg,
None => {
let available: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown swarm '{swarm_name}'. Available swarms: {}",
if available.is_empty() {
"(none configured)".to_string()
} else {
available.join(", ")
}
)),
});
}
};
if swarm_config.agents.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Swarm '{swarm_name}' has no agents configured")),
});
}
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "swarm")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
match swarm_config.strategy {
SwarmStrategy::Sequential => {
self.execute_sequential(swarm_config, prompt, context).await
}
SwarmStrategy::Parallel => self.execute_parallel(swarm_config, prompt, context).await,
SwarmStrategy::Router => self.execute_router(swarm_config, prompt, context).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy::default())
}
fn sample_agents() -> HashMap<String, DelegateAgentConfig> {
let mut agents = HashMap::new();
agents.insert(
"researcher".to_string(),
DelegateAgentConfig {
provider: "ollama".to_string(),
model: "llama3".to_string(),
system_prompt: Some("You are a research assistant.".to_string()),
api_key: None,
temperature: Some(0.3),
max_depth: 3,
agentic: false,
allowed_tools: Vec::new(),
max_iterations: 10,
},
);
agents.insert(
"writer".to_string(),
DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "anthropic/claude-sonnet-4-20250514".to_string(),
system_prompt: Some("You are a technical writer.".to_string()),
api_key: Some("test-key".to_string()),
temperature: Some(0.5),
max_depth: 3,
agentic: false,
allowed_tools: Vec::new(),
max_iterations: 10,
},
);
agents
}
fn sample_swarms() -> HashMap<String, SwarmConfig> {
let mut swarms = HashMap::new();
swarms.insert(
"pipeline".to_string(),
SwarmConfig {
agents: vec!["researcher".to_string(), "writer".to_string()],
strategy: SwarmStrategy::Sequential,
router_prompt: None,
description: Some("Research then write".to_string()),
timeout_secs: 300,
},
);
swarms.insert(
"fanout".to_string(),
SwarmConfig {
agents: vec!["researcher".to_string(), "writer".to_string()],
strategy: SwarmStrategy::Parallel,
router_prompt: None,
description: None,
timeout_secs: 300,
},
);
swarms.insert(
"router".to_string(),
SwarmConfig {
agents: vec!["researcher".to_string(), "writer".to_string()],
strategy: SwarmStrategy::Router,
router_prompt: Some("Pick the best agent.".to_string()),
description: None,
timeout_secs: 300,
},
);
swarms
}
#[test]
fn name_and_schema() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
assert_eq!(tool.name(), "swarm");
let schema = tool.parameters_schema();
assert!(schema["properties"]["swarm"].is_object());
assert!(schema["properties"]["prompt"].is_object());
assert!(schema["properties"]["context"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&json!("swarm")));
assert!(required.contains(&json!("prompt")));
assert_eq!(schema["additionalProperties"], json!(false));
}
#[test]
fn description_not_empty() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
assert!(!tool.description().is_empty());
}
#[test]
fn schema_lists_swarm_names() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let schema = tool.parameters_schema();
let desc = schema["properties"]["swarm"]["description"]
.as_str()
.unwrap();
assert!(desc.contains("pipeline") || desc.contains("fanout") || desc.contains("router"));
}
#[test]
fn empty_swarms_schema() {
let tool = SwarmTool::new(
HashMap::new(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let schema = tool.parameters_schema();
let desc = schema["properties"]["swarm"]["description"]
.as_str()
.unwrap();
assert!(desc.contains("none configured"));
}
#[tokio::test]
async fn unknown_swarm_returns_error() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "nonexistent", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown swarm"));
}
#[tokio::test]
async fn missing_swarm_param() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool.execute(json!({"prompt": "test"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn missing_prompt_param() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool.execute(json!({"swarm": "pipeline"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn blank_swarm_rejected() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": " ", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("must not be empty"));
}
#[tokio::test]
async fn blank_prompt_rejected() {
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "pipeline", "prompt": " "}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("must not be empty"));
}
#[tokio::test]
async fn swarm_with_missing_agent_returns_error() {
let mut swarms = HashMap::new();
swarms.insert(
"broken".to_string(),
SwarmConfig {
agents: vec!["nonexistent_agent".to_string()],
strategy: SwarmStrategy::Sequential,
router_prompt: None,
description: None,
timeout_secs: 60,
},
);
let tool = SwarmTool::new(
swarms,
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "broken", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("unknown agent"));
}
#[tokio::test]
async fn swarm_with_empty_agents_returns_error() {
let mut swarms = HashMap::new();
swarms.insert(
"empty".to_string(),
SwarmConfig {
agents: Vec::new(),
strategy: SwarmStrategy::Parallel,
router_prompt: None,
description: None,
timeout_secs: 60,
},
);
let tool = SwarmTool::new(
swarms,
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "empty", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("no agents configured"));
}
#[tokio::test]
async fn swarm_blocked_in_readonly_mode() {
let readonly = Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::ReadOnly,
..SecurityPolicy::default()
});
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
readonly,
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("read-only mode"));
}
#[tokio::test]
async fn swarm_blocked_when_rate_limited() {
let limited = Arc::new(SecurityPolicy {
max_actions_per_hour: 0,
..SecurityPolicy::default()
});
let tool = SwarmTool::new(
sample_swarms(),
sample_agents(),
None,
limited,
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("Rate limit exceeded"));
}
#[tokio::test]
async fn sequential_invalid_provider_returns_error() {
let mut swarms = HashMap::new();
swarms.insert(
"seq".to_string(),
SwarmConfig {
agents: vec!["researcher".to_string()],
strategy: SwarmStrategy::Sequential,
router_prompt: None,
description: None,
timeout_secs: 60,
},
);
// researcher uses "ollama" which won't be running in CI
let tool = SwarmTool::new(
swarms,
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "seq", "prompt": "test"}))
.await
.unwrap();
// Should fail at provider creation or call level
assert!(!result.success);
}
#[tokio::test]
async fn parallel_invalid_provider_returns_error() {
let mut swarms = HashMap::new();
swarms.insert(
"par".to_string(),
SwarmConfig {
agents: vec!["researcher".to_string()],
strategy: SwarmStrategy::Parallel,
router_prompt: None,
description: None,
timeout_secs: 60,
},
);
let tool = SwarmTool::new(
swarms,
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "par", "prompt": "test"}))
.await
.unwrap();
// Parallel strategy returns success with error annotations in output
assert!(result.success || result.error.is_some());
}
#[tokio::test]
async fn router_invalid_provider_returns_error() {
let mut swarms = HashMap::new();
swarms.insert(
"rout".to_string(),
SwarmConfig {
agents: vec!["researcher".to_string()],
strategy: SwarmStrategy::Router,
router_prompt: Some("Pick.".to_string()),
description: None,
timeout_secs: 60,
},
);
let tool = SwarmTool::new(
swarms,
sample_agents(),
None,
test_security(),
providers::ProviderRuntimeOptions::default(),
);
let result = tool
.execute(json!({"swarm": "rout", "prompt": "test"}))
.await
.unwrap();
assert!(!result.success);
}
}
+356
View File
@@ -0,0 +1,356 @@
//! Tool for managing multi-client workspaces.
//!
//! Provides `workspace` subcommands: list, switch, create, info, export.
use super::traits::{Tool, ToolResult};
use crate::config::workspace::WorkspaceManager;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Agent-callable tool for workspace management operations.
pub struct WorkspaceTool {
manager: Arc<RwLock<WorkspaceManager>>,
security: Arc<SecurityPolicy>,
}
impl WorkspaceTool {
pub fn new(manager: Arc<RwLock<WorkspaceManager>>, security: Arc<SecurityPolicy>) -> Self {
Self { manager, security }
}
}
#[async_trait]
impl Tool for WorkspaceTool {
fn name(&self) -> &str {
"workspace"
}
fn description(&self) -> &str {
"Manage multi-client workspaces. Subcommands: list, switch, create, info, export. Each workspace provides isolated memory, audit, secrets, and tool restrictions."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["list", "switch", "create", "info", "export"],
"description": "Workspace action to perform"
},
"name": {
"type": "string",
"description": "Workspace name (required for switch, create, export)"
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args
.get("action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'action' parameter"))?;
let name = args.get("name").and_then(|v| v.as_str());
match action {
"list" => {
let mgr = self.manager.read().await;
let names = mgr.list();
let active = mgr.active_name();
if names.is_empty() {
return Ok(ToolResult {
success: true,
output: "No workspaces configured.".to_string(),
error: None,
});
}
let mut output = format!("Workspaces ({}):\n", names.len());
for ws_name in &names {
let marker = if Some(*ws_name) == active {
" (active)"
} else {
""
};
let _ = writeln!(output, " - {ws_name}{marker}");
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
"switch" => {
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "workspace")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let ws_name = name.ok_or_else(|| {
anyhow::anyhow!("'name' parameter is required for switch action")
})?;
let mut mgr = self.manager.write().await;
match mgr.switch(ws_name) {
Ok(profile) => Ok(ToolResult {
success: true,
output: format!(
"Switched to workspace '{}'. Memory namespace: {}, Audit namespace: {}",
profile.name,
profile.effective_memory_namespace(),
profile.effective_audit_namespace()
),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
"create" => {
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "workspace")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let ws_name = name.ok_or_else(|| {
anyhow::anyhow!("'name' parameter is required for create action")
})?;
let mut mgr = self.manager.write().await;
match mgr.create(ws_name).await {
Ok(profile) => {
let name = profile.name.clone();
let dir = mgr.workspace_dir(ws_name);
Ok(ToolResult {
success: true,
output: format!("Created workspace '{}' at {}", name, dir.display()),
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
"info" => {
let mgr = self.manager.read().await;
let target_name = name.or_else(|| mgr.active_name());
match target_name {
Some(ws_name) => match mgr.get(ws_name) {
Some(profile) => {
let is_active = mgr.active_name() == Some(ws_name);
let mut output = format!("Workspace: {}\n", profile.name);
let _ = writeln!(
output,
" Status: {}",
if is_active { "active" } else { "inactive" }
);
let _ = writeln!(
output,
" Memory namespace: {}",
profile.effective_memory_namespace()
);
let _ = writeln!(
output,
" Audit namespace: {}",
profile.effective_audit_namespace()
);
if !profile.allowed_domains.is_empty() {
let _ = writeln!(
output,
" Allowed domains: {}",
profile.allowed_domains.join(", ")
);
}
if !profile.tool_restrictions.is_empty() {
let _ = writeln!(
output,
" Restricted tools: {}",
profile.tool_restrictions.join(", ")
);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
None => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("workspace '{}' not found", ws_name)),
}),
},
None => Ok(ToolResult {
success: true,
output: "No workspace is currently active. Use 'workspace switch <name>' to activate one.".to_string(),
error: None,
}),
}
}
"export" => {
let mgr = self.manager.read().await;
let ws_name = name.or_else(|| mgr.active_name()).ok_or_else(|| {
anyhow::anyhow!("'name' parameter is required when no workspace is active")
})?;
match mgr.export(ws_name) {
Ok(toml_str) => Ok(ToolResult {
success: true,
output: format!(
"Exported workspace '{}' config (secrets redacted):\n\n{}",
ws_name, toml_str
),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
other => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"unknown workspace action '{}'. Expected: list, switch, create, info, export",
other
)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::SecurityPolicy;
use tempfile::TempDir;
fn test_tool(tmp: &TempDir) -> WorkspaceTool {
let mgr = WorkspaceManager::new(tmp.path().to_path_buf());
WorkspaceTool::new(
Arc::new(RwLock::new(mgr)),
Arc::new(SecurityPolicy::default()),
)
}
#[tokio::test]
async fn workspace_tool_list_empty() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(&tmp);
let result = tool.execute(json!({"action": "list"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No workspaces"));
}
#[tokio::test]
async fn workspace_tool_create_and_list() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(&tmp);
let result = tool
.execute(json!({"action": "create", "name": "test_client"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("test_client"));
let result = tool.execute(json!({"action": "list"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("test_client"));
}
#[tokio::test]
async fn workspace_tool_switch_and_info() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(&tmp);
tool.execute(json!({"action": "create", "name": "ws_test"}))
.await
.unwrap();
let result = tool
.execute(json!({"action": "switch", "name": "ws_test"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Switched to workspace"));
let result = tool.execute(json!({"action": "info"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("ws_test"));
assert!(result.output.contains("active"));
}
#[tokio::test]
async fn workspace_tool_export_redacts() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(&tmp);
tool.execute(json!({"action": "create", "name": "export_ws"}))
.await
.unwrap();
let result = tool
.execute(json!({"action": "export", "name": "export_ws"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("export_ws"));
}
#[tokio::test]
async fn workspace_tool_unknown_action() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(&tmp);
let result = tool.execute(json!({"action": "destroy"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("unknown workspace action"));
}
#[tokio::test]
async fn workspace_tool_switch_nonexistent() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(&tmp);
let result = tool
.execute(json!({"action": "switch", "name": "ghost"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
}
}
+59 -2
View File
@@ -2,6 +2,7 @@ mod cloudflare;
mod custom;
mod ngrok;
mod none;
mod openvpn;
mod tailscale;
pub use cloudflare::CloudflareTunnel;
@@ -9,6 +10,7 @@ pub use custom::CustomTunnel;
pub use ngrok::NgrokTunnel;
#[allow(unused_imports)]
pub use none::NoneTunnel;
pub use openvpn::OpenVpnTunnel;
pub use tailscale::TailscaleTunnel;
use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig};
@@ -104,6 +106,20 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
))))
}
"openvpn" => {
let ov = config
.openvpn
.as_ref()
.ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?;
Ok(Some(Box::new(OpenVpnTunnel::new(
ov.config_file.clone(),
ov.auth_file.clone(),
ov.advertise_address.clone(),
ov.connect_timeout_secs,
ov.extra_args.clone(),
))))
}
"custom" => {
let cu = config
.custom
@@ -116,7 +132,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
))))
}
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"),
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"),
}
}
@@ -126,7 +142,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
mod tests {
use super::*;
use crate::config::schema::{
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig,
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig,
TunnelConfig,
};
use tokio::process::Command;
@@ -315,6 +332,46 @@ mod tests {
assert!(t.public_url().is_none());
}
#[test]
fn factory_openvpn_missing_config_errors() {
let cfg = TunnelConfig {
provider: "openvpn".into(),
..TunnelConfig::default()
};
assert_tunnel_err(&cfg, "[tunnel.openvpn]");
}
#[test]
fn factory_openvpn_with_config_ok() {
let cfg = TunnelConfig {
provider: "openvpn".into(),
openvpn: Some(OpenVpnTunnelConfig {
config_file: "client.ovpn".into(),
auth_file: None,
advertise_address: None,
connect_timeout_secs: 30,
extra_args: vec![],
}),
..TunnelConfig::default()
};
let t = create_tunnel(&cfg).unwrap();
assert!(t.is_some());
assert_eq!(t.unwrap().name(), "openvpn");
}
#[test]
fn openvpn_tunnel_name() {
let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert_eq!(t.name(), "openvpn");
assert!(t.public_url().is_none());
}
#[tokio::test]
async fn openvpn_health_false_before_start() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert!(!tunnel.health_check().await);
}
#[tokio::test]
async fn kill_shared_no_process_is_ok() {
let proc = new_shared_process();
+254
View File
@@ -0,0 +1,254 @@
use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess};
use anyhow::{bail, Result};
use tokio::io::AsyncBufReadExt;
use tokio::process::Command;
/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection.
///
/// Requires the `openvpn` binary installed and accessible. On most systems,
/// OpenVPN requires root/administrator privileges to create tun/tap devices.
///
/// The tunnel exposes the gateway via the VPN network using a configured
/// `advertise_address` (e.g., `"10.8.0.2:42617"`).
pub struct OpenVpnTunnel {
config_file: String,
auth_file: Option<String>,
advertise_address: Option<String>,
connect_timeout_secs: u64,
extra_args: Vec<String>,
proc: SharedProcess,
}
impl OpenVpnTunnel {
/// Create a new OpenVPN tunnel instance.
///
/// * `config_file` — path to the `.ovpn` configuration file.
/// * `auth_file` — optional path to a credentials file for `--auth-user-pass`.
/// * `advertise_address` — optional public address to advertise once connected.
/// * `connect_timeout_secs` — seconds to wait for the initialization sequence.
/// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary.
pub fn new(
config_file: String,
auth_file: Option<String>,
advertise_address: Option<String>,
connect_timeout_secs: u64,
extra_args: Vec<String>,
) -> Self {
Self {
config_file,
auth_file,
advertise_address,
connect_timeout_secs,
extra_args,
proc: new_shared_process(),
}
}
/// Build the openvpn command arguments.
fn build_args(&self) -> Vec<String> {
let mut args = vec!["--config".to_string(), self.config_file.clone()];
if let Some(ref auth) = self.auth_file {
args.push("--auth-user-pass".to_string());
args.push(auth.clone());
}
args.extend(self.extra_args.iter().cloned());
args
}
}
#[async_trait::async_trait]
impl Tunnel for OpenVpnTunnel {
fn name(&self) -> &str {
"openvpn"
}
/// Spawn the `openvpn` process and wait for the "Initialization Sequence
/// Completed" marker on stderr. Returns the public URL on success.
async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
// Validate config file exists before spawning
if !std::path::Path::new(&self.config_file).exists() {
bail!("OpenVPN config file not found: {}", self.config_file);
}
let args = self.build_args();
let mut child = Command::new("openvpn")
.args(&args)
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
// Wait for "Initialization Sequence Completed" in stderr
let stderr = child
.stderr
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?;
let mut reader = tokio::io::BufReader::new(stderr).lines();
let deadline = tokio::time::Instant::now()
+ tokio::time::Duration::from_secs(self.connect_timeout_secs);
let mut connected = false;
while tokio::time::Instant::now() < deadline {
let line =
tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await;
match line {
Ok(Ok(Some(l))) => {
tracing::debug!("openvpn: {l}");
if l.contains("Initialization Sequence Completed") {
connected = true;
break;
}
}
Ok(Ok(None)) => {
bail!("OpenVPN process exited before connection was established");
}
Ok(Err(e)) => {
bail!("Error reading openvpn output: {e}");
}
Err(_) => {
// Timeout on individual line read, continue waiting
}
}
}
if !connected {
child.kill().await.ok();
bail!(
"OpenVPN connection timed out after {}s waiting for initialization",
self.connect_timeout_secs
);
}
let public_url = self
.advertise_address
.clone()
.unwrap_or_else(|| format!("http://{local_host}:{local_port}"));
// Drain stderr in background to prevent OS pipe buffer from filling and
// blocking the openvpn process.
tokio::spawn(async move {
while let Ok(Some(line)) = reader.next_line().await {
tracing::trace!("openvpn: {line}");
}
});
let mut guard = self.proc.lock().await;
*guard = Some(TunnelProcess {
child,
public_url: public_url.clone(),
});
Ok(public_url)
}
/// Kill the openvpn child process and release its resources.
async fn stop(&self) -> Result<()> {
kill_shared(&self.proc).await
}
/// Return `true` if the openvpn child process is still running.
async fn health_check(&self) -> bool {
let guard = self.proc.lock().await;
guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
}
/// Return the public URL if the tunnel has been started.
fn public_url(&self) -> Option<String> {
self.proc
.try_lock()
.ok()
.and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constructor_stores_fields() {
let tunnel = OpenVpnTunnel::new(
"/etc/openvpn/client.ovpn".into(),
Some("/etc/openvpn/auth.txt".into()),
Some("10.8.0.2:42617".into()),
45,
vec!["--verb".into(), "3".into()],
);
assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn");
assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt"));
assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617"));
assert_eq!(tunnel.connect_timeout_secs, 45);
assert_eq!(tunnel.extra_args, vec!["--verb", "3"]);
}
#[test]
fn build_args_basic() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
let args = tunnel.build_args();
assert_eq!(args, vec!["--config", "client.ovpn"]);
}
#[test]
fn build_args_with_auth_and_extras() {
let tunnel = OpenVpnTunnel::new(
"client.ovpn".into(),
Some("auth.txt".into()),
None,
30,
vec!["--verb".into(), "5".into()],
);
let args = tunnel.build_args();
assert_eq!(
args,
vec![
"--config",
"client.ovpn",
"--auth-user-pass",
"auth.txt",
"--verb",
"5"
]
);
}
#[test]
fn public_url_is_none_before_start() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert!(tunnel.public_url().is_none());
}
#[tokio::test]
async fn health_check_is_false_before_start() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
assert!(!tunnel.health_check().await);
}
#[tokio::test]
async fn stop_without_started_process_is_ok() {
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
let result = tunnel.stop().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn start_with_missing_config_file_errors() {
let tunnel = OpenVpnTunnel::new(
"/nonexistent/path/to/client.ovpn".into(),
None,
None,
30,
vec![],
);
let result = tunnel.start("127.0.0.1", 8080).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("config file not found"));
}
}