Merge remote-tracking branch 'upstream/main' into feat/channel-bluebubbles
# Conflicts: # src/channels/mod.rs
This commit is contained in:
@@ -212,7 +212,7 @@ jobs:
|
||||
|
||||
- name: Link check (offline, added links only)
|
||||
if: steps.collect_links.outputs.count != '0'
|
||||
uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2
|
||||
uses: lycheeverse/lychee-action@8646ba30535128ac92d33dfc9133794bfdd9b411 # v2
|
||||
with:
|
||||
fail: true
|
||||
args: >-
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
name: Deploy Web to GitHub Pages
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, dev]
|
||||
paths:
|
||||
- 'web/**'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
concurrency:
|
||||
group: "pages"
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
working-directory: ./web
|
||||
run: npm run build
|
||||
|
||||
- name: Setup Pages
|
||||
uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4
|
||||
with:
|
||||
path: ./web/dist
|
||||
|
||||
deploy:
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
|
||||
- name: Link check (added links)
|
||||
if: github.event_name != 'workflow_dispatch' && steps.links.outputs.count != '0'
|
||||
uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2
|
||||
uses: lycheeverse/lychee-action@8646ba30535128ac92d33dfc9133794bfdd9b411 # v2
|
||||
with:
|
||||
fail: true
|
||||
args: >-
|
||||
|
||||
+6
-1
@@ -1,4 +1,6 @@
|
||||
/target
|
||||
/target_ci
|
||||
/target_review*
|
||||
firmware/*/target
|
||||
*.db
|
||||
*.db-journal
|
||||
@@ -12,6 +14,7 @@ site/node_modules/
|
||||
site/.vite/
|
||||
site/public/docs-content/
|
||||
gh-pages/
|
||||
|
||||
.idea
|
||||
|
||||
# Environment files (may contain secrets)
|
||||
@@ -30,10 +33,12 @@ venv/
|
||||
|
||||
# Secret keys and credentials
|
||||
.secret_key
|
||||
otp-secret
|
||||
*.key
|
||||
*.pem
|
||||
credentials.json
|
||||
/config.toml
|
||||
.worktrees/
|
||||
|
||||
# Nix
|
||||
result
|
||||
result
|
||||
|
||||
Generated
+49
-35
@@ -540,16 +540,6 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a45f9771ced8a774de5e5ebffbe520f52e3943bf5a9a6baa3a5d14a5de1afe6"
|
||||
|
||||
[[package]]
|
||||
name = "bcder"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f7c42c9913f68cf9390a225e81ad56a5c515347287eb98baa710090ca1de86d"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bech32"
|
||||
version = "0.11.1"
|
||||
@@ -1655,9 +1645,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
|
||||
dependencies = [
|
||||
"const-oid",
|
||||
"der_derive",
|
||||
"flagset",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der_derive"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.5.8"
|
||||
@@ -2204,6 +2207,12 @@ version = "0.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
|
||||
|
||||
[[package]]
|
||||
name = "flagset"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7ac824320a75a52197e8f2d787f6a38b6718bb6897a35142d749af3c0e8f4fe"
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.1.9"
|
||||
@@ -4604,16 +4613,6 @@ dependencies = [
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pem"
|
||||
version = "3.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.2"
|
||||
@@ -6821,6 +6820,27 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tls_codec"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b"
|
||||
dependencies = [
|
||||
"tls_codec_derive",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tls_codec_derive"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.49.0"
|
||||
@@ -6876,16 +6896,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres-rustls"
|
||||
version = "0.12.0"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
|
||||
checksum = "27d684bad428a0f2481f42241f821db42c54e2dc81d8c00db8536c506b0a0144"
|
||||
dependencies = [
|
||||
"const-oid",
|
||||
"ring",
|
||||
"rustls",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-rustls",
|
||||
"x509-certificate",
|
||||
"x509-cert",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8997,22 +9018,15 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x509-certificate"
|
||||
version = "0.23.1"
|
||||
name = "x509-cert"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "66534846dec7a11d7c50a74b7cdb208b9a581cad890b7866430d438455847c85"
|
||||
checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94"
|
||||
dependencies = [
|
||||
"bcder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"const-oid",
|
||||
"der",
|
||||
"hex",
|
||||
"pem",
|
||||
"ring",
|
||||
"signature",
|
||||
"spki",
|
||||
"thiserror 1.0.69",
|
||||
"zeroize",
|
||||
"tls_codec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
+1
-2
@@ -110,7 +110,7 @@ prost = { version = "0.14", default-features = false, features = ["derive"], opt
|
||||
# Memory / persistence
|
||||
rusqlite = { version = "0.37", features = ["bundled"] }
|
||||
postgres = { version = "0.19", features = ["with-chrono-0_4"], optional = true }
|
||||
tokio-postgres-rustls = { version = "0.12", optional = true }
|
||||
tokio-postgres-rustls = { version = "0.13", optional = true }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
|
||||
chrono-tz = "0.10"
|
||||
cron = "0.15"
|
||||
@@ -238,7 +238,6 @@ whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-pr
|
||||
# Keep disabled by default to preserve current runtime behavior.
|
||||
firecrawl = []
|
||||
web-fetch-html2md = []
|
||||
web-fetch-plaintext = []
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
|
||||
@@ -46,7 +46,7 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<strong>Fast, small, and fully autonomous AI assistant infrastructure</strong><br />
|
||||
<strong>Fast, small, and fully autonomous Operating System</strong><br />
|
||||
Deploy anywhere. Swap anything.
|
||||
</p>
|
||||
|
||||
@@ -81,6 +81,45 @@ Use this board for important notices (breaking changes, security advisories, mai
|
||||
- **Fully swappable:** core systems are traits (providers, channels, tools, memory, tunnels).
|
||||
- **No lock-in:** OpenAI-compatible provider support + pluggable custom endpoints.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Option 1: Homebrew (macOS/Linuxbrew)
|
||||
|
||||
```bash
|
||||
brew install zeroclaw
|
||||
```
|
||||
|
||||
### Option 2: Clone + Bootstrap
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zeroclaw-labs/zeroclaw.git
|
||||
cd zeroclaw
|
||||
./bootstrap.sh
|
||||
```
|
||||
|
||||
> **Note:** Source builds require ~2GB RAM and ~6GB disk. For resource-constrained systems, use `./bootstrap.sh --prefer-prebuilt` to download a pre-built binary instead.
|
||||
|
||||
### Option 3: Cargo Install
|
||||
|
||||
```bash
|
||||
cargo install zeroclaw
|
||||
```
|
||||
|
||||
### First Run
|
||||
|
||||
```bash
|
||||
# Start the gateway daemon
|
||||
zeroclaw gateway start
|
||||
|
||||
# Open the web UI
|
||||
zeroclaw dashboard
|
||||
|
||||
# Or chat directly
|
||||
zeroclaw chat "Hello!"
|
||||
```
|
||||
|
||||
For detailed setup options, see [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md).
|
||||
|
||||
## Benchmark Snapshot (ZeroClaw vs OpenClaw, Reproducible)
|
||||
|
||||
Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware.
|
||||
@@ -99,16 +138,9 @@ Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge
|
||||
<img src="zero-claw.jpeg" alt="ZeroClaw vs OpenClaw Comparison" width="800" />
|
||||
</p>
|
||||
|
||||
### 🙏 Special Thanks
|
||||
---
|
||||
|
||||
A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work:
|
||||
|
||||
- **Harvard University** — for fostering intellectual curiosity and pushing the boundaries of what's possible.
|
||||
- **MIT** — for championing open knowledge, open source, and the belief that technology should be accessible to everyone.
|
||||
- **Sundai Club** — for the community, the energy, and the relentless drive to build things that matter.
|
||||
- **The World & Beyond** 🌍✨ — to every contributor, dreamer, and builder out there making open source a force for good. This is for you.
|
||||
|
||||
We're building in the open because the best ideas come from everywhere. If you're reading this, you're part of it. Welcome. 🦀❤️
|
||||
For full documentation, see [`docs/README.md`](docs/README.md) | [`docs/SUMMARY.md`](docs/SUMMARY.md)
|
||||
|
||||
## ⚠️ Official Repository & Impersonation Warning
|
||||
|
||||
@@ -133,31 +165,9 @@ ZeroClaw is dual-licensed for maximum openness and contributor protection:
|
||||
|
||||
You may choose either license. **Contributors automatically grant rights under both** — see [CLA.md](CLA.md) for the full contributor agreement.
|
||||
|
||||
### Trademark
|
||||
|
||||
The **ZeroClaw** name and logo are trademarks of ZeroClaw Labs. This license does not grant permission to use them to imply endorsement or affiliation. See [TRADEMARK.md](TRADEMARK.md) for permitted and prohibited uses.
|
||||
|
||||
### Contributor Protections
|
||||
|
||||
- You **retain copyright** of your contributions
|
||||
- **Patent grant** (Apache 2.0) shields you from patent claims by other contributors
|
||||
- Your contributions are **permanently attributed** in commit history and [NOTICE](NOTICE)
|
||||
- No trademark rights are transferred by contributing
|
||||
|
||||
## Contributing
|
||||
|
||||
New to ZeroClaw? Look for issues labeled [`good first issue`](https://github.com/zeroclaw-labs/zeroclaw/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) — see our [Contributing Guide](CONTRIBUTING.md#first-time-contributors) for how to get started.
|
||||
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md) and [CLA.md](CLA.md). Implement a trait, submit a PR:
|
||||
|
||||
- CI workflow guide: [docs/ci-map.md](docs/ci-map.md)
|
||||
- New `Provider` → `src/providers/`
|
||||
- New `Channel` → `src/channels/`
|
||||
- New `Observer` → `src/observability/`
|
||||
- New `Tool` → `src/tools/`
|
||||
- New `Memory` → `src/memory/`
|
||||
- New `Tunnel` → `src/tunnel/`
|
||||
- New `Skill` → `~/.zeroclaw/workspace/skills/<name>/`
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md) and [CLA.md](CLA.md). Implement a trait, submit a PR.
|
||||
|
||||
---
|
||||
|
||||
|
||||
+2
-1
@@ -2,7 +2,7 @@
|
||||
|
||||
This file is the canonical table of contents for the documentation system.
|
||||
|
||||
Last refreshed: **February 25, 2026**.
|
||||
Last refreshed: **February 28, 2026**.
|
||||
|
||||
## Language Entry
|
||||
|
||||
@@ -110,5 +110,6 @@ Last refreshed: **February 25, 2026**.
|
||||
- [project/README.md](project/README.md)
|
||||
- [project-triage-snapshot-2026-02-18.md](project-triage-snapshot-2026-02-18.md)
|
||||
- [docs-audit-2026-02-24.md](docs-audit-2026-02-24.md)
|
||||
- [project/m4-5-rfi-spike-2026-02-28.md](project/m4-5-rfi-spike-2026-02-28.md)
|
||||
- [i18n-gap-backlog.md](i18n-gap-backlog.md)
|
||||
- [docs-inventory.md](docs-inventory.md)
|
||||
|
||||
@@ -146,6 +146,7 @@ If `[channels_config.matrix]`, `[channels_config.lark]`, or `[channels_config.fe
|
||||
| Napcat | websocket receive + HTTP send (OneBot) | No (typically local/LAN) |
|
||||
| Linq | webhook (`/linq`) | Yes (public HTTPS callback) |
|
||||
| iMessage | local integration | No |
|
||||
| ACP | stdio (JSON-RPC 2.0) | No |
|
||||
| Nostr | relay websocket (NIP-04 / NIP-17) | No |
|
||||
|
||||
---
|
||||
@@ -160,7 +161,7 @@ For channels with inbound sender allowlists:
|
||||
|
||||
Field names differ by channel:
|
||||
|
||||
- `allowed_users` (Telegram/Discord/Slack/Mattermost/Matrix/IRC/Lark/Feishu/DingTalk/QQ/Napcat/Nextcloud Talk)
|
||||
- `allowed_users` (Telegram/Discord/Slack/Mattermost/Matrix/IRC/Lark/Feishu/DingTalk/QQ/Napcat/Nextcloud Talk/ACP)
|
||||
- `allowed_from` (Signal)
|
||||
- `allowed_numbers` (WhatsApp)
|
||||
- `allowed_senders` (Email/Linq)
|
||||
@@ -540,6 +541,25 @@ Notes:
|
||||
allowed_contacts = ["*"]
|
||||
```
|
||||
|
||||
### 4.18 ACP
|
||||
|
||||
ACP (Agent Client Protocol) enables ZeroClaw to act as a client for OpenCode ACP server,
|
||||
allowing remote control of OpenCode behavior through JSON-RPC 2.0 communication over stdio.
|
||||
|
||||
```toml
|
||||
[channels_config.acp]
|
||||
opencode_path = "opencode" # optional, default: "opencode"
|
||||
workdir = "/path/to/workspace" # optional
|
||||
extra_args = [] # optional additional arguments to `opencode acp`
|
||||
allowed_users = ["*"] # empty = deny all, "*" = allow all
|
||||
```
|
||||
|
||||
Notes:
|
||||
- ACP uses JSON-RPC 2.0 protocol over stdio with newline-delimited messages.
|
||||
- Requires `opencode` binary in PATH or specified via `opencode_path`.
|
||||
- The channel starts OpenCode subprocess via `opencode acp` command.
|
||||
- Responses from OpenCode can be sent back to the originating channel when configured.
|
||||
|
||||
---
|
||||
|
||||
## 5. Validation Workflow
|
||||
@@ -588,7 +608,7 @@ RUST_LOG=info zeroclaw daemon 2>&1 | tee /tmp/zeroclaw.log
|
||||
Then filter channel/gateway events:
|
||||
|
||||
```bash
|
||||
rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|DingTalk|QQ|iMessage|Nostr|Webhook|Channel" /tmp/zeroclaw.log
|
||||
rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|DingTalk|QQ|iMessage|Nostr|Webhook|Channel|ACP" /tmp/zeroclaw.log
|
||||
```
|
||||
|
||||
### 7.2 Keyword table
|
||||
@@ -610,6 +630,7 @@ rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|D
|
||||
| QQ | `QQ: connected and identified` | `QQ: ignoring C2C message from unauthorized user:` / `QQ: ignoring group message from unauthorized user:` | `QQ: received Reconnect (op 7)` / `QQ: received Invalid Session (op 9)` / `QQ: message channel closed` |
|
||||
| Nextcloud Talk (gateway) | `POST /nextcloud-talk — Nextcloud Talk bot webhook` | `Nextcloud Talk webhook signature verification failed` / `Nextcloud Talk: ignoring message from unauthorized actor:` | `Nextcloud Talk send failed:` / `LLM error for Nextcloud Talk message:` |
|
||||
| iMessage | `iMessage channel listening (AppleScript bridge)...` | (contact allowlist enforced by `allowed_contacts`) | `iMessage poll error:` |
|
||||
| ACP | `ACP channel started` | `ACP: ignoring message from unauthorized user:` | `ACP process exited unexpectedly:` / `ACP JSON-RPC timeout:` / `ACP process spawn failed:` |
|
||||
| Nostr | `Nostr channel listening as npub1...` | `Nostr: ignoring NIP-04 message from unauthorized pubkey:` / `Nostr: ignoring NIP-17 message from unauthorized pubkey:` | `Failed to decrypt NIP-04 message:` / `Failed to unwrap NIP-17 gift wrap:` / `Nostr relay pool shut down` |
|
||||
|
||||
### 7.3 Runtime supervisor keywords
|
||||
|
||||
@@ -38,6 +38,38 @@ Notes:
|
||||
- Unset keeps the provider's built-in default.
|
||||
- Environment override: `ZEROCLAW_MODEL_SUPPORT_VISION` or `MODEL_SUPPORT_VISION` (values: `true`/`false`/`1`/`0`/`yes`/`no`/`on`/`off`).
|
||||
|
||||
## `[model_providers.<profile>]`
|
||||
|
||||
Use named profiles to map a logical provider id to a provider name/base URL and optional profile-scoped credentials.
|
||||
|
||||
| Key | Default | Notes |
|
||||
|---|---|---|
|
||||
| `name` | unset | Optional provider id override (for example `openai`, `openai-codex`) |
|
||||
| `base_url` | unset | Optional OpenAI-compatible endpoint URL |
|
||||
| `wire_api` | unset | Optional protocol mode: `responses` or `chat_completions` |
|
||||
| `model` | unset | Optional profile-scoped default model |
|
||||
| `api_key` | unset | Optional profile-scoped API key (used when top-level `api_key` is empty) |
|
||||
| `requires_openai_auth` | `false` | Load OpenAI auth material (`OPENAI_API_KEY` / Codex auth file) |
|
||||
|
||||
Notes:
|
||||
|
||||
- If both top-level `api_key` and profile `api_key` are present, top-level `api_key` wins.
|
||||
- If top-level `default_model` is still the global OpenRouter default, profile `model` is used as an automatic compatibility override.
|
||||
- Secrets encryption applies to profile API keys when `secrets.encrypt = true`.
|
||||
|
||||
Example:
|
||||
|
||||
```toml
|
||||
default_provider = "sub2api"
|
||||
|
||||
[model_providers.sub2api]
|
||||
name = "sub2api"
|
||||
base_url = "https://api.example.com/v1"
|
||||
wire_api = "chat_completions"
|
||||
model = "qwen-max"
|
||||
api_key = "sk-profile-key"
|
||||
```
|
||||
|
||||
## `[observability]`
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
This inventory classifies documentation by intent and canonical location.
|
||||
|
||||
Last reviewed: **February 24, 2026**.
|
||||
Last reviewed: **February 28, 2026**.
|
||||
|
||||
## Classification Legend
|
||||
|
||||
@@ -124,6 +124,7 @@ These are valuable context, but **not strict runtime contracts**.
|
||||
|---|---|
|
||||
| `docs/project-triage-snapshot-2026-02-18.md` | Snapshot |
|
||||
| `docs/docs-audit-2026-02-24.md` | Snapshot (docs architecture audit) |
|
||||
| `docs/project/m4-5-rfi-spike-2026-02-28.md` | Snapshot (M4-5 workspace split RFI baseline and execution plan) |
|
||||
| `docs/i18n-gap-backlog.md` | Snapshot (i18n depth gap tracking) |
|
||||
|
||||
## Maintenance Contract
|
||||
|
||||
@@ -6,6 +6,7 @@ Time-bound project status snapshots for planning documentation and operations wo
|
||||
|
||||
- [../project-triage-snapshot-2026-02-18.md](../project-triage-snapshot-2026-02-18.md)
|
||||
- [../docs-audit-2026-02-24.md](../docs-audit-2026-02-24.md)
|
||||
- [m4-5-rfi-spike-2026-02-28.md](m4-5-rfi-spike-2026-02-28.md)
|
||||
|
||||
## Scope
|
||||
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
# M4-5 Multi-Crate Workspace RFI Spike (2026-02-28)
|
||||
|
||||
Status: RFI complete, extraction execution pending.
|
||||
Issue: [#2263](https://github.com/zeroclaw-labs/zeroclaw/issues/2263)
|
||||
Linear parent: RMN-243
|
||||
|
||||
## Scope
|
||||
|
||||
This spike is strictly no-behavior-change planning for the M4-5 workspace split.
|
||||
|
||||
Goals:
|
||||
- capture reproducible compile baseline metrics
|
||||
- define crate boundary and dependency contract
|
||||
- define CI/feature-matrix impact and rollback posture
|
||||
- define stacked PR slicing plan (XS/S/M)
|
||||
|
||||
Out of scope:
|
||||
- broad API redesign
|
||||
- feature additions bundled with structure work
|
||||
- one-shot mega-PR extraction
|
||||
|
||||
## Baseline Compile Metrics
|
||||
|
||||
### Repro command
|
||||
|
||||
```bash
|
||||
scripts/ci/m4_5_rfi_baseline.sh /tmp/zeroclaw-m4rfi-target
|
||||
```
|
||||
|
||||
### Preflight compile blockers observed on `origin/main`
|
||||
|
||||
Before timing could run cleanly, two compile blockers were found:
|
||||
|
||||
- `src/gateway/mod.rs:2176`: `run_gateway_chat_with_tools` call missing `session_id` argument
|
||||
- `src/providers/cursor.rs:233`: `ChatResponse` initializer missing `quota_metadata`
|
||||
|
||||
RFI includes minimal compile-compat fixes for these two blockers so measurements are reproducible.
|
||||
|
||||
### Measured results (Apple Silicon macOS, local workspace)
|
||||
|
||||
| Phase | real(s) | status |
|
||||
|---|---:|---|
|
||||
| A: cold `cargo check --workspace --locked` | 306.47 | pass |
|
||||
| B: cold-ish `cargo build --workspace --locked` | 219.07 | pass |
|
||||
| C: warm `cargo check --workspace --locked` | 0.84 | pass |
|
||||
| D: incremental `cargo check` after touching `src/main.rs` | 6.19 | pass |
|
||||
|
||||
Observations:
|
||||
- cold check is the dominant iteration tax
|
||||
- warm-check performance is excellent once target artifacts exist
|
||||
- incremental behavior is acceptable but sensitive to wide root-crate coupling
|
||||
|
||||
## Current Workspace Snapshot
|
||||
|
||||
Current workspace members:
|
||||
- `.` (`zeroclaw` monolith crate)
|
||||
- `crates/robot-kit`
|
||||
|
||||
Code concentration still sits in the monolith. Large hotspots include:
|
||||
- `src/config/schema.rs`
|
||||
- `src/channels/mod.rs`
|
||||
- `src/onboard/wizard.rs`
|
||||
- `src/agent/loop_.rs`
|
||||
- `src/gateway/mod.rs`
|
||||
|
||||
## Proposed Boundary Contract
|
||||
|
||||
Target crate topology for staged extraction:
|
||||
|
||||
1. `crates/zeroclaw-types`
|
||||
- shared DTOs, enums, IDs, lightweight cross-domain traits
|
||||
- no provider/channel/network dependencies
|
||||
|
||||
1. `crates/zeroclaw-core`
|
||||
- config structs + validation, provider trait contracts, routing primitives, policy helpers
|
||||
- depends on `zeroclaw-types`
|
||||
|
||||
1. `crates/zeroclaw-memory`
|
||||
- memory traits/backends + hygiene/snapshot plumbing
|
||||
- depends on `zeroclaw-types`, `zeroclaw-core` contracts only where required
|
||||
|
||||
1. `crates/zeroclaw-channels`
|
||||
- channel adapters + inbound normalization
|
||||
- depends on `zeroclaw-types`, `zeroclaw-core`, `zeroclaw-memory`
|
||||
|
||||
1. `crates/zeroclaw-api`
|
||||
- gateway/webhook/http orchestration
|
||||
- depends on `zeroclaw-types`, `zeroclaw-core`, `zeroclaw-memory`, `zeroclaw-channels`
|
||||
|
||||
1. `crates/zeroclaw-bin` (or keep root binary package name stable)
|
||||
- CLI entrypoints + wiring only
|
||||
|
||||
Dependency rules:
|
||||
- no downward imports from foundational crates into higher layers
|
||||
- channels must not depend on gateway/http crate
|
||||
- keep provider-specific SDK deps out of `zeroclaw-types`
|
||||
- maintain feature-flag parity at workspace root during migration
|
||||
|
||||
## CI / Feature-Matrix Impact
|
||||
|
||||
Required CI adjustments during migration:
|
||||
- add workspace compile lane (`cargo check --workspace --locked`)
|
||||
- add package-focused lanes for extracted crates (`-p zeroclaw-types`, `-p zeroclaw-core`, etc.)
|
||||
- keep existing runtime behavior lanes (`test`, `sec-audit`, `codeql`) unchanged until final convergence
|
||||
- update path filters so crate-local changes trigger only relevant crate tests plus contract smoke tests
|
||||
|
||||
Guardrails:
|
||||
- changed-line strict-delta lint remains mandatory
|
||||
- each extraction PR must include no-behavior-change assertion in PR body
|
||||
- each step must include explicit rollback note
|
||||
|
||||
## Rollback Strategy
|
||||
|
||||
Per-step rollback (stack-safe):
|
||||
1. revert latest extraction PR only
|
||||
2. re-run workspace compile + existing CI matrix
|
||||
3. keep binary entrypoint and config contract untouched until final extraction stage
|
||||
|
||||
Abort criteria:
|
||||
- unexpected runtime behavior drift
|
||||
- CI lane expansion causes recurring queue stalls without signal gain
|
||||
- feature-flag compatibility regressions
|
||||
|
||||
## Stacked PR Slicing Plan
|
||||
|
||||
### PR-1 (XS)
|
||||
|
||||
- add crate shells + workspace wiring (`types/core`), no symbol moves
|
||||
- objective: establish scaffolding and CI package lanes
|
||||
|
||||
### PR-2 (S)
|
||||
|
||||
- extract low-churn shared types into `zeroclaw-types`
|
||||
- add re-export shim layer to preserve existing import paths
|
||||
|
||||
### PR-3 (S)
|
||||
|
||||
- extract config/provider contracts into `zeroclaw-core`
|
||||
- keep runtime call sites unchanged via compatibility re-exports
|
||||
|
||||
### PR-4 (M)
|
||||
|
||||
- extract memory subsystem crate and move wiring boundaries
|
||||
- run full memory + gateway regression suite
|
||||
|
||||
### PR-5 (M)
|
||||
|
||||
- extract channels/api orchestration seams
|
||||
- finalize package ownership and remove temporary re-export shims
|
||||
|
||||
## Next Execution Step
|
||||
|
||||
Open first no-behavior-change extraction PR from this RFI baseline:
|
||||
- scope: workspace crate scaffolding + CI package lanes only
|
||||
- no runtime behavior changes
|
||||
- explicit rollback command included in PR body
|
||||
Executable
+67
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
|
||||
cat <<'USAGE'
|
||||
Usage: scripts/ci/m4_5_rfi_baseline.sh [target_dir]
|
||||
|
||||
Run reproducible compile-timing probes for the current workspace.
|
||||
The script prints a markdown table with real-time seconds and pass/fail status
|
||||
for each benchmark phase.
|
||||
USAGE
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
||||
TARGET_DIR="${1:-${ROOT_DIR}/target-rfi}"
|
||||
|
||||
cd "${ROOT_DIR}"
|
||||
|
||||
if [[ ! -f Cargo.toml ]]; then
|
||||
echo "error: Cargo.toml not found at ${ROOT_DIR}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_timed() {
|
||||
local label="$1"
|
||||
shift
|
||||
|
||||
local timing_file
|
||||
timing_file="$(mktemp)"
|
||||
local status="pass"
|
||||
|
||||
if /usr/bin/time -p "$@" >/dev/null 2>"${timing_file}"; then
|
||||
status="pass"
|
||||
else
|
||||
status="fail"
|
||||
fi
|
||||
|
||||
local real_time
|
||||
real_time="$(awk '/^real / { print $2 }' "${timing_file}")"
|
||||
rm -f "${timing_file}"
|
||||
|
||||
if [[ -z "${real_time}" ]]; then
|
||||
real_time="n/a"
|
||||
fi
|
||||
|
||||
printf '| %s | %s | %s |\n' "${label}" "${real_time}" "${status}"
|
||||
|
||||
[[ "${status}" == "pass" ]]
|
||||
}
|
||||
|
||||
printf '# M4-5 RFI Baseline\n\n'
|
||||
printf '- Timestamp (UTC): %s\n' "$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
printf '- Commit: `%s`\n' "$(git rev-parse --short HEAD)"
|
||||
printf '- Target dir: `%s`\n\n' "${TARGET_DIR}"
|
||||
printf '| Phase | real(s) | status |\n'
|
||||
printf '|---|---:|---|\n'
|
||||
|
||||
rm -rf "${TARGET_DIR}"
|
||||
|
||||
set +e
|
||||
run_timed "A: cold cargo check" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo check --workspace --locked
|
||||
run_timed "B: cold-ish cargo build" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo build --workspace --locked
|
||||
run_timed "C: warm cargo check" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo check --workspace --locked
|
||||
touch src/main.rs
|
||||
run_timed "D: incremental cargo check after touch src/main.rs" env CARGO_TARGET_DIR="${TARGET_DIR}" cargo check --workspace --locked
|
||||
set -e
|
||||
+24
-10
@@ -18,6 +18,8 @@ use std::io::Write as IoWrite;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20;
|
||||
|
||||
pub struct Agent {
|
||||
provider: Box<dyn Provider>,
|
||||
tools: Vec<Box<dyn Tool>>,
|
||||
@@ -218,9 +220,7 @@ impl AgentBuilder {
|
||||
.memory_loader
|
||||
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
|
||||
config: self.config.unwrap_or_default(),
|
||||
model_name: self
|
||||
.model_name
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
|
||||
model_name: crate::config::resolve_default_model_id(self.model_name.as_deref(), None),
|
||||
temperature: self.temperature.unwrap_or(0.7),
|
||||
workspace_dir: self
|
||||
.workspace_dir
|
||||
@@ -298,11 +298,10 @@ impl Agent {
|
||||
|
||||
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||
|
||||
let model_name = config
|
||||
.default_model
|
||||
.as_deref()
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||
.to_string();
|
||||
let model_name = crate::config::resolve_default_model_id(
|
||||
config.default_model.as_deref(),
|
||||
Some(provider_name),
|
||||
);
|
||||
|
||||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
@@ -598,6 +597,17 @@ impl Agent {
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
final_text.clone(),
|
||||
)));
|
||||
if self.auto_save && final_text.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let _ = self
|
||||
.memory
|
||||
.store(
|
||||
"assistant_resp",
|
||||
&final_text,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
self.trim_history();
|
||||
|
||||
return Ok(final_text);
|
||||
@@ -714,8 +724,12 @@ pub async fn run(
|
||||
let model_name = effective_config
|
||||
.default_model
|
||||
.as_deref()
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||
.to_string();
|
||||
.map(str::trim)
|
||||
.filter(|m| !m.is_empty())
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| {
|
||||
crate::config::default_model_fallback_for_provider(Some(&provider_name)).to_string()
|
||||
});
|
||||
|
||||
agent.observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name.clone(),
|
||||
|
||||
+64
-20
@@ -10,7 +10,7 @@ use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context as _, Result};
|
||||
use regex::{Regex, RegexSet};
|
||||
use rustyline::completion::{Completer, Pair};
|
||||
use rustyline::error::ReadlineError;
|
||||
@@ -33,6 +33,7 @@ mod execution;
|
||||
mod history;
|
||||
mod parsing;
|
||||
|
||||
use crate::agent::session::{resolve_session_id, shared_session_manager};
|
||||
use context::{build_context, build_hardware_context};
|
||||
use detection::{DetectionVerdict, LoopDetectionConfig, LoopDetector};
|
||||
use execution::{
|
||||
@@ -981,7 +982,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&parse_issue),
|
||||
Some(parse_issue),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"response_excerpt": truncate_with_ellipsis(
|
||||
@@ -1816,10 +1817,12 @@ pub async fn run(
|
||||
.or(config.default_provider.as_deref())
|
||||
.unwrap_or("openrouter");
|
||||
|
||||
let model_name = model_override
|
||||
.as_deref()
|
||||
.or(config.default_model.as_deref())
|
||||
.unwrap_or("anthropic/claude-sonnet-4");
|
||||
let model_name = crate::config::resolve_default_model_id(
|
||||
model_override
|
||||
.as_deref()
|
||||
.or(config.default_model.as_deref()),
|
||||
Some(provider_name),
|
||||
);
|
||||
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
@@ -1840,7 +1843,7 @@ pub async fn run(
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
model_name,
|
||||
&model_name,
|
||||
&provider_runtime_options,
|
||||
)?;
|
||||
|
||||
@@ -2003,7 +2006,7 @@ pub async fn run(
|
||||
let native_tools = provider.supports_native_tools();
|
||||
let mut system_prompt = crate::channels::build_system_prompt_with_mode(
|
||||
&config.workspace_dir,
|
||||
model_name,
|
||||
&model_name,
|
||||
&tool_descs,
|
||||
&skills,
|
||||
Some(&config.identity),
|
||||
@@ -2042,7 +2045,7 @@ pub async fn run(
|
||||
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context =
|
||||
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await;
|
||||
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score, None).await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -2085,7 +2088,7 @@ pub async fn run(
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
@@ -2101,6 +2104,17 @@ pub async fn run(
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let assistant_key = autosave_memory_key("assistant_resp");
|
||||
let _ = mem
|
||||
.store(
|
||||
&assistant_key,
|
||||
&response,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
println!("{response}");
|
||||
observer.record_event(&ObserverEvent::TurnComplete);
|
||||
} else {
|
||||
@@ -2191,8 +2205,13 @@ pub async fn run(
|
||||
}
|
||||
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context =
|
||||
build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&user_input,
|
||||
config.memory.min_relevance_score,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -2246,7 +2265,7 @@ pub async fn run(
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
@@ -2288,6 +2307,17 @@ pub async fn run(
|
||||
}
|
||||
};
|
||||
final_output = response.clone();
|
||||
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let assistant_key = autosave_memory_key("assistant_resp");
|
||||
let _ = mem
|
||||
.store(
|
||||
&assistant_key,
|
||||
&response,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
if let Err(e) = crate::channels::Channel::send(
|
||||
&cli,
|
||||
&crate::channels::traits::SendMessage::new(format!("\n{response}\n"), "user"),
|
||||
@@ -2302,7 +2332,7 @@ pub async fn run(
|
||||
if let Ok(compacted) = auto_compact_history(
|
||||
&mut history,
|
||||
provider.as_ref(),
|
||||
model_name,
|
||||
&model_name,
|
||||
config.agent.max_history_messages,
|
||||
)
|
||||
.await
|
||||
@@ -2332,6 +2362,14 @@ pub async fn run(
|
||||
/// Process a single message through the full agent (with tools, peripherals, memory).
|
||||
/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use.
|
||||
pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
process_message_with_session(config, message, None).await
|
||||
}
|
||||
|
||||
pub async fn process_message_with_session(
|
||||
config: Config,
|
||||
message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
@@ -2375,10 +2413,10 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
tools_registry.extend(peripheral_tools);
|
||||
|
||||
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||
let model_name = config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||
let model_name = crate::config::resolve_default_model_id(
|
||||
config.default_model.as_deref(),
|
||||
Some(provider_name),
|
||||
);
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: config.api_url.clone(),
|
||||
@@ -2495,7 +2533,13 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
}
|
||||
system_prompt.push_str(&build_shell_policy_instructions(&config.autonomy));
|
||||
|
||||
let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
message,
|
||||
config.memory.min_relevance_score,
|
||||
session_id,
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -4545,7 +4589,7 @@ Tail"#;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_context(&mem, "status updates", 0.0).await;
|
||||
let context = build_context(&mem, "status updates", 0.0, None).await;
|
||||
assert!(context.contains("user_msg_real"));
|
||||
assert!(!context.contains("assistant_resp_poisoned"));
|
||||
assert!(!context.contains("fabricated event"));
|
||||
|
||||
@@ -8,11 +8,12 @@ pub(super) async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
min_relevance_score: f64,
|
||||
session_id: Option<&str>,
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
|
||||
@@ -220,7 +220,9 @@ impl LoopDetector {
|
||||
|
||||
fn hash_output(output: &str) -> u64 {
|
||||
let prefix = if output.len() > OUTPUT_HASH_PREFIX_BYTES {
|
||||
&output[..OUTPUT_HASH_PREFIX_BYTES]
|
||||
// Use floor_utf8_char_boundary to avoid panic on multi-byte UTF-8 characters
|
||||
let boundary = crate::util::floor_utf8_char_boundary(output, OUTPUT_HASH_PREFIX_BYTES);
|
||||
&output[..boundary]
|
||||
} else {
|
||||
output
|
||||
};
|
||||
@@ -386,4 +388,26 @@ mod tests {
|
||||
det.record_call("shell", r#"{"cmd":"cargo test"}"#, "ok", true);
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 11. UTF-8 boundary safety: hash_output must not panic on CJK text
|
||||
#[test]
|
||||
fn hash_output_utf8_boundary_safe() {
|
||||
// Create a string where byte 4096 lands inside a multi-byte char
|
||||
// Chinese chars are 3 bytes each, so 1366 chars = 4098 bytes
|
||||
let cjk_text: String = "文".repeat(1366); // 4098 bytes
|
||||
assert!(cjk_text.len() > super::OUTPUT_HASH_PREFIX_BYTES);
|
||||
|
||||
// This should NOT panic
|
||||
let hash1 = super::hash_output(&cjk_text);
|
||||
|
||||
// Different content should produce different hash
|
||||
let cjk_text2: String = "字".repeat(1366);
|
||||
let hash2 = super::hash_output(&cjk_text2);
|
||||
assert_ne!(hash1, hash2);
|
||||
|
||||
// Mixed ASCII + CJK at boundary
|
||||
let mixed = "a".repeat(4094) + "文文"; // 4094 + 6 = 4100 bytes, boundary at 4096
|
||||
let hash3 = super::hash_output(&mixed);
|
||||
assert!(hash3 != 0); // Just verify it runs
|
||||
}
|
||||
}
|
||||
|
||||
+106
-3
@@ -28,8 +28,13 @@ pub(super) fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
|
||||
}
|
||||
|
||||
let start = if has_system { 1 } else { 0 };
|
||||
let to_remove = non_system_count - max_history;
|
||||
history.drain(start..start + to_remove);
|
||||
let mut trim_end = start + (non_system_count - max_history);
|
||||
// Never keep a leading `role=tool` at the trim boundary. Tool-message runs
|
||||
// must remain attached to their preceding assistant(tool_calls) message.
|
||||
while trim_end < history.len() && history[trim_end].role == "tool" {
|
||||
trim_end += 1;
|
||||
}
|
||||
history.drain(start..trim_end);
|
||||
}
|
||||
|
||||
pub(super) fn build_compaction_transcript(messages: &[ChatMessage]) -> String {
|
||||
@@ -80,7 +85,11 @@ pub(super) async fn auto_compact_history(
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let compact_end = start + compact_count;
|
||||
let mut compact_end = start + compact_count;
|
||||
// Do not split assistant(tool_calls) -> tool runs across compaction boundary.
|
||||
while compact_end < history.len() && history[compact_end].role == "tool" {
|
||||
compact_end += 1;
|
||||
}
|
||||
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
|
||||
let transcript = build_compaction_transcript(&to_compact);
|
||||
|
||||
@@ -104,3 +113,97 @@ pub(super) async fn auto_compact_history(
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::providers::{ChatRequest, ChatResponse, Provider};
|
||||
use async_trait::async_trait;
|
||||
|
||||
struct StaticSummaryProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for StaticSummaryProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("- summarized context".to_string())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
Ok(ChatResponse {
|
||||
text: Some("- summarized context".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_with_tool_call(id: &str) -> ChatMessage {
|
||||
ChatMessage::assistant(format!(
|
||||
"{{\"content\":\"\",\"tool_calls\":[{{\"id\":\"{id}\",\"name\":\"shell\",\"arguments\":\"{{}}\"}}]}}"
|
||||
))
|
||||
}
|
||||
|
||||
fn tool_result(id: &str) -> ChatMessage {
|
||||
ChatMessage::tool(format!("{{\"tool_call_id\":\"{id}\",\"content\":\"ok\"}}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trim_history_avoids_orphan_tool_at_boundary() {
|
||||
let mut history = vec![
|
||||
ChatMessage::user("old"),
|
||||
assistant_with_tool_call("call_1"),
|
||||
tool_result("call_1"),
|
||||
ChatMessage::user("recent"),
|
||||
];
|
||||
|
||||
trim_history(&mut history, 2);
|
||||
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "user");
|
||||
assert_eq!(history[0].content, "recent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auto_compact_history_does_not_split_tool_run_boundary() {
|
||||
let mut history = vec![
|
||||
ChatMessage::user("oldest"),
|
||||
assistant_with_tool_call("call_2"),
|
||||
tool_result("call_2"),
|
||||
];
|
||||
for idx in 0..19 {
|
||||
history.push(ChatMessage::user(format!("recent-{idx}")));
|
||||
}
|
||||
// 22 non-system messages => compaction with max_history=21 would
|
||||
// previously cut right before the tool result (index 2).
|
||||
assert_eq!(history.len(), 22);
|
||||
|
||||
let compacted =
|
||||
auto_compact_history(&mut history, &StaticSummaryProvider, "test-model", 21)
|
||||
.await
|
||||
.expect("compaction should succeed");
|
||||
|
||||
assert!(compacted);
|
||||
assert_eq!(history[0].role, "assistant");
|
||||
assert!(
|
||||
history[0].content.contains("[Compaction summary]"),
|
||||
"summary message should replace compacted range"
|
||||
);
|
||||
assert_ne!(
|
||||
history[1].role, "tool",
|
||||
"first retained message must not be an orphan tool result"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -7,6 +7,7 @@ pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
pub mod quota_aware;
|
||||
pub mod research;
|
||||
pub mod session;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -14,4 +15,4 @@ mod tests;
|
||||
#[allow(unused_imports)]
|
||||
pub use agent::{Agent, AgentBuilder};
|
||||
#[allow(unused_imports)]
|
||||
pub use loop_::{process_message, run};
|
||||
pub use loop_::{process_message, process_message_with_session, run};
|
||||
|
||||
@@ -0,0 +1,569 @@
|
||||
use crate::providers::ChatMessage;
|
||||
use crate::{
|
||||
config::AgentSessionBackend, config::AgentSessionConfig, config::AgentSessionStrategy,
|
||||
};
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use std::sync::{LazyLock, Mutex as StdMutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time;
|
||||
|
||||
static SHARED_SESSION_MANAGERS: LazyLock<StdMutex<HashMap<String, Arc<dyn SessionManager>>>> =
|
||||
LazyLock::new(|| StdMutex::new(HashMap::new()));
|
||||
|
||||
pub fn resolve_session_id(
|
||||
session_config: &AgentSessionConfig,
|
||||
sender_id: &str,
|
||||
channel_name: Option<&str>,
|
||||
) -> String {
|
||||
fn escape_part(raw: &str) -> String {
|
||||
raw.replace(':', "%3A")
|
||||
}
|
||||
|
||||
match session_config.strategy {
|
||||
AgentSessionStrategy::Main => "main".to_string(),
|
||||
AgentSessionStrategy::PerChannel => escape_part(channel_name.unwrap_or("main")),
|
||||
AgentSessionStrategy::PerSender => match channel_name {
|
||||
Some(channel) => format!("{}:{sender_id}", escape_part(channel)),
|
||||
None => sender_id.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_session_manager(
|
||||
session_config: &AgentSessionConfig,
|
||||
workspace_dir: &Path,
|
||||
) -> Result<Option<Arc<dyn SessionManager>>> {
|
||||
let ttl = Duration::from_secs(session_config.ttl_seconds);
|
||||
let max_messages = session_config.max_messages;
|
||||
match session_config.backend {
|
||||
AgentSessionBackend::None => Ok(None),
|
||||
AgentSessionBackend::Memory => Ok(Some(MemorySessionManager::new(ttl, max_messages))),
|
||||
AgentSessionBackend::Sqlite => {
|
||||
let path = SqliteSessionManager::default_db_path(workspace_dir);
|
||||
Ok(Some(SqliteSessionManager::new(path, ttl, max_messages)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shared_session_manager(
|
||||
session_config: &AgentSessionConfig,
|
||||
workspace_dir: &Path,
|
||||
) -> Result<Option<Arc<dyn SessionManager>>> {
|
||||
let key = format!("{}:{session_config:?}", workspace_dir.display());
|
||||
|
||||
{
|
||||
let map = SHARED_SESSION_MANAGERS.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(mgr) = map.get(&key) {
|
||||
return Ok(Some(mgr.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
let mgr_opt = create_session_manager(session_config, workspace_dir)?;
|
||||
if let Some(mgr) = mgr_opt.as_ref() {
|
||||
let mut map = SHARED_SESSION_MANAGERS.lock().unwrap_or_else(|e| e.into_inner());
|
||||
map.insert(key, mgr.clone());
|
||||
}
|
||||
Ok(mgr_opt)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
id: String,
|
||||
manager: Arc<dyn SessionManager>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub async fn get_history(&self) -> Result<Vec<ChatMessage>> {
|
||||
self.manager.get_history(&self.id).await
|
||||
}
|
||||
|
||||
pub async fn update_history(&self, history: Vec<ChatMessage>) -> Result<()> {
|
||||
self.manager.set_history(&self.id, history).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SessionManager: Send + Sync {
|
||||
fn clone_arc(&self) -> Arc<dyn SessionManager>;
|
||||
async fn ensure_exists(&self, _session_id: &str) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
async fn get_history(&self, session_id: &str) -> Result<Vec<ChatMessage>>;
|
||||
async fn set_history(&self, session_id: &str, history: Vec<ChatMessage>) -> Result<()>;
|
||||
async fn delete(&self, session_id: &str) -> Result<()>;
|
||||
async fn cleanup_expired(&self) -> Result<usize>;
|
||||
|
||||
async fn get_or_create(&self, session_id: &str) -> Result<Session> {
|
||||
self.ensure_exists(session_id).await?;
|
||||
Ok(Session {
|
||||
id: session_id.to_string(),
|
||||
manager: self.clone_arc(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn unix_seconds_now() -> i64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or(Duration::from_secs(0))
|
||||
.as_secs() as i64
|
||||
}
|
||||
|
||||
fn trim_non_system(history: &mut Vec<ChatMessage>, max_messages: usize) {
|
||||
history.retain(|m| m.role != "system");
|
||||
if max_messages == 0 || history.len() <= max_messages {
|
||||
return;
|
||||
}
|
||||
let drop_count = history.len() - max_messages;
|
||||
history.drain(0..drop_count);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MemorySessionState {
|
||||
history: RwLock<Vec<ChatMessage>>,
|
||||
updated_at_unix: AtomicI64,
|
||||
}
|
||||
|
||||
struct MemorySessionManagerInner {
|
||||
sessions: RwLock<HashMap<String, Arc<MemorySessionState>>>,
|
||||
ttl: Duration,
|
||||
max_messages: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MemorySessionManager {
|
||||
inner: Arc<MemorySessionManagerInner>,
|
||||
}
|
||||
|
||||
impl MemorySessionManager {
|
||||
pub fn new(ttl: Duration, max_messages: usize) -> Arc<Self> {
|
||||
let mgr = Arc::new(Self {
|
||||
inner: Arc::new(MemorySessionManagerInner {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
ttl,
|
||||
max_messages,
|
||||
}),
|
||||
});
|
||||
mgr.spawn_cleanup_task();
|
||||
mgr
|
||||
}
|
||||
|
||||
fn spawn_cleanup_task(self: &Arc<Self>) {
|
||||
let mgr = Arc::clone(self);
|
||||
let interval = cleanup_interval(mgr.inner.ttl);
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = time::interval(interval);
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
let _ = mgr.cleanup_expired().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionManager for MemorySessionManager {
|
||||
fn clone_arc(&self) -> Arc<dyn SessionManager> {
|
||||
Arc::new(self.clone())
|
||||
}
|
||||
|
||||
async fn ensure_exists(&self, session_id: &str) -> Result<()> {
|
||||
let mut sessions = self.inner.sessions.write().await;
|
||||
if sessions.contains_key(session_id) {
|
||||
return Ok(());
|
||||
}
|
||||
let now = unix_seconds_now();
|
||||
sessions.insert(
|
||||
session_id.to_string(),
|
||||
Arc::new(MemorySessionState {
|
||||
history: RwLock::new(Vec::new()),
|
||||
updated_at_unix: AtomicI64::new(now),
|
||||
}),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_history(&self, session_id: &str) -> Result<Vec<ChatMessage>> {
|
||||
let state = {
|
||||
let sessions = self.inner.sessions.read().await;
|
||||
sessions.get(session_id).cloned()
|
||||
};
|
||||
let Some(state) = state else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
let history = state.history.read().await;
|
||||
let mut history = history.clone();
|
||||
trim_non_system(&mut history, self.inner.max_messages);
|
||||
Ok(history)
|
||||
}
|
||||
|
||||
async fn set_history(&self, session_id: &str, mut history: Vec<ChatMessage>) -> Result<()> {
|
||||
trim_non_system(&mut history, self.inner.max_messages);
|
||||
let now = unix_seconds_now();
|
||||
let state = {
|
||||
let mut sessions = self.inner.sessions.write().await;
|
||||
sessions
|
||||
.entry(session_id.to_string())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(MemorySessionState {
|
||||
history: RwLock::new(Vec::new()),
|
||||
updated_at_unix: AtomicI64::new(now),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
};
|
||||
state.updated_at_unix.store(now, Ordering::Relaxed);
|
||||
let mut stored = state.history.write().await;
|
||||
*stored = history;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete(&self, session_id: &str) -> Result<()> {
|
||||
let mut sessions = self.inner.sessions.write().await;
|
||||
sessions.remove(session_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) -> Result<usize> {
|
||||
if self.inner.ttl.is_zero() {
|
||||
return Ok(0);
|
||||
}
|
||||
let cutoff = unix_seconds_now() - self.inner.ttl.as_secs() as i64;
|
||||
let mut sessions = self.inner.sessions.write().await;
|
||||
let before = sessions.len();
|
||||
sessions.retain(|_, s| s.updated_at_unix.load(Ordering::Relaxed) >= cutoff);
|
||||
Ok(before.saturating_sub(sessions.len()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SqliteSessionManager {
|
||||
conn: Arc<Mutex<Connection>>,
|
||||
ttl: Duration,
|
||||
max_messages: usize,
|
||||
}
|
||||
|
||||
impl SqliteSessionManager {
|
||||
pub fn new(db_path: PathBuf, ttl: Duration, max_messages: usize) -> Result<Arc<Self>> {
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let conn = Connection::open(&db_path)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;",
|
||||
)?;
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS agent_sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
history_json TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_sessions_updated_at
|
||||
ON agent_sessions(updated_at);",
|
||||
)?;
|
||||
|
||||
let mgr = Arc::new(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
ttl,
|
||||
max_messages,
|
||||
});
|
||||
mgr.spawn_cleanup_task();
|
||||
Ok(mgr)
|
||||
}
|
||||
|
||||
pub fn default_db_path(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("memory").join("sessions.db")
|
||||
}
|
||||
|
||||
fn spawn_cleanup_task(self: &Arc<Self>) {
|
||||
let mgr = Arc::clone(self);
|
||||
let interval = cleanup_interval(mgr.ttl);
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = time::interval(interval);
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
let _ = mgr.cleanup_expired().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn force_expire_session(&self, session_id: &str, age: Duration) -> Result<()> {
|
||||
let conn = self.conn.clone();
|
||||
let session_id = session_id.to_string();
|
||||
let age_secs = age.as_secs() as i64;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let conn = conn.lock();
|
||||
let new_time = unix_seconds_now() - age_secs;
|
||||
conn.execute(
|
||||
"UPDATE agent_sessions SET updated_at = ?2 WHERE session_id = ?1",
|
||||
params![session_id, new_time],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.context("SQLite blocking task panicked")?
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionManager for SqliteSessionManager {
|
||||
fn clone_arc(&self) -> Arc<dyn SessionManager> {
|
||||
Arc::new(self.clone())
|
||||
}
|
||||
|
||||
async fn ensure_exists(&self, session_id: &str) -> Result<()> {
|
||||
let now = unix_seconds_now();
|
||||
let conn = self.conn.clone();
|
||||
let session_id = session_id.to_string();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let conn = conn.lock();
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO agent_sessions(session_id, history_json, updated_at)
|
||||
VALUES(?1, '[]', ?2)",
|
||||
params![session_id, now],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.context("SQLite blocking task panicked")?
|
||||
}
|
||||
|
||||
async fn get_history(&self, session_id: &str) -> Result<Vec<ChatMessage>> {
|
||||
let conn = self.conn.clone();
|
||||
let session_id = session_id.to_string();
|
||||
let max_messages = self.max_messages;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let conn = conn.lock();
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT history_json FROM agent_sessions WHERE session_id = ?1",
|
||||
)?;
|
||||
let mut rows = stmt.query(params![session_id])?;
|
||||
if let Some(row) = rows.next()? {
|
||||
let json: String = row.get(0)?;
|
||||
let mut history: Vec<ChatMessage> = serde_json::from_str(&json)
|
||||
.with_context(|| format!("Failed to parse session history for session_id={session_id}"))?;
|
||||
trim_non_system(&mut history, max_messages);
|
||||
return Ok(history);
|
||||
}
|
||||
Ok(Vec::new())
|
||||
})
|
||||
.await
|
||||
.context("SQLite blocking task panicked")?
|
||||
}
|
||||
|
||||
async fn set_history(&self, session_id: &str, mut history: Vec<ChatMessage>) -> Result<()> {
|
||||
trim_non_system(&mut history, self.max_messages);
|
||||
let json = serde_json::to_string(&history)?;
|
||||
let now = unix_seconds_now();
|
||||
let conn = self.conn.clone();
|
||||
let session_id = session_id.to_string();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let conn = conn.lock();
|
||||
conn.execute(
|
||||
"INSERT INTO agent_sessions(session_id, history_json, updated_at)
|
||||
VALUES(?1, ?2, ?3)
|
||||
ON CONFLICT(session_id) DO UPDATE SET history_json=excluded.history_json, updated_at=excluded.updated_at",
|
||||
params![session_id, json, now],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.context("SQLite blocking task panicked")?
|
||||
}
|
||||
|
||||
async fn delete(&self, session_id: &str) -> Result<()> {
|
||||
let conn = self.conn.clone();
|
||||
let session_id = session_id.to_string();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let conn = conn.lock();
|
||||
conn.execute(
|
||||
"DELETE FROM agent_sessions WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.context("SQLite blocking task panicked")?
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) -> Result<usize> {
|
||||
if self.ttl.is_zero() {
|
||||
return Ok(0);
|
||||
}
|
||||
let conn = self.conn.clone();
|
||||
let ttl_secs = self.ttl.as_secs() as i64;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let cutoff = unix_seconds_now() - ttl_secs;
|
||||
let conn = conn.lock();
|
||||
let removed = conn.execute(
|
||||
"DELETE FROM agent_sessions WHERE updated_at < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
Ok(removed)
|
||||
})
|
||||
.await
|
||||
.context("SQLite blocking task panicked")?
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup_interval(ttl: Duration) -> Duration {
|
||||
if ttl.is_zero() {
|
||||
return Duration::from_secs(60);
|
||||
}
|
||||
let half = ttl / 2;
|
||||
if half < Duration::from_secs(30) {
|
||||
Duration::from_secs(30)
|
||||
} else if half > Duration::from_secs(300) {
|
||||
Duration::from_secs(300)
|
||||
} else {
|
||||
half
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolve_session_id_respects_strategy() {
|
||||
let mut cfg = AgentSessionConfig::default();
|
||||
cfg.strategy = AgentSessionStrategy::Main;
|
||||
assert_eq!(resolve_session_id(&cfg, "u1", Some("whatsapp")), "main");
|
||||
|
||||
cfg.strategy = AgentSessionStrategy::PerChannel;
|
||||
assert_eq!(resolve_session_id(&cfg, "u1", Some("whatsapp")), "whatsapp");
|
||||
assert_eq!(resolve_session_id(&cfg, "u1", None), "main");
|
||||
|
||||
cfg.strategy = AgentSessionStrategy::PerSender;
|
||||
assert_eq!(
|
||||
resolve_session_id(&cfg, "u1", Some("whatsapp")),
|
||||
"whatsapp:u1"
|
||||
);
|
||||
assert_eq!(resolve_session_id(&cfg, "u1", None), "u1");
|
||||
|
||||
assert_eq!(
|
||||
resolve_session_id(&cfg, "u1", Some("matrix:@alice")),
|
||||
"matrix%3A@alice:u1"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_session_accumulates_history() -> Result<()> {
|
||||
let mgr = MemorySessionManager::new(Duration::from_secs(3600), 50);
|
||||
let session = mgr.get_or_create("s1").await?;
|
||||
|
||||
assert!(session.get_history().await?.is_empty());
|
||||
|
||||
session
|
||||
.update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")])
|
||||
.await?;
|
||||
assert_eq!(session.get_history().await?.len(), 2);
|
||||
|
||||
let mut h = session.get_history().await?;
|
||||
h.push(ChatMessage::user("again"));
|
||||
h.push(ChatMessage::assistant("ok2"));
|
||||
session.update_history(h).await?;
|
||||
assert_eq!(session.get_history().await?.len(), 4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_sessions_do_not_mix_histories() -> Result<()> {
|
||||
let mgr = MemorySessionManager::new(Duration::from_secs(3600), 50);
|
||||
let a = mgr.get_or_create("a").await?;
|
||||
let b = mgr.get_or_create("b").await?;
|
||||
|
||||
a.update_history(vec![ChatMessage::user("u1"), ChatMessage::assistant("a1")])
|
||||
.await?;
|
||||
b.update_history(vec![ChatMessage::user("u2"), ChatMessage::assistant("b1")])
|
||||
.await?;
|
||||
|
||||
let ha = a.get_history().await?;
|
||||
let hb = b.get_history().await?;
|
||||
assert_eq!(ha[0].content, "u1");
|
||||
assert_eq!(hb[0].content, "u2");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn max_messages_trims_oldest_non_system() -> Result<()> {
|
||||
let mgr = MemorySessionManager::new(Duration::from_secs(3600), 2);
|
||||
let session = mgr.get_or_create("s1").await?;
|
||||
session
|
||||
.update_history(vec![
|
||||
ChatMessage::system("s"),
|
||||
ChatMessage::user("1"),
|
||||
ChatMessage::assistant("2"),
|
||||
ChatMessage::user("3"),
|
||||
ChatMessage::assistant("4"),
|
||||
])
|
||||
.await?;
|
||||
let h = session.get_history().await?;
|
||||
assert_eq!(h.len(), 2);
|
||||
assert_eq!(h[0].content, "3");
|
||||
assert_eq!(h[1].content, "4");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_session_persists_across_instances() -> Result<()> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let db_path = dir.path().join("sessions.db");
|
||||
|
||||
{
|
||||
let mgr = SqliteSessionManager::new(db_path.clone(), Duration::from_secs(3600), 50)?;
|
||||
let session = mgr.get_or_create("s1").await?;
|
||||
session
|
||||
.update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")])
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mgr2 = SqliteSessionManager::new(db_path, Duration::from_secs(3600), 50)?;
|
||||
let session2 = mgr2.get_or_create("s1").await?;
|
||||
let history = session2.get_history().await?;
|
||||
assert_eq!(history.len(), 2);
|
||||
assert_eq!(history[0].role, "user");
|
||||
assert_eq!(history[1].role, "assistant");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_session_cleanup_expires() -> Result<()> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let db_path = dir.path().join("sessions.db");
|
||||
// TTL 1 second
|
||||
let mgr = SqliteSessionManager::new(db_path, Duration::from_secs(1), 50)?;
|
||||
let session = mgr.get_or_create("s1").await?;
|
||||
session
|
||||
.update_history(vec![ChatMessage::user("hi"), ChatMessage::assistant("ok")])
|
||||
.await?;
|
||||
|
||||
// Force expire by setting age to 2 seconds
|
||||
mgr.force_expire_session("s1", Duration::from_secs(2))
|
||||
.await?;
|
||||
|
||||
let removed = mgr.cleanup_expired().await?;
|
||||
assert!(removed >= 1);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
+11
-6
@@ -636,7 +636,7 @@ async fn history_trims_after_max_messages() {
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[tokio::test]
|
||||
async fn auto_save_stores_only_user_messages_in_memory() {
|
||||
async fn auto_save_stores_user_and_assistant_messages_in_memory() {
|
||||
let (mem, _tmp) = make_sqlite_memory();
|
||||
let provider = Box::new(ScriptedProvider::new(vec![text_response(
|
||||
"I remember everything",
|
||||
@@ -651,11 +651,11 @@ async fn auto_save_stores_only_user_messages_in_memory() {
|
||||
|
||||
let _ = agent.turn("Remember this fact").await.unwrap();
|
||||
|
||||
// Auto-save only persists user-stated input, never assistant-generated summaries.
|
||||
// Auto-save persists both user input and assistant output for traceability.
|
||||
let count = mem.count().await.unwrap();
|
||||
assert_eq!(
|
||||
count, 1,
|
||||
"Expected exactly 1 user memory entry, got {count}"
|
||||
count, 2,
|
||||
"Expected user + assistant memory entries, got {count}"
|
||||
);
|
||||
|
||||
let stored = mem.get("user_msg").await.unwrap();
|
||||
@@ -668,8 +668,13 @@ async fn auto_save_stores_only_user_messages_in_memory() {
|
||||
|
||||
let assistant = mem.get("assistant_resp").await.unwrap();
|
||||
assert!(
|
||||
assistant.is_none(),
|
||||
"assistant_resp should not be auto-saved anymore"
|
||||
assistant.is_some(),
|
||||
"Expected assistant_resp key to be present"
|
||||
);
|
||||
assert_eq!(
|
||||
assistant.unwrap().content,
|
||||
"I remember everything",
|
||||
"Assistant response should be persisted when auto-save is enabled"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use serde::Deserialize;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use zeroclaw::config::schema::McpServerConfig;
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
struct FileMcp {
|
||||
#[serde(default)]
|
||||
enabled: bool,
|
||||
#[serde(default)]
|
||||
servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
struct FileRoot {
|
||||
#[serde(default)]
|
||||
mcp: FileMcp,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
let (enabled, servers) = match std::fs::read_to_string("config.toml") {
|
||||
Ok(s) => {
|
||||
let start = s
|
||||
.lines()
|
||||
.position(|line| line.trim() == "[mcp]")
|
||||
.unwrap_or(0);
|
||||
let slice = s.lines().skip(start).collect::<Vec<_>>().join("\n");
|
||||
let root: FileRoot = toml::from_str(&slice).context("failed to parse ./config.toml")?;
|
||||
(root.mcp.enabled, root.mcp.servers)
|
||||
}
|
||||
Err(_) => {
|
||||
let config = zeroclaw::Config::load_or_init().await?;
|
||||
(config.mcp.enabled, config.mcp.servers)
|
||||
}
|
||||
};
|
||||
|
||||
if !enabled || servers.is_empty() {
|
||||
bail!("MCP is disabled or no servers configured");
|
||||
}
|
||||
|
||||
let registry = zeroclaw::tools::McpRegistry::connect_all(&servers).await?;
|
||||
let tool_count = registry.tool_names().len();
|
||||
tracing::info!(
|
||||
"MCP smoke ok: {} server(s), {} tool(s)",
|
||||
registry.server_count(),
|
||||
tool_count
|
||||
);
|
||||
|
||||
if registry.server_count() == 0 {
|
||||
bail!("no MCP servers connected");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,901 @@
|
||||
//! ACP (Agent Client Protocol) channel for ZeroClaw.
|
||||
//!
|
||||
//! This channel enables ZeroClaw to act as an ACP client, connecting to an OpenCode
|
||||
//! ACP server via `opencode acp` command for JSON-RPC 2.0 communication over stdio.
|
||||
//! This allows users to control OpenCode behavior from any channel via social apps.
|
||||
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use crate::config::schema::AcpConfig;
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Monotonic counter for message IDs in ACP JSON-RPC requests.
|
||||
static ACP_MESSAGE_ID: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// ACP channel implementation for connecting to OpenCode ACP server.
|
||||
///
|
||||
/// The channel starts an OpenCode subprocess via `opencode acp` command and
|
||||
/// communicates using JSON-RPC 2.0 over stdio. Messages from social apps are
|
||||
/// forwarded as prompts to OpenCode, and responses are sent back through the
|
||||
/// originating channel.
|
||||
pub struct AcpChannel {
|
||||
/// OpenCode binary path (default: "opencode")
|
||||
opencode_path: String,
|
||||
/// Working directory for OpenCode process
|
||||
workdir: Option<String>,
|
||||
/// Additional arguments to pass to `opencode acp`
|
||||
extra_args: Vec<String>,
|
||||
/// Allowed user identifiers (empty = deny all, "*" = allow all)
|
||||
allowed_users: Vec<String>,
|
||||
/// Optional pairing guard for authentication
|
||||
pairing: Option<crate::security::pairing::PairingGuard>,
|
||||
/// HTTP client for potential future HTTP transport support
|
||||
client: reqwest::Client,
|
||||
/// Active OpenCode subprocess and its I/O handles
|
||||
process: Arc<Mutex<Option<AcpProcess>>>,
|
||||
/// Serializes ACP send operations to avoid concurrent process take/spawn races.
|
||||
send_operation_lock: Arc<Mutex<()>>,
|
||||
/// Next message ID for JSON-RPC requests
|
||||
next_message_id: Arc<AtomicU64>,
|
||||
/// Optional response channel for sending ACP responses back to original channel
|
||||
response_channel: Option<Arc<dyn Channel>>,
|
||||
}
|
||||
/// Active ACP process with I/O handles and session state.
|
||||
struct AcpProcess {
|
||||
/// Child process handle
|
||||
child: Child,
|
||||
/// Stdin handle for sending JSON-RPC requests
|
||||
stdin: tokio::process::ChildStdin,
|
||||
/// Stdout handle for receiving JSON-RPC responses
|
||||
stdout: BufReader<tokio::process::ChildStdout>,
|
||||
/// Session ID from ACP server (after initialize + session/new)
|
||||
session_id: Option<String>,
|
||||
/// JSON-RPC message ID counter (per-process)
|
||||
message_id: u64,
|
||||
/// Pending responses keyed by request ID
|
||||
pending_responses: VecDeque<PendingResponse>,
|
||||
}
|
||||
|
||||
/// Pending JSON-RPC response awaiting completion.
|
||||
struct PendingResponse {
|
||||
request_id: u64,
|
||||
method: String,
|
||||
created_at: std::time::Instant,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 request structure.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct JsonRpcRequest {
|
||||
jsonrpc: String,
|
||||
id: u64,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 response structure.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct JsonRpcResponse {
|
||||
jsonrpc: String,
|
||||
id: u64,
|
||||
#[serde(flatten)]
|
||||
result_or_error: JsonRpcResultOrError,
|
||||
}
|
||||
|
||||
/// JSON-RPC result or error.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum JsonRpcResultOrError {
|
||||
Result { result: Value },
|
||||
Error { error: JsonRpcError },
|
||||
}
|
||||
|
||||
/// JSON-RPC error object.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct JsonRpcError {
|
||||
code: i32,
|
||||
message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
data: Option<Value>,
|
||||
}
|
||||
|
||||
/// ACP initialization parameters.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct InitializeParams {
|
||||
protocol_version: u64,
|
||||
client_capabilities: ClientCapabilities,
|
||||
client_info: ClientInfo,
|
||||
}
|
||||
|
||||
/// Client capabilities declaration.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
struct ClientCapabilities {
|
||||
fs: FsCapabilities,
|
||||
terminal: bool,
|
||||
#[serde(rename = "_meta")]
|
||||
meta: Option<Value>,
|
||||
}
|
||||
|
||||
/// Filesystem capabilities.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
struct FsCapabilities {
|
||||
read_text_file: bool,
|
||||
write_text_file: bool,
|
||||
}
|
||||
|
||||
/// Client information.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct ClientInfo {
|
||||
name: String,
|
||||
title: String,
|
||||
version: String,
|
||||
}
|
||||
|
||||
/// ACP session/new parameters.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct SessionNewParams {
|
||||
cwd: String,
|
||||
mcp_servers: Vec<Value>,
|
||||
}
|
||||
|
||||
/// ACP session/prompt parameters.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct SessionPromptParams {
|
||||
session_id: String,
|
||||
prompt: Vec<PromptItem>,
|
||||
}
|
||||
|
||||
/// Prompt item (text content).
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct PromptItem {
|
||||
#[serde(rename = "type")]
|
||||
item_type: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl AcpChannel {
|
||||
/// Create a new ACP channel with the given configuration.
|
||||
pub fn new(config: AcpConfig) -> Self {
|
||||
Self {
|
||||
opencode_path: config
|
||||
.opencode_path
|
||||
.unwrap_or_else(|| "opencode".to_string()),
|
||||
workdir: config.workdir,
|
||||
extra_args: config.extra_args,
|
||||
allowed_users: config.allowed_users,
|
||||
pairing: None, // TODO: Implement pairing if needed
|
||||
client: reqwest::Client::new(),
|
||||
process: Arc::new(Mutex::new(None)),
|
||||
send_operation_lock: Arc::new(Mutex::new(())),
|
||||
next_message_id: Arc::new(AtomicU64::new(0)),
|
||||
response_channel: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a user is allowed to interact with this channel.
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
self.allowed_users
|
||||
.iter()
|
||||
.any(|allowed| allowed == "*" || allowed == user_id)
|
||||
}
|
||||
|
||||
/// Set the response channel for sending ACP responses back to original channel
|
||||
pub fn set_response_channel(&mut self, channel: Arc<dyn Channel>) {
|
||||
self.response_channel = Some(channel);
|
||||
}
|
||||
|
||||
/// Start the OpenCode ACP subprocess and establish connection.
|
||||
fn start_process(&self) -> Result<AcpProcess> {
|
||||
let mut command = Command::new(&self.opencode_path);
|
||||
command.arg("acp");
|
||||
|
||||
if let Some(workdir) = &self.workdir {
|
||||
command.current_dir(workdir);
|
||||
}
|
||||
|
||||
for arg in &self.extra_args {
|
||||
command.arg(arg);
|
||||
}
|
||||
|
||||
command.stdin(std::process::Stdio::piped());
|
||||
command.stdout(std::process::Stdio::piped());
|
||||
// Inherit stderr so the child cannot block on an unread stderr pipe.
|
||||
command.stderr(std::process::Stdio::inherit());
|
||||
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to start OpenCode process: {}", self.opencode_path))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.context("Failed to take stdin from child process")?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.context("Failed to take stdout from child process")?;
|
||||
let stdout_reader = BufReader::new(stdout);
|
||||
|
||||
let process = AcpProcess {
|
||||
child,
|
||||
stdin,
|
||||
stdout: stdout_reader,
|
||||
session_id: None,
|
||||
message_id: 0,
|
||||
pending_responses: VecDeque::new(),
|
||||
};
|
||||
|
||||
Ok(process)
|
||||
}
|
||||
|
||||
/// Send a JSON-RPC request and wait for response.
|
||||
async fn send_json_rpc_request(
|
||||
&self,
|
||||
process: &mut AcpProcess,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
) -> Result<Value> {
|
||||
let request_id = process.message_id;
|
||||
process.message_id += 1;
|
||||
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: request_id,
|
||||
method: method.to_string(),
|
||||
params,
|
||||
};
|
||||
|
||||
let json_str = serde_json::to_string(&request).with_context(|| {
|
||||
format!(
|
||||
"Failed to serialize JSON-RPC request for method: {}",
|
||||
method
|
||||
)
|
||||
})?;
|
||||
|
||||
// Write message with newline delimiter (ACP protocol requirement)
|
||||
process.stdin.write_all(json_str.as_bytes()).await?;
|
||||
process.stdin.write_all(b"\n").await?;
|
||||
process.stdin.flush().await?;
|
||||
|
||||
// Read response line with timeout
|
||||
let mut line = String::new();
|
||||
let timeout_duration = std::time::Duration::from_secs(30);
|
||||
match tokio::time::timeout(timeout_duration, process.stdout.read_line(&mut line)).await {
|
||||
Ok(read_result) => {
|
||||
read_result
|
||||
.with_context(|| format!("Failed to read response for method: {}", method))?;
|
||||
}
|
||||
Err(_) => {
|
||||
anyhow::bail!("Timeout waiting for ACP response for method: {}", method);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON-RPC response
|
||||
let response: JsonRpcResponse = serde_json::from_str(&line)
|
||||
.with_context(|| format!("Failed to parse JSON-RPC response: {}", line))?;
|
||||
|
||||
// Verify response ID matches request ID
|
||||
if response.id != request_id {
|
||||
anyhow::bail!(
|
||||
"Response ID mismatch: expected {}, got {}",
|
||||
request_id,
|
||||
response.id
|
||||
);
|
||||
}
|
||||
|
||||
match response.result_or_error {
|
||||
JsonRpcResultOrError::Result { result } => Ok(result),
|
||||
JsonRpcResultOrError::Error { error } => {
|
||||
anyhow::bail!("ACP JSON-RPC error ({}): {}", error.code, error.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize ACP connection with the server.
|
||||
async fn initialize_acp(&self, process: &mut AcpProcess) -> Result<()> {
|
||||
let params = InitializeParams {
|
||||
protocol_version: 1,
|
||||
client_capabilities: ClientCapabilities::default(),
|
||||
client_info: ClientInfo {
|
||||
name: "ZeroClaw".to_string(),
|
||||
title: "ZeroClaw ACP Client".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let params_value =
|
||||
serde_json::to_value(params).context("Failed to serialize initialize params")?;
|
||||
|
||||
let response = self
|
||||
.send_json_rpc_request(process, "initialize", Some(params_value))
|
||||
.await?;
|
||||
|
||||
// TODO: Parse response and store capabilities
|
||||
tracing::info!("ACP initialized successfully: {:?}", response);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new ACP session.
|
||||
async fn create_session(&self, process: &mut AcpProcess) -> Result<String> {
|
||||
let cwd = self.workdir.clone().unwrap_or_else(|| {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| ".".into())
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
});
|
||||
|
||||
let params = SessionNewParams {
|
||||
cwd,
|
||||
mcp_servers: vec![],
|
||||
};
|
||||
|
||||
let params_value =
|
||||
serde_json::to_value(params).context("Failed to serialize session/new params")?;
|
||||
|
||||
let response = self
|
||||
.send_json_rpc_request(process, "session/new", Some(params_value))
|
||||
.await?;
|
||||
|
||||
// Parse response to extract session_id
|
||||
let session_id = response
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Invalid session/new response: missing session_id: {:?}",
|
||||
response
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!("ACP session created: {}", session_id);
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Send a prompt to the ACP session.
|
||||
async fn send_prompt(
|
||||
&self,
|
||||
process: &mut AcpProcess,
|
||||
session_id: &str,
|
||||
prompt_text: &str,
|
||||
) -> Result<String> {
|
||||
let params = SessionPromptParams {
|
||||
session_id: session_id.to_string(),
|
||||
prompt: vec![PromptItem {
|
||||
item_type: "text".to_string(),
|
||||
text: prompt_text.to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
let params_value =
|
||||
serde_json::to_value(params).context("Failed to serialize session/prompt params")?;
|
||||
|
||||
let response = self
|
||||
.send_json_rpc_request(process, "session/prompt", Some(params_value))
|
||||
.await?;
|
||||
|
||||
// Parse response to extract the actual response text
|
||||
// The response may contain a "response" field with text content
|
||||
let response_text = response
|
||||
.get("response")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Invalid session/prompt response: missing string field `response` for prompt {:?}: {:?}",
|
||||
prompt_text, response
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(response_text)
|
||||
}
|
||||
|
||||
fn process_is_running(process: &mut AcpProcess) -> bool {
|
||||
matches!(process.child.try_wait(), Ok(None))
|
||||
}
|
||||
|
||||
async fn initialize_fresh_process(&self) -> Result<AcpProcess> {
|
||||
let mut new_process = self.start_process()?;
|
||||
self.initialize_acp(&mut new_process).await?;
|
||||
let session_id = self.create_session(&mut new_process).await?;
|
||||
new_process.session_id = Some(session_id);
|
||||
Ok(new_process)
|
||||
}
|
||||
|
||||
async fn checkout_process_for_send(&self) -> Result<AcpProcess> {
|
||||
let mut process_opt = {
|
||||
let mut process_guard = self.process.lock().await;
|
||||
process_guard.take()
|
||||
};
|
||||
|
||||
let needs_restart = match process_opt.as_mut() {
|
||||
Some(process) => !Self::process_is_running(process),
|
||||
None => true,
|
||||
};
|
||||
|
||||
if needs_restart {
|
||||
process_opt = Some(self.initialize_fresh_process().await?);
|
||||
}
|
||||
|
||||
process_opt.context("ACP process disappeared unexpectedly")
|
||||
}
|
||||
|
||||
async fn restore_process(&self, process: Option<AcpProcess>) {
|
||||
let mut process_guard = self.process.lock().await;
|
||||
*process_guard = process;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for AcpChannel {
|
||||
fn name(&self) -> &str {
|
||||
"acp"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> Result<()> {
|
||||
const MAX_SEND_ATTEMPTS: usize = 2;
|
||||
|
||||
let _send_guard = self.send_operation_lock.lock().await;
|
||||
|
||||
// Check if user is allowed
|
||||
if !self.is_user_allowed(&message.recipient) {
|
||||
tracing::warn!(
|
||||
"ACP: ignoring message from unauthorized user: {}",
|
||||
message.recipient
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Strip tool call tags from outgoing messages
|
||||
let content = super::strip_tool_call_tags(&message.content);
|
||||
|
||||
let mut last_error = None;
|
||||
for attempt in 0..MAX_SEND_ATTEMPTS {
|
||||
let mut process = self.checkout_process_for_send().await?;
|
||||
let session_id = process
|
||||
.session_id
|
||||
.as_ref()
|
||||
.context("No active ACP session")?
|
||||
.clone();
|
||||
|
||||
match self.send_prompt(&mut process, &session_id, &content).await {
|
||||
Ok(response) => {
|
||||
if Self::process_is_running(&mut process) {
|
||||
self.restore_process(Some(process)).await;
|
||||
} else {
|
||||
self.restore_process(None).await;
|
||||
}
|
||||
|
||||
// Send response back through response_channel if set
|
||||
if let Some(response_channel) = &self.response_channel {
|
||||
let response_message =
|
||||
SendMessage::new(response, message.recipient.clone());
|
||||
if let Err(e) = response_channel.send(&response_message).await {
|
||||
tracing::warn!(
|
||||
"Failed to send ACP response through response channel: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Log if no response channel configured
|
||||
tracing::info!(
|
||||
"ACP response ready (no response channel configured): {}",
|
||||
response
|
||||
);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
Err(error) => {
|
||||
// Drop unhealthy process on failure and retry once with a fresh process.
|
||||
self.restore_process(None).await;
|
||||
if attempt + 1 < MAX_SEND_ATTEMPTS {
|
||||
tracing::warn!(
|
||||
"ACP prompt failed (attempt {}/{}), restarting ACP process: {}",
|
||||
attempt + 1,
|
||||
MAX_SEND_ATTEMPTS,
|
||||
error
|
||||
);
|
||||
}
|
||||
last_error = Some(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("ACP send failed with unknown error")))
|
||||
}
|
||||
|
||||
async fn listen(&self, _tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
// ACP is primarily a client-side protocol where we send prompts
|
||||
// and receive responses. For channel listening, we might need to
|
||||
// handle incoming messages from other sources that should trigger
|
||||
// ACP prompts.
|
||||
|
||||
// Since ACP is more about sending commands to OpenCode rather than
|
||||
// listening for incoming messages, we implement a minimal listener
|
||||
// that just keeps the channel alive.
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
let mut process_guard = self.process.lock().await;
|
||||
let Some(process) = process_guard.as_mut() else {
|
||||
return false;
|
||||
};
|
||||
let is_running = Self::process_is_running(process);
|
||||
if !is_running {
|
||||
*process_guard = None;
|
||||
}
|
||||
is_running
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::AcpConfig;
|
||||
|
||||
#[test]
|
||||
fn acp_channel_name() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert_eq!(channel.name(), "acp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_empty_allowlist_denies_all() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(!channel.is_user_allowed("anyone"));
|
||||
assert!(!channel.is_user_allowed("user123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_wildcard_allows_all() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec!["*".to_string()],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(channel.is_user_allowed("anyone"));
|
||||
assert!(channel.is_user_allowed("user123"));
|
||||
assert!(channel.is_user_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_specific_users() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec!["user1".to_string(), "user2".to_string()],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(channel.is_user_allowed("user1"));
|
||||
assert!(channel.is_user_allowed("user2"));
|
||||
assert!(!channel.is_user_allowed("user3"));
|
||||
assert!(!channel.is_user_allowed("User1")); // case sensitive
|
||||
assert!(!channel.is_user_allowed("user"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_wildcard_and_specific() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec!["user1".to_string(), "*".to_string()],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(channel.is_user_allowed("user1"));
|
||||
assert!(channel.is_user_allowed("anyone"));
|
||||
assert!(channel.is_user_allowed("user2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_empty_user_id() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec!["user1".to_string()],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(!channel.is_user_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_exact_match_not_substring() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec!["user123".to_string()],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(channel.is_user_allowed("user123"));
|
||||
assert!(!channel.is_user_allowed("user12"));
|
||||
assert!(!channel.is_user_allowed("user1234"));
|
||||
assert!(!channel.is_user_allowed("user"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_case_sensitive() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec!["User".to_string()],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(channel.is_user_allowed("User"));
|
||||
assert!(!channel.is_user_allowed("user"));
|
||||
assert!(!channel.is_user_allowed("USER"));
|
||||
}
|
||||
|
||||
// JSON-RPC data structure tests
|
||||
#[test]
|
||||
fn jsonrpc_request_serialization() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: 42,
|
||||
method: "test".to_string(),
|
||||
params: Some(serde_json::json!({"key": "value"})),
|
||||
};
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
let expected = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{"key":"value"}}"#;
|
||||
assert_eq!(json, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_request_without_params() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: 1,
|
||||
method: "ping".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
let expected = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;
|
||||
assert_eq!(json, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_response_deserialization() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":42,"result":{"status":"ok"}}"#;
|
||||
let response: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(response.jsonrpc, "2.0");
|
||||
assert_eq!(response.id, 42);
|
||||
match response.result_or_error {
|
||||
JsonRpcResultOrError::Result { result } => {
|
||||
assert_eq!(result, serde_json::json!({"status": "ok"}));
|
||||
}
|
||||
JsonRpcResultOrError::Error { .. } => panic!("Expected result, got error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_error_deserialization() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":42,"error":{"code":-32700,"message":"Parse error"}}"#;
|
||||
let response: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(response.jsonrpc, "2.0");
|
||||
assert_eq!(response.id, 42);
|
||||
match response.result_or_error {
|
||||
JsonRpcResultOrError::Error { error } => {
|
||||
assert_eq!(error.code, -32700);
|
||||
assert_eq!(error.message, "Parse error");
|
||||
assert!(error.data.is_none());
|
||||
}
|
||||
JsonRpcResultOrError::Result { .. } => panic!("Expected error, got result"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initialize_params_serialization() {
|
||||
let params = InitializeParams {
|
||||
protocol_version: 1,
|
||||
client_capabilities: ClientCapabilities::default(),
|
||||
client_info: ClientInfo {
|
||||
name: "ZeroClaw".to_string(),
|
||||
title: "ZeroClaw ACP Client".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_value(¶ms).unwrap();
|
||||
assert_eq!(json["protocol_version"], 1);
|
||||
assert_eq!(json["client_info"]["name"], "ZeroClaw");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_new_params_serialization() {
|
||||
let params = SessionNewParams {
|
||||
cwd: "/tmp".to_string(),
|
||||
mcp_servers: vec![],
|
||||
};
|
||||
let json = serde_json::to_value(¶ms).unwrap();
|
||||
assert_eq!(json["cwd"], "/tmp");
|
||||
assert_eq!(json["mcp_servers"], serde_json::json!([]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_prompt_params_serialization() {
|
||||
let params = SessionPromptParams {
|
||||
session_id: "session-123".to_string(),
|
||||
prompt: vec![PromptItem {
|
||||
item_type: "text".to_string(),
|
||||
text: "Hello".to_string(),
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_value(¶ms).unwrap();
|
||||
assert_eq!(json["session_id"], "session-123");
|
||||
assert_eq!(json["prompt"][0]["type"], "text");
|
||||
assert_eq!(json["prompt"][0]["text"], "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acp_channel_set_response_channel() {
|
||||
use super::Channel;
|
||||
use crate::channels::traits::SendMessage;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Mock channel for testing
|
||||
struct MockChannel;
|
||||
#[async_trait::async_trait]
|
||||
impl Channel for MockChannel {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn send(&self, _message: &SendMessage) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<crate::channels::traits::ChannelMessage>,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec![],
|
||||
};
|
||||
|
||||
let mut channel = AcpChannel::new(config);
|
||||
let mock_channel = Arc::new(MockChannel);
|
||||
|
||||
// Initially no response channel
|
||||
// (Cannot directly access private field, but we can test via public API)
|
||||
|
||||
// Set response channel
|
||||
channel.set_response_channel(mock_channel.clone());
|
||||
|
||||
// Verify channel can be set (no panic)
|
||||
// This test mainly ensures the method exists and works
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
// Note: More comprehensive tests would require mocking the OpenCode process
|
||||
// which is beyond the scope of basic unit tests.
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn spawn_test_process(command: &str, args: &[&str]) -> AcpProcess {
|
||||
let mut child = Command::new(command)
|
||||
.args(args)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn()
|
||||
.expect("failed to spawn test ACP process");
|
||||
|
||||
let stdin = child.stdin.take().expect("test process stdin");
|
||||
let stdout = BufReader::new(child.stdout.take().expect("test process stdout"));
|
||||
|
||||
AcpProcess {
|
||||
child,
|
||||
stdin,
|
||||
stdout,
|
||||
session_id: Some("test-session".to_string()),
|
||||
message_id: 0,
|
||||
pending_responses: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn cleanup_test_process(channel: &AcpChannel) {
|
||||
let process = {
|
||||
let mut guard = channel.process.lock().await;
|
||||
guard.take()
|
||||
};
|
||||
if let Some(mut process) = process {
|
||||
let _ = process.child.kill().await;
|
||||
let _ = process.child.wait().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn acp_health_check_false_when_no_process() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
assert!(!channel.health_check().await);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn acp_health_check_true_when_process_running() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
|
||||
let process = spawn_test_process("sh", &["-c", "sleep 5"]).await;
|
||||
{
|
||||
let mut guard = channel.process.lock().await;
|
||||
*guard = Some(process);
|
||||
}
|
||||
|
||||
assert!(channel.health_check().await);
|
||||
cleanup_test_process(&channel).await;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn acp_health_check_false_after_process_exit() {
|
||||
let config = AcpConfig {
|
||||
opencode_path: None,
|
||||
workdir: None,
|
||||
extra_args: vec![],
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let channel = AcpChannel::new(config);
|
||||
|
||||
let process = spawn_test_process("sh", &["-c", "true"]).await;
|
||||
{
|
||||
let mut guard = channel.process.lock().await;
|
||||
*guard = Some(process);
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
assert!(!channel.health_check().await);
|
||||
}
|
||||
}
|
||||
+162
-13
@@ -174,7 +174,6 @@ struct LarkEvent {
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkEventHeader {
|
||||
event_type: String,
|
||||
#[allow(dead_code)]
|
||||
event_id: String,
|
||||
}
|
||||
|
||||
@@ -217,6 +216,10 @@ const LARK_TOKEN_REFRESH_SKEW: Duration = Duration::from_secs(120);
|
||||
const LARK_DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(7200);
|
||||
/// Feishu/Lark API business code for expired/invalid tenant access token.
|
||||
const LARK_INVALID_ACCESS_TOKEN_CODE: i64 = 99_991_663;
|
||||
/// Retention window for seen event/message dedupe keys.
|
||||
const LARK_EVENT_DEDUP_TTL: Duration = Duration::from_secs(30 * 60);
|
||||
/// Periodic cleanup interval for the dedupe cache.
|
||||
const LARK_EVENT_DEDUP_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
|
||||
const LARK_IMAGE_DOWNLOAD_FALLBACK_TEXT: &str =
|
||||
"[Image message received but could not be downloaded]";
|
||||
|
||||
@@ -367,8 +370,10 @@ pub struct LarkChannel {
|
||||
receive_mode: crate::config::schema::LarkReceiveMode,
|
||||
/// Cached tenant access token
|
||||
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
|
||||
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
|
||||
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
/// Dedup set for recently seen event/message keys across WS + webhook paths.
|
||||
recent_event_keys: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
/// Last time we ran TTL cleanup over the dedupe cache.
|
||||
recent_event_cleanup_at: Arc<RwLock<Instant>>,
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
@@ -412,7 +417,8 @@ impl LarkChannel {
|
||||
platform,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::default(),
|
||||
tenant_token: Arc::new(RwLock::new(None)),
|
||||
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
recent_event_keys: Arc::new(RwLock::new(HashMap::new())),
|
||||
recent_event_cleanup_at: Arc::new(RwLock::new(Instant::now())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -520,6 +526,42 @@ impl LarkChannel {
|
||||
}
|
||||
}
|
||||
|
||||
fn dedupe_event_key(event_id: Option<&str>, message_id: Option<&str>) -> Option<String> {
|
||||
let normalized_event = event_id.map(str::trim).filter(|value| !value.is_empty());
|
||||
if let Some(event_id) = normalized_event {
|
||||
return Some(format!("event:{event_id}"));
|
||||
}
|
||||
|
||||
let normalized_message = message_id.map(str::trim).filter(|value| !value.is_empty());
|
||||
normalized_message.map(|message_id| format!("message:{message_id}"))
|
||||
}
|
||||
|
||||
async fn try_mark_event_key_seen(&self, dedupe_key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
if self.recent_event_keys.read().await.contains_key(dedupe_key) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let should_cleanup = {
|
||||
let last_cleanup = self.recent_event_cleanup_at.read().await;
|
||||
now.duration_since(*last_cleanup) >= LARK_EVENT_DEDUP_CLEANUP_INTERVAL
|
||||
};
|
||||
|
||||
let mut seen = self.recent_event_keys.write().await;
|
||||
if seen.contains_key(dedupe_key) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if should_cleanup {
|
||||
seen.retain(|_, t| now.duration_since(*t) < LARK_EVENT_DEDUP_TTL);
|
||||
let mut last_cleanup = self.recent_event_cleanup_at.write().await;
|
||||
*last_cleanup = now;
|
||||
}
|
||||
|
||||
seen.insert(dedupe_key.to_string(), now);
|
||||
true
|
||||
}
|
||||
|
||||
async fn fetch_image_marker(&self, image_key: &str) -> anyhow::Result<String> {
|
||||
if image_key.trim().is_empty() {
|
||||
anyhow::bail!("empty image_key");
|
||||
@@ -880,17 +922,14 @@ impl LarkChannel {
|
||||
|
||||
let lark_msg = &recv.message;
|
||||
|
||||
// Dedup
|
||||
{
|
||||
let now = Instant::now();
|
||||
let mut seen = self.ws_seen_ids.write().await;
|
||||
// GC
|
||||
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
|
||||
if seen.contains_key(&lark_msg.message_id) {
|
||||
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
|
||||
if let Some(dedupe_key) = Self::dedupe_event_key(
|
||||
Some(event.header.event_id.as_str()),
|
||||
Some(lark_msg.message_id.as_str()),
|
||||
) {
|
||||
if !self.try_mark_event_key_seen(&dedupe_key).await {
|
||||
tracing::debug!("Lark WS: duplicate event dropped ({dedupe_key})");
|
||||
continue;
|
||||
}
|
||||
seen.insert(lark_msg.message_id.clone(), now);
|
||||
}
|
||||
|
||||
// Decode content by type (mirrors clawdbot-feishu parsing)
|
||||
@@ -1290,6 +1329,22 @@ impl LarkChannel {
|
||||
Some(e) => e,
|
||||
None => return messages,
|
||||
};
|
||||
let event_id = payload
|
||||
.pointer("/header/event_id")
|
||||
.and_then(|id| id.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|id| !id.is_empty());
|
||||
let message_id = event
|
||||
.pointer("/message/message_id")
|
||||
.and_then(|id| id.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|id| !id.is_empty());
|
||||
if let Some(dedupe_key) = Self::dedupe_event_key(event_id, message_id) {
|
||||
if !self.try_mark_event_key_seen(&dedupe_key).await {
|
||||
tracing::debug!("Lark webhook: duplicate event dropped ({dedupe_key})");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
|
||||
let open_id = event
|
||||
.pointer("/sender/sender_id/open_id")
|
||||
@@ -2318,6 +2373,100 @@ mod tests {
|
||||
assert_eq!(msgs[0].content, LARK_IMAGE_DOWNLOAD_FALLBACK_TEXT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_event_payload_async_dedupes_repeated_event_id() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": {
|
||||
"event_type": "im.message.receive_v1",
|
||||
"event_id": "evt_abc"
|
||||
},
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_id": "om_first",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"hello\"}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let first = ch.parse_event_payload_async(&payload).await;
|
||||
let second = ch.parse_event_payload_async(&payload).await;
|
||||
assert_eq!(first.len(), 1);
|
||||
assert!(second.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_event_payload_async_dedupes_by_message_id_without_event_id() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": {
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_id": "om_fallback",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"hello\"}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let first = ch.parse_event_payload_async(&payload).await;
|
||||
let second = ch.parse_event_payload_async(&payload).await;
|
||||
assert_eq!(first.len(), 1);
|
||||
assert!(second.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn try_mark_event_key_seen_cleans_up_expired_keys_periodically() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
|
||||
{
|
||||
let mut seen = ch.recent_event_keys.write().await;
|
||||
seen.insert(
|
||||
"event:stale".to_string(),
|
||||
Instant::now() - LARK_EVENT_DEDUP_TTL - Duration::from_secs(5),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let mut cleanup_at = ch.recent_event_cleanup_at.write().await;
|
||||
*cleanup_at =
|
||||
Instant::now() - LARK_EVENT_DEDUP_CLEANUP_INTERVAL - Duration::from_secs(1);
|
||||
}
|
||||
|
||||
assert!(ch.try_mark_event_key_seen("event:fresh").await);
|
||||
let seen = ch.recent_event_keys.read().await;
|
||||
assert!(!seen.contains_key("event:stale"));
|
||||
assert!(seen.contains_key("event:fresh"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_empty_text_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
|
||||
@@ -400,8 +400,7 @@ impl Channel for LinqChannel {
|
||||
/// The signature is sent in `X-Webhook-Signature` (hex-encoded) and the
|
||||
/// timestamp in `X-Webhook-Timestamp`. Reject timestamps older than 300s.
|
||||
pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signature: &str) -> bool {
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use ring::hmac;
|
||||
|
||||
// Reject stale timestamps (>300s old)
|
||||
if let Ok(ts) = timestamp.parse::<i64>() {
|
||||
@@ -417,10 +416,6 @@ pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signatur
|
||||
|
||||
// Compute HMAC-SHA256 over "{timestamp}.{body}"
|
||||
let message = format!("{timestamp}.{body}");
|
||||
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(secret.as_bytes()) else {
|
||||
return false;
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
let signature_hex = signature
|
||||
.trim()
|
||||
.strip_prefix("sha256=")
|
||||
@@ -430,8 +425,8 @@ pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signatur
|
||||
return false;
|
||||
};
|
||||
|
||||
// Constant-time comparison via HMAC verify.
|
||||
mac.verify_slice(&provided).is_ok()
|
||||
let key = hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes());
|
||||
hmac::verify(&key, message.as_bytes(), &provided).is_ok()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+505
-10
@@ -16,6 +16,7 @@
|
||||
|
||||
pub mod bluebubbles;
|
||||
pub mod clawdtalk;
|
||||
pub mod acp;
|
||||
pub mod cli;
|
||||
pub mod dingtalk;
|
||||
pub mod discord;
|
||||
@@ -45,6 +46,7 @@ pub mod whatsapp_storage;
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
pub mod whatsapp_web;
|
||||
|
||||
pub use acp::AcpChannel;
|
||||
pub use bluebubbles::BlueBubblesChannel;
|
||||
pub use clawdtalk::ClawdTalkChannel;
|
||||
pub use cli::CliChannel;
|
||||
@@ -77,6 +79,7 @@ use crate::agent::loop_::{
|
||||
build_shell_policy_instructions, build_tool_instructions_from_specs,
|
||||
run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig,
|
||||
};
|
||||
use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager};
|
||||
use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError};
|
||||
use crate::config::{Config, NonCliNaturalLanguageApprovalMode};
|
||||
use crate::identity;
|
||||
@@ -100,6 +103,8 @@ use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Per-sender conversation history for channel messages.
|
||||
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
|
||||
type ConversationLockMap =
|
||||
Arc<tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>;
|
||||
/// Maximum history messages to keep per sender.
|
||||
const MAX_CHANNEL_HISTORY: usize = 50;
|
||||
/// Minimum user-message length (in chars) for auto-save to memory.
|
||||
@@ -130,6 +135,9 @@ const MEMORY_CONTEXT_ENTRY_MAX_CHARS: usize = 800;
|
||||
const MEMORY_CONTEXT_MAX_CHARS: usize = 4_000;
|
||||
const CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES: usize = 12;
|
||||
const CHANNEL_HISTORY_COMPACT_CONTENT_CHARS: usize = 600;
|
||||
const CHANNEL_CONTEXT_TOKEN_ESTIMATE_LIMIT: usize = 90_000;
|
||||
const CHANNEL_CONTEXT_TOKEN_ESTIMATE_TARGET: usize = 80_000;
|
||||
const CHANNEL_CONTEXT_MIN_KEEP_NON_SYSTEM_MESSAGES: usize = 10;
|
||||
/// Guardrail for hook-modified outbound channel content.
|
||||
const CHANNEL_HOOK_MAX_OUTBOUND_CHARS: usize = 20_000;
|
||||
|
||||
@@ -278,6 +286,9 @@ struct ChannelRuntimeContext {
|
||||
max_tool_iterations: usize,
|
||||
min_relevance_score: f64,
|
||||
conversation_histories: ConversationHistoryMap,
|
||||
conversation_locks: ConversationLockMap,
|
||||
session_config: crate::config::AgentSessionConfig,
|
||||
session_manager: Option<Arc<dyn SessionManager + Send + Sync>>,
|
||||
provider_cache: ProviderCacheMap,
|
||||
route_overrides: RouteSelectionMap,
|
||||
api_key: Option<String>,
|
||||
@@ -338,6 +349,10 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||
format!("assistant_resp_{}", conversation_memory_key(msg))
|
||||
}
|
||||
|
||||
fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
||||
// QQ uses thread_ts as a passive-reply message id, not a thread identifier.
|
||||
// Using it in history keys would reset context on every incoming message.
|
||||
@@ -346,8 +361,10 @@ fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
||||
}
|
||||
|
||||
// Include thread_ts for per-topic session isolation in forum groups
|
||||
match &msg.thread_ts {
|
||||
let channel = msg.channel.as_str();
|
||||
match msg.thread_ts.as_deref().filter(|_| channel != "qq") {
|
||||
Some(tid) => format!("{}_{}_{}", msg.channel, tid, msg.sender),
|
||||
None if channel == "qq" => format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender),
|
||||
None => format!("{}_{}", msg.channel, msg.sender),
|
||||
}
|
||||
}
|
||||
@@ -961,10 +978,10 @@ fn resolved_default_provider(config: &Config) -> String {
|
||||
}
|
||||
|
||||
fn resolved_default_model(config: &Config) -> String {
|
||||
config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4.6".to_string())
|
||||
crate::config::resolve_default_model_id(
|
||||
config.default_model.as_deref(),
|
||||
config.default_provider.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
|
||||
@@ -1731,6 +1748,40 @@ fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatM
|
||||
}
|
||||
}
|
||||
|
||||
fn estimated_message_tokens(message: &ChatMessage) -> usize {
|
||||
(message.content.chars().count().saturating_add(2) / 3).saturating_add(4)
|
||||
}
|
||||
|
||||
fn estimated_history_tokens(history: &[ChatMessage]) -> usize {
|
||||
history.iter().map(estimated_message_tokens).sum()
|
||||
}
|
||||
|
||||
fn trim_channel_prompt_history(history: &mut Vec<ChatMessage>) -> bool {
|
||||
let mut total = estimated_history_tokens(history);
|
||||
if total <= CHANNEL_CONTEXT_TOKEN_ESTIMATE_LIMIT {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut trimmed = false;
|
||||
loop {
|
||||
if total <= CHANNEL_CONTEXT_TOKEN_ESTIMATE_TARGET {
|
||||
break;
|
||||
}
|
||||
let non_system = history.iter().filter(|m| m.role != "system").count();
|
||||
if non_system <= CHANNEL_CONTEXT_MIN_KEEP_NON_SYSTEM_MESSAGES {
|
||||
break;
|
||||
}
|
||||
let Some(idx) = history.iter().position(|m| m.role != "system") else {
|
||||
break;
|
||||
};
|
||||
let removed = history.remove(idx);
|
||||
total = total.saturating_sub(estimated_message_tokens(&removed));
|
||||
trimmed = true;
|
||||
}
|
||||
|
||||
trimmed
|
||||
}
|
||||
|
||||
fn rollback_orphan_user_turn(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
sender_key: &str,
|
||||
@@ -2691,10 +2742,11 @@ async fn build_memory_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
min_relevance_score: f64,
|
||||
session_id: Option<&str>,
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
let mut included = 0usize;
|
||||
let mut used_chars = 0usize;
|
||||
|
||||
@@ -3158,6 +3210,9 @@ async fn process_channel_message(
|
||||
msg: traits::ChannelMessage,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
let sender_id = msg.sender.as_str();
|
||||
let channel_name = msg.channel.as_str();
|
||||
tracing::debug!(sender_id, channel_name, "process_message called");
|
||||
if cancellation_token.is_cancelled() {
|
||||
return;
|
||||
}
|
||||
@@ -3251,6 +3306,62 @@ or tune thresholds in config.",
|
||||
}
|
||||
|
||||
let history_key = conversation_history_key(&msg);
|
||||
let conversation_lock = {
|
||||
let mut locks = ctx.conversation_locks.lock().await;
|
||||
locks
|
||||
.entry(history_key.clone())
|
||||
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
let _conversation_guard = conversation_lock.lock().await;
|
||||
let mut session: Option<Session> = None;
|
||||
if let Some(manager) = ctx.session_manager.as_ref() {
|
||||
let session_id = resolve_session_id(
|
||||
&ctx.session_config,
|
||||
msg.sender.as_str(),
|
||||
Some(msg.channel.as_str()),
|
||||
);
|
||||
tracing::debug!(session_id, "session_id resolved");
|
||||
match manager.get_or_create(&session_id).await {
|
||||
Ok(opened) => {
|
||||
session = Some(opened);
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to open session: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(session) = session.as_ref() {
|
||||
let should_seed = {
|
||||
let histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
!histories.contains_key(&history_key)
|
||||
};
|
||||
|
||||
if should_seed {
|
||||
match session.get_history().await {
|
||||
Ok(history) => {
|
||||
tracing::debug!(history_len = history.len(), "session history loaded");
|
||||
let filtered: Vec<ChatMessage> =
|
||||
history
|
||||
.into_iter()
|
||||
.filter(|m| crate::providers::is_user_or_assistant_role(m.role.as_str()))
|
||||
.collect();
|
||||
let mut histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
histories.entry(history_key.clone()).or_insert(filtered);
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to load session history: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try classification first, fall back to sender/default route
|
||||
let route = classify_message_route(ctx.as_ref(), &msg.content)
|
||||
.unwrap_or_else(|| get_route_selection(ctx.as_ref(), &history_key));
|
||||
@@ -3282,7 +3393,7 @@ or tune thresholds in config.",
|
||||
&autosave_key,
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
None,
|
||||
Some(&history_key),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -3338,6 +3449,7 @@ or tune thresholds in config.",
|
||||
ctx.memory.as_ref(),
|
||||
&msg.content,
|
||||
ctx.min_relevance_score,
|
||||
Some(&history_key),
|
||||
)
|
||||
.await;
|
||||
if !memory_context.is_empty() {
|
||||
@@ -3367,6 +3479,7 @@ or tune thresholds in config.",
|
||||
));
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
history.extend(prior_turns);
|
||||
let _ = trim_channel_prompt_history(&mut history);
|
||||
let use_streaming = target_channel
|
||||
.as_ref()
|
||||
.is_some_and(|ch| ch.supports_draft_updates());
|
||||
@@ -3661,6 +3774,20 @@ or tune thresholds in config.",
|
||||
&history_key,
|
||||
ChatMessage::assistant(&history_response),
|
||||
);
|
||||
if ctx.auto_save_memory
|
||||
&& delivered_response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
{
|
||||
let assistant_key = assistant_memory_key(&msg);
|
||||
let _ = ctx
|
||||
.memory
|
||||
.store(
|
||||
&assistant_key,
|
||||
&delivered_response,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
@@ -3968,7 +4095,7 @@ async fn run_message_dispatch_loop(
|
||||
}
|
||||
}
|
||||
|
||||
process_channel_message(worker_ctx, msg, cancellation_token).await;
|
||||
Box::pin(process_channel_message(worker_ctx, msg, cancellation_token)).await;
|
||||
|
||||
if interrupt_enabled {
|
||||
let mut active = in_flight.lock().await;
|
||||
@@ -4608,6 +4735,7 @@ fn collect_configured_channels(
|
||||
sl.bot_token.clone(),
|
||||
sl.app_token.clone(),
|
||||
sl.channel_id.clone(),
|
||||
sl.channel_ids.clone(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_group_reply_policy(
|
||||
@@ -4892,6 +5020,12 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref acp) = config.channels_config.acp {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "ACP",
|
||||
channel: Arc::new(AcpChannel::new(acp.clone())),
|
||||
});
|
||||
}
|
||||
channels
|
||||
}
|
||||
|
||||
@@ -5335,6 +5469,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
.as_ref()
|
||||
.is_some_and(|tg| tg.interrupt_on_new_message);
|
||||
|
||||
let session_manager = shared_session_manager(&config.agent.session, &config.workspace_dir)?
|
||||
.map(|mgr| mgr as Arc<dyn SessionManager + Send + Sync>);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name,
|
||||
provider: Arc::clone(&provider),
|
||||
@@ -5349,6 +5486,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
max_tool_iterations: config.agent.max_tool_iterations,
|
||||
min_relevance_score: config.memory.min_relevance_score,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
session_config: config.agent.session.clone(),
|
||||
session_manager,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: config.api_key.clone(),
|
||||
@@ -5412,11 +5552,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::large_futures)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::{ChatMessage, Provider};
|
||||
use crate::security::AutonomyLevel;
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
@@ -5720,6 +5862,9 @@ mod tests {
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||||
conversation_locks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -5774,6 +5919,9 @@ mod tests {
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -5831,6 +5979,9 @@ mod tests {
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -5888,6 +6039,11 @@ mod tests {
|
||||
reactions_removed: tokio::sync::Mutex<Vec<(String, String, String)>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct QqRecordingChannel {
|
||||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TelegramRecordingChannel {
|
||||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||||
@@ -6044,6 +6200,36 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Channel for QqRecordingChannel {
|
||||
fn name(&self) -> &str {
|
||||
"qq"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
self.sent_messages
|
||||
.lock()
|
||||
.await
|
||||
.push(format!("{}:{}", message.recipient, message.content));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct SlowProvider {
|
||||
delay: Duration,
|
||||
}
|
||||
@@ -6429,6 +6615,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6492,6 +6681,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
auto_approve: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
@@ -6506,6 +6702,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6556,6 +6755,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
auto_approve: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
@@ -6570,6 +6776,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6634,6 +6843,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
auto_approve: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
@@ -6648,6 +6864,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6711,6 +6930,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
auto_approve: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
@@ -6725,6 +6951,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6794,6 +7023,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6844,6 +7076,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
auto_approve: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
let approval_manager = Arc::new(ApprovalManager::from_config(&autonomy_cfg));
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingAliasProvider),
|
||||
@@ -6858,6 +7097,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -6931,6 +7173,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7032,6 +7277,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7184,6 +7432,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7296,6 +7547,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7403,6 +7657,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7495,6 +7752,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7586,6 +7846,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7683,6 +7946,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7692,7 +7958,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
zeroclaw_dir: Some(temp.path().to_path_buf()),
|
||||
..providers::ProviderRuntimeOptions::default()
|
||||
},
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
workspace_dir: Arc::new(temp.path().join("workspace")),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
@@ -7831,6 +8097,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -7925,6 +8194,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8072,6 +8344,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8189,6 +8464,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8286,6 +8564,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8405,6 +8686,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8522,6 +8806,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(route_overrides)),
|
||||
api_key: None,
|
||||
@@ -8598,6 +8885,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8688,6 +8978,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -8811,6 +9104,27 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert_eq!(policy.outbound_leak_guard.sensitivity, 0.95);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_runtime_defaults_from_config_file_uses_provider_fallback_when_model_missing() {
|
||||
let temp = tempfile::TempDir::new().expect("temp dir");
|
||||
let config_path = temp.path().join("config.toml");
|
||||
let workspace_dir = temp.path().join("workspace");
|
||||
std::fs::create_dir_all(&workspace_dir).expect("workspace dir");
|
||||
|
||||
let mut cfg = Config::default();
|
||||
cfg.config_path = config_path.clone();
|
||||
cfg.workspace_dir = workspace_dir;
|
||||
cfg.default_provider = Some("openai".to_string());
|
||||
cfg.default_model = None;
|
||||
cfg.save().await.expect("save config");
|
||||
|
||||
let (defaults, _policy) = load_runtime_defaults_from_config_file(&config_path)
|
||||
.await
|
||||
.expect("runtime defaults");
|
||||
assert_eq!(defaults.default_provider, "openai");
|
||||
assert_eq!(defaults.model, "gpt-5.2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn maybe_apply_runtime_config_update_refreshes_autonomy_policy_and_excluded_tools() {
|
||||
let temp = tempfile::TempDir::new().expect("temp dir");
|
||||
@@ -8844,6 +9158,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: Some("http://127.0.0.1:11434".to_string()),
|
||||
@@ -8964,6 +9281,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 12,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9029,6 +9349,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 3,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9206,6 +9529,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9293,6 +9619,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9392,6 +9721,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9473,6 +9805,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9539,6 +9874,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -9981,6 +10319,26 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assistant_memory_key_is_namespaced_from_user_key() {
|
||||
let msg = traits::ChannelMessage {
|
||||
id: "msg_abc123".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "hello".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
};
|
||||
|
||||
let user_key = conversation_memory_key(&msg);
|
||||
let assistant_key = assistant_memory_key(&msg);
|
||||
|
||||
assert!(assistant_key.starts_with("assistant_resp_"));
|
||||
assert!(assistant_key.ends_with(&user_key));
|
||||
assert_ne!(assistant_key, user_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversation_history_key_ignores_qq_message_id_thread() {
|
||||
let msg1 = traits::ChannelMessage {
|
||||
@@ -10092,11 +10450,37 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_memory_context(&mem, "age", 0.0).await;
|
||||
let context = build_memory_context(&mem, "age", 0.0, None).await;
|
||||
assert!(context.contains("[Memory context]"));
|
||||
assert!(context.contains("Age is 45"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_memory_context_respects_session_scope() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store(
|
||||
"session_a_fact",
|
||||
"Session A remembers age 45",
|
||||
MemoryCategory::Conversation,
|
||||
Some("session-a"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"session_b_fact",
|
||||
"Session B remembers age 31",
|
||||
MemoryCategory::Conversation,
|
||||
Some("session-b"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session_a_context = build_memory_context(&mem, "age", 0.0, Some("session-a")).await;
|
||||
assert!(session_a_context.contains("age 45"));
|
||||
assert!(!session_a_context.contains("age 31"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_restores_per_sender_history_on_follow_ups() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
@@ -10121,6 +10505,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -10190,6 +10577,102 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(calls[1][3].1.contains("follow up"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_qq_keeps_history_across_distinct_message_ids() {
|
||||
let channel_impl = Arc::new(QqRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: provider_impl.clone(),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
approval_manager: Arc::new(ApprovalManager::from_config(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
safety_heartbeat: None,
|
||||
startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx.clone(),
|
||||
traits::ChannelMessage {
|
||||
id: "msg-a".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "group:1".to_string(),
|
||||
content: "hello".to_string(),
|
||||
channel: "qq".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: Some("msg-1".to_string()),
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-b".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "group:1".to_string(),
|
||||
content: "follow up".to_string(),
|
||||
channel: "qq".to_string(),
|
||||
timestamp: 2,
|
||||
thread_ts: Some("msg-2".to_string()),
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let calls = provider_impl
|
||||
.calls
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].len(), 2);
|
||||
assert_eq!(calls[0][0].0, "system");
|
||||
assert_eq!(calls[0][1].0, "user");
|
||||
assert_eq!(calls[1].len(), 4);
|
||||
assert_eq!(calls[1][0].0, "system");
|
||||
assert_eq!(calls[1][1].0, "user");
|
||||
assert_eq!(calls[1][2].0, "assistant");
|
||||
assert_eq!(calls[1][3].0, "user");
|
||||
assert!(calls[1][1].1.contains("hello"));
|
||||
assert!(calls[1][2].1.contains("response-1"));
|
||||
assert!(calls[1][3].1.contains("follow up"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_enriches_current_turn_without_persisting_context() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
@@ -10213,6 +10696,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -10309,6 +10795,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -11088,6 +11577,9 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
@@ -11161,6 +11653,9 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Nextcloud Talk channel in webhook mode.
|
||||
@@ -247,6 +245,8 @@ pub fn verify_nextcloud_talk_signature(
|
||||
body: &str,
|
||||
signature: &str,
|
||||
) -> bool {
|
||||
use ring::hmac;
|
||||
|
||||
let random = random.trim();
|
||||
if random.is_empty() {
|
||||
tracing::warn!("Nextcloud Talk: missing X-Nextcloud-Talk-Random header");
|
||||
@@ -265,17 +265,15 @@ pub fn verify_nextcloud_talk_signature(
|
||||
};
|
||||
|
||||
let payload = format!("{random}{body}");
|
||||
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(secret.as_bytes()) else {
|
||||
return false;
|
||||
};
|
||||
mac.update(payload.as_bytes());
|
||||
|
||||
mac.verify_slice(&provided).is_ok()
|
||||
let key = hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes());
|
||||
hmac::verify(&key, payload.as_bytes(), &provided).is_ok()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
fn make_channel() -> NextcloudTalkChannel {
|
||||
NextcloudTalkChannel::new(
|
||||
|
||||
+303
-24
@@ -3,39 +3,52 @@ use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use reqwest::header::HeaderMap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CachedSlackDisplayName {
|
||||
display_name: String,
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
/// Slack channel — polls conversations.history via Web API
|
||||
pub struct SlackChannel {
|
||||
bot_token: String,
|
||||
app_token: Option<String>,
|
||||
channel_id: Option<String>,
|
||||
channel_ids: Vec<String>,
|
||||
allowed_users: Vec<String>,
|
||||
mention_only: bool,
|
||||
group_reply_allowed_sender_ids: Vec<String>,
|
||||
user_display_name_cache: Mutex<HashMap<String, CachedSlackDisplayName>>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
const SLACK_HISTORY_DEFAULT_RETRY_AFTER_SECS: u64 = 1;
|
||||
const SLACK_HISTORY_MAX_BACKOFF_SECS: u64 = 120;
|
||||
const SLACK_HISTORY_MAX_JITTER_MS: u64 = 500;
|
||||
const SLACK_USER_CACHE_TTL_SECS: u64 = 6 * 60 * 60;
|
||||
|
||||
impl SlackChannel {
|
||||
pub fn new(
|
||||
bot_token: String,
|
||||
app_token: Option<String>,
|
||||
channel_id: Option<String>,
|
||||
channel_ids: Vec<String>,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
app_token,
|
||||
channel_id,
|
||||
channel_ids,
|
||||
allowed_users,
|
||||
mention_only: false,
|
||||
group_reply_allowed_sender_ids: Vec::new(),
|
||||
user_display_name_cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +124,22 @@ impl SlackChannel {
|
||||
Self::normalized_channel_id(self.channel_id.as_deref())
|
||||
}
|
||||
|
||||
/// Resolve the effective channel scope:
|
||||
/// explicit `channel_ids` list first, then single `channel_id`, otherwise wildcard discovery.
|
||||
fn scoped_channel_ids(&self) -> Option<Vec<String>> {
|
||||
let mut seen = HashSet::new();
|
||||
let ids: Vec<String> = self
|
||||
.channel_ids
|
||||
.iter()
|
||||
.filter_map(|entry| Self::normalized_channel_id(Some(entry)))
|
||||
.filter(|id| seen.insert(id.clone()))
|
||||
.collect();
|
||||
if !ids.is_empty() {
|
||||
return Some(ids);
|
||||
}
|
||||
self.configured_channel_id().map(|id| vec![id])
|
||||
}
|
||||
|
||||
fn configured_app_token(&self) -> Option<String> {
|
||||
self.app_token
|
||||
.as_deref()
|
||||
@@ -130,6 +159,137 @@ impl SlackChannel {
|
||||
normalized
|
||||
}
|
||||
|
||||
fn user_cache_ttl() -> Duration {
|
||||
Duration::from_secs(SLACK_USER_CACHE_TTL_SECS)
|
||||
}
|
||||
|
||||
fn sanitize_display_name(name: &str) -> Option<String> {
|
||||
let trimmed = name.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_user_display_name(payload: &serde_json::Value) -> Option<String> {
|
||||
let user = payload.get("user")?;
|
||||
let profile = user.get("profile");
|
||||
|
||||
let candidates = [
|
||||
profile
|
||||
.and_then(|p| p.get("display_name"))
|
||||
.and_then(|v| v.as_str()),
|
||||
profile
|
||||
.and_then(|p| p.get("display_name_normalized"))
|
||||
.and_then(|v| v.as_str()),
|
||||
profile
|
||||
.and_then(|p| p.get("real_name_normalized"))
|
||||
.and_then(|v| v.as_str()),
|
||||
profile
|
||||
.and_then(|p| p.get("real_name"))
|
||||
.and_then(|v| v.as_str()),
|
||||
user.get("real_name").and_then(|v| v.as_str()),
|
||||
user.get("name").and_then(|v| v.as_str()),
|
||||
];
|
||||
|
||||
for candidate in candidates.into_iter().flatten() {
|
||||
if let Some(display_name) = Self::sanitize_display_name(candidate) {
|
||||
return Some(display_name);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn cached_sender_display_name(&self, user_id: &str) -> Option<String> {
|
||||
let now = Instant::now();
|
||||
let Ok(mut cache) = self.user_display_name_cache.lock() else {
|
||||
return None;
|
||||
};
|
||||
|
||||
if let Some(entry) = cache.get(user_id) {
|
||||
if now <= entry.expires_at {
|
||||
return Some(entry.display_name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
cache.remove(user_id);
|
||||
None
|
||||
}
|
||||
|
||||
fn cache_sender_display_name(&self, user_id: &str, display_name: &str) {
|
||||
let Ok(mut cache) = self.user_display_name_cache.lock() else {
|
||||
return;
|
||||
};
|
||||
cache.insert(
|
||||
user_id.to_string(),
|
||||
CachedSlackDisplayName {
|
||||
display_name: display_name.to_string(),
|
||||
expires_at: Instant::now() + Self::user_cache_ttl(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
async fn fetch_sender_display_name(&self, user_id: &str) -> Option<String> {
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get("https://slack.com/api/users.info")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.query(&[("user", user_id)])
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
tracing::warn!("Slack users.info request failed for {user_id}: {err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
|
||||
|
||||
if !status.is_success() {
|
||||
let sanitized = crate::providers::sanitize_api_error(&body);
|
||||
tracing::warn!("Slack users.info failed for {user_id} ({status}): {sanitized}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let payload: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
|
||||
if payload.get("ok") == Some(&serde_json::Value::Bool(false)) {
|
||||
let err = payload
|
||||
.get("error")
|
||||
.and_then(|e| e.as_str())
|
||||
.unwrap_or("unknown");
|
||||
tracing::warn!("Slack users.info returned error for {user_id}: {err}");
|
||||
return None;
|
||||
}
|
||||
|
||||
Self::extract_user_display_name(&payload)
|
||||
}
|
||||
|
||||
async fn resolve_sender_identity(&self, user_id: &str) -> String {
|
||||
let user_id = user_id.trim();
|
||||
if user_id.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
if let Some(display_name) = self.cached_sender_display_name(user_id) {
|
||||
return display_name;
|
||||
}
|
||||
|
||||
if let Some(display_name) = self.fetch_sender_display_name(user_id).await {
|
||||
self.cache_sender_display_name(user_id, &display_name);
|
||||
return display_name;
|
||||
}
|
||||
|
||||
user_id.to_string()
|
||||
}
|
||||
|
||||
fn is_group_channel_id(channel_id: &str) -> bool {
|
||||
matches!(channel_id.chars().next(), Some('C' | 'G'))
|
||||
}
|
||||
@@ -327,7 +487,7 @@ impl SlackChannel {
|
||||
&self,
|
||||
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
bot_user_id: &str,
|
||||
scoped_channel: Option<String>,
|
||||
scoped_channels: Option<Vec<String>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut last_ts_by_channel: HashMap<String, String> = HashMap::new();
|
||||
|
||||
@@ -425,8 +585,8 @@ impl SlackChannel {
|
||||
if channel_id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Some(ref configured_channel) = scoped_channel {
|
||||
if channel_id != *configured_channel {
|
||||
if let Some(ref configured_channels) = scoped_channels {
|
||||
if !configured_channels.iter().any(|id| id == &channel_id) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -476,10 +636,11 @@ impl SlackChannel {
|
||||
};
|
||||
|
||||
last_ts_by_channel.insert(channel_id.clone(), ts.to_string());
|
||||
let sender = self.resolve_sender_identity(user).await;
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: format!("slack_{channel_id}_{ts}"),
|
||||
sender: user.to_string(),
|
||||
sender,
|
||||
reply_target: channel_id.clone(),
|
||||
content: normalized_text,
|
||||
channel: "slack".to_string(),
|
||||
@@ -695,11 +856,11 @@ impl Channel for SlackChannel {
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let bot_user_id = self.get_bot_user_id().await.unwrap_or_default();
|
||||
let scoped_channel = self.configured_channel_id();
|
||||
let scoped_channels = self.scoped_channel_ids();
|
||||
if self.configured_app_token().is_some() {
|
||||
tracing::info!("Slack channel listening in Socket Mode");
|
||||
return self
|
||||
.listen_socket_mode(tx, &bot_user_id, scoped_channel)
|
||||
.listen_socket_mode(tx, &bot_user_id, scoped_channels)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -707,19 +868,23 @@ impl Channel for SlackChannel {
|
||||
let mut last_discovery = Instant::now();
|
||||
let mut last_ts_by_channel: HashMap<String, String> = HashMap::new();
|
||||
|
||||
if let Some(ref channel_id) = scoped_channel {
|
||||
tracing::info!("Slack channel listening on #{channel_id}...");
|
||||
if let Some(ref channel_ids) = scoped_channels {
|
||||
tracing::info!(
|
||||
"Slack channel listening on {} configured channel(s): {}",
|
||||
channel_ids.len(),
|
||||
channel_ids.join(", ")
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Slack channel_id not set (or '*'); listening across all accessible channels."
|
||||
"Slack channel_id/channel_ids not set (or wildcard only); listening across all accessible channels."
|
||||
);
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
let target_channels = if let Some(ref channel_id) = scoped_channel {
|
||||
vec![channel_id.clone()]
|
||||
let target_channels = if let Some(ref channel_ids) = scoped_channels {
|
||||
channel_ids.clone()
|
||||
} else {
|
||||
if discovered_channels.is_empty()
|
||||
|| last_discovery.elapsed() >= Duration::from_secs(60)
|
||||
@@ -820,10 +985,11 @@ impl Channel for SlackChannel {
|
||||
};
|
||||
|
||||
last_ts_by_channel.insert(channel_id.clone(), ts.to_string());
|
||||
let sender = self.resolve_sender_identity(user).await;
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: format!("slack_{channel_id}_{ts}"),
|
||||
sender: user.to_string(),
|
||||
sender,
|
||||
reply_target: channel_id.clone(),
|
||||
content: normalized_text,
|
||||
channel: "slack".to_string(),
|
||||
@@ -860,26 +1026,32 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn slack_channel_name() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
assert_eq!(ch.name(), "slack");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_channel_with_channel_id() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, Some("C12345".into()), vec![]);
|
||||
let ch = SlackChannel::new(
|
||||
"xoxb-fake".into(),
|
||||
None,
|
||||
Some("C12345".into()),
|
||||
vec![],
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(ch.channel_id, Some("C12345".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_group_reply_policy_defaults_to_all_messages() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["*".into()]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
|
||||
assert!(!ch.mention_only);
|
||||
assert!(ch.group_reply_allowed_sender_ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_group_reply_policy_applies_sender_overrides() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["*".into()])
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()])
|
||||
.with_group_reply_policy(true, vec![" U111 ".into(), "U111".into(), "U222".into()]);
|
||||
|
||||
assert!(ch.mention_only);
|
||||
@@ -906,16 +1078,55 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn configured_app_token_ignores_blank_values() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), Some(" ".into()), None, vec![]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), Some(" ".into()), None, vec![], vec![]);
|
||||
assert_eq!(ch.configured_app_token(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn configured_app_token_trims_value() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), Some(" xapp-123 ".into()), None, vec![]);
|
||||
let ch = SlackChannel::new(
|
||||
"xoxb-fake".into(),
|
||||
Some(" xapp-123 ".into()),
|
||||
None,
|
||||
vec![],
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(ch.configured_app_token().as_deref(), Some("xapp-123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_channel_ids_prefers_explicit_list() {
|
||||
let ch = SlackChannel::new(
|
||||
"xoxb-fake".into(),
|
||||
None,
|
||||
Some("C_SINGLE".into()),
|
||||
vec!["C_LIST1".into(), "D_DM1".into()],
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(
|
||||
ch.scoped_channel_ids(),
|
||||
Some(vec!["C_LIST1".to_string(), "D_DM1".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_channel_ids_falls_back_to_single_channel_id() {
|
||||
let ch = SlackChannel::new(
|
||||
"xoxb-fake".into(),
|
||||
None,
|
||||
Some("C_SINGLE".into()),
|
||||
vec![],
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(ch.scoped_channel_ids(), Some(vec!["C_SINGLE".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_channel_ids_returns_none_for_wildcard_mode() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
assert_eq!(ch.scoped_channel_ids(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_group_channel_id_detects_channel_prefixes() {
|
||||
assert!(SlackChannel::is_group_channel_id("C123"));
|
||||
@@ -941,17 +1152,83 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_everyone() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
assert!(!ch.is_user_allowed("U12345"));
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_everyone() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["*".into()]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
|
||||
assert!(ch.is_user_allowed("U12345"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_user_display_name_prefers_profile_display_name() {
|
||||
let payload = serde_json::json!({
|
||||
"ok": true,
|
||||
"user": {
|
||||
"name": "fallback_name",
|
||||
"profile": {
|
||||
"display_name": "Display Name",
|
||||
"real_name": "Real Name"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
SlackChannel::extract_user_display_name(&payload).as_deref(),
|
||||
Some("Display Name")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_user_display_name_falls_back_to_username() {
|
||||
let payload = serde_json::json!({
|
||||
"ok": true,
|
||||
"user": {
|
||||
"name": "fallback_name",
|
||||
"profile": {
|
||||
"display_name": " ",
|
||||
"real_name": ""
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
SlackChannel::extract_user_display_name(&payload).as_deref(),
|
||||
Some("fallback_name")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cached_sender_display_name_returns_none_when_expired() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
|
||||
{
|
||||
let mut cache = ch.user_display_name_cache.lock().unwrap();
|
||||
cache.insert(
|
||||
"U123".to_string(),
|
||||
CachedSlackDisplayName {
|
||||
display_name: "Expired Name".to_string(),
|
||||
expires_at: Instant::now() - Duration::from_secs(1),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(ch.cached_sender_display_name("U123"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cached_sender_display_name_returns_cached_value_when_valid() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["*".into()]);
|
||||
ch.cache_sender_display_name("U123", "Cached Name");
|
||||
|
||||
assert_eq!(
|
||||
ch.cached_sender_display_name("U123").as_deref(),
|
||||
Some("Cached Name")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_incoming_content_requires_mention_when_enabled() {
|
||||
assert!(SlackChannel::normalize_incoming_content("hello", true, "U_BOT").is_none());
|
||||
@@ -975,6 +1252,7 @@ mod tests {
|
||||
"xoxb-fake".into(),
|
||||
None,
|
||||
None,
|
||||
vec![],
|
||||
vec!["U111".into(), "U222".into()],
|
||||
);
|
||||
assert!(ch.is_user_allowed("U111"));
|
||||
@@ -984,20 +1262,20 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn allowlist_exact_match_not_substring() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["U111".into()]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["U111".into()]);
|
||||
assert!(!ch.is_user_allowed("U1111"));
|
||||
assert!(!ch.is_user_allowed("U11"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlist_empty_user_id() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["U111".into()]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["U111".into()]);
|
||||
assert!(!ch.is_user_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlist_case_sensitive() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec!["U111".into()]);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec!["U111".into()]);
|
||||
assert!(ch.is_user_allowed("U111"));
|
||||
assert!(!ch.is_user_allowed("u111"));
|
||||
}
|
||||
@@ -1008,6 +1286,7 @@ mod tests {
|
||||
"xoxb-fake".into(),
|
||||
None,
|
||||
None,
|
||||
vec![],
|
||||
vec!["U111".into(), "*".into()],
|
||||
);
|
||||
assert!(ch.is_user_allowed("U111"));
|
||||
|
||||
@@ -765,7 +765,7 @@ impl TelegramChannel {
|
||||
}
|
||||
|
||||
fn log_poll_transport_error(sanitized: &str, consecutive_failures: u32) {
|
||||
if consecutive_failures >= 6 && consecutive_failures % 6 == 0 {
|
||||
if consecutive_failures >= 6 && consecutive_failures.is_multiple_of(6) {
|
||||
tracing::warn!(
|
||||
"Telegram poll transport error persists (consecutive={}): {}",
|
||||
consecutive_failures,
|
||||
@@ -3109,8 +3109,8 @@ impl Channel for TelegramChannel {
|
||||
let thread_id = parsed_thread_id.or(thread_ts);
|
||||
|
||||
let raw_args = arguments.to_string();
|
||||
let args_preview = if raw_args.len() > 260 {
|
||||
format!("{}...", &raw_args[..260])
|
||||
let args_preview = if raw_args.chars().count() > 260 {
|
||||
crate::util::truncate_with_ellipsis(&raw_args, 260)
|
||||
} else {
|
||||
raw_args
|
||||
};
|
||||
|
||||
+6
-3
@@ -4,9 +4,11 @@ pub mod traits;
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{
|
||||
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AgentsIpcConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig,
|
||||
BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, ComposioConfig, Config,
|
||||
build_runtime_proxy_client_with_timeouts, default_model_fallback_for_provider,
|
||||
resolve_default_model_id, runtime_proxy_config, set_runtime_proxy_config, AgentConfig,
|
||||
AgentSessionBackend, AgentSessionConfig, AgentSessionStrategy, AgentsIpcConfig, AuditConfig,
|
||||
AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ComposioConfig, Config,
|
||||
CoordinationConfig, CostConfig, CronConfig, DelegateAgentConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EconomicConfig, EconomicTokenPricing, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig,
|
||||
@@ -23,6 +25,7 @@ pub use schema::{
|
||||
StorageProviderSection, StreamMode, SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig,
|
||||
TunnelConfig, UrlAccessConfig, WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy,
|
||||
WasmRuntimeConfig, WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
DEFAULT_MODEL_FALLBACK,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+485
-19
@@ -1,5 +1,7 @@
|
||||
use crate::config::traits::ChannelConfig;
|
||||
use crate::providers::{is_glm_alias, is_zai_alias};
|
||||
use crate::providers::{
|
||||
canonical_china_provider_name, is_glm_alias, is_qwen_oauth_alias, is_zai_alias,
|
||||
};
|
||||
use crate::security::{AutonomyLevel, DomainMatcher};
|
||||
use anyhow::{Context, Result};
|
||||
use directories::UserDirs;
|
||||
@@ -14,6 +16,100 @@ use tokio::fs::File;
|
||||
use tokio::fs::{self, OpenOptions};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
/// Default fallback model when none is configured. Uses a format compatible with
|
||||
/// OpenRouter and other multi-provider gateways. For Anthropic direct API, this
|
||||
/// model ID will be normalized by the provider layer.
|
||||
pub const DEFAULT_MODEL_FALLBACK: &str = "anthropic/claude-sonnet-4.6";
|
||||
|
||||
fn canonical_provider_for_model_defaults(provider_name: &str) -> String {
|
||||
if let Some(canonical) = canonical_china_provider_name(provider_name) {
|
||||
return if canonical == "doubao" {
|
||||
"volcengine".to_string()
|
||||
} else {
|
||||
canonical.to_string()
|
||||
};
|
||||
}
|
||||
|
||||
match provider_name {
|
||||
"grok" => "xai".to_string(),
|
||||
"together" => "together-ai".to_string(),
|
||||
"google" | "google-gemini" => "gemini".to_string(),
|
||||
"github-copilot" => "copilot".to_string(),
|
||||
"openai_codex" | "codex" => "openai-codex".to_string(),
|
||||
"kimi_coding" | "kimi_for_coding" => "kimi-code".to_string(),
|
||||
"nvidia-nim" | "build.nvidia.com" => "nvidia".to_string(),
|
||||
"aws-bedrock" => "bedrock".to_string(),
|
||||
"llama.cpp" => "llamacpp".to_string(),
|
||||
_ => provider_name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a provider-aware fallback model ID when `default_model` is missing.
|
||||
pub fn default_model_fallback_for_provider(provider_name: Option<&str>) -> &'static str {
|
||||
let normalized_provider = provider_name
|
||||
.unwrap_or("openrouter")
|
||||
.trim()
|
||||
.to_ascii_lowercase()
|
||||
.replace('_', "-");
|
||||
|
||||
if normalized_provider == "qwen-coding-plan" {
|
||||
return "qwen3-coder-plus";
|
||||
}
|
||||
|
||||
let canonical_provider = if is_qwen_oauth_alias(&normalized_provider) {
|
||||
"qwen-code".to_string()
|
||||
} else {
|
||||
canonical_provider_for_model_defaults(&normalized_provider)
|
||||
};
|
||||
|
||||
match canonical_provider.as_str() {
|
||||
"anthropic" => "claude-sonnet-4-5-20250929",
|
||||
"openai" => "gpt-5.2",
|
||||
"openai-codex" => "gpt-5-codex",
|
||||
"venice" => "zai-org-glm-5",
|
||||
"groq" => "llama-3.3-70b-versatile",
|
||||
"mistral" => "mistral-large-latest",
|
||||
"deepseek" => "deepseek-chat",
|
||||
"xai" => "grok-4-1-fast-reasoning",
|
||||
"perplexity" => "sonar-pro",
|
||||
"fireworks" => "accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
"novita" => "minimax/minimax-m2.5",
|
||||
"together-ai" => "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"cohere" => "command-a-03-2025",
|
||||
"moonshot" => "kimi-k2.5",
|
||||
"hunyuan" => "hunyuan-t1-latest",
|
||||
"glm" | "zai" => "glm-5",
|
||||
"minimax" => "MiniMax-M2.5",
|
||||
"qwen" => "qwen-plus",
|
||||
"volcengine" => "doubao-1-5-pro-32k-250115",
|
||||
"siliconflow" => "Pro/zai-org/GLM-4.7",
|
||||
"qwen-code" => "qwen3-coder-plus",
|
||||
"ollama" => "llama3.2",
|
||||
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF",
|
||||
"sglang" | "vllm" | "osaurus" | "copilot" => "default",
|
||||
"gemini" => "gemini-2.5-pro",
|
||||
"kimi-code" => "kimi-for-coding",
|
||||
"bedrock" => "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"nvidia" => "meta/llama-3.3-70b-instruct",
|
||||
_ => DEFAULT_MODEL_FALLBACK,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves the model ID used by runtime components.
|
||||
/// Preference order:
|
||||
/// 1) Explicit configured model (if non-empty)
|
||||
/// 2) Provider-aware fallback
|
||||
pub fn resolve_default_model_id(
|
||||
default_model: Option<&str>,
|
||||
provider_name: Option<&str>,
|
||||
) -> String {
|
||||
if let Some(model) = default_model.map(str::trim).filter(|m| !m.is_empty()) {
|
||||
return model.to_string();
|
||||
}
|
||||
|
||||
default_model_fallback_for_provider(provider_name).to_string()
|
||||
}
|
||||
|
||||
const SUPPORTED_PROXY_SERVICE_KEYS: &[&str] = &[
|
||||
"provider.anthropic",
|
||||
"provider.compatible",
|
||||
@@ -60,6 +156,8 @@ const SUPPORTED_PROXY_SERVICE_SELECTORS: &[&str] = &[
|
||||
static RUNTIME_PROXY_CONFIG: OnceLock<RwLock<ProxyConfig>> = OnceLock::new();
|
||||
static RUNTIME_PROXY_CLIENT_CACHE: OnceLock<RwLock<HashMap<String, reqwest::Client>>> =
|
||||
OnceLock::new();
|
||||
const DEFAULT_PROVIDER_NAME: &str = "openrouter";
|
||||
const DEFAULT_MODEL_NAME: &str = "anthropic/claude-sonnet-4.6";
|
||||
|
||||
// ── Top-level config ──────────────────────────────────────────────
|
||||
|
||||
@@ -305,6 +403,12 @@ pub struct ModelProviderConfig {
|
||||
/// Provider protocol variant ("responses" or "chat_completions").
|
||||
#[serde(default)]
|
||||
pub wire_api: Option<String>,
|
||||
/// Optional profile-scoped default model.
|
||||
#[serde(default, alias = "model")]
|
||||
pub default_model: Option<String>,
|
||||
/// Optional profile-scoped API key.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// If true, load OpenAI auth material (OPENAI_API_KEY or ~/.codex/auth.json).
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
@@ -416,6 +520,7 @@ impl std::fmt::Debug for Config {
|
||||
self.channels_config.dingtalk.is_some(),
|
||||
self.channels_config.napcat.is_some(),
|
||||
self.channels_config.qq.is_some(),
|
||||
self.channels_config.acp.is_some(),
|
||||
self.channels_config.nostr.is_some(),
|
||||
self.channels_config.clawdtalk.is_some(),
|
||||
]
|
||||
@@ -726,6 +831,8 @@ pub struct AgentConfig {
|
||||
/// When true: bootstrap_max_chars=6000, rag_chunk_limit=2. Use for 13B or smaller models.
|
||||
#[serde(default)]
|
||||
pub compact_context: bool,
|
||||
#[serde(default)]
|
||||
pub session: AgentSessionConfig,
|
||||
/// Maximum tool-call loop turns per user message. Default: `20`.
|
||||
/// Setting to `0` falls back to the safe default of `20`.
|
||||
#[serde(default = "default_agent_max_tool_iterations")]
|
||||
@@ -770,6 +877,47 @@ pub struct AgentConfig {
|
||||
pub safety_heartbeat_turn_interval: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AgentSessionBackend {
|
||||
Memory,
|
||||
Sqlite,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum AgentSessionStrategy {
|
||||
PerSender,
|
||||
PerChannel,
|
||||
Main,
|
||||
}
|
||||
|
||||
/// Session persistence configuration (`[agent.session]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AgentSessionConfig {
|
||||
/// Session backend to use. Options: "memory", "sqlite", "none".
|
||||
/// Default: "none" (no persistence).
|
||||
/// Set to "none" to disable session persistence entirely.
|
||||
#[serde(default = "default_agent_session_backend")]
|
||||
pub backend: AgentSessionBackend,
|
||||
|
||||
/// Strategy for resolving session IDs. Options: "per-sender", "per-channel", "main".
|
||||
/// Default: "per-sender" (each user gets a unique session per channel).
|
||||
#[serde(default = "default_agent_session_strategy")]
|
||||
pub strategy: AgentSessionStrategy,
|
||||
|
||||
/// Time-to-live for sessions in seconds.
|
||||
/// Default: 3600 (1 hour).
|
||||
#[serde(default = "default_agent_session_ttl_seconds")]
|
||||
pub ttl_seconds: u64,
|
||||
|
||||
/// Maximum number of messages to retain per session.
|
||||
/// Default: 50.
|
||||
#[serde(default = "default_agent_session_max_messages")]
|
||||
pub max_messages: usize,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
20
|
||||
}
|
||||
@@ -782,6 +930,22 @@ fn default_agent_tool_dispatcher() -> String {
|
||||
"auto".into()
|
||||
}
|
||||
|
||||
fn default_agent_session_backend() -> AgentSessionBackend {
|
||||
AgentSessionBackend::None
|
||||
}
|
||||
|
||||
fn default_agent_session_strategy() -> AgentSessionStrategy {
|
||||
AgentSessionStrategy::PerSender
|
||||
}
|
||||
|
||||
fn default_agent_session_ttl_seconds() -> u64 {
|
||||
3600
|
||||
}
|
||||
|
||||
fn default_agent_session_max_messages() -> usize {
|
||||
default_agent_max_history_messages()
|
||||
}
|
||||
|
||||
fn default_loop_detection_no_progress_threshold() -> usize {
|
||||
3
|
||||
}
|
||||
@@ -806,6 +970,7 @@ impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
compact_context: true,
|
||||
session: AgentSessionConfig::default(),
|
||||
max_tool_iterations: default_agent_max_tool_iterations(),
|
||||
max_history_messages: default_agent_max_history_messages(),
|
||||
parallel_tools: false,
|
||||
@@ -819,6 +984,17 @@ impl Default for AgentConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentSessionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
backend: default_agent_session_backend(),
|
||||
strategy: default_agent_session_strategy(),
|
||||
ttl_seconds: default_agent_session_ttl_seconds(),
|
||||
max_messages: default_agent_session_max_messages(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Skills loading configuration (`[skills]` section).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -3874,6 +4050,8 @@ impl<T: ChannelConfig> crate::config::traits::ConfigHandle for ConfigWrapper<T>
|
||||
pub struct ChannelsConfig {
|
||||
/// Enable the CLI interactive channel. Default: `true`.
|
||||
pub cli: bool,
|
||||
/// ACP (Agent Client Protocol) channel configuration.
|
||||
pub acp: Option<AcpConfig>,
|
||||
/// Telegram bot channel configuration.
|
||||
pub telegram: Option<TelegramConfig>,
|
||||
/// Discord bot channel configuration.
|
||||
@@ -4021,6 +4199,10 @@ impl ChannelsConfig {
|
||||
Box::new(ConfigWrapper::new(self.nostr.as_ref())),
|
||||
self.nostr.is_some(),
|
||||
),
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.acp.as_ref())),
|
||||
self.acp.is_some(),
|
||||
),
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.clawdtalk.as_ref())),
|
||||
self.clawdtalk.is_some(),
|
||||
@@ -4046,6 +4228,7 @@ impl Default for ChannelsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cli: true,
|
||||
acp: None,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
slack: None,
|
||||
@@ -4267,7 +4450,12 @@ pub struct SlackConfig {
|
||||
pub app_token: Option<String>,
|
||||
/// Optional channel ID to restrict the bot to a single channel.
|
||||
/// Omit (or set `"*"`) to listen across all accessible channels.
|
||||
/// Ignored when `channel_ids` is non-empty.
|
||||
pub channel_id: Option<String>,
|
||||
/// Explicit list of channel/DM IDs to listen on simultaneously.
|
||||
/// Takes precedence over `channel_id`. Empty = fall back to `channel_id`.
|
||||
#[serde(default)]
|
||||
pub channel_ids: Vec<String>,
|
||||
/// Allowed Slack user IDs. Empty = deny all.
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
@@ -5654,9 +5842,9 @@ impl Default for Config {
|
||||
config_path: zeroclaw_dir.join("config.toml"),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".to_string()),
|
||||
default_provider: Some(DEFAULT_PROVIDER_NAME.to_string()),
|
||||
provider_api: None,
|
||||
default_model: Some("anthropic/claude-sonnet-4.6".to_string()),
|
||||
default_model: Some(DEFAULT_MODEL_NAME.to_string()),
|
||||
model_providers: HashMap::new(),
|
||||
provider: ProviderConfig::default(),
|
||||
default_temperature: 0.7,
|
||||
@@ -5848,7 +6036,10 @@ pub(crate) async fn persist_active_workspace_config_dir(config_dir: &Path) -> Re
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
sync_directory(&default_config_dir).await?;
|
||||
#[cfg(not(unix))]
|
||||
sync_directory(&default_config_dir)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -6614,6 +6805,10 @@ impl Config {
|
||||
config.workspace_dir = workspace_dir;
|
||||
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
|
||||
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
|
||||
for (profile_name, profile) in config.model_providers.iter_mut() {
|
||||
let secret_path = format!("config.model_providers.{profile_name}.api_key");
|
||||
decrypt_optional_secret(&store, &mut profile.api_key, &secret_path)?;
|
||||
}
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.transcription.api_key,
|
||||
@@ -6850,6 +7045,18 @@ impl Config {
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(ToString::to_string);
|
||||
let profile_default_model = profile
|
||||
.default_model
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(ToString::to_string);
|
||||
let profile_api_key = profile
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(ToString::to_string);
|
||||
|
||||
if self
|
||||
.api_url
|
||||
@@ -6862,6 +7069,30 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
if self
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.is_none_or(|value| value.is_empty())
|
||||
{
|
||||
if let Some(profile_api_key) = profile_api_key {
|
||||
self.api_key = Some(profile_api_key);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(profile_default_model) = profile_default_model {
|
||||
let can_apply_profile_model =
|
||||
self.default_model
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.is_none_or(|value| {
|
||||
value.is_empty() || value.eq_ignore_ascii_case(DEFAULT_MODEL_NAME)
|
||||
});
|
||||
if can_apply_profile_model {
|
||||
self.default_model = Some(profile_default_model);
|
||||
}
|
||||
}
|
||||
|
||||
if profile.requires_openai_auth
|
||||
&& self
|
||||
.api_key
|
||||
@@ -6908,6 +7139,10 @@ impl Config {
|
||||
/// Called after TOML deserialization and env-override application to catch
|
||||
/// obviously invalid values early instead of failing at arbitrary runtime points.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if let Some(acp) = &self.channels_config.acp {
|
||||
acp.validate()?;
|
||||
}
|
||||
|
||||
// Gateway
|
||||
if self.gateway.host.trim().is_empty() {
|
||||
anyhow::bail!("gateway.host must not be empty");
|
||||
@@ -7336,6 +7571,24 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(default_hint) = self
|
||||
.default_model
|
||||
.as_deref()
|
||||
.and_then(|model| model.strip_prefix("hint:"))
|
||||
.map(str::trim)
|
||||
.filter(|hint| !hint.is_empty())
|
||||
{
|
||||
if !self
|
||||
.model_routes
|
||||
.iter()
|
||||
.any(|route| route.hint.trim() == default_hint)
|
||||
{
|
||||
anyhow::bail!(
|
||||
"default_model uses hint '{default_hint}', but no matching [[model_routes]] entry exists"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if self
|
||||
.provider
|
||||
.transport
|
||||
@@ -7541,7 +7794,9 @@ impl Config {
|
||||
} else if let Ok(provider) = std::env::var("PROVIDER") {
|
||||
let should_apply_legacy_provider =
|
||||
self.default_provider.as_deref().map_or(true, |configured| {
|
||||
configured.trim().eq_ignore_ascii_case("openrouter")
|
||||
configured
|
||||
.trim()
|
||||
.eq_ignore_ascii_case(DEFAULT_PROVIDER_NAME)
|
||||
});
|
||||
if should_apply_legacy_provider && !provider.is_empty() {
|
||||
self.default_provider = Some(provider);
|
||||
@@ -8125,6 +8380,10 @@ impl Config {
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
||||
|
||||
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
|
||||
for (profile_name, profile) in config_to_save.model_providers.iter_mut() {
|
||||
let secret_path = format!("config.model_providers.{profile_name}.api_key");
|
||||
encrypt_optional_secret(&store, &mut profile.api_key, &secret_path)?;
|
||||
}
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.transcription.api_key,
|
||||
@@ -8291,7 +8550,10 @@ impl Config {
|
||||
})?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
sync_directory(parent_dir).await?;
|
||||
#[cfg(not(unix))]
|
||||
sync_directory(parent_dir)?;
|
||||
|
||||
if had_existing_config {
|
||||
let _ = fs::remove_file(&backup_path).await;
|
||||
@@ -8301,32 +8563,83 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn sync_directory(path: &Path) -> Result<()> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let dir = File::open(path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?;
|
||||
dir.sync_all()
|
||||
.await
|
||||
.with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
let dir = File::open(path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?;
|
||||
dir.sync_all()
|
||||
.await
|
||||
.with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn sync_directory(path: &Path) -> Result<()> {
|
||||
let _ = path;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// ACP (Agent Client Protocol) channel configuration.
|
||||
///
|
||||
/// Enables ZeroClaw to act as an ACP client, connecting to an OpenCode ACP server
|
||||
/// via `opencode acp` command for JSON-RPC 2.0 communication over stdio.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AcpConfig {
|
||||
/// OpenCode binary path (default: "opencode").
|
||||
#[serde(default = "default_acp_opencode_path")]
|
||||
pub opencode_path: Option<String>,
|
||||
/// Working directory for OpenCode process.
|
||||
pub workdir: Option<String>,
|
||||
/// Additional arguments to pass to `opencode acp`.
|
||||
#[serde(default)]
|
||||
pub extra_args: Vec<String>,
|
||||
/// Allowed user identifiers (empty = deny all, "*" = allow all).
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_acp_opencode_path() -> Option<String> {
|
||||
Some("opencode".to_string())
|
||||
}
|
||||
|
||||
impl AcpConfig {
|
||||
fn validate(&self) -> Result<()> {
|
||||
if self
|
||||
.opencode_path
|
||||
.as_deref()
|
||||
.is_some_and(|path| path.trim().is_empty())
|
||||
{
|
||||
anyhow::bail!("channels_config.acp.opencode_path must not be empty when set");
|
||||
}
|
||||
|
||||
if self
|
||||
.workdir
|
||||
.as_deref()
|
||||
.is_some_and(|dir| dir.trim().is_empty())
|
||||
{
|
||||
anyhow::bail!("channels_config.acp.workdir must not be empty when set");
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = path;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ChannelConfig for AcpConfig {
|
||||
fn name() -> &'static str {
|
||||
"ACP"
|
||||
}
|
||||
|
||||
fn desc() -> &'static str {
|
||||
"Agent Client Protocol channel for OpenCode integration"
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
use tokio::test;
|
||||
use tokio_stream::wrappers::ReadDirStream;
|
||||
@@ -8518,7 +8831,7 @@ mod tests {
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
async fn save_sets_config_permissions_on_new_file() {
|
||||
let temp = TempDir::new().expect("temp dir");
|
||||
let temp = tempfile::TempDir::new().expect("temp dir");
|
||||
let config_path = temp.path().join("config.toml");
|
||||
let workspace_dir = temp.path().join("workspace");
|
||||
|
||||
@@ -8832,6 +9145,7 @@ ws_url = "ws://127.0.0.1:3002"
|
||||
goal_loop: GoalLoopConfig::default(),
|
||||
channels_config: ChannelsConfig {
|
||||
cli: true,
|
||||
acp: None,
|
||||
telegram: Some(TelegramConfig {
|
||||
bot_token: "123:ABC".into(),
|
||||
allowed_users: vec!["user1".into()],
|
||||
@@ -9199,7 +9513,10 @@ tool_dispatcher = "xml"
|
||||
));
|
||||
fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
#[cfg(unix)]
|
||||
sync_directory(&dir).await.unwrap();
|
||||
#[cfg(not(unix))]
|
||||
sync_directory(&dir).unwrap();
|
||||
|
||||
let _ = fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
@@ -9763,6 +10080,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
async fn channels_config_with_imessage_and_matrix() {
|
||||
let c = ChannelsConfig {
|
||||
cli: true,
|
||||
acp: None,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
slack: None,
|
||||
@@ -9834,6 +10152,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
async fn slack_config_deserializes_without_allowed_users() {
|
||||
let json = r#"{"bot_token":"xoxb-tok"}"#;
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.channel_ids.is_empty());
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert_eq!(
|
||||
parsed.effective_group_reply_mode(),
|
||||
@@ -9845,6 +10164,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
async fn slack_config_deserializes_with_allowed_users() {
|
||||
let json = r#"{"bot_token":"xoxb-tok","allowed_users":["U111"]}"#;
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.channel_ids.is_empty());
|
||||
assert_eq!(parsed.allowed_users, vec!["U111"]);
|
||||
}
|
||||
|
||||
@@ -9866,6 +10186,7 @@ bot_token = "xoxb-tok"
|
||||
channel_id = "C123"
|
||||
"#;
|
||||
let parsed: SlackConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(parsed.channel_ids.is_empty());
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert_eq!(parsed.channel_id.as_deref(), Some("C123"));
|
||||
assert_eq!(
|
||||
@@ -10045,6 +10366,7 @@ channel_id = "C123"
|
||||
async fn channels_config_with_whatsapp() {
|
||||
let c = ChannelsConfig {
|
||||
cli: true,
|
||||
acp: None,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
slack: None,
|
||||
@@ -10642,6 +10964,8 @@ model = "gpt-5.3-codex"
|
||||
name = "sub2api"
|
||||
base_url = "https://api.tonsof.blue/v1"
|
||||
wire_api = "responses"
|
||||
model = "gpt-5.3-codex"
|
||||
api_key = "profile-key"
|
||||
requires_openai_auth = true
|
||||
"#;
|
||||
|
||||
@@ -10653,6 +10977,8 @@ requires_openai_auth = true
|
||||
.get("sub2api")
|
||||
.expect("profile should exist");
|
||||
assert_eq!(profile.wire_api.as_deref(), Some("responses"));
|
||||
assert_eq!(profile.default_model.as_deref(), Some("gpt-5.3-codex"));
|
||||
assert_eq!(profile.api_key.as_deref(), Some("profile-key"));
|
||||
assert!(profile.requires_openai_auth);
|
||||
}
|
||||
|
||||
@@ -10813,6 +11139,67 @@ provider_api = "not-a-real-mode"
|
||||
.contains("model_routes[0].max_tokens must be greater than 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn default_model_hint_requires_matching_model_route() {
|
||||
let mut config = Config::default();
|
||||
config.default_model = Some("hint:reasoning".to_string());
|
||||
config.model_routes = vec![ModelRouteConfig {
|
||||
hint: "fast".to_string(),
|
||||
provider: "openrouter".to_string(),
|
||||
model: "openai/gpt-5.2".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("default_model hint without matching route should fail");
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("default_model uses hint 'reasoning'"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn default_model_hint_accepts_matching_model_route() {
|
||||
let mut config = Config::default();
|
||||
config.default_model = Some("hint:reasoning".to_string());
|
||||
config.model_routes = vec![ModelRouteConfig {
|
||||
hint: "reasoning".to_string(),
|
||||
provider: "openrouter".to_string(),
|
||||
model: "openai/gpt-5.2".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let result = config.validate();
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"matching default hint route should validate"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn default_model_hint_accepts_matching_model_route_with_whitespace() {
|
||||
let mut config = Config::default();
|
||||
config.default_model = Some("hint: reasoning ".to_string());
|
||||
config.model_routes = vec![ModelRouteConfig {
|
||||
hint: " reasoning ".to_string(),
|
||||
provider: "openrouter".to_string(),
|
||||
model: "openai/gpt-5.2".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let result = config.validate();
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"trimmed default hint should match trimmed route hint"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_transport_normalizes_aliases() {
|
||||
let mut config = Config::default();
|
||||
@@ -10897,6 +11284,31 @@ provider_api = "not-a-real-mode"
|
||||
std::env::remove_var("ZEROCLAW_MODEL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn resolve_default_model_id_prefers_configured_model() {
|
||||
let resolved =
|
||||
resolve_default_model_id(Some(" anthropic/claude-opus-4.6 "), Some("openrouter"));
|
||||
assert_eq!(resolved, "anthropic/claude-opus-4.6");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn resolve_default_model_id_uses_provider_specific_fallback() {
|
||||
let openai = resolve_default_model_id(None, Some("openai"));
|
||||
assert_eq!(openai, "gpt-5.2");
|
||||
|
||||
let bedrock = resolve_default_model_id(None, Some("aws-bedrock"));
|
||||
assert_eq!(bedrock, "anthropic.claude-sonnet-4-5-20250929-v1:0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn resolve_default_model_id_handles_special_provider_aliases() {
|
||||
let qwen_coding_plan = resolve_default_model_id(None, Some("qwen-coding-plan"));
|
||||
assert_eq!(qwen_coding_plan, "qwen3-coder-plus");
|
||||
|
||||
let google_alias = resolve_default_model_id(None, Some("google-gemini"));
|
||||
assert_eq!(google_alias, "gemini-2.5-pro");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn model_provider_profile_maps_to_custom_endpoint() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
@@ -10908,6 +11320,8 @@ provider_api = "not-a-real-mode"
|
||||
name: Some("sub2api".to_string()),
|
||||
base_url: Some("https://api.tonsof.blue/v1".to_string()),
|
||||
wire_api: None,
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
requires_openai_auth: false,
|
||||
},
|
||||
)]),
|
||||
@@ -10936,6 +11350,8 @@ provider_api = "not-a-real-mode"
|
||||
name: Some("sub2api".to_string()),
|
||||
base_url: Some("https://api.tonsof.blue".to_string()),
|
||||
wire_api: Some("responses".to_string()),
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
requires_openai_auth: true,
|
||||
},
|
||||
)]),
|
||||
@@ -10998,6 +11414,8 @@ provider_api = "not-a-real-mode"
|
||||
name: Some("sub2api".to_string()),
|
||||
base_url: Some("https://api.tonsof.blue/v1".to_string()),
|
||||
wire_api: Some("ws".to_string()),
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
requires_openai_auth: false,
|
||||
},
|
||||
)]),
|
||||
@@ -11010,6 +11428,54 @@ provider_api = "not-a-real-mode"
|
||||
.contains("wire_api must be one of: responses, chat_completions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn model_provider_profile_uses_profile_api_key_when_global_is_missing() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
let mut config = Config {
|
||||
default_provider: Some("sub2api".to_string()),
|
||||
api_key: None,
|
||||
model_providers: HashMap::from([(
|
||||
"sub2api".to_string(),
|
||||
ModelProviderConfig {
|
||||
name: Some("sub2api".to_string()),
|
||||
base_url: Some("https://api.tonsof.blue/v1".to_string()),
|
||||
wire_api: None,
|
||||
default_model: None,
|
||||
api_key: Some("profile-api-key".to_string()),
|
||||
requires_openai_auth: false,
|
||||
},
|
||||
)]),
|
||||
..Config::default()
|
||||
};
|
||||
|
||||
config.apply_env_overrides();
|
||||
assert_eq!(config.api_key.as_deref(), Some("profile-api-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn model_provider_profile_can_override_default_model_when_openrouter_default_is_set() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
let mut config = Config {
|
||||
default_provider: Some("sub2api".to_string()),
|
||||
default_model: Some(DEFAULT_MODEL_NAME.to_string()),
|
||||
model_providers: HashMap::from([(
|
||||
"sub2api".to_string(),
|
||||
ModelProviderConfig {
|
||||
name: Some("sub2api".to_string()),
|
||||
base_url: Some("https://api.tonsof.blue/v1".to_string()),
|
||||
wire_api: None,
|
||||
default_model: Some("qwen-max".to_string()),
|
||||
api_key: None,
|
||||
requires_openai_auth: false,
|
||||
},
|
||||
)]),
|
||||
..Config::default()
|
||||
};
|
||||
|
||||
config.apply_env_overrides();
|
||||
assert_eq!(config.default_model.as_deref(), Some("qwen-max"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_model_fallback() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
|
||||
+21
-17
@@ -59,7 +59,7 @@ pub async fn run(config: Config) -> Result<()> {
|
||||
|
||||
pub async fn execute_job_now(config: &Config, job: &CronJob) -> (bool, String) {
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
execute_job_with_retry(config, &security, job).await
|
||||
Box::pin(execute_job_with_retry(config, &security, job)).await
|
||||
}
|
||||
|
||||
async fn execute_job_with_retry(
|
||||
@@ -74,7 +74,7 @@ async fn execute_job_with_retry(
|
||||
for attempt in 0..=retries {
|
||||
let (success, output) = match job.job_type {
|
||||
JobType::Shell => run_job_command(config, security, job).await,
|
||||
JobType::Agent => run_agent_job(config, security, job).await,
|
||||
JobType::Agent => Box::pin(run_agent_job(config, security, job)).await,
|
||||
};
|
||||
last_output = output;
|
||||
|
||||
@@ -107,18 +107,21 @@ async fn process_due_jobs(
|
||||
crate::health::mark_component_ok(component);
|
||||
|
||||
let max_concurrent = config.scheduler.max_concurrent.max(1);
|
||||
let mut in_flight =
|
||||
stream::iter(
|
||||
jobs.into_iter().map(|job| {
|
||||
let config = config.clone();
|
||||
let security = Arc::clone(security);
|
||||
let component = component.to_owned();
|
||||
async move {
|
||||
execute_and_persist_job(&config, security.as_ref(), &job, &component).await
|
||||
}
|
||||
}),
|
||||
)
|
||||
.buffer_unordered(max_concurrent);
|
||||
let mut in_flight = stream::iter(jobs.into_iter().map(|job| {
|
||||
let config = config.clone();
|
||||
let security = Arc::clone(security);
|
||||
let component = component.to_owned();
|
||||
async move {
|
||||
Box::pin(execute_and_persist_job(
|
||||
&config,
|
||||
security.as_ref(),
|
||||
&job,
|
||||
&component,
|
||||
))
|
||||
.await
|
||||
}
|
||||
}))
|
||||
.buffer_unordered(max_concurrent);
|
||||
|
||||
while let Some((job_id, success, output)) = in_flight.next().await {
|
||||
if !success {
|
||||
@@ -137,7 +140,7 @@ async fn execute_and_persist_job(
|
||||
warn_if_high_frequency_agent_job(job);
|
||||
|
||||
let started_at = Utc::now();
|
||||
let (success, output) = execute_job_with_retry(config, security, job).await;
|
||||
let (success, output) = Box::pin(execute_job_with_retry(config, security, job)).await;
|
||||
let finished_at = Utc::now();
|
||||
let success = persist_job_result(config, job, success, &output, started_at, finished_at).await;
|
||||
|
||||
@@ -176,7 +179,7 @@ async fn run_agent_job(
|
||||
|
||||
let run_result = match job.session_target {
|
||||
SessionTarget::Main | SessionTarget::Isolated => {
|
||||
crate::agent::run(
|
||||
Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
Some(prefixed_prompt),
|
||||
None,
|
||||
@@ -184,7 +187,7 @@ async fn run_agent_job(
|
||||
config.default_temperature,
|
||||
vec![],
|
||||
false,
|
||||
)
|
||||
))
|
||||
.await
|
||||
}
|
||||
};
|
||||
@@ -364,6 +367,7 @@ pub(crate) async fn deliver_announcement(
|
||||
sl.bot_token.clone(),
|
||||
sl.app_token.clone(),
|
||||
sl.channel_id.clone(),
|
||||
sl.channel_ids.clone(),
|
||||
sl.allowed_users.clone(),
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
|
||||
+3
-3
@@ -65,7 +65,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
|
||||
max_backoff,
|
||||
move || {
|
||||
let cfg = channels_cfg.clone();
|
||||
async move { crate::channels::start_channels(cfg).await }
|
||||
async move { Box::pin(crate::channels::start_channels(cfg)).await }
|
||||
},
|
||||
));
|
||||
} else {
|
||||
@@ -214,7 +214,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
for task in tasks {
|
||||
let prompt = format!("[Heartbeat Task] {task}");
|
||||
let temp = config.default_temperature;
|
||||
match crate::agent::run(
|
||||
match Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
Some(prompt),
|
||||
None,
|
||||
@@ -222,7 +222,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
temp,
|
||||
vec![],
|
||||
false,
|
||||
)
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
|
||||
@@ -115,7 +115,9 @@ impl TaskClassifier {
|
||||
|
||||
/// Load all 44 BLS occupations with wage data
|
||||
fn load_occupations() -> Vec<Occupation> {
|
||||
use OccupationCategory::*;
|
||||
use OccupationCategory::{
|
||||
BusinessFinance, HealthcareSocialServices, LegalMediaOperations, TechnologyEngineering,
|
||||
};
|
||||
|
||||
vec![
|
||||
// Technology & Engineering
|
||||
@@ -732,11 +734,11 @@ impl TaskClassifier {
|
||||
};
|
||||
|
||||
// Scale by instruction length
|
||||
let length_factor = (word_count as f64 / 20.0).max(0.5).min(2.0);
|
||||
let length_factor = (word_count as f64 / 20.0).clamp(0.5, 2.0);
|
||||
let hours = base_hours * length_factor;
|
||||
|
||||
// Clamp to valid range
|
||||
hours.max(0.25).min(40.0)
|
||||
hours.clamp(0.25, 40.0)
|
||||
}
|
||||
|
||||
/// Get all occupations
|
||||
|
||||
@@ -9,12 +9,13 @@ use std::fmt;
|
||||
/// Survival status based on balance percentage relative to initial capital.
|
||||
///
|
||||
/// Mirrors the ClawWork LiveBench agent survival states.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SurvivalStatus {
|
||||
/// Balance > 80% of initial - Agent is profitable and healthy
|
||||
Thriving,
|
||||
/// Balance 40-80% of initial - Agent is maintaining stability
|
||||
#[default]
|
||||
Stable,
|
||||
/// Balance 10-40% of initial - Agent is losing money, needs attention
|
||||
Struggling,
|
||||
@@ -100,12 +101,6 @@ impl fmt::Display for SurvivalStatus {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SurvivalStatus {
|
||||
fn default() -> Self {
|
||||
Self::Stable
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
+81
-23
@@ -90,6 +90,17 @@ fn qq_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String {
|
||||
format!("qq_{}_{}", msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn gateway_message_session_id(msg: &crate::channels::traits::ChannelMessage) -> String {
|
||||
if msg.channel == "qq" || msg.channel == "napcat" {
|
||||
return format!("{}_{}", msg.channel, msg.sender);
|
||||
}
|
||||
|
||||
match &msg.thread_ts {
|
||||
Some(thread_id) => format!("{}_{}_{}", msg.channel, thread_id, msg.sender),
|
||||
None => format!("{}_{}", msg.channel, msg.sender),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_webhook_secret(value: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
@@ -1024,9 +1035,10 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
|
||||
pub(super) async fn run_gateway_chat_with_tools(
|
||||
state: &AppState,
|
||||
message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let config = state.config.lock().clone();
|
||||
crate::agent::process_message(config, message).await
|
||||
crate::agent::process_message_with_session(config, message, session_id).await
|
||||
}
|
||||
|
||||
fn gateway_outbound_leak_guard_snapshot(
|
||||
@@ -1062,6 +1074,8 @@ pub struct WebhookBody {
|
||||
pub message: String,
|
||||
#[serde(default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
@@ -1579,6 +1593,11 @@ async fn handle_webhook(
|
||||
}
|
||||
|
||||
let message = webhook_body.message.trim();
|
||||
let webhook_session_id = webhook_body
|
||||
.session_id
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty());
|
||||
if message.is_empty() {
|
||||
let err = serde_json::json!({
|
||||
"error": "The `message` field is required and must be a non-empty string."
|
||||
@@ -1590,7 +1609,12 @@ async fn handle_webhook(
|
||||
let key = webhook_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, message, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
message,
|
||||
MemoryCategory::Conversation,
|
||||
webhook_session_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1786,8 +1810,7 @@ async fn handle_whatsapp_verify(
|
||||
/// Returns true if the signature is valid, false otherwise.
|
||||
/// See: <https://developers.facebook.com/docs/graph-api/webhooks/getting-started#verification-requests>
|
||||
pub fn verify_whatsapp_signature(app_secret: &str, body: &[u8], signature_header: &str) -> bool {
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use ring::hmac;
|
||||
|
||||
// Signature format: "sha256=<hex_signature>"
|
||||
let Some(hex_sig) = signature_header.strip_prefix("sha256=") else {
|
||||
@@ -1799,14 +1822,8 @@ pub fn verify_whatsapp_signature(app_secret: &str, body: &[u8], signature_header
|
||||
return false;
|
||||
};
|
||||
|
||||
// Compute HMAC-SHA256
|
||||
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(app_secret.as_bytes()) else {
|
||||
return false;
|
||||
};
|
||||
mac.update(body);
|
||||
|
||||
// Constant-time comparison
|
||||
mac.verify_slice(&expected).is_ok()
|
||||
let key = hmac::Key::new(hmac::HMAC_SHA256, app_secret.as_bytes());
|
||||
hmac::verify(&key, body, &expected).is_ok()
|
||||
}
|
||||
|
||||
/// POST /whatsapp — incoming message webhook
|
||||
@@ -1868,17 +1885,23 @@ async fn handle_whatsapp_message(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = gateway_message_session_id(msg);
|
||||
|
||||
// Auto-save to memory
|
||||
if state.auto_save {
|
||||
let key = whatsapp_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -1990,18 +2013,24 @@ async fn handle_linq_webhook(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = gateway_message_session_id(msg);
|
||||
|
||||
// Auto-save to memory
|
||||
if state.auto_save {
|
||||
let key = linq_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -2034,6 +2063,7 @@ async fn handle_linq_webhook(
|
||||
}
|
||||
|
||||
/// POST /github — incoming GitHub webhook (issue/PR comments)
|
||||
#[allow(clippy::large_futures)]
|
||||
async fn handle_github_webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
@@ -2167,7 +2197,7 @@ async fn handle_github_webhook(
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match run_gateway_chat_with_tools(&state, &msg.content, None).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -2351,18 +2381,24 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = gateway_message_session_id(msg);
|
||||
|
||||
// Auto-save to memory
|
||||
if state.auto_save {
|
||||
let key = wati_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -2463,16 +2499,22 @@ async fn handle_nextcloud_talk_webhook(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = gateway_message_session_id(msg);
|
||||
|
||||
if state.auto_save {
|
||||
let key = nextcloud_talk_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -2558,16 +2600,22 @@ async fn handle_qq_webhook(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = gateway_message_session_id(msg);
|
||||
|
||||
if state.auto_save {
|
||||
let key = qq_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match run_gateway_chat_with_tools(&state, &msg.content, Some(&session_id)).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = gateway_outbound_leak_guard_snapshot(&state);
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -3365,6 +3413,7 @@ Reminder set successfully."#;
|
||||
let body = Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
}));
|
||||
let first = handle_webhook(
|
||||
State(state.clone()),
|
||||
@@ -3379,6 +3428,7 @@ Reminder set successfully."#;
|
||||
let body = Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
}));
|
||||
let second = handle_webhook(State(state), test_connect_info(), headers, body)
|
||||
.await
|
||||
@@ -3437,6 +3487,7 @@ Reminder set successfully."#;
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
@@ -3490,6 +3541,7 @@ Reminder set successfully."#;
|
||||
Ok(Json(WebhookBody {
|
||||
message: " ".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
@@ -3544,6 +3596,7 @@ Reminder set successfully."#;
|
||||
Ok(Json(WebhookBody {
|
||||
message: "stream me".into(),
|
||||
stream: Some(true),
|
||||
session_id: None,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
@@ -3722,6 +3775,7 @@ Reminder set successfully."#;
|
||||
let body1 = Ok(Json(WebhookBody {
|
||||
message: "hello one".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
}));
|
||||
let first = handle_webhook(
|
||||
State(state.clone()),
|
||||
@@ -3736,6 +3790,7 @@ Reminder set successfully."#;
|
||||
let body2 = Ok(Json(WebhookBody {
|
||||
message: "hello two".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
}));
|
||||
let second = handle_webhook(State(state), test_connect_info(), headers, body2)
|
||||
.await
|
||||
@@ -3809,6 +3864,7 @@ Reminder set successfully."#;
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
@@ -3871,6 +3927,7 @@ Reminder set successfully."#;
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
@@ -3929,6 +3986,7 @@ Reminder set successfully."#;
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
stream: None,
|
||||
session_id: None,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -131,6 +131,11 @@ pub async fn handle_api_chat(
|
||||
};
|
||||
|
||||
let message = chat_body.message.trim();
|
||||
let session_id = chat_body
|
||||
.session_id
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty());
|
||||
if message.is_empty() {
|
||||
let err = serde_json::json!({ "error": "Message cannot be empty" });
|
||||
return (StatusCode::BAD_REQUEST, Json(err));
|
||||
@@ -141,7 +146,7 @@ pub async fn handle_api_chat(
|
||||
let key = api_chat_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, message, MemoryCategory::Conversation, None)
|
||||
.store(&key, message, MemoryCategory::Conversation, session_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -186,7 +191,7 @@ pub async fn handle_api_chat(
|
||||
});
|
||||
|
||||
// ── Run the full agent loop ──
|
||||
match run_gateway_chat_with_tools(&state, &enriched_message).await {
|
||||
match run_gateway_chat_with_tools(&state, &enriched_message, session_id).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = state.config.lock().security.outbound_leak_guard.clone();
|
||||
let safe_response = sanitize_gateway_response(
|
||||
@@ -519,6 +524,11 @@ pub async fn handle_v1_chat_completions_with_tools(
|
||||
};
|
||||
|
||||
let is_stream = request.stream.unwrap_or(false);
|
||||
let session_id = request
|
||||
.user
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty());
|
||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let created = unix_timestamp();
|
||||
|
||||
@@ -527,7 +537,7 @@ pub async fn handle_v1_chat_completions_with_tools(
|
||||
let key = api_chat_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &message, MemoryCategory::Conversation, None)
|
||||
.store(&key, &message, MemoryCategory::Conversation, session_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -562,7 +572,7 @@ pub async fn handle_v1_chat_completions_with_tools(
|
||||
);
|
||||
|
||||
// ── Run the full agent loop ──
|
||||
let reply = match run_gateway_chat_with_tools(&state, &enriched_message).await {
|
||||
let reply = match run_gateway_chat_with_tools(&state, &enriched_message, session_id).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = state.config.lock().security.outbound_leak_guard.clone();
|
||||
let safe = sanitize_gateway_response(
|
||||
|
||||
+261
-21
@@ -11,6 +11,7 @@
|
||||
|
||||
use super::AppState;
|
||||
use crate::agent::loop_::{build_shell_policy_instructions, build_tool_instructions_from_specs};
|
||||
use crate::memory::MemoryCategory;
|
||||
use crate::providers::ChatMessage;
|
||||
use axum::{
|
||||
extract::{
|
||||
@@ -20,9 +21,180 @@ use axum::{
|
||||
http::{header, HeaderMap},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
const EMPTY_WS_RESPONSE_FALLBACK: &str =
|
||||
"Tool execution completed, but the model returned no final text response. Please ask me to summarize the result.";
|
||||
const WS_HISTORY_MEMORY_KEY_PREFIX: &str = "gateway_ws_history";
|
||||
const MAX_WS_PERSISTED_TURNS: usize = 128;
|
||||
const MAX_WS_SESSION_ID_LEN: usize = 128;
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
struct WsQueryParams {
|
||||
token: Option<String>,
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||
struct WsHistoryTurn {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default, PartialEq, Eq)]
|
||||
struct WsPersistedHistory {
|
||||
version: u8,
|
||||
messages: Vec<WsHistoryTurn>,
|
||||
}
|
||||
|
||||
fn normalize_ws_session_id(candidate: Option<&str>) -> Option<String> {
|
||||
let raw = candidate?.trim();
|
||||
if raw.is_empty() || raw.len() > MAX_WS_SESSION_ID_LEN {
|
||||
return None;
|
||||
}
|
||||
|
||||
if raw
|
||||
.chars()
|
||||
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_')
|
||||
{
|
||||
return Some(raw.to_string());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn parse_ws_query_params(raw_query: Option<&str>) -> WsQueryParams {
|
||||
let Some(query) = raw_query else {
|
||||
return WsQueryParams::default();
|
||||
};
|
||||
|
||||
let mut params = WsQueryParams::default();
|
||||
for kv in query.split('&') {
|
||||
let mut parts = kv.splitn(2, '=');
|
||||
let key = parts.next().unwrap_or("").trim();
|
||||
let value = parts.next().unwrap_or("").trim();
|
||||
if value.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match key {
|
||||
"token" if params.token.is_none() => {
|
||||
params.token = Some(value.to_string());
|
||||
}
|
||||
"session_id" if params.session_id.is_none() => {
|
||||
params.session_id = normalize_ws_session_id(Some(value));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
params
|
||||
}
|
||||
|
||||
fn ws_history_memory_key(session_id: &str) -> String {
|
||||
format!("{WS_HISTORY_MEMORY_KEY_PREFIX}:{session_id}")
|
||||
}
|
||||
|
||||
fn ws_history_turns_from_chat(history: &[ChatMessage]) -> Vec<WsHistoryTurn> {
|
||||
let mut turns = history
|
||||
.iter()
|
||||
.filter_map(|msg| match msg.role.as_str() {
|
||||
"user" | "assistant" => {
|
||||
let content = msg.content.trim();
|
||||
if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(WsHistoryTurn {
|
||||
role: msg.role.clone(),
|
||||
content: content.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if turns.len() > MAX_WS_PERSISTED_TURNS {
|
||||
let keep_from = turns.len().saturating_sub(MAX_WS_PERSISTED_TURNS);
|
||||
turns.drain(0..keep_from);
|
||||
}
|
||||
turns
|
||||
}
|
||||
|
||||
fn restore_chat_history(system_prompt: &str, turns: &[WsHistoryTurn]) -> Vec<ChatMessage> {
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
for turn in turns {
|
||||
match turn.role.as_str() {
|
||||
"user" => history.push(ChatMessage::user(&turn.content)),
|
||||
"assistant" => history.push(ChatMessage::assistant(&turn.content)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
history
|
||||
}
|
||||
|
||||
async fn load_ws_history(
|
||||
state: &AppState,
|
||||
session_id: &str,
|
||||
system_prompt: &str,
|
||||
) -> Vec<ChatMessage> {
|
||||
let key = ws_history_memory_key(session_id);
|
||||
let Some(entry) = state.mem.get(&key).await.ok().flatten() else {
|
||||
return vec![ChatMessage::system(system_prompt)];
|
||||
};
|
||||
|
||||
let parsed = serde_json::from_str::<WsPersistedHistory>(&entry.content)
|
||||
.map(|history| history.messages)
|
||||
.or_else(|_| serde_json::from_str::<Vec<WsHistoryTurn>>(&entry.content));
|
||||
|
||||
match parsed {
|
||||
Ok(turns) => restore_chat_history(system_prompt, &turns),
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse persisted websocket history for session {}: {}",
|
||||
session_id,
|
||||
err
|
||||
);
|
||||
vec![ChatMessage::system(system_prompt)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_ws_history(state: &AppState, session_id: &str, history: &[ChatMessage]) {
|
||||
let payload = WsPersistedHistory {
|
||||
version: 1,
|
||||
messages: ws_history_turns_from_chat(history),
|
||||
};
|
||||
let serialized = match serde_json::to_string(&payload) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"Failed to serialize websocket history for session {}: {}",
|
||||
session_id,
|
||||
err
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let key = ws_history_memory_key(session_id);
|
||||
if let Err(err) = state
|
||||
.mem
|
||||
.store(
|
||||
&key,
|
||||
&serialized,
|
||||
MemoryCategory::Conversation,
|
||||
Some(session_id),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"Failed to persist websocket history for session {}: {}",
|
||||
session_id,
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_ws_response(
|
||||
response: &str,
|
||||
@@ -168,10 +340,11 @@ pub async fn handle_ws_chat(
|
||||
RawQuery(query): RawQuery,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> impl IntoResponse {
|
||||
let query_params = parse_ws_query_params(query.as_deref());
|
||||
// Auth via Authorization header or websocket protocol token.
|
||||
if state.pairing.require_pairing() {
|
||||
let query_token = extract_query_token(query.as_deref());
|
||||
let token = extract_ws_bearer_token(&headers, query_token.as_deref()).unwrap_or_default();
|
||||
let token =
|
||||
extract_ws_bearer_token(&headers, query_params.token.as_deref()).unwrap_or_default();
|
||||
if !state.pairing.is_authenticated(&token) {
|
||||
return (
|
||||
axum::http::StatusCode::UNAUTHORIZED,
|
||||
@@ -181,13 +354,16 @@ pub async fn handle_ws_chat(
|
||||
}
|
||||
}
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state))
|
||||
let session_id = query_params
|
||||
.session_id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
// Maintain conversation history for this WebSocket session
|
||||
let mut history: Vec<ChatMessage> = Vec::new();
|
||||
async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: String) {
|
||||
let ws_session_id = format!("ws_{}", Uuid::new_v4());
|
||||
|
||||
// Build system prompt once for the session
|
||||
let system_prompt = {
|
||||
@@ -200,8 +376,17 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
)
|
||||
};
|
||||
|
||||
// Add system message to history
|
||||
history.push(ChatMessage::system(&system_prompt));
|
||||
// Restore persisted history (if any) and replay to the client before processing new input.
|
||||
let mut history = load_ws_history(&state, &session_id, &system_prompt).await;
|
||||
let persisted_turns = ws_history_turns_from_chat(&history);
|
||||
let history_payload = serde_json::json!({
|
||||
"type": "history",
|
||||
"session_id": session_id.as_str(),
|
||||
"messages": persisted_turns,
|
||||
});
|
||||
let _ = socket
|
||||
.send(Message::Text(history_payload.to_string().into()))
|
||||
.await;
|
||||
|
||||
while let Some(msg) = socket.recv().await {
|
||||
let msg = match msg {
|
||||
@@ -250,6 +435,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
|
||||
// Add user message to history
|
||||
history.push(ChatMessage::user(&content));
|
||||
persist_ws_history(&state, &session_id, &history).await;
|
||||
|
||||
// Get provider info
|
||||
let provider_label = state
|
||||
@@ -267,7 +453,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
}));
|
||||
|
||||
// Full agentic loop with tools (includes WASM skills, shell, memory, etc.)
|
||||
match super::run_gateway_chat_with_tools(&state, &content).await {
|
||||
match super::run_gateway_chat_with_tools(&state, &content, Some(&ws_session_id)).await {
|
||||
Ok(response) => {
|
||||
let leak_guard_cfg = { state.config.lock().security.outbound_leak_guard.clone() };
|
||||
let safe_response = finalize_ws_response(
|
||||
@@ -278,6 +464,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState) {
|
||||
);
|
||||
// Add assistant response to history
|
||||
history.push(ChatMessage::assistant(&safe_response));
|
||||
persist_ws_history(&state, &session_id, &history).await;
|
||||
|
||||
// Send the full response as a done message
|
||||
let done = serde_json::json!({
|
||||
@@ -345,18 +532,7 @@ fn extract_ws_bearer_token(headers: &HeaderMap, query_token: Option<&str>) -> Op
|
||||
}
|
||||
|
||||
fn extract_query_token(raw_query: Option<&str>) -> Option<String> {
|
||||
let query = raw_query?;
|
||||
for kv in query.split('&') {
|
||||
let mut parts = kv.splitn(2, '=');
|
||||
if parts.next() != Some("token") {
|
||||
continue;
|
||||
}
|
||||
let token = parts.next().unwrap_or("").trim();
|
||||
if !token.is_empty() {
|
||||
return Some(token.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
parse_ws_query_params(raw_query).token
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -445,6 +621,70 @@ mod tests {
|
||||
assert!(extract_query_token(Some("foo=1")).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ws_query_params_reads_token_and_session_id() {
|
||||
let parsed = parse_ws_query_params(Some("foo=1&session_id=sess_123&token=query-token"));
|
||||
assert_eq!(parsed.token.as_deref(), Some("query-token"));
|
||||
assert_eq!(parsed.session_id.as_deref(), Some("sess_123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ws_query_params_rejects_invalid_session_id() {
|
||||
let parsed = parse_ws_query_params(Some("session_id=../../etc/passwd"));
|
||||
assert!(parsed.session_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ws_history_turns_from_chat_skips_system_and_non_dialog_turns() {
|
||||
let history = vec![
|
||||
ChatMessage::system("sys"),
|
||||
ChatMessage::user(" hello "),
|
||||
ChatMessage {
|
||||
role: "tool".to_string(),
|
||||
content: "ignored".to_string(),
|
||||
},
|
||||
ChatMessage::assistant(" world "),
|
||||
];
|
||||
|
||||
let turns = ws_history_turns_from_chat(&history);
|
||||
assert_eq!(
|
||||
turns,
|
||||
vec![
|
||||
WsHistoryTurn {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string()
|
||||
},
|
||||
WsHistoryTurn {
|
||||
role: "assistant".to_string(),
|
||||
content: "world".to_string()
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restore_chat_history_applies_system_prompt_once() {
|
||||
let turns = vec![
|
||||
WsHistoryTurn {
|
||||
role: "user".to_string(),
|
||||
content: "u1".to_string(),
|
||||
},
|
||||
WsHistoryTurn {
|
||||
role: "assistant".to_string(),
|
||||
content: "a1".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let restored = restore_chat_history("sys", &turns);
|
||||
assert_eq!(restored.len(), 3);
|
||||
assert_eq!(restored[0].role, "system");
|
||||
assert_eq!(restored[0].content, "sys");
|
||||
assert_eq!(restored[1].role, "user");
|
||||
assert_eq!(restored[1].content, "u1");
|
||||
assert_eq!(restored[2].role, "assistant");
|
||||
assert_eq!(restored[2].content, "a1");
|
||||
}
|
||||
|
||||
struct MockScheduleTool;
|
||||
|
||||
#[async_trait]
|
||||
|
||||
+6
-8
@@ -59,9 +59,7 @@ mod agent;
|
||||
mod approval;
|
||||
mod auth;
|
||||
mod channels;
|
||||
mod rag {
|
||||
pub use zeroclaw::rag::*;
|
||||
}
|
||||
mod rag;
|
||||
mod config;
|
||||
mod coordination;
|
||||
mod cost;
|
||||
@@ -858,7 +856,7 @@ async fn main() -> Result<()> {
|
||||
}?;
|
||||
// Auto-start channels if user said yes during wizard
|
||||
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
||||
channels::start_channels(config).await?;
|
||||
Box::pin(channels::start_channels(config)).await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
@@ -919,7 +917,7 @@ async fn main() -> Result<()> {
|
||||
// Single-shot mode (-m) runs non-interactively: no TTY approval prompt,
|
||||
// so tools are not denied by a stdin read returning EOF.
|
||||
let interactive = message.is_none();
|
||||
agent::run(
|
||||
Box::pin(agent::run(
|
||||
config,
|
||||
message,
|
||||
provider,
|
||||
@@ -927,7 +925,7 @@ async fn main() -> Result<()> {
|
||||
temperature,
|
||||
peripheral,
|
||||
interactive,
|
||||
)
|
||||
))
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
@@ -1166,8 +1164,8 @@ async fn main() -> Result<()> {
|
||||
},
|
||||
|
||||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => channels::start_channels(config).await,
|
||||
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
||||
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
|
||||
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
|
||||
other => channels::handle_command(other, &config).await,
|
||||
},
|
||||
|
||||
|
||||
+35
-4
@@ -3,6 +3,7 @@ pub enum MemoryBackendKind {
|
||||
Sqlite,
|
||||
SqliteQdrantHybrid,
|
||||
Lucid,
|
||||
CortexMem,
|
||||
Postgres,
|
||||
Qdrant,
|
||||
Markdown,
|
||||
@@ -39,6 +40,15 @@ const LUCID_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: true,
|
||||
};
|
||||
|
||||
const CORTEX_MEM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "cortex-mem",
|
||||
label: "Cortex-Mem bridge — optional CLI sync with local SQLite fallback",
|
||||
auto_save_default: true,
|
||||
uses_sqlite_hygiene: true,
|
||||
sqlite_based: true,
|
||||
optional_dependency: true,
|
||||
};
|
||||
|
||||
const MARKDOWN_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "markdown",
|
||||
label: "Markdown Files — simple, human-readable, no dependencies",
|
||||
@@ -93,9 +103,10 @@ const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 4] = [
|
||||
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 5] = [
|
||||
SQLITE_PROFILE,
|
||||
LUCID_PROFILE,
|
||||
CORTEX_MEM_PROFILE,
|
||||
MARKDOWN_PROFILE,
|
||||
NONE_PROFILE,
|
||||
];
|
||||
@@ -113,6 +124,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
|
||||
"sqlite" => MemoryBackendKind::Sqlite,
|
||||
"sqlite_qdrant_hybrid" | "hybrid" => MemoryBackendKind::SqliteQdrantHybrid,
|
||||
"lucid" => MemoryBackendKind::Lucid,
|
||||
"cortex-mem" | "cortex_mem" | "cortexmem" | "cortex" => MemoryBackendKind::CortexMem,
|
||||
"postgres" => MemoryBackendKind::Postgres,
|
||||
"qdrant" => MemoryBackendKind::Qdrant,
|
||||
"markdown" => MemoryBackendKind::Markdown,
|
||||
@@ -126,6 +138,7 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
|
||||
MemoryBackendKind::Sqlite => SQLITE_PROFILE,
|
||||
MemoryBackendKind::SqliteQdrantHybrid => SQLITE_QDRANT_HYBRID_PROFILE,
|
||||
MemoryBackendKind::Lucid => LUCID_PROFILE,
|
||||
MemoryBackendKind::CortexMem => CORTEX_MEM_PROFILE,
|
||||
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
|
||||
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
|
||||
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
|
||||
@@ -146,6 +159,14 @@ mod tests {
|
||||
MemoryBackendKind::SqliteQdrantHybrid
|
||||
);
|
||||
assert_eq!(classify_memory_backend("lucid"), MemoryBackendKind::Lucid);
|
||||
assert_eq!(
|
||||
classify_memory_backend("cortex-mem"),
|
||||
MemoryBackendKind::CortexMem
|
||||
);
|
||||
assert_eq!(
|
||||
classify_memory_backend("cortex_mem"),
|
||||
MemoryBackendKind::CortexMem
|
||||
);
|
||||
assert_eq!(
|
||||
classify_memory_backend("postgres"),
|
||||
MemoryBackendKind::Postgres
|
||||
@@ -173,11 +194,12 @@ mod tests {
|
||||
#[test]
|
||||
fn selectable_backends_are_ordered_for_onboarding() {
|
||||
let backends = selectable_memory_backends();
|
||||
assert_eq!(backends.len(), 4);
|
||||
assert_eq!(backends.len(), 5);
|
||||
assert_eq!(backends[0].key, "sqlite");
|
||||
assert_eq!(backends[1].key, "lucid");
|
||||
assert_eq!(backends[2].key, "markdown");
|
||||
assert_eq!(backends[3].key, "none");
|
||||
assert_eq!(backends[2].key, "cortex-mem");
|
||||
assert_eq!(backends[3].key, "markdown");
|
||||
assert_eq!(backends[4].key, "none");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -188,6 +210,15 @@ mod tests {
|
||||
assert!(profile.uses_sqlite_hygiene);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cortex_profile_is_sqlite_based_optional_backend() {
|
||||
let profile = memory_backend_profile("cortex-mem");
|
||||
assert_eq!(profile.key, "cortex-mem");
|
||||
assert!(profile.sqlite_based);
|
||||
assert!(profile.optional_dependency);
|
||||
assert!(profile.uses_sqlite_hygiene);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_profile_preserves_extensibility_defaults() {
|
||||
let profile = memory_backend_profile("custom-memory");
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
use super::lucid::LucidMemory;
|
||||
use super::sqlite::SqliteMemory;
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct CortexMemMemory {
|
||||
inner: LucidMemory,
|
||||
}
|
||||
|
||||
impl CortexMemMemory {
|
||||
const DEFAULT_CORTEX_CMD: &'static str = "cortex-mem";
|
||||
|
||||
pub fn new(workspace_dir: &Path, local: SqliteMemory) -> Self {
|
||||
let cortex_cmd = std::env::var("ZEROCLAW_CORTEX_CMD")
|
||||
.or_else(|_| std::env::var("ZEROCLAW_LUCID_CMD"))
|
||||
.unwrap_or_else(|_| Self::DEFAULT_CORTEX_CMD.to_string());
|
||||
|
||||
let inner = LucidMemory::new_with_command(workspace_dir, local, cortex_cmd);
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn new_with_command_for_test(workspace_dir: &Path, local: SqliteMemory, command: &str) -> Self {
|
||||
let inner = LucidMemory::new_with_command(workspace_dir, local, command.to_string());
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for CortexMemMemory {
|
||||
fn name(&self) -> &str {
|
||||
"cortex-mem"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.inner.store(key, content, category, session_id).await
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.inner.recall(query, limit, session_id).await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
self.inner.get(key).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.inner.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
self.inner.forget(key).await
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
self.inner.count().await
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.inner.health_check().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn cortex_backend_reports_expected_name() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sqlite = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let memory = CortexMemMemory::new(tmp.path(), sqlite);
|
||||
assert_eq!(memory.name(), "cortex-mem");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cortex_backend_keeps_local_store_when_bridge_command_fails() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sqlite = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let memory =
|
||||
CortexMemMemory::new_with_command_for_test(tmp.path(), sqlite, "missing-cortex-cli");
|
||||
|
||||
memory
|
||||
.store("cortex_key", "local first", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stored = memory.get("cortex_key").await.unwrap();
|
||||
assert!(stored.is_some(), "expected local sqlite entry to be present");
|
||||
assert_eq!(stored.unwrap().content, "local first");
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,14 @@ impl LucidMemory {
|
||||
pub fn new(workspace_dir: &Path, local: SqliteMemory) -> Self {
|
||||
let lucid_cmd = std::env::var("ZEROCLAW_LUCID_CMD")
|
||||
.unwrap_or_else(|_| Self::DEFAULT_LUCID_CMD.to_string());
|
||||
Self::new_with_command(workspace_dir, local, lucid_cmd)
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_command(
|
||||
workspace_dir: &Path,
|
||||
local: SqliteMemory,
|
||||
lucid_cmd: String,
|
||||
) -> Self {
|
||||
let token_budget = std::env::var("ZEROCLAW_LUCID_BUDGET")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<usize>().ok())
|
||||
|
||||
+27
-1
@@ -1,6 +1,7 @@
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod cortex;
|
||||
pub mod embeddings;
|
||||
pub mod hybrid;
|
||||
pub mod hygiene;
|
||||
@@ -21,6 +22,7 @@ pub use backend::{
|
||||
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
|
||||
selectable_memory_backends, MemoryBackendKind, MemoryBackendProfile,
|
||||
};
|
||||
pub use cortex::CortexMemMemory;
|
||||
pub use hybrid::SqliteQdrantHybridMemory;
|
||||
pub use lucid::LucidMemory;
|
||||
pub use markdown::MarkdownMemory;
|
||||
@@ -58,6 +60,10 @@ where
|
||||
let local = sqlite_builder()?;
|
||||
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
|
||||
}
|
||||
MemoryBackendKind::CortexMem => {
|
||||
let local = sqlite_builder()?;
|
||||
Ok(Box::new(CortexMemMemory::new(workspace_dir, local)))
|
||||
}
|
||||
MemoryBackendKind::Postgres => postgres_builder(),
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
|
||||
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
|
||||
@@ -217,6 +223,7 @@ pub fn create_memory_with_storage_and_routes(
|
||||
MemoryBackendKind::Sqlite
|
||||
| MemoryBackendKind::SqliteQdrantHybrid
|
||||
| MemoryBackendKind::Lucid
|
||||
| MemoryBackendKind::CortexMem
|
||||
)
|
||||
{
|
||||
if let Err(e) = snapshot::export_snapshot(workspace_dir) {
|
||||
@@ -232,6 +239,7 @@ pub fn create_memory_with_storage_and_routes(
|
||||
MemoryBackendKind::Sqlite
|
||||
| MemoryBackendKind::SqliteQdrantHybrid
|
||||
| MemoryBackendKind::Lucid
|
||||
| MemoryBackendKind::CortexMem
|
||||
)
|
||||
&& snapshot::should_hydrate(workspace_dir)
|
||||
{
|
||||
@@ -381,7 +389,7 @@ pub fn create_memory_for_migration(
|
||||
) -> anyhow::Result<Box<dyn Memory>> {
|
||||
if matches!(classify_memory_backend(backend), MemoryBackendKind::None) {
|
||||
anyhow::bail!(
|
||||
"memory backend 'none' disables persistence; choose sqlite, lucid, or markdown before migration"
|
||||
"memory backend 'none' disables persistence; choose sqlite, lucid, cortex-mem, or markdown before migration"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -477,6 +485,17 @@ mod tests {
|
||||
assert_eq!(mem.name(), "lucid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_cortex_mem() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "cortex-mem".into(),
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
|
||||
assert_eq!(mem.name(), "cortex-mem");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_sqlite_qdrant_hybrid() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
@@ -521,6 +540,13 @@ mod tests {
|
||||
assert_eq!(mem.name(), "lucid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migration_factory_cortex_mem() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = create_memory_for_migration("cortex-mem", tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "cortex-mem");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migration_factory_none_is_rejected() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@@ -48,7 +48,7 @@ impl CostObserver {
|
||||
// Try model family matching (e.g., "claude-sonnet-4" matches any claude-sonnet-4-*)
|
||||
for (key, pricing) in &self.prices {
|
||||
// Strip provider prefix if present
|
||||
let key_model = key.split('/').last().unwrap_or(key);
|
||||
let key_model = key.split('/').next_back().unwrap_or(key);
|
||||
|
||||
// Check if model starts with the key (family match)
|
||||
if model.starts_with(key_model) || key_model.starts_with(model) {
|
||||
|
||||
+12
-5
@@ -381,7 +381,7 @@ fn apply_provider_update(
|
||||
// ── Quick setup (zero prompts) ───────────────────────────────────
|
||||
|
||||
/// Non-interactive setup: generates a sensible default config instantly.
|
||||
/// Use `zeroclaw onboard` or `zeroclaw onboard --api-key sk-... --provider openrouter --memory sqlite|lucid`.
|
||||
/// Use `zeroclaw onboard` or `zeroclaw onboard --api-key sk-... --provider openrouter --memory sqlite|lucid|cortex-mem`.
|
||||
/// Use `zeroclaw onboard --interactive` for the full wizard.
|
||||
fn backend_key_from_choice(choice: usize) -> &'static str {
|
||||
selectable_memory_backends()
|
||||
@@ -787,8 +787,7 @@ fn default_model_for_provider(provider: &str) -> String {
|
||||
"qwen-code" => "qwen3-coder-plus".into(),
|
||||
"ollama" => "llama3.2".into(),
|
||||
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF".into(),
|
||||
"sglang" | "vllm" | "osaurus" => "default".into(),
|
||||
"copilot" => "default".into(),
|
||||
"sglang" | "vllm" | "osaurus" | "copilot" => "default".into(),
|
||||
"gemini" => "gemini-2.5-pro".into(),
|
||||
"kimi-code" => "kimi-for-coding".into(),
|
||||
"bedrock" => "anthropic.claude-sonnet-4-5-20250929-v1:0".into(),
|
||||
@@ -4479,6 +4478,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
} else {
|
||||
Some(channel)
|
||||
},
|
||||
channel_ids: vec![],
|
||||
allowed_users,
|
||||
group_reply: None,
|
||||
});
|
||||
@@ -8328,8 +8328,9 @@ mod tests {
|
||||
fn backend_key_from_choice_maps_supported_backends() {
|
||||
assert_eq!(backend_key_from_choice(0), "sqlite");
|
||||
assert_eq!(backend_key_from_choice(1), "lucid");
|
||||
assert_eq!(backend_key_from_choice(2), "markdown");
|
||||
assert_eq!(backend_key_from_choice(3), "none");
|
||||
assert_eq!(backend_key_from_choice(2), "cortex-mem");
|
||||
assert_eq!(backend_key_from_choice(3), "markdown");
|
||||
assert_eq!(backend_key_from_choice(4), "none");
|
||||
assert_eq!(backend_key_from_choice(999), "sqlite");
|
||||
}
|
||||
|
||||
@@ -8341,6 +8342,12 @@ mod tests {
|
||||
assert!(lucid.sqlite_based);
|
||||
assert!(lucid.optional_dependency);
|
||||
|
||||
let cortex_mem = memory_backend_profile("cortex-mem");
|
||||
assert!(cortex_mem.auto_save_default);
|
||||
assert!(cortex_mem.uses_sqlite_hygiene);
|
||||
assert!(cortex_mem.sqlite_based);
|
||||
assert!(cortex_mem.optional_dependency);
|
||||
|
||||
let markdown = memory_backend_profile("markdown");
|
||||
assert!(markdown.auto_save_default);
|
||||
assert!(!markdown.uses_sqlite_hygiene);
|
||||
|
||||
@@ -126,7 +126,7 @@ pub fn discover_plugins(workspace_dir: Option<&Path>, extra_paths: &[PathBuf]) -
|
||||
let mut deduped: Vec<DiscoveredPlugin> = Vec::with_capacity(seen.len());
|
||||
// Collect in insertion order of the winning index
|
||||
let mut indices: Vec<usize> = seen.values().copied().collect();
|
||||
indices.sort();
|
||||
indices.sort_unstable();
|
||||
for i in indices {
|
||||
deduped.push(all_plugins.swap_remove(i));
|
||||
}
|
||||
|
||||
@@ -13,7 +13,10 @@ use tracing::{info, warn};
|
||||
use crate::config::PluginsConfig;
|
||||
|
||||
use super::discovery::discover_plugins;
|
||||
use super::registry::*;
|
||||
use super::registry::{
|
||||
DiagnosticLevel, PluginDiagnostic, PluginHookRegistration, PluginOrigin, PluginRecord,
|
||||
PluginRegistry, PluginStatus, PluginToolRegistration,
|
||||
};
|
||||
use super::traits::{Plugin, PluginApi, PluginLogger};
|
||||
|
||||
/// Resolve whether a discovered plugin should be enabled.
|
||||
|
||||
+146
-47
@@ -16,6 +16,9 @@ use hmac::{Hmac, Mac};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Hostname prefix for the Bedrock Runtime endpoint.
|
||||
const ENDPOINT_PREFIX: &str = "bedrock-runtime";
|
||||
@@ -27,6 +30,7 @@ const DEFAULT_MAX_TOKENS: u32 = 4096;
|
||||
// ── AWS Credentials ─────────────────────────────────────────────
|
||||
|
||||
/// Resolved AWS credentials for SigV4 signing.
|
||||
#[derive(Clone)]
|
||||
struct AwsCredentials {
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
@@ -134,11 +138,68 @@ impl AwsCredentials {
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve credentials: env vars first, then EC2 IMDS.
|
||||
/// Fetch credentials from ECS container credential endpoint.
|
||||
/// Available when running on ECS/Fargate with a task IAM role.
|
||||
async fn from_ecs() -> anyhow::Result<Self> {
|
||||
// Try relative URI first (standard ECS), then full URI (ECS Anywhere / custom)
|
||||
let uri = std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
|
||||
.ok()
|
||||
.map(|rel| format!("http://169.254.170.2{rel}"))
|
||||
.or_else(|| std::env::var("AWS_CONTAINER_CREDENTIALS_FULL_URI").ok());
|
||||
|
||||
let uri = uri.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Neither AWS_CONTAINER_CREDENTIALS_RELATIVE_URI nor \
|
||||
AWS_CONTAINER_CREDENTIALS_FULL_URI is set"
|
||||
)
|
||||
})?;
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(3))
|
||||
.build()?;
|
||||
|
||||
let mut req = client.get(&uri);
|
||||
// ECS Anywhere / full URI may require an authorization token
|
||||
if let Ok(token) = std::env::var("AWS_CONTAINER_AUTHORIZATION_TOKEN") {
|
||||
req = req.header("Authorization", token);
|
||||
}
|
||||
|
||||
let creds_json: serde_json::Value = req.send().await?.json().await?;
|
||||
|
||||
let access_key_id = creds_json["AccessKeyId"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing AccessKeyId in ECS credential response"))?
|
||||
.to_string();
|
||||
let secret_access_key = creds_json["SecretAccessKey"]
|
||||
.as_str()
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Missing SecretAccessKey in ECS credential response")
|
||||
})?
|
||||
.to_string();
|
||||
let session_token = creds_json["Token"].as_str().map(|s| s.to_string());
|
||||
|
||||
let region = env_optional("AWS_REGION")
|
||||
.or_else(|| env_optional("AWS_DEFAULT_REGION"))
|
||||
.unwrap_or_else(|| DEFAULT_REGION.to_string());
|
||||
|
||||
tracing::info!("Loaded AWS credentials from ECS container credential endpoint");
|
||||
|
||||
Ok(Self {
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
session_token,
|
||||
region,
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve credentials: env vars → ECS endpoint → EC2 IMDS.
|
||||
async fn resolve() -> anyhow::Result<Self> {
|
||||
if let Ok(creds) = Self::from_env() {
|
||||
return Ok(creds);
|
||||
}
|
||||
if let Ok(creds) = Self::from_ecs().await {
|
||||
return Ok(creds);
|
||||
}
|
||||
Self::from_imds().await
|
||||
}
|
||||
|
||||
@@ -176,6 +237,57 @@ fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// How long credentials are considered fresh before re-fetching.
|
||||
/// ECS STS tokens typically expire after 6-12 hours; we refresh well
|
||||
/// before that to avoid any requests hitting expired tokens.
|
||||
const CREDENTIAL_TTL_SECS: u64 = 50 * 60; // 50 minutes
|
||||
|
||||
/// Thread-safe credential cache that auto-refreshes from the ECS
|
||||
/// container credential endpoint (or env vars / IMDS) when the
|
||||
/// cached credentials are older than [`CREDENTIAL_TTL_SECS`].
|
||||
struct CachedCredentials {
|
||||
inner: Arc<RwLock<Option<(AwsCredentials, Instant)>>>,
|
||||
}
|
||||
|
||||
impl CachedCredentials {
|
||||
/// Create a new cache, optionally pre-populated with initial credentials.
|
||||
fn new(initial: Option<AwsCredentials>) -> Self {
|
||||
let entry = initial.map(|c| (c, Instant::now()));
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(entry)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current credentials, refreshing if stale or missing.
|
||||
async fn get(&self) -> anyhow::Result<AwsCredentials> {
|
||||
// Fast path: read lock, check freshness
|
||||
{
|
||||
let guard = self.inner.read().await;
|
||||
if let Some((ref creds, fetched_at)) = *guard {
|
||||
if fetched_at.elapsed().as_secs() < CREDENTIAL_TTL_SECS {
|
||||
return Ok(creds.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: write lock, re-fetch
|
||||
let mut guard = self.inner.write().await;
|
||||
// Double-check after acquiring write lock (another task may have refreshed)
|
||||
if let Some((ref creds, fetched_at)) = *guard {
|
||||
if fetched_at.elapsed().as_secs() < CREDENTIAL_TTL_SECS {
|
||||
return Ok(creds.clone());
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Refreshing AWS credentials (TTL expired or first fetch)");
|
||||
let fresh = AwsCredentials::resolve().await?;
|
||||
let cloned = fresh.clone();
|
||||
*guard = Some((fresh, Instant::now()));
|
||||
Ok(cloned)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Derive the SigV4 signing key via HMAC chain.
|
||||
fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
|
||||
let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes());
|
||||
@@ -454,19 +566,21 @@ struct ResponseToolUseWrapper {
|
||||
// ── BedrockProvider ─────────────────────────────────────────────
|
||||
|
||||
pub struct BedrockProvider {
|
||||
credentials: Option<AwsCredentials>,
|
||||
credentials: CachedCredentials,
|
||||
}
|
||||
|
||||
impl BedrockProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
credentials: AwsCredentials::from_env().ok(),
|
||||
credentials: CachedCredentials::new(AwsCredentials::from_env().ok()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_async() -> Self {
|
||||
let credentials = AwsCredentials::resolve().await.ok();
|
||||
Self { credentials }
|
||||
let initial = AwsCredentials::resolve().await.ok();
|
||||
Self {
|
||||
credentials: CachedCredentials::new(initial),
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
@@ -504,22 +618,10 @@ impl BedrockProvider {
|
||||
format!("/model/{encoded}/converse-stream")
|
||||
}
|
||||
|
||||
fn require_credentials(&self) -> anyhow::Result<&AwsCredentials> {
|
||||
self.credentials.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"AWS Bedrock credentials not set. Set AWS_ACCESS_KEY_ID and \
|
||||
AWS_SECRET_ACCESS_KEY environment variables, or run on an EC2 \
|
||||
instance with an IAM role attached."
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve credentials: use cached if available, otherwise fetch from IMDS.
|
||||
async fn resolve_credentials(&self) -> anyhow::Result<AwsCredentials> {
|
||||
if let Ok(creds) = AwsCredentials::from_env() {
|
||||
return Ok(creds);
|
||||
}
|
||||
AwsCredentials::from_imds().await
|
||||
/// Get credentials, auto-refreshing from the ECS endpoint / env vars /
|
||||
/// IMDS when they are older than [`CREDENTIAL_TTL_SECS`].
|
||||
async fn get_credentials(&self) -> anyhow::Result<AwsCredentials> {
|
||||
self.credentials.get().await
|
||||
}
|
||||
|
||||
// ── Cache heuristics (same thresholds as AnthropicProvider) ──
|
||||
@@ -1243,7 +1345,7 @@ impl Provider for BedrockProvider {
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let credentials = self.resolve_credentials().await?;
|
||||
let credentials = self.get_credentials().await?;
|
||||
|
||||
let system = system_prompt.map(|text| {
|
||||
let mut blocks = vec![SystemBlock::Text(TextBlock {
|
||||
@@ -1285,7 +1387,7 @@ impl Provider for BedrockProvider {
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credentials = self.resolve_credentials().await?;
|
||||
let credentials = self.get_credentials().await?;
|
||||
|
||||
let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
|
||||
|
||||
@@ -1344,18 +1446,6 @@ impl Provider for BedrockProvider {
|
||||
temperature: f64,
|
||||
options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
let credentials = match self.require_credentials() {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return stream::once(async {
|
||||
Err(StreamError::Provider(
|
||||
"AWS Bedrock credentials not set".to_string(),
|
||||
))
|
||||
})
|
||||
.boxed();
|
||||
}
|
||||
};
|
||||
|
||||
let system = system_prompt.map(|text| {
|
||||
let mut blocks = vec![SystemBlock::Text(TextBlock {
|
||||
text: text.to_string(),
|
||||
@@ -1381,13 +1471,7 @@ impl Provider for BedrockProvider {
|
||||
tool_config: None,
|
||||
};
|
||||
|
||||
// Clone what we need for the async block
|
||||
let credentials = AwsCredentials {
|
||||
access_key_id: credentials.access_key_id.clone(),
|
||||
secret_access_key: credentials.secret_access_key.clone(),
|
||||
session_token: credentials.session_token.clone(),
|
||||
region: credentials.region.clone(),
|
||||
};
|
||||
let cred_cache = self.credentials.inner.clone();
|
||||
let model = model.to_string();
|
||||
let count_tokens = options.count_tokens;
|
||||
let client = self.http_client();
|
||||
@@ -1397,6 +1481,21 @@ impl Provider for BedrockProvider {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Resolve credentials inside the async context so we get
|
||||
// TTL-validated, auto-refreshing credentials (not stale sync cache).
|
||||
let cred_handle = CachedCredentials { inner: cred_cache };
|
||||
let credentials = match cred_handle.get().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
let _ = tx
|
||||
.send(Err(StreamError::Provider(format!(
|
||||
"AWS Bedrock credentials not available: {e}"
|
||||
))))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let payload = match serde_json::to_vec(&request) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
@@ -1530,7 +1629,7 @@ impl Provider for BedrockProvider {
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
if let Some(ref creds) = self.credentials {
|
||||
if let Ok(creds) = self.get_credentials().await {
|
||||
let url = format!("https://{ENDPOINT_PREFIX}.{}.amazonaws.com/", creds.region);
|
||||
let _ = self.http_client().get(&url).send().await;
|
||||
}
|
||||
@@ -1696,7 +1795,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_credentials() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
|
||||
let result = provider
|
||||
.chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", 0.7)
|
||||
.await;
|
||||
@@ -1992,14 +2091,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_without_credentials_is_noop() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
|
||||
let result = provider.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_reports_native_tool_calling() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
|
||||
let caps = provider.capabilities();
|
||||
assert!(caps.native_tool_calling);
|
||||
}
|
||||
@@ -2053,7 +2152,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn supports_streaming_returns_true() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { credentials: CachedCredentials::new(None) };
|
||||
assert!(provider.supports_streaming());
|
||||
}
|
||||
|
||||
|
||||
+150
-11
@@ -39,7 +39,8 @@ pub mod traits;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{
|
||||
ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderCapabilityError,
|
||||
ToolCall, ToolResultMessage,
|
||||
is_user_or_assistant_role, ToolCall, ToolResultMessage, ROLE_ASSISTANT, ROLE_SYSTEM, ROLE_TOOL,
|
||||
ROLE_USER,
|
||||
};
|
||||
|
||||
use crate::auth::AuthService;
|
||||
@@ -1511,21 +1512,52 @@ pub fn create_routed_provider_with_options(
|
||||
);
|
||||
}
|
||||
|
||||
// Keep a default provider for non-routed model hints.
|
||||
let default_provider = create_resilient_provider_with_options(
|
||||
let default_hint = default_model
|
||||
.strip_prefix("hint:")
|
||||
.map(str::trim)
|
||||
.filter(|hint| !hint.is_empty());
|
||||
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
let mut has_primary_provider = false;
|
||||
|
||||
// Keep a default provider for non-routed requests. When default_model is a hint,
|
||||
// route-specific providers can satisfy startup even if the primary fails.
|
||||
match create_resilient_provider_with_options(
|
||||
primary_name,
|
||||
api_key,
|
||||
api_url,
|
||||
reliability,
|
||||
options,
|
||||
)?;
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> =
|
||||
vec![(primary_name.to_string(), default_provider)];
|
||||
) {
|
||||
Ok(default_provider) => {
|
||||
providers.push((primary_name.to_string(), default_provider));
|
||||
has_primary_provider = true;
|
||||
}
|
||||
Err(error) => {
|
||||
if default_hint.is_some() {
|
||||
tracing::warn!(
|
||||
provider = primary_name,
|
||||
model = default_model,
|
||||
"Primary provider failed during routed init; continuing with hint-based routes: {error}"
|
||||
);
|
||||
} else {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build hint routes with dedicated provider instances so per-route API keys
|
||||
// and max_tokens overrides do not bleed across routes.
|
||||
let mut routes: Vec<(String, router::Route)> = Vec::new();
|
||||
for route in model_routes {
|
||||
let route_hint = route.hint.trim();
|
||||
if route_hint.is_empty() {
|
||||
tracing::warn!(
|
||||
provider = route.provider.as_str(),
|
||||
"Ignoring routed provider with empty hint"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let routed_credential = route.api_key.as_ref().and_then(|raw_key| {
|
||||
let trimmed_key = raw_key.trim();
|
||||
(!trimmed_key.is_empty()).then_some(trimmed_key)
|
||||
@@ -1554,10 +1586,10 @@ pub fn create_routed_provider_with_options(
|
||||
&route_options,
|
||||
) {
|
||||
Ok(provider) => {
|
||||
let provider_id = format!("{}#{}", route.provider, route.hint);
|
||||
let provider_id = format!("{}#{}", route.provider, route_hint);
|
||||
providers.push((provider_id.clone(), provider));
|
||||
routes.push((
|
||||
route.hint.clone(),
|
||||
route_hint.to_string(),
|
||||
router::Route {
|
||||
provider_name: provider_id,
|
||||
model: route.model.clone(),
|
||||
@@ -1567,19 +1599,42 @@ pub fn create_routed_provider_with_options(
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
provider = route.provider.as_str(),
|
||||
hint = route.hint.as_str(),
|
||||
hint = route_hint,
|
||||
"Ignoring routed provider that failed to initialize: {error}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(hint) = default_hint {
|
||||
if !routes
|
||||
.iter()
|
||||
.any(|(route_hint, _)| route_hint.trim() == hint)
|
||||
{
|
||||
anyhow::bail!(
|
||||
"default_model uses hint '{hint}', but no matching [[model_routes]] entry initialized successfully"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if providers.is_empty() {
|
||||
anyhow::bail!("No providers initialized for routed configuration");
|
||||
}
|
||||
|
||||
// Keep only successfully initialized routed providers and preserve
|
||||
// their provider-id bindings (e.g. "<provider>#<hint>").
|
||||
|
||||
Ok(Box::new(
|
||||
router::RouterProvider::new(providers, routes, default_model.to_string())
|
||||
.with_vision_override(options.model_support_vision),
|
||||
router::RouterProvider::new(
|
||||
providers,
|
||||
routes,
|
||||
if has_primary_provider {
|
||||
String::new()
|
||||
} else {
|
||||
default_model.to_string()
|
||||
},
|
||||
)
|
||||
.with_vision_override(options.model_support_vision),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -3124,6 +3179,90 @@ mod tests {
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routed_provider_supports_hint_default_when_primary_init_fails() {
|
||||
let reliability = crate::config::ReliabilityConfig::default();
|
||||
let routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "reasoning".to_string(),
|
||||
provider: "lmstudio".to_string(),
|
||||
model: "qwen2.5-coder".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let provider = create_routed_provider_with_options(
|
||||
"provider-that-does-not-exist",
|
||||
None,
|
||||
None,
|
||||
&reliability,
|
||||
&routes,
|
||||
"hint:reasoning",
|
||||
&ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert!(
|
||||
provider.is_ok(),
|
||||
"hint default should allow startup from route providers"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routed_provider_normalizes_whitespace_in_hint_routes() {
|
||||
let reliability = crate::config::ReliabilityConfig::default();
|
||||
let routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: " reasoning ".to_string(),
|
||||
provider: "lmstudio".to_string(),
|
||||
model: "qwen2.5-coder".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let provider = create_routed_provider_with_options(
|
||||
"provider-that-does-not-exist",
|
||||
None,
|
||||
None,
|
||||
&reliability,
|
||||
&routes,
|
||||
"hint: reasoning ",
|
||||
&ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert!(
|
||||
provider.is_ok(),
|
||||
"trimmed default hint should match trimmed route hint"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routed_provider_rejects_unresolved_hint_default() {
|
||||
let reliability = crate::config::ReliabilityConfig::default();
|
||||
let routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "fast".to_string(),
|
||||
provider: "lmstudio".to_string(),
|
||||
model: "qwen2.5-coder".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let err = match create_routed_provider_with_options(
|
||||
"provider-that-does-not-exist",
|
||||
None,
|
||||
None,
|
||||
&reliability,
|
||||
&routes,
|
||||
"hint:reasoning",
|
||||
&ProviderRuntimeOptions::default(),
|
||||
) {
|
||||
Ok(_) => panic!("missing default hint route should fail initialization"),
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("default_model uses hint 'reasoning'"));
|
||||
}
|
||||
|
||||
// --- parse_provider_profile ---
|
||||
|
||||
#[test]
|
||||
|
||||
+42
-5
@@ -48,12 +48,17 @@ impl RouterProvider {
|
||||
let resolved_routes: HashMap<String, (usize, String)> = routes
|
||||
.into_iter()
|
||||
.filter_map(|(hint, route)| {
|
||||
let normalized_hint = hint.trim();
|
||||
if normalized_hint.is_empty() {
|
||||
tracing::warn!("Route hint is empty after trimming, skipping");
|
||||
return None;
|
||||
}
|
||||
let index = name_to_index.get(route.provider_name.as_str()).copied();
|
||||
match index {
|
||||
Some(i) => Some((hint, (i, route.model))),
|
||||
Some(i) => Some((normalized_hint.to_string(), (i, route.model))),
|
||||
None => {
|
||||
tracing::warn!(
|
||||
hint = hint,
|
||||
hint = normalized_hint,
|
||||
provider = route.provider_name,
|
||||
"Route references unknown provider, skipping"
|
||||
);
|
||||
@@ -63,10 +68,17 @@ impl RouterProvider {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let default_index = default_model
|
||||
.strip_prefix("hint:")
|
||||
.map(str::trim)
|
||||
.filter(|hint| !hint.is_empty())
|
||||
.and_then(|hint| resolved_routes.get(hint).map(|(idx, _)| *idx))
|
||||
.unwrap_or(0);
|
||||
|
||||
Self {
|
||||
routes: resolved_routes,
|
||||
providers,
|
||||
default_index: 0,
|
||||
default_index,
|
||||
default_model,
|
||||
vision_override: None,
|
||||
}
|
||||
@@ -85,11 +97,12 @@ impl RouterProvider {
|
||||
/// Resolve a model parameter to a (provider_index, actual_model) pair.
|
||||
fn resolve(&self, model: &str) -> (usize, String) {
|
||||
if let Some(hint) = model.strip_prefix("hint:") {
|
||||
if let Some((idx, resolved_model)) = self.routes.get(hint) {
|
||||
let normalized_hint = hint.trim();
|
||||
if let Some((idx, resolved_model)) = self.routes.get(normalized_hint) {
|
||||
return (*idx, resolved_model.clone());
|
||||
}
|
||||
tracing::warn!(
|
||||
hint = hint,
|
||||
hint = normalized_hint,
|
||||
"Unknown route hint, falling back to default provider"
|
||||
);
|
||||
}
|
||||
@@ -375,6 +388,30 @@ mod tests {
|
||||
assert_eq!(model, "claude-opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_trims_whitespace_in_hint_reference() {
|
||||
let (router, _) = make_router(
|
||||
vec![("fast", "ok"), ("smart", "ok")],
|
||||
vec![("reasoning", "smart", "claude-opus")],
|
||||
);
|
||||
|
||||
let (idx, model) = router.resolve("hint: reasoning ");
|
||||
assert_eq!(idx, 1);
|
||||
assert_eq!(model, "claude-opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_matches_routes_with_whitespace_hint_config() {
|
||||
let (router, _) = make_router(
|
||||
vec![("fast", "ok"), ("smart", "ok")],
|
||||
vec![(" reasoning ", "smart", "claude-opus")],
|
||||
);
|
||||
|
||||
let (idx, model) = router.resolve("hint:reasoning");
|
||||
assert_eq!(idx, 1);
|
||||
assert_eq!(model, "claude-opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skips_routes_with_unknown_provider() {
|
||||
let (router, _) = make_router(
|
||||
|
||||
+13
-4
@@ -11,31 +11,40 @@ pub struct ChatMessage {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
pub const ROLE_SYSTEM: &str = "system";
|
||||
pub const ROLE_USER: &str = "user";
|
||||
pub const ROLE_ASSISTANT: &str = "assistant";
|
||||
pub const ROLE_TOOL: &str = "tool";
|
||||
|
||||
pub fn is_user_or_assistant_role(role: &str) -> bool {
|
||||
role == ROLE_USER || role == ROLE_ASSISTANT
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "system".into(),
|
||||
role: ROLE_SYSTEM.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "user".into(),
|
||||
role: ROLE_USER.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "assistant".into(),
|
||||
role: ROLE_ASSISTANT.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "tool".into(),
|
||||
role: ROLE_TOOL.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ fn shannon_entropy(bytes: &[u8]) -> f64 {
|
||||
.iter()
|
||||
.filter(|&&count| count > 0)
|
||||
.map(|&count| {
|
||||
let p = count as f64 / len;
|
||||
let p = f64::from(count) / len;
|
||||
-p * p.log2()
|
||||
})
|
||||
.sum()
|
||||
@@ -396,7 +396,7 @@ mod tests {
|
||||
assert!(patterns.iter().any(|p| p.contains("Stripe")));
|
||||
assert!(redacted.contains("[REDACTED"));
|
||||
}
|
||||
_ => panic!("Should detect Stripe key"),
|
||||
LeakResult::Clean => panic!("Should detect Stripe key"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,7 +409,7 @@ mod tests {
|
||||
LeakResult::Detected { patterns, .. } => {
|
||||
assert!(patterns.iter().any(|p| p.contains("AWS")));
|
||||
}
|
||||
_ => panic!("Should detect AWS key"),
|
||||
LeakResult::Clean => panic!("Should detect AWS key"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,7 +427,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
|
||||
assert!(patterns.iter().any(|p| p.contains("private key")));
|
||||
assert!(redacted.contains("[REDACTED_PRIVATE_KEY]"));
|
||||
}
|
||||
_ => panic!("Should detect private key"),
|
||||
LeakResult::Clean => panic!("Should detect private key"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,7 +441,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
|
||||
assert!(patterns.iter().any(|p| p.contains("JWT")));
|
||||
assert!(redacted.contains("[REDACTED_JWT]"));
|
||||
}
|
||||
_ => panic!("Should detect JWT"),
|
||||
LeakResult::Clean => panic!("Should detect JWT"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,7 +454,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
|
||||
LeakResult::Detected { patterns, .. } => {
|
||||
assert!(patterns.iter().any(|p| p.contains("PostgreSQL")));
|
||||
}
|
||||
_ => panic!("Should detect database URL"),
|
||||
LeakResult::Clean => panic!("Should detect database URL"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -514,7 +514,7 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
|
||||
assert!(patterns.iter().any(|p| p.contains("High-entropy token")));
|
||||
assert!(redacted.contains("[REDACTED_HIGH_ENTROPY_TOKEN]"));
|
||||
}
|
||||
_ => panic!("expected high-entropy detection"),
|
||||
LeakResult::Clean => panic!("expected high-entropy detection"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,8 @@ fn char_class_perplexity(prefix: &str, suffix: &str) -> f64 {
|
||||
let class = classify_char(ch);
|
||||
if let Some(p) = suffix_prev {
|
||||
let numerator = f64::from(transition[p][class] + 1);
|
||||
let denominator = f64::from(row_totals[p] + CLASS_COUNT as u32);
|
||||
let class_count_u32 = u32::try_from(CLASS_COUNT).unwrap_or(u32::MAX);
|
||||
let denominator = f64::from(row_totals[p] + class_count_u32);
|
||||
nll += -(numerator / denominator).ln();
|
||||
pairs += 1;
|
||||
}
|
||||
|
||||
+3
-2
@@ -457,7 +457,8 @@ fn load_skill_md(path: &Path, dir: &Path) -> Result<Skill> {
|
||||
if let Ok(raw) = std::fs::read(&meta_path) {
|
||||
if let Ok(meta) = serde_json::from_slice::<serde_json::Value>(&raw) {
|
||||
if let Some(slug) = meta.get("slug").and_then(|v| v.as_str()) {
|
||||
let normalized = normalize_skill_name(slug.split('/').last().unwrap_or(slug));
|
||||
let normalized =
|
||||
normalize_skill_name(slug.split('/').next_back().unwrap_or(slug));
|
||||
if !normalized.is_empty() {
|
||||
name = normalized;
|
||||
}
|
||||
@@ -1714,7 +1715,7 @@ fn extract_zip_skill_meta(
|
||||
f.read_to_end(&mut buf).ok();
|
||||
if let Ok(meta) = serde_json::from_slice::<serde_json::Value>(&buf) {
|
||||
let slug_raw = meta.get("slug").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let base = slug_raw.split('/').last().unwrap_or(slug_raw);
|
||||
let base = slug_raw.split('/').next_back().unwrap_or(slug_raw);
|
||||
let name = normalize_skill_name(base);
|
||||
if !name.is_empty() {
|
||||
let version = meta
|
||||
|
||||
+140
-1
@@ -11,6 +11,8 @@ pub struct CronAddTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
const MIN_AGENT_EVERY_MS: u64 = 5 * 60 * 1000;
|
||||
|
||||
impl CronAddTool {
|
||||
pub fn new(config: Arc<Config>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { config, security }
|
||||
@@ -56,6 +58,8 @@ impl Tool for CronAddTool {
|
||||
fn description(&self) -> &str {
|
||||
"Create a scheduled cron job (shell or agent) with cron/at/every schedules. \
|
||||
Use job_type='agent' with a prompt to run the AI agent on schedule. \
|
||||
Use schedule.kind='at' for one-time reminders/delayed sends (recommended). \
|
||||
Agent jobs with schedule.kind='cron' or schedule.kind='every' are recurring and require explicit recurring confirmation. \
|
||||
To deliver output to a channel (Discord, Telegram, Slack, Mattermost, QQ, Napcat, Lark, Feishu, Email), set \
|
||||
delivery={\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id_or_chat_id>\"}. \
|
||||
This is the preferred tool for sending scheduled/delayed messages to users via channels."
|
||||
@@ -68,13 +72,18 @@ impl Tool for CronAddTool {
|
||||
"name": { "type": "string" },
|
||||
"schedule": {
|
||||
"type": "object",
|
||||
"description": "Schedule object: {kind:'cron',expr,tz?} | {kind:'at',at} | {kind:'every',every_ms}"
|
||||
"description": "Schedule object: {kind:'cron',expr,tz?} recurring | {kind:'at',at} one-time | {kind:'every',every_ms} recurring interval"
|
||||
},
|
||||
"job_type": { "type": "string", "enum": ["shell", "agent"] },
|
||||
"command": { "type": "string" },
|
||||
"prompt": { "type": "string" },
|
||||
"session_target": { "type": "string", "enum": ["isolated", "main"] },
|
||||
"model": { "type": "string" },
|
||||
"recurring_confirmed": {
|
||||
"type": "boolean",
|
||||
"description": "Required for agent recurring schedules (schedule.kind='cron' or 'every'). Set true only when recurring behavior is intentional.",
|
||||
"default": false
|
||||
},
|
||||
"delivery": {
|
||||
"type": "object",
|
||||
"description": "Delivery config to send job output to a channel. Example: {\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id>\"}",
|
||||
@@ -216,6 +225,49 @@ impl Tool for CronAddTool {
|
||||
.get("model")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(str::to_string);
|
||||
let recurring_confirmed = args
|
||||
.get("recurring_confirmed")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false);
|
||||
|
||||
match &schedule {
|
||||
Schedule::Every { every_ms } => {
|
||||
if !recurring_confirmed {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Agent jobs with recurring schedules require recurring_confirmed=true. \
|
||||
For one-time reminders, use schedule.kind='at' with an RFC3339 timestamp."
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
if *every_ms < MIN_AGENT_EVERY_MS {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Agent schedule.kind='every' must be >= {MIN_AGENT_EVERY_MS} ms (5 minutes)"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
Schedule::Cron { .. } => {
|
||||
if !recurring_confirmed {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Agent jobs with recurring schedules require recurring_confirmed=true. \
|
||||
For one-time reminders, use schedule.kind='at' with an RFC3339 timestamp."
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
Schedule::At { .. } => {}
|
||||
}
|
||||
|
||||
let delivery = match args.get("delivery") {
|
||||
Some(v) => match serde_json::from_value::<DeliveryConfig>(v.clone()) {
|
||||
@@ -482,4 +534,91 @@ mod tests {
|
||||
.unwrap_or_default()
|
||||
.contains("Missing 'prompt'"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn agent_every_requires_recurring_confirmation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": { "kind": "every", "every_ms": 300000 },
|
||||
"job_type": "agent",
|
||||
"prompt": "Send me a recurring status update"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.unwrap_or_default()
|
||||
.contains("recurring_confirmed=true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn agent_cron_requires_recurring_confirmation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": { "kind": "cron", "expr": "*/5 * * * *" },
|
||||
"job_type": "agent",
|
||||
"prompt": "Send recurring reminders"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.unwrap_or_default()
|
||||
.contains("recurring_confirmed=true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn agent_every_rejects_high_frequency_intervals() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": { "kind": "every", "every_ms": 60000 },
|
||||
"job_type": "agent",
|
||||
"prompt": "Send me updates frequently",
|
||||
"recurring_confirmed": true
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.unwrap_or_default()
|
||||
.contains("must be >= 300000 ms"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn agent_every_with_explicit_confirmation_succeeds() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": { "kind": "every", "every_ms": 300000 },
|
||||
"job_type": "agent",
|
||||
"prompt": "Share a heartbeat summary",
|
||||
"recurring_confirmed": true
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
assert!(result.output.contains("next_run"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +116,8 @@ impl Tool for CronRunTool {
|
||||
}
|
||||
|
||||
let started_at = Utc::now();
|
||||
let (success, output) = cron::scheduler::execute_job_now(&self.config, &job).await;
|
||||
let (success, output) =
|
||||
Box::pin(cron::scheduler::execute_job_now(&self.config, &job)).await;
|
||||
let finished_at = Utc::now();
|
||||
let duration_ms = (finished_at - started_at).num_milliseconds();
|
||||
let status = if success { "ok" } else { "error" };
|
||||
|
||||
@@ -803,7 +803,7 @@ mod tests {
|
||||
"coder".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
model: crate::config::DEFAULT_MODEL_FALLBACK.to_string(),
|
||||
system_prompt: None,
|
||||
api_key: Some("delegate-test-credential".to_string()),
|
||||
temperature: None,
|
||||
|
||||
@@ -301,11 +301,11 @@ mod tests {
|
||||
name: "nonexistent".to_string(),
|
||||
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
|
||||
args: vec![],
|
||||
env: Default::default(),
|
||||
env: std::collections::HashMap::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: Default::default(),
|
||||
headers: std::collections::HashMap::default(),
|
||||
};
|
||||
let result = McpServer::connect(config).await;
|
||||
assert!(result.is_err());
|
||||
@@ -320,11 +320,11 @@ mod tests {
|
||||
name: "bad".to_string(),
|
||||
command: "/usr/bin/does_not_exist_zc_test".to_string(),
|
||||
args: vec![],
|
||||
env: Default::default(),
|
||||
env: std::collections::HashMap::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: Default::default(),
|
||||
headers: std::collections::HashMap::default(),
|
||||
}];
|
||||
let registry = McpRegistry::connect_all(&configs)
|
||||
.await
|
||||
|
||||
+617
-27
@@ -1,12 +1,16 @@
|
||||
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::{oneshot, Mutex, Notify};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use crate::config::schema::{McpServerConfig, McpTransport};
|
||||
use crate::tools::mcp_protocol::{JsonRpcRequest, JsonRpcResponse};
|
||||
use crate::tools::mcp_protocol::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR};
|
||||
|
||||
/// Maximum bytes for a single JSON-RPC response.
|
||||
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
||||
@@ -95,6 +99,14 @@ impl McpTransportConn for StdioTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let line = serde_json::to_string(request)?;
|
||||
self.send_raw(&line).await?;
|
||||
if request.id.is_none() {
|
||||
return Ok(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
|
||||
.await
|
||||
.context("timeout waiting for MCP response")??;
|
||||
@@ -158,6 +170,15 @@ impl McpTransportConn for HttpTransport {
|
||||
bail!("MCP server returned HTTP {}", resp.status());
|
||||
}
|
||||
|
||||
if request.id.is_none() {
|
||||
return Ok(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
@@ -173,69 +194,602 @@ impl McpTransportConn for HttpTransport {
|
||||
// ── SSE Transport ─────────────────────────────────────────────────────────
|
||||
|
||||
/// SSE-based transport (HTTP POST for requests, SSE for responses).
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
enum SseStreamState {
|
||||
Unknown,
|
||||
Connected,
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
pub struct SseTransport {
|
||||
base_url: String,
|
||||
sse_url: String,
|
||||
server_name: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
#[allow(dead_code)]
|
||||
event_source: Option<tokio::task::JoinHandle<()>>,
|
||||
stream_state: SseStreamState,
|
||||
shared: std::sync::Arc<Mutex<SseSharedState>>,
|
||||
notify: std::sync::Arc<Notify>,
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
reader_task: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let base_url = config
|
||||
let sse_url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
base_url,
|
||||
sse_url,
|
||||
server_name: config.name.clone(),
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
event_source: None,
|
||||
stream_state: SseStreamState::Unknown,
|
||||
shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
|
||||
notify: std::sync::Arc::new(Notify::new()),
|
||||
shutdown_tx: None,
|
||||
reader_task: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn ensure_connected(&mut self) -> Result<()> {
|
||||
if self.stream_state == SseStreamState::Unsupported {
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(task) = &self.reader_task {
|
||||
if !task.is_finished() {
|
||||
self.stream_state = SseStreamState::Connected;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&self.sse_url)
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE GET to MCP server failed")?;
|
||||
if resp.status() == reqwest::StatusCode::NOT_FOUND
|
||||
|| resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
|
||||
{
|
||||
self.stream_state = SseStreamState::Unsupported;
|
||||
return Ok(());
|
||||
}
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
|
||||
}
|
||||
let is_event_stream = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
if !is_event_stream {
|
||||
self.stream_state = SseStreamState::Unsupported;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let shared = self.shared.clone();
|
||||
let notify = self.notify.clone();
|
||||
let sse_url = self.sse_url.clone();
|
||||
let server_name = self.server_name.clone();
|
||||
|
||||
self.reader_task = Some(tokio::spawn(async move {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|item| item.map_err(std::io::Error::other));
|
||||
let reader = tokio_util::io::StreamReader::new(stream);
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
let mut cur_event: Option<String> = None;
|
||||
let mut cur_id: Option<String> = None;
|
||||
let mut cur_data: Vec<String> = Vec::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut shutdown_rx => {
|
||||
break;
|
||||
}
|
||||
line = lines.next_line() => {
|
||||
let Ok(line_opt) = line else { break; };
|
||||
let Some(mut line) = line_opt else { break; };
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
if line.is_empty() {
|
||||
if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let event = cur_event.take();
|
||||
let data = cur_data.join("\n");
|
||||
cur_data.clear();
|
||||
let id = cur_id.take();
|
||||
handle_sse_event(&server_name, &sse_url, &shared, ¬ify, event.as_deref(), id.as_deref(), data).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
cur_event = Some(rest.trim().to_string());
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
cur_data.push(rest.to_string());
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("id:") {
|
||||
cur_id = Some(rest.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pending = {
|
||||
let mut guard = shared.lock().await;
|
||||
std::mem::take(&mut guard.pending)
|
||||
};
|
||||
for (_, tx) in pending {
|
||||
let _ = tx.send(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: INTERNAL_ERROR,
|
||||
message: "SSE connection closed".to_string(),
|
||||
data: None,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}));
|
||||
self.stream_state = SseStreamState::Connected;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_message_url(&self) -> Result<(String, bool)> {
|
||||
let guard = self.shared.lock().await;
|
||||
if let Some(url) = &guard.message_url {
|
||||
return Ok((url.clone(), guard.message_url_from_endpoint));
|
||||
}
|
||||
drop(guard);
|
||||
|
||||
let derived = derive_message_url(&self.sse_url, "messages")
|
||||
.or_else(|| derive_message_url(&self.sse_url, "message"))
|
||||
.ok_or_else(|| anyhow!("invalid SSE URL"))?;
|
||||
let mut guard = self.shared.lock().await;
|
||||
if guard.message_url.is_none() {
|
||||
guard.message_url = Some(derived.clone());
|
||||
guard.message_url_from_endpoint = false;
|
||||
}
|
||||
Ok((derived, false))
|
||||
}
|
||||
|
||||
fn maybe_try_alternate_message_url(
|
||||
&self,
|
||||
current_url: &str,
|
||||
from_endpoint: bool,
|
||||
) -> Option<String> {
|
||||
if from_endpoint {
|
||||
return None;
|
||||
}
|
||||
let alt = if current_url.ends_with("/messages") {
|
||||
derive_message_url(&self.sse_url, "message")
|
||||
} else {
|
||||
derive_message_url(&self.sse_url, "messages")
|
||||
}?;
|
||||
if alt == current_url {
|
||||
return None;
|
||||
}
|
||||
Some(alt)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SseSharedState {
|
||||
message_url: Option<String>,
|
||||
message_url_from_endpoint: bool,
|
||||
pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
|
||||
}
|
||||
|
||||
fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
|
||||
let url = reqwest::Url::parse(sse_url).ok()?;
|
||||
let mut segments: Vec<&str> = url.path_segments()?.collect();
|
||||
if segments.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if segments.last().copied() == Some("sse") {
|
||||
segments.pop();
|
||||
segments.push(message_path);
|
||||
let mut new_url = url.clone();
|
||||
new_url.set_path(&format!("/{}", segments.join("/")));
|
||||
return Some(new_url.to_string());
|
||||
}
|
||||
let mut new_url = url.clone();
|
||||
let mut path = url.path().trim_end_matches('/').to_string();
|
||||
path.push('/');
|
||||
path.push_str(message_path);
|
||||
new_url.set_path(&path);
|
||||
Some(new_url.to_string())
|
||||
}
|
||||
|
||||
async fn handle_sse_event(
|
||||
server_name: &str,
|
||||
sse_url: &str,
|
||||
shared: &std::sync::Arc<Mutex<SseSharedState>>,
|
||||
notify: &std::sync::Arc<Notify>,
|
||||
event: Option<&str>,
|
||||
_id: Option<&str>,
|
||||
data: String,
|
||||
) {
|
||||
let event = event.unwrap_or("message");
|
||||
let trimmed = data.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
|
||||
if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
|
||||
let mut guard = shared.lock().await;
|
||||
guard.message_url = Some(url);
|
||||
guard.message_url_from_endpoint = true;
|
||||
drop(guard);
|
||||
notify.notify_waiters();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !event.eq_ignore_ascii_case("message") {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
|
||||
let _ = serde_json::from_value::<JsonRpcRequest>(value);
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(id_val) = resp.id.clone() else {
|
||||
return;
|
||||
};
|
||||
let id = match id_val.as_u64() {
|
||||
Some(v) => v,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let tx = {
|
||||
let mut guard = shared.lock().await;
|
||||
guard.pending.remove(&id)
|
||||
};
|
||||
if let Some(tx) = tx {
|
||||
let _ = tx.send(resp);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"MCP SSE `{}` received response for unknown id {}",
|
||||
server_name,
|
||||
id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
|
||||
if data.starts_with('{') {
|
||||
let v: serde_json::Value = serde_json::from_str(data).ok()?;
|
||||
let endpoint = v.get("endpoint")?.as_str()?;
|
||||
return parse_endpoint_from_data(sse_url, endpoint);
|
||||
}
|
||||
if data.starts_with("http://") || data.starts_with("https://") {
|
||||
return Some(data.to_string());
|
||||
}
|
||||
let base = reqwest::Url::parse(sse_url).ok()?;
|
||||
base.join(data).ok().map(|u| u.to_string())
|
||||
}
|
||||
|
||||
fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
|
||||
let text = resp_text.trim_start_matches('\u{feff}');
|
||||
let mut current_data_lines: Vec<&str> = Vec::new();
|
||||
let mut last_event_data_lines: Vec<&str> = Vec::new();
|
||||
|
||||
for raw_line in text.lines() {
|
||||
let line = raw_line.trim_end_matches('\r').trim_start();
|
||||
if line.is_empty() {
|
||||
if !current_data_lines.is_empty() {
|
||||
last_event_data_lines = std::mem::take(&mut current_data_lines);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
current_data_lines.push(rest);
|
||||
}
|
||||
}
|
||||
|
||||
if !current_data_lines.is_empty() {
|
||||
last_event_data_lines = current_data_lines;
|
||||
}
|
||||
|
||||
if last_event_data_lines.is_empty() {
|
||||
return Cow::Borrowed(text.trim());
|
||||
}
|
||||
|
||||
if last_event_data_lines.len() == 1 {
|
||||
return Cow::Borrowed(last_event_data_lines[0].trim());
|
||||
}
|
||||
|
||||
let joined = last_event_data_lines.join("\n");
|
||||
Cow::Owned(joined.trim().to_string())
|
||||
}
|
||||
|
||||
async fn read_first_jsonrpc_from_sse_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<Option<JsonRpcResponse>> {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|item| item.map_err(std::io::Error::other));
|
||||
let reader = tokio_util::io::StreamReader::new(stream);
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
let mut cur_event: Option<String> = None;
|
||||
let mut cur_data: Vec<String> = Vec::new();
|
||||
|
||||
while let Ok(line_opt) = lines.next_line().await {
|
||||
let Some(mut line) = line_opt else { break };
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
if line.is_empty() {
|
||||
if cur_event.is_none() && cur_data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let event = cur_event.take();
|
||||
let data = cur_data.join("\n");
|
||||
cur_data.clear();
|
||||
|
||||
let event = event.unwrap_or_else(|| "message".to_string());
|
||||
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if !event.eq_ignore_ascii_case("message") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let trimmed = data.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let json_str = extract_json_from_sse_text(trimmed);
|
||||
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
|
||||
return Ok(Some(resp));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
cur_event = Some(rest.trim().to_string());
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
cur_data.push(rest.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for SseTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
// For SSE, we POST the request and the response comes via SSE stream.
|
||||
// Simplified implementation: treat as HTTP for now, proper SSE would
|
||||
// maintain a persistent event stream.
|
||||
self.ensure_connected().await?;
|
||||
|
||||
let id = request.id.as_ref().and_then(|v| v.as_u64());
|
||||
let body = serde_json::to_string(request)?;
|
||||
let url = format!("{}/message", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.body(body)
|
||||
.header("Content-Type", "application/json");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
|
||||
if self.stream_state == SseStreamState::Connected && !from_endpoint {
|
||||
for _ in 0..3 {
|
||||
{
|
||||
let guard = self.shared.lock().await;
|
||||
if guard.message_url_from_endpoint {
|
||||
if let Some(url) = &guard.message_url {
|
||||
message_url = url.clone();
|
||||
from_endpoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
|
||||
}
|
||||
}
|
||||
let primary_url = if from_endpoint {
|
||||
message_url.clone()
|
||||
} else {
|
||||
self.sse_url.clone()
|
||||
};
|
||||
let secondary_url = if message_url == self.sse_url {
|
||||
None
|
||||
} else if primary_url == message_url {
|
||||
Some(self.sse_url.clone())
|
||||
} else {
|
||||
Some(message_url.clone())
|
||||
};
|
||||
let has_secondary = secondary_url.is_some();
|
||||
|
||||
let mut rx = None;
|
||||
if let Some(id) = id {
|
||||
if self.stream_state == SseStreamState::Connected {
|
||||
let (tx, ch) = oneshot::channel();
|
||||
{
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.insert(id, tx);
|
||||
}
|
||||
rx = Some((id, ch));
|
||||
}
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
let mut got_direct = None;
|
||||
let mut last_status = None;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("MCP server returned HTTP {}", resp.status());
|
||||
for (i, url) in std::iter::once(primary_url)
|
||||
.chain(secondary_url.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(120))
|
||||
.body(body.clone())
|
||||
.header("Content-Type", "application/json");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"))
|
||||
{
|
||||
req = req.header("Accept", "application/json, text/event-stream");
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
let status = resp.status();
|
||||
last_status = Some(status);
|
||||
|
||||
if (status == reqwest::StatusCode::NOT_FOUND
|
||||
|| status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
|
||||
&& i == 0
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
break;
|
||||
}
|
||||
|
||||
if request.id.is_none() {
|
||||
got_direct = Some(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
let is_sse = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
|
||||
if is_sse {
|
||||
if i == 0 && has_secondary {
|
||||
match timeout(
|
||||
Duration::from_secs(3),
|
||||
read_first_jsonrpc_from_sse_response(resp),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if let Some(resp) = res? {
|
||||
got_direct = Some(resp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
|
||||
got_direct = Some(resp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let text = if i == 0 && has_secondary {
|
||||
match timeout(Duration::from_secs(3), resp.text()).await {
|
||||
Ok(Ok(t)) => t,
|
||||
Ok(Err(_)) => String::new(),
|
||||
Err(_) => continue,
|
||||
}
|
||||
} else {
|
||||
resp.text().await.unwrap_or_default()
|
||||
};
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
|
||||
extract_json_from_sse_text(trimmed)
|
||||
} else {
|
||||
Cow::Borrowed(trimmed)
|
||||
};
|
||||
if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
|
||||
got_direct = Some(mcp_resp);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// For now, parse response directly. Full SSE would read from event stream.
|
||||
let resp_text = resp.text().await.context("failed to read SSE response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
if let Some((id, _)) = rx.as_ref() {
|
||||
if got_direct.is_some() {
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.remove(id);
|
||||
} else if let Some(status) = last_status {
|
||||
if !status.is_success() {
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(mcp_resp)
|
||||
if let Some(resp) = got_direct {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
if let Some(status) = last_status {
|
||||
if !status.is_success() {
|
||||
bail!("MCP server returned HTTP {}", status);
|
||||
}
|
||||
} else {
|
||||
bail!("MCP request not sent");
|
||||
}
|
||||
|
||||
let Some((_id, rx)) = rx else {
|
||||
bail!("MCP server returned no response");
|
||||
};
|
||||
|
||||
rx.await.map_err(|_| anyhow!("SSE response channel closed"))
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
if let Some(task) = self.reader_task.take() {
|
||||
task.abort();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -282,4 +836,40 @@ mod tests {
|
||||
};
|
||||
assert!(SseTransport::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_data_no_space() {
|
||||
let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_with_event_and_id() {
|
||||
let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_multiline_data() {
|
||||
let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
|
||||
let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_uses_last_event_with_data() {
|
||||
let input =
|
||||
": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -950,8 +950,9 @@ impl Tool for ModelRoutingConfigTool {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::sync::OnceLock;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
@@ -995,11 +996,9 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
async fn env_lock() -> tokio::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
LOCK.get_or_init(|| Mutex::new(())).lock().await
|
||||
}
|
||||
|
||||
async fn test_config(tmp: &TempDir) -> Arc<Config> {
|
||||
@@ -1218,7 +1217,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_reports_env_backed_credentials_for_routes_and_agents() {
|
||||
let _env_lock = env_lock();
|
||||
let _env_lock = env_lock().await;
|
||||
let _provider_guard = EnvGuard::set("TELNYX_API_KEY", Some("test-telnyx-key"));
|
||||
let _generic_guard = EnvGuard::set("ZEROCLAW_API_KEY", None);
|
||||
let _api_key_guard = EnvGuard::set("API_KEY", None);
|
||||
|
||||
+84
-39
@@ -10,11 +10,17 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Canonical provider list for error messages and the tool description.
|
||||
/// `fast_html2md` is kept as a deprecated alias for `nanohtml2text`.
|
||||
const WEB_FETCH_PROVIDER_HELP: &str =
|
||||
"Supported providers: 'nanohtml2text' (default), 'firecrawl', 'tavily'. \
|
||||
Deprecated alias: 'fast_html2md' (maps to 'nanohtml2text').";
|
||||
|
||||
/// Web fetch tool: fetches a web page and returns text/markdown content for LLM consumption.
|
||||
///
|
||||
/// Providers:
|
||||
/// - `fast_html2md`: fetch with reqwest, convert HTML to markdown
|
||||
/// - `nanohtml2text`: fetch with reqwest, convert HTML to plaintext
|
||||
/// - `nanohtml2text` (default): fetch with reqwest, strip noise elements, convert HTML to plaintext
|
||||
/// - `fast_html2md` (deprecated alias): same as nanohtml2text unless `web-fetch-html2md` feature is compiled in
|
||||
/// - `firecrawl`: fetch using Firecrawl cloud/self-hosted API
|
||||
/// - `tavily`: fetch using Tavily Extract API
|
||||
pub struct WebFetchTool {
|
||||
@@ -33,6 +39,7 @@ pub struct WebFetchTool {
|
||||
|
||||
impl WebFetchTool {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Creates a new `WebFetchTool`. `api_key` accepts comma-separated values for round-robin rotation.
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
provider: String,
|
||||
@@ -59,7 +66,7 @@ impl WebFetchTool {
|
||||
Self {
|
||||
security,
|
||||
provider: if provider.is_empty() {
|
||||
"fast_html2md".to_string()
|
||||
"nanohtml2text".to_string()
|
||||
} else {
|
||||
provider
|
||||
},
|
||||
@@ -75,6 +82,7 @@ impl WebFetchTool {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the next API key from the rotation pool using round-robin, or `None` if unconfigured.
|
||||
fn get_next_api_key(&self) -> Option<String> {
|
||||
if self.api_keys.is_empty() {
|
||||
return None;
|
||||
@@ -83,6 +91,7 @@ impl WebFetchTool {
|
||||
Some(self.api_keys[idx].clone())
|
||||
}
|
||||
|
||||
/// Validates and normalises a URL against the allowlist, blocklist, and SSRF policy.
|
||||
fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
|
||||
validate_url(
|
||||
raw_url,
|
||||
@@ -99,6 +108,7 @@ impl WebFetchTool {
|
||||
)
|
||||
}
|
||||
|
||||
/// Truncates text to `max_response_size` characters and appends a marker if trimmed.
|
||||
fn truncate_response(&self, text: &str) -> String {
|
||||
if text.len() > self.max_response_size {
|
||||
let mut truncated = text
|
||||
@@ -112,6 +122,7 @@ impl WebFetchTool {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the configured timeout, substituting a safe 30 s default if zero is set.
|
||||
fn effective_timeout_secs(&self) -> u64 {
|
||||
if self.timeout_secs == 0 {
|
||||
tracing::warn!("web_fetch: timeout_secs is 0, using safe default of 30s");
|
||||
@@ -121,40 +132,64 @@ impl WebFetchTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
/// Strips noisy structural HTML elements (nav, scripts, footers, etc.) before text
|
||||
/// extraction to reduce boilerplate in the LLM output.
|
||||
fn strip_noise_elements(html: &str) -> anyhow::Result<String> {
|
||||
// Rust regex does not support backreferences, so run one pass per tag.
|
||||
// OnceLock stores Result<_, String> so that a compile failure is surfaced as an
|
||||
// error rather than a panic. String is used instead of anyhow::Error because it
|
||||
// is Clone + Sync, which OnceLock requires.
|
||||
use std::sync::OnceLock;
|
||||
static NOISE_RES: OnceLock<Result<Vec<regex::Regex>, String>> = OnceLock::new();
|
||||
let regexes = NOISE_RES
|
||||
.get_or_init(|| {
|
||||
[
|
||||
"script", "style", "nav", "header", "footer", "aside", "noscript", "form",
|
||||
"button",
|
||||
]
|
||||
.iter()
|
||||
.map(|tag| {
|
||||
regex::Regex::new(&format!(r"(?si)<{tag}[^>]*>.*?</{tag}>"))
|
||||
.map_err(|e| e.to_string())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.as_ref()
|
||||
.map_err(|e| anyhow::anyhow!("noise regex init failed: {e}"))?;
|
||||
let mut result = html.to_string();
|
||||
for re in regexes {
|
||||
result = re.replace_all(&result, " ").into_owned();
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Strips noise elements then converts HTML to plain text using the configured provider.
|
||||
/// `fast_html2md` is a deprecated alias that maps to `nanohtml2text` when the
|
||||
/// `web-fetch-html2md` feature is not compiled in.
|
||||
fn convert_html_to_output(&self, body: &str) -> anyhow::Result<String> {
|
||||
let cleaned = Self::strip_noise_elements(body)?;
|
||||
match self.provider.as_str() {
|
||||
"fast_html2md" => {
|
||||
#[cfg(feature = "web-fetch-html2md")]
|
||||
{
|
||||
Ok(html2md::rewrite_html(body, false))
|
||||
Ok(html2md::rewrite_html(&cleaned, false))
|
||||
}
|
||||
#[cfg(not(feature = "web-fetch-html2md"))]
|
||||
{
|
||||
anyhow::bail!(
|
||||
"web_fetch provider 'fast_html2md' requires Cargo feature 'web-fetch-html2md'"
|
||||
);
|
||||
}
|
||||
}
|
||||
"nanohtml2text" => {
|
||||
#[cfg(feature = "web-fetch-plaintext")]
|
||||
{
|
||||
Ok(nanohtml2text::html2text(body))
|
||||
}
|
||||
#[cfg(not(feature = "web-fetch-plaintext"))]
|
||||
{
|
||||
anyhow::bail!(
|
||||
"web_fetch provider 'nanohtml2text' requires Cargo feature 'web-fetch-plaintext'"
|
||||
);
|
||||
// Feature not compiled in; fall through to nanohtml2text.
|
||||
Ok(nanohtml2text::html2text(&cleaned))
|
||||
}
|
||||
}
|
||||
"nanohtml2text" => Ok(nanohtml2text::html2text(&cleaned)),
|
||||
_ => anyhow::bail!(
|
||||
"Unknown web_fetch provider: '{}'. Set [web_fetch].provider to 'fast_html2md', 'nanohtml2text', 'firecrawl', or 'tavily' in config.toml",
|
||||
self.provider
|
||||
"Unknown web_fetch provider: '{}'. {}",
|
||||
self.provider,
|
||||
WEB_FETCH_PROVIDER_HELP
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a `reqwest::Client` with the configured timeout, user-agent, and proxy settings.
|
||||
fn build_http_client(&self) -> anyhow::Result<reqwest::Client> {
|
||||
let builder = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.effective_timeout_secs()))
|
||||
@@ -165,6 +200,8 @@ impl WebFetchTool {
|
||||
Ok(builder.build()?)
|
||||
}
|
||||
|
||||
/// Fetches `url` with reqwest, handles one redirect (re-validated), and converts the
|
||||
/// response body to text via the configured HTML provider.
|
||||
async fn fetch_with_http_provider(&self, url: &str) -> anyhow::Result<String> {
|
||||
let client = self.build_http_client()?;
|
||||
let response = client.get(url).send().await?;
|
||||
@@ -221,6 +258,7 @@ impl WebFetchTool {
|
||||
)
|
||||
}
|
||||
|
||||
/// Fetches `url` via the Firecrawl scrape API and returns the extracted markdown content.
|
||||
#[cfg(feature = "firecrawl")]
|
||||
async fn fetch_with_firecrawl(&self, url: &str) -> anyhow::Result<String> {
|
||||
let auth_token = self.get_next_api_key().ok_or_else(|| {
|
||||
@@ -301,6 +339,7 @@ impl WebFetchTool {
|
||||
anyhow::bail!("web_fetch provider 'firecrawl' requires Cargo feature 'firecrawl'")
|
||||
}
|
||||
|
||||
/// Fetches `url` via the Tavily Extract API and returns the raw extracted content.
|
||||
async fn fetch_with_tavily(&self, url: &str) -> anyhow::Result<String> {
|
||||
let api_key = self.get_next_api_key().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
@@ -374,7 +413,7 @@ impl Tool for WebFetchTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fetch a web page and return markdown/text content for LLM consumption. Providers: fast_html2md, nanohtml2text, firecrawl, tavily. Security: allowlist-only domains, blocked_domains, and no local/private hosts."
|
||||
"Fetch a web page and return text content for LLM consumption. Strips navigation, scripts, and boilerplate before extraction. Providers: nanohtml2text (default), firecrawl, tavily. Deprecated alias: fast_html2md. Security: allowlist-only domains, blocked_domains, and no local/private hosts."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@@ -428,8 +467,9 @@ impl Tool for WebFetchTool {
|
||||
"firecrawl" => self.fetch_with_firecrawl(&url).await,
|
||||
"tavily" => self.fetch_with_tavily(&url).await,
|
||||
_ => Err(anyhow::anyhow!(
|
||||
"Unknown web_fetch provider: '{}'. Set [web_fetch].provider to 'fast_html2md', 'nanohtml2text', 'firecrawl', or 'tavily' in config.toml",
|
||||
self.provider
|
||||
"Unknown web_fetch provider: '{}'. {}",
|
||||
self.provider,
|
||||
WEB_FETCH_PROVIDER_HELP
|
||||
)),
|
||||
};
|
||||
|
||||
@@ -505,22 +545,12 @@ mod tests {
|
||||
assert!(required.iter().any(|v| v.as_str() == Some("url")));
|
||||
}
|
||||
|
||||
#[cfg(feature = "web-fetch-html2md")]
|
||||
// Previously gated on cfg(feature = "web-fetch-html2md") / cfg(feature = "web-fetch-plaintext")
|
||||
// — neither feature was declared in Cargo.toml so these tests never ran.
|
||||
// Now always-on: fast_html2md falls back to nanohtml2text when uncompiled.
|
||||
#[test]
|
||||
fn html_to_markdown_conversion_preserves_structure() {
|
||||
fn html_conversion_removes_tags() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
let html = "<html><body><h1>Title</h1><ul><li>Hello</li></ul></body></html>";
|
||||
let markdown = tool.convert_html_to_output(html).unwrap();
|
||||
assert!(markdown.contains("Title"));
|
||||
assert!(markdown.contains("Hello"));
|
||||
assert!(!markdown.contains("<h1>"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "web-fetch-plaintext")]
|
||||
#[test]
|
||||
fn html_to_plaintext_conversion_removes_html_tags() {
|
||||
let tool =
|
||||
test_tool_with_provider(vec!["example.com"], vec![], "nanohtml2text", None, None);
|
||||
let html = "<html><body><h1>Title</h1><p>Hello <b>world</b></p></body></html>";
|
||||
let text = tool.convert_html_to_output(html).unwrap();
|
||||
assert!(text.contains("Title"));
|
||||
@@ -528,6 +558,21 @@ mod tests {
|
||||
assert!(!text.contains("<h1>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_noise_removes_nav_scripts_footer() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
let html = "<html><body>\
|
||||
<nav><a>Home</a><a>Menu</a></nav>\
|
||||
<script>var x = 1;</script>\
|
||||
<article><p>Real content here</p></article>\
|
||||
<footer>Copyright 2025</footer>\
|
||||
</body></html>";
|
||||
let text = tool.convert_html_to_output(html).unwrap();
|
||||
assert!(text.contains("Real content"));
|
||||
assert!(!text.contains("var x"));
|
||||
assert!(!text.contains("Copyright 2025"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_accepts_exact_domain() {
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
[build]
|
||||
command = "npm run build"
|
||||
publish = "dist"
|
||||
|
||||
[[redirects]]
|
||||
from = "/*"
|
||||
to = "/index.html"
|
||||
status = 200
|
||||
Generated
+15
-1
@@ -7,11 +7,13 @@
|
||||
"": {
|
||||
"name": "zeroclaw-web",
|
||||
"version": "0.1.0",
|
||||
"license": "(MIT OR Apache-2.0)",
|
||||
"dependencies": {
|
||||
"lucide-react": "^0.468.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-router-dom": "^7.1.1"
|
||||
"react-router-dom": "^7.1.1",
|
||||
"smol-toml": "^1.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
@@ -2327,6 +2329,18 @@
|
||||
"integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/smol-toml": {
|
||||
"version": "1.6.0",
|
||||
"resolved": "https://registry.npmjs.org/smol-toml/-/smol-toml-1.6.0.tgz",
|
||||
"integrity": "sha512-4zemZi0HvTnYwLfrpk/CF9LOd9Lt87kAt50GnqhMpyF9U3poDAP2+iukq2bZsO/ufegbYehBkqINbsWxj4l4cw==",
|
||||
"license": "BSD-3-Clause",
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/cyyynthia"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map-js": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
|
||||
|
||||
+2
-1
@@ -13,7 +13,8 @@
|
||||
"lucide-react": "^0.468.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-router-dom": "^7.1.1"
|
||||
"react-router-dom": "^7.1.1",
|
||||
"smol-toml": "^1.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
import { useState, useMemo } from 'react';
|
||||
import { Search } from 'lucide-react';
|
||||
import { CONFIG_SECTIONS } from './configSections';
|
||||
import ConfigSection from './ConfigSection';
|
||||
import type { FieldDef } from './types';
|
||||
|
||||
const CATEGORY_ORDER = [
|
||||
{ key: 'all', label: 'All' },
|
||||
{ key: 'general', label: 'General' },
|
||||
{ key: 'security', label: 'Security' },
|
||||
{ key: 'channels', label: 'Channels' },
|
||||
{ key: 'runtime', label: 'Runtime' },
|
||||
{ key: 'tools', label: 'Tools' },
|
||||
{ key: 'memory', label: 'Memory' },
|
||||
{ key: 'network', label: 'Network' },
|
||||
{ key: 'advanced', label: 'Advanced' },
|
||||
] as const;
|
||||
|
||||
interface Props {
|
||||
getFieldValue: (sectionPath: string, fieldKey: string) => unknown;
|
||||
setFieldValue: (sectionPath: string, fieldKey: string, value: unknown) => void;
|
||||
isFieldMasked: (sectionPath: string, fieldKey: string) => boolean;
|
||||
}
|
||||
|
||||
export default function ConfigFormEditor({
|
||||
getFieldValue,
|
||||
setFieldValue,
|
||||
isFieldMasked,
|
||||
}: Props) {
|
||||
const [search, setSearch] = useState('');
|
||||
const [activeCategory, setActiveCategory] = useState('all');
|
||||
|
||||
const isSearching = search.trim().length > 0;
|
||||
|
||||
const filteredSections = useMemo(() => {
|
||||
if (isSearching) {
|
||||
const q = search.toLowerCase();
|
||||
return CONFIG_SECTIONS.map((section) => {
|
||||
const titleMatch = section.title.toLowerCase().includes(q);
|
||||
const descMatch = section.description?.toLowerCase().includes(q);
|
||||
|
||||
if (titleMatch || descMatch) {
|
||||
return { section, fields: undefined };
|
||||
}
|
||||
|
||||
const matchingFields = section.fields.filter(
|
||||
(f: FieldDef) =>
|
||||
f.label.toLowerCase().includes(q) ||
|
||||
f.key.toLowerCase().includes(q) ||
|
||||
f.description?.toLowerCase().includes(q),
|
||||
);
|
||||
|
||||
if (matchingFields.length > 0) {
|
||||
return { section, fields: matchingFields };
|
||||
}
|
||||
|
||||
return null;
|
||||
}).filter(Boolean) as { section: (typeof CONFIG_SECTIONS)[0]; fields: FieldDef[] | undefined }[];
|
||||
}
|
||||
|
||||
// Category filter
|
||||
const sections = activeCategory === 'all'
|
||||
? CONFIG_SECTIONS
|
||||
: CONFIG_SECTIONS.filter((s) => s.category === activeCategory);
|
||||
|
||||
return sections.map((s) => ({ section: s, fields: undefined }));
|
||||
}, [search, isSearching, activeCategory]);
|
||||
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
{/* Search */}
|
||||
<div className="relative">
|
||||
<Search className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-gray-500" />
|
||||
<input
|
||||
type="text"
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
placeholder="Search config fields..."
|
||||
className="w-full bg-gray-800 border border-gray-700 rounded-lg pl-9 pr-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Category pills — hidden during search */}
|
||||
{!isSearching && (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{CATEGORY_ORDER.map(({ key, label }) => (
|
||||
<button
|
||||
key={key}
|
||||
onClick={() => setActiveCategory(key)}
|
||||
className={`px-3 py-1 rounded-lg text-sm font-medium transition-colors ${
|
||||
activeCategory === key
|
||||
? 'bg-blue-600 text-white'
|
||||
: 'bg-gray-900 text-gray-400 border border-gray-700 hover:bg-gray-800 hover:text-gray-200'
|
||||
}`}
|
||||
>
|
||||
{label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Sections */}
|
||||
{filteredSections.length === 0 ? (
|
||||
<div className="text-center py-12 text-gray-500 text-sm">
|
||||
No matching config fields found.
|
||||
</div>
|
||||
) : (
|
||||
filteredSections.map(({ section, fields }) => (
|
||||
<ConfigSection
|
||||
key={section.path || '_root'}
|
||||
section={fields ? { ...section, defaultCollapsed: false } : section}
|
||||
getFieldValue={getFieldValue}
|
||||
setFieldValue={setFieldValue}
|
||||
isFieldMasked={isFieldMasked}
|
||||
visibleFields={fields}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
interface Props {
|
||||
rawToml: string;
|
||||
onChange: (raw: string) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export default function ConfigRawEditor({ rawToml, onChange, disabled }: Props) {
|
||||
return (
|
||||
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
|
||||
<div className="flex items-center justify-between px-4 py-2 border-b border-gray-800 bg-gray-800/50">
|
||||
<span className="text-xs text-gray-400 font-medium uppercase tracking-wider">
|
||||
TOML Configuration
|
||||
</span>
|
||||
<span className="text-xs text-gray-500">
|
||||
{rawToml.split('\n').length} lines
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
value={rawToml}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
disabled={disabled}
|
||||
spellCheck={false}
|
||||
aria-label="Raw TOML configuration editor"
|
||||
className="w-full min-h-[500px] bg-gray-950 text-gray-200 font-mono text-sm p-4 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-inset disabled:opacity-50"
|
||||
style={{ tabSize: 4 }}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { ChevronRight, ChevronDown } from 'lucide-react';
|
||||
import type { SectionDef, FieldDef } from './types';
|
||||
import TextField from './fields/TextField';
|
||||
import NumberField from './fields/NumberField';
|
||||
import ToggleField from './fields/ToggleField';
|
||||
import SelectField from './fields/SelectField';
|
||||
import TagListField from './fields/TagListField';
|
||||
|
||||
interface Props {
|
||||
section: SectionDef;
|
||||
getFieldValue: (sectionPath: string, fieldKey: string) => unknown;
|
||||
setFieldValue: (sectionPath: string, fieldKey: string, value: unknown) => void;
|
||||
isFieldMasked: (sectionPath: string, fieldKey: string) => boolean;
|
||||
visibleFields?: FieldDef[];
|
||||
}
|
||||
|
||||
function renderField(
|
||||
field: FieldDef,
|
||||
value: unknown,
|
||||
onChange: (v: unknown) => void,
|
||||
isMasked: boolean,
|
||||
) {
|
||||
const props = { field, value, onChange, isMasked };
|
||||
switch (field.type) {
|
||||
case 'text':
|
||||
case 'password':
|
||||
return <TextField {...props} />;
|
||||
case 'number':
|
||||
return <NumberField {...props} />;
|
||||
case 'toggle':
|
||||
return <ToggleField {...props} />;
|
||||
case 'select':
|
||||
return <SelectField {...props} />;
|
||||
case 'tag-list':
|
||||
return <TagListField {...props} />;
|
||||
default:
|
||||
return <TextField {...props} />;
|
||||
}
|
||||
}
|
||||
|
||||
export default function ConfigSection({
|
||||
section,
|
||||
getFieldValue,
|
||||
setFieldValue,
|
||||
isFieldMasked,
|
||||
visibleFields,
|
||||
}: Props) {
|
||||
const [collapsed, setCollapsed] = useState(section.defaultCollapsed ?? false);
|
||||
const sectionPanelId = useMemo(
|
||||
() =>
|
||||
`config-section-${(section.path || 'root').replace(/[^a-zA-Z0-9_-]/g, '-')}`,
|
||||
[section.path],
|
||||
);
|
||||
const Icon = section.icon;
|
||||
const fields = visibleFields ?? section.fields;
|
||||
|
||||
useEffect(() => {
|
||||
setCollapsed(section.defaultCollapsed ?? false);
|
||||
}, [section.path, section.defaultCollapsed]);
|
||||
|
||||
return (
|
||||
<div className="bg-gray-900 rounded-xl border border-gray-800">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setCollapsed(!collapsed)}
|
||||
aria-expanded={!collapsed}
|
||||
aria-controls={sectionPanelId}
|
||||
className="w-full flex items-center gap-3 px-4 py-3 hover:bg-gray-800/30 transition-colors rounded-t-xl"
|
||||
>
|
||||
{collapsed ? (
|
||||
<ChevronRight className="h-4 w-4 text-gray-500 flex-shrink-0" />
|
||||
) : (
|
||||
<ChevronDown className="h-4 w-4 text-gray-500 flex-shrink-0" />
|
||||
)}
|
||||
<Icon className="h-4 w-4 text-blue-400 flex-shrink-0" />
|
||||
<span className="text-sm font-medium text-white">{section.title}</span>
|
||||
{section.description && (
|
||||
<span className="text-xs text-gray-500 hidden sm:inline">
|
||||
— {section.description}
|
||||
</span>
|
||||
)}
|
||||
<span className="ml-auto text-xs text-gray-600">
|
||||
{fields.length} {fields.length === 1 ? 'field' : 'fields'}
|
||||
</span>
|
||||
</button>
|
||||
|
||||
{!collapsed && (
|
||||
<div
|
||||
id={sectionPanelId}
|
||||
className="border-t border-gray-800 px-4 py-4 grid grid-cols-1 sm:grid-cols-2 gap-x-4 gap-y-4"
|
||||
>
|
||||
{fields.map((field) => {
|
||||
const value = getFieldValue(section.path, field.key);
|
||||
const masked = isFieldMasked(section.path, field.key);
|
||||
const spanFull = field.type === 'tag-list';
|
||||
|
||||
return (
|
||||
<div key={field.key} className={`flex flex-col${spanFull ? ' sm:col-span-2' : ''}`}>
|
||||
<label className="flex items-center gap-2 text-sm font-medium text-gray-300 mb-1.5">
|
||||
<span>{field.label}</span>
|
||||
{field.sensitive && (
|
||||
<span className="text-[10px] text-yellow-400 bg-yellow-900/30 border border-yellow-800/50 px-1.5 py-0.5 rounded">
|
||||
sensitive
|
||||
</span>
|
||||
)}
|
||||
{masked && (
|
||||
<span className="text-[10px] text-blue-400 bg-blue-900/30 border border-blue-800/50 px-1.5 py-0.5 rounded">
|
||||
masked
|
||||
</span>
|
||||
)}
|
||||
</label>
|
||||
{field.description && field.type !== 'text' && field.type !== 'password' && field.type !== 'number' && (
|
||||
<p className="text-xs text-gray-500 mb-1.5">{field.description}</p>
|
||||
)}
|
||||
<div className="mt-auto">
|
||||
{renderField(
|
||||
field,
|
||||
value,
|
||||
(v) => setFieldValue(section.path, field.key, v),
|
||||
masked,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,41 @@
|
||||
import type { FieldProps } from '../types';
|
||||
|
||||
export default function NumberField({ field, value, onChange }: FieldProps) {
|
||||
const numValue = value === undefined || value === null || value === '' ? '' : Number(value);
|
||||
|
||||
return (
|
||||
<input
|
||||
type="number"
|
||||
value={numValue}
|
||||
onChange={(e) => {
|
||||
const raw = e.target.value;
|
||||
if (raw === '') {
|
||||
onChange(undefined);
|
||||
return;
|
||||
}
|
||||
const n = Number(raw);
|
||||
if (!isNaN(n)) {
|
||||
onChange(n);
|
||||
}
|
||||
}}
|
||||
onBlur={(e) => {
|
||||
if (field.step !== undefined && field.step < 1) {
|
||||
return;
|
||||
}
|
||||
const raw = e.target.value;
|
||||
if (raw === '') {
|
||||
return;
|
||||
}
|
||||
const n = Number(raw);
|
||||
if (!isNaN(n)) {
|
||||
onChange(Math.floor(n));
|
||||
}
|
||||
}}
|
||||
min={field.min}
|
||||
max={field.max}
|
||||
step={field.step ?? 1}
|
||||
placeholder={field.description ?? ''}
|
||||
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
import type { FieldProps } from '../types';
|
||||
|
||||
export default function SelectField({ field, value, onChange }: FieldProps) {
|
||||
const strValue = (value as string) ?? '';
|
||||
|
||||
return (
|
||||
<select
|
||||
value={strValue}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
>
|
||||
<option value="">Select...</option>
|
||||
{field.options?.map((opt) => (
|
||||
<option key={opt.value} value={opt.value}>
|
||||
{opt.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
import { useState } from 'react';
|
||||
import { X } from 'lucide-react';
|
||||
import type { FieldProps } from '../types';
|
||||
|
||||
export default function TagListField({ field, value, onChange }: FieldProps) {
|
||||
const [input, setInput] = useState('');
|
||||
const tags: string[] = Array.isArray(value) ? value : [];
|
||||
|
||||
const addTag = (tag: string) => {
|
||||
const trimmed = tag.trim();
|
||||
if (trimmed && !tags.includes(trimmed)) {
|
||||
onChange([...tags, trimmed]);
|
||||
}
|
||||
setInput('');
|
||||
};
|
||||
|
||||
const removeTag = (index: number) => {
|
||||
onChange(tags.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter' || e.key === ',') {
|
||||
e.preventDefault();
|
||||
addTag(input);
|
||||
} else if (e.key === 'Backspace' && input === '' && tags.length > 0) {
|
||||
removeTag(tags.length - 1);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-wrap gap-1.5 mb-2">
|
||||
{tags.map((tag, i) => (
|
||||
<span
|
||||
key={tag}
|
||||
className="inline-flex items-center gap-1 bg-gray-700 text-gray-200 rounded-full px-2.5 py-0.5 text-xs"
|
||||
>
|
||||
{tag}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => removeTag(i)}
|
||||
className="text-gray-400 hover:text-white transition-colors"
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</button>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
onBlur={() => { if (input.trim()) addTag(input); }}
|
||||
placeholder={field.tagPlaceholder ?? 'Type and press Enter to add'}
|
||||
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
import { useState } from 'react';
|
||||
import { Eye, EyeOff, Lock } from 'lucide-react';
|
||||
import type { FieldProps } from '../types';
|
||||
|
||||
export default function TextField({ field, value, onChange, isMasked }: FieldProps) {
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
const isPassword = field.type === 'password';
|
||||
const strValue = isMasked ? '' : ((value as string) ?? '');
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<input
|
||||
type={isPassword && !showPassword ? 'password' : 'text'}
|
||||
value={strValue}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
placeholder={isMasked ? 'Configured (masked)' : field.description ?? ''}
|
||||
className="w-full bg-gray-800 border border-gray-700 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500 pr-16"
|
||||
/>
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2 flex items-center gap-1">
|
||||
{isMasked && (
|
||||
<Lock className="h-3.5 w-3.5 text-yellow-500" />
|
||||
)}
|
||||
{isPassword && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowPassword(!showPassword)}
|
||||
className="p-1 text-gray-400 hover:text-gray-200 transition-colors"
|
||||
>
|
||||
{showPassword ? (
|
||||
<EyeOff className="h-3.5 w-3.5" />
|
||||
) : (
|
||||
<Eye className="h-3.5 w-3.5" />
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
import type { FieldProps } from '../types';
|
||||
|
||||
export default function ToggleField({ field, value, onChange }: FieldProps) {
|
||||
const isOn = Boolean(value);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
aria-checked={isOn}
|
||||
aria-label={field.label}
|
||||
onClick={() => onChange(!isOn)}
|
||||
className={`relative inline-flex h-6 w-11 items-center rounded-full transition-colors ${
|
||||
isOn ? 'bg-blue-600' : 'bg-gray-700'
|
||||
}`}
|
||||
>
|
||||
<span
|
||||
className={`inline-block h-4 w-4 transform rounded-full bg-white transition-transform ${
|
||||
isOn ? 'translate-x-6' : 'translate-x-1'
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<span className="text-sm text-gray-400">{isOn ? 'Enabled' : 'Disabled'}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
|
||||
export type FieldType =
|
||||
| 'text'
|
||||
| 'password'
|
||||
| 'number'
|
||||
| 'toggle'
|
||||
| 'select'
|
||||
| 'tag-list';
|
||||
|
||||
export interface FieldDef {
|
||||
key: string;
|
||||
label: string;
|
||||
type: FieldType;
|
||||
description?: string;
|
||||
sensitive?: boolean;
|
||||
defaultValue?: unknown;
|
||||
options?: { value: string; label: string }[];
|
||||
min?: number;
|
||||
max?: number;
|
||||
step?: number;
|
||||
tagPlaceholder?: string;
|
||||
}
|
||||
|
||||
export interface SectionDef {
|
||||
path: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
icon: LucideIcon;
|
||||
fields: FieldDef[];
|
||||
defaultCollapsed?: boolean;
|
||||
category?: string;
|
||||
}
|
||||
|
||||
export interface FieldProps {
|
||||
field: FieldDef;
|
||||
value: unknown;
|
||||
onChange: (value: unknown) => void;
|
||||
isMasked: boolean;
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
import { useState, useCallback, useRef, useEffect } from 'react';
|
||||
import { parse, stringify } from 'smol-toml';
|
||||
import { getConfig, putConfig } from '@/lib/api';
|
||||
|
||||
const MASKED = '***MASKED***';
|
||||
|
||||
type ParsedConfig = Record<string, unknown>;
|
||||
|
||||
function deepClone<T>(obj: T): T {
|
||||
return JSON.parse(JSON.stringify(obj));
|
||||
}
|
||||
|
||||
/** Recursively scan for MASKED strings and collect their dotted paths. */
|
||||
function scanMasked(obj: unknown, prefix: string, out: Set<string>) {
|
||||
if (obj === null || obj === undefined) return;
|
||||
if (typeof obj === 'string' && obj === MASKED) {
|
||||
out.add(prefix);
|
||||
return;
|
||||
}
|
||||
if (Array.isArray(obj)) {
|
||||
obj.forEach((item, i) => {
|
||||
scanMasked(item, `${prefix}.${i}`, out);
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (typeof obj === 'object') {
|
||||
for (const [k, v] of Object.entries(obj as Record<string, unknown>)) {
|
||||
scanMasked(v, prefix ? `${prefix}.${k}` : k, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Navigate into an object by dotted path segments, returning the value. */
|
||||
function getNestedValue(obj: unknown, segments: string[]): unknown {
|
||||
let current: unknown = obj;
|
||||
for (const seg of segments) {
|
||||
if (current === null || current === undefined || typeof current !== 'object') return undefined;
|
||||
current = (current as Record<string, unknown>)[seg];
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
/** Set a value in an object by dotted path segments, creating intermediates. */
|
||||
function setNestedValue(obj: Record<string, unknown>, segments: string[], value: unknown) {
|
||||
if (segments.length === 0) return;
|
||||
let current: Record<string, unknown> = obj;
|
||||
for (let i = 0; i < segments.length - 1; i++) {
|
||||
const seg: string = segments[i]!;
|
||||
if (current[seg] === undefined || current[seg] === null || typeof current[seg] !== 'object') {
|
||||
current[seg] = {};
|
||||
}
|
||||
current = current[seg] as Record<string, unknown>;
|
||||
}
|
||||
const lastSeg: string = segments[segments.length - 1]!;
|
||||
if (value === undefined || value === '') {
|
||||
delete current[lastSeg];
|
||||
} else {
|
||||
current[lastSeg] = value;
|
||||
}
|
||||
}
|
||||
|
||||
export type EditorMode = 'form' | 'raw';
|
||||
|
||||
export interface ConfigFormState {
|
||||
loading: boolean;
|
||||
saving: boolean;
|
||||
error: string | null;
|
||||
success: string | null;
|
||||
mode: EditorMode;
|
||||
rawToml: string;
|
||||
parsed: ParsedConfig;
|
||||
maskedPaths: Set<string>;
|
||||
dirtyPaths: Set<string>;
|
||||
setMode: (mode: EditorMode) => boolean;
|
||||
getFieldValue: (sectionPath: string, fieldKey: string) => unknown;
|
||||
setFieldValue: (sectionPath: string, fieldKey: string, value: unknown) => void;
|
||||
isFieldMasked: (sectionPath: string, fieldKey: string) => boolean;
|
||||
isFieldDirty: (sectionPath: string, fieldKey: string) => boolean;
|
||||
setRawToml: (raw: string) => void;
|
||||
save: () => Promise<void>;
|
||||
reload: () => Promise<void>;
|
||||
clearMessages: () => void;
|
||||
}
|
||||
|
||||
export function useConfigForm(): ConfigFormState {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [success, setSuccess] = useState<string | null>(null);
|
||||
const [mode, setModeState] = useState<EditorMode>('form');
|
||||
const [rawToml, setRawTomlState] = useState('');
|
||||
const [parsed, setParsed] = useState<ParsedConfig>({});
|
||||
const maskedPathsRef = useRef<Set<string>>(new Set());
|
||||
const dirtyPathsRef = useRef<Set<string>>(new Set());
|
||||
const successTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const [, forceRender] = useState(0);
|
||||
|
||||
const loadConfig = useCallback(async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const data = await getConfig();
|
||||
const raw = typeof data === 'string' ? data : JSON.stringify(data, null, 2);
|
||||
setRawTomlState(raw);
|
||||
|
||||
try {
|
||||
const obj = parse(raw) as ParsedConfig;
|
||||
setParsed(obj);
|
||||
const masked = new Set<string>();
|
||||
scanMasked(obj, '', masked);
|
||||
maskedPathsRef.current = masked;
|
||||
} catch {
|
||||
// If TOML parse fails, start in raw mode
|
||||
setParsed({});
|
||||
maskedPathsRef.current = new Set();
|
||||
setModeState('raw');
|
||||
}
|
||||
|
||||
dirtyPathsRef.current = new Set();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load configuration');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Load once on mount.
|
||||
const hasLoaded = useRef(false);
|
||||
useEffect(() => {
|
||||
if (!hasLoaded.current) {
|
||||
hasLoaded.current = true;
|
||||
void loadConfig();
|
||||
}
|
||||
}, [loadConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (successTimeoutRef.current) {
|
||||
clearTimeout(successTimeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const fieldPath = (sectionPath: string, fieldKey: string) =>
|
||||
sectionPath ? `${sectionPath}.${fieldKey}` : fieldKey;
|
||||
|
||||
const fieldSegments = (sectionPath: string, fieldKey: string) => {
|
||||
const full = fieldPath(sectionPath, fieldKey);
|
||||
return full.split('.').filter(Boolean);
|
||||
};
|
||||
|
||||
const getFieldValue = useCallback(
|
||||
(sectionPath: string, fieldKey: string): unknown => {
|
||||
const segments = fieldSegments(sectionPath, fieldKey);
|
||||
return getNestedValue(parsed, segments);
|
||||
},
|
||||
[parsed],
|
||||
);
|
||||
|
||||
const setFieldValue = useCallback(
|
||||
(sectionPath: string, fieldKey: string, value: unknown) => {
|
||||
const fp = fieldPath(sectionPath, fieldKey);
|
||||
const segments = fieldSegments(sectionPath, fieldKey);
|
||||
|
||||
setParsed((prev) => {
|
||||
const next = deepClone(prev);
|
||||
setNestedValue(next, segments, value);
|
||||
return next;
|
||||
});
|
||||
|
||||
dirtyPathsRef.current.add(fp);
|
||||
forceRender((n) => n + 1);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const isFieldMasked = useCallback(
|
||||
(sectionPath: string, fieldKey: string): boolean => {
|
||||
const fp = fieldPath(sectionPath, fieldKey);
|
||||
return maskedPathsRef.current.has(fp) && !dirtyPathsRef.current.has(fp);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const isFieldDirty = useCallback(
|
||||
(sectionPath: string, fieldKey: string): boolean => {
|
||||
const fp = fieldPath(sectionPath, fieldKey);
|
||||
return dirtyPathsRef.current.has(fp);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const syncFormToRaw = useCallback((): string => {
|
||||
try {
|
||||
const toml = stringify(parsed);
|
||||
return toml;
|
||||
} catch {
|
||||
return rawToml;
|
||||
}
|
||||
}, [parsed, rawToml]);
|
||||
|
||||
const syncRawToForm = useCallback(
|
||||
(raw: string): boolean => {
|
||||
try {
|
||||
const obj = parse(raw) as ParsedConfig;
|
||||
setParsed(obj);
|
||||
// Re-scan masked paths from fresh parse, preserving dirty overrides
|
||||
const masked = new Set<string>();
|
||||
scanMasked(obj, '', masked);
|
||||
maskedPathsRef.current = masked;
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const setMode = useCallback(
|
||||
(newMode: EditorMode): boolean => {
|
||||
if (newMode === mode) return true;
|
||||
|
||||
if (newMode === 'raw') {
|
||||
// form → raw: serialize parsed to TOML
|
||||
const toml = syncFormToRaw();
|
||||
setRawTomlState(toml);
|
||||
setModeState('raw');
|
||||
return true;
|
||||
} else {
|
||||
// raw → form: parse TOML
|
||||
if (syncRawToForm(rawToml)) {
|
||||
setModeState('form');
|
||||
return true;
|
||||
} else {
|
||||
setError('Invalid TOML syntax. Fix errors before switching to Form view.');
|
||||
return false;
|
||||
}
|
||||
}
|
||||
},
|
||||
[mode, syncFormToRaw, syncRawToForm, rawToml],
|
||||
);
|
||||
|
||||
const setRawToml = useCallback((raw: string) => {
|
||||
setRawTomlState(raw);
|
||||
}, []);
|
||||
|
||||
const save = useCallback(async () => {
|
||||
setSaving(true);
|
||||
setError(null);
|
||||
setSuccess(null);
|
||||
if (successTimeoutRef.current) {
|
||||
clearTimeout(successTimeoutRef.current);
|
||||
}
|
||||
|
||||
try {
|
||||
let toml: string;
|
||||
if (mode === 'form') {
|
||||
toml = syncFormToRaw();
|
||||
} else {
|
||||
toml = rawToml;
|
||||
}
|
||||
await putConfig(toml);
|
||||
setSuccess('Configuration saved successfully.');
|
||||
|
||||
// Auto-dismiss success after 4 seconds
|
||||
successTimeoutRef.current = setTimeout(() => setSuccess(null), 4000);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to save configuration');
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
}, [mode, syncFormToRaw, rawToml]);
|
||||
|
||||
const reload = useCallback(async () => {
|
||||
await loadConfig();
|
||||
}, [loadConfig]);
|
||||
|
||||
const clearMessages = useCallback(() => {
|
||||
setError(null);
|
||||
setSuccess(null);
|
||||
if (successTimeoutRef.current) {
|
||||
clearTimeout(successTimeoutRef.current);
|
||||
successTimeoutRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
return {
|
||||
loading,
|
||||
saving,
|
||||
error,
|
||||
success,
|
||||
mode,
|
||||
rawToml,
|
||||
parsed,
|
||||
maskedPaths: maskedPathsRef.current,
|
||||
dirtyPaths: dirtyPathsRef.current,
|
||||
setMode,
|
||||
getFieldValue,
|
||||
setFieldValue,
|
||||
isFieldMasked,
|
||||
isFieldDirty,
|
||||
setRawToml,
|
||||
save,
|
||||
reload,
|
||||
clearMessages,
|
||||
};
|
||||
}
|
||||
+17
-1
@@ -19,6 +19,7 @@ export interface WebSocketClientOptions {
|
||||
|
||||
const DEFAULT_RECONNECT_DELAY = 1000;
|
||||
const MAX_RECONNECT_DELAY = 30000;
|
||||
const WS_SESSION_STORAGE_KEY = 'zeroclaw.ws.session_id';
|
||||
|
||||
export class WebSocketClient {
|
||||
private ws: WebSocket | null = null;
|
||||
@@ -35,6 +36,7 @@ export class WebSocketClient {
|
||||
private readonly reconnectDelay: number;
|
||||
private readonly maxReconnectDelay: number;
|
||||
private readonly autoReconnect: boolean;
|
||||
private readonly sessionId: string;
|
||||
|
||||
constructor(options: WebSocketClientOptions = {}) {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
@@ -44,6 +46,7 @@ export class WebSocketClient {
|
||||
this.maxReconnectDelay = options.maxReconnectDelay ?? MAX_RECONNECT_DELAY;
|
||||
this.autoReconnect = options.autoReconnect ?? true;
|
||||
this.currentDelay = this.reconnectDelay;
|
||||
this.sessionId = this.resolveSessionId();
|
||||
}
|
||||
|
||||
/** Open the WebSocket connection. */
|
||||
@@ -52,7 +55,7 @@ export class WebSocketClient {
|
||||
this.clearReconnectTimer();
|
||||
|
||||
const token = getToken();
|
||||
const url = `${this.baseUrl}/ws/chat`;
|
||||
const url = `${this.baseUrl}/ws/chat?session_id=${encodeURIComponent(this.sessionId)}`;
|
||||
const protocols = ['zeroclaw.v1'];
|
||||
if (token) {
|
||||
protocols.push(`bearer.${token}`);
|
||||
@@ -126,4 +129,17 @@ export class WebSocketClient {
|
||||
this.reconnectTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private resolveSessionId(): string {
|
||||
const existing = window.localStorage.getItem(WS_SESSION_STORAGE_KEY);
|
||||
if (existing && /^[A-Za-z0-9_-]{1,128}$/.test(existing)) {
|
||||
return existing;
|
||||
}
|
||||
|
||||
const generated =
|
||||
globalThis.crypto?.randomUUID?.().replace(/-/g, '_') ??
|
||||
`sess_${Date.now().toString(36)}_${Math.random().toString(36).slice(2, 10)}`;
|
||||
window.localStorage.setItem(WS_SESSION_STORAGE_KEY, generated);
|
||||
return generated;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,22 @@ export default function AgentChat() {
|
||||
|
||||
ws.onMessage = (msg: WsMessage) => {
|
||||
switch (msg.type) {
|
||||
case 'history': {
|
||||
const restored = (msg.messages ?? [])
|
||||
.filter((entry) => entry.content?.trim())
|
||||
.map((entry) => ({
|
||||
id: makeMessageId(),
|
||||
role: entry.role === 'user' ? 'user' : 'agent',
|
||||
content: entry.content.trim(),
|
||||
timestamp: new Date(),
|
||||
}));
|
||||
|
||||
setMessages(restored);
|
||||
setTyping(false);
|
||||
pendingContentRef.current = '';
|
||||
break;
|
||||
}
|
||||
|
||||
case 'chunk':
|
||||
setTyping(true);
|
||||
pendingContentRef.current += msg.content ?? '';
|
||||
|
||||
+93
-65
@@ -1,50 +1,61 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import {
|
||||
Settings,
|
||||
Save,
|
||||
CheckCircle,
|
||||
AlertTriangle,
|
||||
ShieldAlert,
|
||||
FileText,
|
||||
SlidersHorizontal,
|
||||
} from 'lucide-react';
|
||||
import { getConfig, putConfig } from '@/lib/api';
|
||||
import { useConfigForm, type EditorMode } from '@/components/config/useConfigForm';
|
||||
import ConfigFormEditor from '@/components/config/ConfigFormEditor';
|
||||
import ConfigRawEditor from '@/components/config/ConfigRawEditor';
|
||||
|
||||
function ModeTab({
|
||||
mode,
|
||||
active,
|
||||
icon: Icon,
|
||||
label,
|
||||
onClick,
|
||||
}: {
|
||||
mode: EditorMode;
|
||||
active: boolean;
|
||||
icon: React.ComponentType<{ className?: string }>;
|
||||
label: string;
|
||||
onClick: () => void;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-lg text-sm font-medium transition-colors ${
|
||||
active
|
||||
? 'bg-blue-600 text-white'
|
||||
: 'text-gray-400 hover:text-gray-200 hover:bg-gray-800'
|
||||
}`}
|
||||
aria-pressed={active}
|
||||
data-mode={mode}
|
||||
>
|
||||
<Icon className="h-3.5 w-3.5" />
|
||||
{label}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Config() {
|
||||
const [config, setConfig] = useState('');
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [success, setSuccess] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
getConfig()
|
||||
.then((data) => {
|
||||
// The API may return either a raw string or a JSON string
|
||||
setConfig(typeof data === 'string' ? data : JSON.stringify(data, null, 2));
|
||||
})
|
||||
.catch((err) => setError(err.message))
|
||||
.finally(() => setLoading(false));
|
||||
}, []);
|
||||
|
||||
const handleSave = async () => {
|
||||
setSaving(true);
|
||||
setError(null);
|
||||
setSuccess(null);
|
||||
try {
|
||||
await putConfig(config);
|
||||
setSuccess('Configuration saved successfully.');
|
||||
} catch (err: unknown) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to save configuration');
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Auto-dismiss success after 4 seconds
|
||||
useEffect(() => {
|
||||
if (!success) return;
|
||||
const timer = setTimeout(() => setSuccess(null), 4000);
|
||||
return () => clearTimeout(timer);
|
||||
}, [success]);
|
||||
const {
|
||||
loading,
|
||||
saving,
|
||||
error,
|
||||
success,
|
||||
mode,
|
||||
rawToml,
|
||||
setMode,
|
||||
getFieldValue,
|
||||
setFieldValue,
|
||||
isFieldMasked,
|
||||
setRawToml,
|
||||
save,
|
||||
} = useConfigForm();
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
@@ -62,14 +73,34 @@ export default function Config() {
|
||||
<Settings className="h-5 w-5 text-blue-400" />
|
||||
<h2 className="text-base font-semibold text-white">Configuration</h2>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={saving}
|
||||
className="flex items-center gap-2 bg-blue-600 hover:bg-blue-700 text-white text-sm font-medium px-4 py-2 rounded-lg transition-colors disabled:opacity-50"
|
||||
>
|
||||
<Save className="h-4 w-4" />
|
||||
{saving ? 'Saving...' : 'Save'}
|
||||
</button>
|
||||
<div className="flex items-center gap-3">
|
||||
{/* Mode toggle */}
|
||||
<div className="flex items-center gap-1 bg-gray-900 border border-gray-800 rounded-lg p-0.5">
|
||||
<ModeTab
|
||||
mode="form"
|
||||
active={mode === 'form'}
|
||||
icon={SlidersHorizontal}
|
||||
label="Form"
|
||||
onClick={() => setMode('form')}
|
||||
/>
|
||||
<ModeTab
|
||||
mode="raw"
|
||||
active={mode === 'raw'}
|
||||
icon={FileText}
|
||||
label="Raw"
|
||||
onClick={() => setMode('raw')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={save}
|
||||
disabled={saving}
|
||||
className="flex items-center gap-2 bg-blue-600 hover:bg-blue-700 text-white text-sm font-medium px-4 py-2 rounded-lg transition-colors disabled:opacity-50"
|
||||
>
|
||||
<Save className="h-4 w-4" />
|
||||
{saving ? 'Saving...' : 'Save'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Sensitive fields note */}
|
||||
@@ -80,8 +111,9 @@ export default function Config() {
|
||||
Sensitive fields are masked
|
||||
</p>
|
||||
<p className="text-sm text-yellow-400/70 mt-0.5">
|
||||
API keys, tokens, and passwords are hidden for security. To update a
|
||||
masked field, replace the entire masked value with your new value.
|
||||
{mode === 'form'
|
||||
? 'Masked fields show "Configured (masked)" as a placeholder. Leave them untouched to preserve existing values, or enter a new value to update.'
|
||||
: 'API keys, tokens, and passwords are hidden for security. To update a masked field, replace the entire masked value with your new value.'}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -102,24 +134,20 @@ export default function Config() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Config Editor */}
|
||||
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
|
||||
<div className="flex items-center justify-between px-4 py-2 border-b border-gray-800 bg-gray-800/50">
|
||||
<span className="text-xs text-gray-400 font-medium uppercase tracking-wider">
|
||||
TOML Configuration
|
||||
</span>
|
||||
<span className="text-xs text-gray-500">
|
||||
{config.split('\n').length} lines
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
value={config}
|
||||
onChange={(e) => setConfig(e.target.value)}
|
||||
spellCheck={false}
|
||||
className="w-full min-h-[500px] bg-gray-950 text-gray-200 font-mono text-sm p-4 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-inset"
|
||||
style={{ tabSize: 4 }}
|
||||
{/* Editor */}
|
||||
{mode === 'form' ? (
|
||||
<ConfigFormEditor
|
||||
getFieldValue={getFieldValue}
|
||||
setFieldValue={setFieldValue}
|
||||
isFieldMasked={isFieldMasked}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<ConfigRawEditor
|
||||
rawToml={rawToml}
|
||||
onChange={setRawToml}
|
||||
disabled={saving}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -131,11 +131,16 @@ export interface SSEEvent {
|
||||
}
|
||||
|
||||
export interface WsMessage {
|
||||
type: 'message' | 'chunk' | 'tool_call' | 'tool_result' | 'done' | 'error';
|
||||
type: 'message' | 'chunk' | 'tool_call' | 'tool_result' | 'done' | 'error' | 'history';
|
||||
content?: string;
|
||||
full_response?: string;
|
||||
name?: string;
|
||||
args?: any;
|
||||
output?: string;
|
||||
message?: string;
|
||||
session_id?: string;
|
||||
messages?: Array<{
|
||||
role: 'user' | 'assistant';
|
||||
content: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user