Merge remote-tracking branch 'upstream/main' into feat/channel-bluebubbles

# Conflicts:
#	src/channels/mod.rs
This commit is contained in:
xj
2026-02-28 14:52:34 -08:00
80 changed files with 7535 additions and 531 deletions
+1 -1
View File
@@ -212,7 +212,7 @@ jobs:
- name: Link check (offline, added links only)
if: steps.collect_links.outputs.count != '0'
uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2
uses: lycheeverse/lychee-action@8646ba30535128ac92d33dfc9133794bfdd9b411 # v2
with:
fail: true
args: >-
+56
View File
@@ -0,0 +1,56 @@
name: Deploy Web to GitHub Pages
on:
push:
branches: [main, dev]
paths:
- 'web/**'
workflow_dispatch:
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Node
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
with:
node-version: '20'
- name: Install dependencies
working-directory: ./web
run: npm ci
- name: Build
working-directory: ./web
run: npm run build
- name: Setup Pages
uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5
- name: Upload artifact
uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4
with:
path: ./web/dist
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
needs: build
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4
+1 -1
View File
@@ -184,7 +184,7 @@ jobs:
- name: Link check (added links)
if: github.event_name != 'workflow_dispatch' && steps.links.outputs.count != '0'
uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2
uses: lycheeverse/lychee-action@8646ba30535128ac92d33dfc9133794bfdd9b411 # v2
with:
fail: true
args: >-
+6 -1
View File
@@ -1,4 +1,6 @@
/target
/target_ci
/target_review*
firmware/*/target
*.db
*.db-journal
@@ -12,6 +14,7 @@ site/node_modules/
site/.vite/
site/public/docs-content/
gh-pages/
.idea
# Environment files (may contain secrets)
@@ -30,10 +33,12 @@ venv/
# Secret keys and credentials
.secret_key
otp-secret
*.key
*.pem
credentials.json
/config.toml
.worktrees/
# Nix
result
result
Generated
+49 -35
View File
@@ -540,16 +540,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a45f9771ced8a774de5e5ebffbe520f52e3943bf5a9a6baa3a5d14a5de1afe6"
[[package]]
name = "bcder"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f7c42c9913f68cf9390a225e81ad56a5c515347287eb98baa710090ca1de86d"
dependencies = [
"bytes",
"smallvec",
]
[[package]]
name = "bech32"
version = "0.11.1"
@@ -1655,9 +1645,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
dependencies = [
"const-oid",
"der_derive",
"flagset",
"zeroize",
]
[[package]]
name = "der_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "deranged"
version = "0.5.8"
@@ -2204,6 +2207,12 @@ version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
[[package]]
name = "flagset"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7ac824320a75a52197e8f2d787f6a38b6718bb6897a35142d749af3c0e8f4fe"
[[package]]
name = "flate2"
version = "1.1.9"
@@ -4604,16 +4613,6 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "pem"
version = "3.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be"
dependencies = [
"base64",
"serde_core",
]
[[package]]
name = "percent-encoding"
version = "2.3.2"
@@ -6821,6 +6820,27 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tls_codec"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b"
dependencies = [
"tls_codec_derive",
"zeroize",
]
[[package]]
name = "tls_codec_derive"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "tokio"
version = "1.49.0"
@@ -6876,16 +6896,17 @@ dependencies = [
[[package]]
name = "tokio-postgres-rustls"
version = "0.12.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
checksum = "27d684bad428a0f2481f42241f821db42c54e2dc81d8c00db8536c506b0a0144"
dependencies = [
"const-oid",
"ring",
"rustls",
"tokio",
"tokio-postgres",
"tokio-rustls",
"x509-certificate",
"x509-cert",
]
[[package]]
@@ -8997,22 +9018,15 @@ dependencies = [
]
[[package]]
name = "x509-certificate"
version = "0.23.1"
name = "x509-cert"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66534846dec7a11d7c50a74b7cdb208b9a581cad890b7866430d438455847c85"
checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94"
dependencies = [
"bcder",
"bytes",
"chrono",
"const-oid",
"der",
"hex",
"pem",
"ring",
"signature",
"spki",
"thiserror 1.0.69",
"zeroize",
"tls_codec",
]
[[package]]
+1 -2
View File
@@ -110,7 +110,7 @@ prost = { version = "0.14", default-features = false, features = ["derive"], opt
# Memory / persistence
rusqlite = { version = "0.37", features = ["bundled"] }
postgres = { version = "0.19", features = ["with-chrono-0_4"], optional = true }
tokio-postgres-rustls = { version = "0.12", optional = true }
tokio-postgres-rustls = { version = "0.13", optional = true }
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
chrono-tz = "0.10"
cron = "0.15"
@@ -238,7 +238,6 @@ whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-pr
# Keep disabled by default to preserve current runtime behavior.
firecrawl = []
web-fetch-html2md = []
web-fetch-plaintext = []
[profile.release]
opt-level = "z" # Optimize for size
+43 -33
View File
@@ -46,7 +46,7 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.
</p>
<p align="center">
<strong>Fast, small, and fully autonomous AI assistant infrastructure</strong><br />
<strong>Fast, small, and fully autonomous Operating System</strong><br />
Deploy anywhere. Swap anything.
</p>
@@ -81,6 +81,45 @@ Use this board for important notices (breaking changes, security advisories, mai
- **Fully swappable:** core systems are traits (providers, channels, tools, memory, tunnels).
- **No lock-in:** OpenAI-compatible provider support + pluggable custom endpoints.
## Quick Start
### Option 1: Homebrew (macOS/Linuxbrew)
```bash
brew install zeroclaw
```
### Option 2: Clone + Bootstrap
```bash
git clone https://github.com/zeroclaw-labs/zeroclaw.git
cd zeroclaw
./bootstrap.sh
```
> **Note:** Source builds require ~2GB RAM and ~6GB disk. For resource-constrained systems, use `./bootstrap.sh --prefer-prebuilt` to download a pre-built binary instead.
### Option 3: Cargo Install
```bash
cargo install zeroclaw
```
### First Run
```bash
# Start the gateway daemon
zeroclaw gateway start
# Open the web UI
zeroclaw dashboard
# Or chat directly
zeroclaw chat "Hello!"
```
For detailed setup options, see [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md).
## Benchmark Snapshot (ZeroClaw vs OpenClaw, Reproducible)
Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware.
@@ -99,16 +138,9 @@ Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge
<img src="zero-claw.jpeg" alt="ZeroClaw vs OpenClaw Comparison" width="800" />
</p>
### 🙏 Special Thanks
---
A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work:
- **Harvard University** — for fostering intellectual curiosity and pushing the boundaries of what's possible.
- **MIT** — for championing open knowledge, open source, and the belief that technology should be accessible to everyone.
- **Sundai Club** — for the community, the energy, and the relentless drive to build things that matter.
- **The World & Beyond** 🌍✨ — to every contributor, dreamer, and builder out there making open source a force for good. This is for you.
We're building in the open because the best ideas come from everywhere. If you're reading this, you're part of it. Welcome. 🦀❤️
For full documentation, see [`docs/README.md`](docs/README.md) | [`docs/SUMMARY.md`](docs/SUMMARY.md)
## ⚠️ Official Repository & Impersonation Warning
@@ -133,31 +165,9 @@ ZeroClaw is dual-licensed for maximum openness and contributor protection:
You may choose either license. **Contributors automatically grant rights under both** — see [CLA.md](CLA.md) for the full contributor agreement.
### Trademark
The **ZeroClaw** name and logo are trademarks of ZeroClaw Labs. This license does not grant permission to use them to imply endorsement or affiliation. See [TRADEMARK.md](TRADEMARK.md) for permitted and prohibited uses.
### Contributor Protections
- You **retain copyright** of your contributions
- **Patent grant** (Apache 2.0) shields you from patent claims by other contributors
- Your contributions are **permanently attributed** in commit history and [NOTICE](NOTICE)
- No trademark rights are transferred by contributing
## Contributing
New to ZeroClaw? Look for issues labeled [`good first issue`](https://github.com/zeroclaw-labs/zeroclaw/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) — see our [Contributing Guide](CONTRIBUTING.md#first-time-contributors) for how to get started.
See [CONTRIBUTING.md](CONTRIBUTING.md) and [CLA.md](CLA.md). Implement a trait, submit a PR:
- CI workflow guide: [docs/ci-map.md](docs/ci-map.md)
- New `Provider``src/providers/`
- New `Channel``src/channels/`
- New `Observer``src/observability/`
- New `Tool``src/tools/`
- New `Memory``src/memory/`
- New `Tunnel``src/tunnel/`
- New `Skill``~/.zeroclaw/workspace/skills/<name>/`
See [CONTRIBUTING.md](CONTRIBUTING.md) and [CLA.md](CLA.md). Implement a trait, submit a PR.
---
+2 -1
View File
@@ -2,7 +2,7 @@
This file is the canonical table of contents for the documentation system.
Last refreshed: **February 25, 2026**.
Last refreshed: **February 28, 2026**.
## Language Entry
@@ -110,5 +110,6 @@ Last refreshed: **February 25, 2026**.
- [project/README.md](project/README.md)
- [project-triage-snapshot-2026-02-18.md](project-triage-snapshot-2026-02-18.md)
- [docs-audit-2026-02-24.md](docs-audit-2026-02-24.md)
- [project/m4-5-rfi-spike-2026-02-28.md](project/m4-5-rfi-spike-2026-02-28.md)
- [i18n-gap-backlog.md](i18n-gap-backlog.md)
- [docs-inventory.md](docs-inventory.md)
+23 -2
View File
@@ -146,6 +146,7 @@ If `[channels_config.matrix]`, `[channels_config.lark]`, or `[channels_config.fe
| Napcat | websocket receive + HTTP send (OneBot) | No (typically local/LAN) |
| Linq | webhook (`/linq`) | Yes (public HTTPS callback) |
| iMessage | local integration | No |
| ACP | stdio (JSON-RPC 2.0) | No |
| Nostr | relay websocket (NIP-04 / NIP-17) | No |
---
@@ -160,7 +161,7 @@ For channels with inbound sender allowlists:
Field names differ by channel:
- `allowed_users` (Telegram/Discord/Slack/Mattermost/Matrix/IRC/Lark/Feishu/DingTalk/QQ/Napcat/Nextcloud Talk)
- `allowed_users` (Telegram/Discord/Slack/Mattermost/Matrix/IRC/Lark/Feishu/DingTalk/QQ/Napcat/Nextcloud Talk/ACP)
- `allowed_from` (Signal)
- `allowed_numbers` (WhatsApp)
- `allowed_senders` (Email/Linq)
@@ -540,6 +541,25 @@ Notes:
allowed_contacts = ["*"]
```
### 4.18 ACP
ACP (Agent Client Protocol) enables ZeroClaw to act as a client for OpenCode ACP server,
allowing remote control of OpenCode behavior through JSON-RPC 2.0 communication over stdio.
```toml
[channels_config.acp]
opencode_path = "opencode" # optional, default: "opencode"
workdir = "/path/to/workspace" # optional
extra_args = [] # optional additional arguments to `opencode acp`
allowed_users = ["*"] # empty = deny all, "*" = allow all
```
Notes:
- ACP uses JSON-RPC 2.0 protocol over stdio with newline-delimited messages.
- Requires `opencode` binary in PATH or specified via `opencode_path`.
- The channel starts OpenCode subprocess via `opencode acp` command.
- Responses from OpenCode can be sent back to the originating channel when configured.
---
## 5. Validation Workflow
@@ -588,7 +608,7 @@ RUST_LOG=info zeroclaw daemon 2>&1 | tee /tmp/zeroclaw.log
Then filter channel/gateway events:
```bash
rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|DingTalk|QQ|iMessage|Nostr|Webhook|Channel" /tmp/zeroclaw.log
rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|DingTalk|QQ|iMessage|Nostr|Webhook|Channel|ACP" /tmp/zeroclaw.log
```
### 7.2 Keyword table
@@ -610,6 +630,7 @@ rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|D
| QQ | `QQ: connected and identified` | `QQ: ignoring C2C message from unauthorized user:` / `QQ: ignoring group message from unauthorized user:` | `QQ: received Reconnect (op 7)` / `QQ: received Invalid Session (op 9)` / `QQ: message channel closed` |
| Nextcloud Talk (gateway) | `POST /nextcloud-talk — Nextcloud Talk bot webhook` | `Nextcloud Talk webhook signature verification failed` / `Nextcloud Talk: ignoring message from unauthorized actor:` | `Nextcloud Talk send failed:` / `LLM error for Nextcloud Talk message:` |
| iMessage | `iMessage channel listening (AppleScript bridge)...` | (contact allowlist enforced by `allowed_contacts`) | `iMessage poll error:` |
| ACP | `ACP channel started` | `ACP: ignoring message from unauthorized user:` | `ACP process exited unexpectedly:` / `ACP JSON-RPC timeout:` / `ACP process spawn failed:` |
| Nostr | `Nostr channel listening as npub1...` | `Nostr: ignoring NIP-04 message from unauthorized pubkey:` / `Nostr: ignoring NIP-17 message from unauthorized pubkey:` | `Failed to decrypt NIP-04 message:` / `Failed to unwrap NIP-17 gift wrap:` / `Nostr relay pool shut down` |
### 7.3 Runtime supervisor keywords
+32
View File
@@ -38,6 +38,38 @@ Notes:
- Unset keeps the provider's built-in default.
- Environment override: `ZEROCLAW_MODEL_SUPPORT_VISION` or `MODEL_SUPPORT_VISION` (values: `true`/`false`/`1`/`0`/`yes`/`no`/`on`/`off`).
## `[model_providers.<profile>]`
Use named profiles to map a logical provider id to a provider name/base URL and optional profile-scoped credentials.
| Key | Default | Notes |
|---|---|---|
| `name` | unset | Optional provider id override (for example `openai`, `openai-codex`) |
| `base_url` | unset | Optional OpenAI-compatible endpoint URL |
| `wire_api` | unset | Optional protocol mode: `responses` or `chat_completions` |
| `model` | unset | Optional profile-scoped default model |
| `api_key` | unset | Optional profile-scoped API key (used when top-level `api_key` is empty) |
| `requires_openai_auth` | `false` | Load OpenAI auth material (`OPENAI_API_KEY` / Codex auth file) |
Notes:
- If both top-level `api_key` and profile `api_key` are present, top-level `api_key` wins.
- If top-level `default_model` is still the global OpenRouter default, profile `model` is used as an automatic compatibility override.
- Secrets encryption applies to profile API keys when `secrets.encrypt = true`.
Example:
```toml
default_provider = "sub2api"
[model_providers.sub2api]
name = "sub2api"
base_url = "https://api.example.com/v1"
wire_api = "chat_completions"
model = "qwen-max"
api_key = "sk-profile-key"
```
## `[observability]`
| Key | Default | Purpose |
+2 -1
View File
@@ -2,7 +2,7 @@
This inventory classifies documentation by intent and canonical location.
Last reviewed: **February 24, 2026**.
Last reviewed: **February 28, 2026**.
## Classification Legend
@@ -124,6 +124,7 @@ These are valuable context, but **not strict runtime contracts**.
|---|---|
| `docs/project-triage-snapshot-2026-02-18.md` | Snapshot |
| `docs/docs-audit-2026-02-24.md` | Snapshot (docs architecture audit) |
| `docs/project/m4-5-rfi-spike-2026-02-28.md` | Snapshot (M4-5 workspace split RFI baseline and execution plan) |
| `docs/i18n-gap-backlog.md` | Snapshot (i18n depth gap tracking) |
## Maintenance Contract
+1
View File
@@ -6,6 +6,7 @@ Time-bound project status snapshots for planning documentation and operations wo
- [../project-triage-snapshot-2026-02-18.md](../project-triage-snapshot-2026-02-18.md)
- [../docs-audit-2026-02-24.md](../docs-audit-2026-02-24.md)
- [m4-5-rfi-spike-2026-02-28.md](m4-5-rfi-spike-2026-02-28.md)
## Scope
+156
View File
@@ -0,0 +1,156 @@
# M4-5 Multi-Crate Workspace RFI Spike (2026-02-28)
Status: RFI complete, extraction execution pending.
Issue: [#2263](https://github.com/zeroclaw-labs/zeroclaw/issues/2263)
Linear parent: RMN-243
## Scope
This spike is strictly no-behavior-change planning for the M4-5 workspace split.
Goals:
- capture reproducible compile baseline metrics
- define crate boundary and dependency contract
- define CI/feature-matrix impact and rollback posture
- define stacked PR slicing plan (XS/S/M)
Out of scope:
- broad API redesign
- feature additions bundled with structure work
- one-shot mega-PR extraction
## Baseline Compile Metrics
### Repro command
```bash
scripts/ci/m4_5_rfi_baseline.sh /tmp/zeroclaw-m4rfi-target
```
### Preflight compile blockers observed on `origin/main`
Before timing could run cleanly, two compile blockers were found:
- `src/gateway/mod.rs:2176`: `run_gateway_chat_with_tools` call missing `session_id` argument
- `src/providers/cursor.rs:233`: `ChatResponse` initializer missing `quota_metadata`
RFI includes minimal compile-compat fixes for these two blockers so measurements are reproducible.
### Measured results (Apple Silicon macOS, local workspace)
| Phase | real(s) | status |
|---|---:|---|
| A: cold `cargo check --workspace --locked` | 306.47 | pass |
| B: cold-ish `cargo build --workspace --locked` | 219.07 | pass |
| C: warm `cargo check --workspace --locked` | 0.84 | pass |
| D: incremental `cargo check` after touching `src/main.rs` | 6.19 | pass |
Observations:
- cold check is the dominant iteration tax
- warm-check performance is excellent once target artifacts exist
- incremental behavior is acceptable but sensitive to wide root-crate coupling
## Current Workspace Snapshot
Current workspace members:
- `.` (`zeroclaw` monolith crate)
- `crates/robot-kit`
Code concentration still sits in the monolith. Large hotspots include:
- `src/config/schema.rs`
- `src/channels/mod.rs`
- `src/onboard/wizard.rs`
- `src/agent/loop_.rs`
- `src/gateway/mod.rs`
## Proposed Boundary Contract
Target crate topology for staged extraction:
1. `crates/zeroclaw-types`
- shared DTOs, enums, IDs, lightweight cross-domain traits
- no provider/channel/network dependencies
1. `crates/zeroclaw-core`
- config structs + validation, provider trait contracts, routing primitives, policy helpers
- depends on `zeroclaw-types`
1. `crates/zeroclaw-memory`
- memory traits/backends + hygiene/snapshot plumbing
- depends on `zeroclaw-types`, `zeroclaw-core` contracts only where required
1. `crates/zeroclaw-channels`
- channel adapters + inbound normalization
- depends on `zeroclaw-types`, `zeroclaw-core`, `zeroclaw-memory`
1. `crates/zeroclaw-api`
- gateway/webhook/http orchestration
- depends on `zeroclaw-types`, `zeroclaw-core`, `zeroclaw-memory`, `zeroclaw-channels`
1. `crates/zeroclaw-bin` (or keep root binary package name stable)
- CLI entrypoints + wiring only
Dependency rules:
- no downward imports from foundational crates into higher layers
- channels must not depend on gateway/http crate
- keep provider-specific SDK deps out of `zeroclaw-types`
- maintain feature-flag parity at workspace root during migration
## CI / Feature-Matrix Impact
Required CI adjustments during migration:
- add workspace compile lane (`cargo check --workspace --locked`)
- add package-focused lanes for extracted crates (`-p zeroclaw-types`, `-p zeroclaw-core`, etc.)
- keep existing runtime behavior lanes (`test`, `sec-audit`, `codeql`) unchanged until final convergence
- update path filters so crate-local changes trigger only relevant crate tests plus contract smoke tests
Guardrails:
- changed-line strict-delta lint remains mandatory
- each extraction PR must include no-behavior-change assertion in PR body
- each step must include explicit rollback note
## Rollback Strategy
Per-step rollback (stack-safe):
1. revert latest extraction PR only
2. re-run workspace compile + existing CI matrix
3. keep binary entrypoint and config contract untouched until final extraction stage
Abort criteria:
- unexpected runtime behavior drift
- CI lane expansion causes recurring queue stalls without signal gain
- feature-flag compatibility regressions
## Stacked PR Slicing Plan
### PR-1 (XS)
- add crate shells + workspace wiring (`types/core`), no symbol moves
- objective: establish scaffolding and CI package lanes
### PR-2 (S)
- extract low-churn shared types into `zeroclaw-types`
- add re-export shim layer to preserve existing import paths
### PR-3 (S)
- extract config/provider contracts into `zeroclaw-core`
- keep runtime call sites unchanged via compatibility re-exports
### PR-4 (M)
- extract memory subsystem crate and move wiring boundaries
- run full memory + gateway regression suite
### PR-5 (M)
- extract channels/api orchestration seams
- finalize package ownership and remove temporary re-export shims
## Next Execution Step
Open first no-behavior-change extraction PR from this RFI baseline:
- scope: workspace crate scaffolding + CI package lanes only
- no runtime behavior changes
- explicit rollback command included in PR body
+67
View File
@@ -0,0 +1,67 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
cat <<'USAGE'
Usage: scripts/ci/m4_5_rfi_baseline.sh [target_dir]
Run reproducible compile-timing probes for the current workspace.
The script prints a markdown table with real-time seconds and pass/fail status
for each benchmark phase.
USAGE
exit 0
fi
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
TARGET_DIR="${1:-${ROOT_DIR}/target-rfi}"
cd "${ROOT_DIR}"
if [[ ! -f Cargo.toml ]]; then
echo "error: Cargo.toml not found at ${ROOT_DIR}" >&2
exit 1
fi
run_timed() {
local label="$1"
shift
local timing_file
timing_file="$(mktemp)"
local status="pass"
if /usr/bin/time -p "$@" >/dev/null 2>"${timing_file}"; then
status="pass"
else
status="fail"
fi
local real_time
real_time="$(awk '/^real / { print $2 }' "${timing_file}")"
rm -f "${timing_file}"
if [[ -z "${real_time}" ]]; then
real_time="n/a"
fi
printf '| %s | %s | %s |\n' "${label}" "${real_time}" "${status}"
[[ "${status}" == "pass" ]]
}
printf '# M4-5 RFI Baseline\n\n'
printf '- Timestamp (UTC): %s\n' "$(date -u +%Y-%m-%dT%H:%M:%SZ)"
printf '- Commit: `%s`\n' "$(git rev-parse --short HEAD)"
printf '- Target dir: `%s`\n\n' "${TARGET_DIR}"
printf '| Phase | real(s) | status |\n'
printf '|---|---:|---|\n'
rm -rf "${TARGET_DIR}"
set +e
run_timed "A: cold cargo check" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo check --workspace --locked
run_timed "B: cold-ish cargo build" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo build --workspace --locked
run_timed "C: warm cargo check" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo check --workspace --locked
touch src/main.rs
run_timed "D: incremental cargo check after touch src/main.rs" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo check --workspace --locked
set -e
+24 -10
View File
@@ -18,6 +18,8 @@ use std::io::Write as IoWrite;
use std::sync::Arc;
use std::time::Instant;
const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20;
pub struct Agent {
provider: Box<dyn Provider>,
tools: Vec<Box<dyn Tool>>,
@@ -218,9 +220,7 @@ impl AgentBuilder {
.memory_loader
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
config: self.config.unwrap_or_default(),
model_name: self
.model_name
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
model_name: crate::config::resolve_default_model_id(self.model_name.as_deref(), None),
temperature: self.temperature.unwrap_or(0.7),
workspace_dir: self
.workspace_dir
@@ -298,11 +298,10 @@ impl Agent {
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
let model_name = config
.default_model
.as_deref()
.unwrap_or("anthropic/claude-sonnet-4-20250514")
.to_string();
let model_name = crate::config::resolve_default_model_id(
config.default_model.as_deref(),
Some(provider_name),
);
let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name,
@@ -598,6 +597,17 @@ impl Agent {
.push(ConversationMessage::Chat(ChatMessage::assistant(
final_text.clone(),
)));
if self.auto_save && final_text.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
let _ = self
.memory
.store(
"assistant_resp",
&final_text,
MemoryCategory::Conversation,
None,
)
.await;
}
self.trim_history();
return Ok(final_text);
@@ -714,8 +724,12 @@ pub async fn run(
let model_name = effective_config
.default_model
.as_deref()
.unwrap_or("anthropic/claude-sonnet-4-20250514")
.to_string();
.map(str::trim)
.filter(|m| !m.is_empty())
.map(str::to_string)
.unwrap_or_else(|| {
crate::config::default_model_fallback_for_provider(Some(&provider_name)).to_string()
});
agent.observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.clone(),
+64 -20
View File
@@ -10,7 +10,7 @@ use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use anyhow::{Context as _, Result};
use regex::{Regex, RegexSet};
use rustyline::completion::{Completer, Pair};
use rustyline::error::ReadlineError;
@@ -33,6 +33,7 @@ mod execution;
mod history;
mod parsing;
use crate::agent::session::{resolve_session_id, shared_session_manager};
use context::{build_context, build_hardware_context};
use detection::{DetectionVerdict, LoopDetectionConfig, LoopDetector};
use execution::{
@@ -981,7 +982,7 @@ pub(crate) async fn run_tool_call_loop(
Some(model),
Some(&turn_id),
Some(false),
Some(&parse_issue),
Some(parse_issue),
serde_json::json!({
"iteration": iteration + 1,
"response_excerpt": truncate_with_ellipsis(
@@ -1816,10 +1817,12 @@ pub async fn run(
.or(config.default_provider.as_deref())
.unwrap_or("openrouter");
let model_name = model_override
.as_deref()
.or(config.default_model.as_deref())
.unwrap_or("anthropic/claude-sonnet-4");
let model_name = crate::config::resolve_default_model_id(
model_override
.as_deref()
.or(config.default_model.as_deref()),
Some(provider_name),
);
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
@@ -1840,7 +1843,7 @@ pub async fn run(
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
model_name,
&model_name,
&provider_runtime_options,
)?;
@@ -2003,7 +2006,7 @@ pub async fn run(
let native_tools = provider.supports_native_tools();
let mut system_prompt = crate::channels::build_system_prompt_with_mode(
&config.workspace_dir,
model_name,
&model_name,
&tool_descs,
&skills,
Some(&config.identity),
@@ -2042,7 +2045,7 @@ pub async fn run(
// Inject memory + hardware RAG context into user message
let mem_context =
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await;
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score, None).await;
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
let hw_context = hardware_rag
.as_ref()
@@ -2085,7 +2088,7 @@ pub async fn run(
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
&model_name,
temperature,
false,
approval_manager.as_ref(),
@@ -2101,6 +2104,17 @@ pub async fn run(
)
.await?;
final_output = response.clone();
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
let assistant_key = autosave_memory_key("assistant_resp");
let _ = mem
.store(
&assistant_key,
&response,
MemoryCategory::Conversation,
None,
)
.await;
}
println!("{response}");
observer.record_event(&ObserverEvent::TurnComplete);
} else {
@@ -2191,8 +2205,13 @@ pub async fn run(
}
// Inject memory + hardware RAG context into user message
let mem_context =
build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await;
let mem_context = build_context(
mem.as_ref(),
&user_input,
config.memory.min_relevance_score,
None,
)
.await;
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
let hw_context = hardware_rag
.as_ref()
@@ -2246,7 +2265,7 @@ pub async fn run(
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
&model_name,
temperature,
false,
approval_manager.as_ref(),
@@ -2288,6 +2307,17 @@ pub async fn run(
}
};
final_output = response.clone();
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
let assistant_key = autosave_memory_key("assistant_resp");
let _ = mem
.store(
&assistant_key,
&response,
MemoryCategory::Conversation,
None,
)
.await;
}
if let Err(e) = crate::channels::Channel::send(
&cli,
&crate::channels::traits::SendMessage::new(format!("\n{response}\n"), "user"),
@@ -2302,7 +2332,7 @@ pub async fn run(
if let Ok(compacted) = auto_compact_history(
&mut history,
provider.as_ref(),
model_name,
&model_name,
config.agent.max_history_messages,
)
.await
@@ -2332,6 +2362,14 @@ pub async fn run(
/// Process a single message through the full agent (with tools, peripherals, memory).
/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use.
pub async fn process_message(config: Config, message: &str) -> Result<String> {
process_message_with_session(config, message, None).await
}
pub async fn process_message_with_session(
config: Config,
message: &str,
session_id: Option<&str>,
) -> Result<String> {
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let runtime: Arc<dyn runtime::RuntimeAdapter> =
@@ -2375,10 +2413,10 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
tools_registry.extend(peripheral_tools);
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
let model_name = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let model_name = crate::config::resolve_default_model_id(
config.default_model.as_deref(),
Some(provider_name),
);
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: config.api_url.clone(),
@@ -2495,7 +2533,13 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
}
system_prompt.push_str(&build_shell_policy_instructions(&config.autonomy));
let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await;
let mem_context = build_context(
mem.as_ref(),
message,
config.memory.min_relevance_score,
session_id,
)
.await;
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
let hw_context = hardware_rag
.as_ref()
@@ -4545,7 +4589,7 @@ Tail"#;
.await
.unwrap();
let context = build_context(&mem, "status updates", 0.0).await;
let context = build_context(&mem, "status updates", 0.0, None).await;
assert!(context.contains("user_msg_real"));
assert!(!context.contains("assistant_resp_poisoned"));
assert!(!context.contains("fabricated event"));
+2 -1
View File
@@ -8,11 +8,12 @@ pub(super) async fn build_context(
mem: &dyn Memory,
user_msg: &str,
min_relevance_score: f64,
session_id: Option<&str>,
) -> String {
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
let relevant: Vec<_> = entries
.iter()
.filter(|e| match e.score {
+25 -1
View File
@@ -220,7 +220,9 @@ impl LoopDetector {
fn hash_output(output: &str) -> u64 {
let prefix = if output.len() > OUTPUT_HASH_PREFIX_BYTES {
&output[..OUTPUT_HASH_PREFIX_BYTES]
// Use floor_utf8_char_boundary to avoid panic on multi-byte UTF-8 characters
let boundary = crate::util::floor_utf8_char_boundary(output, OUTPUT_HASH_PREFIX_BYTES);
&output[..boundary]
} else {
output
};
@@ -386,4 +388,26 @@ mod tests {
det.record_call("shell", r#"{"cmd":"cargo test"}"#, "ok", true);
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 11. UTF-8 boundary safety: hash_output must not panic on CJK text
#[test]
fn hash_output_utf8_boundary_safe() {
// Create a string where byte 4096 lands inside a multi-byte char
// Chinese chars are 3 bytes each, so 1366 chars = 4098 bytes
let cjk_text: String = "".repeat(1366); // 4098 bytes
assert!(cjk_text.len() > super::OUTPUT_HASH_PREFIX_BYTES);
// This should NOT panic
let hash1 = super::hash_output(&cjk_text);
// Different content should produce different hash
let cjk_text2: String = "".repeat(1366);
let hash2 = super::hash_output(&cjk_text2);
assert_ne!(hash1, hash2);
// Mixed ASCII + CJK at boundary
let mixed = "a".repeat(4094) + "文文"; // 4094 + 6 = 4100 bytes, boundary at 4096
let hash3 = super::hash_output(&mixed);
assert!(hash3 != 0); // Just verify it runs
}
}
+106 -3
View File
@@ -28,8 +28,13 @@ pub(super) fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
}
let start = if has_system { 1 } else { 0 };
let to_remove = non_system_count - max_history;
history.drain(start..start + to_remove);
let mut trim_end = start + (non_system_count - max_history);
// Never keep a leading `role=tool` at the trim boundary. Tool-message runs
// must remain attached to their preceding assistant(tool_calls) message.
while trim_end < history.len() && history[trim_end].role == "tool" {
trim_end += 1;
}
history.drain(start..trim_end);
}
pub(super) fn build_compaction_transcript(messages: &[ChatMessage]) -> String {
@@ -80,7 +85,11 @@ pub(super) async fn auto_compact_history(
return Ok(false);
}
let compact_end = start + compact_count;
let mut compact_end = start + compact_count;
// Do not split assistant(tool_calls) -> tool runs across compaction boundary.
while compact_end < history.len() && history[compact_end].role == "tool" {
compact_end += 1;
}
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
let transcript = build_compaction_transcript(&to_compact);
@@ -104,3 +113,97 @@ pub(super) async fn auto_compact_history(
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{ChatRequest, ChatResponse, Provider};
use async_trait::async_trait;
struct StaticSummaryProvider;
#[async_trait]
impl Provider for StaticSummaryProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("- summarized context".to_string())
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
text: Some("- summarized context".to_string()),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
quota_metadata: None,
})
}
}
fn assistant_with_tool_call(id: &str) -> ChatMessage {
ChatMessage::assistant(format!(
"{{\"content\":\"\",\"tool_calls\":[{{\"id\":\"{id}\",\"name\":\"shell\",\"arguments\":\"{{}}\"}}]}}"
))
}
fn tool_result(id: &str) -> ChatMessage {
ChatMessage::tool(format!("{{\"tool_call_id\":\"{id}\",\"content\":\"ok\"}}"))
}
#[test]
fn trim_history_avoids_orphan_tool_at_boundary() {
let mut history = vec![
ChatMessage::user("old"),
assistant_with_tool_call("call_1"),
tool_result("call_1"),
ChatMessage::user("recent"),
];
trim_history(&mut history, 2);
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "user");
assert_eq!(history[0].content, "recent");
}
#[tokio::test]
async fn auto_compact_history_does_not_split_tool_run_boundary() {
let mut history = vec![
ChatMessage::user("oldest"),
assistant_with_tool_call("call_2"),
tool_result("call_2"),
];
for idx in 0..19 {
history.push(ChatMessage::user(format!("recent-{idx}")));
}
// 22 non-system messages => compaction with max_history=21 would
// previously cut right before the tool result (index 2).
assert_eq!(history.len(), 22);
let compacted =
auto_compact_history(&mut history, &StaticSummaryProvider, "test-model", 21)
.await
.expect("compaction should succeed");
assert!(compacted);
assert_eq!(history[0].role, "assistant");
assert!(
history[0].content.contains("[Compaction summary]"),
"summary message should replace compacted range"
);
assert_ne!(
history[1].role, "tool",
"first retained message must not be an orphan tool result"
);
}
}
+2 -1
View File
@@ -7,6 +7,7 @@ pub mod memory_loader;
pub mod prompt;
pub mod quota_aware;
pub mod research;
pub mod session;
#[cfg(test)]
mod tests;
@@ -14,4 +15,4 @@ mod tests;
#[allow(unused_imports)]
pub use agent::{Agent, AgentBuilder};
#[allow(unused_imports)]
pub use loop_::{process_message, run};
pub use loop_::{process_message, process_message_with_session, run};
+569
View File
@@ -0,0 +1,569 @@
use crate::providers::ChatMessage;
use crate::{
config::AgentSessionBackend, config::AgentSessionConfig, config::AgentSessionStrategy,
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use parking_lot::Mutex;
use rusqlite::{params, Connection};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{LazyLock, Mutex as StdMutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio::time;
static SHARED_SESSION_MANAGERS: LazyLock<StdMutex<HashMap<String, Arc<dyn SessionManager>>>> =
LazyLock::new(|| StdMutex::new(HashMap::new()));
pub fn resolve_session_id(
session_config: &AgentSessionConfig,
sender_id: &str,
channel_name: Option<&str>,
) -> String {
fn escape_part(raw: &str) -> String {
raw.replace(':', "%3A")
}
match session_config.strategy {
AgentSessionStrategy::Main => "main".to_string(),
AgentSessionStrategy::PerChannel => escape_part(channel_name.unwrap_or("main")),
AgentSessionStrategy::PerSender => match channel_name {
Some(channel) => format!("{}:{sender_id}", escape_part(channel)),
None => sender_id.to_string(),
},
}
}
pub fn create_session_manager(
session_config: &AgentSessionConfig,
workspace_dir: &Path,
) -> Result<Option<Arc<dyn SessionManager>>> {
let ttl = Duration::from_secs(session_config.ttl_seconds);
let max_messages = session_config.max_messages;
match session_config.backend {
AgentSessionBackend::None => Ok(None),
AgentSessionBackend::Memory => Ok(Some(MemorySessionManager::new(ttl, max_messages))),
AgentSessionBackend::Sqlite => {
let path = SqliteSessionManager::default_db_path(workspace_dir);
Ok(Some(SqliteSessionManager::new(path, ttl, max_messages)?))
}
}
}
pub fn shared_session_manager(
session_config: &AgentSessionConfig,
workspace_dir: &Path,
) -> Result<Option<Arc<dyn SessionManager>>> {
let key = format!("{}:{session_config:?}", workspace_dir.display());
{
let map = SHARED_SESSION_MANAGERS.lock().unwrap_or_else(|e| e.into_inner());
if let Some(mgr) = map.get(&key) {
return Ok(Some(mgr.clone()));
}
}
let mgr_opt = create_session_manager(session_config, workspace_dir)?;
if let Some(mgr) = mgr_opt.as_ref() {
let mut map = SHARED_SESSION_MANAGERS.lock().unwrap_or_else(|e| e.into_inner());
map.insert(key, mgr.clone());
}
Ok(mgr_opt)
}
#[derive(Clone)]
pub struct Session {
id: String,
manager: Arc<dyn SessionManager>,
}
impl Session {
pub fn id(&self) -> &str {
&self.id
}
pub async fn get_history(&self) -> Result<Vec<ChatMessage>> {
self.manager.get_history(&self.id).await
}
pub async fn update_history(&self, history: Vec<ChatMessage>) -> Result<()> {
self.manager.set_history(&self.id, history).await
}
}
#[async_trait]
pub trait SessionManager: Send + Sync {
fn clone_arc(&self) -> Arc<dyn SessionManager>;
async fn ensure_exists(&self, _session_id: &str) -> Result<()> {
Ok(())
}
async fn get_history(&self, session_id: &str) -> Result<Vec<ChatMessage>>;
async fn set_history(&self, session_id: &str, history: Vec<ChatMessage>) -> Result<()>;
async fn delete(&self, session_id: &str) -> Result<()>;
async fn cleanup_expired(&self) -> Result<usize>;
async fn get_or_create(&self, session_id: &str) -> Result<Session> {
self.ensure_exists(session_id).await?;
Ok(Session {
id: session_id.to_string(),
manager: self.clone_arc(),
})
}
}
fn unix_seconds_now() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs() as i64
}
fn trim_non_system(history: &mut Vec<ChatMessage>, max_messages: usize) {
history.retain(|m| m.role != "system");
if max_messages == 0 || history.len() <= max_messages {
return;
}
let drop_count = history.len() - max_messages;
history.drain(0..drop_count);
}
#[derive(Debug)]
struct MemorySessionState {
history: RwLock<Vec<ChatMessage>>,
updated_at_unix: AtomicI64,
}
struct MemorySessionManagerInner {
sessions: RwLock<HashMap<String, Arc<MemorySessionState>>>,
ttl: Duration,
max_messages: usize,
}
#[derive(Clone)]
pub struct MemorySessionManager {
inner: Arc<MemorySessionManagerInner>,
}
impl MemorySessionManager {
pub fn new(ttl: Duration, max_messages: usize) -> Arc<Self> {
let mgr = Arc::new(Self {
inner: Arc::new(MemorySessionManagerInner {
sessions: RwLock::new(HashMap::new()),
ttl,
max_messages,
}),
});
mgr.spawn_cleanup_task();
mgr
}
fn spawn_cleanup_task(self: &Arc<Self>) {
let mgr = Arc::clone(self);
let interval = cleanup_interval(mgr.inner.ttl);
tokio::spawn(async move {
let mut ticker = time::interval(interval);
loop {
ticker.tick().await;
let _ = mgr.cleanup_expired().await;
}
});
}
}
#[async_trait]
impl SessionManager for MemorySessionManager {
fn clone_arc(&self) -> Arc<dyn SessionManager> {
Arc::new(self.clone())
}
async fn ensure_exists(&self, session_id: &str) -> Result<()> {
let mut sessions = self.inner.sessions.write().await;
if sessions.contains_key(session_id) {
return Ok(());
}
let now = unix_seconds_now();
sessions.insert(
session_id.to_string(),
Arc::new(MemorySessionState {
history: RwLock::new(Vec::new()),
updated_at_unix: AtomicI64::new(now),
}),
);
Ok(())
}
async fn get_history(&self, session_id: &str) -> Result<Vec<ChatMessage>> {
let state = {
let sessions = self.inner.sessions.read().await;
sessions.get(session_id).cloned()
};
let Some(state) = state else {
return Ok(Vec::new());
};
let history = state.history.read().await;
let mut history = history.clone();
trim_non_system(&mut history, self.inner.max_messages);
Ok(history)
}
async fn set_history(&self, session_id: &str, mut history: Vec<ChatMessage>) -> Result<()> {
trim_non_system(&mut history, self.inner.max_messages);
let now = unix_seconds_now();
let state = {
let mut sessions = self.inner.sessions.write().await;
sessions
.entry(session_id.to_string())
.or_insert_with(|| {
Arc::new(MemorySessionState {
history: RwLock::new(Vec::new()),
updated_at_unix: AtomicI64::new(now),
})
})
.clone()
};
state.updated_at_unix.store(now, Ordering::Relaxed);
let mut stored = state.history.write().await;
*stored = history;
Ok(())
}
async fn delete(&self, session_id: &str) -> Result<()> {
let mut sessions = self.inner.sessions.write().await;
sessions.remove(session_id);
Ok(())
}
async fn cleanup_expired(&self) -> Result<usize> {
if self.inner.ttl.is_zero() {
return Ok(0);
}
let cutoff = unix_seconds_now() - self.inner.ttl.as_secs() as i64;
let mut sessions = self.inner.sessions.write().await;
let before = sessions.len();
sessions.retain(|_, s| s.updated_at_unix.load(Ordering::Relaxed) >= cutoff);
Ok(before.saturating_sub(sessions.len()))
}
}
#[derive(Clone)]
pub struct SqliteSessionManager {
conn: Arc<Mutex<Connection>>,
ttl: Duration,
max_messages: usize,
}
impl SqliteSessionManager {
pub fn new(db_path: PathBuf, ttl: Duration, max_messages: usize) -> Result<Arc<Self>> {
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(&db_path)?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;",
)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS agent_sessions (
session_id TEXT PRIMARY KEY,
history_json TEXT NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_agent_sessions_updated_at
ON agent_sessions(updated_at);",
)?;
let mgr = Arc::new(Self {
conn: Arc::new(Mutex::new(conn)),
ttl,
max_messages,
});
mgr.spawn_cleanup_task();
Ok(mgr)
}
pub fn default_db_path(workspace_dir: &Path) -> PathBuf {
workspace_dir.join("memory").join("sessions.db")
}
fn spawn_cleanup_task(self: &Arc<Self>) {
let mgr = Arc::clone(self);
let interval = cleanup_interval(mgr.ttl);
tokio::spawn(async move {
let mut ticker = time::interval(interval);
loop {
ticker.tick().await;
let _ = mgr.cleanup_expired().await;
}
});
}
#[cfg(test)]
pub async fn force_expire_session(&self, session_id: &str, age: Duration) -> Result<()> {
let conn = self.conn.clone();
let session_id = session_id.to_string();
let age_secs = age.as_secs() as i64;
tokio::task::spawn_blocking(move || {
let conn = conn.lock();
let new_time = unix_seconds_now() - age_secs;
conn.execute(
"UPDATE agent_sessions SET updated_at = ?2 WHERE session_id = ?1",
params![session_id, new_time],
)?;
Ok(())
})
.await
.context("SQLite blocking task panicked")?
}
}
#[async_trait]
impl SessionManager for SqliteSessionManager {
fn clone_arc(&self) -> Arc<dyn SessionManager> {
Arc::new(self.clone())
}
async fn ensure_exists(&self, session_id: &str) -> Result<()> {
let now = unix_seconds_now();
let conn = self.conn.clone();
let session_id = session_id.to_string();
tokio::task::spawn_blocking(move || {
let conn = conn.lock();
conn.execute(
"INSERT OR IGNORE INTO agent_sessions(session_id, history_json, updated_at)
VALUES(?1, '[]', ?2)",
params![session_id, now],
)?;
Ok(())
})
.await
.context("SQLite blocking task panicked")?
}
async fn get_history(&self, session_id: &str) -> Result<Vec<ChatMessage>> {
let conn = self.conn.clone();
let session_id = session_id.to_string();
let max_messages = self.max_messages;
tokio::task::spawn_blocking(move || {
let conn = conn.lock();
let mut stmt = conn.prepare(
"SELECT history_json FROM agent_sessions WHERE session_id = ?1",
)?;
let mut rows = stmt.query(params![session_id])?;
if let Some(row) = rows.next()? {
let json: String = row.get(0)?;
let mut history: Vec<ChatMessage> = serde_json::from_str(&json)
.with_context(|| format!("Failed to parse session history for session_id={session_id}"))?;
trim_non_system(&mut history, max_messages);
return Ok(history);
}
Ok(Vec::new())
})
.await
.context("SQLite blocking task panicked")?
}
async fn set_history(&self, session_id: &str, mut history: Vec<ChatMessage>) -> Result<()> {
trim_non_system(&mut history, self.max_messages);
let json = serde_json::to_string(&history)?;
let now = unix_seconds_now();
let conn = self.conn.clone();
let session_id = session_id.to_string();
tokio::task::spawn_blocking(move || {
let conn = conn.lock();
conn.execute(
"INSERT INTO agent_sessions(session_id, history_json, updated_at)
VALUES(?1, ?2, ?3)
ON CONFLICT(session_id) DO UPDATE SET history_json=excluded.history_json, updated_at=excluded.updated_at",
params![session_id, json, now],
)?;
Ok(())
})
.await
.context("SQLite blocking task panicked")?
}
async fn delete(&self, session_id: &str) -> Result<()> {
let conn = self.conn.clone();
let session_id = session_id.to_string();
tokio::task::spawn_blocking(move || {
let conn = conn.lock();
conn.execute(
"DELETE FROM agent_sessions WHERE session_id = ?1",
params![session_id],
)?;
Ok(())
})
.await
.context("SQLite blocking task panicked")?
}
async fn cleanup_expired(&self) -> Result<usize> {
if self.ttl.is_zero() {
return Ok(0);
}
let conn = self.conn.clone();
let ttl_secs = self.ttl.as_secs() as i64;
tokio::task::spawn_blocking(move || {
let cutoff = unix_seconds_now() - ttl_secs;
let conn = conn.lock();
let removed = conn.execute(
"DELETE FROM agent_sessions WHERE updated_at < ?1",
params![cutoff],
)?;
Ok(removed)
})
.await
.context("SQLite blocking task panicked")?
}
}
fn cleanup_interval(ttl: Duration) -> Duration {
if ttl.is_zero() {
return Duration::from_secs(60);
}
let half = ttl / 2;
if half < Duration::from_secs(30) {
Duration::from_secs(30)
} else if half > Duration::from_secs(300) {
Duration::from_secs(300)
} else {
half
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_session_id_respects_strategy() {
let mut cfg = AgentSessionConfig::default();
cfg.strategy = AgentSessionStrategy::Main;
assert_eq!(resolve_session_id(&cfg, "u1", Some("whatsapp")), "main");
cfg.strategy = AgentSessionStrategy::PerChannel;
assert_eq!(resolve_session_id(&cfg, "u1", Some("whatsapp")), "whatsapp");
assert_eq!(resolve_session_id(&cfg, "u1", None), "main");
cfg.strategy = AgentSessionStrategy::PerSender;
assert_eq!(
resolve_session_id(&cfg, "u1", Some("whatsapp")),
"whatsapp:u1"
);
assert_eq!(resolve_session_id(&cfg, "u1", None), "u1");
assert_eq!(
resolve_session_id(&cfg, "u1", Some("matrix:@alice")),
"matrix%3A@alice:u1"
);
}
#[tokio::test]
async fn memory_session_accumulates_history() -> Result<()> {
let mgr = MemorySessionManager::new(Duration::from_secs(3600), 50);
let session = mgr.get_or_create("s1").await?;
assert!(session.get_history().await?.is_empty());
session
.update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")])
.await?;
assert_eq!(session.get_history().await?.len(), 2);
let mut h = session.get_history().await?;
h.push(ChatMessage::user("again"));
h.push(ChatMessage::assistant("ok2"));
session.update_history(h).await?;
assert_eq!(session.get_history().await?.len(), 4);
Ok(())
}
#[tokio::test]
async fn memory_sessions_do_not_mix_histories() -> Result<()> {
let mgr = MemorySessionManager::new(Duration::from_secs(3600), 50);
let a = mgr.get_or_create("a").await?;
let b = mgr.get_or_create("b").await?;
a.update_history(vec![ChatMessage::user("u1"), ChatMessage::assistant("a1")])
.await?;
b.update_history(vec![ChatMessage::user("u2"), ChatMessage::assistant("b1")])
.await?;
let ha = a.get_history().await?;
let hb = b.get_history().await?;
assert_eq!(ha[0].content, "u1");
assert_eq!(hb[0].content, "u2");
Ok(())
}
#[tokio::test]
async fn max_messages_trims_oldest_non_system() -> Result<()> {
let mgr = MemorySessionManager::new(Duration::from_secs(3600), 2);
let session = mgr.get_or_create("s1").await?;
session
.update_history(vec![
ChatMessage::system("s"),
ChatMessage::user("1"),
ChatMessage::assistant("2"),
ChatMessage::user("3"),
ChatMessage::assistant("4"),
])
.await?;
let h = session.get_history().await?;
assert_eq!(h.len(), 2);
assert_eq!(h[0].content, "3");
assert_eq!(h[1].content, "4");
Ok(())
}
#[tokio::test]
async fn sqlite_session_persists_across_instances() -> Result<()> {
let dir = tempfile::tempdir()?;
let db_path = dir.path().join("sessions.db");
{
let mgr = SqliteSessionManager::new(db_path.clone(), Duration::from_secs(3600), 50)?;
let session = mgr.get_or_create("s1").await?;
session
.update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")])
.await?;
}
let mgr2 = SqliteSessionManager::new(db_path, Duration::from_secs(3600), 50)?;
let session2 = mgr2.get_or_create("s1").await?;
let history = session2.get_history().await?;
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "user");
assert_eq!(history[1].role, "assistant");
Ok(())
}
#[tokio::test]
async fn sqlite_session_cleanup_expires() -> Result<()> {
let dir = tempfile::tempdir()?;
let db_path = dir.path().join("sessions.db");
// TTL 1 second
let mgr = SqliteSessionManager::new(db_path, Duration::from_secs(1), 50)?;
let session = mgr.get_or_create("s1").await?;
session
.update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")])
.await?;
// Force expire by setting age to 2 seconds
mgr.force_expire_session("s1", Duration::from_secs(2))
.await?;
let removed = mgr.cleanup_expired().await?;
assert!(removed >= 1);
Ok(())
}
}
+11 -6
View File
@@ -636,7 +636,7 @@ async fn history_trims_after_max_messages() {
// ═══════════════════════════════════════════════════════════════════════════
#[tokio::test]
async fn auto_save_stores_only_user_messages_in_memory() {
async fn auto_save_stores_user_and_assistant_messages_in_memory() {
let (mem, _tmp) = make_sqlite_memory();
let provider = Box::new(ScriptedProvider::new(vec![text_response(
"I remember everything",
@@ -651,11 +651,11 @@ async fn auto_save_stores_only_user_messages_in_memory() {
let _ = agent.turn("Remember this fact").await.unwrap();
// Auto-save only persists user-stated input, never assistant-generated summaries.
// Auto-save persists both user input and assistant output for traceability.
let count = mem.count().await.unwrap();
assert_eq!(
count, 1,
"Expected exactly 1 user memory entry, got {count}"
count, 2,
"Expected user + assistant memory entries, got {count}"
);
let stored = mem.get("user_msg").await.unwrap();
@@ -668,8 +668,13 @@ async fn auto_save_stores_only_user_messages_in_memory() {
let assistant = mem.get("assistant_resp").await.unwrap();
assert!(
assistant.is_none(),
"assistant_resp should not be auto-saved anymore"
assistant.is_some(),
"Expected assistant_resp key to be present"
);
assert_eq!(
assistant.unwrap().content,
"I remember everything",
"Assistant response should be persisted when auto-save is enabled"
);
}
+59
View File
@@ -0,0 +1,59 @@
use anyhow::{bail, Context, Result};
use serde::Deserialize;
use tracing_subscriber::EnvFilter;
use zeroclaw::config::schema::McpServerConfig;
#[derive(Default, Deserialize)]
struct FileMcp {
#[serde(default)]
enabled: bool,
#[serde(default)]
servers: Vec<McpServerConfig>,
}
#[derive(Default, Deserialize)]
struct FileRoot {
#[serde(default)]
mcp: FileMcp,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.init();
let (enabled, servers) = match std::fs::read_to_string("config.toml") {
Ok(s) => {
let start = s
.lines()
.position(|line| line.trim() == "[mcp]")
.unwrap_or(0);
let slice = s.lines().skip(start).collect::<Vec<_>>().join("\n");
let root: FileRoot = toml::from_str(&slice).context("failed to parse ./config.toml")?;
(root.mcp.enabled, root.mcp.servers)
}
Err(_) => {
let config = zeroclaw::Config::load_or_init().await?;
(config.mcp.enabled, config.mcp.servers)
}
};
if !enabled || servers.is_empty() {
bail!("MCP is disabled or no servers configured");
}
let registry = zeroclaw::tools::McpRegistry::connect_all(&servers).await?;
let tool_count = registry.tool_names().len();
tracing::info!(
"MCP smoke ok: {} server(s), {} tool(s)",
registry.server_count(),
tool_count
);
if registry.server_count() == 0 {
bail!("no MCP servers connected");
}
Ok(())
}
+901
View File
@@ -0,0 +1,901 @@
//! ACP (Agent Client Protocol) channel for ZeroClaw.
//!
//! This channel enables ZeroClaw to act as an ACP client, connecting to an OpenCode
//! ACP server via `opencode acp` command for JSON-RPC 2.0 communication over stdio.
//! This allows users to control OpenCode behavior from any channel via social apps.
use super::traits::{Channel, ChannelMessage, SendMessage};
use crate::config::schema::AcpConfig;
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::VecDeque;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::mpsc;
use tokio::sync::Mutex;
/// Monotonic counter for message IDs in ACP JSON-RPC requests.
static ACP_MESSAGE_ID: AtomicU64 = AtomicU64::new(0);
/// ACP channel implementation for connecting to OpenCode ACP server.
///
/// The channel starts an OpenCode subprocess via `opencode acp` command and
/// communicates using JSON-RPC 2.0 over stdio. Messages from social apps are
/// forwarded as prompts to OpenCode, and responses are sent back through the
/// originating channel.
pub struct AcpChannel {
/// OpenCode binary path (default: "opencode")
opencode_path: String,
/// Working directory for OpenCode process
workdir: Option<String>,
/// Additional arguments to pass to `opencode acp`
extra_args: Vec<String>,
/// Allowed user identifiers (empty = deny all, "*" = allow all)
allowed_users: Vec<String>,
/// Optional pairing guard for authentication
pairing: Option<crate::security::pairing::PairingGuard>,
/// HTTP client for potential future HTTP transport support
client: reqwest::Client,
/// Active OpenCode subprocess and its I/O handles
process: Arc<Mutex<Option<AcpProcess>>>,
/// Serializes ACP send operations to avoid concurrent process take/spawn races.
send_operation_lock: Arc<Mutex<()>>,
/// Next message ID for JSON-RPC requests
next_message_id: Arc<AtomicU64>,
/// Optional response channel for sending ACP responses back to original channel
response_channel: Option<Arc<dyn Channel>>,
}
/// Active ACP process with I/O handles and session state.
struct AcpProcess {
/// Child process handle
child: Child,
/// Stdin handle for sending JSON-RPC requests
stdin: tokio::process::ChildStdin,
/// Stdout handle for receiving JSON-RPC responses
stdout: BufReader<tokio::process::ChildStdout>,
/// Session ID from ACP server (after initialize + session/new)
session_id: Option<String>,
/// JSON-RPC message ID counter (per-process)
message_id: u64,
/// Pending responses keyed by request ID
pending_responses: VecDeque<PendingResponse>,
}
/// Pending JSON-RPC response awaiting completion.
struct PendingResponse {
request_id: u64,
method: String,
created_at: std::time::Instant,
}
/// JSON-RPC 2.0 request structure.
#[derive(Debug, Clone, Serialize)]
struct JsonRpcRequest {
jsonrpc: String,
id: u64,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
}
/// JSON-RPC 2.0 response structure.
#[derive(Debug, Clone, Deserialize)]
struct JsonRpcResponse {
jsonrpc: String,
id: u64,
#[serde(flatten)]
result_or_error: JsonRpcResultOrError,
}
/// JSON-RPC result or error.
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum JsonRpcResultOrError {
Result { result: Value },
Error { error: JsonRpcError },
}
/// JSON-RPC error object.
#[derive(Debug, Clone, Deserialize)]
struct JsonRpcError {
code: i32,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
}
/// ACP initialization parameters.
#[derive(Debug, Clone, Serialize)]
struct InitializeParams {
protocol_version: u64,
client_capabilities: ClientCapabilities,
client_info: ClientInfo,
}
/// Client capabilities declaration.
#[derive(Debug, Clone, Serialize, Default)]
struct ClientCapabilities {
fs: FsCapabilities,
terminal: bool,
#[serde(rename = "_meta")]
meta: Option<Value>,
}
/// Filesystem capabilities.
#[derive(Debug, Clone, Serialize, Default)]
struct FsCapabilities {
read_text_file: bool,
write_text_file: bool,
}
/// Client information.
#[derive(Debug, Clone, Serialize)]
struct ClientInfo {
name: String,
title: String,
version: String,
}
/// ACP session/new parameters.
#[derive(Debug, Clone, Serialize)]
struct SessionNewParams {
cwd: String,
mcp_servers: Vec<Value>,
}
/// ACP session/prompt parameters.
#[derive(Debug, Clone, Serialize)]
struct SessionPromptParams {
session_id: String,
prompt: Vec<PromptItem>,
}
/// Prompt item (text content).
#[derive(Debug, Clone, Serialize)]
struct PromptItem {
#[serde(rename = "type")]
item_type: String,
text: String,
}
impl AcpChannel {
/// Create a new ACP channel with the given configuration.
pub fn new(config: AcpConfig) -> Self {
Self {
opencode_path: config
.opencode_path
.unwrap_or_else(|| "opencode".to_string()),
workdir: config.workdir,
extra_args: config.extra_args,
allowed_users: config.allowed_users,
pairing: None, // TODO: Implement pairing if needed
client: reqwest::Client::new(),
process: Arc::new(Mutex::new(None)),
send_operation_lock: Arc::new(Mutex::new(())),
next_message_id: Arc::new(AtomicU64::new(0)),
response_channel: None,
}
}
/// Check if a user is allowed to interact with this channel.
fn is_user_allowed(&self, user_id: &str) -> bool {
self.allowed_users
.iter()
.any(|allowed| allowed == "*" || allowed == user_id)
}
/// Set the response channel for sending ACP responses back to original channel
pub fn set_response_channel(&mut self, channel: Arc<dyn Channel>) {
self.response_channel = Some(channel);
}
/// Start the OpenCode ACP subprocess and establish connection.
fn start_process(&self) -> Result<AcpProcess> {
let mut command = Command::new(&self.opencode_path);
command.arg("acp");
if let Some(workdir) = &self.workdir {
command.current_dir(workdir);
}
for arg in &self.extra_args {
command.arg(arg);
}
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
// Inherit stderr so the child cannot block on an unread stderr pipe.
command.stderr(std::process::Stdio::inherit());
let mut child = command
.spawn()
.with_context(|| format!("Failed to start OpenCode process: {}", self.opencode_path))?;
let stdin = child
.stdin
.take()
.context("Failed to take stdin from child process")?;
let stdout = child
.stdout
.take()
.context("Failed to take stdout from child process")?;
let stdout_reader = BufReader::new(stdout);
let process = AcpProcess {
child,
stdin,
stdout: stdout_reader,
session_id: None,
message_id: 0,
pending_responses: VecDeque::new(),
};
Ok(process)
}
/// Send a JSON-RPC request and wait for response.
async fn send_json_rpc_request(
&self,
process: &mut AcpProcess,
method: &str,
params: Option<Value>,
) -> Result<Value> {
let request_id = process.message_id;
process.message_id += 1;
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: request_id,
method: method.to_string(),
params,
};
let json_str = serde_json::to_string(&request).with_context(|| {
format!(
"Failed to serialize JSON-RPC request for method: {}",
method
)
})?;
// Write message with newline delimiter (ACP protocol requirement)
process.stdin.write_all(json_str.as_bytes()).await?;
process.stdin.write_all(b"\n").await?;
process.stdin.flush().await?;
// Read response line with timeout
let mut line = String::new();
let timeout_duration = std::time::Duration::from_secs(30);
match tokio::time::timeout(timeout_duration, process.stdout.read_line(&mut line)).await {
Ok(read_result) => {
read_result
.with_context(|| format!("Failed to read response for method: {}", method))?;
}
Err(_) => {
anyhow::bail!("Timeout waiting for ACP response for method: {}", method);
}
}
// Parse JSON-RPC response
let response: JsonRpcResponse = serde_json::from_str(&line)
.with_context(|| format!("Failed to parse JSON-RPC response: {}", line))?;
// Verify response ID matches request ID
if response.id != request_id {
anyhow::bail!(
"Response ID mismatch: expected {}, got {}",
request_id,
response.id
);
}
match response.result_or_error {
JsonRpcResultOrError::Result { result } => Ok(result),
JsonRpcResultOrError::Error { error } => {
anyhow::bail!("ACP JSON-RPC error ({}): {}", error.code, error.message);
}
}
}
/// Initialize ACP connection with the server.
async fn initialize_acp(&self, process: &mut AcpProcess) -> Result<()> {
let params = InitializeParams {
protocol_version: 1,
client_capabilities: ClientCapabilities::default(),
client_info: ClientInfo {
name: "ZeroClaw".to_string(),
title: "ZeroClaw ACP Client".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
let params_value =
serde_json::to_value(params).context("Failed to serialize initialize params")?;
let response = self
.send_json_rpc_request(process, "initialize", Some(params_value))
.await?;
// TODO: Parse response and store capabilities
tracing::info!("ACP initialized successfully: {:?}", response);
Ok(())
}
/// Create a new ACP session.
async fn create_session(&self, process: &mut AcpProcess) -> Result<String> {
let cwd = self.workdir.clone().unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_else(|_| ".".into())
.to_string_lossy()
.to_string()
});
let params = SessionNewParams {
cwd,
mcp_servers: vec![],
};
let params_value =
serde_json::to_value(params).context("Failed to serialize session/new params")?;
let response = self
.send_json_rpc_request(process, "session/new", Some(params_value))
.await?;
// Parse response to extract session_id
let session_id = response
.get("session_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.with_context(|| {
format!(
"Invalid session/new response: missing session_id: {:?}",
response
)
})?;
tracing::info!("ACP session created: {}", session_id);
Ok(session_id)
}
/// Send a prompt to the ACP session.
async fn send_prompt(
&self,
process: &mut AcpProcess,
session_id: &str,
prompt_text: &str,
) -> Result<String> {
let params = SessionPromptParams {
session_id: session_id.to_string(),
prompt: vec![PromptItem {
item_type: "text".to_string(),
text: prompt_text.to_string(),
}],
};
let params_value =
serde_json::to_value(params).context("Failed to serialize session/prompt params")?;
let response = self
.send_json_rpc_request(process, "session/prompt", Some(params_value))
.await?;
// Parse response to extract the actual response text
// The response may contain a "response" field with text content
let response_text = response
.get("response")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.with_context(|| {
format!(
"Invalid session/prompt response: missing string field `response` for prompt {:?}: {:?}",
prompt_text, response
)
})?;
Ok(response_text)
}
fn process_is_running(process: &mut AcpProcess) -> bool {
matches!(process.child.try_wait(), Ok(None))
}
async fn initialize_fresh_process(&self) -> Result<AcpProcess> {
let mut new_process = self.start_process()?;
self.initialize_acp(&mut new_process).await?;
let session_id = self.create_session(&mut new_process).await?;
new_process.session_id = Some(session_id);
Ok(new_process)
}
async fn checkout_process_for_send(&self) -> Result<AcpProcess> {
let mut process_opt = {
let mut process_guard = self.process.lock().await;
process_guard.take()
};
let needs_restart = match process_opt.as_mut() {
Some(process) => !Self::process_is_running(process),
None => true,
};
if needs_restart {
process_opt = Some(self.initialize_fresh_process().await?);
}
process_opt.context("ACP process disappeared unexpectedly")
}
async fn restore_process(&self, process: Option<AcpProcess>) {
let mut process_guard = self.process.lock().await;
*process_guard = process;
}
}
#[async_trait]
impl Channel for AcpChannel {
fn name(&self) -> &str {
"acp"
}
async fn send(&self, message: &SendMessage) -> Result<()> {
const MAX_SEND_ATTEMPTS: usize = 2;
let _send_guard = self.send_operation_lock.lock().await;
// Check if user is allowed
if !self.is_user_allowed(&message.recipient) {
tracing::warn!(
"ACP: ignoring message from unauthorized user: {}",
message.recipient
);
return Ok(());
}
// Strip tool call tags from outgoing messages
let content = super::strip_tool_call_tags(&message.content);
let mut last_error = None;
for attempt in 0..MAX_SEND_ATTEMPTS {
let mut process = self.checkout_process_for_send().await?;
let session_id = process
.session_id
.as_ref()
.context("No active ACP session")?
.clone();
match self.send_prompt(&mut process, &session_id, &content).await {
Ok(response) => {
if Self::process_is_running(&mut process) {
self.restore_process(Some(process)).await;
} else {
self.restore_process(None).await;
}
// Send response back through response_channel if set
if let Some(response_channel) = &self.response_channel {
let response_message =
SendMessage::new(response, message.recipient.clone());
if let Err(e) = response_channel.send(&response_message).await {
tracing::warn!(
"Failed to send ACP response through response channel: {}",
e
);
}
} else {
// Log if no response channel configured
tracing::info!(
"ACP response ready (no response channel configured): {}",
response
);
}
return Ok(());
}
Err(error) => {
// Drop unhealthy process on failure and retry once with a fresh process.
self.restore_process(None).await;
if attempt + 1 < MAX_SEND_ATTEMPTS {
tracing::warn!(
"ACP prompt failed (attempt {}/{}), restarting ACP process: {}",
attempt + 1,
MAX_SEND_ATTEMPTS,
error
);
}
last_error = Some(error);
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("ACP send failed with unknown error")))
}
async fn listen(&self, _tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
// ACP is primarily a client-side protocol where we send prompts
// and receive responses. For channel listening, we might need to
// handle incoming messages from other sources that should trigger
// ACP prompts.
// Since ACP is more about sending commands to OpenCode rather than
// listening for incoming messages, we implement a minimal listener
// that just keeps the channel alive.
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
}
}
async fn health_check(&self) -> bool {
let mut process_guard = self.process.lock().await;
let Some(process) = process_guard.as_mut() else {
return false;
};
let is_running = Self::process_is_running(process);
if !is_running {
*process_guard = None;
}
is_running
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::schema::AcpConfig;
#[test]
fn acp_channel_name() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec![],
};
let channel = AcpChannel::new(config);
assert_eq!(channel.name(), "acp");
}
#[test]
fn acp_channel_empty_allowlist_denies_all() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec![],
};
let channel = AcpChannel::new(config);
assert!(!channel.is_user_allowed("anyone"));
assert!(!channel.is_user_allowed("user123"));
}
#[test]
fn acp_channel_wildcard_allows_all() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec!["*".to_string()],
};
let channel = AcpChannel::new(config);
assert!(channel.is_user_allowed("anyone"));
assert!(channel.is_user_allowed("user123"));
assert!(channel.is_user_allowed(""));
}
#[test]
fn acp_channel_specific_users() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec!["user1".to_string(), "user2".to_string()],
};
let channel = AcpChannel::new(config);
assert!(channel.is_user_allowed("user1"));
assert!(channel.is_user_allowed("user2"));
assert!(!channel.is_user_allowed("user3"));
assert!(!channel.is_user_allowed("User1")); // case sensitive
assert!(!channel.is_user_allowed("user"));
}
#[test]
fn acp_channel_wildcard_and_specific() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec!["user1".to_string(), "*".to_string()],
};
let channel = AcpChannel::new(config);
assert!(channel.is_user_allowed("user1"));
assert!(channel.is_user_allowed("anyone"));
assert!(channel.is_user_allowed("user2"));
}
#[test]
fn acp_channel_empty_user_id() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec!["user1".to_string()],
};
let channel = AcpChannel::new(config);
assert!(!channel.is_user_allowed(""));
}
#[test]
fn acp_channel_exact_match_not_substring() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec!["user123".to_string()],
};
let channel = AcpChannel::new(config);
assert!(channel.is_user_allowed("user123"));
assert!(!channel.is_user_allowed("user12"));
assert!(!channel.is_user_allowed("user1234"));
assert!(!channel.is_user_allowed("user"));
}
#[test]
fn acp_channel_case_sensitive() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec!["User".to_string()],
};
let channel = AcpChannel::new(config);
assert!(channel.is_user_allowed("User"));
assert!(!channel.is_user_allowed("user"));
assert!(!channel.is_user_allowed("USER"));
}
// JSON-RPC data structure tests
#[test]
fn jsonrpc_request_serialization() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: 42,
method: "test".to_string(),
params: Some(serde_json::json!({"key": "value"})),
};
let json = serde_json::to_string(&request).unwrap();
let expected = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{"key":"value"}}"#;
assert_eq!(json, expected);
}
#[test]
fn jsonrpc_request_without_params() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: 1,
method: "ping".to_string(),
params: None,
};
let json = serde_json::to_string(&request).unwrap();
let expected = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;
assert_eq!(json, expected);
}
#[test]
fn jsonrpc_response_deserialization() {
let json = r#"{"jsonrpc":"2.0","id":42,"result":{"status":"ok"}}"#;
let response: JsonRpcResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.jsonrpc, "2.0");
assert_eq!(response.id, 42);
match response.result_or_error {
JsonRpcResultOrError::Result { result } => {
assert_eq!(result, serde_json::json!({"status": "ok"}));
}
JsonRpcResultOrError::Error { .. } => panic!("Expected result, got error"),
}
}
#[test]
fn jsonrpc_error_deserialization() {
let json = r#"{"jsonrpc":"2.0","id":42,"error":{"code":-32700,"message":"Parse error"}}"#;
let response: JsonRpcResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.jsonrpc, "2.0");
assert_eq!(response.id, 42);
match response.result_or_error {
JsonRpcResultOrError::Error { error } => {
assert_eq!(error.code, -32700);
assert_eq!(error.message, "Parse error");
assert!(error.data.is_none());
}
JsonRpcResultOrError::Result { .. } => panic!("Expected error, got result"),
}
}
#[test]
fn initialize_params_serialization() {
let params = InitializeParams {
protocol_version: 1,
client_capabilities: ClientCapabilities::default(),
client_info: ClientInfo {
name: "ZeroClaw".to_string(),
title: "ZeroClaw ACP Client".to_string(),
version: "1.0.0".to_string(),
},
};
let json = serde_json::to_value(&params).unwrap();
assert_eq!(json["protocol_version"], 1);
assert_eq!(json["client_info"]["name"], "ZeroClaw");
}
#[test]
fn session_new_params_serialization() {
let params = SessionNewParams {
cwd: "/tmp".to_string(),
mcp_servers: vec![],
};
let json = serde_json::to_value(&params).unwrap();
assert_eq!(json["cwd"], "/tmp");
assert_eq!(json["mcp_servers"], serde_json::json!([]));
}
#[test]
fn session_prompt_params_serialization() {
let params = SessionPromptParams {
session_id: "session-123".to_string(),
prompt: vec![PromptItem {
item_type: "text".to_string(),
text: "Hello".to_string(),
}],
};
let json = serde_json::to_value(&params).unwrap();
assert_eq!(json["session_id"], "session-123");
assert_eq!(json["prompt"][0]["type"], "text");
assert_eq!(json["prompt"][0]["text"], "Hello");
}
#[test]
fn acp_channel_set_response_channel() {
use super::Channel;
use crate::channels::traits::SendMessage;
use std::sync::Arc;
// Mock channel for testing
struct MockChannel;
#[async_trait::async_trait]
impl Channel for MockChannel {
fn name(&self) -> &str {
"mock"
}
async fn send(&self, _message: &SendMessage) -> Result<()> {
Ok(())
}
async fn listen(
&self,
_tx: tokio::sync::mpsc::Sender<crate::channels::traits::ChannelMessage>,
) -> Result<()> {
Ok(())
}
async fn health_check(&self) -> bool {
true
}
}
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec![],
};
let mut channel = AcpChannel::new(config);
let mock_channel = Arc::new(MockChannel);
// Initially no response channel
// (Cannot directly access private field, but we can test via public API)
// Set response channel
channel.set_response_channel(mock_channel.clone());
// Verify channel can be set (no panic)
// This test mainly ensures the method exists and works
assert!(true);
}
// Note: More comprehensive tests would require mocking the OpenCode process
// which is beyond the scope of basic unit tests.
#[cfg(unix)]
async fn spawn_test_process(command: &str, args: &[&str]) -> AcpProcess {
let mut child = Command::new(command)
.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.spawn()
.expect("failed to spawn test ACP process");
let stdin = child.stdin.take().expect("test process stdin");
let stdout = BufReader::new(child.stdout.take().expect("test process stdout"));
AcpProcess {
child,
stdin,
stdout,
session_id: Some("test-session".to_string()),
message_id: 0,
pending_responses: VecDeque::new(),
}
}
#[cfg(unix)]
async fn cleanup_test_process(channel: &AcpChannel) {
let process = {
let mut guard = channel.process.lock().await;
guard.take()
};
if let Some(mut process) = process {
let _ = process.child.kill().await;
let _ = process.child.wait().await;
}
}
#[cfg(unix)]
#[tokio::test]
async fn acp_health_check_false_when_no_process() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec![],
};
let channel = AcpChannel::new(config);
assert!(!channel.health_check().await);
}
#[cfg(unix)]
#[tokio::test]
async fn acp_health_check_true_when_process_running() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec![],
};
let channel = AcpChannel::new(config);
let process = spawn_test_process("sh", &["-c", "sleep 5"]).await;
{
let mut guard = channel.process.lock().await;
*guard = Some(process);
}
assert!(channel.health_check().await);
cleanup_test_process(&channel).await;
}
#[cfg(unix)]
#[tokio::test]
async fn acp_health_check_false_after_process_exit() {
let config = AcpConfig {
opencode_path: None,
workdir: None,
extra_args: vec![],
allowed_users: vec![],
};
let channel = AcpChannel::new(config);
let process = spawn_test_process("sh", &["-c", "true"]).await;
{
let mut guard = channel.process.lock().await;
*guard = Some(process);
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(!channel.health_check().await);
}
}
+162 -13
View File
@@ -174,7 +174,6 @@ struct LarkEvent {
#[derive(Debug, serde::Deserialize)]
struct LarkEventHeader {
event_type: String,
#[allow(dead_code)]
event_id: String,
}
@@ -217,6 +216,10 @@ const LARK_TOKEN_REFRESH_SKEW: Duration = Duration::from_secs(120);
const LARK_DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(7200);
/// Feishu/Lark API business code for expired/invalid tenant access token.
const LARK_INVALID_ACCESS_TOKEN_CODE: i64 = 99_991_663;
/// Retention window for seen event/message dedupe keys.
const LARK_EVENT_DEDUP_TTL: Duration = Duration::from_secs(30 * 60);
/// Periodic cleanup interval for the dedupe cache.
const LARK_EVENT_DEDUP_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
const LARK_IMAGE_DOWNLOAD_FALLBACK_TEXT: &str =
"[Image message received but could not be downloaded]";
@@ -367,8 +370,10 @@ pub struct LarkChannel {
receive_mode: crate::config::schema::LarkReceiveMode,
/// Cached tenant access token
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
/// Dedup set for recently seen event/message keys across WS + webhook paths.
recent_event_keys: Arc<RwLock<HashMap<String, Instant>>>,
/// Last time we ran TTL cleanup over the dedupe cache.
recent_event_cleanup_at: Arc<RwLock<Instant>>,
}
impl LarkChannel {
@@ -412,7 +417,8 @@ impl LarkChannel {
platform,
receive_mode: crate::config::schema::LarkReceiveMode::default(),
tenant_token: Arc::new(RwLock::new(None)),
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
recent_event_keys: Arc::new(RwLock::new(HashMap::new())),
recent_event_cleanup_at: Arc::new(RwLock::new(Instant::now())),
}
}
@@ -520,6 +526,42 @@ impl LarkChannel {
}
}
fn dedupe_event_key(event_id: Option<&str>, message_id: Option<&str>) -> Option<String> {
let normalized_event = event_id.map(str::trim).filter(|value| !value.is_empty());
if let Some(event_id) = normalized_event {
return Some(format!("event:{event_id}"));
}
let normalized_message = message_id.map(str::trim).filter(|value| !value.is_empty());
normalized_message.map(|message_id| format!("message:{message_id}"))
}
async fn try_mark_event_key_seen(&self, dedupe_key: &str) -> bool {
let now = Instant::now();
if self.recent_event_keys.read().await.contains_key(dedupe_key) {
return false;
}
let should_cleanup = {
let last_cleanup = self.recent_event_cleanup_at.read().await;
now.duration_since(*last_cleanup) >= LARK_EVENT_DEDUP_CLEANUP_INTERVAL
};
let mut seen = self.recent_event_keys.write().await;
if seen.contains_key(dedupe_key) {
return false;
}
if should_cleanup {
seen.retain(|_, t| now.duration_since(*t) < LARK_EVENT_DEDUP_TTL);
let mut last_cleanup = self.recent_event_cleanup_at.write().await;
*last_cleanup = now;
}
seen.insert(dedupe_key.to_string(), now);
true
}
async fn fetch_image_marker(&self, image_key: &str) -> anyhow::Result<String> {
if image_key.trim().is_empty() {
anyhow::bail!("empty image_key");
@@ -880,17 +922,14 @@ impl LarkChannel {
let lark_msg = &recv.message;
// Dedup
{
let now = Instant::now();
let mut seen = self.ws_seen_ids.write().await;
// GC
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
if seen.contains_key(&lark_msg.message_id) {
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
if let Some(dedupe_key) = Self::dedupe_event_key(
Some(event.header.event_id.as_str()),
Some(lark_msg.message_id.as_str()),
) {
if !self.try_mark_event_key_seen(&dedupe_key).await {
tracing::debug!("Lark WS: duplicate event dropped ({dedupe_key})");
continue;
}
seen.insert(lark_msg.message_id.clone(), now);
}
// Decode content by type (mirrors clawdbot-feishu parsing)
@@ -1290,6 +1329,22 @@ impl LarkChannel {
Some(e) => e,
None => return messages,
};
let event_id = payload
.pointer("/header/event_id")
.and_then(|id| id.as_str())
.map(str::trim)
.filter(|id| !id.is_empty());
let message_id = event
.pointer("/message/message_id")
.and_then(|id| id.as_str())
.map(str::trim)
.filter(|id| !id.is_empty());
if let Some(dedupe_key) = Self::dedupe_event_key(event_id, message_id) {
if !self.try_mark_event_key_seen(&dedupe_key).await {
tracing::debug!("Lark webhook: duplicate event dropped ({dedupe_key})");
return messages;
}
}
let open_id = event
.pointer("/sender/sender_id/open_id")
@@ -2318,6 +2373,100 @@ mod tests {
assert_eq!(msgs[0].content, LARK_IMAGE_DOWNLOAD_FALLBACK_TEXT);
}
#[tokio::test]
async fn lark_parse_event_payload_async_dedupes_repeated_event_id() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
None,
vec!["*".into()],
true,
);
let payload = serde_json::json!({
"header": {
"event_type": "im.message.receive_v1",
"event_id": "evt_abc"
},
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_id": "om_first",
"message_type": "text",
"content": "{\"text\":\"hello\"}",
"chat_id": "oc_chat"
}
}
});
let first = ch.parse_event_payload_async(&payload).await;
let second = ch.parse_event_payload_async(&payload).await;
assert_eq!(first.len(), 1);
assert!(second.is_empty());
}
#[tokio::test]
async fn lark_parse_event_payload_async_dedupes_by_message_id_without_event_id() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
None,
vec!["*".into()],
true,
);
let payload = serde_json::json!({
"header": {
"event_type": "im.message.receive_v1"
},
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_id": "om_fallback",
"message_type": "text",
"content": "{\"text\":\"hello\"}",
"chat_id": "oc_chat"
}
}
});
let first = ch.parse_event_payload_async(&payload).await;
let second = ch.parse_event_payload_async(&payload).await;
assert_eq!(first.len(), 1);
assert!(second.is_empty());
}
#[tokio::test]
async fn try_mark_event_key_seen_cleans_up_expired_keys_periodically() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
None,
vec!["*".into()],
true,
);
{
let mut seen = ch.recent_event_keys.write().await;
seen.insert(
"event:stale".to_string(),
Instant::now() - LARK_EVENT_DEDUP_TTL - Duration::from_secs(5),
);
}
{
let mut cleanup_at = ch.recent_event_cleanup_at.write().await;
*cleanup_at =
Instant::now() - LARK_EVENT_DEDUP_CLEANUP_INTERVAL - Duration::from_secs(1);
}
assert!(ch.try_mark_event_key_seen("event:fresh").await);
let seen = ch.recent_event_keys.read().await;
assert!(!seen.contains_key("event:stale"));
assert!(seen.contains_key("event:fresh"));
}
#[test]
fn lark_parse_empty_text_skipped() {
let ch = LarkChannel::new(
+3 -8
View File
@@ -400,8 +400,7 @@ impl Channel for LinqChannel {
/// The signature is sent in `X-Webhook-Signature` (hex-encoded) and the
/// timestamp in `X-Webhook-Timestamp`. Reject timestamps older than 300s.
pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signature: &str) -> bool {
use hmac::{Hmac, Mac};
use sha2::Sha256;
use ring::hmac;
// Reject stale timestamps (>300s old)
if let Ok(ts) = timestamp.parse::<i64>() {
@@ -417,10 +416,6 @@ pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signatur
// Compute HMAC-SHA256 over "{timestamp}.{body}"
let message = format!("{timestamp}.{body}");
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(secret.as_bytes()) else {
return false;
};
mac.update(message.as_bytes());
let signature_hex = signature
.trim()
.strip_prefix("sha256=")
@@ -430,8 +425,8 @@ pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signatur
return false;
};
// Constant-time comparison via HMAC verify.
mac.verify_slice(&provided).is_ok()
let key = hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes());
hmac::verify(&key, message.as_bytes(), &provided).is_ok()
}
#[cfg(test)]
+505 -10
View File
@@ -16,6 +16,7 @@
pub mod bluebubbles;
pub mod clawdtalk;
pub mod acp;
pub mod cli;
pub mod dingtalk;
pub mod discord;
@@ -45,6 +46,7 @@ pub mod whatsapp_storage;
#[cfg(feature = "whatsapp-web")]
pub mod whatsapp_web;
pub use acp::AcpChannel;
pub use bluebubbles::BlueBubblesChannel;
pub use clawdtalk::ClawdTalkChannel;
pub use cli::CliChannel;
@@ -77,6 +79,7 @@ use crate::agent::loop_::{
build_shell_policy_instructions, build_tool_instructions_from_specs,
run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig,
};
use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager};
use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError};
use crate::config::{Config, NonCliNaturalLanguageApprovalMode};
use crate::identity;
@@ -100,6 +103,8 @@ use tokio_util::sync::CancellationToken;
/// Per-sender conversation history for channel messages.
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
type ConversationLockMap =
Arc<tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>;
/// Maximum history messages to keep per sender.
const MAX_CHANNEL_HISTORY: usize = 50;
/// Minimum user-message length (in chars) for auto-save to memory.
@@ -130,6 +135,9 @@ const MEMORY_CONTEXT_ENTRY_MAX_CHARS: usize = 800;
const MEMORY_CONTEXT_MAX_CHARS: usize = 4_000;
const CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES: usize = 12;
const CHANNEL_HISTORY_COMPACT_CONTENT_CHARS: usize = 600;
const CHANNEL_CONTEXT_TOKEN_ESTIMATE_LIMIT: usize = 90_000;
const CHANNEL_CONTEXT_TOKEN_ESTIMATE_TARGET: usize = 80_000;
const CHANNEL_CONTEXT_MIN_KEEP_NON_SYSTEM_MESSAGES: usize = 10;
/// Guardrail for hook-modified outbound channel content.
const CHANNEL_HOOK_MAX_OUTBOUND_CHARS: usize = 20_000;
@@ -278,6 +286,9 @@ struct ChannelRuntimeContext {
max_tool_iterations: usize,
min_relevance_score: f64,
conversation_histories: ConversationHistoryMap,
conversation_locks: ConversationLockMap,
session_config: crate::config::AgentSessionConfig,
session_manager: Option<Arc<dyn SessionManager + Send + Sync>>,
provider_cache: ProviderCacheMap,
route_overrides: RouteSelectionMap,
api_key: Option<String>,
@@ -338,6 +349,10 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
}
}
fn assistant_memory_key(msg: &traits::ChannelMessage) -> String {
format!("assistant_resp_{}", conversation_memory_key(msg))
}
fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
// QQ uses thread_ts as a passive-reply message id, not a thread identifier.
// Using it in history keys would reset context on every incoming message.
@@ -346,8 +361,10 @@ fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
}
// Include thread_ts for per-topic session isolation in forum groups
match &msg.thread_ts {
let channel = msg.channel.as_str();
match msg.thread_ts.as_deref().filter(|_| channel != "qq") {
Some(tid) => format!("{}_{}_{}", msg.channel, tid, msg.sender),
None if channel == "qq" => format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender),
None => format!("{}_{}", msg.channel, msg.sender),
}
}
@@ -961,10 +978,10 @@ fn resolved_default_provider(config: &Config) -> String {
}
fn resolved_default_model(config: &Config) -> String {
config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4.6".to_string())
crate::config::resolve_default_model_id(
config.default_model.as_deref(),
config.default_provider.as_deref(),
)
}
fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
@@ -1731,6 +1748,40 @@ fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatM
}
}
fn estimated_message_tokens(message: &ChatMessage) -> usize {
(message.content.chars().count().saturating_add(2) / 3).saturating_add(4)
}
fn estimated_history_tokens(history: &[ChatMessage]) -> usize {
history.iter().map(estimated_message_tokens).sum()
}
fn trim_channel_prompt_history(history: &mut Vec<ChatMessage>) -> bool {
let mut total = estimated_history_tokens(history);
if total <= CHANNEL_CONTEXT_TOKEN_ESTIMATE_LIMIT {
return false;
}
let mut trimmed = false;
loop {
if total <= CHANNEL_CONTEXT_TOKEN_ESTIMATE_TARGET {
break;
}
let non_system = history.iter().filter(|m| m.role != "system").count();
if non_system <= CHANNEL_CONTEXT_MIN_KEEP_NON_SYSTEM_MESSAGES {
break;
}
let Some(idx) = history.iter().position(|m| m.role != "system") else {
break;
};
let removed = history.remove(idx);
total = total.saturating_sub(estimated_message_tokens(&removed));
trimmed = true;
}
trimmed
}
fn rollback_orphan_user_turn(
ctx: &ChannelRuntimeContext,
sender_key: &str,
@@ -2691,10 +2742,11 @@ async fn build_memory_context(
mem: &dyn Memory,
user_msg: &str,
min_relevance_score: f64,
session_id: Option<&str>,
) -> String {
let mut context = String::new();
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
let mut included = 0usize;
let mut used_chars = 0usize;
@@ -3158,6 +3210,9 @@ async fn process_channel_message(
msg: traits::ChannelMessage,
cancellation_token: CancellationToken,
) {
let sender_id = msg.sender.as_str();
let channel_name = msg.channel.as_str();
tracing::debug!(sender_id, channel_name, "process_message called");
if cancellation_token.is_cancelled() {
return;
}
@@ -3251,6 +3306,62 @@ or tune thresholds in config.",
}
let history_key = conversation_history_key(&msg);
let conversation_lock = {
let mut locks = ctx.conversation_locks.lock().await;
locks
.entry(history_key.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
};
let _conversation_guard = conversation_lock.lock().await;
let mut session: Option<Session> = None;
if let Some(manager) = ctx.session_manager.as_ref() {
let session_id = resolve_session_id(
&ctx.session_config,
msg.sender.as_str(),
Some(msg.channel.as_str()),
);
tracing::debug!(session_id, "session_id resolved");
match manager.get_or_create(&session_id).await {
Ok(opened) => {
session = Some(opened);
}
Err(err) => {
tracing::warn!("Failed to open session: {err}");
}
}
}
if let Some(session) = session.as_ref() {
let should_seed = {
let histories = ctx
.conversation_histories
.lock()
.unwrap_or_else(|e| e.into_inner());
!histories.contains_key(&history_key)
};
if should_seed {
match session.get_history().await {
Ok(history) => {
tracing::debug!(history_len = history.len(), "session history loaded");
let filtered: Vec<ChatMessage> =
history
.into_iter()
.filter(|m| crate::providers::is_user_or_assistant_role(m.role.as_str()))
.collect();
let mut histories = ctx
.conversation_histories
.lock()
.unwrap_or_else(|e| e.into_inner());
histories.entry(history_key.clone()).or_insert(filtered);
}
Err(err) => {
tracing::warn!("Failed to load session history: {err}");
}
}
}
}
// Try classification first, fall back to sender/default route
let route = classify_message_route(ctx.as_ref(), &msg.content)
.unwrap_or_else(|| get_route_selection(ctx.as_ref(), &history_key));
@@ -3282,7 +3393,7 @@ or tune thresholds in config.",
&autosave_key,
&msg.content,
crate::memory::MemoryCategory::Conversation,
None,
Some(&history_key),
)
.await;
}
@@ -3338,6 +3449,7 @@ or tune thresholds in config.",
ctx.memory.as_ref(),
&msg.content,
ctx.min_relevance_score,
Some(&history_key),
)
.await;
if !memory_context.is_empty() {
@@ -3367,6 +3479,7 @@ or tune thresholds in config.",
));
let mut history = vec![ChatMessage::system(system_prompt)];
history.extend(prior_turns);
let _ = trim_channel_prompt_history(&mut history);
let use_streaming = target_channel
.as_ref()
.is_some_and(|ch| ch.supports_draft_updates());
@@ -3661,6 +3774,20 @@ or tune thresholds in config.",
&history_key,
ChatMessage::assistant(&history_response),
);
if ctx.auto_save_memory
&& delivered_response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
{
let assistant_key = assistant_memory_key(&msg);
let _ = ctx
.memory
.store(
&assistant_key,
&delivered_response,
crate::memory::MemoryCategory::Conversation,
None,
)
.await;
}
println!(
" 🤖 Reply ({}ms): {}",
started_at.elapsed().as_millis(),
@@ -3968,7 +4095,7 @@ async fn run_message_dispatch_loop(
}
}
process_channel_message(worker_ctx, msg, cancellation_token).await;
Box::pin(process_channel_message(worker_ctx, msg, cancellation_token)).await;
if interrupt_enabled {
let mut active = in_flight.lock().await;
@@ -4608,6 +4735,7 @@ fn collect_configured_channels(
sl.bot_token.clone(),
sl.app_token.clone(),
sl.channel_id.clone(),
sl.channel_ids.clone(),
sl.allowed_users.clone(),
)
.with_group_reply_policy(
@@ -4892,6 +5020,12 @@ fn collect_configured_channels(
});
}
if let Some(ref acp) = config.channels_config.acp {
channels.push(ConfiguredChannel {
display_name: "ACP",
channel: Arc::new(AcpChannel::new(acp.clone())),
});
}
channels
}
@@ -5335,6 +5469,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
.as_ref()
.is_some_and(|tg| tg.interrupt_on_new_message);
let session_manager = shared_session_manager(&config.agent.session, &config.workspace_dir)?
.map(|mgr| mgr as Arc<dyn SessionManager + Send + Sync>);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name,
provider: Arc::clone(&provider),
@@ -5349,6 +5486,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
max_tool_iterations: config.agent.max_tool_iterations,
min_relevance_score: config.memory.min_relevance_score,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
session_config: config.agent.session.clone(),
session_manager,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: config.api_key.clone(),
@@ -5412,11 +5552,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
}
#[cfg(test)]
#[allow(clippy::large_futures)]
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::observability::NoopObserver;
use crate::providers::{ChatMessage, Provider};
use crate::security::AutonomyLevel;
use crate::tools::{Tool, ToolResult};
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -5720,6 +5862,9 @@ mod tests {
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(histories)),
conversation_locks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -5774,6 +5919,9 @@ mod tests {
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -5831,6 +5979,9 @@ mod tests {
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(histories)),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -5888,6 +6039,11 @@ mod tests {
reactions_removed: tokio::sync::Mutex<Vec<(String, String, String)>>,
}
#[derive(Default)]
struct QqRecordingChannel {
sent_messages: tokio::sync::Mutex<Vec<String>>,
}
#[derive(Default)]
struct TelegramRecordingChannel {
sent_messages: tokio::sync::Mutex<Vec<String>>,
@@ -6044,6 +6200,36 @@ mod tests {
}
}
#[async_trait::async_trait]
impl Channel for QqRecordingChannel {
fn name(&self) -> &str {
"qq"
}
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
self.sent_messages
.lock()
.await
.push(format!("{}:{}", message.recipient, message.content));
Ok(())
}
async fn listen(
&self,
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
) -> anyhow::Result<()> {
Ok(())
}
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
Ok(())
}
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
Ok(())
}
}
struct SlowProvider {
delay: Duration,
}
@@ -6429,6 +6615,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6492,6 +6681,13 @@ BTC is currently around $65,000 based on latest tool output."#
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let autonomy_cfg = crate::config::AutonomyConfig {
level: AutonomyLevel::Full,
auto_approve: vec!["mock_price".to_string()],
..crate::config::AutonomyConfig::default()
};
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider),
@@ -6506,6 +6702,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6556,6 +6755,13 @@ BTC is currently around $65,000 based on latest tool output."#
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let autonomy_cfg = crate::config::AutonomyConfig {
level: AutonomyLevel::Full,
auto_approve: vec!["mock_price".to_string()],
..crate::config::AutonomyConfig::default()
};
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider),
@@ -6570,6 +6776,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6634,6 +6843,13 @@ BTC is currently around $65,000 based on latest tool output."#
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let autonomy_cfg = crate::config::AutonomyConfig {
level: AutonomyLevel::Full,
auto_approve: vec!["mock_price".to_string()],
..crate::config::AutonomyConfig::default()
};
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider),
@@ -6648,6 +6864,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6711,6 +6930,13 @@ BTC is currently around $65,000 based on latest tool output."#
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let autonomy_cfg = crate::config::AutonomyConfig {
level: AutonomyLevel::Full,
auto_approve: vec!["mock_price".to_string()],
..crate::config::AutonomyConfig::default()
};
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider),
@@ -6725,6 +6951,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6794,6 +7023,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6844,6 +7076,13 @@ BTC is currently around $65,000 based on latest tool output."#
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let autonomy_cfg = crate::config::AutonomyConfig {
level: AutonomyLevel::Full,
auto_approve: vec!["mock_price".to_string()],
..crate::config::AutonomyConfig::default()
};
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingAliasProvider),
@@ -6858,6 +7097,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -6931,6 +7173,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7032,6 +7277,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7184,6 +7432,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7296,6 +7547,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7403,6 +7657,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7495,6 +7752,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7586,6 +7846,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7683,6 +7946,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7692,7 +7958,7 @@ BTC is currently around $65,000 based on latest tool output."#
zeroclaw_dir: Some(temp.path().to_path_buf()),
..providers::ProviderRuntimeOptions::default()
},
workspace_dir: Arc::new(std::env::temp_dir()),
workspace_dir: Arc::new(temp.path().join("workspace")),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: false,
multimodal: crate::config::MultimodalConfig::default(),
@@ -7831,6 +8097,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -7925,6 +8194,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8072,6 +8344,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8189,6 +8464,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8286,6 +8564,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8405,6 +8686,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8522,6 +8806,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(route_overrides)),
api_key: None,
@@ -8598,6 +8885,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8688,6 +8978,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -8811,6 +9104,27 @@ BTC is currently around $65,000 based on latest tool output."#
assert_eq!(policy.outbound_leak_guard.sensitivity, 0.95);
}
#[tokio::test]
async fn load_runtime_defaults_from_config_file_uses_provider_fallback_when_model_missing() {
let temp = tempfile::TempDir::new().expect("temp dir");
let config_path = temp.path().join("config.toml");
let workspace_dir = temp.path().join("workspace");
std::fs::create_dir_all(&workspace_dir).expect("workspace dir");
let mut cfg = Config::default();
cfg.config_path = config_path.clone();
cfg.workspace_dir = workspace_dir;
cfg.default_provider = Some("openai".to_string());
cfg.default_model = None;
cfg.save().await.expect("save config");
let (defaults, _policy) = load_runtime_defaults_from_config_file(&config_path)
.await
.expect("runtime defaults");
assert_eq!(defaults.default_provider, "openai");
assert_eq!(defaults.model, "gpt-5.2");
}
#[tokio::test]
async fn maybe_apply_runtime_config_update_refreshes_autonomy_policy_and_excluded_tools() {
let temp = tempfile::TempDir::new().expect("temp dir");
@@ -8844,6 +9158,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: Some("http://127.0.0.1:11434".to_string()),
@@ -8964,6 +9281,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 12,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9029,6 +9349,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 3,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9206,6 +9529,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9293,6 +9619,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9392,6 +9721,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9473,6 +9805,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9539,6 +9874,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -9981,6 +10319,26 @@ BTC is currently around $65,000 based on latest tool output."#
);
}
#[test]
fn assistant_memory_key_is_namespaced_from_user_key() {
let msg = traits::ChannelMessage {
id: "msg_abc123".into(),
sender: "U123".into(),
reply_target: "C456".into(),
content: "hello".into(),
channel: "slack".into(),
timestamp: 1,
thread_ts: None,
};
let user_key = conversation_memory_key(&msg);
let assistant_key = assistant_memory_key(&msg);
assert!(assistant_key.starts_with("assistant_resp_"));
assert!(assistant_key.ends_with(&user_key));
assert_ne!(assistant_key, user_key);
}
#[test]
fn conversation_history_key_ignores_qq_message_id_thread() {
let msg1 = traits::ChannelMessage {
@@ -10092,11 +10450,37 @@ BTC is currently around $65,000 based on latest tool output."#
.await
.unwrap();
let context = build_memory_context(&mem, "age", 0.0).await;
let context = build_memory_context(&mem, "age", 0.0, None).await;
assert!(context.contains("[Memory context]"));
assert!(context.contains("Age is 45"));
}
#[tokio::test]
async fn build_memory_context_respects_session_scope() {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store(
"session_a_fact",
"Session A remembers age 45",
MemoryCategory::Conversation,
Some("session-a"),
)
.await
.unwrap();
mem.store(
"session_b_fact",
"Session B remembers age 31",
MemoryCategory::Conversation,
Some("session-b"),
)
.await
.unwrap();
let session_a_context = build_memory_context(&mem, "age", 0.0, Some("session-a")).await;
assert!(session_a_context.contains("age 45"));
assert!(!session_a_context.contains("age 31"));
}
#[tokio::test]
async fn process_channel_message_restores_per_sender_history_on_follow_ups() {
let channel_impl = Arc::new(RecordingChannel::default());
@@ -10121,6 +10505,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -10190,6 +10577,102 @@ BTC is currently around $65,000 based on latest tool output."#
assert!(calls[1][3].1.contains("follow up"));
}
#[tokio::test]
async fn process_channel_message_qq_keeps_history_across_distinct_message_ids() {
let channel_impl = Arc::new(QqRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let provider_impl = Arc::new(HistoryCaptureProvider::default());
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: provider_impl.clone(),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: false,
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())),
query_classification: crate::config::QueryClassificationConfig::default(),
model_routes: Vec::new(),
approval_manager: Arc::new(ApprovalManager::from_config(
&crate::config::AutonomyConfig::default(),
)),
safety_heartbeat: None,
startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(),
});
process_channel_message(
runtime_ctx.clone(),
traits::ChannelMessage {
id: "msg-a".to_string(),
sender: "alice".to_string(),
reply_target: "group:1".to_string(),
content: "hello".to_string(),
channel: "qq".to_string(),
timestamp: 1,
thread_ts: Some("msg-1".to_string()),
},
CancellationToken::new(),
)
.await;
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-b".to_string(),
sender: "alice".to_string(),
reply_target: "group:1".to_string(),
content: "follow up".to_string(),
channel: "qq".to_string(),
timestamp: 2,
thread_ts: Some("msg-2".to_string()),
},
CancellationToken::new(),
)
.await;
let calls = provider_impl
.calls
.lock()
.unwrap_or_else(|e| e.into_inner());
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].len(), 2);
assert_eq!(calls[0][0].0, "system");
assert_eq!(calls[0][1].0, "user");
assert_eq!(calls[1].len(), 4);
assert_eq!(calls[1][0].0, "system");
assert_eq!(calls[1][1].0, "user");
assert_eq!(calls[1][2].0, "assistant");
assert_eq!(calls[1][3].0, "user");
assert!(calls[1][1].1.contains("hello"));
assert!(calls[1][2].1.contains("response-1"));
assert!(calls[1][3].1.contains("follow up"));
}
#[tokio::test]
async fn process_channel_message_enriches_current_turn_without_persisting_context() {
let channel_impl = Arc::new(RecordingChannel::default());
@@ -10213,6 +10696,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -10309,6 +10795,9 @@ BTC is currently around $65,000 based on latest tool output."#
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(histories)),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -11088,6 +11577,9 @@ BTC is currently around $65,000 based on latest tool output."#;
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
@@ -11161,6 +11653,9 @@ BTC is currently around $65,000 based on latest tool output."#;
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
+6 -8
View File
@@ -1,7 +1,5 @@
use super::traits::{Channel, ChannelMessage, SendMessage};
use async_trait::async_trait;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use uuid::Uuid;
/// Nextcloud Talk channel in webhook mode.
@@ -247,6 +245,8 @@ pub fn verify_nextcloud_talk_signature(
body: &str,
signature: &str,
) -> bool {
use ring::hmac;
let random = random.trim();
if random.is_empty() {
tracing::warn!("Nextcloud Talk: missing X-Nextcloud-Talk-Random header");
@@ -265,17 +265,15 @@ pub fn verify_nextcloud_talk_signature(
};
let payload = format!("{random}{body}");
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(secret.as_bytes()) else {
return false;
};
mac.update(payload.as_bytes());
mac.verify_slice(&provided).is_ok()
let key = hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes());
hmac::verify(&key, payload.as_bytes(), &provided).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use hmac::{Hmac, Mac};
use sha2::Sha256;
fn make_channel() -> NextcloudTalkChannel {
NextcloudTalkChannel::new(
+303 -24
View File
@@ -3,39 +3,52 @@ use async_trait::async_trait;
use chrono::Utc;
use futures_util::{SinkExt, StreamExt};
use reqwest::header::HeaderMap;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Mutex;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio_tungstenite::tungstenite::Message as WsMessage;
#[derive(Clone)]
struct CachedSlackDisplayName {
display_name: String,
expires_at: Instant,
}
/// Slack channel — polls conversations.history via Web API
pub struct SlackChannel {
bot_token: String,
app_token: Option<String>,
channel_id: Option<String>,
channel_ids: Vec<String>,
allowed_users: Vec<String>,
mention_only: bool,
group_reply_allowed_sender_ids: Vec<String>,
user_display_name_cache: Mutex<HashMap<String, CachedSlackDisplayName>>,
}
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
const SLACK_HISTORY_DEFAULT_RETRY_AFTER_SECS: u64 = 1;
const SLACK_HISTORY_MAX_BACKOFF_SECS: u64 = 120;
const SLACK_HISTORY_MAX_JITTER_MS: u64 = 500;
const SLACK_USER_CACHE_TTL_SECS: u64 = 6 * 60 * 60;
impl SlackChannel {
pub fn new(
bot_token: String,
app_token: Option<String>,
channel_id: Option<String>,
channel_ids: Vec<String>,
allowed_users: Vec<String>,
) -> Self {
Self {
bot_token,
app_token,
channel_id,
channel_ids,
allowed_users,
mention_only: false,
group_reply_allowed_sender_ids: Vec::new(),
user_display_name_cache: Mutex::new(HashMap::new()),
}
}
@@ -111,6 +124,22 @@ impl SlackChannel {
Self::normalized_channel_id(self.channel_id.as_deref())
}
/// Resolve the effective channel scope:
/// explicit `channel_ids` list first, then single `channel_id`, otherwise wildcard discovery.
fn scoped_channel_ids(&self) -> Option<Vec<String>> {
let mut seen = HashSet::new();
let ids: Vec<String> = self
.channel_ids
.iter()
.filter_map(|entry| Self::normalized_channel_id(Some(entry)))
.filter(|id| seen.insert(id.clone()))
.collect();
if !ids.is_empty() {
return Some(ids);
}
self.configured_channel_id().map(|id| vec![id])
}
fn configured_app_token(&self) -> Option<String> {
self.app_token
.as_deref()
@@ -130,6 +159,137 @@ impl SlackChannel {
normalized
}
fn user_cache_ttl() -> Duration {
Duration::from_secs(SLACK_USER_CACHE_TTL_SECS)
}
fn sanitize_display_name(name: &str) -> Option<String> {
let trimmed = name.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn extract_user_display_name(payload: &serde_json::Value) -> Option<String> {
let user = payload.get("user")?;
let profile = user.get("profile");
let candidates = [
profile
.and_then(|p| p.get("display_name"))
.and_then(|v| v.as_str()),
profile
.and_then(|p| p.get("display_name_normalized"))
.and_then(|v| v.as_str()),
profile
.and_then(|p| p.get("real_name_normalized"))
.and_then(|v| v.as_str()),
profile
.and_then(|p| p.get("real_name"))
.and_then(|v| v.as_str()),
user.get("real_name").and_then(|v| v.as_str()),
user.get("name").and_then(|v| v.as_str()),
];
for candidate in candidates.into_iter().flatten() {
if let Some(display_name) = Self::sanitize_display_name(candidate) {
return Some(display_name);
}
}
None
}
fn cached_sender_display_name(&self, user_id: &str) -> Option<String> {
let now = Instant::now();
let Ok(mut cache) = self.user_display_name_cache.lock() else {
return None;
};
if let Some(entry) = cache.get(user_id) {
if now <= entry.expires_at {
return Some(entry.display_name.clone());
}
}
cache.remove(user_id);
None
}
fn cache_sender_display_name(&self, user_id: &str, display_name: &str) {
let Ok(mut cache) = self.user_display_name_cache.lock() else {
return;
};
cache.insert(
user_id.to_string(),
CachedSlackDisplayName {
display_name: display_name.to_string(),
expires_at: Instant::now() + Self::user_cache_ttl(),
},
);
}
async fn fetch_sender_display_name(&self, user_id: &str) -> Option<String> {
let resp = match self
.http_client()
.get("https://slack.com/api/users.info")
.bearer_auth(&self.bot_token)
.query(&[("user", user_id)])
.send()
.await
{
Ok(response) => response,
Err(err) => {
tracing::warn!("Slack users.info request failed for {user_id}: {err}");
return None;
}
};
let status = resp.status();
let body = resp
.text()
.await
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
if !status.is_success() {
let sanitized = crate::providers::sanitize_api_error(&body);
tracing::warn!("Slack users.info failed for {user_id} ({status}): {sanitized}");
return None;
}
let payload: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
if payload.get("ok") == Some(&serde_json::Value::Bool(false)) {
let err = payload
.get("error")
.and_then(|e| e.as_str())
.unwrap_or("unknown");
tracing::warn!("Slack users.info returned error for {user_id}: {err}");
return None;
}
Self::extract_user_display_name(&payload)
}
async fn resolve_sender_identity(&self, user_id: &str) -> String {
let user_id = user_id.trim();
if user_id.is_empty() {
return String::new();
}
if let Some(display_name) = self.cached_sender_display_name(user_id) {
return display_name;
}
if let Some(display_name) = self.fetch_sender_display_name(user_id).await {
self.cache_sender_display_name(user_id, &display_name);
return display_name;
}
user_id.to_string()
}
fn is_group_channel_id(channel_id: &str) -> bool {
matches!(channel_id.chars().next(), Some('C' | 'G'))
}
@@ -327,7 +487,7 @@ impl SlackChannel {
&self,
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
bot_user_id: &str,
scoped_channel: Option<String>,
scoped_channels: Option<Vec<String>>,
) -> anyhow::Result<()> {
let mut last_ts_by_channel: HashMap<String, String> = HashMap::new();
@@ -425,8 +585,8 @@ impl SlackChannel {
if channel_id.is_empty() {
continue;
}
if let Some(ref configured_channel) = scoped_channel {
if channel_id != *configured_channel {
if let Some(ref configured_channels) = scoped_channels {
if !configured_channels.iter().any(|id| id == &channel_id) {
continue;
}
}
@@ -476,10 +636,11 @@ impl SlackChannel {
};
last_ts_by_channel.insert(channel_id.clone(), ts.to_string());
let sender = self.resolve_sender_identity(user).await;
let channel_msg = ChannelMessage {
id: format!("slack_{channel_id}_{ts}"),
sender: user.to_string(),
sender,
reply_target: channel_id.clone(),
content: normalized_text,
channel: "slack".to_string(),
@@ -695,11 +856,11 @@ impl Channel for SlackChannel {
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let bot_user_id = self.get_bot_user_id().await.unwrap_or_default();
let scoped_channel = self.configured_channel_id();
let scoped_channels = self.scoped_channel_ids();
if self.configured_app_token().is_some() {
tracing::info!("Slack channel listening in Socket Mode");
return self
.listen_socket_mode(tx, &bot_user_id, scoped_channel)
.listen_socket_mode(tx, &bot_user_id, scoped_channels)
.await;
}
@@ -707,19 +868,23 @@ impl Channel for SlackChannel {
let mut last_discovery = Instant::now();
let mut last_ts_by_channel: HashMap<String, String> = HashMap::new();
if let Some(ref channel_id) = scoped_channel {
tracing::info!("Slack channel listening on #{channel_id}...");
if let Some(ref channel_ids) = scoped_channels {
tracing::info!(
"Slack channel listening on {} configured channel(s): {}",
channel_ids.len(),
channel_ids.join(", ")
);
} else {
tracing::info!(
"Slack channel_id not set (or '*'); listening across all accessible channels."
"Slack channel_id/channel_ids not set (or wildcard only); listening across all accessible channels."
);
}
loop {
tokio::time::sleep(Duration::from_secs(3)).await;
let target_channels = if let Some(ref channel_id) = scoped_channel {
vec![channel_id.clone()]
let target_channels = if let Some(ref channel_ids) = scoped_channels {
channel_ids.clone()
} else {
if discovered_channels.is_empty()
|| last_discovery.elapsed() >= Duration::from_secs(60)
@@ -820,10 +985,11 @@ impl Channel for SlackChannel {
};
last_ts_by_channel.insert(channel_id.clone(), ts.to_string());
let sender = self.resolve_sender_identity(user).await;
let channel_msg = ChannelMessage {
id: format!("slack_{channel_id}_{ts}"),
sender: user.to_string(),
sender,
reply_target: channel_id.clone(),
content: normalized_text,
channel: "slack".to_string(),
@@ -860,26 +1026,32 @@ mod tests {
#[test]
fn slack_channel_name() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
assert_eq!(ch.name(), "slack");
}
#[test]
fn slack_channel_with_channel_id() {
let ch = SlackChannel::new("xoxb-fake".into(), None, Some("C12345".into()), vec![]);
let ch = SlackChannel::new(
"xoxb-fake".into(),
None,
Some("C12345".into()),
vec![],
vec![],
);
assert_eq!(ch.channel_id, Some("C12345".to_string()));
}
#[test]
fn slack_group_reply_policy_defaults_to_all_messages() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["*".into()]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
assert!(!ch.mention_only);
assert!(ch.group_reply_allowed_sender_ids.is_empty());
}
#[test]
fn slack_group_reply_policy_applies_sender_overrides() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["*".into()])
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()])
.with_group_reply_policy(true, vec![" U111 ".into(), "U111".into(), "U222".into()]);
assert!(ch.mention_only);
@@ -906,16 +1078,55 @@ mod tests {
#[test]
fn configured_app_token_ignores_blank_values() {
let ch = SlackChannel::new("xoxb-fake".into(), Some(" ".into()), None, vec![]);
let ch = SlackChannel::new("xoxb-fake".into(), Some(" ".into()), None, vec![], vec![]);
assert_eq!(ch.configured_app_token(), None);
}
#[test]
fn configured_app_token_trims_value() {
let ch = SlackChannel::new("xoxb-fake".into(), Some(" xapp-123 ".into()), None, vec![]);
let ch = SlackChannel::new(
"xoxb-fake".into(),
Some(" xapp-123 ".into()),
None,
vec![],
vec![],
);
assert_eq!(ch.configured_app_token().as_deref(), Some("xapp-123"));
}
#[test]
fn scoped_channel_ids_prefers_explicit_list() {
let ch = SlackChannel::new(
"xoxb-fake".into(),
None,
Some("C_SINGLE".into()),
vec!["C_LIST1".into(), "D_DM1".into()],
vec![],
);
assert_eq!(
ch.scoped_channel_ids(),
Some(vec!["C_LIST1".to_string(), "D_DM1".to_string()])
);
}
#[test]
fn scoped_channel_ids_falls_back_to_single_channel_id() {
let ch = SlackChannel::new(
"xoxb-fake".into(),
None,
Some("C_SINGLE".into()),
vec![],
vec![],
);
assert_eq!(ch.scoped_channel_ids(), Some(vec!["C_SINGLE".to_string()]));
}
#[test]
fn scoped_channel_ids_returns_none_for_wildcard_mode() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
assert_eq!(ch.scoped_channel_ids(), None);
}
#[test]
fn is_group_channel_id_detects_channel_prefixes() {
assert!(SlackChannel::is_group_channel_id("C123"));
@@ -941,17 +1152,83 @@ mod tests {
#[test]
fn empty_allowlist_denies_everyone() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
assert!(!ch.is_user_allowed("U12345"));
assert!(!ch.is_user_allowed("anyone"));
}
#[test]
fn wildcard_allows_everyone() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["*".into()]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
assert!(ch.is_user_allowed("U12345"));
}
#[test]
fn extract_user_display_name_prefers_profile_display_name() {
let payload = serde_json::json!({
"ok": true,
"user": {
"name": "fallback_name",
"profile": {
"display_name": "Display Name",
"real_name": "Real Name"
}
}
});
assert_eq!(
SlackChannel::extract_user_display_name(&payload).as_deref(),
Some("Display Name")
);
}
#[test]
fn extract_user_display_name_falls_back_to_username() {
let payload = serde_json::json!({
"ok": true,
"user": {
"name": "fallback_name",
"profile": {
"display_name": " ",
"real_name": ""
}
}
});
assert_eq!(
SlackChannel::extract_user_display_name(&payload).as_deref(),
Some("fallback_name")
);
}
#[test]
fn cached_sender_display_name_returns_none_when_expired() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
{
let mut cache = ch.user_display_name_cache.lock().unwrap();
cache.insert(
"U123".to_string(),
CachedSlackDisplayName {
display_name: "Expired Name".to_string(),
expires_at: Instant::now() - Duration::from_secs(1),
},
);
}
assert_eq!(ch.cached_sender_display_name("U123"), None);
}
#[test]
fn cached_sender_display_name_returns_cached_value_when_valid() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
ch.cache_sender_display_name("U123", "Cached Name");
assert_eq!(
ch.cached_sender_display_name("U123").as_deref(),
Some("Cached Name")
);
}
#[test]
fn normalize_incoming_content_requires_mention_when_enabled() {
assert!(SlackChannel::normalize_incoming_content("hello", true, "U_BOT").is_none());
@@ -975,6 +1252,7 @@ mod tests {
"xoxb-fake".into(),
None,
None,
vec![],
vec!["U111".into(), "U222".into()],
);
assert!(ch.is_user_allowed("U111"));
@@ -984,20 +1262,20 @@ mod tests {
#[test]
fn allowlist_exact_match_not_substring() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["U111".into()]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["U111".into()]);
assert!(!ch.is_user_allowed("U1111"));
assert!(!ch.is_user_allowed("U11"));
}
#[test]
fn allowlist_empty_user_id() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["U111".into()]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["U111".into()]);
assert!(!ch.is_user_allowed(""));
}
#[test]
fn allowlist_case_sensitive() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["U111".into()]);
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["U111".into()]);
assert!(ch.is_user_allowed("U111"));
assert!(!ch.is_user_allowed("u111"));
}
@@ -1008,6 +1286,7 @@ mod tests {
"xoxb-fake".into(),
None,
None,
vec![],
vec!["U111".into(), "*".into()],
);
assert!(ch.is_user_allowed("U111"));
+3 -3
View File
@@ -765,7 +765,7 @@ impl TelegramChannel {
}
fn log_poll_transport_error(sanitized: &str, consecutive_failures: u32) {
if consecutive_failures >= 6 && consecutive_failures % 6 == 0 {
if consecutive_failures >= 6 && consecutive_failures.is_multiple_of(6) {
tracing::warn!(
"Telegram poll transport error persists (consecutive={}): {}",
consecutive_failures,
@@ -3109,8 +3109,8 @@ impl Channel for TelegramChannel {
let thread_id = parsed_thread_id.or(thread_ts);
let raw_args = arguments.to_string();
let args_preview = if raw_args.len() > 260 {
format!("{}...", &raw_args[..260])
let args_preview = if raw_args.chars().count() > 260 {
crate::util::truncate_with_ellipsis(&raw_args, 260)
} else {
raw_args
};
+6 -3
View File
@@ -4,9 +4,11 @@ pub mod traits;
#[allow(unused_imports)]
pub use schema::{
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
AgentConfig, AgentsIpcConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig,
BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, ComposioConfig, Config,
build_runtime_proxy_client_with_timeouts, default_model_fallback_for_provider,
resolve_default_model_id, runtime_proxy_config, set_runtime_proxy_config, AgentConfig,
AgentSessionBackend, AgentSessionConfig, AgentSessionStrategy, AgentsIpcConfig, AuditConfig,
AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
ClassificationRule, ComposioConfig, Config,
CoordinationConfig, CostConfig, CronConfig, DelegateAgentConfig, DiscordConfig,
DockerRuntimeConfig, EconomicConfig, EconomicTokenPricing, EmbeddingRouteConfig, EstopConfig,
FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig,
@@ -23,6 +25,7 @@ pub use schema::{
StorageProviderSection, StreamMode, SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig,
TunnelConfig, UrlAccessConfig, WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy,
WasmRuntimeConfig, WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
DEFAULT_MODEL_FALLBACK,
};
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
+485 -19
View File
@@ -1,5 +1,7 @@
use crate::config::traits::ChannelConfig;
use crate::providers::{is_glm_alias, is_zai_alias};
use crate::providers::{
canonical_china_provider_name, is_glm_alias, is_qwen_oauth_alias, is_zai_alias,
};
use crate::security::{AutonomyLevel, DomainMatcher};
use anyhow::{Context, Result};
use directories::UserDirs;
@@ -14,6 +16,100 @@ use tokio::fs::File;
use tokio::fs::{self, OpenOptions};
use tokio::io::AsyncWriteExt;
/// Default fallback model when none is configured. Uses a format compatible with
/// OpenRouter and other multi-provider gateways. For Anthropic direct API, this
/// model ID will be normalized by the provider layer.
pub const DEFAULT_MODEL_FALLBACK: &str = "anthropic/claude-sonnet-4.6";
fn canonical_provider_for_model_defaults(provider_name: &str) -> String {
if let Some(canonical) = canonical_china_provider_name(provider_name) {
return if canonical == "doubao" {
"volcengine".to_string()
} else {
canonical.to_string()
};
}
match provider_name {
"grok" => "xai".to_string(),
"together" => "together-ai".to_string(),
"google" | "google-gemini" => "gemini".to_string(),
"github-copilot" => "copilot".to_string(),
"openai_codex" | "codex" => "openai-codex".to_string(),
"kimi_coding" | "kimi_for_coding" => "kimi-code".to_string(),
"nvidia-nim" | "build.nvidia.com" => "nvidia".to_string(),
"aws-bedrock" => "bedrock".to_string(),
"llama.cpp" => "llamacpp".to_string(),
_ => provider_name.to_string(),
}
}
/// Returns a provider-aware fallback model ID when `default_model` is missing.
pub fn default_model_fallback_for_provider(provider_name: Option<&str>) -> &'static str {
let normalized_provider = provider_name
.unwrap_or("openrouter")
.trim()
.to_ascii_lowercase()
.replace('_', "-");
if normalized_provider == "qwen-coding-plan" {
return "qwen3-coder-plus";
}
let canonical_provider = if is_qwen_oauth_alias(&normalized_provider) {
"qwen-code".to_string()
} else {
canonical_provider_for_model_defaults(&normalized_provider)
};
match canonical_provider.as_str() {
"anthropic" => "claude-sonnet-4-5-20250929",
"openai" => "gpt-5.2",
"openai-codex" => "gpt-5-codex",
"venice" => "zai-org-glm-5",
"groq" => "llama-3.3-70b-versatile",
"mistral" => "mistral-large-latest",
"deepseek" => "deepseek-chat",
"xai" => "grok-4-1-fast-reasoning",
"perplexity" => "sonar-pro",
"fireworks" => "accounts/fireworks/models/llama-v3p3-70b-instruct",
"novita" => "minimax/minimax-m2.5",
"together-ai" => "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"cohere" => "command-a-03-2025",
"moonshot" => "kimi-k2.5",
"hunyuan" => "hunyuan-t1-latest",
"glm" | "zai" => "glm-5",
"minimax" => "MiniMax-M2.5",
"qwen" => "qwen-plus",
"volcengine" => "doubao-1-5-pro-32k-250115",
"siliconflow" => "Pro/zai-org/GLM-4.7",
"qwen-code" => "qwen3-coder-plus",
"ollama" => "llama3.2",
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF",
"sglang" | "vllm" | "osaurus" | "copilot" => "default",
"gemini" => "gemini-2.5-pro",
"kimi-code" => "kimi-for-coding",
"bedrock" => "anthropic.claude-sonnet-4-5-20250929-v1:0",
"nvidia" => "meta/llama-3.3-70b-instruct",
_ => DEFAULT_MODEL_FALLBACK,
}
}
/// Resolves the model ID used by runtime components.
/// Preference order:
/// 1) Explicit configured model (if non-empty)
/// 2) Provider-aware fallback
pub fn resolve_default_model_id(
default_model: Option<&str>,
provider_name: Option<&str>,
) -> String {
if let Some(model) = default_model.map(str::trim).filter(|m| !m.is_empty()) {
return model.to_string();
}
default_model_fallback_for_provider(provider_name).to_string()
}
const SUPPORTED_PROXY_SERVICE_KEYS: &[&str] = &[
"provider.anthropic",
"provider.compatible",
@@ -60,6 +156,8 @@ const SUPPORTED_PROXY_SERVICE_SELECTORS: &[&str] = &[
static RUNTIME_PROXY_CONFIG: OnceLock<RwLock<ProxyConfig>> = OnceLock::new();
static RUNTIME_PROXY_CLIENT_CACHE: OnceLock<RwLock<HashMap<String, reqwest::Client>>> =
OnceLock::new();
const DEFAULT_PROVIDER_NAME: &str = "openrouter";
const DEFAULT_MODEL_NAME: &str = "anthropic/claude-sonnet-4.6";
// ── Top-level config ──────────────────────────────────────────────
@@ -305,6 +403,12 @@ pub struct ModelProviderConfig {
/// Provider protocol variant ("responses" or "chat_completions").
#[serde(default)]
pub wire_api: Option<String>,
/// Optional profile-scoped default model.
#[serde(default, alias = "model")]
pub default_model: Option<String>,
/// Optional profile-scoped API key.
#[serde(default)]
pub api_key: Option<String>,
/// If true, load OpenAI auth material (OPENAI_API_KEY or ~/.codex/auth.json).
#[serde(default)]
pub requires_openai_auth: bool,
@@ -416,6 +520,7 @@ impl std::fmt::Debug for Config {
self.channels_config.dingtalk.is_some(),
self.channels_config.napcat.is_some(),
self.channels_config.qq.is_some(),
self.channels_config.acp.is_some(),
self.channels_config.nostr.is_some(),
self.channels_config.clawdtalk.is_some(),
]
@@ -726,6 +831,8 @@ pub struct AgentConfig {
/// When true: bootstrap_max_chars=6000, rag_chunk_limit=2. Use for 13B or smaller models.
#[serde(default)]
pub compact_context: bool,
#[serde(default)]
pub session: AgentSessionConfig,
/// Maximum tool-call loop turns per user message. Default: `20`.
/// Setting to `0` falls back to the safe default of `20`.
#[serde(default = "default_agent_max_tool_iterations")]
@@ -770,6 +877,47 @@ pub struct AgentConfig {
pub safety_heartbeat_turn_interval: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum AgentSessionBackend {
Memory,
Sqlite,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "kebab-case")]
pub enum AgentSessionStrategy {
PerSender,
PerChannel,
Main,
}
/// Session persistence configuration (`[agent.session]` section).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentSessionConfig {
/// Session backend to use. Options: "memory", "sqlite", "none".
/// Default: "none" (no persistence).
/// Set to "none" to disable session persistence entirely.
#[serde(default = "default_agent_session_backend")]
pub backend: AgentSessionBackend,
/// Strategy for resolving session IDs. Options: "per-sender", "per-channel", "main".
/// Default: "per-sender" (each user gets a unique session per channel).
#[serde(default = "default_agent_session_strategy")]
pub strategy: AgentSessionStrategy,
/// Time-to-live for sessions in seconds.
/// Default: 3600 (1 hour).
#[serde(default = "default_agent_session_ttl_seconds")]
pub ttl_seconds: u64,
/// Maximum number of messages to retain per session.
/// Default: 50.
#[serde(default = "default_agent_session_max_messages")]
pub max_messages: usize,
}
fn default_agent_max_tool_iterations() -> usize {
20
}
@@ -782,6 +930,22 @@ fn default_agent_tool_dispatcher() -> String {
"auto".into()
}
fn default_agent_session_backend() -> AgentSessionBackend {
AgentSessionBackend::None
}
fn default_agent_session_strategy() -> AgentSessionStrategy {
AgentSessionStrategy::PerSender
}
fn default_agent_session_ttl_seconds() -> u64 {
3600
}
fn default_agent_session_max_messages() -> usize {
default_agent_max_history_messages()
}
fn default_loop_detection_no_progress_threshold() -> usize {
3
}
@@ -806,6 +970,7 @@ impl Default for AgentConfig {
fn default() -> Self {
Self {
compact_context: true,
session: AgentSessionConfig::default(),
max_tool_iterations: default_agent_max_tool_iterations(),
max_history_messages: default_agent_max_history_messages(),
parallel_tools: false,
@@ -819,6 +984,17 @@ impl Default for AgentConfig {
}
}
impl Default for AgentSessionConfig {
fn default() -> Self {
Self {
backend: default_agent_session_backend(),
strategy: default_agent_session_strategy(),
ttl_seconds: default_agent_session_ttl_seconds(),
max_messages: default_agent_session_max_messages(),
}
}
}
/// Skills loading configuration (`[skills]` section).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
@@ -3874,6 +4050,8 @@ impl<T: ChannelConfig> crate::config::traits::ConfigHandle for ConfigWrapper<T>
pub struct ChannelsConfig {
/// Enable the CLI interactive channel. Default: `true`.
pub cli: bool,
/// ACP (Agent Client Protocol) channel configuration.
pub acp: Option<AcpConfig>,
/// Telegram bot channel configuration.
pub telegram: Option<TelegramConfig>,
/// Discord bot channel configuration.
@@ -4021,6 +4199,10 @@ impl ChannelsConfig {
Box::new(ConfigWrapper::new(self.nostr.as_ref())),
self.nostr.is_some(),
),
(
Box::new(ConfigWrapper::new(self.acp.as_ref())),
self.acp.is_some(),
),
(
Box::new(ConfigWrapper::new(self.clawdtalk.as_ref())),
self.clawdtalk.is_some(),
@@ -4046,6 +4228,7 @@ impl Default for ChannelsConfig {
fn default() -> Self {
Self {
cli: true,
acp: None,
telegram: None,
discord: None,
slack: None,
@@ -4267,7 +4450,12 @@ pub struct SlackConfig {
pub app_token: Option<String>,
/// Optional channel ID to restrict the bot to a single channel.
/// Omit (or set `"*"`) to listen across all accessible channels.
/// Ignored when `channel_ids` is non-empty.
pub channel_id: Option<String>,
/// Explicit list of channel/DM IDs to listen on simultaneously.
/// Takes precedence over `channel_id`. Empty = fall back to `channel_id`.
#[serde(default)]
pub channel_ids: Vec<String>,
/// Allowed Slack user IDs. Empty = deny all.
#[serde(default)]
pub allowed_users: Vec<String>,
@@ -5654,9 +5842,9 @@ impl Default for Config {
config_path: zeroclaw_dir.join("config.toml"),
api_key: None,
api_url: None,
default_provider: Some("openrouter".to_string()),
default_provider: Some(DEFAULT_PROVIDER_NAME.to_string()),
provider_api: None,
default_model: Some("anthropic/claude-sonnet-4.6".to_string()),
default_model: Some(DEFAULT_MODEL_NAME.to_string()),
model_providers: HashMap::new(),
provider: ProviderConfig::default(),
default_temperature: 0.7,
@@ -5848,7 +6036,10 @@ pub(crate) async fn persist_active_workspace_config_dir(config_dir: &Path) -> Re
);
}
#[cfg(unix)]
sync_directory(&default_config_dir).await?;
#[cfg(not(unix))]
sync_directory(&default_config_dir)?;
Ok(())
}
@@ -6614,6 +6805,10 @@ impl Config {
config.workspace_dir = workspace_dir;
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
for (profile_name, profile) in config.model_providers.iter_mut() {
let secret_path = format!("config.model_providers.{profile_name}.api_key");
decrypt_optional_secret(&store, &mut profile.api_key, &secret_path)?;
}
decrypt_optional_secret(
&store,
&mut config.transcription.api_key,
@@ -6850,6 +7045,18 @@ impl Config {
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToString::to_string);
let profile_default_model = profile
.default_model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToString::to_string);
let profile_api_key = profile
.api_key
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToString::to_string);
if self
.api_url
@@ -6862,6 +7069,30 @@ impl Config {
}
}
if self
.api_key
.as_deref()
.map(str::trim)
.is_none_or(|value| value.is_empty())
{
if let Some(profile_api_key) = profile_api_key {
self.api_key = Some(profile_api_key);
}
}
if let Some(profile_default_model) = profile_default_model {
let can_apply_profile_model =
self.default_model
.as_deref()
.map(str::trim)
.is_none_or(|value| {
value.is_empty() || value.eq_ignore_ascii_case(DEFAULT_MODEL_NAME)
});
if can_apply_profile_model {
self.default_model = Some(profile_default_model);
}
}
if profile.requires_openai_auth
&& self
.api_key
@@ -6908,6 +7139,10 @@ impl Config {
/// Called after TOML deserialization and env-override application to catch
/// obviously invalid values early instead of failing at arbitrary runtime points.
pub fn validate(&self) -> Result<()> {
if let Some(acp) = &self.channels_config.acp {
acp.validate()?;
}
// Gateway
if self.gateway.host.trim().is_empty() {
anyhow::bail!("gateway.host must not be empty");
@@ -7336,6 +7571,24 @@ impl Config {
}
}
if let Some(default_hint) = self
.default_model
.as_deref()
.and_then(|model| model.strip_prefix("hint:"))
.map(str::trim)
.filter(|hint| !hint.is_empty())
{
if !self
.model_routes
.iter()
.any(|route| route.hint.trim() == default_hint)
{
anyhow::bail!(
"default_model uses hint '{default_hint}', but no matching [[model_routes]] entry exists"
);
}
}
if self
.provider
.transport
@@ -7541,7 +7794,9 @@ impl Config {
} else if let Ok(provider) = std::env::var("PROVIDER") {
let should_apply_legacy_provider =
self.default_provider.as_deref().map_or(true, |configured| {
configured.trim().eq_ignore_ascii_case("openrouter")
configured
.trim()
.eq_ignore_ascii_case(DEFAULT_PROVIDER_NAME)
});
if should_apply_legacy_provider && !provider.is_empty() {
self.default_provider = Some(provider);
@@ -8125,6 +8380,10 @@ impl Config {
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
for (profile_name, profile) in config_to_save.model_providers.iter_mut() {
let secret_path = format!("config.model_providers.{profile_name}.api_key");
encrypt_optional_secret(&store, &mut profile.api_key, &secret_path)?;
}
encrypt_optional_secret(
&store,
&mut config_to_save.transcription.api_key,
@@ -8291,7 +8550,10 @@ impl Config {
})?;
}
#[cfg(unix)]
sync_directory(parent_dir).await?;
#[cfg(not(unix))]
sync_directory(parent_dir)?;
if had_existing_config {
let _ = fs::remove_file(&backup_path).await;
@@ -8301,32 +8563,83 @@ impl Config {
}
}
#[cfg(unix)]
async fn sync_directory(path: &Path) -> Result<()> {
#[cfg(unix)]
{
let dir = File::open(path)
.await
.with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?;
dir.sync_all()
.await
.with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?;
Ok(())
}
let dir = File::open(path)
.await
.with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?;
dir.sync_all()
.await
.with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?;
Ok(())
}
#[cfg(not(unix))]
fn sync_directory(path: &Path) -> Result<()> {
let _ = path;
Ok(())
}
/// ACP (Agent Client Protocol) channel configuration.
///
/// Enables ZeroClaw to act as an ACP client, connecting to an OpenCode ACP server
/// via `opencode acp` command for JSON-RPC 2.0 communication over stdio.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AcpConfig {
/// OpenCode binary path (default: "opencode").
#[serde(default = "default_acp_opencode_path")]
pub opencode_path: Option<String>,
/// Working directory for OpenCode process.
pub workdir: Option<String>,
/// Additional arguments to pass to `opencode acp`.
#[serde(default)]
pub extra_args: Vec<String>,
/// Allowed user identifiers (empty = deny all, "*" = allow all).
#[serde(default)]
pub allowed_users: Vec<String>,
}
fn default_acp_opencode_path() -> Option<String> {
Some("opencode".to_string())
}
impl AcpConfig {
fn validate(&self) -> Result<()> {
if self
.opencode_path
.as_deref()
.is_some_and(|path| path.trim().is_empty())
{
anyhow::bail!("channels_config.acp.opencode_path must not be empty when set");
}
if self
.workdir
.as_deref()
.is_some_and(|dir| dir.trim().is_empty())
{
anyhow::bail!("channels_config.acp.workdir must not be empty when set");
}
#[cfg(not(unix))]
{
let _ = path;
Ok(())
}
}
impl ChannelConfig for AcpConfig {
fn name() -> &'static str {
"ACP"
}
fn desc() -> &'static str {
"Agent Client Protocol channel for OpenCode integration"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use tempfile::TempDir;
use tokio::sync::{Mutex, MutexGuard};
use tokio::test;
use tokio_stream::wrappers::ReadDirStream;
@@ -8518,7 +8831,7 @@ mod tests {
#[cfg(unix)]
#[test]
async fn save_sets_config_permissions_on_new_file() {
let temp = TempDir::new().expect("temp dir");
let temp = tempfile::TempDir::new().expect("temp dir");
let config_path = temp.path().join("config.toml");
let workspace_dir = temp.path().join("workspace");
@@ -8832,6 +9145,7 @@ ws_url = "ws://127.0.0.1:3002"
goal_loop: GoalLoopConfig::default(),
channels_config: ChannelsConfig {
cli: true,
acp: None,
telegram: Some(TelegramConfig {
bot_token: "123:ABC".into(),
allowed_users: vec!["user1".into()],
@@ -9199,7 +9513,10 @@ tool_dispatcher = "xml"
));
fs::create_dir_all(&dir).await.unwrap();
#[cfg(unix)]
sync_directory(&dir).await.unwrap();
#[cfg(not(unix))]
sync_directory(&dir).unwrap();
let _ = fs::remove_dir_all(&dir).await;
}
@@ -9763,6 +10080,7 @@ allowed_users = ["@ops:matrix.org"]
async fn channels_config_with_imessage_and_matrix() {
let c = ChannelsConfig {
cli: true,
acp: None,
telegram: None,
discord: None,
slack: None,
@@ -9834,6 +10152,7 @@ allowed_users = ["@ops:matrix.org"]
async fn slack_config_deserializes_without_allowed_users() {
let json = r#"{"bot_token":"xoxb-tok"}"#;
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
assert!(parsed.channel_ids.is_empty());
assert!(parsed.allowed_users.is_empty());
assert_eq!(
parsed.effective_group_reply_mode(),
@@ -9845,6 +10164,7 @@ allowed_users = ["@ops:matrix.org"]
async fn slack_config_deserializes_with_allowed_users() {
let json = r#"{"bot_token":"xoxb-tok","allowed_users":["U111"]}"#;
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
assert!(parsed.channel_ids.is_empty());
assert_eq!(parsed.allowed_users, vec!["U111"]);
}
@@ -9866,6 +10186,7 @@ bot_token = "xoxb-tok"
channel_id = "C123"
"#;
let parsed: SlackConfig = toml::from_str(toml_str).unwrap();
assert!(parsed.channel_ids.is_empty());
assert!(parsed.allowed_users.is_empty());
assert_eq!(parsed.channel_id.as_deref(), Some("C123"));
assert_eq!(
@@ -10045,6 +10366,7 @@ channel_id = "C123"
async fn channels_config_with_whatsapp() {
let c = ChannelsConfig {
cli: true,
acp: None,
telegram: None,
discord: None,
slack: None,
@@ -10642,6 +10964,8 @@ model = "gpt-5.3-codex"
name = "sub2api"
base_url = "https://api.tonsof.blue/v1"
wire_api = "responses"
model = "gpt-5.3-codex"
api_key = "profile-key"
requires_openai_auth = true
"#;
@@ -10653,6 +10977,8 @@ requires_openai_auth = true
.get("sub2api")
.expect("profile should exist");
assert_eq!(profile.wire_api.as_deref(), Some("responses"));
assert_eq!(profile.default_model.as_deref(), Some("gpt-5.3-codex"));
assert_eq!(profile.api_key.as_deref(), Some("profile-key"));
assert!(profile.requires_openai_auth);
}
@@ -10813,6 +11139,67 @@ provider_api = "not-a-real-mode"
.contains("model_routes[0].max_tokens must be greater than 0"));
}
#[test]
async fn default_model_hint_requires_matching_model_route() {
let mut config = Config::default();
config.default_model = Some("hint:reasoning".to_string());
config.model_routes = vec![ModelRouteConfig {
hint: "fast".to_string(),
provider: "openrouter".to_string(),
model: "openai/gpt-5.2".to_string(),
max_tokens: None,
api_key: None,
transport: None,
}];
let err = config
.validate()
.expect_err("default_model hint without matching route should fail");
assert!(err
.to_string()
.contains("default_model uses hint 'reasoning'"));
}
#[test]
async fn default_model_hint_accepts_matching_model_route() {
let mut config = Config::default();
config.default_model = Some("hint:reasoning".to_string());
config.model_routes = vec![ModelRouteConfig {
hint: "reasoning".to_string(),
provider: "openrouter".to_string(),
model: "openai/gpt-5.2".to_string(),
max_tokens: None,
api_key: None,
transport: None,
}];
let result = config.validate();
assert!(
result.is_ok(),
"matching default hint route should validate"
);
}
#[test]
async fn default_model_hint_accepts_matching_model_route_with_whitespace() {
let mut config = Config::default();
config.default_model = Some("hint: reasoning ".to_string());
config.model_routes = vec![ModelRouteConfig {
hint: " reasoning ".to_string(),
provider: "openrouter".to_string(),
model: "openai/gpt-5.2".to_string(),
max_tokens: None,
api_key: None,
transport: None,
}];
let result = config.validate();
assert!(
result.is_ok(),
"trimmed default hint should match trimmed route hint"
);
}
#[test]
async fn provider_transport_normalizes_aliases() {
let mut config = Config::default();
@@ -10897,6 +11284,31 @@ provider_api = "not-a-real-mode"
std::env::remove_var("ZEROCLAW_MODEL");
}
#[test]
async fn resolve_default_model_id_prefers_configured_model() {
let resolved =
resolve_default_model_id(Some(" anthropic/claude-opus-4.6 "), Some("openrouter"));
assert_eq!(resolved, "anthropic/claude-opus-4.6");
}
#[test]
async fn resolve_default_model_id_uses_provider_specific_fallback() {
let openai = resolve_default_model_id(None, Some("openai"));
assert_eq!(openai, "gpt-5.2");
let bedrock = resolve_default_model_id(None, Some("aws-bedrock"));
assert_eq!(bedrock, "anthropic.claude-sonnet-4-5-20250929-v1:0");
}
#[test]
async fn resolve_default_model_id_handles_special_provider_aliases() {
let qwen_coding_plan = resolve_default_model_id(None, Some("qwen-coding-plan"));
assert_eq!(qwen_coding_plan, "qwen3-coder-plus");
let google_alias = resolve_default_model_id(None, Some("google-gemini"));
assert_eq!(google_alias, "gemini-2.5-pro");
}
#[test]
async fn model_provider_profile_maps_to_custom_endpoint() {
let _env_guard = env_override_lock().await;
@@ -10908,6 +11320,8 @@ provider_api = "not-a-real-mode"
name: Some("sub2api".to_string()),
base_url: Some("https://api.tonsof.blue/v1".to_string()),
wire_api: None,
default_model: None,
api_key: None,
requires_openai_auth: false,
},
)]),
@@ -10936,6 +11350,8 @@ provider_api = "not-a-real-mode"
name: Some("sub2api".to_string()),
base_url: Some("https://api.tonsof.blue".to_string()),
wire_api: Some("responses".to_string()),
default_model: None,
api_key: None,
requires_openai_auth: true,
},
)]),
@@ -10998,6 +11414,8 @@ provider_api = "not-a-real-mode"
name: Some("sub2api".to_string()),
base_url: Some("https://api.tonsof.blue/v1".to_string()),
wire_api: Some("ws".to_string()),
default_model: None,
api_key: None,
requires_openai_auth: false,
},
)]),
@@ -11010,6 +11428,54 @@ provider_api = "not-a-real-mode"
.contains("wire_api must be one of: responses, chat_completions"));
}
#[test]
async fn model_provider_profile_uses_profile_api_key_when_global_is_missing() {
let _env_guard = env_override_lock().await;
let mut config = Config {
default_provider: Some("sub2api".to_string()),
api_key: None,
model_providers: HashMap::from([(
"sub2api".to_string(),
ModelProviderConfig {
name: Some("sub2api".to_string()),
base_url: Some("https://api.tonsof.blue/v1".to_string()),
wire_api: None,
default_model: None,
api_key: Some("profile-api-key".to_string()),
requires_openai_auth: false,
},
)]),
..Config::default()
};
config.apply_env_overrides();
assert_eq!(config.api_key.as_deref(), Some("profile-api-key"));
}
#[test]
async fn model_provider_profile_can_override_default_model_when_openrouter_default_is_set() {
let _env_guard = env_override_lock().await;
let mut config = Config {
default_provider: Some("sub2api".to_string()),
default_model: Some(DEFAULT_MODEL_NAME.to_string()),
model_providers: HashMap::from([(
"sub2api".to_string(),
ModelProviderConfig {
name: Some("sub2api".to_string()),
base_url: Some("https://api.tonsof.blue/v1".to_string()),
wire_api: None,
default_model: Some("qwen-max".to_string()),
api_key: None,
requires_openai_auth: false,
},
)]),
..Config::default()
};
config.apply_env_overrides();
assert_eq!(config.default_model.as_deref(), Some("qwen-max"));
}
#[test]
async fn env_override_model_fallback() {
let _env_guard = env_override_lock().await;
+21 -17
View File
@@ -59,7 +59,7 @@ pub async fn run(config: Config) -> Result<()> {
pub async fn execute_job_now(config: &Config, job: &CronJob) -> (bool, String) {
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
execute_job_with_retry(config, &security, job).await
Box::pin(execute_job_with_retry(config, &security, job)).await
}
async fn execute_job_with_retry(
@@ -74,7 +74,7 @@ async fn execute_job_with_retry(
for attempt in 0..=retries {
let (success, output) = match job.job_type {
JobType::Shell => run_job_command(config, security, job).await,
JobType::Agent => run_agent_job(config, security, job).await,
JobType::Agent => Box::pin(run_agent_job(config, security, job)).await,
};
last_output = output;
@@ -107,18 +107,21 @@ async fn process_due_jobs(
crate::health::mark_component_ok(component);
let max_concurrent = config.scheduler.max_concurrent.max(1);
let mut in_flight =
stream::iter(
jobs.into_iter().map(|job| {
let config = config.clone();
let security = Arc::clone(security);
let component = component.to_owned();
async move {
execute_and_persist_job(&config, security.as_ref(), &job, &component).await
}
}),
)
.buffer_unordered(max_concurrent);
let mut in_flight = stream::iter(jobs.into_iter().map(|job| {
let config = config.clone();
let security = Arc::clone(security);
let component = component.to_owned();
async move {
Box::pin(execute_and_persist_job(
&config,
security.as_ref(),
&job,
&component,
))
.await
}
}))
.buffer_unordered(max_concurrent);
while let Some((job_id, success, output)) = in_flight.next().await {
if !success {
@@ -137,7 +140,7 @@ async fn execute_and_persist_job(
warn_if_high_frequency_agent_job(job);
let started_at = Utc::now();
let (success, output) = execute_job_with_retry(config, security, job).await;
let (success, output) = Box::pin(execute_job_with_retry(config, security, job)).await;
let finished_at = Utc::now();
let success = persist_job_result(config, job, success, &output, started_at, finished_at).await;
@@ -176,7 +179,7 @@ async fn run_agent_job(
let run_result = match job.session_target {
SessionTarget::Main | SessionTarget::Isolated => {
crate::agent::run(
Box::pin(crate::agent::run(
config.clone(),
Some(prefixed_prompt),
None,
@@ -184,7 +187,7 @@ async fn run_agent_job(
config.default_temperature,
vec![],
false,
)
))
.await
}
};
@@ -364,6 +367,7 @@ pub(crate) async fn deliver_announcement(
sl.bot_token.clone(),
sl.app_token.clone(),
sl.channel_id.clone(),
sl.channel_ids.clone(),
sl.allowed_users.clone(),
);
channel.send(&SendMessage::new(output, target)).await?;
+3 -3
View File
@@ -65,7 +65,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
max_backoff,
move || {
let cfg = channels_cfg.clone();
async move { crate::channels::start_channels(cfg).await }
async move { Box::pin(crate::channels::start_channels(cfg)).await }
},
));
} else {
@@ -214,7 +214,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
for task in tasks {
let prompt = format!("[Heartbeat Task] {task}");
let temp = config.default_temperature;
match crate::agent::run(
match Box::pin(crate::agent::run(
config.clone(),
Some(prompt),
None,
@@ -222,7 +222,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
temp,
vec![],
false,
)
))
.await
{
Ok(output) => {
+5 -3
View File
@@ -115,7 +115,9 @@ impl TaskClassifier {
/// Load all 44 BLS occupations with wage data
fn load_occupations() -> Vec<Occupation> {
use OccupationCategory::*;
use OccupationCategory::{
BusinessFinance, HealthcareSocialServices, LegalMediaOperations, TechnologyEngineering,
};
vec![
// Technology & Engineering
@@ -732,11 +734,11 @@ impl TaskClassifier {
};
// Scale by instruction length
let length_factor = (word_count as f64 / 20.0).max(0.5).min(2.0);
let length_factor = (word_count as f64 / 20.0).clamp(0.5, 2.0);
let hours = base_hours * length_factor;
// Clamp to valid range
hours.max(0.25).min(40.0)
hours.clamp(0.25, 40.0)
}
/// Get all occupations
+2 -7
View File
@@ -9,12 +9,13 @@ use std::fmt;
/// Survival status based on balance percentage relative to initial capital.
///
/// Mirrors the ClawWork LiveBench agent survival states.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum SurvivalStatus {
/// Balance > 80% of initial - Agent is profitable and healthy
Thriving,
/// Balance 40-80% of initial - Agent is maintaining stability
#[default]
Stable,
/// Balance 10-40% of initial - Agent is losing money, needs attention
Struggling,
@@ -100,12 +101,6 @@ impl fmt::Display for SurvivalStatus {
}
}
impl Default for SurvivalStatus {
fn default() -> Self {
Self::Stable
}
}
#[cfg(test)]
mod tests {
use super::*;
+81 -23
View File
@@ -90,6 +90,17 @@ fn qq_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String {
format!("qq_{}_{}", msg.sender, msg.id)
}
fn gateway_message_session_id(msg: &crate::channels::traits::ChannelMessage) -> String {
if msg.channel == "qq" || msg.channel == "napcat" {
return format!("{}_{}", msg.channel, msg.sender);
}
match &msg.thread_ts {
Some(thread_id) => format!("{}_{}_{}", msg.channel, thread_id, msg.sender),
None => format!("{}_{}", msg.channel, msg.sender),
}
}
fn hash_webhook_secret(value: &str) -> String {
use sha2::{Digest, Sha256};
@@ -1024,9 +1035,10 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
pub(super) async fn run_gateway_chat_with_tools(
state: &AppState,
message: &str,
session_id: Option<&str>,
) -> anyhow::Result<String> {
let config = state.config.lock().clone();
crate::agent::process_message(config, message).await
crate::agent::process_message_with_session(config, message, session_id).await
}
fn gateway_outbound_leak_guard_snapshot(
@@ -1062,6 +1074,8 @@ pub struct WebhookBody {
pub message: String,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub session_id: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
@@ -1579,6 +1593,11 @@ async fn handle_webhook(
}
let message = webhook_body.message.trim();
let webhook_session_id = webhook_body
.session_id
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty());
if message.is_empty() {
let err = serde_json::json!({
"error": "The `message` field is required and must be a non-empty string."
@@ -1590,7 +1609,12 @@ async fn handle_webhook(
let key = webhook_memory_key();
let _ = state
.mem
.store(&key, message, MemoryCategory::Conversation, None)
.store(
&key,
message,
MemoryCategory::Conversation,
webhook_session_id,
)
.await;
}
@@ -1786,8 +1810,7 @@ async fn handle_whatsapp_verify(
/// Returns true if the signature is valid, false otherwise.
/// See: <https://developers.facebook.com/docs/graph-api/webhooks/getting-started#verification-requests>
pub fn verify_whatsapp_signature(app_secret: &str, body: &[u8], signature_header: &str) -> bool {
use hmac::{Hmac, Mac};
use sha2::Sha256;
use ring::hmac;
// Signature format: "sha256=<hex_signature>"
let Some(hex_sig) = signature_header.strip_prefix("sha256=") else {
@@ -1799,14 +1822,8 @@ pub fn verify_whatsapp_signature(app_secret: &str, body: &[u8], signature_header
return false;
};
// Compute HMAC-SHA256
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(app_secret.as_bytes()) else {
return false;
};
mac.update(body);
// Constant-time comparison
mac.verify_slice(&expected).is_ok()
let key = hmac::Key::new(hmac::HMAC_SHA256, app_secret.as_bytes());
hmac::verify(&key, body, &expected).is_ok()
}
/// POST /whatsapp — incoming message webhook
@@ -1868,17 +1885,23 @@ async fn handle_whatsapp_message(
msg.sender,
truncate_with_ellipsis(&msg.content, 50)
);
let session_id = gateway_message_session_id(msg);
// Auto-save to memory
if state.auto_save {
let key = whatsapp_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation, None)
.store(
&key,
&msg.content,
MemoryCategory::Conversation,
Some(&session_id),
)
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
Ok(response) => {
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
let safe_response = sanitize_gateway_response(
@@ -1990,18 +2013,24 @@ async fn handle_linq_webhook(
msg.sender,
truncate_with_ellipsis(&msg.content, 50)
);
let session_id = gateway_message_session_id(msg);
// Auto-save to memory
if state.auto_save {
let key = linq_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation, None)
.store(
&key,
&msg.content,
MemoryCategory::Conversation,
Some(&session_id),
)
.await;
}
// Call the LLM
match run_gateway_chat_with_tools(&state, &msg.content).await {
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
Ok(response) => {
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
let safe_response = sanitize_gateway_response(
@@ -2034,6 +2063,7 @@ async fn handle_linq_webhook(
}
/// POST /github — incoming GitHub webhook (issue/PR comments)
#[allow(clippy::large_futures)]
async fn handle_github_webhook(
State(state): State<AppState>,
headers: HeaderMap,
@@ -2167,7 +2197,7 @@ async fn handle_github_webhook(
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match run_gateway_chat_with_tools(&state, &msg.content, None).await {
Ok(response) => {
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
let safe_response = sanitize_gateway_response(
@@ -2351,18 +2381,24 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
msg.sender,
truncate_with_ellipsis(&msg.content, 50)
);
let session_id = gateway_message_session_id(msg);
// Auto-save to memory
if state.auto_save {
let key = wati_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation, None)
.store(
&key,
&msg.content,
MemoryCategory::Conversation,
Some(&session_id),
)
.await;
}
// Call the LLM
match run_gateway_chat_with_tools(&state, &msg.content).await {
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
Ok(response) => {
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
let safe_response = sanitize_gateway_response(
@@ -2463,16 +2499,22 @@ async fn handle_nextcloud_talk_webhook(
msg.sender,
truncate_with_ellipsis(&msg.content, 50)
);
let session_id = gateway_message_session_id(msg);
if state.auto_save {
let key = nextcloud_talk_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation, None)
.store(
&key,
&msg.content,
MemoryCategory::Conversation,
Some(&session_id),
)
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
Ok(response) => {
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
let safe_response = sanitize_gateway_response(
@@ -2558,16 +2600,22 @@ async fn handle_qq_webhook(
msg.sender,
truncate_with_ellipsis(&msg.content, 50)
);
let session_id = gateway_message_session_id(msg);
if state.auto_save {
let key = qq_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation, None)
.store(
&key,
&msg.content,
MemoryCategory::Conversation,
Some(&session_id),
)
.await;
}
match run_gateway_chat_with_tools(&state, &msg.content).await {
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
Ok(response) => {
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
let safe_response = sanitize_gateway_response(
@@ -3365,6 +3413,7 @@ Reminder set successfully."#;
let body = Ok(Json(WebhookBody {
message: "hello".into(),
stream: None,
session_id: None,
}));
let first = handle_webhook(
State(state.clone()),
@@ -3379,6 +3428,7 @@ Reminder set successfully."#;
let body = Ok(Json(WebhookBody {
message: "hello".into(),
stream: None,
session_id: None,
}));
let second = handle_webhook(State(state), test_connect_info(), headers, body)
.await
@@ -3437,6 +3487,7 @@ Reminder set successfully."#;
Ok(Json(WebhookBody {
message: "hello".into(),
stream: None,
session_id: None,
})),
)
.await
@@ -3490,6 +3541,7 @@ Reminder set successfully."#;
Ok(Json(WebhookBody {
message: " ".into(),
stream: None,
session_id: None,
})),
)
.await
@@ -3544,6 +3596,7 @@ Reminder set successfully."#;
Ok(Json(WebhookBody {
message: "stream me".into(),
stream: Some(true),
session_id: None,
})),
)
.await
@@ -3722,6 +3775,7 @@ Reminder set successfully."#;
let body1 = Ok(Json(WebhookBody {
message: "hello one".into(),
stream: None,
session_id: None,
}));
let first = handle_webhook(
State(state.clone()),
@@ -3736,6 +3790,7 @@ Reminder set successfully."#;
let body2 = Ok(Json(WebhookBody {
message: "hello two".into(),
stream: None,
session_id: None,
}));
let second = handle_webhook(State(state), test_connect_info(), headers, body2)
.await
@@ -3809,6 +3864,7 @@ Reminder set successfully."#;
Ok(Json(WebhookBody {
message: "hello".into(),
stream: None,
session_id: None,
})),
)
.await
@@ -3871,6 +3927,7 @@ Reminder set successfully."#;
Ok(Json(WebhookBody {
message: "hello".into(),
stream: None,
session_id: None,
})),
)
.await
@@ -3929,6 +3986,7 @@ Reminder set successfully."#;
Ok(Json(WebhookBody {
message: "hello".into(),
stream: None,
session_id: None,
})),
)
.await
+14 -4
View File
@@ -131,6 +131,11 @@ pub async fn handle_api_chat(
};
let message = chat_body.message.trim();
let session_id = chat_body
.session_id
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty());
if message.is_empty() {
let err = serde_json::json!({ "error": "Message cannot be empty" });
return (StatusCode::BAD_REQUEST, Json(err));
@@ -141,7 +146,7 @@ pub async fn handle_api_chat(
let key = api_chat_memory_key();
let _ = state
.mem
.store(&key, message, MemoryCategory::Conversation, None)
.store(&key, message, MemoryCategory::Conversation, session_id)
.await;
}
@@ -186,7 +191,7 @@ pub async fn handle_api_chat(
});
// ── Run the full agent loop ──
match run_gateway_chat_with_tools(&state, &enriched_message).await {
match run_gateway_chat_with_tools(&state, &enriched_message, session_id).await {
Ok(response) => {
let leak_guard_cfg = state.config.lock().security.outbound_leak_guard.clone();
let safe_response = sanitize_gateway_response(
@@ -519,6 +524,11 @@ pub async fn handle_v1_chat_completions_with_tools(
};
let is_stream = request.stream.unwrap_or(false);
let session_id = request
.user
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty());
let request_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
let created = unix_timestamp();
@@ -527,7 +537,7 @@ pub async fn handle_v1_chat_completions_with_tools(
let key = api_chat_memory_key();
let _ = state
.mem
.store(&key, &message, MemoryCategory::Conversation, None)
.store(&key, &message, MemoryCategory::Conversation, session_id)
.await;
}
@@ -562,7 +572,7 @@ pub async fn handle_v1_chat_completions_with_tools(
);
// ── Run the full agent loop ──
let reply = match run_gateway_chat_with_tools(&state, &enriched_message).await {
let reply = match run_gateway_chat_with_tools(&state, &enriched_message, session_id).await {
Ok(response) => {
let leak_guard_cfg = state.config.lock().security.outbound_leak_guard.clone();
let safe = sanitize_gateway_response(
+261 -21
View File
@@ -11,6 +11,7 @@
use super::AppState;
use crate::agent::loop_::{build_shell_policy_instructions, build_tool_instructions_from_specs};
use crate::memory::MemoryCategory;
use crate::providers::ChatMessage;
use axum::{
extract::{
@@ -20,9 +21,180 @@ use axum::{
http::{header, HeaderMap},
response::IntoResponse,
};
use uuid::Uuid;
const EMPTY_WS_RESPONSE_FALLBACK: &str =
"Tool execution completed, but the model returned no final text response. Please ask me to summarize the result.";
const WS_HISTORY_MEMORY_KEY_PREFIX: &str = "gateway_ws_history";
const MAX_WS_PERSISTED_TURNS: usize = 128;
const MAX_WS_SESSION_ID_LEN: usize = 128;
#[derive(Debug, Default, PartialEq, Eq)]
struct WsQueryParams {
token: Option<String>,
session_id: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
struct WsHistoryTurn {
role: String,
content: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default, PartialEq, Eq)]
struct WsPersistedHistory {
version: u8,
messages: Vec<WsHistoryTurn>,
}
fn normalize_ws_session_id(candidate: Option<&str>) -> Option<String> {
let raw = candidate?.trim();
if raw.is_empty() || raw.len() > MAX_WS_SESSION_ID_LEN {
return None;
}
if raw
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_')
{
return Some(raw.to_string());
}
None
}
fn parse_ws_query_params(raw_query: Option<&str>) -> WsQueryParams {
let Some(query) = raw_query else {
return WsQueryParams::default();
};
let mut params = WsQueryParams::default();
for kv in query.split('&') {
let mut parts = kv.splitn(2, '=');
let key = parts.next().unwrap_or("").trim();
let value = parts.next().unwrap_or("").trim();
if value.is_empty() {
continue;
}
match key {
"token" if params.token.is_none() => {
params.token = Some(value.to_string());
}
"session_id" if params.session_id.is_none() => {
params.session_id = normalize_ws_session_id(Some(value));
}
_ => {}
}
}
params
}
fn ws_history_memory_key(session_id: &str) -> String {
format!("{WS_HISTORY_MEMORY_KEY_PREFIX}:{session_id}")
}
fn ws_history_turns_from_chat(history: &[ChatMessage]) -> Vec<WsHistoryTurn> {
let mut turns = history
.iter()
.filter_map(|msg| match msg.role.as_str() {
"user" | "assistant" => {
let content = msg.content.trim();
if content.is_empty() {
None
} else {
Some(WsHistoryTurn {
role: msg.role.clone(),
content: content.to_string(),
})
}
}
_ => None,
})
.collect::<Vec<_>>();
if turns.len() > MAX_WS_PERSISTED_TURNS {
let keep_from = turns.len().saturating_sub(MAX_WS_PERSISTED_TURNS);
turns.drain(0..keep_from);
}
turns
}
fn restore_chat_history(system_prompt: &str, turns: &[WsHistoryTurn]) -> Vec<ChatMessage> {
let mut history = vec![ChatMessage::system(system_prompt)];
for turn in turns {
match turn.role.as_str() {
"user" => history.push(ChatMessage::user(&turn.content)),
"assistant" => history.push(ChatMessage::assistant(&turn.content)),
_ => {}
}
}
history
}
async fn load_ws_history(
state: &AppState,
session_id: &str,
system_prompt: &str,
) -> Vec<ChatMessage> {
let key = ws_history_memory_key(session_id);
let Some(entry) = state.mem.get(&key).await.ok().flatten() else {
return vec![ChatMessage::system(system_prompt)];
};
let parsed = serde_json::from_str::<WsPersistedHistory>(&entry.content)
.map(|history| history.messages)
.or_else(|_| serde_json::from_str::<Vec<WsHistoryTurn>>(&entry.content));
match parsed {
Ok(turns) => restore_chat_history(system_prompt, &turns),
Err(err) => {
tracing::warn!(
"Failed to parse persisted websocket history for session {}: {}",
session_id,
err
);
vec![ChatMessage::system(system_prompt)]
}
}
}
async fn persist_ws_history(state: &AppState, session_id: &str, history: &[ChatMessage]) {
let payload = WsPersistedHistory {
version: 1,
messages: ws_history_turns_from_chat(history),
};
let serialized = match serde_json::to_string(&payload) {
Ok(value) => value,
Err(err) => {
tracing::warn!(
"Failed to serialize websocket history for session {}: {}",
session_id,
err
);
return;
}
};
let key = ws_history_memory_key(session_id);
if let Err(err) = state
.mem
.store(
&key,
&serialized,
MemoryCategory::Conversation,
Some(session_id),
)
.await
{
tracing::warn!(
"Failed to persist websocket history for session {}: {}",
session_id,
err
);
}
}
fn sanitize_ws_response(
response: &str,
@@ -168,10 +340,11 @@ pub async fn handle_ws_chat(
RawQuery(query): RawQuery,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
let query_params = parse_ws_query_params(query.as_deref());
// Auth via Authorization header or websocket protocol token.
if state.pairing.require_pairing() {
let query_token = extract_query_token(query.as_deref());
let token = extract_ws_bearer_token(&headers, query_token.as_deref()).unwrap_or_default();
let token =
extract_ws_bearer_token(&headers, query_params.token.as_deref()).unwrap_or_default();
if !state.pairing.is_authenticated(&token) {
return (
axum::http::StatusCode::UNAUTHORIZED,
@@ -181,13 +354,16 @@ pub async fn handle_ws_chat(
}
}
ws.on_upgrade(move |socket| handle_socket(socket, state))
let session_id = query_params
.session_id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
.into_response()
}
async fn handle_socket(mut socket: WebSocket, state: AppState) {
// Maintain conversation history for this WebSocket session
let mut history: Vec<ChatMessage> = Vec::new();
async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: String) {
let ws_session_id = format!("ws_{}", Uuid::new_v4());
// Build system prompt once for the session
let system_prompt = {
@@ -200,8 +376,17 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
)
};
// Add system message to history
history.push(ChatMessage::system(&system_prompt));
// Restore persisted history (if any) and replay to the client before processing new input.
let mut history = load_ws_history(&state, &session_id, &system_prompt).await;
let persisted_turns = ws_history_turns_from_chat(&history);
let history_payload = serde_json::json!({
"type": "history",
"session_id": session_id.as_str(),
"messages": persisted_turns,
});
let _ = socket
.send(Message::Text(history_payload.to_string().into()))
.await;
while let Some(msg) = socket.recv().await {
let msg = match msg {
@@ -250,6 +435,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
// Add user message to history
history.push(ChatMessage::user(&content));
persist_ws_history(&state, &session_id, &history).await;
// Get provider info
let provider_label = state
@@ -267,7 +453,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
}));
// Full agentic loop with tools (includes WASM skills, shell, memory, etc.)
match super::run_gateway_chat_with_tools(&state, &content).await {
match super::run_gateway_chat_with_tools(&state, &content, Some(&ws_session_id)).await {
Ok(response) => {
let leak_guard_cfg = { state.config.lock().security.outbound_leak_guard.clone() };
let safe_response = finalize_ws_response(
@@ -278,6 +464,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
);
// Add assistant response to history
history.push(ChatMessage::assistant(&safe_response));
persist_ws_history(&state, &session_id, &history).await;
// Send the full response as a done message
let done = serde_json::json!({
@@ -345,18 +532,7 @@ fn extract_ws_bearer_token(headers: &HeaderMap, query_token: Option<&str>) -> Op
}
fn extract_query_token(raw_query: Option<&str>) -> Option<String> {
let query = raw_query?;
for kv in query.split('&') {
let mut parts = kv.splitn(2, '=');
if parts.next() != Some("token") {
continue;
}
let token = parts.next().unwrap_or("").trim();
if !token.is_empty() {
return Some(token.to_string());
}
}
None
parse_ws_query_params(raw_query).token
}
#[cfg(test)]
@@ -445,6 +621,70 @@ mod tests {
assert!(extract_query_token(Some("foo=1")).is_none());
}
#[test]
fn parse_ws_query_params_reads_token_and_session_id() {
let parsed = parse_ws_query_params(Some("foo=1&session_id=sess_123&token=query-token"));
assert_eq!(parsed.token.as_deref(), Some("query-token"));
assert_eq!(parsed.session_id.as_deref(), Some("sess_123"));
}
#[test]
fn parse_ws_query_params_rejects_invalid_session_id() {
let parsed = parse_ws_query_params(Some("session_id=../../etc/passwd"));
assert!(parsed.session_id.is_none());
}
#[test]
fn ws_history_turns_from_chat_skips_system_and_non_dialog_turns() {
let history = vec![
ChatMessage::system("sys"),
ChatMessage::user(" hello "),
ChatMessage {
role: "tool".to_string(),
content: "ignored".to_string(),
},
ChatMessage::assistant(" world "),
];
let turns = ws_history_turns_from_chat(&history);
assert_eq!(
turns,
vec![
WsHistoryTurn {
role: "user".to_string(),
content: "hello".to_string()
},
WsHistoryTurn {
role: "assistant".to_string(),
content: "world".to_string()
}
]
);
}
#[test]
fn restore_chat_history_applies_system_prompt_once() {
let turns = vec![
WsHistoryTurn {
role: "user".to_string(),
content: "u1".to_string(),
},
WsHistoryTurn {
role: "assistant".to_string(),
content: "a1".to_string(),
},
];
let restored = restore_chat_history("sys", &turns);
assert_eq!(restored.len(), 3);
assert_eq!(restored[0].role, "system");
assert_eq!(restored[0].content, "sys");
assert_eq!(restored[1].role, "user");
assert_eq!(restored[1].content, "u1");
assert_eq!(restored[2].role, "assistant");
assert_eq!(restored[2].content, "a1");
}
struct MockScheduleTool;
#[async_trait]
+6 -8
View File
@@ -59,9 +59,7 @@ mod agent;
mod approval;
mod auth;
mod channels;
mod rag {
pub use zeroclaw::rag::*;
}
mod rag;
mod config;
mod coordination;
mod cost;
@@ -858,7 +856,7 @@ async fn main() -> Result<()> {
}?;
// Auto-start channels if user said yes during wizard
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
channels::start_channels(config).await?;
Box::pin(channels::start_channels(config)).await?;
}
return Ok(());
}
@@ -919,7 +917,7 @@ async fn main() -> Result<()> {
// Single-shot mode (-m) runs non-interactively: no TTY approval prompt,
// so tools are not denied by a stdin read returning EOF.
let interactive = message.is_none();
agent::run(
Box::pin(agent::run(
config,
message,
provider,
@@ -927,7 +925,7 @@ async fn main() -> Result<()> {
temperature,
peripheral,
interactive,
)
))
.await
.map(|_| ())
}
@@ -1166,8 +1164,8 @@ async fn main() -> Result<()> {
},
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await,
ChannelCommands::Doctor => channels::doctor_channels(config).await,
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
other => channels::handle_command(other, &config).await,
},
+35 -4
View File
@@ -3,6 +3,7 @@ pub enum MemoryBackendKind {
Sqlite,
SqliteQdrantHybrid,
Lucid,
CortexMem,
Postgres,
Qdrant,
Markdown,
@@ -39,6 +40,15 @@ const LUCID_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: true,
};
const CORTEX_MEM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "cortex-mem",
label: "Cortex-Mem bridge — optional CLI sync with local SQLite fallback",
auto_save_default: true,
uses_sqlite_hygiene: true,
sqlite_based: true,
optional_dependency: true,
};
const MARKDOWN_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "markdown",
label: "Markdown Files — simple, human-readable, no dependencies",
@@ -93,9 +103,10 @@ const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: false,
};
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 4] = [
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 5] = [
SQLITE_PROFILE,
LUCID_PROFILE,
CORTEX_MEM_PROFILE,
MARKDOWN_PROFILE,
NONE_PROFILE,
];
@@ -113,6 +124,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
"sqlite" => MemoryBackendKind::Sqlite,
"sqlite_qdrant_hybrid" | "hybrid" => MemoryBackendKind::SqliteQdrantHybrid,
"lucid" => MemoryBackendKind::Lucid,
"cortex-mem" | "cortex_mem" | "cortexmem" | "cortex" => MemoryBackendKind::CortexMem,
"postgres" => MemoryBackendKind::Postgres,
"qdrant" => MemoryBackendKind::Qdrant,
"markdown" => MemoryBackendKind::Markdown,
@@ -126,6 +138,7 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
MemoryBackendKind::Sqlite => SQLITE_PROFILE,
MemoryBackendKind::SqliteQdrantHybrid => SQLITE_QDRANT_HYBRID_PROFILE,
MemoryBackendKind::Lucid => LUCID_PROFILE,
MemoryBackendKind::CortexMem => CORTEX_MEM_PROFILE,
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
@@ -146,6 +159,14 @@ mod tests {
MemoryBackendKind::SqliteQdrantHybrid
);
assert_eq!(classify_memory_backend("lucid"), MemoryBackendKind::Lucid);
assert_eq!(
classify_memory_backend("cortex-mem"),
MemoryBackendKind::CortexMem
);
assert_eq!(
classify_memory_backend("cortex_mem"),
MemoryBackendKind::CortexMem
);
assert_eq!(
classify_memory_backend("postgres"),
MemoryBackendKind::Postgres
@@ -173,11 +194,12 @@ mod tests {
#[test]
fn selectable_backends_are_ordered_for_onboarding() {
let backends = selectable_memory_backends();
assert_eq!(backends.len(), 4);
assert_eq!(backends.len(), 5);
assert_eq!(backends[0].key, "sqlite");
assert_eq!(backends[1].key, "lucid");
assert_eq!(backends[2].key, "markdown");
assert_eq!(backends[3].key, "none");
assert_eq!(backends[2].key, "cortex-mem");
assert_eq!(backends[3].key, "markdown");
assert_eq!(backends[4].key, "none");
}
#[test]
@@ -188,6 +210,15 @@ mod tests {
assert!(profile.uses_sqlite_hygiene);
}
#[test]
fn cortex_profile_is_sqlite_based_optional_backend() {
let profile = memory_backend_profile("cortex-mem");
assert_eq!(profile.key, "cortex-mem");
assert!(profile.sqlite_based);
assert!(profile.optional_dependency);
assert!(profile.uses_sqlite_hygiene);
}
#[test]
fn unknown_profile_preserves_extensibility_defaults() {
let profile = memory_backend_profile("custom-memory");
+109
View File
@@ -0,0 +1,109 @@
use super::lucid::LucidMemory;
use super::sqlite::SqliteMemory;
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use async_trait::async_trait;
use std::path::Path;
pub struct CortexMemMemory {
inner: LucidMemory,
}
impl CortexMemMemory {
const DEFAULT_CORTEX_CMD: &'static str = "cortex-mem";
pub fn new(workspace_dir: &Path, local: SqliteMemory) -> Self {
let cortex_cmd = std::env::var("ZEROCLAW_CORTEX_CMD")
.or_else(|_| std::env::var("ZEROCLAW_LUCID_CMD"))
.unwrap_or_else(|_| Self::DEFAULT_CORTEX_CMD.to_string());
let inner = LucidMemory::new_with_command(workspace_dir, local, cortex_cmd);
Self { inner }
}
#[cfg(test)]
fn new_with_command_for_test(workspace_dir: &Path, local: SqliteMemory, command: &str) -> Self {
let inner = LucidMemory::new_with_command(workspace_dir, local, command.to_string());
Self { inner }
}
}
#[async_trait]
impl Memory for CortexMemMemory {
fn name(&self) -> &str {
"cortex-mem"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
self.inner.store(key, content, category, session_id).await
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.inner.recall(query, limit, session_id).await
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
self.inner.get(key).await
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.inner.list(category, session_id).await
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
self.inner.forget(key).await
}
async fn count(&self) -> anyhow::Result<usize> {
self.inner.count().await
}
async fn health_check(&self) -> bool {
self.inner.health_check().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn cortex_backend_reports_expected_name() {
let tmp = TempDir::new().unwrap();
let sqlite = SqliteMemory::new(tmp.path()).unwrap();
let memory = CortexMemMemory::new(tmp.path(), sqlite);
assert_eq!(memory.name(), "cortex-mem");
}
#[tokio::test]
async fn cortex_backend_keeps_local_store_when_bridge_command_fails() {
let tmp = TempDir::new().unwrap();
let sqlite = SqliteMemory::new(tmp.path()).unwrap();
let memory =
CortexMemMemory::new_with_command_for_test(tmp.path(), sqlite, "missing-cortex-cli");
memory
.store("cortex_key", "local first", MemoryCategory::Conversation, None)
.await
.unwrap();
let stored = memory.get("cortex_key").await.unwrap();
assert!(stored.is_some(), "expected local sqlite entry to be present");
assert_eq!(stored.unwrap().content, "local first");
}
}
+7
View File
@@ -34,7 +34,14 @@ impl LucidMemory {
pub fn new(workspace_dir: &Path, local: SqliteMemory) -> Self {
let lucid_cmd = std::env::var("ZEROCLAW_LUCID_CMD")
.unwrap_or_else(|_| Self::DEFAULT_LUCID_CMD.to_string());
Self::new_with_command(workspace_dir, local, lucid_cmd)
}
pub(crate) fn new_with_command(
workspace_dir: &Path,
local: SqliteMemory,
lucid_cmd: String,
) -> Self {
let token_budget = std::env::var("ZEROCLAW_LUCID_BUDGET")
.ok()
.and_then(|v| v.parse::<usize>().ok())
+27 -1
View File
@@ -1,6 +1,7 @@
pub mod backend;
pub mod chunker;
pub mod cli;
pub mod cortex;
pub mod embeddings;
pub mod hybrid;
pub mod hygiene;
@@ -21,6 +22,7 @@ pub use backend::{
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
selectable_memory_backends, MemoryBackendKind, MemoryBackendProfile,
};
pub use cortex::CortexMemMemory;
pub use hybrid::SqliteQdrantHybridMemory;
pub use lucid::LucidMemory;
pub use markdown::MarkdownMemory;
@@ -58,6 +60,10 @@ where
let local = sqlite_builder()?;
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
}
MemoryBackendKind::CortexMem => {
let local = sqlite_builder()?;
Ok(Box::new(CortexMemMemory::new(workspace_dir, local)))
}
MemoryBackendKind::Postgres => postgres_builder(),
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
@@ -217,6 +223,7 @@ pub fn create_memory_with_storage_and_routes(
MemoryBackendKind::Sqlite
| MemoryBackendKind::SqliteQdrantHybrid
| MemoryBackendKind::Lucid
| MemoryBackendKind::CortexMem
)
{
if let Err(e) = snapshot::export_snapshot(workspace_dir) {
@@ -232,6 +239,7 @@ pub fn create_memory_with_storage_and_routes(
MemoryBackendKind::Sqlite
| MemoryBackendKind::SqliteQdrantHybrid
| MemoryBackendKind::Lucid
| MemoryBackendKind::CortexMem
)
&& snapshot::should_hydrate(workspace_dir)
{
@@ -381,7 +389,7 @@ pub fn create_memory_for_migration(
) -> anyhow::Result<Box<dyn Memory>> {
if matches!(classify_memory_backend(backend), MemoryBackendKind::None) {
anyhow::bail!(
"memory backend 'none' disables persistence; choose sqlite, lucid, or markdown before migration"
"memory backend 'none' disables persistence; choose sqlite, lucid, cortex-mem, or markdown before migration"
);
}
@@ -477,6 +485,17 @@ mod tests {
assert_eq!(mem.name(), "lucid");
}
#[test]
fn factory_cortex_mem() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "cortex-mem".into(),
..MemoryConfig::default()
};
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
assert_eq!(mem.name(), "cortex-mem");
}
#[test]
fn factory_sqlite_qdrant_hybrid() {
let tmp = TempDir::new().unwrap();
@@ -521,6 +540,13 @@ mod tests {
assert_eq!(mem.name(), "lucid");
}
#[test]
fn migration_factory_cortex_mem() {
let tmp = TempDir::new().unwrap();
let mem = create_memory_for_migration("cortex-mem", tmp.path()).unwrap();
assert_eq!(mem.name(), "cortex-mem");
}
#[test]
fn migration_factory_none_is_rejected() {
let tmp = TempDir::new().unwrap();
+1 -1
View File
@@ -48,7 +48,7 @@ impl CostObserver {
// Try model family matching (e.g., "claude-sonnet-4" matches any claude-sonnet-4-*)
for (key, pricing) in &self.prices {
// Strip provider prefix if present
let key_model = key.split('/').last().unwrap_or(key);
let key_model = key.split('/').next_back().unwrap_or(key);
// Check if model starts with the key (family match)
if model.starts_with(key_model) || key_model.starts_with(model) {
+12 -5
View File
@@ -381,7 +381,7 @@ fn apply_provider_update(
// ── Quick setup (zero prompts) ───────────────────────────────────
/// Non-interactive setup: generates a sensible default config instantly.
/// Use `zeroclaw onboard` or `zeroclaw onboard --api-key sk-... --provider openrouter --memory sqlite|lucid`.
/// Use `zeroclaw onboard` or `zeroclaw onboard --api-key sk-... --provider openrouter --memory sqlite|lucid|cortex-mem`.
/// Use `zeroclaw onboard --interactive` for the full wizard.
fn backend_key_from_choice(choice: usize) -> &'static str {
selectable_memory_backends()
@@ -787,8 +787,7 @@ fn default_model_for_provider(provider: &str) -> String {
"qwen-code" => "qwen3-coder-plus".into(),
"ollama" => "llama3.2".into(),
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF".into(),
"sglang" | "vllm" | "osaurus" => "default".into(),
"copilot" => "default".into(),
"sglang" | "vllm" | "osaurus" | "copilot" => "default".into(),
"gemini" => "gemini-2.5-pro".into(),
"kimi-code" => "kimi-for-coding".into(),
"bedrock" => "anthropic.claude-sonnet-4-5-20250929-v1:0".into(),
@@ -4479,6 +4478,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
} else {
Some(channel)
},
channel_ids: vec![],
allowed_users,
group_reply: None,
});
@@ -8328,8 +8328,9 @@ mod tests {
fn backend_key_from_choice_maps_supported_backends() {
assert_eq!(backend_key_from_choice(0), "sqlite");
assert_eq!(backend_key_from_choice(1), "lucid");
assert_eq!(backend_key_from_choice(2), "markdown");
assert_eq!(backend_key_from_choice(3), "none");
assert_eq!(backend_key_from_choice(2), "cortex-mem");
assert_eq!(backend_key_from_choice(3), "markdown");
assert_eq!(backend_key_from_choice(4), "none");
assert_eq!(backend_key_from_choice(999), "sqlite");
}
@@ -8341,6 +8342,12 @@ mod tests {
assert!(lucid.sqlite_based);
assert!(lucid.optional_dependency);
let cortex_mem = memory_backend_profile("cortex-mem");
assert!(cortex_mem.auto_save_default);
assert!(cortex_mem.uses_sqlite_hygiene);
assert!(cortex_mem.sqlite_based);
assert!(cortex_mem.optional_dependency);
let markdown = memory_backend_profile("markdown");
assert!(markdown.auto_save_default);
assert!(!markdown.uses_sqlite_hygiene);
+1 -1
View File
@@ -126,7 +126,7 @@ pub fn discover_plugins(workspace_dir: Option<&Path>, extra_paths: &[PathBuf]) -
let mut deduped: Vec<DiscoveredPlugin> = Vec::with_capacity(seen.len());
// Collect in insertion order of the winning index
let mut indices: Vec<usize> = seen.values().copied().collect();
indices.sort();
indices.sort_unstable();
for i in indices {
deduped.push(all_plugins.swap_remove(i));
}
+4 -1
View File
@@ -13,7 +13,10 @@ use tracing::{info, warn};
use crate::config::PluginsConfig;
use super::discovery::discover_plugins;
use super::registry::*;
use super::registry::{
DiagnosticLevel, PluginDiagnostic, PluginHookRegistration, PluginOrigin, PluginRecord,
PluginRegistry, PluginStatus, PluginToolRegistration,
};
use super::traits::{Plugin, PluginApi, PluginLogger};
/// Resolve whether a discovered plugin should be enabled.
+146 -47
View File
@@ -16,6 +16,9 @@ use hmac::{Hmac, Mac};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
/// Hostname prefix for the Bedrock Runtime endpoint.
const ENDPOINT_PREFIX: &str = "bedrock-runtime";
@@ -27,6 +30,7 @@ const DEFAULT_MAX_TOKENS: u32 = 4096;
// ── AWS Credentials ─────────────────────────────────────────────
/// Resolved AWS credentials for SigV4 signing.
#[derive(Clone)]
struct AwsCredentials {
access_key_id: String,
secret_access_key: String,
@@ -134,11 +138,68 @@ impl AwsCredentials {
})
}
/// Resolve credentials: env vars first, then EC2 IMDS.
/// Fetch credentials from ECS container credential endpoint.
/// Available when running on ECS/Fargate with a task IAM role.
async fn from_ecs() -> anyhow::Result<Self> {
// Try relative URI first (standard ECS), then full URI (ECS Anywhere / custom)
let uri = std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
.ok()
.map(|rel| format!("http://169.254.170.2{rel}"))
.or_else(|| std::env::var("AWS_CONTAINER_CREDENTIALS_FULL_URI").ok());
let uri = uri.ok_or_else(|| {
anyhow::anyhow!(
"Neither AWS_CONTAINER_CREDENTIALS_RELATIVE_URI nor \
AWS_CONTAINER_CREDENTIALS_FULL_URI is set"
)
})?;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(3))
.build()?;
let mut req = client.get(&uri);
// ECS Anywhere / full URI may require an authorization token
if let Ok(token) = std::env::var("AWS_CONTAINER_AUTHORIZATION_TOKEN") {
req = req.header("Authorization", token);
}
let creds_json: serde_json::Value = req.send().await?.json().await?;
let access_key_id = creds_json["AccessKeyId"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing AccessKeyId in ECS credential response"))?
.to_string();
let secret_access_key = creds_json["SecretAccessKey"]
.as_str()
.ok_or_else(|| {
anyhow::anyhow!("Missing SecretAccessKey in ECS credential response")
})?
.to_string();
let session_token = creds_json["Token"].as_str().map(|s| s.to_string());
let region = env_optional("AWS_REGION")
.or_else(|| env_optional("AWS_DEFAULT_REGION"))
.unwrap_or_else(|| DEFAULT_REGION.to_string());
tracing::info!("Loaded AWS credentials from ECS container credential endpoint");
Ok(Self {
access_key_id,
secret_access_key,
session_token,
region,
})
}
/// Resolve credentials: env vars → ECS endpoint → EC2 IMDS.
async fn resolve() -> anyhow::Result<Self> {
if let Ok(creds) = Self::from_env() {
return Ok(creds);
}
if let Ok(creds) = Self::from_ecs().await {
return Ok(creds);
}
Self::from_imds().await
}
@@ -176,6 +237,57 @@ fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
mac.finalize().into_bytes().to_vec()
}
/// How long credentials are considered fresh before re-fetching.
/// ECS STS tokens typically expire after 6-12 hours; we refresh well
/// before that to avoid any requests hitting expired tokens.
const CREDENTIAL_TTL_SECS: u64 = 50 * 60; // 50 minutes
/// Thread-safe credential cache that auto-refreshes from the ECS
/// container credential endpoint (or env vars / IMDS) when the
/// cached credentials are older than [`CREDENTIAL_TTL_SECS`].
struct CachedCredentials {
inner: Arc<RwLock<Option<(AwsCredentials, Instant)>>>,
}
impl CachedCredentials {
/// Create a new cache, optionally pre-populated with initial credentials.
fn new(initial: Option<AwsCredentials>) -> Self {
let entry = initial.map(|c| (c, Instant::now()));
Self {
inner: Arc::new(RwLock::new(entry)),
}
}
/// Get current credentials, refreshing if stale or missing.
async fn get(&self) -> anyhow::Result<AwsCredentials> {
// Fast path: read lock, check freshness
{
let guard = self.inner.read().await;
if let Some((ref creds, fetched_at)) = *guard {
if fetched_at.elapsed().as_secs() < CREDENTIAL_TTL_SECS {
return Ok(creds.clone());
}
}
}
// Slow path: write lock, re-fetch
let mut guard = self.inner.write().await;
// Double-check after acquiring write lock (another task may have refreshed)
if let Some((ref creds, fetched_at)) = *guard {
if fetched_at.elapsed().as_secs() < CREDENTIAL_TTL_SECS {
return Ok(creds.clone());
}
}
tracing::info!("Refreshing AWS credentials (TTL expired or first fetch)");
let fresh = AwsCredentials::resolve().await?;
let cloned = fresh.clone();
*guard = Some((fresh, Instant::now()));
Ok(cloned)
}
}
/// Derive the SigV4 signing key via HMAC chain.
fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes());
@@ -454,19 +566,21 @@ struct ResponseToolUseWrapper {
// ── BedrockProvider ─────────────────────────────────────────────
pub struct BedrockProvider {
credentials: Option<AwsCredentials>,
credentials: CachedCredentials,
}
impl BedrockProvider {
pub fn new() -> Self {
Self {
credentials: AwsCredentials::from_env().ok(),
credentials: CachedCredentials::new(AwsCredentials::from_env().ok()),
}
}
pub async fn new_async() -> Self {
let credentials = AwsCredentials::resolve().await.ok();
Self { credentials }
let initial = AwsCredentials::resolve().await.ok();
Self {
credentials: CachedCredentials::new(initial),
}
}
fn http_client(&self) -> Client {
@@ -504,22 +618,10 @@ impl BedrockProvider {
format!("/model/{encoded}/converse-stream")
}
fn require_credentials(&self) -> anyhow::Result<&AwsCredentials> {
self.credentials.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"AWS Bedrock credentials not set. Set AWS_ACCESS_KEY_ID and \
AWS_SECRET_ACCESS_KEY environment variables, or run on an EC2 \
instance with an IAM role attached."
)
})
}
/// Resolve credentials: use cached if available, otherwise fetch from IMDS.
async fn resolve_credentials(&self) -> anyhow::Result<AwsCredentials> {
if let Ok(creds) = AwsCredentials::from_env() {
return Ok(creds);
}
AwsCredentials::from_imds().await
/// Get credentials, auto-refreshing from the ECS endpoint / env vars /
/// IMDS when they are older than [`CREDENTIAL_TTL_SECS`].
async fn get_credentials(&self) -> anyhow::Result<AwsCredentials> {
self.credentials.get().await
}
// ── Cache heuristics (same thresholds as AnthropicProvider) ──
@@ -1243,7 +1345,7 @@ impl Provider for BedrockProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let credentials = self.resolve_credentials().await?;
let credentials = self.get_credentials().await?;
let system = system_prompt.map(|text| {
let mut blocks = vec![SystemBlock::Text(TextBlock {
@@ -1285,7 +1387,7 @@ impl Provider for BedrockProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credentials = self.resolve_credentials().await?;
let credentials = self.get_credentials().await?;
let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
@@ -1344,18 +1446,6 @@ impl Provider for BedrockProvider {
temperature: f64,
options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
let credentials = match self.require_credentials() {
Ok(c) => c,
Err(_) => {
return stream::once(async {
Err(StreamError::Provider(
"AWS Bedrock credentials not set".to_string(),
))
})
.boxed();
}
};
let system = system_prompt.map(|text| {
let mut blocks = vec![SystemBlock::Text(TextBlock {
text: text.to_string(),
@@ -1381,13 +1471,7 @@ impl Provider for BedrockProvider {
tool_config: None,
};
// Clone what we need for the async block
let credentials = AwsCredentials {
access_key_id: credentials.access_key_id.clone(),
secret_access_key: credentials.secret_access_key.clone(),
session_token: credentials.session_token.clone(),
region: credentials.region.clone(),
};
let cred_cache = self.credentials.inner.clone();
let model = model.to_string();
let count_tokens = options.count_tokens;
let client = self.http_client();
@@ -1397,6 +1481,21 @@ impl Provider for BedrockProvider {
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
tokio::spawn(async move {
// Resolve credentials inside the async context so we get
// TTL-validated, auto-refreshing credentials (not stale sync cache).
let cred_handle = CachedCredentials { inner: cred_cache };
let credentials = match cred_handle.get().await {
Ok(c) => c,
Err(e) => {
let _ = tx
.send(Err(StreamError::Provider(format!(
"AWS Bedrock credentials not available: {e}"
))))
.await;
return;
}
};
let payload = match serde_json::to_vec(&request) {
Ok(p) => p,
Err(e) => {
@@ -1530,7 +1629,7 @@ impl Provider for BedrockProvider {
}
async fn warmup(&self) -> anyhow::Result<()> {
if let Some(ref creds) = self.credentials {
if let Ok(creds) = self.get_credentials().await {
let url = format!("https://{ENDPOINT_PREFIX}.{}.amazonaws.com/", creds.region);
let _ = self.http_client().get(&url).send().await;
}
@@ -1696,7 +1795,7 @@ mod tests {
#[tokio::test]
async fn chat_fails_without_credentials() {
let provider = BedrockProvider { credentials: None };
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
let result = provider
.chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", 0.7)
.await;
@@ -1992,14 +2091,14 @@ mod tests {
#[tokio::test]
async fn warmup_without_credentials_is_noop() {
let provider = BedrockProvider { credentials: None };
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
let result = provider.warmup().await;
assert!(result.is_ok());
}
#[test]
fn capabilities_reports_native_tool_calling() {
let provider = BedrockProvider { credentials: None };
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
let caps = provider.capabilities();
assert!(caps.native_tool_calling);
}
@@ -2053,7 +2152,7 @@ mod tests {
#[test]
fn supports_streaming_returns_true() {
let provider = BedrockProvider { credentials: None };
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
assert!(provider.supports_streaming());
}
+150 -11
View File
@@ -39,7 +39,8 @@ pub mod traits;
#[allow(unused_imports)]
pub use traits::{
ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderCapabilityError,
ToolCall, ToolResultMessage,
is_user_or_assistant_role, ToolCall, ToolResultMessage, ROLE_ASSISTANT, ROLE_SYSTEM, ROLE_TOOL,
ROLE_USER,
};
use crate::auth::AuthService;
@@ -1511,21 +1512,52 @@ pub fn create_routed_provider_with_options(
);
}
// Keep a default provider for non-routed model hints.
let default_provider = create_resilient_provider_with_options(
let default_hint = default_model
.strip_prefix("hint:")
.map(str::trim)
.filter(|hint| !hint.is_empty());
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
let mut has_primary_provider = false;
// Keep a default provider for non-routed requests. When default_model is a hint,
// route-specific providers can satisfy startup even if the primary fails.
match create_resilient_provider_with_options(
primary_name,
api_key,
api_url,
reliability,
options,
)?;
let mut providers: Vec<(String, Box<dyn Provider>)> =
vec![(primary_name.to_string(), default_provider)];
) {
Ok(default_provider) => {
providers.push((primary_name.to_string(), default_provider));
has_primary_provider = true;
}
Err(error) => {
if default_hint.is_some() {
tracing::warn!(
provider = primary_name,
model = default_model,
"Primary provider failed during routed init; continuing with hint-based routes: {error}"
);
} else {
return Err(error);
}
}
}
// Build hint routes with dedicated provider instances so per-route API keys
// and max_tokens overrides do not bleed across routes.
let mut routes: Vec<(String, router::Route)> = Vec::new();
for route in model_routes {
let route_hint = route.hint.trim();
if route_hint.is_empty() {
tracing::warn!(
provider = route.provider.as_str(),
"Ignoring routed provider with empty hint"
);
continue;
}
let routed_credential = route.api_key.as_ref().and_then(|raw_key| {
let trimmed_key = raw_key.trim();
(!trimmed_key.is_empty()).then_some(trimmed_key)
@@ -1554,10 +1586,10 @@ pub fn create_routed_provider_with_options(
&route_options,
) {
Ok(provider) => {
let provider_id = format!("{}#{}", route.provider, route.hint);
let provider_id = format!("{}#{}", route.provider, route_hint);
providers.push((provider_id.clone(), provider));
routes.push((
route.hint.clone(),
route_hint.to_string(),
router::Route {
provider_name: provider_id,
model: route.model.clone(),
@@ -1567,19 +1599,42 @@ pub fn create_routed_provider_with_options(
Err(error) => {
tracing::warn!(
provider = route.provider.as_str(),
hint = route.hint.as_str(),
hint = route_hint,
"Ignoring routed provider that failed to initialize: {error}"
);
}
}
}
if let Some(hint) = default_hint {
if !routes
.iter()
.any(|(route_hint, _)| route_hint.trim() == hint)
{
anyhow::bail!(
"default_model uses hint '{hint}', but no matching [[model_routes]] entry initialized successfully"
);
}
}
if providers.is_empty() {
anyhow::bail!("No providers initialized for routed configuration");
}
// Keep only successfully initialized routed providers and preserve
// their provider-id bindings (e.g. "<provider>#<hint>").
Ok(Box::new(
router::RouterProvider::new(providers, routes, default_model.to_string())
.with_vision_override(options.model_support_vision),
router::RouterProvider::new(
providers,
routes,
if has_primary_provider {
String::new()
} else {
default_model.to_string()
},
)
.with_vision_override(options.model_support_vision),
))
}
@@ -3124,6 +3179,90 @@ mod tests {
assert!(provider.is_ok());
}
#[test]
fn routed_provider_supports_hint_default_when_primary_init_fails() {
let reliability = crate::config::ReliabilityConfig::default();
let routes = vec![crate::config::ModelRouteConfig {
hint: "reasoning".to_string(),
provider: "lmstudio".to_string(),
model: "qwen2.5-coder".to_string(),
max_tokens: None,
api_key: None,
transport: None,
}];
let provider = create_routed_provider_with_options(
"provider-that-does-not-exist",
None,
None,
&reliability,
&routes,
"hint:reasoning",
&ProviderRuntimeOptions::default(),
);
assert!(
provider.is_ok(),
"hint default should allow startup from route providers"
);
}
#[test]
fn routed_provider_normalizes_whitespace_in_hint_routes() {
let reliability = crate::config::ReliabilityConfig::default();
let routes = vec![crate::config::ModelRouteConfig {
hint: " reasoning ".to_string(),
provider: "lmstudio".to_string(),
model: "qwen2.5-coder".to_string(),
max_tokens: None,
api_key: None,
transport: None,
}];
let provider = create_routed_provider_with_options(
"provider-that-does-not-exist",
None,
None,
&reliability,
&routes,
"hint: reasoning ",
&ProviderRuntimeOptions::default(),
);
assert!(
provider.is_ok(),
"trimmed default hint should match trimmed route hint"
);
}
#[test]
fn routed_provider_rejects_unresolved_hint_default() {
let reliability = crate::config::ReliabilityConfig::default();
let routes = vec![crate::config::ModelRouteConfig {
hint: "fast".to_string(),
provider: "lmstudio".to_string(),
model: "qwen2.5-coder".to_string(),
max_tokens: None,
api_key: None,
transport: None,
}];
let err = match create_routed_provider_with_options(
"provider-that-does-not-exist",
None,
None,
&reliability,
&routes,
"hint:reasoning",
&ProviderRuntimeOptions::default(),
) {
Ok(_) => panic!("missing default hint route should fail initialization"),
Err(err) => err,
};
assert!(err
.to_string()
.contains("default_model uses hint 'reasoning'"));
}
// --- parse_provider_profile ---
#[test]
+42 -5
View File
@@ -48,12 +48,17 @@ impl RouterProvider {
let resolved_routes: HashMap<String, (usize, String)> = routes
.into_iter()
.filter_map(|(hint, route)| {
let normalized_hint = hint.trim();
if normalized_hint.is_empty() {
tracing::warn!("Route hint is empty after trimming, skipping");
return None;
}
let index = name_to_index.get(route.provider_name.as_str()).copied();
match index {
Some(i) => Some((hint, (i, route.model))),
Some(i) => Some((normalized_hint.to_string(), (i, route.model))),
None => {
tracing::warn!(
hint = hint,
hint = normalized_hint,
provider = route.provider_name,
"Route references unknown provider, skipping"
);
@@ -63,10 +68,17 @@ impl RouterProvider {
})
.collect();
let default_index = default_model
.strip_prefix("hint:")
.map(str::trim)
.filter(|hint| !hint.is_empty())
.and_then(|hint| resolved_routes.get(hint).map(|(idx, _)| *idx))
.unwrap_or(0);
Self {
routes: resolved_routes,
providers,
default_index: 0,
default_index,
default_model,
vision_override: None,
}
@@ -85,11 +97,12 @@ impl RouterProvider {
/// Resolve a model parameter to a (provider_index, actual_model) pair.
fn resolve(&self, model: &str) -> (usize, String) {
if let Some(hint) = model.strip_prefix("hint:") {
if let Some((idx, resolved_model)) = self.routes.get(hint) {
let normalized_hint = hint.trim();
if let Some((idx, resolved_model)) = self.routes.get(normalized_hint) {
return (*idx, resolved_model.clone());
}
tracing::warn!(
hint = hint,
hint = normalized_hint,
"Unknown route hint, falling back to default provider"
);
}
@@ -375,6 +388,30 @@ mod tests {
assert_eq!(model, "claude-opus");
}
#[test]
fn resolve_trims_whitespace_in_hint_reference() {
let (router, _) = make_router(
vec![("fast", "ok"), ("smart", "ok")],
vec![("reasoning", "smart", "claude-opus")],
);
let (idx, model) = router.resolve("hint: reasoning ");
assert_eq!(idx, 1);
assert_eq!(model, "claude-opus");
}
#[test]
fn resolve_matches_routes_with_whitespace_hint_config() {
let (router, _) = make_router(
vec![("fast", "ok"), ("smart", "ok")],
vec![(" reasoning ", "smart", "claude-opus")],
);
let (idx, model) = router.resolve("hint:reasoning");
assert_eq!(idx, 1);
assert_eq!(model, "claude-opus");
}
#[test]
fn skips_routes_with_unknown_provider() {
let (router, _) = make_router(
+13 -4
View File
@@ -11,31 +11,40 @@ pub struct ChatMessage {
pub content: String,
}
pub const ROLE_SYSTEM: &str = "system";
pub const ROLE_USER: &str = "user";
pub const ROLE_ASSISTANT: &str = "assistant";
pub const ROLE_TOOL: &str = "tool";
pub fn is_user_or_assistant_role(role: &str) -> bool {
role == ROLE_USER || role == ROLE_ASSISTANT
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
role: ROLE_SYSTEM.into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
role: ROLE_USER.into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
role: ROLE_ASSISTANT.into(),
content: content.into(),
}
}
pub fn tool(content: impl Into<String>) -> Self {
Self {
role: "tool".into(),
role: ROLE_TOOL.into(),
content: content.into(),
}
}
+7 -7
View File
@@ -369,7 +369,7 @@ fn shannon_entropy(bytes: &[u8]) -> f64 {
.iter()
.filter(|&&count| count > 0)
.map(|&count| {
let p = count as f64 / len;
let p = f64::from(count) / len;
-p * p.log2()
})
.sum()
@@ -396,7 +396,7 @@ mod tests {
assert!(patterns.iter().any(|p| p.contains("Stripe")));
assert!(redacted.contains("[REDACTED"));
}
_ => panic!("Should detect Stripe key"),
LeakResult::Clean => panic!("Should detect Stripe key"),
}
}
@@ -409,7 +409,7 @@ mod tests {
LeakResult::Detected { patterns, .. } => {
assert!(patterns.iter().any(|p| p.contains("AWS")));
}
_ => panic!("Should detect AWS key"),
LeakResult::Clean => panic!("Should detect AWS key"),
}
}
@@ -427,7 +427,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
assert!(patterns.iter().any(|p| p.contains("private key")));
assert!(redacted.contains("[REDACTED_PRIVATE_KEY]"));
}
_ => panic!("Should detect private key"),
LeakResult::Clean => panic!("Should detect private key"),
}
}
@@ -441,7 +441,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
assert!(patterns.iter().any(|p| p.contains("JWT")));
assert!(redacted.contains("[REDACTED_JWT]"));
}
_ => panic!("Should detect JWT"),
LeakResult::Clean => panic!("Should detect JWT"),
}
}
@@ -454,7 +454,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
LeakResult::Detected { patterns, .. } => {
assert!(patterns.iter().any(|p| p.contains("PostgreSQL")));
}
_ => panic!("Should detect database URL"),
LeakResult::Clean => panic!("Should detect database URL"),
}
}
@@ -514,7 +514,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
assert!(patterns.iter().any(|p| p.contains("High-entropy token")));
assert!(redacted.contains("[REDACTED_HIGH_ENTROPY_TOKEN]"));
}
_ => panic!("expected high-entropy detection"),
LeakResult::Clean => panic!("expected high-entropy detection"),
}
}
+2 -1
View File
@@ -61,7 +61,8 @@ fn char_class_perplexity(prefix: &str, suffix: &str) -> f64 {
let class = classify_char(ch);
if let Some(p) = suffix_prev {
let numerator = f64::from(transition[p][class] + 1);
let denominator = f64::from(row_totals[p] + CLASS_COUNT as u32);
let class_count_u32 = u32::try_from(CLASS_COUNT).unwrap_or(u32::MAX);
let denominator = f64::from(row_totals[p] + class_count_u32);
nll += -(numerator / denominator).ln();
pairs += 1;
}
+3 -2
View File
@@ -457,7 +457,8 @@ fn load_skill_md(path: &Path, dir: &Path) -> Result<Skill> {
if let Ok(raw) = std::fs::read(&meta_path) {
if let Ok(meta) = serde_json::from_slice::<serde_json::Value>(&raw) {
if let Some(slug) = meta.get("slug").and_then(|v| v.as_str()) {
let normalized = normalize_skill_name(slug.split('/').last().unwrap_or(slug));
let normalized =
normalize_skill_name(slug.split('/').next_back().unwrap_or(slug));
if !normalized.is_empty() {
name = normalized;
}
@@ -1714,7 +1715,7 @@ fn extract_zip_skill_meta(
f.read_to_end(&mut buf).ok();
if let Ok(meta) = serde_json::from_slice::<serde_json::Value>(&buf) {
let slug_raw = meta.get("slug").and_then(|v| v.as_str()).unwrap_or("");
let base = slug_raw.split('/').last().unwrap_or(slug_raw);
let base = slug_raw.split('/').next_back().unwrap_or(slug_raw);
let name = normalize_skill_name(base);
if !name.is_empty() {
let version = meta
+140 -1
View File
@@ -11,6 +11,8 @@ pub struct CronAddTool {
security: Arc<SecurityPolicy>,
}
const MIN_AGENT_EVERY_MS: u64 = 5 * 60 * 1000;
impl CronAddTool {
pub fn new(config: Arc<Config>, security: Arc<SecurityPolicy>) -> Self {
Self { config, security }
@@ -56,6 +58,8 @@ impl Tool for CronAddTool {
fn description(&self) -> &str {
"Create a scheduled cron job (shell or agent) with cron/at/every schedules. \
Use job_type='agent' with a prompt to run the AI agent on schedule. \
Use schedule.kind='at' for one-time reminders/delayed sends (recommended). \
Agent jobs with schedule.kind='cron' or schedule.kind='every' are recurring and require explicit recurring confirmation. \
To deliver output to a channel (Discord, Telegram, Slack, Mattermost, QQ, Napcat, Lark, Feishu, Email), set \
delivery={\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id_or_chat_id>\"}. \
This is the preferred tool for sending scheduled/delayed messages to users via channels."
@@ -68,13 +72,18 @@ impl Tool for CronAddTool {
"name": { "type": "string" },
"schedule": {
"type": "object",
"description": "Schedule object: {kind:'cron',expr,tz?} | {kind:'at',at} | {kind:'every',every_ms}"
"description": "Schedule object: {kind:'cron',expr,tz?} recurring | {kind:'at',at} one-time | {kind:'every',every_ms} recurring interval"
},
"job_type": { "type": "string", "enum": ["shell", "agent"] },
"command": { "type": "string" },
"prompt": { "type": "string" },
"session_target": { "type": "string", "enum": ["isolated", "main"] },
"model": { "type": "string" },
"recurring_confirmed": {
"type": "boolean",
"description": "Required for agent recurring schedules (schedule.kind='cron' or 'every'). Set true only when recurring behavior is intentional.",
"default": false
},
"delivery": {
"type": "object",
"description": "Delivery config to send job output to a channel. Example: {\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id>\"}",
@@ -216,6 +225,49 @@ impl Tool for CronAddTool {
.get("model")
.and_then(serde_json::Value::as_str)
.map(str::to_string);
let recurring_confirmed = args
.get("recurring_confirmed")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false);
match &schedule {
Schedule::Every { every_ms } => {
if !recurring_confirmed {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Agent jobs with recurring schedules require recurring_confirmed=true. \
For one-time reminders, use schedule.kind='at' with an RFC3339 timestamp."
.to_string(),
),
});
}
if *every_ms < MIN_AGENT_EVERY_MS {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Agent schedule.kind='every' must be >= {MIN_AGENT_EVERY_MS} ms (5 minutes)"
)),
});
}
}
Schedule::Cron { .. } => {
if !recurring_confirmed {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Agent jobs with recurring schedules require recurring_confirmed=true. \
For one-time reminders, use schedule.kind='at' with an RFC3339 timestamp."
.to_string(),
),
});
}
}
Schedule::At { .. } => {}
}
let delivery = match args.get("delivery") {
Some(v) => match serde_json::from_value::<DeliveryConfig>(v.clone()) {
@@ -482,4 +534,91 @@ mod tests {
.unwrap_or_default()
.contains("Missing 'prompt'"));
}
#[tokio::test]
async fn agent_every_requires_recurring_confirmation() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
let result = tool
.execute(json!({
"schedule": { "kind": "every", "every_ms": 300000 },
"job_type": "agent",
"prompt": "Send me a recurring status update"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.unwrap_or_default()
.contains("recurring_confirmed=true"));
}
#[tokio::test]
async fn agent_cron_requires_recurring_confirmation() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
let result = tool
.execute(json!({
"schedule": { "kind": "cron", "expr": "*/5 * * * *" },
"job_type": "agent",
"prompt": "Send recurring reminders"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.unwrap_or_default()
.contains("recurring_confirmed=true"));
}
#[tokio::test]
async fn agent_every_rejects_high_frequency_intervals() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
let result = tool
.execute(json!({
"schedule": { "kind": "every", "every_ms": 60000 },
"job_type": "agent",
"prompt": "Send me updates frequently",
"recurring_confirmed": true
}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.unwrap_or_default()
.contains("must be >= 300000 ms"));
}
#[tokio::test]
async fn agent_every_with_explicit_confirmation_succeeds() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
let result = tool
.execute(json!({
"schedule": { "kind": "every", "every_ms": 300000 },
"job_type": "agent",
"prompt": "Share a heartbeat summary",
"recurring_confirmed": true
}))
.await
.unwrap();
assert!(result.success, "{:?}", result.error);
assert!(result.output.contains("next_run"));
}
}
+2 -1
View File
@@ -116,7 +116,8 @@ impl Tool for CronRunTool {
}
let started_at = Utc::now();
let (success, output) = cron::scheduler::execute_job_now(&self.config, &job).await;
let (success, output) =
Box::pin(cron::scheduler::execute_job_now(&self.config, &job)).await;
let finished_at = Utc::now();
let duration_ms = (finished_at - started_at).num_milliseconds();
let status = if success { "ok" } else { "error" };
+1 -1
View File
@@ -803,7 +803,7 @@ mod tests {
"coder".to_string(),
DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "anthropic/claude-sonnet-4-20250514".to_string(),
model: crate::config::DEFAULT_MODEL_FALLBACK.to_string(),
system_prompt: None,
api_key: Some("delegate-test-credential".to_string()),
temperature: None,
+4 -4
View File
@@ -301,11 +301,11 @@ mod tests {
name: "nonexistent".to_string(),
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
args: vec![],
env: Default::default(),
env: std::collections::HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: Default::default(),
headers: std::collections::HashMap::default(),
};
let result = McpServer::connect(config).await;
assert!(result.is_err());
@@ -320,11 +320,11 @@ mod tests {
name: "bad".to_string(),
command: "/usr/bin/does_not_exist_zc_test".to_string(),
args: vec![],
env: Default::default(),
env: std::collections::HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: Default::default(),
headers: std::collections::HashMap::default(),
}];
let registry = McpRegistry::connect_all(&configs)
.await
+617 -27
View File
@@ -1,12 +1,16 @@
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
use std::borrow::Cow;
use anyhow::{anyhow, bail, Context, Result};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{oneshot, Mutex, Notify};
use tokio::time::{timeout, Duration};
use tokio_stream::StreamExt;
use crate::config::schema::{McpServerConfig, McpTransport};
use crate::tools::mcp_protocol::{JsonRpcRequest, JsonRpcResponse};
use crate::tools::mcp_protocol::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR};
/// Maximum bytes for a single JSON-RPC response.
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
@@ -95,6 +99,14 @@ impl McpTransportConn for StdioTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let line = serde_json::to_string(request)?;
self.send_raw(&line).await?;
if request.id.is_none() {
return Ok(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: None,
});
}
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
.await
.context("timeout waiting for MCP response")??;
@@ -158,6 +170,15 @@ impl McpTransportConn for HttpTransport {
bail!("MCP server returned HTTP {}", resp.status());
}
if request.id.is_none() {
return Ok(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: None,
});
}
let resp_text = resp.text().await.context("failed to read HTTP response")?;
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
@@ -173,69 +194,602 @@ impl McpTransportConn for HttpTransport {
// ── SSE Transport ─────────────────────────────────────────────────────────
/// SSE-based transport (HTTP POST for requests, SSE for responses).
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum SseStreamState {
Unknown,
Connected,
Unsupported,
}
pub struct SseTransport {
base_url: String,
sse_url: String,
server_name: String,
client: reqwest::Client,
headers: std::collections::HashMap<String, String>,
#[allow(dead_code)]
event_source: Option<tokio::task::JoinHandle<()>>,
stream_state: SseStreamState,
shared: std::sync::Arc<Mutex<SseSharedState>>,
notify: std::sync::Arc<Notify>,
shutdown_tx: Option<oneshot::Sender<()>>,
reader_task: Option<tokio::task::JoinHandle<()>>,
}
impl SseTransport {
pub fn new(config: &McpServerConfig) -> Result<Self> {
let base_url = config
let sse_url = config
.url
.as_ref()
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
.clone();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.build()
.context("failed to build HTTP client")?;
Ok(Self {
base_url,
sse_url,
server_name: config.name.clone(),
client,
headers: config.headers.clone(),
event_source: None,
stream_state: SseStreamState::Unknown,
shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
notify: std::sync::Arc::new(Notify::new()),
shutdown_tx: None,
reader_task: None,
})
}
async fn ensure_connected(&mut self) -> Result<()> {
if self.stream_state == SseStreamState::Unsupported {
return Ok(());
}
if let Some(task) = &self.reader_task {
if !task.is_finished() {
self.stream_state = SseStreamState::Connected;
return Ok(());
}
}
let mut req = self
.client
.get(&self.sse_url)
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache");
for (key, value) in &self.headers {
req = req.header(key, value);
}
let resp = req.send().await.context("SSE GET to MCP server failed")?;
if resp.status() == reqwest::StatusCode::NOT_FOUND
|| resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
{
self.stream_state = SseStreamState::Unsupported;
return Ok(());
}
if !resp.status().is_success() {
return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
}
let is_event_stream = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
if !is_event_stream {
self.stream_state = SseStreamState::Unsupported;
return Ok(());
}
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
self.shutdown_tx = Some(shutdown_tx);
let shared = self.shared.clone();
let notify = self.notify.clone();
let sse_url = self.sse_url.clone();
let server_name = self.server_name.clone();
self.reader_task = Some(tokio::spawn(async move {
let stream = resp
.bytes_stream()
.map(|item| item.map_err(std::io::Error::other));
let reader = tokio_util::io::StreamReader::new(stream);
let mut lines = BufReader::new(reader).lines();
let mut cur_event: Option<String> = None;
let mut cur_id: Option<String> = None;
let mut cur_data: Vec<String> = Vec::new();
loop {
tokio::select! {
_ = &mut shutdown_rx => {
break;
}
line = lines.next_line() => {
let Ok(line_opt) = line else { break; };
let Some(mut line) = line_opt else { break; };
if line.ends_with('\r') {
line.pop();
}
if line.is_empty() {
if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
continue;
}
let event = cur_event.take();
let data = cur_data.join("\n");
cur_data.clear();
let id = cur_id.take();
handle_sse_event(&server_name, &sse_url, &shared, &notify, event.as_deref(), id.as_deref(), data).await;
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix("event:") {
cur_event = Some(rest.trim().to_string());
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
cur_data.push(rest.to_string());
}
if let Some(rest) = line.strip_prefix("id:") {
cur_id = Some(rest.trim().to_string());
}
}
}
}
let pending = {
let mut guard = shared.lock().await;
std::mem::take(&mut guard.pending)
};
for (_, tx) in pending {
let _ = tx.send(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: Some(JsonRpcError {
code: INTERNAL_ERROR,
message: "SSE connection closed".to_string(),
data: None,
}),
});
}
}));
self.stream_state = SseStreamState::Connected;
Ok(())
}
async fn get_message_url(&self) -> Result<(String, bool)> {
let guard = self.shared.lock().await;
if let Some(url) = &guard.message_url {
return Ok((url.clone(), guard.message_url_from_endpoint));
}
drop(guard);
let derived = derive_message_url(&self.sse_url, "messages")
.or_else(|| derive_message_url(&self.sse_url, "message"))
.ok_or_else(|| anyhow!("invalid SSE URL"))?;
let mut guard = self.shared.lock().await;
if guard.message_url.is_none() {
guard.message_url = Some(derived.clone());
guard.message_url_from_endpoint = false;
}
Ok((derived, false))
}
fn maybe_try_alternate_message_url(
&self,
current_url: &str,
from_endpoint: bool,
) -> Option<String> {
if from_endpoint {
return None;
}
let alt = if current_url.ends_with("/messages") {
derive_message_url(&self.sse_url, "message")
} else {
derive_message_url(&self.sse_url, "messages")
}?;
if alt == current_url {
return None;
}
Some(alt)
}
}
#[derive(Default)]
struct SseSharedState {
message_url: Option<String>,
message_url_from_endpoint: bool,
pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
}
fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
let url = reqwest::Url::parse(sse_url).ok()?;
let mut segments: Vec<&str> = url.path_segments()?.collect();
if segments.is_empty() {
return None;
}
if segments.last().copied() == Some("sse") {
segments.pop();
segments.push(message_path);
let mut new_url = url.clone();
new_url.set_path(&format!("/{}", segments.join("/")));
return Some(new_url.to_string());
}
let mut new_url = url.clone();
let mut path = url.path().trim_end_matches('/').to_string();
path.push('/');
path.push_str(message_path);
new_url.set_path(&path);
Some(new_url.to_string())
}
async fn handle_sse_event(
server_name: &str,
sse_url: &str,
shared: &std::sync::Arc<Mutex<SseSharedState>>,
notify: &std::sync::Arc<Notify>,
event: Option<&str>,
_id: Option<&str>,
data: String,
) {
let event = event.unwrap_or("message");
let trimmed = data.trim();
if trimmed.is_empty() {
return;
}
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
let mut guard = shared.lock().await;
guard.message_url = Some(url);
guard.message_url_from_endpoint = true;
drop(guard);
notify.notify_waiters();
}
return;
}
if !event.eq_ignore_ascii_case("message") {
return;
}
let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
return;
};
let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
let _ = serde_json::from_value::<JsonRpcRequest>(value);
return;
};
let Some(id_val) = resp.id.clone() else {
return;
};
let id = match id_val.as_u64() {
Some(v) => v,
None => return,
};
let tx = {
let mut guard = shared.lock().await;
guard.pending.remove(&id)
};
if let Some(tx) = tx {
let _ = tx.send(resp);
} else {
tracing::debug!(
"MCP SSE `{}` received response for unknown id {}",
server_name,
id
);
}
}
fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
if data.starts_with('{') {
let v: serde_json::Value = serde_json::from_str(data).ok()?;
let endpoint = v.get("endpoint")?.as_str()?;
return parse_endpoint_from_data(sse_url, endpoint);
}
if data.starts_with("http://") || data.starts_with("https://") {
return Some(data.to_string());
}
let base = reqwest::Url::parse(sse_url).ok()?;
base.join(data).ok().map(|u| u.to_string())
}
fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
let text = resp_text.trim_start_matches('\u{feff}');
let mut current_data_lines: Vec<&str> = Vec::new();
let mut last_event_data_lines: Vec<&str> = Vec::new();
for raw_line in text.lines() {
let line = raw_line.trim_end_matches('\r').trim_start();
if line.is_empty() {
if !current_data_lines.is_empty() {
last_event_data_lines = std::mem::take(&mut current_data_lines);
}
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
current_data_lines.push(rest);
}
}
if !current_data_lines.is_empty() {
last_event_data_lines = current_data_lines;
}
if last_event_data_lines.is_empty() {
return Cow::Borrowed(text.trim());
}
if last_event_data_lines.len() == 1 {
return Cow::Borrowed(last_event_data_lines[0].trim());
}
let joined = last_event_data_lines.join("\n");
Cow::Owned(joined.trim().to_string())
}
async fn read_first_jsonrpc_from_sse_response(
resp: reqwest::Response,
) -> Result<Option<JsonRpcResponse>> {
let stream = resp
.bytes_stream()
.map(|item| item.map_err(std::io::Error::other));
let reader = tokio_util::io::StreamReader::new(stream);
let mut lines = BufReader::new(reader).lines();
let mut cur_event: Option<String> = None;
let mut cur_data: Vec<String> = Vec::new();
while let Ok(line_opt) = lines.next_line().await {
let Some(mut line) = line_opt else { break };
if line.ends_with('\r') {
line.pop();
}
if line.is_empty() {
if cur_event.is_none() && cur_data.is_empty() {
continue;
}
let event = cur_event.take();
let data = cur_data.join("\n");
cur_data.clear();
let event = event.unwrap_or_else(|| "message".to_string());
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
{
continue;
}
if !event.eq_ignore_ascii_case("message") {
continue;
}
let trimmed = data.trim();
if trimmed.is_empty() {
continue;
}
let json_str = extract_json_from_sse_text(trimmed);
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
return Ok(Some(resp));
}
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix("event:") {
cur_event = Some(rest.trim().to_string());
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
cur_data.push(rest.to_string());
}
}
Ok(None)
}
#[async_trait::async_trait]
impl McpTransportConn for SseTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
// For SSE, we POST the request and the response comes via SSE stream.
// Simplified implementation: treat as HTTP for now, proper SSE would
// maintain a persistent event stream.
self.ensure_connected().await?;
let id = request.id.as_ref().and_then(|v| v.as_u64());
let body = serde_json::to_string(request)?;
let url = format!("{}/message", self.base_url.trim_end_matches('/'));
let mut req = self
.client
.post(&url)
.body(body)
.header("Content-Type", "application/json");
for (key, value) in &self.headers {
req = req.header(key, value);
let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
if self.stream_state == SseStreamState::Connected && !from_endpoint {
for _ in 0..3 {
{
let guard = self.shared.lock().await;
if guard.message_url_from_endpoint {
if let Some(url) = &guard.message_url {
message_url = url.clone();
from_endpoint = true;
break;
}
}
}
let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
}
}
let primary_url = if from_endpoint {
message_url.clone()
} else {
self.sse_url.clone()
};
let secondary_url = if message_url == self.sse_url {
None
} else if primary_url == message_url {
Some(self.sse_url.clone())
} else {
Some(message_url.clone())
};
let has_secondary = secondary_url.is_some();
let mut rx = None;
if let Some(id) = id {
if self.stream_state == SseStreamState::Connected {
let (tx, ch) = oneshot::channel();
{
let mut guard = self.shared.lock().await;
guard.pending.insert(id, tx);
}
rx = Some((id, ch));
}
}
let resp = req.send().await.context("SSE POST to MCP server failed")?;
let mut got_direct = None;
let mut last_status = None;
if !resp.status().is_success() {
bail!("MCP server returned HTTP {}", resp.status());
for (i, url) in std::iter::once(primary_url)
.chain(secondary_url.into_iter())
.enumerate()
{
let mut req = self
.client
.post(&url)
.timeout(Duration::from_secs(120))
.body(body.clone())
.header("Content-Type", "application/json");
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"))
{
req = req.header("Accept", "application/json, text/event-stream");
}
let resp = req.send().await.context("SSE POST to MCP server failed")?;
let status = resp.status();
last_status = Some(status);
if (status == reqwest::StatusCode::NOT_FOUND
|| status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
&& i == 0
{
continue;
}
if !status.is_success() {
break;
}
if request.id.is_none() {
got_direct = Some(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: None,
});
break;
}
let is_sse = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
if is_sse {
if i == 0 && has_secondary {
match timeout(
Duration::from_secs(3),
read_first_jsonrpc_from_sse_response(resp),
)
.await
{
Ok(res) => {
if let Some(resp) = res? {
got_direct = Some(resp);
}
break;
}
Err(_) => continue,
}
}
if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
got_direct = Some(resp);
}
break;
}
let text = if i == 0 && has_secondary {
match timeout(Duration::from_secs(3), resp.text()).await {
Ok(Ok(t)) => t,
Ok(Err(_)) => String::new(),
Err(_) => continue,
}
} else {
resp.text().await.unwrap_or_default()
};
let trimmed = text.trim();
if !trimmed.is_empty() {
let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
extract_json_from_sse_text(trimmed)
} else {
Cow::Borrowed(trimmed)
};
if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
got_direct = Some(mcp_resp);
}
}
break;
}
// For now, parse response directly. Full SSE would read from event stream.
let resp_text = resp.text().await.context("failed to read SSE response")?;
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
if let Some((id, _)) = rx.as_ref() {
if got_direct.is_some() {
let mut guard = self.shared.lock().await;
guard.pending.remove(id);
} else if let Some(status) = last_status {
if !status.is_success() {
let mut guard = self.shared.lock().await;
guard.pending.remove(id);
}
}
}
Ok(mcp_resp)
if let Some(resp) = got_direct {
return Ok(resp);
}
if let Some(status) = last_status {
if !status.is_success() {
bail!("MCP server returned HTTP {}", status);
}
} else {
bail!("MCP request not sent");
}
let Some((_id, rx)) = rx else {
bail!("MCP server returned no response");
};
rx.await.map_err(|_| anyhow!("SSE response channel closed"))
}
async fn close(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(task) = self.reader_task.take() {
task.abort();
}
Ok(())
}
}
@@ -282,4 +836,40 @@ mod tests {
};
assert!(SseTransport::new(&config).is_err());
}
#[test]
fn test_extract_json_from_sse_data_no_space() {
let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_with_event_and_id() {
let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_multiline_data() {
let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_uses_last_event_with_data() {
let input =
": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
}
+5 -6
View File
@@ -950,8 +950,9 @@ impl Tool for ModelRoutingConfigTool {
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
use std::sync::{Mutex, OnceLock};
use std::sync::OnceLock;
use tempfile::TempDir;
use tokio::sync::Mutex;
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
@@ -995,11 +996,9 @@ mod tests {
}
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
async fn env_lock() -> tokio::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
LOCK.get_or_init(|| Mutex::new(())).lock().await
}
async fn test_config(tmp: &TempDir) -> Arc<Config> {
@@ -1218,7 +1217,7 @@ mod tests {
#[tokio::test]
async fn get_reports_env_backed_credentials_for_routes_and_agents() {
let _env_lock = env_lock();
let _env_lock = env_lock().await;
let _provider_guard = EnvGuard::set("TELNYX_API_KEY", Some("test-telnyx-key"));
let _generic_guard = EnvGuard::set("ZEROCLAW_API_KEY", None);
let _api_key_guard = EnvGuard::set("API_KEY", None);
+84 -39
View File
@@ -10,11 +10,17 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
/// Canonical provider list for error messages and the tool description.
/// `fast_html2md` is kept as a deprecated alias for `nanohtml2text`.
const WEB_FETCH_PROVIDER_HELP: &str =
"Supported providers: 'nanohtml2text' (default), 'firecrawl', 'tavily'. \
Deprecated alias: 'fast_html2md' (maps to 'nanohtml2text').";
/// Web fetch tool: fetches a web page and returns text/markdown content for LLM consumption.
///
/// Providers:
/// - `fast_html2md`: fetch with reqwest, convert HTML to markdown
/// - `nanohtml2text`: fetch with reqwest, convert HTML to plaintext
/// - `nanohtml2text` (default): fetch with reqwest, strip noise elements, convert HTML to plaintext
/// - `fast_html2md` (deprecated alias): same as nanohtml2text unless `web-fetch-html2md` feature is compiled in
/// - `firecrawl`: fetch using Firecrawl cloud/self-hosted API
/// - `tavily`: fetch using Tavily Extract API
pub struct WebFetchTool {
@@ -33,6 +39,7 @@ pub struct WebFetchTool {
impl WebFetchTool {
#[allow(clippy::too_many_arguments)]
/// Creates a new `WebFetchTool`. `api_key` accepts comma-separated values for round-robin rotation.
pub fn new(
security: Arc<SecurityPolicy>,
provider: String,
@@ -59,7 +66,7 @@ impl WebFetchTool {
Self {
security,
provider: if provider.is_empty() {
"fast_html2md".to_string()
"nanohtml2text".to_string()
} else {
provider
},
@@ -75,6 +82,7 @@ impl WebFetchTool {
}
}
/// Returns the next API key from the rotation pool using round-robin, or `None` if unconfigured.
fn get_next_api_key(&self) -> Option<String> {
if self.api_keys.is_empty() {
return None;
@@ -83,6 +91,7 @@ impl WebFetchTool {
Some(self.api_keys[idx].clone())
}
/// Validates and normalises a URL against the allowlist, blocklist, and SSRF policy.
fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
validate_url(
raw_url,
@@ -99,6 +108,7 @@ impl WebFetchTool {
)
}
/// Truncates text to `max_response_size` characters and appends a marker if trimmed.
fn truncate_response(&self, text: &str) -> String {
if text.len() > self.max_response_size {
let mut truncated = text
@@ -112,6 +122,7 @@ impl WebFetchTool {
}
}
/// Returns the configured timeout, substituting a safe 30 s default if zero is set.
fn effective_timeout_secs(&self) -> u64 {
if self.timeout_secs == 0 {
tracing::warn!("web_fetch: timeout_secs is 0, using safe default of 30s");
@@ -121,40 +132,64 @@ impl WebFetchTool {
}
}
#[allow(unused_variables)]
/// Strips noisy structural HTML elements (nav, scripts, footers, etc.) before text
/// extraction to reduce boilerplate in the LLM output.
fn strip_noise_elements(html: &str) -> anyhow::Result<String> {
// Rust regex does not support backreferences, so run one pass per tag.
// OnceLock stores Result<_, String> so that a compile failure is surfaced as an
// error rather than a panic. String is used instead of anyhow::Error because it
// is Clone + Sync, which OnceLock requires.
use std::sync::OnceLock;
static NOISE_RES: OnceLock<Result<Vec<regex::Regex>, String>> = OnceLock::new();
let regexes = NOISE_RES
.get_or_init(|| {
[
"script", "style", "nav", "header", "footer", "aside", "noscript", "form",
"button",
]
.iter()
.map(|tag| {
regex::Regex::new(&format!(r"(?si)<{tag}[^>]*>.*?</{tag}>"))
.map_err(|e| e.to_string())
})
.collect::<Result<Vec<_>, _>>()
})
.as_ref()
.map_err(|e| anyhow::anyhow!("noise regex init failed: {e}"))?;
let mut result = html.to_string();
for re in regexes {
result = re.replace_all(&result, " ").into_owned();
}
Ok(result)
}
/// Strips noise elements then converts HTML to plain text using the configured provider.
/// `fast_html2md` is a deprecated alias that maps to `nanohtml2text` when the
/// `web-fetch-html2md` feature is not compiled in.
fn convert_html_to_output(&self, body: &str) -> anyhow::Result<String> {
let cleaned = Self::strip_noise_elements(body)?;
match self.provider.as_str() {
"fast_html2md" => {
#[cfg(feature = "web-fetch-html2md")]
{
Ok(html2md::rewrite_html(body, false))
Ok(html2md::rewrite_html(&cleaned, false))
}
#[cfg(not(feature = "web-fetch-html2md"))]
{
anyhow::bail!(
"web_fetch provider 'fast_html2md' requires Cargo feature 'web-fetch-html2md'"
);
}
}
"nanohtml2text" => {
#[cfg(feature = "web-fetch-plaintext")]
{
Ok(nanohtml2text::html2text(body))
}
#[cfg(not(feature = "web-fetch-plaintext"))]
{
anyhow::bail!(
"web_fetch provider 'nanohtml2text' requires Cargo feature 'web-fetch-plaintext'"
);
// Feature not compiled in; fall through to nanohtml2text.
Ok(nanohtml2text::html2text(&cleaned))
}
}
"nanohtml2text" => Ok(nanohtml2text::html2text(&cleaned)),
_ => anyhow::bail!(
"Unknown web_fetch provider: '{}'. Set [web_fetch].provider to 'fast_html2md', 'nanohtml2text', 'firecrawl', or 'tavily' in config.toml",
self.provider
"Unknown web_fetch provider: '{}'. {}",
self.provider,
WEB_FETCH_PROVIDER_HELP
),
}
}
/// Builds a `reqwest::Client` with the configured timeout, user-agent, and proxy settings.
fn build_http_client(&self) -> anyhow::Result<reqwest::Client> {
let builder = reqwest::Client::builder()
.timeout(Duration::from_secs(self.effective_timeout_secs()))
@@ -165,6 +200,8 @@ impl WebFetchTool {
Ok(builder.build()?)
}
/// Fetches `url` with reqwest, handles one redirect (re-validated), and converts the
/// response body to text via the configured HTML provider.
async fn fetch_with_http_provider(&self, url: &str) -> anyhow::Result<String> {
let client = self.build_http_client()?;
let response = client.get(url).send().await?;
@@ -221,6 +258,7 @@ impl WebFetchTool {
)
}
/// Fetches `url` via the Firecrawl scrape API and returns the extracted markdown content.
#[cfg(feature = "firecrawl")]
async fn fetch_with_firecrawl(&self, url: &str) -> anyhow::Result<String> {
let auth_token = self.get_next_api_key().ok_or_else(|| {
@@ -301,6 +339,7 @@ impl WebFetchTool {
anyhow::bail!("web_fetch provider 'firecrawl' requires Cargo feature 'firecrawl'")
}
/// Fetches `url` via the Tavily Extract API and returns the raw extracted content.
async fn fetch_with_tavily(&self, url: &str) -> anyhow::Result<String> {
let api_key = self.get_next_api_key().ok_or_else(|| {
anyhow::anyhow!(
@@ -374,7 +413,7 @@ impl Tool for WebFetchTool {
}
fn description(&self) -> &str {
"Fetch a web page and return markdown/text content for LLM consumption. Providers: fast_html2md, nanohtml2text, firecrawl, tavily. Security: allowlist-only domains, blocked_domains, and no local/private hosts."
"Fetch a web page and return text content for LLM consumption. Strips navigation, scripts, and boilerplate before extraction. Providers: nanohtml2text (default), firecrawl, tavily. Deprecated alias: fast_html2md. Security: allowlist-only domains, blocked_domains, and no local/private hosts."
}
fn parameters_schema(&self) -> serde_json::Value {
@@ -428,8 +467,9 @@ impl Tool for WebFetchTool {
"firecrawl" => self.fetch_with_firecrawl(&url).await,
"tavily" => self.fetch_with_tavily(&url).await,
_ => Err(anyhow::anyhow!(
"Unknown web_fetch provider: '{}'. Set [web_fetch].provider to 'fast_html2md', 'nanohtml2text', 'firecrawl', or 'tavily' in config.toml",
self.provider
"Unknown web_fetch provider: '{}'. {}",
self.provider,
WEB_FETCH_PROVIDER_HELP
)),
};
@@ -505,22 +545,12 @@ mod tests {
assert!(required.iter().any(|v| v.as_str() == Some("url")));
}
#[cfg(feature = "web-fetch-html2md")]
// Previously gated on cfg(feature = "web-fetch-html2md") / cfg(feature = "web-fetch-plaintext")
// — neither feature was declared in Cargo.toml so these tests never ran.
// Now always-on: fast_html2md falls back to nanohtml2text when uncompiled.
#[test]
fn html_to_markdown_conversion_preserves_structure() {
fn html_conversion_removes_tags() {
let tool = test_tool(vec!["example.com"]);
let html = "<html><body><h1>Title</h1><ul><li>Hello</li></ul></body></html>";
let markdown = tool.convert_html_to_output(html).unwrap();
assert!(markdown.contains("Title"));
assert!(markdown.contains("Hello"));
assert!(!markdown.contains("<h1>"));
}
#[cfg(feature = "web-fetch-plaintext")]
#[test]
fn html_to_plaintext_conversion_removes_html_tags() {
let tool =
test_tool_with_provider(vec!["example.com"], vec![], "nanohtml2text", None, None);
let html = "<html><body><h1>Title</h1><p>Hello <b>world</b></p></body></html>";
let text = tool.convert_html_to_output(html).unwrap();
assert!(text.contains("Title"));
@@ -528,6 +558,21 @@ mod tests {
assert!(!text.contains("<h1>"));
}
#[test]
fn strip_noise_removes_nav_scripts_footer() {
let tool = test_tool(vec!["example.com"]);
let html = "<html><body>\
<nav><a>Home</a><a>Menu</a></nav>\
<script>var x = 1;</script>\
<article><p>Real content here</p></article>\
<footer>Copyright 2025</footer>\
</body></html>";
let text = tool.convert_html_to_output(html).unwrap();
assert!(text.contains("Real content"));
assert!(!text.contains("var x"));
assert!(!text.contains("Copyright 2025"));
}
#[test]
fn validate_accepts_exact_domain() {
let tool = test_tool(vec!["example.com"]);
+8
View File
@@ -0,0 +1,8 @@
[build]
command = "npm run build"
publish = "dist"
[[redirects]]
from = "/*"
to = "/index.html"
status = 200
+15 -1
View File
@@ -7,11 +7,13 @@
"": {
"name": "zeroclaw-web",
"version": "0.1.0",
"license": "(MIT OR Apache-2.0)",
"dependencies": {
"lucide-react": "^0.468.0",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-router-dom": "^7.1.1"
"react-router-dom": "^7.1.1",
"smol-toml": "^1.3.1"
},
"devDependencies": {
"@tailwindcss/vite": "^4.0.0",
@@ -2327,6 +2329,18 @@
"integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==",
"license": "MIT"
},
"node_modules/smol-toml": {
"version": "1.6.0",
"resolved": "https://registry.npmjs.org/smol-toml/-/smol-toml-1.6.0.tgz",
"integrity": "sha512-4zemZi0HvTnYwLfrpk/CF9LOd9Lt87kAt50GnqhMpyF9U3poDAP2+iukq2bZsO/ufegbYehBkqINbsWxj4l4cw==",
"license": "BSD-3-Clause",
"engines": {
"node": ">= 18"
},
"funding": {
"url": "https://github.com/sponsors/cyyynthia"
}
},
"node_modules/source-map-js": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
+2 -1
View File
@@ -13,7 +13,8 @@
"lucide-react": "^0.468.0",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-router-dom": "^7.1.1"
"react-router-dom": "^7.1.1",
"smol-toml": "^1.3.1"
},
"devDependencies": {
"@tailwindcss/vite": "^4.0.0",
@@ -0,0 +1,121 @@
import { useState, useMemo } from 'react';
import { Search } from 'lucide-react';
import { CONFIG_SECTIONS } from './configSections';
import ConfigSection from './ConfigSection';
import type { FieldDef } from './types';
const CATEGORY_ORDER = [
{ key: 'all', label: 'All' },
{ key: 'general', label: 'General' },
{ key: 'security', label: 'Security' },
{ key: 'channels', label: 'Channels' },
{ key: 'runtime', label: 'Runtime' },
{ key: 'tools', label: 'Tools' },
{ key: 'memory', label: 'Memory' },
{ key: 'network', label: 'Network' },
{ key: 'advanced', label: 'Advanced' },
] as const;
interface Props {
getFieldValue: (sectionPath: string, fieldKey: string) => unknown;
setFieldValue: (sectionPath: string, fieldKey: string, value: unknown) => void;
isFieldMasked: (sectionPath: string, fieldKey: string) => boolean;
}
export default function ConfigFormEditor({
getFieldValue,
setFieldValue,
isFieldMasked,
}: Props) {
const [search, setSearch] = useState('');
const [activeCategory, setActiveCategory] = useState('all');
const isSearching = search.trim().length > 0;
const filteredSections = useMemo(() => {
if (isSearching) {
const q = search.toLowerCase();
return CONFIG_SECTIONS.map((section) => {
const titleMatch = section.title.toLowerCase().includes(q);
const descMatch = section.description?.toLowerCase().includes(q);
if (titleMatch || descMatch) {
return { section, fields: undefined };
}
const matchingFields = section.fields.filter(
(f: FieldDef) =>
f.label.toLowerCase().includes(q) ||
f.key.toLowerCase().includes(q) ||
f.description?.toLowerCase().includes(q),
);
if (matchingFields.length > 0) {
return { section, fields: matchingFields };
}
return null;
}).filter(Boolean) as { section: (typeof CONFIG_SECTIONS)[0]; fields: FieldDef[] | undefined }[];
}
// Category filter
const sections = activeCategory === 'all'
? CONFIG_SECTIONS
: CONFIG_SECTIONS.filter((s) => s.category === activeCategory);
return sections.map((s) => ({ section: s, fields: undefined }));
}, [search, isSearching, activeCategory]);
return (
<div className="space-y-3">
{/* Search */}
<div className="relative">
<Search className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-gray-500" />
<input
type="text"
value={search}
onChange={(e) => setSearch(e.target.value)}
placeholder="Search config fields..."
className="w-full bg-gray-800 border border-gray-700 rounded-lg pl-9 pr-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
</div>
{/* Category pills — hidden during search */}
{!isSearching && (
<div className="flex flex-wrap gap-2">
{CATEGORY_ORDER.map(({ key, label }) => (
<button
key={key}
onClick={() => setActiveCategory(key)}
className={`px-3 py-1 rounded-lg text-sm font-medium transition-colors ${
activeCategory === key
? 'bg-blue-600 text-white'
: 'bg-gray-900 text-gray-400 border border-gray-700 hover:bg-gray-800 hover:text-gray-200'
}`}
>
{label}
</button>
))}
</div>
)}
{/* Sections */}
{filteredSections.length === 0 ? (
<div className="text-center py-12 text-gray-500 text-sm">
No matching config fields found.
</div>
) : (
filteredSections.map(({ section, fields }) => (
<ConfigSection
key={section.path || '_root'}
section={fields ? { ...section, defaultCollapsed: false } : section}
getFieldValue={getFieldValue}
setFieldValue={setFieldValue}
isFieldMasked={isFieldMasked}
visibleFields={fields}
/>
))
)}
</div>
);
}
@@ -0,0 +1,29 @@
interface Props {
rawToml: string;
onChange: (raw: string) => void;
disabled?: boolean;
}
export default function ConfigRawEditor({ rawToml, onChange, disabled }: Props) {
return (
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
<div className="flex items-center justify-between px-4 py-2 border-b border-gray-800 bg-gray-800/50">
<span className="text-xs text-gray-400 font-medium uppercase tracking-wider">
TOML Configuration
</span>
<span className="text-xs text-gray-500">
{rawToml.split('\n').length} lines
</span>
</div>
<textarea
value={rawToml}
onChange={(e) => onChange(e.target.value)}
disabled={disabled}
spellCheck={false}
aria-label="Raw TOML configuration editor"
className="w-full min-h-[500px] bg-gray-950 text-gray-200 font-mono text-sm p-4 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-inset disabled:opacity-50"
style={{ tabSize: 4 }}
/>
</div>
);
}
+131
View File
@@ -0,0 +1,131 @@
import { useEffect, useMemo, useState } from 'react';
import { ChevronRight, ChevronDown } from 'lucide-react';
import type { SectionDef, FieldDef } from './types';
import TextField from './fields/TextField';
import NumberField from './fields/NumberField';
import ToggleField from './fields/ToggleField';
import SelectField from './fields/SelectField';
import TagListField from './fields/TagListField';
interface Props {
section: SectionDef;
getFieldValue: (sectionPath: string, fieldKey: string) => unknown;
setFieldValue: (sectionPath: string, fieldKey: string, value: unknown) => void;
isFieldMasked: (sectionPath: string, fieldKey: string) => boolean;
visibleFields?: FieldDef[];
}
function renderField(
field: FieldDef,
value: unknown,
onChange: (v: unknown) => void,
isMasked: boolean,
) {
const props = { field, value, onChange, isMasked };
switch (field.type) {
case 'text':
case 'password':
return <TextField {...props} />;
case 'number':
return <NumberField {...props} />;
case 'toggle':
return <ToggleField {...props} />;
case 'select':
return <SelectField {...props} />;
case 'tag-list':
return <TagListField {...props} />;
default:
return <TextField {...props} />;
}
}
export default function ConfigSection({
section,
getFieldValue,
setFieldValue,
isFieldMasked,
visibleFields,
}: Props) {
const [collapsed, setCollapsed] = useState(section.defaultCollapsed ?? false);
const sectionPanelId = useMemo(
() =>
`config-section-${(section.path || 'root').replace(/[^a-zA-Z0-9_-]/g, '-')}`,
[section.path],
);
const Icon = section.icon;
const fields = visibleFields ?? section.fields;
useEffect(() => {
setCollapsed(section.defaultCollapsed ?? false);
}, [section.path, section.defaultCollapsed]);
return (
<div className="bg-gray-900 rounded-xl border border-gray-800">
<button
type="button"
onClick={() => setCollapsed(!collapsed)}
aria-expanded={!collapsed}
aria-controls={sectionPanelId}
className="w-full flex items-center gap-3 px-4 py-3 hover:bg-gray-800/30 transition-colors rounded-t-xl"
>
{collapsed ? (
<ChevronRight className="h-4 w-4 text-gray-500 flex-shrink-0" />
) : (
<ChevronDown className="h-4 w-4 text-gray-500 flex-shrink-0" />
)}
<Icon className="h-4 w-4 text-blue-400 flex-shrink-0" />
<span className="text-sm font-medium text-white">{section.title}</span>
{section.description && (
<span className="text-xs text-gray-500 hidden sm:inline">
{section.description}
</span>
)}
<span className="ml-auto text-xs text-gray-600">
{fields.length} {fields.length === 1 ? 'field' : 'fields'}
</span>
</button>
{!collapsed && (
<div
id={sectionPanelId}
className="border-t border-gray-800 px-4 py-4 grid grid-cols-1 sm:grid-cols-2 gap-x-4 gap-y-4"
>
{fields.map((field) => {
const value = getFieldValue(section.path, field.key);
const masked = isFieldMasked(section.path, field.key);
const spanFull = field.type === 'tag-list';
return (
<div key={field.key} className={`flex flex-col${spanFull ? ' sm:col-span-2' : ''}`}>
<label className="flex items-center gap-2 text-sm font-medium text-gray-300 mb-1.5">
<span>{field.label}</span>
{field.sensitive && (
<span className="text-[10px] text-yellow-400 bg-yellow-900/30 border border-yellow-800/50 px-1.5 py-0.5 rounded">
sensitive
</span>
)}
{masked && (
<span className="text-[10px] text-blue-400 bg-blue-900/30 border border-blue-800/50 px-1.5 py-0.5 rounded">
masked
</span>
)}
</label>
{field.description && field.type !== 'text' && field.type !== 'password' && field.type !== 'number' && (
<p className="text-xs text-gray-500 mb-1.5">{field.description}</p>
)}
<div className="mt-auto">
{renderField(
field,
value,
(v) => setFieldValue(section.path, field.key, v),
masked,
)}
</div>
</div>
);
})}
</div>
)}
</div>
);
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,41 @@
import type { FieldProps } from '../types';
export default function NumberField({ field, value, onChange }: FieldProps) {
const numValue = value === undefined || value === null || value === '' ? '' : Number(value);
return (
<input
type="number"
value={numValue}
onChange={(e) => {
const raw = e.target.value;
if (raw === '') {
onChange(undefined);
return;
}
const n = Number(raw);
if (!isNaN(n)) {
onChange(n);
}
}}
onBlur={(e) => {
if (field.step !== undefined && field.step < 1) {
return;
}
const raw = e.target.value;
if (raw === '') {
return;
}
const n = Number(raw);
if (!isNaN(n)) {
onChange(Math.floor(n));
}
}}
min={field.min}
max={field.max}
step={field.step ?? 1}
placeholder={field.description ?? ''}
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
);
}
@@ -0,0 +1,20 @@
import type { FieldProps } from '../types';
export default function SelectField({ field, value, onChange }: FieldProps) {
const strValue = (value as string) ?? '';
return (
<select
value={strValue}
onChange={(e) => onChange(e.target.value)}
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white focus:outline-none focus:ring-2 focus:ring-blue-500"
>
<option value="">Select...</option>
{field.options?.map((opt) => (
<option key={opt.value} value={opt.value}>
{opt.label}
</option>
))}
</select>
);
}
@@ -0,0 +1,60 @@
import { useState } from 'react';
import { X } from 'lucide-react';
import type { FieldProps } from '../types';
export default function TagListField({ field, value, onChange }: FieldProps) {
const [input, setInput] = useState('');
const tags: string[] = Array.isArray(value) ? value : [];
const addTag = (tag: string) => {
const trimmed = tag.trim();
if (trimmed && !tags.includes(trimmed)) {
onChange([...tags, trimmed]);
}
setInput('');
};
const removeTag = (index: number) => {
onChange(tags.filter((_, i) => i !== index));
};
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter' || e.key === ',') {
e.preventDefault();
addTag(input);
} else if (e.key === 'Backspace' && input === '' && tags.length > 0) {
removeTag(tags.length - 1);
}
};
return (
<div>
<div className="flex flex-wrap gap-1.5 mb-2">
{tags.map((tag, i) => (
<span
key={tag}
className="inline-flex items-center gap-1 bg-gray-700 text-gray-200 rounded-full px-2.5 py-0.5 text-xs"
>
{tag}
<button
type="button"
onClick={() => removeTag(i)}
className="text-gray-400 hover:text-white transition-colors"
>
<X className="h-3 w-3" />
</button>
</span>
))}
</div>
<input
type="text"
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={handleKeyDown}
onBlur={() => { if (input.trim()) addTag(input); }}
placeholder={field.tagPlaceholder ?? 'Type and press Enter to add'}
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
</div>
);
}
@@ -0,0 +1,39 @@
import { useState } from 'react';
import { Eye, EyeOff, Lock } from 'lucide-react';
import type { FieldProps } from '../types';
export default function TextField({ field, value, onChange, isMasked }: FieldProps) {
const [showPassword, setShowPassword] = useState(false);
const isPassword = field.type === 'password';
const strValue = isMasked ? '' : ((value as string) ?? '');
return (
<div className="relative">
<input
type={isPassword && !showPassword ? 'password' : 'text'}
value={strValue}
onChange={(e) => onChange(e.target.value)}
placeholder={isMasked ? 'Configured (masked)' : field.description ?? ''}
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500 pr-16"
/>
<div className="absolute right-2 top-1/2 -translate-y-1/2 flex items-center gap-1">
{isMasked && (
<Lock className="h-3.5 w-3.5 text-yellow-500" />
)}
{isPassword && (
<button
type="button"
onClick={() => setShowPassword(!showPassword)}
className="p-1 text-gray-400 hover:text-gray-200 transition-colors"
>
{showPassword ? (
<EyeOff className="h-3.5 w-3.5" />
) : (
<Eye className="h-3.5 w-3.5" />
)}
</button>
)}
</div>
</div>
);
}
@@ -0,0 +1,27 @@
import type { FieldProps } from '../types';
export default function ToggleField({ field, value, onChange }: FieldProps) {
const isOn = Boolean(value);
return (
<div className="flex items-center gap-3">
<button
type="button"
role="switch"
aria-checked={isOn}
aria-label={field.label}
onClick={() => onChange(!isOn)}
className={`relative inline-flex h-6 w-11 items-center rounded-full transition-colors ${
isOn ? 'bg-blue-600' : 'bg-gray-700'
}`}
>
<span
className={`inline-block h-4 w-4 transform rounded-full bg-white transition-transform ${
isOn ? 'translate-x-6' : 'translate-x-1'
}`}
/>
</button>
<span className="text-sm text-gray-400">{isOn ? 'Enabled' : 'Disabled'}</span>
</div>
);
}
+40
View File
@@ -0,0 +1,40 @@
import type { LucideIcon } from 'lucide-react';
export type FieldType =
| 'text'
| 'password'
| 'number'
| 'toggle'
| 'select'
| 'tag-list';
export interface FieldDef {
key: string;
label: string;
type: FieldType;
description?: string;
sensitive?: boolean;
defaultValue?: unknown;
options?: { value: string; label: string }[];
min?: number;
max?: number;
step?: number;
tagPlaceholder?: string;
}
export interface SectionDef {
path: string;
title: string;
description?: string;
icon: LucideIcon;
fields: FieldDef[];
defaultCollapsed?: boolean;
category?: string;
}
export interface FieldProps {
field: FieldDef;
value: unknown;
onChange: (value: unknown) => void;
isMasked: boolean;
}
+307
View File
@@ -0,0 +1,307 @@
import { useState, useCallback, useRef, useEffect } from 'react';
import { parse, stringify } from 'smol-toml';
import { getConfig, putConfig } from '@/lib/api';
const MASKED = '***MASKED***';
type ParsedConfig = Record<string, unknown>;
function deepClone<T>(obj: T): T {
return JSON.parse(JSON.stringify(obj));
}
/** Recursively scan for MASKED strings and collect their dotted paths. */
function scanMasked(obj: unknown, prefix: string, out: Set<string>) {
if (obj === null || obj === undefined) return;
if (typeof obj === 'string' && obj === MASKED) {
out.add(prefix);
return;
}
if (Array.isArray(obj)) {
obj.forEach((item, i) => {
scanMasked(item, `${prefix}.${i}`, out);
});
return;
}
if (typeof obj === 'object') {
for (const [k, v] of Object.entries(obj as Record<string, unknown>)) {
scanMasked(v, prefix ? `${prefix}.${k}` : k, out);
}
}
}
/** Navigate into an object by dotted path segments, returning the value. */
function getNestedValue(obj: unknown, segments: string[]): unknown {
let current: unknown = obj;
for (const seg of segments) {
if (current === null || current === undefined || typeof current !== 'object') return undefined;
current = (current as Record<string, unknown>)[seg];
}
return current;
}
/** Set a value in an object by dotted path segments, creating intermediates. */
function setNestedValue(obj: Record<string, unknown>, segments: string[], value: unknown) {
if (segments.length === 0) return;
let current: Record<string, unknown> = obj;
for (let i = 0; i < segments.length - 1; i++) {
const seg: string = segments[i]!;
if (current[seg] === undefined || current[seg] === null || typeof current[seg] !== 'object') {
current[seg] = {};
}
current = current[seg] as Record<string, unknown>;
}
const lastSeg: string = segments[segments.length - 1]!;
if (value === undefined || value === '') {
delete current[lastSeg];
} else {
current[lastSeg] = value;
}
}
export type EditorMode = 'form' | 'raw';
export interface ConfigFormState {
loading: boolean;
saving: boolean;
error: string | null;
success: string | null;
mode: EditorMode;
rawToml: string;
parsed: ParsedConfig;
maskedPaths: Set<string>;
dirtyPaths: Set<string>;
setMode: (mode: EditorMode) => boolean;
getFieldValue: (sectionPath: string, fieldKey: string) => unknown;
setFieldValue: (sectionPath: string, fieldKey: string, value: unknown) => void;
isFieldMasked: (sectionPath: string, fieldKey: string) => boolean;
isFieldDirty: (sectionPath: string, fieldKey: string) => boolean;
setRawToml: (raw: string) => void;
save: () => Promise<void>;
reload: () => Promise<void>;
clearMessages: () => void;
}
export function useConfigForm(): ConfigFormState {
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [error, setError] = useState<string | null>(null);
const [success, setSuccess] = useState<string | null>(null);
const [mode, setModeState] = useState<EditorMode>('form');
const [rawToml, setRawTomlState] = useState('');
const [parsed, setParsed] = useState<ParsedConfig>({});
const maskedPathsRef = useRef<Set<string>>(new Set());
const dirtyPathsRef = useRef<Set<string>>(new Set());
const successTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const [, forceRender] = useState(0);
const loadConfig = useCallback(async () => {
setLoading(true);
setError(null);
try {
const data = await getConfig();
const raw = typeof data === 'string' ? data : JSON.stringify(data, null, 2);
setRawTomlState(raw);
try {
const obj = parse(raw) as ParsedConfig;
setParsed(obj);
const masked = new Set<string>();
scanMasked(obj, '', masked);
maskedPathsRef.current = masked;
} catch {
// If TOML parse fails, start in raw mode
setParsed({});
maskedPathsRef.current = new Set();
setModeState('raw');
}
dirtyPathsRef.current = new Set();
} catch (err) {
setError(err instanceof Error ? err.message : 'Failed to load configuration');
} finally {
setLoading(false);
}
}, []);
// Load once on mount.
const hasLoaded = useRef(false);
useEffect(() => {
if (!hasLoaded.current) {
hasLoaded.current = true;
void loadConfig();
}
}, [loadConfig]);
useEffect(() => {
return () => {
if (successTimeoutRef.current) {
clearTimeout(successTimeoutRef.current);
}
};
}, []);
const fieldPath = (sectionPath: string, fieldKey: string) =>
sectionPath ? `${sectionPath}.${fieldKey}` : fieldKey;
const fieldSegments = (sectionPath: string, fieldKey: string) => {
const full = fieldPath(sectionPath, fieldKey);
return full.split('.').filter(Boolean);
};
const getFieldValue = useCallback(
(sectionPath: string, fieldKey: string): unknown => {
const segments = fieldSegments(sectionPath, fieldKey);
return getNestedValue(parsed, segments);
},
[parsed],
);
const setFieldValue = useCallback(
(sectionPath: string, fieldKey: string, value: unknown) => {
const fp = fieldPath(sectionPath, fieldKey);
const segments = fieldSegments(sectionPath, fieldKey);
setParsed((prev) => {
const next = deepClone(prev);
setNestedValue(next, segments, value);
return next;
});
dirtyPathsRef.current.add(fp);
forceRender((n) => n + 1);
},
[],
);
const isFieldMasked = useCallback(
(sectionPath: string, fieldKey: string): boolean => {
const fp = fieldPath(sectionPath, fieldKey);
return maskedPathsRef.current.has(fp) && !dirtyPathsRef.current.has(fp);
},
[],
);
const isFieldDirty = useCallback(
(sectionPath: string, fieldKey: string): boolean => {
const fp = fieldPath(sectionPath, fieldKey);
return dirtyPathsRef.current.has(fp);
},
[],
);
const syncFormToRaw = useCallback((): string => {
try {
const toml = stringify(parsed);
return toml;
} catch {
return rawToml;
}
}, [parsed, rawToml]);
const syncRawToForm = useCallback(
(raw: string): boolean => {
try {
const obj = parse(raw) as ParsedConfig;
setParsed(obj);
// Re-scan masked paths from fresh parse, preserving dirty overrides
const masked = new Set<string>();
scanMasked(obj, '', masked);
maskedPathsRef.current = masked;
return true;
} catch {
return false;
}
},
[],
);
const setMode = useCallback(
(newMode: EditorMode): boolean => {
if (newMode === mode) return true;
if (newMode === 'raw') {
// form → raw: serialize parsed to TOML
const toml = syncFormToRaw();
setRawTomlState(toml);
setModeState('raw');
return true;
} else {
// raw → form: parse TOML
if (syncRawToForm(rawToml)) {
setModeState('form');
return true;
} else {
setError('Invalid TOML syntax. Fix errors before switching to Form view.');
return false;
}
}
},
[mode, syncFormToRaw, syncRawToForm, rawToml],
);
const setRawToml = useCallback((raw: string) => {
setRawTomlState(raw);
}, []);
const save = useCallback(async () => {
setSaving(true);
setError(null);
setSuccess(null);
if (successTimeoutRef.current) {
clearTimeout(successTimeoutRef.current);
}
try {
let toml: string;
if (mode === 'form') {
toml = syncFormToRaw();
} else {
toml = rawToml;
}
await putConfig(toml);
setSuccess('Configuration saved successfully.');
// Auto-dismiss success after 4 seconds
successTimeoutRef.current = setTimeout(() => setSuccess(null), 4000);
} catch (err) {
setError(err instanceof Error ? err.message : 'Failed to save configuration');
} finally {
setSaving(false);
}
}, [mode, syncFormToRaw, rawToml]);
const reload = useCallback(async () => {
await loadConfig();
}, [loadConfig]);
const clearMessages = useCallback(() => {
setError(null);
setSuccess(null);
if (successTimeoutRef.current) {
clearTimeout(successTimeoutRef.current);
successTimeoutRef.current = null;
}
}, []);
return {
loading,
saving,
error,
success,
mode,
rawToml,
parsed,
maskedPaths: maskedPathsRef.current,
dirtyPaths: dirtyPathsRef.current,
setMode,
getFieldValue,
setFieldValue,
isFieldMasked,
isFieldDirty,
setRawToml,
save,
reload,
clearMessages,
};
}
+17 -1
View File
@@ -19,6 +19,7 @@ export interface WebSocketClientOptions {
const DEFAULT_RECONNECT_DELAY = 1000;
const MAX_RECONNECT_DELAY = 30000;
const WS_SESSION_STORAGE_KEY = 'zeroclaw.ws.session_id';
export class WebSocketClient {
private ws: WebSocket | null = null;
@@ -35,6 +36,7 @@ export class WebSocketClient {
private readonly reconnectDelay: number;
private readonly maxReconnectDelay: number;
private readonly autoReconnect: boolean;
private readonly sessionId: string;
constructor(options: WebSocketClientOptions = {}) {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
@@ -44,6 +46,7 @@ export class WebSocketClient {
this.maxReconnectDelay = options.maxReconnectDelay ?? MAX_RECONNECT_DELAY;
this.autoReconnect = options.autoReconnect ?? true;
this.currentDelay = this.reconnectDelay;
this.sessionId = this.resolveSessionId();
}
/** Open the WebSocket connection. */
@@ -52,7 +55,7 @@ export class WebSocketClient {
this.clearReconnectTimer();
const token = getToken();
const url = `${this.baseUrl}/ws/chat`;
const url = `${this.baseUrl}/ws/chat?session_id=${encodeURIComponent(this.sessionId)}`;
const protocols = ['zeroclaw.v1'];
if (token) {
protocols.push(`bearer.${token}`);
@@ -126,4 +129,17 @@ export class WebSocketClient {
this.reconnectTimer = null;
}
}
private resolveSessionId(): string {
const existing = window.localStorage.getItem(WS_SESSION_STORAGE_KEY);
if (existing && /^[A-Za-z0-9_-]{1,128}$/.test(existing)) {
return existing;
}
const generated =
globalThis.crypto?.randomUUID?.().replace(/-/g, '_') ??
`sess_${Date.now().toString(36)}_${Math.random().toString(36).slice(2, 10)}`;
window.localStorage.setItem(WS_SESSION_STORAGE_KEY, generated);
return generated;
}
}
+16
View File
@@ -54,6 +54,22 @@ export default function AgentChat() {
ws.onMessage = (msg: WsMessage) => {
switch (msg.type) {
case 'history': {
const restored = (msg.messages ?? [])
.filter((entry) => entry.content?.trim())
.map((entry) => ({
id: makeMessageId(),
role: entry.role === 'user' ? 'user' : 'agent',
content: entry.content.trim(),
timestamp: new Date(),
}));
setMessages(restored);
setTyping(false);
pendingContentRef.current = '';
break;
}
case 'chunk':
setTyping(true);
pendingContentRef.current += msg.content ?? '';
+93 -65
View File
@@ -1,50 +1,61 @@
import { useState, useEffect } from 'react';
import {
Settings,
Save,
CheckCircle,
AlertTriangle,
ShieldAlert,
FileText,
SlidersHorizontal,
} from 'lucide-react';
import { getConfig, putConfig } from '@/lib/api';
import { useConfigForm, type EditorMode } from '@/components/config/useConfigForm';
import ConfigFormEditor from '@/components/config/ConfigFormEditor';
import ConfigRawEditor from '@/components/config/ConfigRawEditor';
function ModeTab({
mode,
active,
icon: Icon,
label,
onClick,
}: {
mode: EditorMode;
active: boolean;
icon: React.ComponentType<{ className?: string }>;
label: string;
onClick: () => void;
}) {
return (
<button
onClick={onClick}
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-lg text-sm font-medium transition-colors ${
active
? 'bg-blue-600 text-white'
: 'text-gray-400 hover:text-gray-200 hover:bg-gray-800'
}`}
aria-pressed={active}
data-mode={mode}
>
<Icon className="h-3.5 w-3.5" />
{label}
</button>
);
}
export default function Config() {
const [config, setConfig] = useState('');
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [error, setError] = useState<string | null>(null);
const [success, setSuccess] = useState<string | null>(null);
useEffect(() => {
getConfig()
.then((data) => {
// The API may return either a raw string or a JSON string
setConfig(typeof data === 'string' ? data : JSON.stringify(data, null, 2));
})
.catch((err) => setError(err.message))
.finally(() => setLoading(false));
}, []);
const handleSave = async () => {
setSaving(true);
setError(null);
setSuccess(null);
try {
await putConfig(config);
setSuccess('Configuration saved successfully.');
} catch (err: unknown) {
setError(err instanceof Error ? err.message : 'Failed to save configuration');
} finally {
setSaving(false);
}
};
// Auto-dismiss success after 4 seconds
useEffect(() => {
if (!success) return;
const timer = setTimeout(() => setSuccess(null), 4000);
return () => clearTimeout(timer);
}, [success]);
const {
loading,
saving,
error,
success,
mode,
rawToml,
setMode,
getFieldValue,
setFieldValue,
isFieldMasked,
setRawToml,
save,
} = useConfigForm();
if (loading) {
return (
@@ -62,14 +73,34 @@ export default function Config() {
<Settings className="h-5 w-5 text-blue-400" />
<h2 className="text-base font-semibold text-white">Configuration</h2>
</div>
<button
onClick={handleSave}
disabled={saving}
className="flex items-center gap-2 bg-blue-600 hover:bg-blue-700 text-white text-sm font-medium px-4 py-2 rounded-lg transition-colors disabled:opacity-50"
>
<Save className="h-4 w-4" />
{saving ? 'Saving...' : 'Save'}
</button>
<div className="flex items-center gap-3">
{/* Mode toggle */}
<div className="flex items-center gap-1 bg-gray-900 border border-gray-800 rounded-lg p-0.5">
<ModeTab
mode="form"
active={mode === 'form'}
icon={SlidersHorizontal}
label="Form"
onClick={() => setMode('form')}
/>
<ModeTab
mode="raw"
active={mode === 'raw'}
icon={FileText}
label="Raw"
onClick={() => setMode('raw')}
/>
</div>
<button
onClick={save}
disabled={saving}
className="flex items-center gap-2 bg-blue-600 hover:bg-blue-700 text-white text-sm font-medium px-4 py-2 rounded-lg transition-colors disabled:opacity-50"
>
<Save className="h-4 w-4" />
{saving ? 'Saving...' : 'Save'}
</button>
</div>
</div>
{/* Sensitive fields note */}
@@ -80,8 +111,9 @@ export default function Config() {
Sensitive fields are masked
</p>
<p className="text-sm text-yellow-400/70 mt-0.5">
API keys, tokens, and passwords are hidden for security. To update a
masked field, replace the entire masked value with your new value.
{mode === 'form'
? 'Masked fields show "Configured (masked)" as a placeholder. Leave them untouched to preserve existing values, or enter a new value to update.'
: 'API keys, tokens, and passwords are hidden for security. To update a masked field, replace the entire masked value with your new value.'}
</p>
</div>
</div>
@@ -102,24 +134,20 @@ export default function Config() {
</div>
)}
{/* Config Editor */}
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
<div className="flex items-center justify-between px-4 py-2 border-b border-gray-800 bg-gray-800/50">
<span className="text-xs text-gray-400 font-medium uppercase tracking-wider">
TOML Configuration
</span>
<span className="text-xs text-gray-500">
{config.split('\n').length} lines
</span>
</div>
<textarea
value={config}
onChange={(e) => setConfig(e.target.value)}
spellCheck={false}
className="w-full min-h-[500px] bg-gray-950 text-gray-200 font-mono text-sm p-4 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-inset"
style={{ tabSize: 4 }}
{/* Editor */}
{mode === 'form' ? (
<ConfigFormEditor
getFieldValue={getFieldValue}
setFieldValue={setFieldValue}
isFieldMasked={isFieldMasked}
/>
</div>
) : (
<ConfigRawEditor
rawToml={rawToml}
onChange={setRawToml}
disabled={saving}
/>
)}
</div>
);
}
+6 -1
View File
@@ -131,11 +131,16 @@ export interface SSEEvent {
}
export interface WsMessage {
type: 'message' | 'chunk' | 'tool_call' | 'tool_result' | 'done' | 'error';
type: 'message' | 'chunk' | 'tool_call' | 'tool_result' | 'done' | 'error' | 'history';
content?: string;
full_response?: string;
name?: string;
args?: any;
output?: string;
message?: string;
session_id?: string;
messages?: Array<{
role: 'user' | 'assistant';
content: string;
}>;
}