Compare commits
26 Commits
dependabot
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
feaca20582 | ||
|
|
40af505e90 | ||
|
|
f812dbcb85 | ||
|
|
59225d97b3 | ||
|
|
483f773e1d | ||
|
|
b4bbe820a2 | ||
|
|
1702bb2747 | ||
|
|
b6f661c3c5 | ||
|
|
ac543cff20 | ||
|
|
c88affa020 | ||
|
|
67998ad702 | ||
|
|
c104b23ddb | ||
|
|
698adca707 | ||
|
|
50a877b4c1 | ||
|
|
fa7b615508 | ||
|
|
ea9eccfe8b | ||
|
|
eb036b4d95 | ||
|
|
1c07d5b411 | ||
|
|
0fe3834349 | ||
|
|
33f9d66b54 | ||
|
|
368f39829f | ||
|
|
9376c26018 | ||
|
|
08e131d7c6 | ||
|
|
36db977b35 | ||
|
|
92b0ebb61a | ||
|
|
9c312180a2 |
301
.github/labeler.yml
vendored
301
.github/labeler.yml
vendored
@ -36,6 +36,145 @@
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/**"
|
||||
|
||||
"channel:bluesky":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/bluesky.rs"
|
||||
|
||||
"channel:clawdtalk":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/clawdtalk.rs"
|
||||
|
||||
"channel:cli":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/cli.rs"
|
||||
|
||||
"channel:dingtalk":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/dingtalk.rs"
|
||||
|
||||
"channel:discord":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/discord.rs"
|
||||
- "src/channels/discord_history.rs"
|
||||
|
||||
"channel:email":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/email_channel.rs"
|
||||
- "src/channels/gmail_push.rs"
|
||||
|
||||
"channel:imessage":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/imessage.rs"
|
||||
|
||||
"channel:irc":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/irc.rs"
|
||||
|
||||
"channel:lark":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/lark.rs"
|
||||
|
||||
"channel:linq":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/linq.rs"
|
||||
|
||||
"channel:matrix":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/matrix.rs"
|
||||
|
||||
"channel:mattermost":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/mattermost.rs"
|
||||
|
||||
"channel:mochat":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/mochat.rs"
|
||||
|
||||
"channel:mqtt":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/mqtt.rs"
|
||||
|
||||
"channel:nextcloud-talk":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/nextcloud_talk.rs"
|
||||
|
||||
"channel:nostr":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/nostr.rs"
|
||||
|
||||
"channel:notion":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/notion.rs"
|
||||
|
||||
"channel:qq":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/qq.rs"
|
||||
|
||||
"channel:reddit":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/reddit.rs"
|
||||
|
||||
"channel:signal":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/signal.rs"
|
||||
|
||||
"channel:slack":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/slack.rs"
|
||||
|
||||
"channel:telegram":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/telegram.rs"
|
||||
|
||||
"channel:twitter":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/twitter.rs"
|
||||
|
||||
"channel:wati":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/wati.rs"
|
||||
|
||||
"channel:webhook":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/webhook.rs"
|
||||
|
||||
"channel:wecom":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/wecom.rs"
|
||||
|
||||
"channel:whatsapp":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/channels/whatsapp.rs"
|
||||
- "src/channels/whatsapp_storage.rs"
|
||||
- "src/channels/whatsapp_web.rs"
|
||||
|
||||
"gateway":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
@ -101,6 +240,73 @@
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/**"
|
||||
|
||||
"provider:anthropic":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/anthropic.rs"
|
||||
|
||||
"provider:azure-openai":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/azure_openai.rs"
|
||||
|
||||
"provider:bedrock":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/bedrock.rs"
|
||||
|
||||
"provider:claude-code":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/claude_code.rs"
|
||||
|
||||
"provider:compatible":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/compatible.rs"
|
||||
|
||||
"provider:copilot":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/copilot.rs"
|
||||
|
||||
"provider:gemini":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/gemini.rs"
|
||||
- "src/providers/gemini_cli.rs"
|
||||
|
||||
"provider:glm":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/glm.rs"
|
||||
|
||||
"provider:kilocli":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/kilocli.rs"
|
||||
|
||||
"provider:ollama":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/ollama.rs"
|
||||
|
||||
"provider:openai":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/openai.rs"
|
||||
- "src/providers/openai_codex.rs"
|
||||
|
||||
"provider:openrouter":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/openrouter.rs"
|
||||
|
||||
"provider:telnyx":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/providers/telnyx.rs"
|
||||
|
||||
"service":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
@ -121,6 +327,101 @@
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/**"
|
||||
|
||||
"tool:browser":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/browser.rs"
|
||||
- "src/tools/browser_delegate.rs"
|
||||
- "src/tools/browser_open.rs"
|
||||
- "src/tools/text_browser.rs"
|
||||
- "src/tools/screenshot.rs"
|
||||
|
||||
"tool:composio":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/composio.rs"
|
||||
|
||||
"tool:cron":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/cron_add.rs"
|
||||
- "src/tools/cron_list.rs"
|
||||
- "src/tools/cron_remove.rs"
|
||||
- "src/tools/cron_run.rs"
|
||||
- "src/tools/cron_runs.rs"
|
||||
- "src/tools/cron_update.rs"
|
||||
|
||||
"tool:file":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/file_edit.rs"
|
||||
- "src/tools/file_read.rs"
|
||||
- "src/tools/file_write.rs"
|
||||
- "src/tools/glob_search.rs"
|
||||
- "src/tools/content_search.rs"
|
||||
|
||||
"tool:google-workspace":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/google_workspace.rs"
|
||||
|
||||
"tool:mcp":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/mcp_client.rs"
|
||||
- "src/tools/mcp_deferred.rs"
|
||||
- "src/tools/mcp_protocol.rs"
|
||||
- "src/tools/mcp_tool.rs"
|
||||
- "src/tools/mcp_transport.rs"
|
||||
|
||||
"tool:memory":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/memory_forget.rs"
|
||||
- "src/tools/memory_recall.rs"
|
||||
- "src/tools/memory_store.rs"
|
||||
|
||||
"tool:microsoft365":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/microsoft365/**"
|
||||
|
||||
"tool:shell":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/shell.rs"
|
||||
- "src/tools/node_tool.rs"
|
||||
- "src/tools/cli_discovery.rs"
|
||||
|
||||
"tool:sop":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/sop_advance.rs"
|
||||
- "src/tools/sop_approve.rs"
|
||||
- "src/tools/sop_execute.rs"
|
||||
- "src/tools/sop_list.rs"
|
||||
- "src/tools/sop_status.rs"
|
||||
|
||||
"tool:web":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/web_fetch.rs"
|
||||
- "src/tools/web_search_tool.rs"
|
||||
- "src/tools/web_search_provider_routing.rs"
|
||||
- "src/tools/http_request.rs"
|
||||
|
||||
"tool:security":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/security_ops.rs"
|
||||
- "src/tools/verifiable_intent.rs"
|
||||
|
||||
"tool:cloud":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "src/tools/cloud_ops.rs"
|
||||
- "src/tools/cloud_patterns.rs"
|
||||
|
||||
"tunnel":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
2
.github/workflows/ci-run.yml
vendored
2
.github/workflows/ci-run.yml
vendored
@ -7,7 +7,7 @@ on:
|
||||
branches: [master]
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.event.pull_request.number || github.sha }}
|
||||
group: ci-${{ github.event.pull_request.number || 'push-master' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
|
||||
19
.github/workflows/pr-path-labeler.yml
vendored
Normal file
19
.github/workflows/pr-path-labeler.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: PR Path Labeler
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
label:
|
||||
name: Apply path labels
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- uses: actions/labeler@8558fd74291d67161a8a78ce36a881fa63b766a9 # v5
|
||||
with:
|
||||
sync-labels: true
|
||||
67
.github/workflows/release-beta-on-push.yml
vendored
67
.github/workflows/release-beta-on-push.yml
vendored
@ -266,9 +266,65 @@ jobs:
|
||||
path: zeroclaw-${{ matrix.target }}.${{ matrix.ext }}
|
||||
retention-days: 7
|
||||
|
||||
build-desktop:
|
||||
name: Build Desktop App (macOS Universal)
|
||||
needs: [version]
|
||||
if: needs.version.outputs.skip != 'true'
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 40
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable
|
||||
with:
|
||||
toolchain: 1.92.0
|
||||
targets: aarch64-apple-darwin,x86_64-apple-darwin
|
||||
|
||||
- uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2
|
||||
with:
|
||||
prefix-key: macos-tauri
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install Tauri CLI
|
||||
run: cargo install tauri-cli --locked
|
||||
|
||||
- name: Sync Tauri version with Cargo.toml
|
||||
shell: bash
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
cd apps/tauri
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
jq --arg v "$VERSION" '.version = $v' tauri.conf.json > tmp.json && mv tmp.json tauri.conf.json
|
||||
else
|
||||
sed -i '' "s/\"version\": \"[^\"]*\"/\"version\": \"$VERSION\"/" tauri.conf.json
|
||||
fi
|
||||
echo "Tauri version set to: $VERSION"
|
||||
|
||||
- name: Build Tauri app (universal binary)
|
||||
working-directory: apps/tauri
|
||||
run: cargo tauri build --target universal-apple-darwin
|
||||
|
||||
- name: Prepare desktop release assets
|
||||
run: |
|
||||
mkdir -p desktop-assets
|
||||
find target -name '*.dmg' -exec cp {} desktop-assets/ZeroClaw.dmg \; 2>/dev/null || true
|
||||
find target -name '*.app.tar.gz' -exec cp {} desktop-assets/ZeroClaw-macos.app.tar.gz \; 2>/dev/null || true
|
||||
find target -name '*.app.tar.gz.sig' -exec cp {} desktop-assets/ZeroClaw-macos.app.tar.gz.sig \; 2>/dev/null || true
|
||||
echo "--- Desktop assets ---"
|
||||
ls -lh desktop-assets/
|
||||
|
||||
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
with:
|
||||
name: desktop-macos
|
||||
path: desktop-assets/*
|
||||
retention-days: 7
|
||||
|
||||
publish:
|
||||
name: Publish Beta Release
|
||||
needs: [version, release-notes, build]
|
||||
needs: [version, release-notes, build, build-desktop]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
@ -278,16 +334,21 @@ jobs:
|
||||
pattern: zeroclaw-*
|
||||
path: artifacts
|
||||
|
||||
- uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
with:
|
||||
name: desktop-macos
|
||||
path: artifacts/desktop-macos
|
||||
|
||||
- name: Generate checksums
|
||||
run: |
|
||||
cd artifacts
|
||||
find . -type f \( -name '*.tar.gz' -o -name '*.zip' \) -exec sha256sum {} + | sed 's| \./[^/]*/| |' > SHA256SUMS
|
||||
find . -type f \( -name '*.tar.gz' -o -name '*.zip' -o -name '*.dmg' \) -exec sha256sum {} + | sed 's| \./[^/]*/| |' > SHA256SUMS
|
||||
cat SHA256SUMS
|
||||
|
||||
- name: Collect release assets
|
||||
run: |
|
||||
mkdir -p release-assets
|
||||
find artifacts -type f \( -name '*.tar.gz' -o -name '*.zip' -o -name 'SHA256SUMS' \) -exec cp {} release-assets/ \;
|
||||
find artifacts -type f \( -name '*.tar.gz' -o -name '*.zip' -o -name '*.dmg' -o -name 'SHA256SUMS' \) -exec cp {} release-assets/ \;
|
||||
cp install.sh release-assets/
|
||||
echo "--- Assets ---"
|
||||
ls -lh release-assets/
|
||||
|
||||
66
.github/workflows/release-stable-manual.yml
vendored
66
.github/workflows/release-stable-manual.yml
vendored
@ -273,9 +273,64 @@ jobs:
|
||||
path: zeroclaw-${{ matrix.target }}.${{ matrix.ext }}
|
||||
retention-days: 14
|
||||
|
||||
build-desktop:
|
||||
name: Build Desktop App (macOS Universal)
|
||||
needs: [validate]
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 40
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable
|
||||
with:
|
||||
toolchain: 1.92.0
|
||||
targets: aarch64-apple-darwin,x86_64-apple-darwin
|
||||
|
||||
- uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2
|
||||
with:
|
||||
prefix-key: macos-tauri
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install Tauri CLI
|
||||
run: cargo install tauri-cli --locked
|
||||
|
||||
- name: Sync Tauri version with Cargo.toml
|
||||
shell: bash
|
||||
run: |
|
||||
VERSION=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
cd apps/tauri
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
jq --arg v "$VERSION" '.version = $v' tauri.conf.json > tmp.json && mv tmp.json tauri.conf.json
|
||||
else
|
||||
sed -i '' "s/\"version\": \"[^\"]*\"/\"version\": \"$VERSION\"/" tauri.conf.json
|
||||
fi
|
||||
echo "Tauri version set to: $VERSION"
|
||||
|
||||
- name: Build Tauri app (universal binary)
|
||||
working-directory: apps/tauri
|
||||
run: cargo tauri build --target universal-apple-darwin
|
||||
|
||||
- name: Prepare desktop release assets
|
||||
run: |
|
||||
mkdir -p desktop-assets
|
||||
find target -name '*.dmg' -exec cp {} desktop-assets/ZeroClaw.dmg \; 2>/dev/null || true
|
||||
find target -name '*.app.tar.gz' -exec cp {} desktop-assets/ZeroClaw-macos.app.tar.gz \; 2>/dev/null || true
|
||||
find target -name '*.app.tar.gz.sig' -exec cp {} desktop-assets/ZeroClaw-macos.app.tar.gz.sig \; 2>/dev/null || true
|
||||
echo "--- Desktop assets ---"
|
||||
ls -lh desktop-assets/
|
||||
|
||||
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
with:
|
||||
name: desktop-macos
|
||||
path: desktop-assets/*
|
||||
retention-days: 14
|
||||
|
||||
publish:
|
||||
name: Publish Stable Release
|
||||
needs: [validate, release-notes, build]
|
||||
needs: [validate, release-notes, build, build-desktop]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
@ -285,16 +340,21 @@ jobs:
|
||||
pattern: zeroclaw-*
|
||||
path: artifacts
|
||||
|
||||
- uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4
|
||||
with:
|
||||
name: desktop-macos
|
||||
path: artifacts/desktop-macos
|
||||
|
||||
- name: Generate checksums
|
||||
run: |
|
||||
cd artifacts
|
||||
find . -type f \( -name '*.tar.gz' -o -name '*.zip' \) -exec sha256sum {} + | sed 's| \./[^/]*/| |' > SHA256SUMS
|
||||
find . -type f \( -name '*.tar.gz' -o -name '*.zip' -o -name '*.dmg' \) -exec sha256sum {} + | sed 's| \./[^/]*/| |' > SHA256SUMS
|
||||
cat SHA256SUMS
|
||||
|
||||
- name: Collect release assets
|
||||
run: |
|
||||
mkdir -p release-assets
|
||||
find artifacts -type f \( -name '*.tar.gz' -o -name '*.zip' -o -name 'SHA256SUMS' \) -exec cp {} release-assets/ \;
|
||||
find artifacts -type f \( -name '*.tar.gz' -o -name '*.zip' -o -name '*.dmg' -o -name 'SHA256SUMS' \) -exec cp {} release-assets/ \;
|
||||
cp install.sh release-assets/
|
||||
echo "--- Assets ---"
|
||||
ls -lh release-assets/
|
||||
|
||||
92
AGENTS.md
Normal file
92
AGENTS.md
Normal file
@ -0,0 +1,92 @@
|
||||
# AGENTS.md — ZeroClaw
|
||||
|
||||
Cross-tool agent instructions for any AI coding assistant working on this repository.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
cargo fmt --all -- --check
|
||||
cargo clippy --all-targets -- -D warnings
|
||||
cargo test
|
||||
```
|
||||
|
||||
Full pre-PR validation (recommended):
|
||||
|
||||
```bash
|
||||
./dev/ci.sh all
|
||||
```
|
||||
|
||||
Docs-only changes: run markdown lint and link-integrity checks. If touching bootstrap scripts: `bash -n install.sh`.
|
||||
|
||||
## Project Snapshot
|
||||
|
||||
ZeroClaw is a Rust-first autonomous agent runtime optimized for performance, efficiency, stability, extensibility, sustainability, and security.
|
||||
|
||||
Core architecture is trait-driven and modular. Extend by implementing traits and registering in factory modules.
|
||||
|
||||
Key extension points:
|
||||
|
||||
- `src/providers/traits.rs` (`Provider`)
|
||||
- `src/channels/traits.rs` (`Channel`)
|
||||
- `src/tools/traits.rs` (`Tool`)
|
||||
- `src/memory/traits.rs` (`Memory`)
|
||||
- `src/observability/traits.rs` (`Observer`)
|
||||
- `src/runtime/traits.rs` (`RuntimeAdapter`)
|
||||
- `src/peripherals/traits.rs` (`Peripheral`) — hardware boards (STM32, RPi GPIO)
|
||||
|
||||
## Repository Map
|
||||
|
||||
- `src/main.rs` — CLI entrypoint and command routing
|
||||
- `src/lib.rs` — module exports and shared command enums
|
||||
- `src/config/` — schema + config loading/merging
|
||||
- `src/agent/` — orchestration loop
|
||||
- `src/gateway/` — webhook/gateway server
|
||||
- `src/security/` — policy, pairing, secret store
|
||||
- `src/memory/` — markdown/sqlite memory backends + embeddings/vector merge
|
||||
- `src/providers/` — model providers and resilient wrapper
|
||||
- `src/channels/` — Telegram/Discord/Slack/etc channels
|
||||
- `src/tools/` — tool execution surface (shell, file, memory, browser)
|
||||
- `src/peripherals/` — hardware peripherals (STM32, RPi GPIO)
|
||||
- `src/runtime/` — runtime adapters (currently native)
|
||||
- `docs/` — topic-based documentation (setup-guides, reference, ops, security, hardware, contributing, maintainers)
|
||||
- `.github/` — CI, templates, automation workflows
|
||||
|
||||
## Risk Tiers
|
||||
|
||||
- **Low risk**: docs/chore/tests-only changes
|
||||
- **Medium risk**: most `src/**` behavior changes without boundary/security impact
|
||||
- **High risk**: `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**`, access-control boundaries
|
||||
|
||||
When uncertain, classify as higher risk.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Read before write** — inspect existing module, factory wiring, and adjacent tests before editing.
|
||||
2. **One concern per PR** — avoid mixed feature+refactor+infra patches.
|
||||
3. **Implement minimal patch** — no speculative abstractions, no config keys without a concrete use case.
|
||||
4. **Validate by risk tier** — docs-only: lightweight checks. Code changes: full relevant checks.
|
||||
5. **Document impact** — update PR notes for behavior, risk, side effects, and rollback.
|
||||
6. **Queue hygiene** — stacked PR: declare `Depends on #...`. Replacing old PR: declare `Supersedes #...`.
|
||||
|
||||
Branch/commit/PR rules:
|
||||
- Work from a non-`master` branch. Open a PR to `master`; do not push directly.
|
||||
- Use conventional commit titles. Prefer small PRs (`size: XS/S/M`).
|
||||
- Follow `.github/pull_request_template.md` fully.
|
||||
- Never commit secrets, personal data, or real identity information (see `@docs/contributing/pr-discipline.md`).
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not add heavy dependencies for minor convenience.
|
||||
- Do not silently weaken security policy or access constraints.
|
||||
- Do not add speculative config/feature flags "just in case".
|
||||
- Do not mix massive formatting-only changes with functional changes.
|
||||
- Do not modify unrelated modules "while here".
|
||||
- Do not bypass failing checks without explicit explanation.
|
||||
- Do not hide behavior-changing side effects in refactor commits.
|
||||
- Do not include personal identity or sensitive information in test data, examples, docs, or commits.
|
||||
|
||||
## Linked References
|
||||
|
||||
- `@docs/contributing/change-playbooks.md` — adding providers, channels, tools, peripherals; security/gateway changes; architecture boundaries
|
||||
- `@docs/contributing/pr-discipline.md` — privacy rules, superseded-PR attribution/templates, handoff template
|
||||
- `@docs/contributing/docs-contract.md` — docs system contract, i18n rules, locale parity
|
||||
92
CLAUDE.md
92
CLAUDE.md
@ -1,90 +1,16 @@
|
||||
# CLAUDE.md — ZeroClaw
|
||||
# CLAUDE.md — ZeroClaw (Claude Code)
|
||||
|
||||
## Commands
|
||||
> **Shared instructions live in [`AGENTS.md`](./AGENTS.md).**
|
||||
> This file contains only Claude Code-specific directives.
|
||||
|
||||
```bash
|
||||
cargo fmt --all -- --check
|
||||
cargo clippy --all-targets -- -D warnings
|
||||
cargo test
|
||||
```
|
||||
## Claude Code Settings
|
||||
|
||||
Full pre-PR validation (recommended):
|
||||
Claude Code should read and follow all instructions in `AGENTS.md` at the repository root for project conventions, commands, risk tiers, workflow rules, and anti-patterns.
|
||||
|
||||
```bash
|
||||
./dev/ci.sh all
|
||||
```
|
||||
## Hooks
|
||||
|
||||
Docs-only changes: run markdown lint and link-integrity checks. If touching bootstrap scripts: `bash -n install.sh`.
|
||||
_No custom hooks defined yet._
|
||||
|
||||
## Project Snapshot
|
||||
## Slash Commands
|
||||
|
||||
ZeroClaw is a Rust-first autonomous agent runtime optimized for performance, efficiency, stability, extensibility, sustainability, and security.
|
||||
|
||||
Core architecture is trait-driven and modular. Extend by implementing traits and registering in factory modules.
|
||||
|
||||
Key extension points:
|
||||
|
||||
- `src/providers/traits.rs` (`Provider`)
|
||||
- `src/channels/traits.rs` (`Channel`)
|
||||
- `src/tools/traits.rs` (`Tool`)
|
||||
- `src/memory/traits.rs` (`Memory`)
|
||||
- `src/observability/traits.rs` (`Observer`)
|
||||
- `src/runtime/traits.rs` (`RuntimeAdapter`)
|
||||
- `src/peripherals/traits.rs` (`Peripheral`) — hardware boards (STM32, RPi GPIO)
|
||||
|
||||
## Repository Map
|
||||
|
||||
- `src/main.rs` — CLI entrypoint and command routing
|
||||
- `src/lib.rs` — module exports and shared command enums
|
||||
- `src/config/` — schema + config loading/merging
|
||||
- `src/agent/` — orchestration loop
|
||||
- `src/gateway/` — webhook/gateway server
|
||||
- `src/security/` — policy, pairing, secret store
|
||||
- `src/memory/` — markdown/sqlite memory backends + embeddings/vector merge
|
||||
- `src/providers/` — model providers and resilient wrapper
|
||||
- `src/channels/` — Telegram/Discord/Slack/etc channels
|
||||
- `src/tools/` — tool execution surface (shell, file, memory, browser)
|
||||
- `src/peripherals/` — hardware peripherals (STM32, RPi GPIO)
|
||||
- `src/runtime/` — runtime adapters (currently native)
|
||||
- `docs/` — topic-based documentation (setup-guides, reference, ops, security, hardware, contributing, maintainers)
|
||||
- `.github/` — CI, templates, automation workflows
|
||||
|
||||
## Risk Tiers
|
||||
|
||||
- **Low risk**: docs/chore/tests-only changes
|
||||
- **Medium risk**: most `src/**` behavior changes without boundary/security impact
|
||||
- **High risk**: `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**`, access-control boundaries
|
||||
|
||||
When uncertain, classify as higher risk.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Read before write** — inspect existing module, factory wiring, and adjacent tests before editing.
|
||||
2. **One concern per PR** — avoid mixed feature+refactor+infra patches.
|
||||
3. **Implement minimal patch** — no speculative abstractions, no config keys without a concrete use case.
|
||||
4. **Validate by risk tier** — docs-only: lightweight checks. Code changes: full relevant checks.
|
||||
5. **Document impact** — update PR notes for behavior, risk, side effects, and rollback.
|
||||
6. **Queue hygiene** — stacked PR: declare `Depends on #...`. Replacing old PR: declare `Supersedes #...`.
|
||||
|
||||
Branch/commit/PR rules:
|
||||
- Work from a non-`master` branch. Open a PR to `master`; do not push directly.
|
||||
- Use conventional commit titles. Prefer small PRs (`size: XS/S/M`).
|
||||
- Follow `.github/pull_request_template.md` fully.
|
||||
- Never commit secrets, personal data, or real identity information (see `@docs/contributing/pr-discipline.md`).
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not add heavy dependencies for minor convenience.
|
||||
- Do not silently weaken security policy or access constraints.
|
||||
- Do not add speculative config/feature flags "just in case".
|
||||
- Do not mix massive formatting-only changes with functional changes.
|
||||
- Do not modify unrelated modules "while here".
|
||||
- Do not bypass failing checks without explicit explanation.
|
||||
- Do not hide behavior-changing side effects in refactor commits.
|
||||
- Do not include personal identity or sensitive information in test data, examples, docs, or commits.
|
||||
|
||||
## Linked References
|
||||
|
||||
- `@docs/contributing/change-playbooks.md` — adding providers, channels, tools, peripherals; security/gateway changes; architecture boundaries
|
||||
- `@docs/contributing/pr-discipline.md` — privacy rules, superseded-PR attribution/templates, handoff template
|
||||
- `@docs/contributing/docs-contract.md` — docs system contract, i18n rules, locale parity
|
||||
_No custom slash commands defined yet._
|
||||
|
||||
73
Cargo.lock
generated
73
Cargo.lock
generated
@ -6,7 +6,7 @@ version = 4
|
||||
name = "aardvark-sys"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"libloading 0.9.0",
|
||||
"libloading 0.8.9",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
@ -2542,9 +2542,9 @@ checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
|
||||
|
||||
[[package]]
|
||||
name = "embed-resource"
|
||||
version = "3.0.7"
|
||||
version = "3.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47ec73ddcf6b7f23173d5c3c5a32b5507dc0a734de7730aa14abc5d5e296bb5f"
|
||||
checksum = "63a1d0de4f2249aa0ff5884d7080814f446bb241a559af6c170a41e878ed2d45"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"memchr",
|
||||
@ -2616,18 +2616,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f"
|
||||
checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef"
|
||||
dependencies = [
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.9"
|
||||
version = "0.11.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d"
|
||||
checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a"
|
||||
dependencies = [
|
||||
"env_filter",
|
||||
"log",
|
||||
@ -4689,16 +4689,6 @@ dependencies = [
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.16"
|
||||
@ -6886,7 +6876,7 @@ version = "3.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f"
|
||||
dependencies = [
|
||||
"toml_edit 0.25.5+spec-1.1.0",
|
||||
"toml_edit 0.25.8+spec-1.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -8361,9 +8351,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "1.0.4"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776"
|
||||
checksum = "876ac351060d4f882bb1032b6369eb0aef79ad9df1ea8bc404874d8cc3d0cd98"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
@ -9633,7 +9623,7 @@ checksum = "cf92845e79fc2e2def6a5d828f0801e29a2f8acc037becc5ab08595c7d5e9863"
|
||||
dependencies = [
|
||||
"indexmap 2.13.0",
|
||||
"serde_core",
|
||||
"serde_spanned 1.0.4",
|
||||
"serde_spanned 1.1.0",
|
||||
"toml_datetime 0.7.5+spec-1.1.0",
|
||||
"toml_parser",
|
||||
"toml_writer",
|
||||
@ -9642,14 +9632,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "1.0.7+spec-1.1.0"
|
||||
version = "1.1.0+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd28d57d8a6f6e458bc0b8784f8fdcc4b99a437936056fa122cb234f18656a96"
|
||||
checksum = "f8195ca05e4eb728f4ba94f3e3291661320af739c4e43779cbdfae82ab239fcc"
|
||||
dependencies = [
|
||||
"indexmap 2.13.0",
|
||||
"serde_core",
|
||||
"serde_spanned 1.0.4",
|
||||
"toml_datetime 1.0.1+spec-1.1.0",
|
||||
"serde_spanned 1.1.0",
|
||||
"toml_datetime 1.1.0+spec-1.1.0",
|
||||
"toml_parser",
|
||||
"toml_writer",
|
||||
"winnow 1.0.0",
|
||||
@ -9675,9 +9665,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "1.0.1+spec-1.1.0"
|
||||
version = "1.1.0+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b320e741db58cac564e26c607d3cc1fdc4a88fd36c879568c07856ed83ff3e9"
|
||||
checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
@ -9720,21 +9710,21 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.25.5+spec-1.1.0"
|
||||
version = "0.25.8+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ca1a40644a28bce036923f6a431df0b34236949d111cc07cb6dca830c9ef2e1"
|
||||
checksum = "16bff38f1d86c47f9ff0647e6838d7bb362522bdf44006c7068c2b1e606f1f3c"
|
||||
dependencies = [
|
||||
"indexmap 2.13.0",
|
||||
"toml_datetime 1.0.1+spec-1.1.0",
|
||||
"toml_datetime 1.1.0+spec-1.1.0",
|
||||
"toml_parser",
|
||||
"winnow 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.10+spec-1.1.0"
|
||||
version = "1.1.0+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7df25b4befd31c4816df190124375d5a20c6b6921e2cad937316de3fccd63420"
|
||||
checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011"
|
||||
dependencies = [
|
||||
"winnow 1.0.0",
|
||||
]
|
||||
@ -9747,9 +9737,9 @@ checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.7+spec-1.1.0"
|
||||
version = "1.1.0+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f17aaa1c6e3dc22b1da4b6bba97d066e354c7945cac2f7852d4e4e7ca7a6b56d"
|
||||
checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed"
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
@ -9991,9 +9981,9 @@ checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c"
|
||||
|
||||
[[package]]
|
||||
name = "type1-encoding-parser"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3d6cc09e1a99c7e01f2afe4953789311a1c50baebbdac5b477ecf78e2e92a5b"
|
||||
checksum = "fa10c302f5a53b7ad27fd42a3996e23d096ba39b5b8dd6d9e683a05b01bee749"
|
||||
dependencies = [
|
||||
"pom",
|
||||
]
|
||||
@ -12368,13 +12358,13 @@ dependencies = [
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"toml 1.0.7+spec-1.1.0",
|
||||
"toml 1.1.0+spec-1.1.0",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.9"
|
||||
version = "0.6.1"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
@ -12446,10 +12436,11 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-serial",
|
||||
"tokio-socks",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite 0.29.0",
|
||||
"tokio-util",
|
||||
"toml 1.0.7+spec-1.1.0",
|
||||
"toml 1.1.0+spec-1.1.0",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
@ -12586,9 +12577,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "8.3.1"
|
||||
version = "8.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c546feb4481b0fbafb4ef0d79b6204fc41c6f9884b1b73b1d73f82442fc0845"
|
||||
checksum = "7756d0206d058333667493c4014f545f4b9603c4330ccd6d9b3f86dcab59f7d9"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"flate2",
|
||||
|
||||
@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.9"
|
||||
version = "0.6.1"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@ -150,6 +150,7 @@ which = "8.0"
|
||||
|
||||
# WebSocket client channels (Discord/Lark/DingTalk/Nostr)
|
||||
tokio-tungstenite = { version = "0.29", features = ["rustls-tls-webpki-roots"] }
|
||||
tokio-socks = "0.5"
|
||||
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
||||
nostr-sdk = { version = "0.44", default-features = false, features = ["nip04", "nip59"], optional = true }
|
||||
regex = "1.10"
|
||||
@ -224,7 +225,7 @@ landlock = { version = "0.4", optional = true }
|
||||
libc = "0.2"
|
||||
|
||||
[features]
|
||||
default = ["observability-prometheus", "channel-nostr", "skill-creation"]
|
||||
default = ["observability-prometheus", "channel-nostr", "channel-lark", "skill-creation"]
|
||||
channel-nostr = ["dep:nostr-sdk"]
|
||||
hardware = ["nusb", "tokio-serial"]
|
||||
channel-matrix = ["dep:matrix-sdk"]
|
||||
|
||||
@ -12,7 +12,7 @@ RUN npm run build
|
||||
FROM rust:1.94-slim@sha256:da9dab7a6b8dd428e71718402e97207bb3e54167d37b5708616050b1e8f60ed6 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
ARG ZEROCLAW_CARGO_FEATURES="memory-postgres"
|
||||
ARG ZEROCLAW_CARGO_FEATURES="memory-postgres,channel-lark"
|
||||
|
||||
# Install build dependencies
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
@ -79,6 +79,10 @@ RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace && \
|
||||
'port = 42617' \
|
||||
'host = "[::]"' \
|
||||
'allow_public_bind = true' \
|
||||
'' \
|
||||
'[autonomy]' \
|
||||
'level = "supervised"' \
|
||||
'auto_approve = ["file_read", "file_write", "file_edit", "memory_recall", "memory_store", "web_search_tool", "web_fetch", "calculator", "glob_search", "content_search", "image_info", "weather", "git_operations"]' \
|
||||
> /zeroclaw-data/.zeroclaw/config.toml && \
|
||||
chown -R 65534:65534 /zeroclaw-data
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ RUN npm run build
|
||||
FROM rust:1.94-bookworm AS builder
|
||||
|
||||
WORKDIR /app
|
||||
ARG ZEROCLAW_CARGO_FEATURES="memory-postgres"
|
||||
ARG ZEROCLAW_CARGO_FEATURES="memory-postgres,channel-lark"
|
||||
|
||||
# Install build dependencies
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
@ -89,6 +89,10 @@ RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace && \
|
||||
'port = 42617' \
|
||||
'host = "[::]"' \
|
||||
'allow_public_bind = true' \
|
||||
'' \
|
||||
'[autonomy]' \
|
||||
'level = "supervised"' \
|
||||
'auto_approve = ["file_read", "file_write", "file_edit", "memory_recall", "memory_store", "web_search_tool", "web_fetch", "calculator", "glob_search", "content_search", "image_info", "weather", "git_operations"]' \
|
||||
> /zeroclaw-data/.zeroclaw/config.toml && \
|
||||
chown -R 65534:65534 /zeroclaw-data
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
{
|
||||
"$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/dev/crates/tauri-cli/config.schema.json",
|
||||
"productName": "ZeroClaw",
|
||||
"version": "0.1.0",
|
||||
"version": "0.6.1",
|
||||
"identifier": "ai.zeroclawlabs.desktop",
|
||||
"build": {
|
||||
"devUrl": "http://127.0.0.1:42617/_app/",
|
||||
|
||||
@ -21,5 +21,5 @@ repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
||||
# 4. Replace stub method bodies with FFI calls via mod bindings
|
||||
|
||||
[dependencies]
|
||||
libloading = "0.9"
|
||||
libloading = "0.8"
|
||||
thiserror = "2.0"
|
||||
|
||||
@ -20,6 +20,7 @@ Selected allowlist (all actions currently used across Quality Gate, Release Beta
|
||||
| `docker/setup-buildx-action@v3` | release, promote-release | Docker Buildx setup |
|
||||
| `docker/login-action@v3` | release, promote-release | GHCR authentication |
|
||||
| `docker/build-push-action@v6` | release, promote-release | Multi-platform Docker image build and push |
|
||||
| `actions/labeler@v5` | pr-path-labeler | Apply path/scope labels from `labeler.yml` |
|
||||
|
||||
Equivalent allowlist patterns:
|
||||
|
||||
@ -36,6 +37,7 @@ Equivalent allowlist patterns:
|
||||
| Quality Gate | `.github/workflows/checks-on-pr.yml` | Pull requests to `master` |
|
||||
| Release Beta | `.github/workflows/release-beta-on-push.yml` | Push to `master` |
|
||||
| Release Stable | `.github/workflows/release-stable-manual.yml` | Manual `workflow_dispatch` |
|
||||
| PR Path Labeler | `.github/workflows/pr-path-labeler.yml` | `pull_request_target` (opened, synchronize, reopened) |
|
||||
|
||||
## Change Control
|
||||
|
||||
@ -62,6 +64,7 @@ gh api repos/zeroclaw-labs/zeroclaw/actions/permissions/selected-actions
|
||||
|
||||
## Change Log
|
||||
|
||||
- 2026-03-23: Added PR Path Labeler (`pr-path-labeler.yml`) using `actions/labeler@v5`. No allowlist change needed — covered by existing `actions/*` pattern.
|
||||
- 2026-03-10: Renamed workflows — CI → Quality Gate (`checks-on-pr.yml`), Beta Release → Release Beta (`release-beta-on-push.yml`), Promote Release → Release Stable (`release-stable-manual.yml`). Added `lint` and `security` jobs to Quality Gate. Added Cross-Platform Build (`cross-platform-build-manual.yml`).
|
||||
- 2026-03-05: Complete workflow overhaul — replaced 22 workflows with 3 (CI, Beta Release, Promote Release)
|
||||
- Removed patterns no longer in use: `DavidAnson/markdownlint-cli2-action@*`, `lycheeverse/lychee-action@*`, `EmbarkStudios/cargo-deny-action@*`, `rustsec/audit-check@*`, `rhysd/actionlint@*`, `sigstore/cosign-installer@*`, `Checkmarx/vorpal-reviewdog-github-action@*`, `useblacksmith/*`
|
||||
|
||||
213
docs/contributing/label-registry.md
Normal file
213
docs/contributing/label-registry.md
Normal file
@ -0,0 +1,213 @@
|
||||
# Label Registry
|
||||
|
||||
Single reference for every label used on PRs and issues. Labels are grouped by category. Each entry lists the label name, definition, and how it is applied.
|
||||
|
||||
Sources consolidated here:
|
||||
|
||||
- `.github/labeler.yml` (path-label config for `actions/labeler`)
|
||||
- `.github/label-policy.json` (contributor tier thresholds)
|
||||
- `docs/contributing/pr-workflow.md` (size, risk, and triage label definitions)
|
||||
- `docs/contributing/ci-map.md` (automation behavior and high-risk path heuristics)
|
||||
|
||||
Note: The CI was simplified to 4 workflows (`ci.yml`, `release.yml`, `ci-full.yml`, `promote-release.yml`). Workflows that previously automated size, risk, contributor tier, and triage labels (`pr-labeler.yml`, `pr-auto-response.yml`, `pr-check-stale.yml`, and supporting scripts) were removed. Only path labels via `pr-path-labeler.yml` are currently automated.
|
||||
|
||||
---
|
||||
|
||||
## Path labels
|
||||
|
||||
Applied automatically by `pr-path-labeler.yml` using `actions/labeler`. Matches changed files against glob patterns in `.github/labeler.yml`.
|
||||
|
||||
### Base scope labels
|
||||
|
||||
| Label | Matches |
|
||||
|---|---|
|
||||
| `docs` | `docs/**`, `**/*.md`, `**/*.mdx`, `LICENSE`, `.markdownlint-cli2.yaml` |
|
||||
| `dependencies` | `Cargo.toml`, `Cargo.lock`, `deny.toml`, `.github/dependabot.yml` |
|
||||
| `ci` | `.github/**`, `.githooks/**` |
|
||||
| `core` | `src/*.rs` |
|
||||
| `agent` | `src/agent/**` |
|
||||
| `channel` | `src/channels/**` |
|
||||
| `gateway` | `src/gateway/**` |
|
||||
| `config` | `src/config/**` |
|
||||
| `cron` | `src/cron/**` |
|
||||
| `daemon` | `src/daemon/**` |
|
||||
| `doctor` | `src/doctor/**` |
|
||||
| `health` | `src/health/**` |
|
||||
| `heartbeat` | `src/heartbeat/**` |
|
||||
| `integration` | `src/integrations/**` |
|
||||
| `memory` | `src/memory/**` |
|
||||
| `security` | `src/security/**` |
|
||||
| `runtime` | `src/runtime/**` |
|
||||
| `onboard` | `src/onboard/**` |
|
||||
| `provider` | `src/providers/**` |
|
||||
| `service` | `src/service/**` |
|
||||
| `skillforge` | `src/skillforge/**` |
|
||||
| `skills` | `src/skills/**` |
|
||||
| `tool` | `src/tools/**` |
|
||||
| `tunnel` | `src/tunnel/**` |
|
||||
| `observability` | `src/observability/**` |
|
||||
| `tests` | `tests/**` |
|
||||
| `scripts` | `scripts/**` |
|
||||
| `dev` | `dev/**` |
|
||||
|
||||
### Per-component channel labels
|
||||
|
||||
Each channel gets a specific label in addition to the base `channel` label.
|
||||
|
||||
| Label | Matches |
|
||||
|---|---|
|
||||
| `channel:bluesky` | `bluesky.rs` |
|
||||
| `channel:clawdtalk` | `clawdtalk.rs` |
|
||||
| `channel:cli` | `cli.rs` |
|
||||
| `channel:dingtalk` | `dingtalk.rs` |
|
||||
| `channel:discord` | `discord.rs`, `discord_history.rs` |
|
||||
| `channel:email` | `email_channel.rs`, `gmail_push.rs` |
|
||||
| `channel:imessage` | `imessage.rs` |
|
||||
| `channel:irc` | `irc.rs` |
|
||||
| `channel:lark` | `lark.rs` |
|
||||
| `channel:linq` | `linq.rs` |
|
||||
| `channel:matrix` | `matrix.rs` |
|
||||
| `channel:mattermost` | `mattermost.rs` |
|
||||
| `channel:mochat` | `mochat.rs` |
|
||||
| `channel:mqtt` | `mqtt.rs` |
|
||||
| `channel:nextcloud-talk` | `nextcloud_talk.rs` |
|
||||
| `channel:nostr` | `nostr.rs` |
|
||||
| `channel:notion` | `notion.rs` |
|
||||
| `channel:qq` | `qq.rs` |
|
||||
| `channel:reddit` | `reddit.rs` |
|
||||
| `channel:signal` | `signal.rs` |
|
||||
| `channel:slack` | `slack.rs` |
|
||||
| `channel:telegram` | `telegram.rs` |
|
||||
| `channel:twitter` | `twitter.rs` |
|
||||
| `channel:wati` | `wati.rs` |
|
||||
| `channel:webhook` | `webhook.rs` |
|
||||
| `channel:wecom` | `wecom.rs` |
|
||||
| `channel:whatsapp` | `whatsapp.rs`, `whatsapp_storage.rs`, `whatsapp_web.rs` |
|
||||
|
||||
### Per-component provider labels
|
||||
|
||||
| Label | Matches |
|
||||
|---|---|
|
||||
| `provider:anthropic` | `anthropic.rs` |
|
||||
| `provider:azure-openai` | `azure_openai.rs` |
|
||||
| `provider:bedrock` | `bedrock.rs` |
|
||||
| `provider:claude-code` | `claude_code.rs` |
|
||||
| `provider:compatible` | `compatible.rs` |
|
||||
| `provider:copilot` | `copilot.rs` |
|
||||
| `provider:gemini` | `gemini.rs`, `gemini_cli.rs` |
|
||||
| `provider:glm` | `glm.rs` |
|
||||
| `provider:kilocli` | `kilocli.rs` |
|
||||
| `provider:ollama` | `ollama.rs` |
|
||||
| `provider:openai` | `openai.rs`, `openai_codex.rs` |
|
||||
| `provider:openrouter` | `openrouter.rs` |
|
||||
| `provider:telnyx` | `telnyx.rs` |
|
||||
|
||||
### Per-group tool labels
|
||||
|
||||
Tools are grouped by logical function rather than one label per file.
|
||||
|
||||
| Label | Matches |
|
||||
|---|---|
|
||||
| `tool:browser` | `browser.rs`, `browser_delegate.rs`, `browser_open.rs`, `text_browser.rs`, `screenshot.rs` |
|
||||
| `tool:cloud` | `cloud_ops.rs`, `cloud_patterns.rs` |
|
||||
| `tool:composio` | `composio.rs` |
|
||||
| `tool:cron` | `cron_add.rs`, `cron_list.rs`, `cron_remove.rs`, `cron_run.rs`, `cron_runs.rs`, `cron_update.rs` |
|
||||
| `tool:file` | `file_edit.rs`, `file_read.rs`, `file_write.rs`, `glob_search.rs`, `content_search.rs` |
|
||||
| `tool:google-workspace` | `google_workspace.rs` |
|
||||
| `tool:mcp` | `mcp_client.rs`, `mcp_deferred.rs`, `mcp_protocol.rs`, `mcp_tool.rs`, `mcp_transport.rs` |
|
||||
| `tool:memory` | `memory_forget.rs`, `memory_recall.rs`, `memory_store.rs` |
|
||||
| `tool:microsoft365` | `microsoft365/**` |
|
||||
| `tool:security` | `security_ops.rs`, `verifiable_intent.rs` |
|
||||
| `tool:shell` | `shell.rs`, `node_tool.rs`, `cli_discovery.rs` |
|
||||
| `tool:sop` | `sop_advance.rs`, `sop_approve.rs`, `sop_execute.rs`, `sop_list.rs`, `sop_status.rs` |
|
||||
| `tool:web` | `web_fetch.rs`, `web_search_tool.rs`, `web_search_provider_routing.rs`, `http_request.rs` |
|
||||
|
||||
---
|
||||
|
||||
## Size labels
|
||||
|
||||
Defined in `pr-workflow.md` §6.1. Based on effective changed line count, normalized for docs-only and lockfile-heavy PRs.
|
||||
|
||||
| Label | Threshold |
|
||||
|---|---|
|
||||
| `size: XS` | <= 80 lines |
|
||||
| `size: S` | <= 250 lines |
|
||||
| `size: M` | <= 500 lines |
|
||||
| `size: L` | <= 1000 lines |
|
||||
| `size: XL` | > 1000 lines |
|
||||
|
||||
**Applied by:** manual. The workflows that previously computed size labels (`pr-labeler.yml` and supporting scripts) were removed during CI simplification.
|
||||
|
||||
---
|
||||
|
||||
## Risk labels
|
||||
|
||||
Defined in `pr-workflow.md` §13.2 and `ci-map.md`. Based on a heuristic combining touched paths and change size.
|
||||
|
||||
| Label | Meaning |
|
||||
|---|---|
|
||||
| `risk: low` | No high-risk paths touched, small change |
|
||||
| `risk: medium` | Behavioral `src/**` changes without boundary/security impact |
|
||||
| `risk: high` | Touches high-risk paths (see below) or large security-adjacent change |
|
||||
| `risk: manual` | Maintainer override that freezes automated risk recalculation |
|
||||
|
||||
High-risk paths: `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**`.
|
||||
|
||||
The boundary between low and medium is not formally defined beyond "no high-risk paths."
|
||||
|
||||
**Applied by:** manual. Previously automated via `pr-labeler.yml`; removed during CI simplification.
|
||||
|
||||
---
|
||||
|
||||
## Contributor tier labels
|
||||
|
||||
Defined in `.github/label-policy.json`. Based on the author's merged PR count queried from the GitHub API.
|
||||
|
||||
| Label | Minimum merged PRs |
|
||||
|---|---|
|
||||
| `trusted contributor` | 5 |
|
||||
| `experienced contributor` | 10 |
|
||||
| `principal contributor` | 20 |
|
||||
| `distinguished contributor` | 50 |
|
||||
|
||||
**Applied by:** manual. Previously automated via `pr-labeler.yml` and `pr-auto-response.yml`; removed during CI simplification.
|
||||
|
||||
---
|
||||
|
||||
## Response and triage labels
|
||||
|
||||
Defined in `pr-workflow.md` §8. Applied manually.
|
||||
|
||||
| Label | Purpose | Applied by |
|
||||
|---|---|---|
|
||||
| `r:needs-repro` | Incomplete bug report; request deterministic repro | Manual |
|
||||
| `r:support` | Usage/help item better handled outside bug backlog | Manual |
|
||||
| `invalid` | Not a valid bug/feature request | Manual |
|
||||
| `duplicate` | Duplicate of existing issue | Manual |
|
||||
| `stale-candidate` | Dormant PR/issue; candidate for closing | Manual |
|
||||
| `superseded` | Replaced by a newer PR | Manual |
|
||||
| `no-stale` | Exempt from stale automation; accepted but blocked work | Manual |
|
||||
|
||||
**Automation:** none currently. The workflows that handled label-driven issue closing (`pr-auto-response.yml`) and stale detection (`pr-check-stale.yml`) were removed during CI simplification.
|
||||
|
||||
---
|
||||
|
||||
## Implementation status
|
||||
|
||||
| Category | Count | Automated | Workflow |
|
||||
|---|---|---|---|
|
||||
| Path (base scope) | 27 | Yes | `pr-path-labeler.yml` |
|
||||
| Path (per-component) | 52 | Yes | `pr-path-labeler.yml` |
|
||||
| Size | 5 | No | Manual |
|
||||
| Risk | 4 | No | Manual |
|
||||
| Contributor tier | 4 | No | Manual |
|
||||
| Response/triage | 7 | No | Manual |
|
||||
| **Total** | **99** | | |
|
||||
|
||||
---
|
||||
|
||||
## Maintenance
|
||||
|
||||
- **Owner:** maintainers responsible for label policy and PR triage automation.
|
||||
- **Update trigger:** new channels, providers, or tools added to the source tree; label policy changes; triage workflow changes.
|
||||
- **Source of truth:** this document consolidates definitions from the four source files listed at the top. When definitions conflict, update the source file first, then sync this registry.
|
||||
@ -109,3 +109,11 @@ allow_override = false
|
||||
# [cost.prices."openai/gpt-4o-mini"]
|
||||
# input = 0.15
|
||||
# output = 0.60
|
||||
|
||||
# ── Voice Transcription ─────────────────────────────────────────
|
||||
# [transcription]
|
||||
# enabled = true
|
||||
# default_provider = "groq"
|
||||
# Also transcribe non-PTT (forwarded / regular) audio on WhatsApp.
|
||||
# Default: false (only voice notes are transcribed).
|
||||
# transcribe_non_ptt_audio = false
|
||||
|
||||
163
install.sh
163
install.sh
@ -230,6 +230,49 @@ detect_release_target() {
|
||||
esac
|
||||
}
|
||||
|
||||
detect_device_class() {
|
||||
# Containers are never desktops
|
||||
if _is_container_runtime; then
|
||||
echo "container"
|
||||
return
|
||||
fi
|
||||
|
||||
# Termux / Android
|
||||
if [[ -n "${TERMUX_VERSION:-}" || -d "/data/data/com.termux" ]]; then
|
||||
echo "mobile"
|
||||
return
|
||||
fi
|
||||
|
||||
local os arch
|
||||
os="$(uname -s)"
|
||||
arch="$(uname -m)"
|
||||
|
||||
case "$os" in
|
||||
Darwin)
|
||||
# macOS is always a desktop
|
||||
echo "desktop"
|
||||
;;
|
||||
Linux)
|
||||
# Raspberry Pi / ARM SBCs — treat as embedded (typically headless)
|
||||
case "$arch" in
|
||||
armv6l|armv7l)
|
||||
echo "embedded"
|
||||
return
|
||||
;;
|
||||
esac
|
||||
# Check for a display server (X11 or Wayland)
|
||||
if [[ -n "${DISPLAY:-}" || -n "${WAYLAND_DISPLAY:-}" || -n "${XDG_SESSION_TYPE:-}" ]]; then
|
||||
echo "desktop"
|
||||
else
|
||||
echo "server"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "server"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
should_attempt_prebuilt_for_resources() {
|
||||
local workspace="${1:-.}"
|
||||
local min_ram_mb min_disk_mb total_ram_mb free_disk_mb low_resource
|
||||
@ -1155,6 +1198,9 @@ while [[ $# -gt 0 ]]; do
|
||||
done
|
||||
|
||||
OS_NAME="$(uname -s)"
|
||||
DEVICE_CLASS="$(detect_device_class)"
|
||||
step_dot "Device: $OS_NAME/$(uname -m) ($DEVICE_CLASS)"
|
||||
|
||||
if [[ "$GUIDED_MODE" == "auto" ]]; then
|
||||
if [[ "$OS_NAME" == "Linux" && "$ORIGINAL_ARG_COUNT" -eq 0 && -t 0 && -t 1 ]]; then
|
||||
GUIDED_MODE="on"
|
||||
@ -1479,63 +1525,64 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
# --- Build desktop app (macOS only) ---
|
||||
if [[ "$SKIP_BUILD" == false && "$OS_NAME" == "Darwin" && -d "$WORK_DIR/apps/tauri" ]]; then
|
||||
echo
|
||||
echo -e "${BOLD}Desktop app preflight${RESET}"
|
||||
# --- Companion desktop app (device-class-aware) ---
|
||||
# The desktop app is a pre-built download from the website, not built from source.
|
||||
# This keeps the one-liner install fast and the CLI binary small.
|
||||
DESKTOP_DOWNLOAD_URL="https://www.zeroclawlabs.ai/download"
|
||||
DESKTOP_APP_DETECTED=false
|
||||
|
||||
_desktop_ok=true
|
||||
|
||||
# Check Rust toolchain
|
||||
if have_cmd cargo && have_cmd rustc; then
|
||||
step_ok "Rust $(rustc --version | awk '{print $2}') found"
|
||||
else
|
||||
step_fail "Rust toolchain not found — required for desktop app"
|
||||
_desktop_ok=false
|
||||
fi
|
||||
|
||||
# Check Xcode CLT (needed for linking native frameworks)
|
||||
if xcode-select -p >/dev/null 2>&1; then
|
||||
step_ok "Xcode Command Line Tools installed"
|
||||
else
|
||||
step_fail "Xcode Command Line Tools not found — run: xcode-select --install"
|
||||
_desktop_ok=false
|
||||
fi
|
||||
|
||||
# Check that the Tauri CLI is available (cargo-tauri or tauri-cli)
|
||||
if have_cmd cargo-tauri; then
|
||||
step_ok "cargo-tauri $(cargo tauri --version 2>/dev/null | awk '{print $NF}') found"
|
||||
else
|
||||
step_dot "cargo-tauri not found — installing"
|
||||
if cargo install tauri-cli --locked 2>/dev/null; then
|
||||
step_ok "cargo-tauri installed"
|
||||
else
|
||||
warn "Failed to install cargo-tauri — desktop app build may fail"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check node/npm (needed for web frontend that Tauri embeds)
|
||||
if have_cmd node && have_cmd npm; then
|
||||
step_ok "Node.js $(node --version) found"
|
||||
else
|
||||
warn "node/npm not found — desktop app needs the web dashboard built first"
|
||||
fi
|
||||
|
||||
if [[ "$_desktop_ok" == true ]]; then
|
||||
step_dot "Building desktop app (zeroclaw-desktop)"
|
||||
if cargo build -p zeroclaw-desktop --release --locked 2>/dev/null; then
|
||||
step_ok "Desktop app built"
|
||||
# Copy binary to cargo bin for easy access
|
||||
if [[ -x "$WORK_DIR/target/release/zeroclaw-desktop" ]]; then
|
||||
cp -f "$WORK_DIR/target/release/zeroclaw-desktop" "$HOME/.cargo/bin/zeroclaw-desktop" 2>/dev/null && \
|
||||
step_ok "zeroclaw-desktop installed to ~/.cargo/bin" || true
|
||||
if [[ "$DEVICE_CLASS" == "desktop" ]]; then
|
||||
# Check if the companion app is already installed
|
||||
case "$OS_NAME" in
|
||||
Darwin)
|
||||
if [[ -d "/Applications/ZeroClaw.app" ]] || [[ -d "$HOME/Applications/ZeroClaw.app" ]]; then
|
||||
DESKTOP_APP_DETECTED=true
|
||||
step_ok "Companion app found (ZeroClaw.app)"
|
||||
fi
|
||||
else
|
||||
warn "Desktop app build failed — you can build later with: cargo build -p zeroclaw-desktop --release"
|
||||
fi
|
||||
else
|
||||
warn "Skipping desktop app build — fix missing dependencies above and re-run"
|
||||
;;
|
||||
Linux)
|
||||
if have_cmd zeroclaw-desktop; then
|
||||
DESKTOP_APP_DETECTED=true
|
||||
step_ok "Companion app found (zeroclaw-desktop)"
|
||||
elif [[ -x "$HOME/.local/bin/zeroclaw-desktop" ]]; then
|
||||
DESKTOP_APP_DETECTED=true
|
||||
step_ok "Companion app found (~/.local/bin/zeroclaw-desktop)"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ "$DESKTOP_APP_DETECTED" == false ]]; then
|
||||
echo
|
||||
echo -e "${BOLD}Companion App${RESET}"
|
||||
echo -e " Menu bar access to your ZeroClaw agent."
|
||||
echo -e " Works alongside the CLI — connects to the same gateway."
|
||||
echo
|
||||
case "$OS_NAME" in
|
||||
Darwin)
|
||||
echo -e " ${BOLD}Download for macOS:${RESET} ${BLUE}${DESKTOP_DOWNLOAD_URL}${RESET}"
|
||||
;;
|
||||
Linux)
|
||||
echo -e " ${BOLD}Download for Linux:${RESET} ${BLUE}${DESKTOP_DOWNLOAD_URL}${RESET}"
|
||||
;;
|
||||
esac
|
||||
echo -e " ${DIM}Or run: zeroclaw desktop --install${RESET}"
|
||||
fi
|
||||
elif [[ "$DEVICE_CLASS" != "desktop" ]]; then
|
||||
# Non-desktop device — explain why companion app is not offered
|
||||
case "$DEVICE_CLASS" in
|
||||
mobile)
|
||||
step_dot "Mobile device — use the web dashboard at http://127.0.0.1:42617"
|
||||
;;
|
||||
embedded)
|
||||
step_dot "Embedded device ($(uname -m)) — use the web dashboard"
|
||||
;;
|
||||
container)
|
||||
step_dot "Container runtime — use the web dashboard"
|
||||
;;
|
||||
server)
|
||||
step_dot "Headless server — use the web dashboard"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
ZEROCLAW_BIN=""
|
||||
@ -1704,8 +1751,12 @@ echo -e "${BOLD}Next steps:${RESET}"
|
||||
echo -e " ${DIM}zeroclaw status${RESET}"
|
||||
echo -e " ${DIM}zeroclaw agent -m \"Hello, ZeroClaw!\"${RESET}"
|
||||
echo -e " ${DIM}zeroclaw gateway${RESET}"
|
||||
if [[ "$OS_NAME" == "Darwin" ]] && have_cmd zeroclaw-desktop; then
|
||||
echo -e " ${DIM}zeroclaw-desktop${RESET} ${DIM}# Launch the menu bar app${RESET}"
|
||||
if [[ "$DEVICE_CLASS" == "desktop" ]]; then
|
||||
if [[ "$DESKTOP_APP_DETECTED" == true ]]; then
|
||||
echo -e " ${DIM}zeroclaw desktop${RESET} ${DIM}# Launch the menu bar app${RESET}"
|
||||
else
|
||||
echo -e " ${DIM}zeroclaw desktop --install${RESET} ${DIM}# Download the companion app${RESET}"
|
||||
fi
|
||||
fi
|
||||
echo
|
||||
echo -e "${BOLD}Docs:${RESET} ${BLUE}https://www.zeroclawlabs.ai/docs${RESET}"
|
||||
|
||||
3
skills/browser/TEST.sh
Normal file
3
skills/browser/TEST.sh
Normal file
@ -0,0 +1,3 @@
|
||||
# Browser skill tests
|
||||
# Format: command | expected_exit_code | expected_output_pattern
|
||||
echo "browser skill loaded" | 0 | browser skill loaded
|
||||
@ -377,7 +377,7 @@ impl Agent {
|
||||
None
|
||||
};
|
||||
|
||||
let (mut tools, delegate_handle, _reaction_handle, _channel_map_handle) =
|
||||
let (mut tools, delegate_handle, _reaction_handle, _channel_map_handle, _ask_user_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
@ -653,6 +653,24 @@ impl Agent {
|
||||
return format!("hint:{}", decision.hint);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: auto-classify by complexity when no rule matched.
|
||||
if let Some(ref ac) = self.config.auto_classify {
|
||||
let tier = super::eval::estimate_complexity(user_message);
|
||||
if let Some(hint) = ac.hint_for(tier) {
|
||||
if self.available_hints.contains(&hint.to_string()) {
|
||||
tracing::info!(
|
||||
target: "query_classification",
|
||||
hint = hint,
|
||||
complexity = ?tier,
|
||||
message_length = user_message.len(),
|
||||
"Auto-classified by complexity"
|
||||
);
|
||||
return format!("hint:{hint}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.model_name.clone()
|
||||
}
|
||||
|
||||
|
||||
155
src/agent/context_analyzer.rs
Normal file
155
src/agent/context_analyzer.rs
Normal file
@ -0,0 +1,155 @@
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Signals extracted from conversation context to guide tool filtering.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextSignals {
|
||||
/// Tool names likely needed. Empty vec means no filtering.
|
||||
pub suggested_tools: Vec<String>,
|
||||
/// Whether full history is relevant.
|
||||
pub history_relevant: bool,
|
||||
}
|
||||
|
||||
/// Analyze context to determine which tools are likely needed.
|
||||
pub fn analyze_turn_context(
|
||||
history: &[ChatMessage],
|
||||
_user_message: &str,
|
||||
iteration: usize,
|
||||
last_tool_calls: &[String],
|
||||
) -> ContextSignals {
|
||||
if iteration == 0 {
|
||||
return ContextSignals {
|
||||
suggested_tools: Vec::new(),
|
||||
history_relevant: true,
|
||||
};
|
||||
}
|
||||
|
||||
let mut tools: HashSet<String> = HashSet::new();
|
||||
for tool in last_tool_calls {
|
||||
tools.insert(tool.clone());
|
||||
}
|
||||
|
||||
if let Some(last_assistant) = history.iter().rev().find(|m| m.role == "assistant") {
|
||||
for word in last_assistant.content.split_whitespace() {
|
||||
for tool_name in tools_for_keyword(word) {
|
||||
tools.insert(tool_name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut suggested: Vec<String> = tools.into_iter().collect();
|
||||
suggested.sort();
|
||||
|
||||
ContextSignals {
|
||||
suggested_tools: suggested,
|
||||
history_relevant: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn tools_for_keyword(keyword: &str) -> &'static [&'static str] {
|
||||
match keyword.to_lowercase().as_str() {
|
||||
"file" | "read" | "write" | "edit" | "path" | "directory" => {
|
||||
&["file_read", "file_write", "file_edit", "glob_search"]
|
||||
}
|
||||
"shell" | "command" | "run" | "execute" | "install" | "build" => &["shell"],
|
||||
"memory" | "remember" | "recall" | "store" | "forget" => &["memory_store", "memory_recall"],
|
||||
"search" | "find" | "grep" | "look" => {
|
||||
&["content_search", "glob_search", "web_search_tool"]
|
||||
}
|
||||
"browser" | "website" | "url" | "http" | "fetch" => &["web_fetch", "web_search_tool"],
|
||||
"image" | "screenshot" | "picture" => &["image_info"],
|
||||
"git" | "commit" | "branch" | "push" | "pull" => &["git_operations", "shell"],
|
||||
_ => &[],
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_message(role: &str, content: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: role.to_string(),
|
||||
content: content.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iteration_zero_returns_empty_suggestions() {
|
||||
let history = vec![make_message("user", "hello")];
|
||||
let signals = analyze_turn_context(&history, "do something", 0, &[]);
|
||||
assert!(signals.suggested_tools.is_empty());
|
||||
assert!(signals.history_relevant);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iteration_one_includes_last_tools() {
|
||||
let history = vec![
|
||||
make_message("user", "hello"),
|
||||
make_message("assistant", "sure"),
|
||||
];
|
||||
let last_tools = vec!["shell".to_string(), "file_read".to_string()];
|
||||
let signals = analyze_turn_context(&history, "next step", 1, &last_tools);
|
||||
assert!(signals.suggested_tools.contains(&"shell".to_string()));
|
||||
assert!(signals.suggested_tools.contains(&"file_read".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyword_extraction_from_assistant_message() {
|
||||
let history = vec![
|
||||
make_message("user", "help me"),
|
||||
make_message("assistant", "I will read the file at that path"),
|
||||
];
|
||||
let signals = analyze_turn_context(&history, "ok", 1, &[]);
|
||||
assert!(signals.suggested_tools.contains(&"file_read".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_keywords_suggest_shell_tool() {
|
||||
let history = vec![
|
||||
make_message("user", "build the project"),
|
||||
make_message("assistant", "I will run the build command"),
|
||||
];
|
||||
let signals = analyze_turn_context(&history, "go", 1, &[]);
|
||||
assert!(signals.suggested_tools.contains(&"shell".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_keywords_suggest_memory_tools() {
|
||||
let history = vec![
|
||||
make_message("user", "save this"),
|
||||
make_message("assistant", "I will store that in memory"),
|
||||
];
|
||||
let signals = analyze_turn_context(&history, "ok", 1, &[]);
|
||||
assert!(signals
|
||||
.suggested_tools
|
||||
.contains(&"memory_store".to_string()));
|
||||
assert!(signals
|
||||
.suggested_tools
|
||||
.contains(&"memory_recall".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_keywords_merge_tools() {
|
||||
let history = vec![
|
||||
make_message("user", "do stuff"),
|
||||
make_message(
|
||||
"assistant",
|
||||
"I need to read the file and run a shell command to search",
|
||||
),
|
||||
];
|
||||
let signals = analyze_turn_context(&history, "go", 1, &[]);
|
||||
assert!(signals.suggested_tools.contains(&"file_read".to_string()));
|
||||
assert!(signals.suggested_tools.contains(&"shell".to_string()));
|
||||
assert!(signals
|
||||
.suggested_tools
|
||||
.contains(&"content_search".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_history_iteration_one() {
|
||||
let history: Vec<ChatMessage> = vec![];
|
||||
let signals = analyze_turn_context(&history, "hello", 1, &[]);
|
||||
assert!(signals.suggested_tools.is_empty());
|
||||
}
|
||||
}
|
||||
415
src/agent/eval.rs
Normal file
415
src/agent/eval.rs
Normal file
@ -0,0 +1,415 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use schemars::JsonSchema;
|
||||
|
||||
// ── Complexity estimation ───────────────────────────────────────
|
||||
|
||||
/// Coarse complexity tier for a user message.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ComplexityTier {
|
||||
/// Short, simple query (greetings, yes/no, lookups).
|
||||
Simple,
|
||||
/// Typical request — not trivially simple, not deeply complex.
|
||||
Standard,
|
||||
/// Long or reasoning-heavy request (code, multi-step, analysis).
|
||||
Complex,
|
||||
}
|
||||
|
||||
/// Heuristic keywords that signal reasoning complexity.
|
||||
const REASONING_KEYWORDS: &[&str] = &[
|
||||
"explain",
|
||||
"why",
|
||||
"analyze",
|
||||
"compare",
|
||||
"design",
|
||||
"implement",
|
||||
"refactor",
|
||||
"debug",
|
||||
"optimize",
|
||||
"architecture",
|
||||
"trade-off",
|
||||
"tradeoff",
|
||||
"reasoning",
|
||||
"step by step",
|
||||
"think through",
|
||||
"evaluate",
|
||||
"critique",
|
||||
"pros and cons",
|
||||
];
|
||||
|
||||
/// Estimate the complexity of a user message without an LLM call.
|
||||
///
|
||||
/// Rules (applied in order):
|
||||
/// - **Complex**: message > 200 chars, OR contains a code fence, OR ≥ 2
|
||||
/// reasoning keywords.
|
||||
/// - **Simple**: message < 50 chars AND no reasoning keywords.
|
||||
/// - **Standard**: everything else.
|
||||
pub fn estimate_complexity(message: &str) -> ComplexityTier {
|
||||
let lower = message.to_lowercase();
|
||||
let len = message.len();
|
||||
|
||||
let keyword_count = REASONING_KEYWORDS
|
||||
.iter()
|
||||
.filter(|kw| lower.contains(**kw))
|
||||
.count();
|
||||
|
||||
let has_code_fence = message.contains("```");
|
||||
|
||||
if len > 200 || has_code_fence || keyword_count >= 2 {
|
||||
return ComplexityTier::Complex;
|
||||
}
|
||||
|
||||
if len < 50 && keyword_count == 0 {
|
||||
return ComplexityTier::Simple;
|
||||
}
|
||||
|
||||
ComplexityTier::Standard
|
||||
}
|
||||
|
||||
// ── Auto-classify config ────────────────────────────────────────
|
||||
|
||||
/// Configuration for automatic complexity-based classification.
|
||||
///
|
||||
/// When the rule-based classifier in `QueryClassificationConfig` produces no
|
||||
/// match, the eval layer can fall back to `estimate_complexity` and map the
|
||||
/// resulting tier to a routing hint.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AutoClassifyConfig {
|
||||
/// Hint to use for `Simple` complexity tier (e.g. `"fast"`).
|
||||
#[serde(default)]
|
||||
pub simple_hint: Option<String>,
|
||||
/// Hint to use for `Standard` complexity tier.
|
||||
#[serde(default)]
|
||||
pub standard_hint: Option<String>,
|
||||
/// Hint to use for `Complex` complexity tier (e.g. `"reasoning"`).
|
||||
#[serde(default)]
|
||||
pub complex_hint: Option<String>,
|
||||
}
|
||||
|
||||
impl AutoClassifyConfig {
|
||||
/// Map a complexity tier to the configured hint, if any.
|
||||
pub fn hint_for(&self, tier: ComplexityTier) -> Option<&str> {
|
||||
match tier {
|
||||
ComplexityTier::Simple => self.simple_hint.as_deref(),
|
||||
ComplexityTier::Standard => self.standard_hint.as_deref(),
|
||||
ComplexityTier::Complex => self.complex_hint.as_deref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Post-response eval ──────────────────────────────────────────
|
||||
|
||||
/// Configuration for the post-response quality evaluator.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct EvalConfig {
|
||||
/// Enable the eval quality gate.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Minimum quality score (0.0–1.0) to accept a response.
|
||||
/// Below this threshold, a retry with a higher-tier model is suggested.
|
||||
#[serde(default = "default_min_quality_score")]
|
||||
pub min_quality_score: f64,
|
||||
/// Maximum retries with escalated models before accepting whatever we get.
|
||||
#[serde(default = "default_max_retries")]
|
||||
pub max_retries: u32,
|
||||
}
|
||||
|
||||
fn default_min_quality_score() -> f64 {
|
||||
0.5
|
||||
}
|
||||
|
||||
fn default_max_retries() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
impl Default for EvalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
min_quality_score: default_min_quality_score(),
|
||||
max_retries: default_max_retries(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of evaluating a response against quality heuristics.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvalResult {
|
||||
/// Aggregate quality score from 0.0 (terrible) to 1.0 (excellent).
|
||||
pub score: f64,
|
||||
/// Individual check outcomes (for observability).
|
||||
pub checks: Vec<EvalCheck>,
|
||||
/// If score < threshold, the suggested higher-tier hint for retry.
|
||||
pub retry_hint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvalCheck {
|
||||
pub name: &'static str,
|
||||
pub passed: bool,
|
||||
pub weight: f64,
|
||||
}
|
||||
|
||||
/// Code-related keywords in user queries.
|
||||
const CODE_KEYWORDS: &[&str] = &[
|
||||
"code",
|
||||
"function",
|
||||
"implement",
|
||||
"class",
|
||||
"struct",
|
||||
"module",
|
||||
"script",
|
||||
"program",
|
||||
"bug",
|
||||
"error",
|
||||
"compile",
|
||||
"syntax",
|
||||
"refactor",
|
||||
];
|
||||
|
||||
/// Evaluate a response against heuristic quality checks. No LLM call.
|
||||
///
|
||||
/// Checks:
|
||||
/// 1. **Non-empty**: response must not be empty.
|
||||
/// 2. **Not a cop-out**: response must not be just "I don't know" or similar.
|
||||
/// 3. **Sufficient length**: response length should be proportional to query complexity.
|
||||
/// 4. **Code presence**: if the query mentions code keywords, the response should
|
||||
/// contain a code block.
|
||||
pub fn evaluate_response(
|
||||
query: &str,
|
||||
response: &str,
|
||||
complexity: ComplexityTier,
|
||||
auto_classify: Option<&AutoClassifyConfig>,
|
||||
) -> EvalResult {
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// Check 1: Non-empty
|
||||
let non_empty = !response.trim().is_empty();
|
||||
checks.push(EvalCheck {
|
||||
name: "non_empty",
|
||||
passed: non_empty,
|
||||
weight: 0.3,
|
||||
});
|
||||
|
||||
// Check 2: Not a cop-out
|
||||
let lower_resp = response.to_lowercase();
|
||||
let cop_out_phrases = [
|
||||
"i don't know",
|
||||
"i'm not sure",
|
||||
"i cannot",
|
||||
"i can't help",
|
||||
"as an ai",
|
||||
];
|
||||
let is_cop_out = cop_out_phrases
|
||||
.iter()
|
||||
.any(|phrase| lower_resp.starts_with(phrase));
|
||||
let not_cop_out = !is_cop_out || response.len() > 200; // long responses with caveats are fine
|
||||
checks.push(EvalCheck {
|
||||
name: "not_cop_out",
|
||||
passed: not_cop_out,
|
||||
weight: 0.25,
|
||||
});
|
||||
|
||||
// Check 3: Sufficient length for complexity
|
||||
let min_len = match complexity {
|
||||
ComplexityTier::Simple => 5,
|
||||
ComplexityTier::Standard => 20,
|
||||
ComplexityTier::Complex => 50,
|
||||
};
|
||||
let sufficient_length = response.len() >= min_len;
|
||||
checks.push(EvalCheck {
|
||||
name: "sufficient_length",
|
||||
passed: sufficient_length,
|
||||
weight: 0.2,
|
||||
});
|
||||
|
||||
// Check 4: Code presence when expected
|
||||
let query_lower = query.to_lowercase();
|
||||
let expects_code = CODE_KEYWORDS.iter().any(|kw| query_lower.contains(kw));
|
||||
let has_code = response.contains("```") || response.contains(" "); // code block or indented
|
||||
let code_check_passed = !expects_code || has_code;
|
||||
checks.push(EvalCheck {
|
||||
name: "code_presence",
|
||||
passed: code_check_passed,
|
||||
weight: 0.25,
|
||||
});
|
||||
|
||||
// Compute weighted score
|
||||
let total_weight: f64 = checks.iter().map(|c| c.weight).sum();
|
||||
let earned: f64 = checks.iter().filter(|c| c.passed).map(|c| c.weight).sum();
|
||||
let score = if total_weight > 0.0 {
|
||||
earned / total_weight
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
// Determine retry hint: if score is low, suggest escalating
|
||||
let retry_hint = if score <= default_min_quality_score() {
|
||||
// Try to escalate: Simple→Standard→Complex
|
||||
let next_tier = match complexity {
|
||||
ComplexityTier::Simple => Some(ComplexityTier::Standard),
|
||||
ComplexityTier::Standard => Some(ComplexityTier::Complex),
|
||||
ComplexityTier::Complex => None, // already at max
|
||||
};
|
||||
next_tier.and_then(|tier| {
|
||||
auto_classify
|
||||
.and_then(|ac| ac.hint_for(tier))
|
||||
.map(String::from)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
EvalResult {
|
||||
score,
|
||||
checks,
|
||||
retry_hint,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── estimate_complexity ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn simple_short_message() {
|
||||
assert_eq!(estimate_complexity("hi"), ComplexityTier::Simple);
|
||||
assert_eq!(estimate_complexity("hello"), ComplexityTier::Simple);
|
||||
assert_eq!(estimate_complexity("yes"), ComplexityTier::Simple);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn complex_long_message() {
|
||||
let long = "a".repeat(201);
|
||||
assert_eq!(estimate_complexity(&long), ComplexityTier::Complex);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn complex_code_fence() {
|
||||
let msg = "Here is some code:\n```rust\nfn main() {}\n```";
|
||||
assert_eq!(estimate_complexity(msg), ComplexityTier::Complex);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn complex_multiple_reasoning_keywords() {
|
||||
let msg = "Please explain why this design is better and analyze the trade-off";
|
||||
assert_eq!(estimate_complexity(msg), ComplexityTier::Complex);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn standard_medium_message() {
|
||||
// 50+ chars but no code fence, < 2 reasoning keywords
|
||||
let msg = "Can you help me find a good restaurant in this area please?";
|
||||
assert_eq!(estimate_complexity(msg), ComplexityTier::Standard);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn standard_short_with_one_keyword() {
|
||||
// < 50 chars but has 1 reasoning keyword → still not Simple
|
||||
let msg = "explain this";
|
||||
assert_eq!(estimate_complexity(msg), ComplexityTier::Standard);
|
||||
}
|
||||
|
||||
// ── auto_classify ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn auto_classify_maps_tiers_to_hints() {
|
||||
let ac = AutoClassifyConfig {
|
||||
simple_hint: Some("fast".into()),
|
||||
standard_hint: None,
|
||||
complex_hint: Some("reasoning".into()),
|
||||
};
|
||||
assert_eq!(ac.hint_for(ComplexityTier::Simple), Some("fast"));
|
||||
assert_eq!(ac.hint_for(ComplexityTier::Standard), None);
|
||||
assert_eq!(ac.hint_for(ComplexityTier::Complex), Some("reasoning"));
|
||||
}
|
||||
|
||||
// ── evaluate_response ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_response_scores_low() {
|
||||
let result = evaluate_response("hello", "", ComplexityTier::Simple, None);
|
||||
assert!(result.score <= 0.5, "empty response should score low");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn good_response_scores_high() {
|
||||
let result = evaluate_response(
|
||||
"what is 2+2?",
|
||||
"The answer is 4.",
|
||||
ComplexityTier::Simple,
|
||||
None,
|
||||
);
|
||||
assert!(
|
||||
result.score >= 0.9,
|
||||
"good simple response should score high, got {}",
|
||||
result.score
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cop_out_response_penalized() {
|
||||
let result = evaluate_response(
|
||||
"explain quantum computing",
|
||||
"I don't know much about that.",
|
||||
ComplexityTier::Standard,
|
||||
None,
|
||||
);
|
||||
assert!(
|
||||
result.score < 1.0,
|
||||
"cop-out should be penalized, got {}",
|
||||
result.score
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_query_without_code_response_penalized() {
|
||||
let result = evaluate_response(
|
||||
"write a function to sort an array",
|
||||
"You should use a sorting algorithm.",
|
||||
ComplexityTier::Standard,
|
||||
None,
|
||||
);
|
||||
// "code_presence" check should fail
|
||||
let code_check = result.checks.iter().find(|c| c.name == "code_presence");
|
||||
assert!(
|
||||
code_check.is_some() && !code_check.unwrap().passed,
|
||||
"code check should fail"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retry_hint_escalation() {
|
||||
let ac = AutoClassifyConfig {
|
||||
simple_hint: Some("fast".into()),
|
||||
standard_hint: Some("default".into()),
|
||||
complex_hint: Some("reasoning".into()),
|
||||
};
|
||||
// Empty response for a Simple query → should suggest Standard hint
|
||||
let result = evaluate_response("hello", "", ComplexityTier::Simple, Some(&ac));
|
||||
assert_eq!(result.retry_hint, Some("default".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_retry_when_already_complex() {
|
||||
let ac = AutoClassifyConfig {
|
||||
simple_hint: Some("fast".into()),
|
||||
standard_hint: Some("default".into()),
|
||||
complex_hint: Some("reasoning".into()),
|
||||
};
|
||||
// Empty response for Complex → no escalation possible
|
||||
let result =
|
||||
evaluate_response("explain everything", "", ComplexityTier::Complex, Some(&ac));
|
||||
assert_eq!(result.retry_hint, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_retries_defaults() {
|
||||
let config = EvalConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_retries, 1);
|
||||
assert!((config.min_quality_score - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
283
src/agent/history_pruner.rs
Normal file
283
src/agent/history_pruner.rs
Normal file
@ -0,0 +1,283 @@
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn default_max_tokens() -> usize {
|
||||
8192
|
||||
}
|
||||
|
||||
fn default_keep_recent() -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
fn default_collapse() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct HistoryPrunerConfig {
|
||||
/// Enable history pruning. Default: false.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Maximum estimated tokens for message history. Default: 8192.
|
||||
#[serde(default = "default_max_tokens")]
|
||||
pub max_tokens: usize,
|
||||
/// Keep the N most recent messages untouched. Default: 4.
|
||||
#[serde(default = "default_keep_recent")]
|
||||
pub keep_recent: usize,
|
||||
/// Collapse old tool call/result pairs into short summaries. Default: true.
|
||||
#[serde(default = "default_collapse")]
|
||||
pub collapse_tool_results: bool,
|
||||
}
|
||||
|
||||
impl Default for HistoryPrunerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_tokens: 8192,
|
||||
keep_recent: 4,
|
||||
collapse_tool_results: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stats
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PruneStats {
|
||||
pub messages_before: usize,
|
||||
pub messages_after: usize,
|
||||
pub collapsed_pairs: usize,
|
||||
pub dropped_messages: usize,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Token estimation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn estimate_tokens(messages: &[ChatMessage]) -> usize {
|
||||
messages.iter().map(|m| m.content.len() / 4).sum()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Protected-index helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn protected_indices(messages: &[ChatMessage], keep_recent: usize) -> Vec<bool> {
|
||||
let len = messages.len();
|
||||
let mut protected = vec![false; len];
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
if msg.role == "system" {
|
||||
protected[i] = true;
|
||||
}
|
||||
}
|
||||
let recent_start = len.saturating_sub(keep_recent);
|
||||
for p in protected.iter_mut().skip(recent_start) {
|
||||
*p = true;
|
||||
}
|
||||
protected
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub fn prune_history(messages: &mut Vec<ChatMessage>, config: &HistoryPrunerConfig) -> PruneStats {
|
||||
let messages_before = messages.len();
|
||||
if !config.enabled || messages.is_empty() {
|
||||
return PruneStats {
|
||||
messages_before,
|
||||
messages_after: messages_before,
|
||||
collapsed_pairs: 0,
|
||||
dropped_messages: 0,
|
||||
};
|
||||
}
|
||||
|
||||
let mut collapsed_pairs: usize = 0;
|
||||
|
||||
// Phase 1 – collapse assistant+tool pairs
|
||||
if config.collapse_tool_results {
|
||||
let mut i = 0;
|
||||
while i + 1 < messages.len() {
|
||||
let protected = protected_indices(messages, config.keep_recent);
|
||||
if messages[i].role == "assistant"
|
||||
&& messages[i + 1].role == "tool"
|
||||
&& !protected[i]
|
||||
&& !protected[i + 1]
|
||||
{
|
||||
let tool_content = &messages[i + 1].content;
|
||||
let truncated: String = tool_content.chars().take(100).collect();
|
||||
let summary = format!("[Tool result: {truncated}...]");
|
||||
messages[i] = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: summary,
|
||||
};
|
||||
messages.remove(i + 1);
|
||||
collapsed_pairs += 1;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2 – budget enforcement
|
||||
let mut dropped_messages: usize = 0;
|
||||
while estimate_tokens(messages) > config.max_tokens {
|
||||
let protected = protected_indices(messages, config.keep_recent);
|
||||
if let Some(idx) = protected
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &p)| !p)
|
||||
.map(|(i, _)| i)
|
||||
{
|
||||
messages.remove(idx);
|
||||
dropped_messages += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
PruneStats {
|
||||
messages_before,
|
||||
messages_after: messages.len(),
|
||||
collapsed_pairs,
|
||||
dropped_messages,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn msg(role: &str, content: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: role.to_string(),
|
||||
content: content.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_disabled_is_noop() {
|
||||
let mut messages = vec![
|
||||
msg("system", "You are helpful."),
|
||||
msg("user", "Hello"),
|
||||
msg("assistant", "Hi there!"),
|
||||
];
|
||||
let config = HistoryPrunerConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
};
|
||||
let stats = prune_history(&mut messages, &config);
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(messages[0].content, "You are helpful.");
|
||||
assert_eq!(stats.messages_before, 3);
|
||||
assert_eq!(stats.messages_after, 3);
|
||||
assert_eq!(stats.collapsed_pairs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_under_budget_no_change() {
|
||||
let mut messages = vec![
|
||||
msg("system", "You are helpful."),
|
||||
msg("user", "Hello"),
|
||||
msg("assistant", "Hi!"),
|
||||
];
|
||||
let config = HistoryPrunerConfig {
|
||||
enabled: true,
|
||||
max_tokens: 8192,
|
||||
keep_recent: 2,
|
||||
collapse_tool_results: false,
|
||||
};
|
||||
let stats = prune_history(&mut messages, &config);
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(stats.collapsed_pairs, 0);
|
||||
assert_eq!(stats.dropped_messages, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_collapses_tool_pairs() {
|
||||
let tool_result = "a".repeat(160);
|
||||
let mut messages = vec![
|
||||
msg("system", "sys"),
|
||||
msg("assistant", "calling tool X"),
|
||||
msg("tool", &tool_result),
|
||||
msg("user", "thanks"),
|
||||
msg("assistant", "done"),
|
||||
];
|
||||
let config = HistoryPrunerConfig {
|
||||
enabled: true,
|
||||
max_tokens: 100_000,
|
||||
keep_recent: 2,
|
||||
collapse_tool_results: true,
|
||||
};
|
||||
let stats = prune_history(&mut messages, &config);
|
||||
assert_eq!(stats.collapsed_pairs, 1);
|
||||
assert_eq!(messages.len(), 4);
|
||||
assert_eq!(messages[1].role, "assistant");
|
||||
assert!(messages[1].content.starts_with("[Tool result: "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_preserves_system_and_recent() {
|
||||
let big = "x".repeat(40_000);
|
||||
let mut messages = vec![
|
||||
msg("system", "system prompt"),
|
||||
msg("user", &big),
|
||||
msg("assistant", "old reply"),
|
||||
msg("user", "recent1"),
|
||||
msg("assistant", "recent2"),
|
||||
];
|
||||
let config = HistoryPrunerConfig {
|
||||
enabled: true,
|
||||
max_tokens: 100,
|
||||
keep_recent: 2,
|
||||
collapse_tool_results: false,
|
||||
};
|
||||
let stats = prune_history(&mut messages, &config);
|
||||
assert!(messages.iter().any(|m| m.role == "system"));
|
||||
assert!(messages.iter().any(|m| m.content == "recent1"));
|
||||
assert!(messages.iter().any(|m| m.content == "recent2"));
|
||||
assert!(stats.dropped_messages > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_drops_oldest_when_over_budget() {
|
||||
let filler = "y".repeat(400);
|
||||
let mut messages = vec![
|
||||
msg("system", "sys"),
|
||||
msg("user", &filler),
|
||||
msg("assistant", &filler),
|
||||
msg("user", "recent-user"),
|
||||
msg("assistant", "recent-assistant"),
|
||||
];
|
||||
let config = HistoryPrunerConfig {
|
||||
enabled: true,
|
||||
max_tokens: 150,
|
||||
keep_recent: 2,
|
||||
collapse_tool_results: false,
|
||||
};
|
||||
let stats = prune_history(&mut messages, &config);
|
||||
assert!(stats.dropped_messages >= 1);
|
||||
assert_eq!(messages[0].role, "system");
|
||||
assert!(messages.iter().any(|m| m.content == "recent-user"));
|
||||
assert!(messages.iter().any(|m| m.content == "recent-assistant"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_empty_messages() {
|
||||
let mut messages: Vec<ChatMessage> = vec![];
|
||||
let config = HistoryPrunerConfig {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
};
|
||||
let stats = prune_history(&mut messages, &config);
|
||||
assert_eq!(stats.messages_before, 0);
|
||||
assert_eq!(stats.messages_after, 0);
|
||||
}
|
||||
}
|
||||
@ -3622,23 +3622,28 @@ pub async fn run(
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (mut tools_registry, delegate_handle, _reaction_handle, _channel_map_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let (
|
||||
mut tools_registry,
|
||||
delegate_handle,
|
||||
_reaction_handle,
|
||||
_channel_map_handle,
|
||||
_ask_user_handle,
|
||||
) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
||||
@ -4039,6 +4044,14 @@ pub async fn run(
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
|
||||
// Prune history for token efficiency (when enabled).
|
||||
if config.agent.history_pruning.enabled {
|
||||
let _stats = crate::agent::history_pruner::prune_history(
|
||||
&mut history,
|
||||
&config.agent.history_pruning,
|
||||
);
|
||||
}
|
||||
|
||||
// Compute per-turn excluded MCP tools from tool_filter_groups.
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
@ -4460,23 +4473,28 @@ pub async fn process_message(
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (mut tools_registry, delegate_handle_pm, _reaction_handle_pm, _channel_map_handle_pm) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let (
|
||||
mut tools_registry,
|
||||
delegate_handle_pm,
|
||||
_reaction_handle_pm,
|
||||
_channel_map_handle_pm,
|
||||
_ask_user_handle_pm,
|
||||
) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
||||
tools_registry.extend(peripheral_tools);
|
||||
@ -8214,6 +8232,7 @@ Let me check the result."#;
|
||||
mode: ToolFilterGroupMode::Always,
|
||||
tools: vec!["mcp_filesystem_*".into()],
|
||||
keywords: vec![],
|
||||
filter_builtins: false,
|
||||
}];
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "anything");
|
||||
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
|
||||
@ -8232,6 +8251,7 @@ Let me check the result."#;
|
||||
mode: ToolFilterGroupMode::Dynamic,
|
||||
tools: vec!["mcp_browser_*".into()],
|
||||
keywords: vec!["browse".into(), "website".into()],
|
||||
filter_builtins: false,
|
||||
}];
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "please browse this page");
|
||||
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
|
||||
@ -8248,6 +8268,7 @@ Let me check the result."#;
|
||||
mode: ToolFilterGroupMode::Dynamic,
|
||||
tools: vec!["mcp_browser_*".into()],
|
||||
keywords: vec!["browse".into(), "website".into()],
|
||||
filter_builtins: false,
|
||||
}];
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "read the file /etc/hosts");
|
||||
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
|
||||
@ -8264,6 +8285,7 @@ Let me check the result."#;
|
||||
mode: ToolFilterGroupMode::Dynamic,
|
||||
tools: vec!["mcp_browser_*".into()],
|
||||
keywords: vec!["Browse".into()],
|
||||
filter_builtins: false,
|
||||
}];
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "BROWSE the site");
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod agent;
|
||||
pub mod classifier;
|
||||
pub mod context_analyzer;
|
||||
pub mod dispatcher;
|
||||
pub mod eval;
|
||||
pub mod history_pruner;
|
||||
pub mod loop_;
|
||||
pub mod loop_detector;
|
||||
pub mod memory_loader;
|
||||
|
||||
@ -162,7 +162,12 @@ impl Channel for DingTalkChannel {
|
||||
let ws_url = format!("{}?ticket={}", gw.endpoint, gw.ticket);
|
||||
|
||||
tracing::info!("DingTalk: connecting to stream WebSocket...");
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (ws_stream, _) = crate::config::ws_connect_with_proxy(
|
||||
&ws_url,
|
||||
"channel.dingtalk",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
tracing::info!("DingTalk: connected and listening for messages...");
|
||||
|
||||
@ -23,6 +23,7 @@ pub struct DiscordChannel {
|
||||
/// Voice transcription config — when set, audio attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<std::sync::Arc<super::transcription::TranscriptionManager>>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@ -42,6 +43,7 @@ impl DiscordChannel {
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
transcription_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,8 +55,19 @@ impl DiscordChannel {
|
||||
|
||||
/// Configure voice transcription for audio attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
if !config.enabled {
|
||||
return self;
|
||||
}
|
||||
match super::transcription::TranscriptionManager::new(&config) {
|
||||
Ok(m) => {
|
||||
self.transcription_manager = Some(std::sync::Arc::new(m));
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"transcription manager init failed, voice transcription disabled: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
@ -148,7 +161,7 @@ fn is_discord_audio_attachment(content_type: &str, filename: &str) -> bool {
|
||||
async fn transcribe_discord_audio_attachments(
|
||||
attachments: &[serde_json::Value],
|
||||
client: &reqwest::Client,
|
||||
config: &crate::config::TranscriptionConfig,
|
||||
manager: &super::transcription::TranscriptionManager,
|
||||
) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
for att in attachments {
|
||||
@ -187,7 +200,7 @@ async fn transcribe_discord_audio_attachments(
|
||||
}
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, name, config).await {
|
||||
match manager.transcribe(&audio_data, name).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
@ -662,7 +675,12 @@ impl Channel for DiscordChannel {
|
||||
let ws_url = format!("{gw_url}/?v=10&encoding=json");
|
||||
tracing::info!("Discord: connecting to gateway...");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (ws_stream, _) = crate::config::ws_connect_with_proxy(
|
||||
&ws_url,
|
||||
"channel.discord",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
@ -835,11 +853,11 @@ impl Channel for DiscordChannel {
|
||||
let mut text_parts = process_attachments(&atts, &client).await;
|
||||
|
||||
// Transcribe audio attachments when transcription is configured
|
||||
if let Some(ref transcription_config) = self.transcription {
|
||||
if let Some(ref transcription_manager) = self.transcription_manager {
|
||||
let voice_text = transcribe_discord_audio_attachments(
|
||||
&atts,
|
||||
&client,
|
||||
transcription_config,
|
||||
transcription_manager,
|
||||
)
|
||||
.await;
|
||||
if !voice_text.is_empty() {
|
||||
|
||||
@ -240,7 +240,12 @@ impl Channel for DiscordHistoryChannel {
|
||||
let ws_url = format!("{gw_url}/?v=10&encoding=json");
|
||||
tracing::info!("DiscordHistory: connecting to gateway...");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (ws_stream, _) = crate::config::ws_connect_with_proxy(
|
||||
&ws_url,
|
||||
"channel.discord",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
|
||||
@ -734,7 +734,12 @@ impl LarkChannel {
|
||||
.unwrap_or(0);
|
||||
tracing::info!("Lark: connecting to {wss_url}");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
|
||||
let (ws_stream, _) = crate::config::ws_connect_with_proxy(
|
||||
&wss_url,
|
||||
"channel.lark",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
tracing::info!("Lark: WS connected (service_id={service_id})");
|
||||
|
||||
@ -1932,6 +1937,7 @@ fn pick_uniform_index(len: usize) -> usize {
|
||||
loop {
|
||||
let value = rand::random::<u64>();
|
||||
if value < reject_threshold {
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
return (value % upper) as usize;
|
||||
}
|
||||
}
|
||||
@ -3335,7 +3341,7 @@ mod tests {
|
||||
let tc = crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
default_provider: "groq".to_string(),
|
||||
api_key: Some("".to_string()),
|
||||
api_key: Some(String::new()),
|
||||
..Default::default()
|
||||
};
|
||||
let ch = make_channel().with_transcription(tc);
|
||||
@ -3461,10 +3467,15 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn lark_audio_file_key_missing_returns_none() {
|
||||
let ch = make_channel();
|
||||
let tc = crate::config::TranscriptionConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
};
|
||||
let mut tc = crate::config::TranscriptionConfig::default();
|
||||
tc.enabled = true;
|
||||
tc.default_provider = "local_whisper".to_string();
|
||||
tc.local_whisper = Some(crate::config::LocalWhisperConfig {
|
||||
url: "http://localhost:0/v1/transcribe".to_string(),
|
||||
bearer_token: "unused".to_string(),
|
||||
max_audio_bytes: 10 * 1024 * 1024,
|
||||
timeout_secs: 30,
|
||||
});
|
||||
let ch = ch.with_transcription(tc);
|
||||
let manager = ch.transcription_manager.as_deref().unwrap();
|
||||
|
||||
@ -3540,6 +3551,7 @@ mod tests {
|
||||
.await;
|
||||
|
||||
let mut config = crate::config::TranscriptionConfig::default();
|
||||
config.enabled = true;
|
||||
config.local_whisper = Some(crate::config::LocalWhisperConfig {
|
||||
url: format!("{}/v1/transcribe", whisper_server.uri()),
|
||||
bearer_token: "test-token".to_string(),
|
||||
@ -3553,6 +3565,9 @@ mod tests {
|
||||
let ch = ch.with_transcription(config);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"header": {
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": { "open_id": "ou_testuser123" }
|
||||
@ -3596,7 +3611,7 @@ mod tests {
|
||||
Mock::given(method("GET"))
|
||||
.and(path_regex("/im/v1/messages/.+/resources/.+"))
|
||||
.respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
|
||||
"code": 99991663,
|
||||
"code": 99_991_663,
|
||||
"msg": "token invalid"
|
||||
})))
|
||||
.up_to_n_times(1)
|
||||
|
||||
@ -42,6 +42,8 @@ pub struct MatrixChannel {
|
||||
http_client: Client,
|
||||
reaction_events: Arc<RwLock<HashMap<String, String>>>,
|
||||
voice_mode: Arc<AtomicBool>,
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<Arc<super::transcription::TranscriptionManager>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MatrixChannel {
|
||||
@ -215,9 +217,30 @@ impl MatrixChannel {
|
||||
http_client: Client::new(),
|
||||
reaction_events: Arc::new(RwLock::new(HashMap::new())),
|
||||
voice_mode: Arc::new(AtomicBool::new(false)),
|
||||
transcription: None,
|
||||
transcription_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio messages.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if !config.enabled {
|
||||
return self;
|
||||
}
|
||||
match super::transcription::TranscriptionManager::new(&config) {
|
||||
Ok(m) => {
|
||||
self.transcription_manager = Some(Arc::new(m));
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"transcription manager init failed, voice transcription disabled: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn encode_path_segment(value: &str) -> String {
|
||||
fn should_encode(byte: u8) -> bool {
|
||||
!matches!(
|
||||
@ -763,6 +786,7 @@ impl Channel for MatrixChannel {
|
||||
let homeserver_for_handler = self.homeserver.clone();
|
||||
let access_token_for_handler = self.access_token.clone();
|
||||
let voice_mode_for_handler = Arc::clone(&self.voice_mode);
|
||||
let transcription_mgr_for_handler = self.transcription_manager.clone();
|
||||
|
||||
client.add_event_handler(move |event: OriginalSyncRoomMessageEvent, room: Room| {
|
||||
let tx = tx_handler.clone();
|
||||
@ -774,6 +798,7 @@ impl Channel for MatrixChannel {
|
||||
let homeserver = homeserver_for_handler.clone();
|
||||
let access_token = access_token_for_handler.clone();
|
||||
let voice_mode = Arc::clone(&voice_mode_for_handler);
|
||||
let transcription_mgr = transcription_mgr_for_handler.clone();
|
||||
|
||||
async move {
|
||||
if !MatrixChannel::room_matches_target(
|
||||
@ -875,51 +900,36 @@ impl Channel for MatrixChannel {
|
||||
|
||||
// Voice transcription: if this was an audio message, transcribe it
|
||||
let body = if body.starts_with("[audio:") {
|
||||
if let Some(path_start) = body.find("saved to ") {
|
||||
if let (Some(path_start), Some(ref manager)) = (body.find("saved to "), &transcription_mgr) {
|
||||
let audio_path = body[path_start + 9..].to_string();
|
||||
let wav_path = format!("{}.16k.wav", audio_path);
|
||||
let convert_ok = tokio::process::Command::new("ffmpeg")
|
||||
.args([
|
||||
"-y",
|
||||
"-i",
|
||||
&audio_path,
|
||||
"-ar",
|
||||
"16000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-f",
|
||||
"wav",
|
||||
&wav_path,
|
||||
])
|
||||
.stderr(std::process::Stdio::null())
|
||||
.output()
|
||||
.await
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false);
|
||||
if convert_ok {
|
||||
let transcription = tokio::process::Command::new("whisper-cpp")
|
||||
.args([
|
||||
"-m",
|
||||
"/tmp/ggml-base.en.bin",
|
||||
"-f",
|
||||
&wav_path,
|
||||
"--no-timestamps",
|
||||
"-nt",
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.ok()
|
||||
.filter(|o| o.status.success())
|
||||
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||
.filter(|s| !s.is_empty());
|
||||
if let Some(text) = transcription {
|
||||
voice_mode.store(true, Ordering::Relaxed);
|
||||
format!("[Voice message]: {}", text)
|
||||
} else {
|
||||
let file_name = audio_path
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("audio.ogg")
|
||||
.to_string();
|
||||
match tokio::fs::read(&audio_path).await {
|
||||
Ok(audio_data) => {
|
||||
match manager.transcribe(&audio_data, &file_name).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::info!("Matrix: voice transcription returned empty text, skipping");
|
||||
body
|
||||
} else {
|
||||
voice_mode.store(true, Ordering::Relaxed);
|
||||
format!("[Voice message]: {}", trimmed)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Matrix: voice transcription failed: {e}");
|
||||
body
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Matrix: failed to read audio file {}: {e}", audio_path);
|
||||
body
|
||||
}
|
||||
} else {
|
||||
body
|
||||
}
|
||||
} else {
|
||||
body
|
||||
@ -1226,6 +1236,31 @@ impl Channel for MatrixChannel {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn redact_message(
|
||||
&self,
|
||||
_channel_id: &str,
|
||||
message_id: &str,
|
||||
reason: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let client = self
|
||||
.sdk_client
|
||||
.get()
|
||||
.ok_or_else(|| anyhow::anyhow!("Matrix SDK client not initialized"))?;
|
||||
|
||||
let target_room_id = self.target_room_id().await?;
|
||||
let target_room: OwnedRoomId = target_room_id.parse()?;
|
||||
let room = client
|
||||
.get_room(&target_room)
|
||||
.ok_or_else(|| anyhow::anyhow!("Matrix room not found for message redaction"))?;
|
||||
|
||||
let event_id: OwnedEventId = message_id
|
||||
.parse()
|
||||
.map_err(|_| anyhow::anyhow!("Invalid event ID: {}", message_id))?;
|
||||
|
||||
room.redact(&event_id, reason.as_deref(), None).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@ -1212,6 +1212,7 @@ mod tests {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
},
|
||||
);
|
||||
assert!(ch.transcription_manager.is_some());
|
||||
@ -1234,6 +1235,7 @@ mod tests {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
},
|
||||
);
|
||||
assert!(ch.transcription_manager.is_none());
|
||||
@ -1376,6 +1378,7 @@ mod tests {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
},
|
||||
);
|
||||
|
||||
@ -1448,6 +1451,7 @@ mod tests {
|
||||
max_audio_bytes: 25_000_000,
|
||||
timeout_secs: 300,
|
||||
}),
|
||||
transcribe_non_ptt_audio: false,
|
||||
});
|
||||
|
||||
let post = json!({
|
||||
@ -1497,6 +1501,7 @@ mod tests {
|
||||
max_audio_bytes: 25_000_000,
|
||||
timeout_secs: 300,
|
||||
}),
|
||||
transcribe_non_ptt_audio: false,
|
||||
});
|
||||
|
||||
let post = json!({
|
||||
|
||||
@ -266,6 +266,7 @@ enum ChannelRuntimeCommand {
|
||||
SetProvider(String),
|
||||
ShowModel,
|
||||
SetModel(String),
|
||||
ShowConfig,
|
||||
NewSession,
|
||||
}
|
||||
|
||||
@ -720,8 +721,30 @@ fn strip_tool_result_content(text: &str) -> String {
|
||||
cleaned.to_string()
|
||||
}
|
||||
|
||||
/// Remove a leading `[Used tools: ...]` line from a cached assistant turn.
|
||||
///
|
||||
/// The tool-context summary is prepended to history entries so the LLM retains
|
||||
/// awareness of prior tool usage. However, when these entries are loaded back
|
||||
/// into the LLM context, the bracket-format leaks into generated output and
|
||||
/// gets forwarded to end users as-is (bug #4400). Stripping the prefix on
|
||||
/// reload prevents the model from learning and reproducing this internal format.
|
||||
fn strip_tool_summary_prefix(text: &str) -> String {
|
||||
if let Some(rest) = text.strip_prefix("[Used tools:") {
|
||||
// Find the closing bracket, then skip it and any leading newline(s).
|
||||
if let Some(bracket_end) = rest.find(']') {
|
||||
let after_bracket = &rest[bracket_end + 1..];
|
||||
let trimmed = after_bracket.trim_start_matches('\n');
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
text.to_string()
|
||||
}
|
||||
|
||||
fn supports_runtime_model_switch(channel_name: &str) -> bool {
|
||||
matches!(channel_name, "telegram" | "discord" | "matrix")
|
||||
matches!(channel_name, "telegram" | "discord" | "matrix" | "slack")
|
||||
}
|
||||
|
||||
fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRuntimeCommand> {
|
||||
@ -759,6 +782,9 @@ fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRun
|
||||
Some(ChannelRuntimeCommand::SetModel(model))
|
||||
}
|
||||
}
|
||||
"/config" if supports_runtime_model_switch(channel_name) => {
|
||||
Some(ChannelRuntimeCommand::ShowConfig)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -1436,6 +1462,171 @@ fn build_providers_help_response(current: &ChannelRouteSelection) -> String {
|
||||
response
|
||||
}
|
||||
|
||||
/// Build a plain-text `/config` response for non-Slack channels.
|
||||
fn build_config_text_response(
|
||||
current: &ChannelRouteSelection,
|
||||
_workspace_dir: &Path,
|
||||
model_routes: &[crate::config::ModelRouteConfig],
|
||||
) -> String {
|
||||
let mut resp = String::new();
|
||||
let _ = writeln!(
|
||||
resp,
|
||||
"Current provider: `{}`\nCurrent model: `{}`",
|
||||
current.provider, current.model
|
||||
);
|
||||
resp.push_str("\nAvailable providers:\n");
|
||||
for p in providers::list_providers() {
|
||||
let _ = writeln!(resp, "- `{}`", p.name);
|
||||
}
|
||||
if !model_routes.is_empty() {
|
||||
resp.push_str("\nConfigured model routes:\n");
|
||||
for route in model_routes {
|
||||
let _ = writeln!(
|
||||
resp,
|
||||
" `{}` -> {} ({})",
|
||||
route.hint, route.model, route.provider
|
||||
);
|
||||
}
|
||||
}
|
||||
resp.push_str(
|
||||
"\nUse `/models <provider>` to switch provider.\nUse `/model <model-id>` to switch model.",
|
||||
);
|
||||
resp
|
||||
}
|
||||
|
||||
/// Prefix used to signal that a runtime command response contains raw Block Kit
|
||||
/// JSON instead of plain text. [`SlackChannel::send`] detects this and posts
|
||||
/// the blocks directly via `chat.postMessage`.
|
||||
const BLOCK_KIT_PREFIX: &str = "__ZEROCLAW_BLOCK_KIT__";
|
||||
|
||||
/// Build a Slack Block Kit JSON payload for the `/config` interactive UI.
|
||||
fn build_config_block_kit(
|
||||
current: &ChannelRouteSelection,
|
||||
workspace_dir: &Path,
|
||||
model_routes: &[crate::config::ModelRouteConfig],
|
||||
) -> String {
|
||||
let provider_options: Vec<serde_json::Value> = providers::list_providers()
|
||||
.iter()
|
||||
.map(|p| {
|
||||
serde_json::json!({
|
||||
"text": { "type": "plain_text", "text": p.display_name },
|
||||
"value": p.name
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build model options from model_routes + cached models.
|
||||
let mut model_options: Vec<serde_json::Value> = model_routes
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let label = if r.hint.is_empty() {
|
||||
r.model.clone()
|
||||
} else {
|
||||
format!("{} ({})", r.model, r.hint)
|
||||
};
|
||||
serde_json::json!({
|
||||
"text": { "type": "plain_text", "text": label },
|
||||
"value": r.model
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let cached = load_cached_model_preview(workspace_dir, ¤t.provider);
|
||||
for model_id in cached {
|
||||
if !model_options.iter().any(|o| {
|
||||
o.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.is_some_and(|v| v == model_id)
|
||||
}) {
|
||||
model_options.push(serde_json::json!({
|
||||
"text": { "type": "plain_text", "text": model_id },
|
||||
"value": model_id
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// If the current model is not in the list, prepend it.
|
||||
if !model_options.iter().any(|o| {
|
||||
o.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.is_some_and(|v| v == current.model)
|
||||
}) {
|
||||
model_options.insert(
|
||||
0,
|
||||
serde_json::json!({
|
||||
"text": { "type": "plain_text", "text": ¤t.model },
|
||||
"value": ¤t.model
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Find initial options matching current selection.
|
||||
let initial_provider = provider_options
|
||||
.iter()
|
||||
.find(|o| {
|
||||
o.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.is_some_and(|v| v == current.provider)
|
||||
})
|
||||
.cloned();
|
||||
|
||||
let initial_model = model_options
|
||||
.iter()
|
||||
.find(|o| {
|
||||
o.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.is_some_and(|v| v == current.model)
|
||||
})
|
||||
.cloned();
|
||||
|
||||
let mut provider_select = serde_json::json!({
|
||||
"type": "static_select",
|
||||
"action_id": "zeroclaw_config_provider",
|
||||
"placeholder": { "type": "plain_text", "text": "Select provider" },
|
||||
"options": provider_options
|
||||
});
|
||||
if let Some(init) = initial_provider {
|
||||
provider_select["initial_option"] = init;
|
||||
}
|
||||
|
||||
let mut model_select = serde_json::json!({
|
||||
"type": "static_select",
|
||||
"action_id": "zeroclaw_config_model",
|
||||
"placeholder": { "type": "plain_text", "text": "Select model" },
|
||||
"options": model_options
|
||||
});
|
||||
if let Some(init) = initial_model {
|
||||
model_select["initial_option"] = init;
|
||||
}
|
||||
|
||||
let blocks = serde_json::json!([
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": format!(
|
||||
"*Model Configuration*\nCurrent: `{}` / `{}`",
|
||||
current.provider, current.model
|
||||
)
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"block_id": "config_provider_block",
|
||||
"text": { "type": "mrkdwn", "text": "*Provider*" },
|
||||
"accessory": provider_select
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"block_id": "config_model_block",
|
||||
"text": { "type": "mrkdwn", "text": "*Model*" },
|
||||
"accessory": model_select
|
||||
}
|
||||
]);
|
||||
|
||||
blocks.to_string()
|
||||
}
|
||||
|
||||
async fn handle_runtime_command_if_needed(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
msg: &traits::ChannelMessage,
|
||||
@ -1508,6 +1699,19 @@ async fn handle_runtime_command_if_needed(
|
||||
)
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::ShowConfig => {
|
||||
if msg.channel == "slack" {
|
||||
let blocks_json = build_config_block_kit(
|
||||
¤t,
|
||||
ctx.workspace_dir.as_path(),
|
||||
&ctx.model_routes,
|
||||
);
|
||||
// Use a magic prefix so SlackChannel::send() can detect Block Kit JSON.
|
||||
format!("__ZEROCLAW_BLOCK_KIT__{blocks_json}")
|
||||
} else {
|
||||
build_config_text_response(¤t, ctx.workspace_dir.as_path(), &ctx.model_routes)
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::NewSession => {
|
||||
clear_sender_history(ctx, &sender_key);
|
||||
if let Some(ref store) = ctx.session_store {
|
||||
@ -1590,6 +1794,7 @@ async fn build_memory_context(
|
||||
/// during `run_tool_call_loop`. Scans assistant messages for `<tool_call>` tags
|
||||
/// or native tool-call JSON to collect tool names used.
|
||||
/// Returns an empty string when no tools were invoked.
|
||||
#[cfg(test)]
|
||||
fn extract_tool_context_summary(history: &[ChatMessage], start_index: usize) -> String {
|
||||
fn push_unique_tool_name(tool_names: &mut Vec<String>, name: &str) {
|
||||
let candidate = name.trim();
|
||||
@ -1685,12 +1890,27 @@ fn sanitize_channel_response(response: &str, tools: &[Box<dyn Tool>]) -> String
|
||||
.iter()
|
||||
.map(|tool| tool.name().to_ascii_lowercase())
|
||||
.collect();
|
||||
// Strip any [Used tools: ...] prefix that the LLM may have echoed from
|
||||
// history context (#4400).
|
||||
let stripped_summary = strip_tool_summary_prefix(response);
|
||||
// Strip XML-style tool-call tags (e.g. <tool_call>...</tool_call>)
|
||||
let stripped_xml = strip_tool_call_tags(response);
|
||||
let stripped_xml = strip_tool_call_tags(&stripped_summary);
|
||||
// Strip isolated tool-call JSON artifacts
|
||||
let stripped_json = strip_isolated_tool_json_artifacts(&stripped_xml, &known_tool_names);
|
||||
// Strip leading narration lines that announce tool usage
|
||||
strip_tool_narration(&stripped_json)
|
||||
let sanitized = strip_tool_narration(&stripped_json);
|
||||
|
||||
// Scan for credential leaks before returning to caller
|
||||
match crate::security::LeakDetector::new().scan(&sanitized) {
|
||||
crate::security::LeakResult::Clean => sanitized,
|
||||
crate::security::LeakResult::Detected { patterns, redacted } => {
|
||||
tracing::warn!(
|
||||
patterns = ?patterns,
|
||||
"output guardrail: credential leak detected in outbound channel response"
|
||||
);
|
||||
redacted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove leading lines that narrate tool usage (e.g. "Let me check the weather for you.").
|
||||
@ -2230,6 +2450,14 @@ async fn process_channel_message(
|
||||
}
|
||||
}
|
||||
|
||||
// Strip [Used tools: ...] prefixes from cached assistant turns so the
|
||||
// LLM never sees (and reproduces) this internal summary format (#4400).
|
||||
for turn in &mut prior_turns {
|
||||
if turn.role == "assistant" && turn.content.starts_with("[Used tools:") {
|
||||
turn.content = strip_tool_summary_prefix(&turn.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Strip [IMAGE:] markers from *older* history messages when the active
|
||||
// provider does not support vision. This prevents "history poisoning"
|
||||
// where a previously-sent image marker gets reloaded from the JSONL
|
||||
@ -2446,9 +2674,6 @@ async fn process_channel_message(
|
||||
}))
|
||||
};
|
||||
|
||||
// Record history length before tool loop so we can extract tool context after.
|
||||
let history_len_before_tools = history.len();
|
||||
|
||||
enum LlmExecutionResult {
|
||||
Completed(Result<Result<String, anyhow::Error>, tokio::time::error::Elapsed>),
|
||||
Cancelled,
|
||||
@ -2704,15 +2929,15 @@ async fn process_channel_message(
|
||||
}),
|
||||
);
|
||||
|
||||
// Extract condensed tool-use context from the history messages
|
||||
// added during run_tool_call_loop, so the LLM retains awareness
|
||||
// of what it did on subsequent turns.
|
||||
let tool_summary = extract_tool_context_summary(&history, history_len_before_tools);
|
||||
let history_response = if tool_summary.is_empty() || msg.channel == "telegram" {
|
||||
delivered_response.clone()
|
||||
} else {
|
||||
format!("{tool_summary}\n{delivered_response}")
|
||||
};
|
||||
// Previously we prepended a `[Used tools: …]` summary to the
|
||||
// history entry so the LLM retained awareness of prior tool usage.
|
||||
// This caused the model to learn and reproduce the bracket format
|
||||
// in its own output, which leaked to end-users as raw log lines
|
||||
// instead of meaningful responses (#4400). The LLM already
|
||||
// receives tool context through the tool-call/result messages in
|
||||
// the conversation history built by `run_tool_call_loop`, so the
|
||||
// extra summary prefix is unnecessary.
|
||||
let history_response = delivered_response.clone();
|
||||
|
||||
append_sender_turn(
|
||||
ctx.as_ref(),
|
||||
@ -3897,16 +4122,19 @@ fn collect_configured_channels(
|
||||
if let Some(ref mx) = config.channels_config.matrix {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Matrix",
|
||||
channel: Arc::new(MatrixChannel::new_full(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
mx.allowed_rooms.clone(),
|
||||
mx.user_id.clone(),
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
MatrixChannel::new_full(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
mx.allowed_rooms.clone(),
|
||||
mx.user_id.clone(),
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@ -4400,23 +4628,28 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let (mut built_tools, delegate_handle_ch, reaction_handle_ch, _channel_map_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let (
|
||||
mut built_tools,
|
||||
delegate_handle_ch,
|
||||
reaction_handle_ch,
|
||||
_channel_map_handle,
|
||||
ask_user_handle_ch,
|
||||
) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
|
||||
@ -4699,6 +4932,14 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the ask_user tool's channel map now that channels are initialized.
|
||||
if let Some(ref handle) = ask_user_handle_ch {
|
||||
let mut map = handle.write();
|
||||
for (name, ch) in channels_by_name.as_ref() {
|
||||
map.insert(name.clone(), Arc::clone(ch));
|
||||
}
|
||||
}
|
||||
|
||||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||||
|
||||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||||
@ -5024,6 +5265,36 @@ mod tests {
|
||||
assert_eq!(strip_tool_result_content(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_tool_summary_prefix_removes_prefix_and_preserves_content() {
|
||||
let input = "[Used tools: browser_open, shell]\nI opened the page successfully.";
|
||||
assert_eq!(
|
||||
strip_tool_summary_prefix(input),
|
||||
"I opened the page successfully."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_tool_summary_prefix_returns_empty_when_only_prefix() {
|
||||
let input = "[Used tools: browser_open]";
|
||||
assert_eq!(strip_tool_summary_prefix(input), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_tool_summary_prefix_preserves_text_without_prefix() {
|
||||
let input = "Here is the result of the search.";
|
||||
assert_eq!(strip_tool_summary_prefix(input), input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_tool_summary_prefix_handles_multiple_newlines() {
|
||||
let input = "[Used tools: shell]\n\nThe command output is 42.";
|
||||
assert_eq!(
|
||||
strip_tool_summary_prefix(input),
|
||||
"The command output is 42."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_cached_channel_turns_merges_consecutive_user_turns() {
|
||||
let turns = vec![
|
||||
@ -10303,4 +10574,25 @@ This is an example JSON object for profile settings."#;
|
||||
"both Slack thread messages should complete, got: {sent_messages:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_channel_response_redacts_detected_credentials() {
|
||||
let tools: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let leaked = "Temporary key: AKIAABCDEFGHIJKLMNOP"; // gitleaks:allow
|
||||
|
||||
let result = sanitize_channel_response(leaked, &tools);
|
||||
|
||||
assert!(!result.contains("AKIAABCDEFGHIJKLMNOP")); // gitleaks:allow
|
||||
assert!(result.contains("[REDACTED"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_channel_response_passes_clean_text() {
|
||||
let tools: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let clean_text = "This is a normal message with no credentials.";
|
||||
|
||||
let result = sanitize_channel_response(clean_text, &tools);
|
||||
|
||||
assert_eq!(result, clean_text);
|
||||
}
|
||||
}
|
||||
|
||||
@ -976,7 +976,9 @@ impl Channel for QQChannel {
|
||||
let gw_url = self.get_gateway_url(&token).await?;
|
||||
|
||||
tracing::info!("QQ: connecting to gateway WebSocket...");
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&gw_url).await?;
|
||||
let (ws_stream, _) =
|
||||
crate::config::ws_connect_with_proxy(&gw_url, "channel.qq", self.proxy_url.as_deref())
|
||||
.await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
|
||||
@ -37,6 +37,7 @@ pub struct SlackChannel {
|
||||
/// Voice transcription config — when set, audio file attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<std::sync::Arc<super::transcription::TranscriptionManager>>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@ -129,6 +130,7 @@ impl SlackChannel {
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
transcription_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -164,8 +166,19 @@ impl SlackChannel {
|
||||
|
||||
/// Configure voice transcription for audio file attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
if !config.enabled {
|
||||
return self;
|
||||
}
|
||||
match super::transcription::TranscriptionManager::new(&config) {
|
||||
Ok(m) => {
|
||||
self.transcription_manager = Some(std::sync::Arc::new(m));
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"transcription manager init failed, voice transcription disabled: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
@ -179,6 +192,94 @@ impl SlackChannel {
|
||||
)
|
||||
}
|
||||
|
||||
/// Post a new Slack message and return the message timestamp (`ts`).
|
||||
///
|
||||
/// This is a lower-level helper that exposes the `ts` value needed for
|
||||
/// subsequent `chat.update` calls. For simple sends, use the [`Channel::send`]
|
||||
/// trait method instead.
|
||||
pub async fn post_message(&self, channel: &str, text: &str) -> anyhow::Result<String> {
|
||||
let body = serde_json::json!({
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http_client()
|
||||
.post("https://slack.com/api/chat.postMessage")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let raw = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
|
||||
|
||||
if !status.is_success() {
|
||||
let sanitized = crate::providers::sanitize_api_error(&raw);
|
||||
anyhow::bail!("Slack chat.postMessage failed ({status}): {sanitized}");
|
||||
}
|
||||
|
||||
let parsed: serde_json::Value = serde_json::from_str(&raw).unwrap_or_default();
|
||||
if parsed.get("ok") == Some(&serde_json::Value::Bool(false)) {
|
||||
let err = parsed
|
||||
.get("error")
|
||||
.and_then(|e| e.as_str())
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!("Slack chat.postMessage failed: {err}");
|
||||
}
|
||||
|
||||
parsed
|
||||
.get("ts")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.ok_or_else(|| anyhow::anyhow!("Slack chat.postMessage response missing 'ts'"))
|
||||
}
|
||||
|
||||
/// Update an existing Slack message in-place using `chat.update`.
|
||||
///
|
||||
/// `channel` is the channel ID and `ts` is the timestamp of the original
|
||||
/// message (returned by [`post_message`]).
|
||||
pub async fn update_message(&self, channel: &str, ts: &str, text: &str) -> anyhow::Result<()> {
|
||||
let body = serde_json::json!({
|
||||
"channel": channel,
|
||||
"ts": ts,
|
||||
"text": text,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http_client()
|
||||
.post("https://slack.com/api/chat.update")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let raw = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
|
||||
|
||||
if !status.is_success() {
|
||||
let sanitized = crate::providers::sanitize_api_error(&raw);
|
||||
anyhow::bail!("Slack chat.update failed ({status}): {sanitized}");
|
||||
}
|
||||
|
||||
let parsed: serde_json::Value = serde_json::from_str(&raw).unwrap_or_default();
|
||||
if parsed.get("ok") == Some(&serde_json::Value::Bool(false)) {
|
||||
let err = parsed
|
||||
.get("error")
|
||||
.and_then(|e| e.as_str())
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!("Slack chat.update failed: {err}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a Slack user ID is in the allowlist.
|
||||
/// Empty list means deny everyone until explicitly configured.
|
||||
/// `"*"` means allow everyone.
|
||||
@ -1509,7 +1610,7 @@ impl SlackChannel {
|
||||
/// transcription provider. Returns `None` if transcription is not configured
|
||||
/// or if the download/transcription fails.
|
||||
async fn try_transcribe_audio_file(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
let manager = self.transcription_manager.as_deref()?;
|
||||
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let file_name = Self::slack_file_name(file);
|
||||
@ -1544,7 +1645,8 @@ impl SlackChannel {
|
||||
format!("voice.{mime_ext}")
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, &transcription_filename, config)
|
||||
match manager
|
||||
.transcribe(&audio_data, &transcription_filename)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
@ -1778,6 +1880,79 @@ impl SlackChannel {
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Parse a Socket Mode `interactive` envelope containing a `block_actions`
|
||||
/// payload from the `/config` Block Kit UI. Translates provider/model
|
||||
/// dropdown selections into synthetic `/models <provider>` or `/model <id>`
|
||||
/// commands so the existing runtime command handler can apply them.
|
||||
fn parse_block_action_as_command(
|
||||
envelope: &serde_json::Value,
|
||||
_bot_user_id: &str,
|
||||
) -> Option<ChannelMessage> {
|
||||
let payload = envelope.get("payload")?;
|
||||
|
||||
let payload_type = payload.get("type").and_then(|v| v.as_str())?;
|
||||
if payload_type != "block_actions" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let actions = payload.get("actions").and_then(|v| v.as_array())?;
|
||||
let action = actions.first()?;
|
||||
|
||||
let action_id = action.get("action_id").and_then(|v| v.as_str())?;
|
||||
let selected_value = action
|
||||
.get("selected_option")
|
||||
.and_then(|o| o.get("value"))
|
||||
.and_then(|v| v.as_str())?;
|
||||
|
||||
let command = match action_id {
|
||||
"zeroclaw_config_provider" => format!("/models {selected_value}"),
|
||||
"zeroclaw_config_model" => format!("/model {selected_value}"),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let user = payload
|
||||
.get("user")
|
||||
.and_then(|u| u.get("id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let channel_id = payload
|
||||
.get("channel")
|
||||
.and_then(|c| c.get("id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
|
||||
if channel_id.is_empty() {
|
||||
tracing::warn!("Slack block_actions: missing channel ID in interactive payload");
|
||||
return None;
|
||||
}
|
||||
|
||||
let ts = payload
|
||||
.get("message")
|
||||
.and_then(|m| m.get("ts"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0");
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("slack_{channel_id}_{ts}_action"),
|
||||
sender: user.to_string(),
|
||||
reply_target: channel_id.to_string(),
|
||||
content: command,
|
||||
channel: "slack".to_string(),
|
||||
timestamp: SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
thread_ts: payload
|
||||
.get("message")
|
||||
.and_then(|m| m.get("thread_ts"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::to_string),
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
async fn open_socket_mode_url(&self) -> anyhow::Result<String> {
|
||||
let app_token = self
|
||||
.configured_app_token()
|
||||
@ -1846,7 +2021,13 @@ impl SlackChannel {
|
||||
}
|
||||
};
|
||||
|
||||
let (ws_stream, _) = match tokio_tungstenite::connect_async(&ws_url).await {
|
||||
let (ws_stream, _) = match crate::config::ws_connect_with_proxy(
|
||||
&ws_url,
|
||||
"channel.slack",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => {
|
||||
socket_reconnect_attempt = 0;
|
||||
connection
|
||||
@ -1912,6 +2093,17 @@ impl SlackChannel {
|
||||
tracing::warn!("Slack Socket Mode: received disconnect event");
|
||||
break;
|
||||
}
|
||||
|
||||
// Handle interactive payloads (block_actions from /config UI).
|
||||
if envelope_type == "interactive" {
|
||||
if let Some(msg) = Self::parse_block_action_as_command(&envelope, bot_user_id) {
|
||||
if tx.send(msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if envelope_type != "events_api" {
|
||||
continue;
|
||||
}
|
||||
@ -2403,22 +2595,39 @@ impl Channel for SlackChannel {
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
let mut body = serde_json::json!({
|
||||
"channel": message.recipient,
|
||||
"text": message.content
|
||||
});
|
||||
|
||||
// Use Slack's native markdown block for rich formatting when content fits.
|
||||
if message.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
// Detect Block Kit payloads produced by the `/config` command.
|
||||
let body = if let Some(blocks_json) = message.content.strip_prefix(super::BLOCK_KIT_PREFIX)
|
||||
{
|
||||
let blocks: serde_json::Value = serde_json::from_str(blocks_json)
|
||||
.context("invalid Block Kit JSON in runtime command response")?;
|
||||
let mut body = serde_json::json!({
|
||||
"channel": message.recipient,
|
||||
"text": "Model configuration",
|
||||
"blocks": blocks
|
||||
});
|
||||
if let Some(ts) = self.outbound_thread_ts(message) {
|
||||
body["thread_ts"] = serde_json::json!(ts);
|
||||
}
|
||||
body
|
||||
} else {
|
||||
let mut body = serde_json::json!({
|
||||
"channel": message.recipient,
|
||||
"text": message.content
|
||||
}]);
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(ts) = self.outbound_thread_ts(message) {
|
||||
body["thread_ts"] = serde_json::json!(ts);
|
||||
}
|
||||
// Use Slack's native markdown block for rich formatting when content fits.
|
||||
if message.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": message.content
|
||||
}]);
|
||||
}
|
||||
|
||||
if let Some(ts) = self.outbound_thread_ts(message) {
|
||||
body["thread_ts"] = serde_json::json!(ts);
|
||||
}
|
||||
body
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.http_client()
|
||||
|
||||
@ -330,6 +330,7 @@ pub struct TelegramChannel {
|
||||
/// Override for local Bot API servers or testing.
|
||||
api_base: String,
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<std::sync::Arc<super::transcription::TranscriptionManager>>,
|
||||
voice_transcriptions: Mutex<std::collections::HashMap<String, String>>,
|
||||
workspace_dir: Option<std::path::PathBuf>,
|
||||
ack_reactions: bool,
|
||||
@ -375,6 +376,7 @@ impl TelegramChannel {
|
||||
bot_username: Mutex::new(None),
|
||||
api_base: "https://api.telegram.org".to_string(),
|
||||
transcription: None,
|
||||
transcription_manager: None,
|
||||
voice_transcriptions: Mutex::new(std::collections::HashMap::new()),
|
||||
workspace_dir: None,
|
||||
ack_reactions: true,
|
||||
@ -423,8 +425,19 @@ impl TelegramChannel {
|
||||
|
||||
/// Configure voice transcription.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
if !config.enabled {
|
||||
return self;
|
||||
}
|
||||
match super::transcription::TranscriptionManager::new(&config) {
|
||||
Ok(m) => {
|
||||
self.transcription_manager = Some(std::sync::Arc::new(m));
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"transcription manager init failed, voice transcription disabled: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
@ -1167,6 +1180,7 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
/// or the message exceeds duration limits.
|
||||
async fn try_parse_voice_message(&self, update: &serde_json::Value) -> Option<ChannelMessage> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
let manager = self.transcription_manager.as_deref()?;
|
||||
let message = update.get("message")?;
|
||||
|
||||
let (file_id, duration) = Self::parse_voice_metadata(message)?;
|
||||
@ -1235,14 +1249,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
}
|
||||
};
|
||||
|
||||
let text =
|
||||
match super::transcription::transcribe_audio(audio_data, &file_name, config).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Voice transcription failed: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let text = match manager.transcribe(&audio_data, &file_name).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Voice transcription failed: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if text.trim().is_empty() {
|
||||
tracing::info!("Voice transcription returned empty text, skipping");
|
||||
@ -4348,10 +4361,12 @@ mod tests {
|
||||
fn with_transcription_sets_config_when_enabled() {
|
||||
let mut tc = crate::config::TranscriptionConfig::default();
|
||||
tc.enabled = true;
|
||||
tc.api_key = Some("test_key".to_string());
|
||||
|
||||
let ch =
|
||||
TelegramChannel::new("token".into(), vec!["*".into()], false).with_transcription(tc);
|
||||
assert!(ch.transcription.is_some());
|
||||
assert!(ch.transcription_manager.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -4360,6 +4375,7 @@ mod tests {
|
||||
let ch =
|
||||
TelegramChannel::new("token".into(), vec!["*".into()], false).with_transcription(tc);
|
||||
assert!(ch.transcription.is_none());
|
||||
assert!(ch.transcription_manager.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -4382,6 +4398,7 @@ mod tests {
|
||||
async fn try_parse_voice_message_skips_when_duration_exceeds_limit() {
|
||||
let mut tc = crate::config::TranscriptionConfig::default();
|
||||
tc.enabled = true;
|
||||
tc.api_key = Some("test_key".to_string());
|
||||
tc.max_duration_secs = 5;
|
||||
|
||||
let ch =
|
||||
@ -4403,6 +4420,7 @@ mod tests {
|
||||
async fn try_parse_voice_message_rejects_unauthorized_sender_before_download() {
|
||||
let mut tc = crate::config::TranscriptionConfig::default();
|
||||
tc.enabled = true;
|
||||
tc.api_key = Some("test_key".to_string());
|
||||
tc.max_duration_secs = 120;
|
||||
|
||||
let ch = TelegramChannel::new("token".into(), vec!["alice".into()], false)
|
||||
@ -4453,15 +4471,17 @@ mod tests {
|
||||
audio_data.len()
|
||||
);
|
||||
|
||||
// 2. Call transcribe_audio() — real Groq Whisper API
|
||||
// 2. Call TranscriptionManager.transcribe() — real Groq Whisper API
|
||||
let config = crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
};
|
||||
let transcript: String =
|
||||
crate::channels::transcription::transcribe_audio(audio_data, "hello.mp3", &config)
|
||||
.await
|
||||
.expect("transcribe_audio should succeed with valid GROQ_API_KEY");
|
||||
let manager = crate::channels::transcription::TranscriptionManager::new(&config)
|
||||
.expect("TranscriptionManager::new should succeed with valid GROQ_API_KEY");
|
||||
let transcript: String = manager
|
||||
.transcribe(&audio_data, "hello.mp3")
|
||||
.await
|
||||
.expect("transcribe should succeed with valid GROQ_API_KEY");
|
||||
|
||||
// 3. Verify Whisper actually recognized "hello"
|
||||
assert!(
|
||||
|
||||
@ -161,6 +161,20 @@ pub trait Channel: Send + Sync {
|
||||
async fn unpin_message(&self, _channel_id: &str, _message_id: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Redact (delete) a message from the channel.
|
||||
///
|
||||
/// `channel_id` is the platform channel/conversation identifier.
|
||||
/// `message_id` is the platform-scoped message identifier.
|
||||
/// `reason` is an optional reason for the redaction (may be visible in audit logs).
|
||||
async fn redact_message(
|
||||
&self,
|
||||
_channel_id: &str,
|
||||
_message_id: &str,
|
||||
_reason: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -279,4 +293,18 @@ mod tests {
|
||||
assert_eq!(received.content, "hello");
|
||||
assert_eq!(received.channel, "dummy");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_redact_message_returns_success() {
|
||||
let channel = DummyChannel;
|
||||
|
||||
assert!(channel
|
||||
.redact_message("chan_1", "msg_1", Some("spam".to_string()))
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(channel
|
||||
.redact_message("chan_1", "msg_2", None)
|
||||
.await
|
||||
.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1153,6 +1153,7 @@ mod tests {
|
||||
assert!(config.assemblyai.is_none());
|
||||
assert!(config.google.is_none());
|
||||
assert!(config.local_whisper.is_none());
|
||||
assert!(!config.transcribe_non_ptt_audio);
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider tests (TDD — added below as red/green cycles) ──
|
||||
|
||||
@ -723,6 +723,7 @@ mod tests {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
};
|
||||
|
||||
let ch = WatiChannel::new(
|
||||
@ -752,6 +753,7 @@ mod tests {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
};
|
||||
|
||||
let ch = WatiChannel::new(
|
||||
@ -794,6 +796,7 @@ mod tests {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
};
|
||||
|
||||
let ch = WatiChannel::new(
|
||||
@ -940,6 +943,7 @@ mod tests {
|
||||
max_audio_bytes: 25 * 1024 * 1024,
|
||||
timeout_secs: 300,
|
||||
}),
|
||||
transcribe_non_ptt_audio: false,
|
||||
};
|
||||
|
||||
let ch = WatiChannel::new(
|
||||
@ -992,6 +996,7 @@ mod tests {
|
||||
max_audio_bytes: 25 * 1024 * 1024,
|
||||
timeout_secs: 300,
|
||||
}),
|
||||
transcribe_non_ptt_audio: false,
|
||||
};
|
||||
|
||||
let ch = WatiChannel::new(
|
||||
|
||||
@ -74,6 +74,7 @@ pub struct WhatsAppWebChannel {
|
||||
tx: Arc<Mutex<Option<tokio::sync::mpsc::Sender<ChannelMessage>>>>,
|
||||
/// Voice transcription (STT) config
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<std::sync::Arc<super::transcription::TranscriptionManager>>,
|
||||
/// Text-to-speech config for voice replies
|
||||
tts_config: Option<crate::config::TtsConfig>,
|
||||
/// Chats awaiting a voice reply — maps chat JID to the latest substantive
|
||||
@ -122,6 +123,7 @@ impl WhatsAppWebChannel {
|
||||
client: Arc::new(Mutex::new(None)),
|
||||
tx: Arc::new(Mutex::new(None)),
|
||||
transcription: None,
|
||||
transcription_manager: None,
|
||||
tts_config: None,
|
||||
pending_voice: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
voice_chats: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
@ -131,8 +133,19 @@ impl WhatsAppWebChannel {
|
||||
/// Configure voice transcription (STT) for incoming voice notes.
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
if !config.enabled {
|
||||
return self;
|
||||
}
|
||||
match super::transcription::TranscriptionManager::new(&config) {
|
||||
Ok(m) => {
|
||||
self.transcription_manager = Some(std::sync::Arc::new(m));
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"transcription manager init failed, voice transcription disabled: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
@ -338,8 +351,10 @@ impl WhatsAppWebChannel {
|
||||
client: &wa_rs::Client,
|
||||
audio: &wa_rs_proto::whatsapp::message::AudioMessage,
|
||||
transcription_config: Option<&crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<&super::transcription::TranscriptionManager>,
|
||||
) -> Option<String> {
|
||||
let config = transcription_config?;
|
||||
let manager = transcription_manager?;
|
||||
|
||||
// Enforce duration limit
|
||||
if let Some(seconds) = audio.seconds {
|
||||
@ -378,7 +393,7 @@ impl WhatsAppWebChannel {
|
||||
file_name
|
||||
);
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, file_name, config).await {
|
||||
match manager.transcribe(&audio_data, file_name).await {
|
||||
Ok(text) if text.trim().is_empty() => {
|
||||
tracing::info!("WhatsApp Web: voice transcription returned empty text, skipping");
|
||||
None
|
||||
@ -644,6 +659,7 @@ impl Channel for WhatsAppWebChannel {
|
||||
let retry_count_clone = retry_count.clone();
|
||||
let session_revoked_clone = session_revoked.clone();
|
||||
let transcription_config = self.transcription.clone();
|
||||
let transcription_mgr = self.transcription_manager.clone();
|
||||
let voice_chats = self.voice_chats.clone();
|
||||
let wa_mode = self.mode.clone();
|
||||
let wa_dm_policy = self.dm_policy.clone();
|
||||
@ -661,6 +677,7 @@ impl Channel for WhatsAppWebChannel {
|
||||
let retry_count = retry_count_clone.clone();
|
||||
let session_revoked = session_revoked_clone.clone();
|
||||
let transcription_config = transcription_config.clone();
|
||||
let transcription_mgr = transcription_mgr.clone();
|
||||
let voice_chats = voice_chats.clone();
|
||||
let wa_mode = wa_mode.clone();
|
||||
let wa_dm_policy = wa_dm_policy.clone();
|
||||
@ -756,13 +773,20 @@ impl Channel for WhatsAppWebChannel {
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt voice note transcription (ptt = push-to-talk = voice note)
|
||||
// Attempt voice note transcription (ptt = push-to-talk = voice note).
|
||||
// When `transcribe_non_ptt_audio` is enabled in the transcription
|
||||
// config, also transcribe forwarded / regular audio messages.
|
||||
let voice_text = if let Some(ref audio) = msg.audio_message {
|
||||
if audio.ptt == Some(true) {
|
||||
let is_ptt = audio.ptt == Some(true);
|
||||
let non_ptt_enabled = transcription_config
|
||||
.as_ref()
|
||||
.is_some_and(|c| c.transcribe_non_ptt_audio);
|
||||
if is_ptt || non_ptt_enabled {
|
||||
Self::try_transcribe_voice_note(
|
||||
&client,
|
||||
audio,
|
||||
transcription_config.as_ref(),
|
||||
transcription_mgr.as_deref(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
@ -1343,9 +1367,11 @@ mod tests {
|
||||
fn with_transcription_sets_config_when_enabled() {
|
||||
let mut tc = crate::config::TranscriptionConfig::default();
|
||||
tc.enabled = true;
|
||||
tc.api_key = Some("test_key".to_string());
|
||||
|
||||
let ch = make_channel().with_transcription(tc);
|
||||
assert!(ch.transcription.is_some());
|
||||
assert!(ch.transcription_manager.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1354,6 +1380,7 @@ mod tests {
|
||||
let tc = crate::config::TranscriptionConfig::default(); // enabled = false
|
||||
let ch = make_channel().with_transcription(tc);
|
||||
assert!(ch.transcription.is_none());
|
||||
assert!(ch.transcription_manager.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -7,31 +7,32 @@ pub use schema::{
|
||||
apply_channel_proxy_to_builder, apply_runtime_proxy_to_builder, build_channel_proxy_client,
|
||||
build_channel_proxy_client_with_timeouts, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, CodexCliConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, CronJobDecl, CronScheduleDecl,
|
||||
DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GeminiCliConfig, GoogleSttConfig, GoogleTtsConfig,
|
||||
GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig, HardwareConfig, HardwareTransport,
|
||||
HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig,
|
||||
ImageGenConfig, ImageProviderDalleConfig, ImageProviderFluxConfig, ImageProviderImagenConfig,
|
||||
ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig, LarkConfig, LinkEnricherConfig,
|
||||
LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig, MatrixConfig,
|
||||
McpConfig, McpServerConfig, McpTransport, MediaPipelineConfig, MemoryConfig,
|
||||
MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiSttConfig, OpenAiTtsConfig, OpenCodeCliConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
|
||||
PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PiperTtsConfig, PluginsConfig,
|
||||
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, SopConfig, StorageConfig,
|
||||
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
|
||||
TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig,
|
||||
TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig,
|
||||
WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
ws_connect_with_proxy, AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig,
|
||||
BackupConfig, BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ClaudeCodeConfig, ClaudeCodeRunnerConfig, CloudOpsConfig, CodexCliConfig,
|
||||
ComposioConfig, Config, ConversationalAiConfig, CostConfig, CronConfig, CronJobDecl,
|
||||
CronScheduleDecl, DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig,
|
||||
DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig,
|
||||
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GeminiCliConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkEnricherConfig, LinkedInConfig, LinkedInContentConfig,
|
||||
LinkedInImageConfig, LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MediaPipelineConfig, MemoryConfig, MemoryPolicyConfig, Microsoft365Config,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
|
||||
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenCodeCliConfig,
|
||||
OpenVpnTunnelConfig, OtpConfig, OtpMethod, PacingConfig, PeripheralBoardConfig,
|
||||
PeripheralsConfig, PiperTtsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, ShellToolConfig, SkillCreationConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, SopConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
@ -392,6 +392,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub claude_code: ClaudeCodeConfig,
|
||||
|
||||
/// Claude Code task runner with Slack progress and SSH session handoff (`[claude_code_runner]`).
|
||||
#[serde(default)]
|
||||
pub claude_code_runner: ClaudeCodeRunnerConfig,
|
||||
|
||||
/// Codex CLI tool configuration (`[codex_cli]`).
|
||||
#[serde(default)]
|
||||
pub codex_cli: CodexCliConfig,
|
||||
@ -407,6 +411,10 @@ pub struct Config {
|
||||
/// Standard Operating Procedures engine configuration (`[sop]`).
|
||||
#[serde(default)]
|
||||
pub sop: SopConfig,
|
||||
|
||||
/// Shell tool configuration (`[shell_tool]`).
|
||||
#[serde(default)]
|
||||
pub shell_tool: ShellToolConfig,
|
||||
}
|
||||
|
||||
/// Multi-client workspace isolation configuration.
|
||||
@ -827,6 +835,10 @@ pub struct TranscriptionConfig {
|
||||
/// Local/self-hosted Whisper-compatible STT provider.
|
||||
#[serde(default)]
|
||||
pub local_whisper: Option<LocalWhisperConfig>,
|
||||
/// Also transcribe non-PTT (forwarded/regular) audio messages on WhatsApp,
|
||||
/// not just voice notes. Default: `false` (preserves legacy behavior).
|
||||
#[serde(default)]
|
||||
pub transcribe_non_ptt_audio: bool,
|
||||
}
|
||||
|
||||
impl Default for TranscriptionConfig {
|
||||
@ -845,6 +857,7 @@ impl Default for TranscriptionConfig {
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
transcribe_non_ptt_audio: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1186,6 +1199,9 @@ pub struct ToolFilterGroup {
|
||||
/// Ignored when `mode = "always"`.
|
||||
#[serde(default)]
|
||||
pub keywords: Vec<String>,
|
||||
/// When true, also filter built-in tools (not just MCP tools).
|
||||
#[serde(default)]
|
||||
pub filter_builtins: bool,
|
||||
}
|
||||
|
||||
/// OpenAI Whisper STT provider configuration (`[transcription.openai]`).
|
||||
@ -1302,6 +1318,22 @@ pub struct AgentConfig {
|
||||
/// per message. Users can override per-message with `/think:<level>` directives.
|
||||
#[serde(default)]
|
||||
pub thinking: crate::agent::thinking::ThinkingConfig,
|
||||
|
||||
/// History pruning configuration for token efficiency.
|
||||
#[serde(default)]
|
||||
pub history_pruning: crate::agent::history_pruner::HistoryPrunerConfig,
|
||||
|
||||
/// Enable context-aware tool filtering (only surface relevant tools per iteration).
|
||||
#[serde(default)]
|
||||
pub context_aware_tools: bool,
|
||||
|
||||
/// Post-response quality evaluator configuration.
|
||||
#[serde(default)]
|
||||
pub eval: crate::agent::eval::EvalConfig,
|
||||
|
||||
/// Automatic complexity-based classification fallback.
|
||||
#[serde(default)]
|
||||
pub auto_classify: Option<crate::agent::eval::AutoClassifyConfig>,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@ -1337,6 +1369,10 @@ impl Default for AgentConfig {
|
||||
tool_filter_groups: Vec::new(),
|
||||
max_system_prompt_chars: default_max_system_prompt_chars(),
|
||||
thinking: crate::agent::thinking::ThinkingConfig::default(),
|
||||
history_pruning: crate::agent::history_pruner::HistoryPrunerConfig::default(),
|
||||
context_aware_tools: false,
|
||||
eval: crate::agent::eval::EvalConfig::default(),
|
||||
auto_classify: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2497,6 +2533,32 @@ impl Default for TextBrowserConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shell tool ───────────────────────────────────────────────────
|
||||
|
||||
/// Shell tool configuration (`[shell_tool]` section).
|
||||
///
|
||||
/// Controls the behaviour of the `shell` execution tool. The main
|
||||
/// tunable is `timeout_secs` — the maximum wall-clock time a single
|
||||
/// shell command may run before it is killed.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ShellToolConfig {
|
||||
/// Maximum shell command execution time in seconds (default: 60).
|
||||
#[serde(default = "default_shell_tool_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_shell_tool_timeout_secs() -> u64 {
|
||||
60
|
||||
}
|
||||
|
||||
impl Default for ShellToolConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeout_secs: default_shell_tool_timeout_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Web search ───────────────────────────────────────────────────
|
||||
|
||||
/// Web search tool configuration (`[web_search]` section).
|
||||
@ -3312,6 +3374,48 @@ impl Default for ClaudeCodeConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude Code Runner ──────────────────────────────────────────
|
||||
|
||||
/// Claude Code task runner configuration (`[claude_code_runner]` section).
|
||||
///
|
||||
/// Spawns Claude Code in a tmux session with HTTP hooks that POST tool
|
||||
/// execution events back to ZeroClaw's gateway, updating a Slack message
|
||||
/// in-place with progress plus an SSH handoff link.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ClaudeCodeRunnerConfig {
|
||||
/// Enable the `claude_code_runner` tool
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// SSH host for session handoff links (e.g. "myhost.example.com")
|
||||
#[serde(default)]
|
||||
pub ssh_host: Option<String>,
|
||||
/// Prefix for tmux session names (default: "zc-claude-")
|
||||
#[serde(default = "default_claude_code_runner_tmux_prefix")]
|
||||
pub tmux_prefix: String,
|
||||
/// Session time-to-live in seconds before auto-cleanup (default: 3600)
|
||||
#[serde(default = "default_claude_code_runner_session_ttl")]
|
||||
pub session_ttl: u64,
|
||||
}
|
||||
|
||||
fn default_claude_code_runner_tmux_prefix() -> String {
|
||||
"zc-claude-".into()
|
||||
}
|
||||
|
||||
fn default_claude_code_runner_session_ttl() -> u64 {
|
||||
3600
|
||||
}
|
||||
|
||||
impl Default for ClaudeCodeRunnerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
ssh_host: None,
|
||||
tmux_prefix: default_claude_code_runner_tmux_prefix(),
|
||||
session_ttl: default_claude_code_runner_session_ttl(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Codex CLI ───────────────────────────────────────────────────
|
||||
|
||||
/// Codex CLI tool configuration (`[codex_cli]` section).
|
||||
@ -4055,6 +4159,317 @@ fn apply_explicit_proxy_to_builder(
|
||||
builder
|
||||
}
|
||||
|
||||
// ── Proxy-aware WebSocket connect ────────────────────────────────
|
||||
//
|
||||
// `tokio_tungstenite::connect_async` does not honour proxy settings.
|
||||
// The helpers below resolve the effective proxy URL for a given service
|
||||
// key and, when a proxy is active, establish a tunnelled TCP connection
|
||||
// (HTTP CONNECT for http/https proxies, SOCKS5 for socks5/socks5h)
|
||||
// before handing the stream to `tokio_tungstenite` for the WebSocket
|
||||
// handshake.
|
||||
|
||||
/// Combined async IO trait for boxed WebSocket transport streams.
|
||||
trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {}
|
||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
|
||||
|
||||
/// A boxed async IO stream used when a WebSocket connection is tunnelled
|
||||
/// through a proxy. The concrete type varies depending on the proxy
|
||||
/// kind (HTTP CONNECT vs SOCKS5) and the target scheme (ws vs wss).
|
||||
///
|
||||
/// We wrap in a newtype so we can implement `AsyncRead` and `AsyncWrite`
|
||||
/// via delegation, since Rust trait objects cannot combine multiple
|
||||
/// non-auto traits.
|
||||
pub struct BoxedIo(Box<dyn AsyncReadWrite>);
|
||||
|
||||
impl tokio::io::AsyncRead for BoxedIo {
|
||||
fn poll_read(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut *self.0).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for BoxedIo {
|
||||
fn poll_write(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
std::pin::Pin::new(&mut *self.0).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut *self.0).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut *self.0).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpin for BoxedIo {}
|
||||
|
||||
/// Convenience alias for the WebSocket stream returned by the proxy-aware
|
||||
/// connect helpers.
|
||||
pub type ProxiedWsStream = tokio_tungstenite::WebSocketStream<BoxedIo>;
|
||||
|
||||
/// Resolve the effective proxy URL for a WebSocket connection to the
|
||||
/// given `ws_url`, taking into account the per-channel `proxy_url`
|
||||
/// override, the runtime proxy config, scope and no_proxy list.
|
||||
fn resolve_ws_proxy_url(
|
||||
service_key: &str,
|
||||
ws_url: &str,
|
||||
channel_proxy_url: Option<&str>,
|
||||
) -> Option<String> {
|
||||
// 1. Explicit per-channel proxy always wins.
|
||||
if let Some(url) = normalize_proxy_url_option(channel_proxy_url) {
|
||||
return Some(url);
|
||||
}
|
||||
|
||||
// 2. Consult the runtime proxy config.
|
||||
let cfg = runtime_proxy_config();
|
||||
if !cfg.should_apply_to_service(service_key) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check the no_proxy list against the WebSocket target host.
|
||||
if let Ok(parsed) = reqwest::Url::parse(ws_url) {
|
||||
if let Some(host) = parsed.host_str() {
|
||||
let no_proxy_entries = cfg.normalized_no_proxy();
|
||||
if !no_proxy_entries.is_empty() {
|
||||
let host_lower = host.to_ascii_lowercase();
|
||||
let matches_no_proxy = no_proxy_entries.iter().any(|entry| {
|
||||
let entry = entry.trim().to_ascii_lowercase();
|
||||
if entry == "*" {
|
||||
return true;
|
||||
}
|
||||
if host_lower == entry {
|
||||
return true;
|
||||
}
|
||||
// Support ".example.com" matching "foo.example.com"
|
||||
if let Some(suffix) = entry.strip_prefix('.') {
|
||||
return host_lower.ends_with(suffix) || host_lower == suffix;
|
||||
}
|
||||
// Support "example.com" also matching "foo.example.com"
|
||||
host_lower.ends_with(&format!(".{entry}"))
|
||||
});
|
||||
if matches_no_proxy {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For wss:// prefer https_proxy, for ws:// prefer http_proxy, fall
|
||||
// back to all_proxy in both cases.
|
||||
let is_secure = ws_url.starts_with("wss://") || ws_url.starts_with("wss:");
|
||||
let preferred = if is_secure {
|
||||
normalize_proxy_url_option(cfg.https_proxy.as_deref())
|
||||
} else {
|
||||
normalize_proxy_url_option(cfg.http_proxy.as_deref())
|
||||
};
|
||||
preferred.or_else(|| normalize_proxy_url_option(cfg.all_proxy.as_deref()))
|
||||
}
|
||||
|
||||
/// Connect a WebSocket through the configured proxy (if any).
|
||||
///
|
||||
/// When no proxy applies, this is a thin wrapper around
|
||||
/// `tokio_tungstenite::connect_async`. When a proxy is active the
|
||||
/// function tunnels the TCP connection through the proxy before
|
||||
/// performing the WebSocket upgrade.
|
||||
///
|
||||
/// `service_key` is the proxy-service selector (e.g. `"channel.discord"`).
|
||||
/// `channel_proxy_url` is the optional per-channel proxy override.
|
||||
pub async fn ws_connect_with_proxy(
|
||||
ws_url: &str,
|
||||
service_key: &str,
|
||||
channel_proxy_url: Option<&str>,
|
||||
) -> anyhow::Result<(
|
||||
ProxiedWsStream,
|
||||
tokio_tungstenite::tungstenite::http::Response<Option<Vec<u8>>>,
|
||||
)> {
|
||||
let proxy_url = resolve_ws_proxy_url(service_key, ws_url, channel_proxy_url);
|
||||
|
||||
match proxy_url {
|
||||
None => {
|
||||
// No proxy — delegate directly.
|
||||
let (stream, resp) = tokio_tungstenite::connect_async(ws_url).await?;
|
||||
// Re-wrap the inner stream into our boxed type so the caller
|
||||
// always gets `ProxiedWsStream`.
|
||||
let inner = stream.into_inner();
|
||||
let boxed = BoxedIo(Box::new(inner));
|
||||
let ws = tokio_tungstenite::WebSocketStream::from_raw_socket(
|
||||
boxed,
|
||||
tokio_tungstenite::tungstenite::protocol::Role::Client,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
Ok((ws, resp))
|
||||
}
|
||||
Some(proxy) => ws_connect_via_proxy(ws_url, &proxy).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Establish a WebSocket connection tunnelled through the given proxy URL.
|
||||
async fn ws_connect_via_proxy(
|
||||
ws_url: &str,
|
||||
proxy_url: &str,
|
||||
) -> anyhow::Result<(
|
||||
ProxiedWsStream,
|
||||
tokio_tungstenite::tungstenite::http::Response<Option<Vec<u8>>>,
|
||||
)> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt as _};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
let target =
|
||||
reqwest::Url::parse(ws_url).with_context(|| format!("Invalid WebSocket URL: {ws_url}"))?;
|
||||
let target_host = target
|
||||
.host_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("WebSocket URL has no host: {ws_url}"))?
|
||||
.to_string();
|
||||
let target_port = target
|
||||
.port_or_known_default()
|
||||
.unwrap_or(if target.scheme() == "wss" { 443 } else { 80 });
|
||||
|
||||
let proxy = reqwest::Url::parse(proxy_url)
|
||||
.with_context(|| format!("Invalid proxy URL: {proxy_url}"))?;
|
||||
|
||||
let stream: BoxedIo = match proxy.scheme() {
|
||||
"socks5" | "socks5h" | "socks" => {
|
||||
let proxy_addr = format!(
|
||||
"{}:{}",
|
||||
proxy.host_str().unwrap_or("127.0.0.1"),
|
||||
proxy.port_or_known_default().unwrap_or(1080)
|
||||
);
|
||||
let target_addr = format!("{target_host}:{target_port}");
|
||||
let socks_stream = if proxy.username().is_empty() {
|
||||
tokio_socks::tcp::Socks5Stream::connect(proxy_addr.as_str(), target_addr.as_str())
|
||||
.await
|
||||
.with_context(|| format!("SOCKS5 connect to {target_addr} via {proxy_addr}"))?
|
||||
} else {
|
||||
let password = proxy.password().unwrap_or("");
|
||||
tokio_socks::tcp::Socks5Stream::connect_with_password(
|
||||
proxy_addr.as_str(),
|
||||
target_addr.as_str(),
|
||||
proxy.username(),
|
||||
password,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("SOCKS5 auth connect to {target_addr} via {proxy_addr}"))?
|
||||
};
|
||||
let tcp: TcpStream = socks_stream.into_inner();
|
||||
BoxedIo(Box::new(tcp))
|
||||
}
|
||||
"http" | "https" => {
|
||||
let proxy_host = proxy.host_str().unwrap_or("127.0.0.1");
|
||||
let proxy_port = proxy.port_or_known_default().unwrap_or(8080);
|
||||
let proxy_addr = format!("{proxy_host}:{proxy_port}");
|
||||
|
||||
let mut tcp = TcpStream::connect(&proxy_addr)
|
||||
.await
|
||||
.with_context(|| format!("TCP connect to HTTP proxy {proxy_addr}"))?;
|
||||
|
||||
// Send HTTP CONNECT request.
|
||||
let connect_req = format!(
|
||||
"CONNECT {target_host}:{target_port} HTTP/1.1\r\nHost: {target_host}:{target_port}\r\n\r\n"
|
||||
);
|
||||
tcp.write_all(connect_req.as_bytes()).await?;
|
||||
|
||||
// Read the response (we only need the status line).
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let mut total = 0usize;
|
||||
loop {
|
||||
let n = tcp.read(&mut buf[total..]).await?;
|
||||
if n == 0 {
|
||||
anyhow::bail!("HTTP CONNECT proxy closed connection before response");
|
||||
}
|
||||
total += n;
|
||||
// Look for end of HTTP headers.
|
||||
if let Some(pos) = find_header_end(&buf[..total]) {
|
||||
let status_line = std::str::from_utf8(&buf[..pos])
|
||||
.unwrap_or("")
|
||||
.lines()
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
if !status_line.contains("200") {
|
||||
anyhow::bail!(
|
||||
"HTTP CONNECT proxy returned non-200 response: {status_line}"
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if total >= buf.len() {
|
||||
anyhow::bail!("HTTP CONNECT proxy response too large");
|
||||
}
|
||||
}
|
||||
|
||||
BoxedIo(Box::new(tcp))
|
||||
}
|
||||
scheme => {
|
||||
anyhow::bail!("Unsupported proxy scheme '{scheme}' for WebSocket connections");
|
||||
}
|
||||
};
|
||||
|
||||
// If the target is wss://, wrap in TLS.
|
||||
let is_secure = target.scheme() == "wss";
|
||||
let stream: BoxedIo = if is_secure {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
let tls_config = std::sync::Arc::new(
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth(),
|
||||
);
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let server_name = rustls_pki_types::ServerName::try_from(target_host.clone())
|
||||
.with_context(|| format!("Invalid TLS server name: {target_host}"))?;
|
||||
|
||||
// `stream` is `BoxedIo` — we need a concrete `AsyncRead + AsyncWrite`
|
||||
// for `TlsConnector::connect`. Since `BoxedIo` already satisfies
|
||||
// those bounds we can pass it directly.
|
||||
let tls_stream = connector
|
||||
.connect(server_name, stream)
|
||||
.await
|
||||
.with_context(|| format!("TLS handshake with {target_host}"))?;
|
||||
BoxedIo(Box::new(tls_stream))
|
||||
} else {
|
||||
stream
|
||||
};
|
||||
|
||||
// Perform the WebSocket client handshake over the tunnelled stream.
|
||||
let ws_request = tokio_tungstenite::tungstenite::http::Request::builder()
|
||||
.uri(ws_url)
|
||||
.header("Host", format!("{target_host}:{target_port}"))
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "websocket")
|
||||
.header(
|
||||
"Sec-WebSocket-Key",
|
||||
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
|
||||
)
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.body(())
|
||||
.with_context(|| "Failed to build WebSocket upgrade request")?;
|
||||
|
||||
let (ws_stream, response) = tokio_tungstenite::client_async(ws_request, stream)
|
||||
.await
|
||||
.with_context(|| format!("WebSocket handshake failed for {ws_url}"))?;
|
||||
|
||||
Ok((ws_stream, response))
|
||||
}
|
||||
|
||||
/// Find the `\r\n\r\n` boundary marking the end of HTTP headers.
|
||||
fn find_header_end(buf: &[u8]) -> Option<usize> {
|
||||
buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4)
|
||||
}
|
||||
|
||||
fn parse_proxy_scope(raw: &str) -> Option<ProxyScope> {
|
||||
match raw.trim().to_ascii_lowercase().as_str() {
|
||||
"environment" | "env" => Some(ProxyScope::Environment),
|
||||
@ -4117,6 +4532,18 @@ pub struct StorageProviderConfig {
|
||||
/// Optional connection timeout in seconds for remote providers.
|
||||
#[serde(default)]
|
||||
pub connect_timeout_secs: Option<u64>,
|
||||
|
||||
/// Enable pgvector extension for hybrid vector+keyword recall.
|
||||
#[serde(default)]
|
||||
pub pgvector_enabled: bool,
|
||||
|
||||
/// Vector dimensions for pgvector embeddings (default: 1536).
|
||||
#[serde(default = "default_pgvector_dimensions")]
|
||||
pub pgvector_dimensions: usize,
|
||||
}
|
||||
|
||||
fn default_pgvector_dimensions() -> usize {
|
||||
1536
|
||||
}
|
||||
|
||||
fn default_storage_schema() -> String {
|
||||
@ -4135,6 +4562,8 @@ impl Default for StorageProviderConfig {
|
||||
schema: default_storage_schema(),
|
||||
table: default_storage_table(),
|
||||
connect_timeout_secs: None,
|
||||
pgvector_enabled: false,
|
||||
pgvector_dimensions: default_pgvector_dimensions(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7482,10 +7911,12 @@ impl Default for Config {
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
claude_code_runner: ClaudeCodeRunnerConfig::default(),
|
||||
codex_cli: CodexCliConfig::default(),
|
||||
gemini_cli: GeminiCliConfig::default(),
|
||||
opencode_cli: OpenCodeCliConfig::default(),
|
||||
sop: SopConfig::default(),
|
||||
shell_tool: ShellToolConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10540,10 +10971,12 @@ default_temperature = 0.7
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
claude_code_runner: ClaudeCodeRunnerConfig::default(),
|
||||
codex_cli: CodexCliConfig::default(),
|
||||
gemini_cli: GeminiCliConfig::default(),
|
||||
opencode_cli: OpenCodeCliConfig::default(),
|
||||
sop: SopConfig::default(),
|
||||
shell_tool: ShellToolConfig::default(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
@ -11062,10 +11495,12 @@ default_temperature = 0.7
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
claude_code_runner: ClaudeCodeRunnerConfig::default(),
|
||||
codex_cli: CodexCliConfig::default(),
|
||||
gemini_cli: GeminiCliConfig::default(),
|
||||
opencode_cli: OpenCodeCliConfig::default(),
|
||||
sop: SopConfig::default(),
|
||||
shell_tool: ShellToolConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await.unwrap();
|
||||
@ -13740,6 +14175,7 @@ default_model = "persisted-profile"
|
||||
assert_eq!(tc.model, "whisper-large-v3-turbo");
|
||||
assert!(tc.language.is_none());
|
||||
assert_eq!(tc.max_duration_secs, 120);
|
||||
assert!(!tc.transcribe_non_ptt_audio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -14696,4 +15132,55 @@ require_otp_to_resume = true
|
||||
assert_eq!(from_toml.loop_detection_window_size, 20);
|
||||
assert_eq!(from_toml.loop_detection_max_repeats, 3);
|
||||
}
|
||||
|
||||
// ── Docker baked config template ────────────────────────────
|
||||
|
||||
/// The TOML template baked into Docker images (Dockerfile + Dockerfile.debian).
|
||||
/// Kept here so changes to the Dockerfiles can be validated by `cargo test`.
|
||||
const DOCKER_CONFIG_TEMPLATE: &str = r#"
|
||||
workspace_dir = "/zeroclaw-data/workspace"
|
||||
config_path = "/zeroclaw-data/.zeroclaw/config.toml"
|
||||
api_key = ""
|
||||
default_provider = "openrouter"
|
||||
default_model = "anthropic/claude-sonnet-4-20250514"
|
||||
default_temperature = 0.7
|
||||
|
||||
[gateway]
|
||||
port = 42617
|
||||
host = "[::]"
|
||||
allow_public_bind = true
|
||||
|
||||
[autonomy]
|
||||
level = "supervised"
|
||||
auto_approve = ["file_read", "file_write", "file_edit", "memory_recall", "memory_store", "web_search_tool", "web_fetch", "calculator", "glob_search", "content_search", "image_info", "weather", "git_operations"]
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
async fn docker_config_template_is_parseable() {
|
||||
let cfg: Config = toml::from_str(DOCKER_CONFIG_TEMPLATE)
|
||||
.expect("Docker baked config.toml must be valid TOML that deserialises into Config");
|
||||
|
||||
// The [autonomy] section must be present and contain the expected tools.
|
||||
let auto = &cfg.autonomy.auto_approve;
|
||||
for tool in &[
|
||||
"file_read",
|
||||
"file_write",
|
||||
"file_edit",
|
||||
"memory_recall",
|
||||
"memory_store",
|
||||
"web_search_tool",
|
||||
"web_fetch",
|
||||
"calculator",
|
||||
"glob_search",
|
||||
"content_search",
|
||||
"image_info",
|
||||
"weather",
|
||||
"git_operations",
|
||||
] {
|
||||
assert!(
|
||||
auto.iter().any(|t| t == tool),
|
||||
"Docker config auto_approve missing expected tool: {tool}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -390,12 +390,47 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
|
||||
deliver_announcement(config, channel, target, output).await
|
||||
}
|
||||
|
||||
/// Output that has been scanned for credential leaks and redacted if necessary.
|
||||
/// All channel dispatch must use this type — constructing it requires going through
|
||||
/// `scan_and_redact_output`, which enforces leak detection on every outbound path.
|
||||
pub(crate) struct RedactedOutput(String);
|
||||
|
||||
impl RedactedOutput {
|
||||
/// Access the safe-to-send content.
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Scan cron job output for credential leaks and return redacted output if leaks are detected.
|
||||
/// Logs a warning with channel, target, and detected patterns when credentials are found.
|
||||
fn scan_and_redact_output(channel: &str, target: &str, output: &str) -> RedactedOutput {
|
||||
let leak_detector = crate::security::LeakDetector::new();
|
||||
let leak_check = leak_detector.scan(output);
|
||||
|
||||
match leak_check {
|
||||
crate::security::LeakResult::Detected { patterns, redacted } => {
|
||||
tracing::warn!(
|
||||
channel = %channel,
|
||||
target = %target,
|
||||
patterns = ?patterns,
|
||||
"Credential leak detected in cron job output; redacting before delivery"
|
||||
);
|
||||
RedactedOutput(redacted)
|
||||
}
|
||||
crate::security::LeakResult::Clean => RedactedOutput(output.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn deliver_announcement(
|
||||
config: &Config,
|
||||
channel: &str,
|
||||
target: &str,
|
||||
output: &str,
|
||||
) -> Result<()> {
|
||||
// Scan for credential leaks before delivering cron job output to channel.
|
||||
let safe_output = scan_and_redact_output(channel, target, output);
|
||||
|
||||
match channel.to_ascii_lowercase().as_str() {
|
||||
"telegram" => {
|
||||
let tg = config
|
||||
@ -408,7 +443,9 @@ pub(crate) async fn deliver_announcement(
|
||||
tg.allowed_users.clone(),
|
||||
tg.mention_only,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
"discord" => {
|
||||
let dc = config
|
||||
@ -423,7 +460,9 @@ pub(crate) async fn deliver_announcement(
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
"slack" => {
|
||||
let sl = config
|
||||
@ -439,7 +478,9 @@ pub(crate) async fn deliver_announcement(
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone());
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
"mattermost" => {
|
||||
let mm = config
|
||||
@ -455,7 +496,9 @@ pub(crate) async fn deliver_announcement(
|
||||
mm.thread_replies.unwrap_or(true),
|
||||
mm.mention_only.unwrap_or(false),
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
"signal" => {
|
||||
let sg = config
|
||||
@ -471,7 +514,9 @@ pub(crate) async fn deliver_announcement(
|
||||
sg.ignore_attachments,
|
||||
sg.ignore_stories,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
"matrix" => {
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
@ -491,7 +536,9 @@ pub(crate) async fn deliver_announcement(
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
#[cfg(not(feature = "channel-matrix"))]
|
||||
{
|
||||
@ -521,7 +568,9 @@ pub(crate) async fn deliver_announcement(
|
||||
wa.group_policy.clone(),
|
||||
wa.self_chat_mode,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
@ -539,7 +588,9 @@ pub(crate) async fn deliver_announcement(
|
||||
qq.app_secret.clone(),
|
||||
qq.allowed_users.clone(),
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
channel
|
||||
.send(&SendMessage::new(safe_output.as_str(), target))
|
||||
.await?;
|
||||
}
|
||||
other => anyhow::bail!("unsupported delivery channel: {other}"),
|
||||
}
|
||||
@ -1327,4 +1378,26 @@ mod tests {
|
||||
let overdue = cron::all_overdue_jobs(&config, far_future).unwrap();
|
||||
assert_eq!(overdue.len(), 3, "all_overdue_jobs must return all");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_and_redact_output_redacts_credentials() {
|
||||
let leaked_output = "Deployment key: sk_test_FAKE1234567890abcdefgh"; // gitleaks:allow
|
||||
|
||||
let redacted = scan_and_redact_output("telegram", "123456", leaked_output);
|
||||
|
||||
assert!(
|
||||
!redacted.as_str().contains("sk_test_FAKE1234567890abcdefgh"), // gitleaks:allow
|
||||
"credentials must be redacted"
|
||||
);
|
||||
assert!(redacted.as_str().contains("[REDACTED"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_and_redact_output_preserves_clean_output() {
|
||||
let clean_output = "Deployment completed successfully at 2024-03-15 10:00:00";
|
||||
|
||||
let redacted = scan_and_redact_output("telegram", "123456", clean_output);
|
||||
|
||||
assert_eq!(redacted.as_str(), clean_output);
|
||||
}
|
||||
}
|
||||
|
||||
@ -261,7 +261,13 @@ pub fn update_job(config: &Config, job_id: &str, patch: CronJobPatch) -> Result<
|
||||
job.delete_after_run = delete_after_run;
|
||||
}
|
||||
if let Some(allowed_tools) = patch.allowed_tools {
|
||||
job.allowed_tools = Some(allowed_tools);
|
||||
// Empty list means "clear the allowlist" (all tools available),
|
||||
// not "allow zero tools".
|
||||
if allowed_tools.is_empty() {
|
||||
job.allowed_tools = None;
|
||||
} else {
|
||||
job.allowed_tools = Some(allowed_tools);
|
||||
}
|
||||
}
|
||||
|
||||
if schedule_changed {
|
||||
|
||||
@ -1380,6 +1380,34 @@ pub async fn handle_api_session_rename(
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude Code hook endpoint ────────────────────────────────────
|
||||
|
||||
/// POST /hooks/claude-code — receives HTTP hook events from Claude Code
|
||||
/// sessions spawned by [`ClaudeCodeRunnerTool`].
|
||||
///
|
||||
/// Claude Code posts structured JSON describing tool executions, completions,
|
||||
/// and errors. This handler logs the event and (when a Slack channel is
|
||||
/// configured) could be wired to update a Slack message in-place.
|
||||
pub async fn handle_claude_code_hook(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<crate::tools::claude_code_runner::ClaudeCodeHookEvent>,
|
||||
) -> impl IntoResponse {
|
||||
// Do not require bearer-token auth: Claude Code subprocesses cannot easily
|
||||
// obtain a pairing token, and the hook carries a session_id that ties it
|
||||
// back to a session we spawned.
|
||||
let _ = &state; // retained for future Slack update wiring
|
||||
|
||||
tracing::info!(
|
||||
session_id = %payload.session_id,
|
||||
event_type = %payload.event_type,
|
||||
tool_name = ?payload.tool_name,
|
||||
summary = ?payload.summary,
|
||||
"Claude Code hook event received"
|
||||
);
|
||||
|
||||
Json(serde_json::json!({ "ok": true }))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -440,23 +440,28 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
|
||||
let canvas_store = tools::CanvasStore::new();
|
||||
|
||||
let (mut tools_registry_raw, delegate_handle_gw, _reaction_handle_gw, _channel_map_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
Some(canvas_store.clone()),
|
||||
);
|
||||
let (
|
||||
mut tools_registry_raw,
|
||||
delegate_handle_gw,
|
||||
_reaction_handle_gw,
|
||||
_channel_map_handle,
|
||||
_ask_user_handle_gw,
|
||||
) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
Some(canvas_store.clone()),
|
||||
);
|
||||
|
||||
// ── Wire MCP tools into the gateway tool registry (non-fatal) ───
|
||||
// Without this, the `/api/tools` endpoint misses MCP tools.
|
||||
@ -858,6 +863,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/wati", post(handle_wati_webhook))
|
||||
.route("/nextcloud-talk", post(handle_nextcloud_talk_webhook))
|
||||
.route("/webhook/gmail", post(handle_gmail_push_webhook))
|
||||
// ── Claude Code runner hooks ──
|
||||
.route("/hooks/claude-code", post(api::handle_claude_code_hook))
|
||||
// ── Web Dashboard API routes ──
|
||||
.route("/api/status", get(api::handle_api_status))
|
||||
.route("/api/config", get(api::handle_api_config_get))
|
||||
|
||||
@ -259,6 +259,14 @@ pub enum SkillCommands {
|
||||
/// Skill name to remove
|
||||
name: String,
|
||||
},
|
||||
/// Run TEST.sh validation for a skill (or all skills)
|
||||
Test {
|
||||
/// Skill name to test; omit for all skills
|
||||
name: Option<String>,
|
||||
/// Show verbose output
|
||||
#[arg(long)]
|
||||
verbose: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// Migration subcommands
|
||||
|
||||
135
src/main.rs
135
src/main.rs
@ -535,6 +535,25 @@ Examples:
|
||||
shell: CompletionShell,
|
||||
},
|
||||
|
||||
/// Launch or install the companion desktop app
|
||||
#[command(long_about = "\
|
||||
Launch the ZeroClaw companion desktop app.
|
||||
|
||||
The companion app is a lightweight menu bar / system tray application \
|
||||
that connects to the same gateway as the CLI. It provides quick access \
|
||||
to the dashboard, status monitoring, and device pairing.
|
||||
|
||||
Use --install to download the pre-built companion app for your platform.
|
||||
|
||||
Examples:
|
||||
zeroclaw desktop # launch the companion app
|
||||
zeroclaw desktop --install # download and install it")]
|
||||
Desktop {
|
||||
/// Download and install the companion app
|
||||
#[arg(long)]
|
||||
install: bool,
|
||||
},
|
||||
|
||||
/// Manage WASM plugins
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
Plugin {
|
||||
@ -1324,6 +1343,122 @@ async fn main() -> Result<()> {
|
||||
.await
|
||||
}
|
||||
|
||||
Commands::Desktop {
|
||||
install: do_install,
|
||||
} => {
|
||||
let download_url = "https://www.zeroclawlabs.ai/download";
|
||||
|
||||
if do_install {
|
||||
println!("Download the ZeroClaw companion app:");
|
||||
println!();
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
println!(" macOS: {download_url}");
|
||||
println!();
|
||||
println!("Or install via Homebrew (coming soon):");
|
||||
println!(" brew install --cask zeroclaw");
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
println!(" Linux: {download_url}");
|
||||
println!();
|
||||
println!(" Download the .deb or .AppImage for your architecture.");
|
||||
}
|
||||
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
|
||||
{
|
||||
println!(" {download_url}");
|
||||
}
|
||||
println!();
|
||||
|
||||
// On macOS, open the download page in the browser
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let _ = std::process::Command::new("open").arg(download_url).spawn();
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
let _ = std::process::Command::new("xdg-open")
|
||||
.arg(download_url)
|
||||
.spawn();
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Locate the companion app
|
||||
let desktop_bin = {
|
||||
let mut found = None;
|
||||
|
||||
// 1. macOS: check /Applications/ZeroClaw.app
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let app_paths = [
|
||||
PathBuf::from("/Applications/ZeroClaw.app/Contents/MacOS/ZeroClaw"),
|
||||
PathBuf::from(std::env::var("HOME").unwrap_or_default())
|
||||
.join("Applications/ZeroClaw.app/Contents/MacOS/ZeroClaw"),
|
||||
];
|
||||
for app in &app_paths {
|
||||
if app.is_file() {
|
||||
found = Some(app.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Same directory as the current executable
|
||||
if found.is_none() {
|
||||
if let Ok(exe) = std::env::current_exe() {
|
||||
let sibling = exe.with_file_name("zeroclaw-desktop");
|
||||
if sibling.is_file() {
|
||||
found = Some(sibling);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. ~/.cargo/bin/zeroclaw-desktop or ~/.local/bin/zeroclaw-desktop
|
||||
if found.is_none() {
|
||||
if let Some(home) = std::env::var_os("HOME") {
|
||||
let home = PathBuf::from(home);
|
||||
for dir in &[".cargo/bin", ".local/bin"] {
|
||||
let candidate = home.join(dir).join("zeroclaw-desktop");
|
||||
if candidate.is_file() {
|
||||
found = Some(candidate);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Fallback to PATH lookup
|
||||
if found.is_none() {
|
||||
if let Ok(path) = which::which("zeroclaw-desktop") {
|
||||
found = Some(path);
|
||||
}
|
||||
}
|
||||
|
||||
found
|
||||
};
|
||||
|
||||
match desktop_bin {
|
||||
Some(bin) => {
|
||||
println!("Launching ZeroClaw companion app...");
|
||||
let _child = std::process::Command::new(&bin)
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to launch {}", bin.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
None => {
|
||||
println!("ZeroClaw companion app is not installed.");
|
||||
println!();
|
||||
println!(" Download it at: {download_url}");
|
||||
println!(" Or run: zeroclaw desktop --install");
|
||||
println!();
|
||||
println!("The companion app is a lightweight menu bar app that");
|
||||
println!("connects to the same gateway as the CLI.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Commands::Update {
|
||||
check,
|
||||
force: _force,
|
||||
|
||||
@ -59,6 +59,8 @@ fn create_cli_memory(config: &Config) -> Result<Box<dyn Memory>> {
|
||||
&sp.schema,
|
||||
&sp.table,
|
||||
sp.connect_timeout_secs,
|
||||
Some(sp.pgvector_enabled),
|
||||
Some(sp.pgvector_dimensions),
|
||||
)?;
|
||||
Ok(Box::new(mem))
|
||||
}
|
||||
|
||||
@ -20,6 +20,12 @@ pub struct ConsolidationResult {
|
||||
pub history_entry: String,
|
||||
/// New facts/preferences/decisions to store long-term, or None.
|
||||
pub memory_update: Option<String>,
|
||||
/// Atomic facts extracted from the turn (when consolidation_extract_facts is enabled).
|
||||
#[serde(default)]
|
||||
pub facts: Vec<String>,
|
||||
/// Observed trend or pattern (when consolidation_extract_facts is enabled).
|
||||
#[serde(default)]
|
||||
pub trend: Option<String>,
|
||||
}
|
||||
|
||||
const CONSOLIDATION_SYSTEM_PROMPT: &str = r#"You are a memory consolidation engine. Given a conversation turn, extract:
|
||||
@ -141,6 +147,8 @@ fn parse_consolidation_response(raw: &str, fallback_text: &str) -> Consolidation
|
||||
ConsolidationResult {
|
||||
history_entry: summary,
|
||||
memory_update: None,
|
||||
facts: Vec::new(),
|
||||
trend: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
317
src/memory/knowledge_graph_pg.rs
Normal file
317
src/memory/knowledge_graph_pg.rs
Normal file
@ -0,0 +1,317 @@
|
||||
//! PostgreSQL-backed knowledge graph with optional vector similarity.
|
||||
//!
|
||||
//! Feature-gated behind `memory-postgres`. Uses pure SQL with recursive CTEs
|
||||
//! rather than requiring the AGE extension.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use parking_lot::Mutex;
|
||||
use postgres::{Client, Row};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
pub use super::knowledge_graph::{NodeType, Relation};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgNode {
|
||||
pub id: i64,
|
||||
pub name: String,
|
||||
pub node_type: NodeType,
|
||||
pub content: String,
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgEdge {
|
||||
pub source_id: i64,
|
||||
pub target_id: i64,
|
||||
pub relation: Relation,
|
||||
pub weight: f64,
|
||||
}
|
||||
|
||||
pub struct PgKnowledgeGraph {
|
||||
client: Arc<Mutex<Client>>,
|
||||
schema: String,
|
||||
}
|
||||
|
||||
async fn run_on_os_thread<F, T>(f: F) -> Result<T>
|
||||
where
|
||||
F: FnOnce() -> Result<T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
std::thread::Builder::new()
|
||||
.name("pg-knowledge-graph-op".to_string())
|
||||
.spawn(move || {
|
||||
let _ = tx.send(f());
|
||||
})
|
||||
.context("failed to spawn pg knowledge graph thread")?;
|
||||
rx.await
|
||||
.map_err(|_| anyhow::anyhow!("pg knowledge graph thread terminated unexpectedly"))?
|
||||
}
|
||||
|
||||
impl PgKnowledgeGraph {
|
||||
pub fn new(client: Arc<Mutex<Client>>, schema: &str) -> Result<Self> {
|
||||
let graph = Self {
|
||||
client,
|
||||
schema: schema.to_string(),
|
||||
};
|
||||
graph.init_schema_sync()?;
|
||||
Ok(graph)
|
||||
}
|
||||
|
||||
fn init_schema_sync(&self) -> Result<()> {
|
||||
let mut client = self.client.lock();
|
||||
let schema = &self.schema;
|
||||
client.batch_execute(&format!(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS "{schema}".kg_nodes (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
tags TEXT[] NOT NULL DEFAULT '{{}}'::TEXT[],
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_kg_nodes_type ON "{schema}".kg_nodes(node_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_kg_nodes_tags ON "{schema}".kg_nodes USING gin(tags);
|
||||
CREATE INDEX IF NOT EXISTS idx_kg_nodes_fts ON "{schema}".kg_nodes
|
||||
USING gin(to_tsvector('simple', name || ' ' || content));
|
||||
CREATE TABLE IF NOT EXISTS "{schema}".kg_edges (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES "{schema}".kg_nodes(id) ON DELETE CASCADE,
|
||||
target_id BIGINT NOT NULL REFERENCES "{schema}".kg_nodes(id) ON DELETE CASCADE,
|
||||
relation TEXT NOT NULL,
|
||||
weight DOUBLE PRECISION NOT NULL DEFAULT 1.0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_kg_edges_source ON "{schema}".kg_edges(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_kg_edges_target ON "{schema}".kg_edges(target_id);
|
||||
"#
|
||||
))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn node_type_str(nt: &NodeType) -> &'static str {
|
||||
match nt {
|
||||
NodeType::Pattern => "pattern",
|
||||
NodeType::Decision => "decision",
|
||||
NodeType::Lesson => "lesson",
|
||||
NodeType::Expert => "expert",
|
||||
NodeType::Technology => "technology",
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_node_type(s: &str) -> NodeType {
|
||||
match s {
|
||||
"pattern" => NodeType::Pattern,
|
||||
"decision" => NodeType::Decision,
|
||||
"lesson" => NodeType::Lesson,
|
||||
"expert" => NodeType::Expert,
|
||||
"technology" => NodeType::Technology,
|
||||
_ => NodeType::Pattern,
|
||||
}
|
||||
}
|
||||
|
||||
fn relation_str(r: &Relation) -> &'static str {
|
||||
match r {
|
||||
Relation::Uses => "uses",
|
||||
Relation::Replaces => "replaces",
|
||||
Relation::Extends => "extends",
|
||||
Relation::AuthoredBy => "authored_by",
|
||||
Relation::AppliesTo => "applies_to",
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_relation(s: &str) -> Relation {
|
||||
match s {
|
||||
"uses" => Relation::Uses,
|
||||
"replaces" => Relation::Replaces,
|
||||
"extends" => Relation::Extends,
|
||||
"authored_by" => Relation::AuthoredBy,
|
||||
"applies_to" => Relation::AppliesTo,
|
||||
_ => Relation::Uses,
|
||||
}
|
||||
}
|
||||
|
||||
fn row_to_node(row: &Row) -> PgNode {
|
||||
PgNode {
|
||||
id: row.get(0),
|
||||
name: row.get(1),
|
||||
node_type: Self::parse_node_type(&row.get::<_, String>(2)),
|
||||
content: row.get(3),
|
||||
tags: row.get(4),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_node(
|
||||
&self,
|
||||
name: &str,
|
||||
node_type: NodeType,
|
||||
content: &str,
|
||||
tags: &[String],
|
||||
) -> Result<i64> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
let name = name.to_string();
|
||||
let nt = Self::node_type_str(&node_type).to_string();
|
||||
let content = content.to_string();
|
||||
let tags = tags.to_vec();
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let row = client.query_one(&format!(r#"INSERT INTO "{schema}".kg_nodes (name, node_type, content, tags) VALUES ($1, $2, $3, $4) RETURNING id"#), &[&name, &nt, &content, &tags])?;
|
||||
Ok(row.get(0))
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn add_edge(
|
||||
&self,
|
||||
source_id: i64,
|
||||
target_id: i64,
|
||||
relation: Relation,
|
||||
weight: f64,
|
||||
) -> Result<i64> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
let rel = Self::relation_str(&relation).to_string();
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let row = client.query_one(&format!(r#"INSERT INTO "{schema}".kg_edges (source_id, target_id, relation, weight) VALUES ($1, $2, $3, $4) RETURNING id"#), &[&source_id, &target_id, &rel, &weight])?;
|
||||
Ok(row.get(0))
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn get_node(&self, id: i64) -> Result<Option<PgNode>> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let row = client.query_opt(&format!(r#"SELECT id, name, node_type, content, tags FROM "{schema}".kg_nodes WHERE id = $1"#), &[&id])?;
|
||||
Ok(row.as_ref().map(Self::row_to_node))
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn query_by_tags(&self, tags: &[String], limit: usize) -> Result<Vec<PgNode>> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
let tags = tags.to_vec();
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit = limit as i64;
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let rows = client.query(&format!(r#"SELECT id, name, node_type, content, tags FROM "{schema}".kg_nodes WHERE tags && $1 LIMIT $2"#), &[&tags, &limit])?;
|
||||
Ok(rows.iter().map(Self::row_to_node).collect())
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn query_by_similarity(&self, query: &str, limit: usize) -> Result<Vec<PgNode>> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
let query = query.to_string();
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit = limit as i64;
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let rows = client.query(&format!(r#"SELECT id, name, node_type, content, tags FROM "{schema}".kg_nodes WHERE to_tsvector('simple', name || ' ' || content) @@ plainto_tsquery('simple', $1) LIMIT $2"#), &[&query, &limit])?;
|
||||
Ok(rows.iter().map(Self::row_to_node).collect())
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn find_related(&self, node_id: i64, limit: usize) -> Result<Vec<PgNode>> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit = limit as i64;
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let rows = client.query(&format!(r#"SELECT n.id, n.name, n.node_type, n.content, n.tags FROM "{schema}".kg_nodes n JOIN "{schema}".kg_edges e ON n.id = e.target_id WHERE e.source_id = $1 UNION SELECT n.id, n.name, n.node_type, n.content, n.tags FROM "{schema}".kg_nodes n JOIN "{schema}".kg_edges e ON n.id = e.source_id WHERE e.target_id = $1 LIMIT $2"#), &[&node_id, &limit])?;
|
||||
Ok(rows.iter().map(Self::row_to_node).collect())
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn get_subgraph(&self, root_id: i64, max_depth: u32) -> Result<Vec<PgNode>> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let max_depth = max_depth as i32;
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let rows = client.query(&format!(r#"WITH RECURSIVE reachable AS (SELECT id, name, node_type, content, tags, 0 AS depth FROM "{schema}".kg_nodes WHERE id = $1 UNION SELECT n.id, n.name, n.node_type, n.content, n.tags, r.depth + 1 FROM "{schema}".kg_nodes n JOIN "{schema}".kg_edges e ON n.id = e.target_id JOIN reachable r ON e.source_id = r.id WHERE r.depth < $2) SELECT DISTINCT id, name, node_type, content, tags FROM reachable"#), &[&root_id, &max_depth])?;
|
||||
Ok(rows.iter().map(Self::row_to_node).collect())
|
||||
}).await
|
||||
}
|
||||
|
||||
pub async fn stats(&self) -> Result<(i64, i64)> {
|
||||
let client = self.client.clone();
|
||||
let schema = self.schema.clone();
|
||||
run_on_os_thread(move || {
|
||||
let mut client = client.lock();
|
||||
let nc: i64 = client
|
||||
.query_one(&format!(r#"SELECT COUNT(*) FROM "{schema}".kg_nodes"#), &[])?
|
||||
.get(0);
|
||||
let ec: i64 = client
|
||||
.query_one(&format!(r#"SELECT COUNT(*) FROM "{schema}".kg_edges"#), &[])?
|
||||
.get(0);
|
||||
Ok((nc, ec))
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn node_type_roundtrips() {
|
||||
for nt in &[
|
||||
NodeType::Pattern,
|
||||
NodeType::Decision,
|
||||
NodeType::Lesson,
|
||||
NodeType::Expert,
|
||||
NodeType::Technology,
|
||||
] {
|
||||
let s = PgKnowledgeGraph::node_type_str(nt);
|
||||
assert_eq!(&PgKnowledgeGraph::parse_node_type(s), nt);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relation_roundtrips() {
|
||||
for r in &[
|
||||
Relation::Uses,
|
||||
Relation::Replaces,
|
||||
Relation::Extends,
|
||||
Relation::AuthoredBy,
|
||||
Relation::AppliesTo,
|
||||
] {
|
||||
let s = PgKnowledgeGraph::relation_str(r);
|
||||
assert_eq!(&PgKnowledgeGraph::parse_relation(s), r);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_node_type_defaults_to_pattern() {
|
||||
assert_eq!(
|
||||
PgKnowledgeGraph::parse_node_type("nonexistent"),
|
||||
NodeType::Pattern
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_relation_defaults_to_uses() {
|
||||
assert_eq!(
|
||||
PgKnowledgeGraph::parse_relation("nonexistent"),
|
||||
Relation::Uses
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn init_schema_sql_is_syntactically_valid() {
|
||||
let schema = "test_schema";
|
||||
let sql = format!(
|
||||
r#"CREATE TABLE IF NOT EXISTS "{schema}".kg_nodes (id BIGSERIAL PRIMARY KEY, name TEXT NOT NULL);"#
|
||||
);
|
||||
assert!(sql.contains("BIGSERIAL"));
|
||||
assert!(sql.contains("test_schema"));
|
||||
}
|
||||
}
|
||||
@ -9,6 +9,8 @@ pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod importance;
|
||||
pub mod knowledge_graph;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub mod knowledge_graph_pg;
|
||||
pub mod lucid;
|
||||
pub mod markdown;
|
||||
pub mod none;
|
||||
@ -338,6 +340,8 @@ pub fn create_memory_with_storage_and_routes(
|
||||
&storage_provider.schema,
|
||||
&storage_provider.table,
|
||||
storage_provider.connect_timeout_secs,
|
||||
Some(storage_provider.pgvector_enabled),
|
||||
Some(storage_provider.pgvector_dimensions),
|
||||
)?;
|
||||
Ok(Box::new(memory))
|
||||
}
|
||||
|
||||
@ -19,6 +19,8 @@ const POSTGRES_CONNECT_TIMEOUT_CAP_SECS: u64 = 300;
|
||||
pub struct PostgresMemory {
|
||||
client: Arc<Mutex<Client>>,
|
||||
qualified_table: String,
|
||||
pgvector_enabled: bool,
|
||||
pgvector_dimensions: usize,
|
||||
}
|
||||
|
||||
impl PostgresMemory {
|
||||
@ -27,6 +29,8 @@ impl PostgresMemory {
|
||||
schema: &str,
|
||||
table: &str,
|
||||
connect_timeout_secs: Option<u64>,
|
||||
pgvector_enabled: Option<bool>,
|
||||
pgvector_dimensions: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
validate_identifier(schema, "storage schema")?;
|
||||
validate_identifier(table, "storage table")?;
|
||||
@ -42,10 +46,34 @@ impl PostgresMemory {
|
||||
qualified_table.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
client: Arc::new(Mutex::new(client)),
|
||||
qualified_table,
|
||||
})
|
||||
let pgvector_enabled = pgvector_enabled.unwrap_or(false);
|
||||
let pgvector_dimensions = pgvector_dimensions.unwrap_or(1536);
|
||||
|
||||
if pgvector_enabled {
|
||||
let client_ref = Arc::new(Mutex::new(client));
|
||||
let ext_ok = {
|
||||
let mut c = client_ref.lock();
|
||||
Self::try_enable_pgvector(&mut c, &qualified_table, pgvector_dimensions).is_ok()
|
||||
};
|
||||
if !ext_ok {
|
||||
tracing::warn!(
|
||||
"pgvector extension not available; falling back to keyword-only recall"
|
||||
);
|
||||
}
|
||||
Ok(Self {
|
||||
client: client_ref,
|
||||
qualified_table,
|
||||
pgvector_enabled: ext_ok,
|
||||
pgvector_dimensions,
|
||||
})
|
||||
} else {
|
||||
Ok(Self {
|
||||
client: Arc::new(Mutex::new(client)),
|
||||
qualified_table,
|
||||
pgvector_enabled: false,
|
||||
pgvector_dimensions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize_client(
|
||||
@ -126,6 +154,27 @@ impl PostgresMemory {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_enable_pgvector(
|
||||
client: &mut Client,
|
||||
qualified_table: &str,
|
||||
dimensions: usize,
|
||||
) -> Result<()> {
|
||||
client.batch_execute("CREATE EXTENSION IF NOT EXISTS vector")?;
|
||||
client.batch_execute(&format!(
|
||||
r#"
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE {qualified_table} ADD COLUMN IF NOT EXISTS namespace TEXT DEFAULT 'default';
|
||||
ALTER TABLE {qualified_table} ADD COLUMN IF NOT EXISTS importance REAL;
|
||||
ALTER TABLE {qualified_table} ADD COLUMN IF NOT EXISTS embedding vector({dimensions});
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgvector columns could not be added: %', SQLERRM;
|
||||
END $$;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON {qualified_table}(namespace);
|
||||
"#
|
||||
))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn row_to_entry(row: &Row) -> Result<MemoryEntry> {
|
||||
let timestamp: DateTime<Utc> = row.get(4);
|
||||
|
||||
@ -137,8 +186,10 @@ impl PostgresMemory {
|
||||
timestamp: timestamp.to_rfc3339(),
|
||||
session_id: row.get(5),
|
||||
score: row.try_get(6).ok(),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
namespace: row
|
||||
.try_get::<_, String>(7)
|
||||
.unwrap_or_else(|_| "default".into()),
|
||||
importance: row.try_get(8).ok(),
|
||||
superseded_by: None,
|
||||
})
|
||||
}
|
||||
@ -437,6 +488,8 @@ mod tests {
|
||||
"public",
|
||||
"memories",
|
||||
Some(1),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
});
|
||||
|
||||
|
||||
@ -926,6 +926,38 @@ impl Memory for SqliteMemory {
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn purge_namespace(&self, namespace: &str) -> anyhow::Result<usize> {
|
||||
let conn = self.conn.clone();
|
||||
let namespace = namespace.to_string();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
|
||||
let conn = conn.lock();
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memories WHERE category = ?1",
|
||||
params![namespace],
|
||||
)?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(affected as usize)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn purge_session(&self, session_id: &str) -> anyhow::Result<usize> {
|
||||
let conn = self.conn.clone();
|
||||
let session_id = session_id.to_string();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
|
||||
let conn = conn.lock();
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memories WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
)?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(affected as usize)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self.conn.clone();
|
||||
|
||||
@ -1920,6 +1952,153 @@ mod tests {
|
||||
assert!(all.is_empty());
|
||||
}
|
||||
|
||||
// ── Bulk deletion tests ───────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_namespace_removes_all_matching_entries() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "data1", MemoryCategory::Custom("ns1".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("a2", "data2", MemoryCategory::Custom("ns1".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b1", "data3", MemoryCategory::Custom("ns2".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let count = mem.purge_namespace("ns1").await.unwrap();
|
||||
assert_eq!(count, 2);
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_namespace_preserves_other_namespaces() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "data1", MemoryCategory::Custom("ns1".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b1", "data2", MemoryCategory::Custom("ns2".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c1", "data3", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("d1", "data4", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let count = mem.purge_namespace("ns1").await.unwrap();
|
||||
assert_eq!(count, 1);
|
||||
assert_eq!(mem.count().await.unwrap(), 3);
|
||||
|
||||
let remaining = mem.list(None, None).await.unwrap();
|
||||
assert!(remaining
|
||||
.iter()
|
||||
.all(|e| e.category != MemoryCategory::Custom("ns1".into())));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_namespace_returns_count() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
for i in 0..5 {
|
||||
mem.store(
|
||||
&format!("k{i}"),
|
||||
"data",
|
||||
MemoryCategory::Custom("target".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let count = mem.purge_namespace("target").await.unwrap();
|
||||
assert_eq!(count, 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_session_removes_all_matching_entries() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "data1", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("a2", "data2", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b1", "data3", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let count = mem.purge_session("sess-a").await.unwrap();
|
||||
assert_eq!(count, 2);
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_session_preserves_other_sessions() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "data1", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b1", "data2", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c1", "data3", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let count = mem.purge_session("sess-a").await.unwrap();
|
||||
assert_eq!(count, 1);
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let remaining = mem.list(None, None).await.unwrap();
|
||||
assert!(remaining
|
||||
.iter()
|
||||
.all(|e| e.session_id.as_deref() != Some("sess-a")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_session_returns_count() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
for i in 0..3 {
|
||||
mem.store(
|
||||
&format!("k{i}"),
|
||||
"data",
|
||||
MemoryCategory::Core,
|
||||
Some("target-sess"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let count = mem.purge_session("target-sess").await.unwrap();
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_namespace_empty_namespace_is_noop() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let count = mem.purge_namespace("").await.unwrap();
|
||||
assert_eq!(count, 0);
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_purge_session_empty_session_is_noop() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core, Some("sess"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let count = mem.purge_session("").await.unwrap();
|
||||
assert_eq!(count, 0);
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
// ── Session isolation ─────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@ -135,6 +135,20 @@ pub trait Memory: Send + Sync {
|
||||
/// Remove a memory by key
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
|
||||
/// Remove all memories in a namespace (category).
|
||||
/// Returns the number of deleted entries.
|
||||
/// Default: returns unsupported error. Backends that support bulk deletion override this.
|
||||
async fn purge_namespace(&self, _namespace: &str) -> anyhow::Result<usize> {
|
||||
anyhow::bail!("purge_namespace not supported by this memory backend")
|
||||
}
|
||||
|
||||
/// Remove all memories in a session.
|
||||
/// Returns the number of deleted entries.
|
||||
/// Default: returns unsupported error. Backends that support bulk deletion override this.
|
||||
async fn purge_session(&self, _session_id: &str) -> anyhow::Result<usize> {
|
||||
anyhow::bail!("purge_session not supported by this memory backend")
|
||||
}
|
||||
|
||||
/// Count total memories
|
||||
async fn count(&self) -> anyhow::Result<usize>;
|
||||
|
||||
|
||||
@ -204,10 +204,12 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
claude_code: crate::config::ClaudeCodeConfig::default(),
|
||||
claude_code_runner: crate::config::ClaudeCodeRunnerConfig::default(),
|
||||
codex_cli: crate::config::CodexCliConfig::default(),
|
||||
gemini_cli: crate::config::GeminiCliConfig::default(),
|
||||
opencode_cli: crate::config::OpenCodeCliConfig::default(),
|
||||
sop: crate::config::SopConfig::default(),
|
||||
shell_tool: crate::config::ShellToolConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@ -642,10 +644,12 @@ async fn run_quick_setup_with_home(
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
claude_code: crate::config::ClaudeCodeConfig::default(),
|
||||
claude_code_runner: crate::config::ClaudeCodeRunnerConfig::default(),
|
||||
codex_cli: crate::config::CodexCliConfig::default(),
|
||||
gemini_cli: crate::config::GeminiCliConfig::default(),
|
||||
opencode_cli: crate::config::OpenCodeCliConfig::default(),
|
||||
sop: crate::config::SopConfig::default(),
|
||||
shell_tool: crate::config::ShellToolConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
//! AWS Bedrock provider using the Converse API.
|
||||
//!
|
||||
//! Authentication: AWS AKSK (Access Key ID + Secret Access Key)
|
||||
//! via environment variables. SigV4 signing is implemented manually
|
||||
//! using hmac/sha2 crates — no AWS SDK dependency.
|
||||
//! Authentication: supports two methods:
|
||||
//! - **Bearer token**: set `BEDROCK_API_KEY` env var (takes precedence).
|
||||
//! - **SigV4 signing**: AWS AKSK (Access Key ID + Secret Access Key)
|
||||
//! via environment variables or EC2 IMDSv2. SigV4 signing is implemented
|
||||
//! manually using hmac/sha2 crates — no AWS SDK dependency.
|
||||
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
@ -22,6 +24,14 @@ const SIGNING_SERVICE: &str = "bedrock";
|
||||
const DEFAULT_REGION: &str = "us-east-1";
|
||||
const DEFAULT_MAX_TOKENS: u32 = 4096;
|
||||
|
||||
// ── Authentication ──────────────────────────────────────────────
|
||||
|
||||
/// Authentication method for Bedrock: either SigV4 (AKSK) or Bearer token.
|
||||
enum BedrockAuth {
|
||||
SigV4(AwsCredentials),
|
||||
BearerToken(String),
|
||||
}
|
||||
|
||||
// ── AWS Credentials ─────────────────────────────────────────────
|
||||
|
||||
/// Resolved AWS credentials for SigV4 signing.
|
||||
@ -452,19 +462,38 @@ struct ResponseToolUseWrapper {
|
||||
// ── BedrockProvider ─────────────────────────────────────────────
|
||||
|
||||
pub struct BedrockProvider {
|
||||
credentials: Option<AwsCredentials>,
|
||||
auth: Option<BedrockAuth>,
|
||||
}
|
||||
|
||||
impl BedrockProvider {
|
||||
pub fn new() -> Self {
|
||||
// Bearer token takes precedence over SigV4 credentials.
|
||||
if let Some(token) = env_optional("BEDROCK_API_KEY") {
|
||||
return Self {
|
||||
auth: Some(BedrockAuth::BearerToken(token)),
|
||||
};
|
||||
}
|
||||
Self {
|
||||
credentials: AwsCredentials::from_env().ok(),
|
||||
auth: AwsCredentials::from_env().ok().map(BedrockAuth::SigV4),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_async() -> Self {
|
||||
let credentials = AwsCredentials::resolve().await.ok();
|
||||
Self { credentials }
|
||||
// Bearer token takes precedence over SigV4 credentials.
|
||||
if let Some(token) = env_optional("BEDROCK_API_KEY") {
|
||||
return Self {
|
||||
auth: Some(BedrockAuth::BearerToken(token)),
|
||||
};
|
||||
}
|
||||
let auth = AwsCredentials::resolve().await.ok().map(BedrockAuth::SigV4);
|
||||
Self { auth }
|
||||
}
|
||||
|
||||
/// Create a provider using a Bearer token for authentication.
|
||||
pub fn with_bearer_token(token: &str) -> Self {
|
||||
Self {
|
||||
auth: Some(BedrockAuth::BearerToken(token.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
@ -478,6 +507,13 @@ impl BedrockProvider {
|
||||
model_id.replace(':', "%3A")
|
||||
}
|
||||
|
||||
/// Resolve the AWS region from environment variables.
|
||||
fn resolve_region() -> String {
|
||||
env_optional("AWS_REGION")
|
||||
.or_else(|| env_optional("AWS_DEFAULT_REGION"))
|
||||
.unwrap_or_else(|| DEFAULT_REGION.to_string())
|
||||
}
|
||||
|
||||
/// Build the actual request URL. Uses raw model ID (reqwest sends colons as-is).
|
||||
fn endpoint_url(region: &str, model_id: &str) -> String {
|
||||
format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse")
|
||||
@ -491,22 +527,38 @@ impl BedrockProvider {
|
||||
format!("/model/{encoded}/converse")
|
||||
}
|
||||
|
||||
fn require_credentials(&self) -> anyhow::Result<&AwsCredentials> {
|
||||
self.credentials.as_ref().ok_or_else(|| {
|
||||
fn require_auth(&self) -> anyhow::Result<&BedrockAuth> {
|
||||
self.auth.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"AWS Bedrock credentials not set. Set AWS_ACCESS_KEY_ID and \
|
||||
AWS_SECRET_ACCESS_KEY environment variables, or run on an EC2 \
|
||||
instance with an IAM role attached."
|
||||
"AWS Bedrock credentials not set. Set BEDROCK_API_KEY for Bearer \
|
||||
token auth, or AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY for \
|
||||
SigV4 auth, or run on an EC2 instance with an IAM role attached."
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve credentials: use cached if available, otherwise fetch from IMDS.
|
||||
async fn resolve_credentials(&self) -> anyhow::Result<AwsCredentials> {
|
||||
if let Ok(creds) = AwsCredentials::from_env() {
|
||||
return Ok(creds);
|
||||
/// Resolve auth: use cached if available, otherwise try env vars then IMDS.
|
||||
async fn resolve_auth(&self) -> anyhow::Result<BedrockAuth> {
|
||||
// If we already have auth cached, re-resolve from the same source.
|
||||
if let Some(ref auth) = self.auth {
|
||||
match auth {
|
||||
BedrockAuth::BearerToken(token) => {
|
||||
return Ok(BedrockAuth::BearerToken(token.clone()));
|
||||
}
|
||||
BedrockAuth::SigV4(_) => {
|
||||
// Re-resolve SigV4 credentials (they may have rotated).
|
||||
}
|
||||
}
|
||||
}
|
||||
AwsCredentials::from_imds().await
|
||||
// Check Bearer token first.
|
||||
if let Some(token) = env_optional("BEDROCK_API_KEY") {
|
||||
return Ok(BedrockAuth::BearerToken(token));
|
||||
}
|
||||
// Fall back to SigV4.
|
||||
if let Ok(creds) = AwsCredentials::from_env() {
|
||||
return Ok(BedrockAuth::SigV4(creds));
|
||||
}
|
||||
Ok(BedrockAuth::SigV4(AwsCredentials::from_imds().await?))
|
||||
}
|
||||
|
||||
// ── Cache heuristics (same thresholds as AnthropicProvider) ──
|
||||
@ -876,7 +928,7 @@ impl BedrockProvider {
|
||||
|
||||
async fn send_converse_request(
|
||||
&self,
|
||||
credentials: &AwsCredentials,
|
||||
auth: &BedrockAuth,
|
||||
model: &str,
|
||||
request_body: &ConverseRequest,
|
||||
) -> anyhow::Result<ConverseResponse> {
|
||||
@ -912,44 +964,62 @@ impl BedrockProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
let url = Self::endpoint_url(&credentials.region, model);
|
||||
let canonical_uri = Self::canonical_uri(model);
|
||||
let now = chrono::Utc::now();
|
||||
let host = credentials.host();
|
||||
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
|
||||
|
||||
let mut headers_to_sign = vec![
|
||||
("content-type".to_string(), "application/json".to_string()),
|
||||
("host".to_string(), host),
|
||||
("x-amz-date".to_string(), amz_date.clone()),
|
||||
];
|
||||
if let Some(ref token) = credentials.session_token {
|
||||
headers_to_sign.push(("x-amz-security-token".to_string(), token.clone()));
|
||||
}
|
||||
headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
let response: reqwest::Response = match auth {
|
||||
BedrockAuth::BearerToken(token) => {
|
||||
let region = Self::resolve_region();
|
||||
let url = Self::endpoint_url(®ion, model);
|
||||
|
||||
let authorization = build_authorization_header(
|
||||
credentials,
|
||||
"POST",
|
||||
&canonical_uri,
|
||||
"",
|
||||
&headers_to_sign,
|
||||
&payload,
|
||||
&now,
|
||||
);
|
||||
self.http_client()
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(payload)
|
||||
.send()
|
||||
.await?
|
||||
}
|
||||
BedrockAuth::SigV4(credentials) => {
|
||||
let url = Self::endpoint_url(&credentials.region, model);
|
||||
let canonical_uri = Self::canonical_uri(model);
|
||||
let now = chrono::Utc::now();
|
||||
let host = credentials.host();
|
||||
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
|
||||
|
||||
let mut request = self
|
||||
.http_client()
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.header("x-amz-date", &amz_date)
|
||||
.header("authorization", &authorization);
|
||||
let mut headers_to_sign = vec![
|
||||
("content-type".to_string(), "application/json".to_string()),
|
||||
("host".to_string(), host),
|
||||
("x-amz-date".to_string(), amz_date.clone()),
|
||||
];
|
||||
if let Some(ref session_token) = credentials.session_token {
|
||||
headers_to_sign
|
||||
.push(("x-amz-security-token".to_string(), session_token.clone()));
|
||||
}
|
||||
headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
if let Some(ref token) = credentials.session_token {
|
||||
request = request.header("x-amz-security-token", token);
|
||||
}
|
||||
let authorization = build_authorization_header(
|
||||
credentials,
|
||||
"POST",
|
||||
&canonical_uri,
|
||||
"",
|
||||
&headers_to_sign,
|
||||
&payload,
|
||||
&now,
|
||||
);
|
||||
|
||||
let response: reqwest::Response = request.body(payload).send().await?;
|
||||
let mut request = self
|
||||
.http_client()
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.header("x-amz-date", &amz_date)
|
||||
.header("authorization", &authorization);
|
||||
|
||||
if let Some(ref session_token) = credentials.session_token {
|
||||
request = request.header("x-amz-security-token", session_token);
|
||||
}
|
||||
|
||||
request.body(payload).send().await?
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Bedrock", response).await);
|
||||
@ -999,7 +1069,7 @@ impl Provider for BedrockProvider {
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let credentials = self.resolve_credentials().await?;
|
||||
let auth = self.resolve_auth().await?;
|
||||
|
||||
let system = system_prompt.map(|text| {
|
||||
let mut blocks = vec![SystemBlock::Text(TextBlock {
|
||||
@ -1026,9 +1096,7 @@ impl Provider for BedrockProvider {
|
||||
tool_config: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.send_converse_request(&credentials, model, &request)
|
||||
.await?;
|
||||
let response = self.send_converse_request(&auth, model, &request).await?;
|
||||
|
||||
Self::parse_converse_response(response)
|
||||
.text
|
||||
@ -1041,7 +1109,7 @@ impl Provider for BedrockProvider {
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credentials = self.resolve_credentials().await?;
|
||||
let auth = self.resolve_auth().await?;
|
||||
|
||||
let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
|
||||
|
||||
@ -1082,17 +1150,20 @@ impl Provider for BedrockProvider {
|
||||
};
|
||||
|
||||
let response = self
|
||||
.send_converse_request(&credentials, model, &converse_request)
|
||||
.send_converse_request(&auth, model, &converse_request)
|
||||
.await?;
|
||||
|
||||
Ok(Self::parse_converse_response(response))
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
if let Some(ref creds) = self.credentials {
|
||||
let url = format!("https://{ENDPOINT_PREFIX}.{}.amazonaws.com/", creds.region);
|
||||
let _ = self.http_client().get(&url).send().await;
|
||||
}
|
||||
let region = match self.auth {
|
||||
Some(BedrockAuth::SigV4(ref creds)) => creds.region.clone(),
|
||||
Some(BedrockAuth::BearerToken(_)) => Self::resolve_region(),
|
||||
None => return Ok(()),
|
||||
};
|
||||
let url = format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/");
|
||||
let _ = self.http_client().get(&url).send().await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1104,6 +1175,35 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::providers::traits::ChatMessage;
|
||||
|
||||
/// RAII guard that sets/unsets an env var and restores the original on drop.
|
||||
struct EnvGuard {
|
||||
key: String,
|
||||
original: Option<String>,
|
||||
}
|
||||
|
||||
impl EnvGuard {
|
||||
fn set(key: &str, value: Option<&str>) -> Self {
|
||||
let original = std::env::var(key).ok();
|
||||
match value {
|
||||
Some(v) => std::env::set_var(key, v),
|
||||
None => std::env::remove_var(key),
|
||||
}
|
||||
Self {
|
||||
key: key.to_string(),
|
||||
original,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvGuard {
|
||||
fn drop(&mut self) {
|
||||
match &self.original {
|
||||
Some(v) => std::env::set_var(&self.key, v),
|
||||
None => std::env::remove_var(&self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── SigV4 signing tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
@ -1255,7 +1355,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_credentials() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { auth: None };
|
||||
let result = provider
|
||||
.chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", 0.7)
|
||||
.await;
|
||||
@ -1270,6 +1370,45 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
// ── Bearer token tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn creates_with_bearer_token() {
|
||||
let provider = BedrockProvider::with_bearer_token("test-api-key");
|
||||
assert!(provider.auth.is_some());
|
||||
assert!(
|
||||
matches!(provider.auth, Some(BedrockAuth::BearerToken(ref t)) if t == "test-api-key")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bearer_token_from_env() {
|
||||
let _guard = EnvGuard::set("BEDROCK_API_KEY", Some("env-bearer-token"));
|
||||
// Clear SigV4 vars to ensure Bearer is chosen.
|
||||
let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
|
||||
let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
|
||||
|
||||
let provider = BedrockProvider::new();
|
||||
assert!(matches!(
|
||||
provider.auth,
|
||||
Some(BedrockAuth::BearerToken(ref t)) if t == "env-bearer-token"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bearer_token_precedence() {
|
||||
let _bearer_guard = EnvGuard::set("BEDROCK_API_KEY", Some("bearer-key"));
|
||||
let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("AKIAEXAMPLE"));
|
||||
let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret"));
|
||||
|
||||
let provider = BedrockProvider::new();
|
||||
// Bearer token should take priority over SigV4 credentials.
|
||||
assert!(matches!(
|
||||
provider.auth,
|
||||
Some(BedrockAuth::BearerToken(ref t)) if t == "bearer-key"
|
||||
));
|
||||
}
|
||||
|
||||
// ── Endpoint URL tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
@ -1550,14 +1689,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_without_credentials_is_noop() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { auth: None };
|
||||
let result = provider.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_reports_native_tool_calling() {
|
||||
let provider = BedrockProvider { credentials: None };
|
||||
let provider = BedrockProvider { auth: None };
|
||||
let caps = provider.capabilities();
|
||||
assert!(caps.native_tool_calling);
|
||||
}
|
||||
|
||||
@ -890,9 +890,18 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) ->
|
||||
}
|
||||
name if is_glm_alias(name) => vec!["GLM_API_KEY"],
|
||||
name if is_minimax_alias(name) => vec![MINIMAX_OAUTH_TOKEN_ENV, MINIMAX_API_KEY_ENV],
|
||||
// Bedrock uses AWS AKSK from env vars (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY),
|
||||
// not a single API key. Credential resolution happens inside BedrockProvider.
|
||||
"bedrock" | "aws-bedrock" => return None,
|
||||
// Bedrock supports Bearer token auth via BEDROCK_API_KEY env var, in addition
|
||||
// to AWS AKSK (SigV4). If BEDROCK_API_KEY is set, return it; otherwise return
|
||||
// None and let BedrockProvider handle SigV4 credential resolution internally.
|
||||
"bedrock" | "aws-bedrock" => {
|
||||
if let Ok(val) = std::env::var("BEDROCK_API_KEY") {
|
||||
let trimmed = val.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed);
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
name if is_qianfan_alias(name) => vec!["QIANFAN_API_KEY"],
|
||||
name if is_doubao_alias(name) => {
|
||||
vec!["ARK_API_KEY", "VOLCENGINE_API_KEY", "DOUBAO_API_KEY"]
|
||||
@ -1247,7 +1256,13 @@ fn create_provider_with_url_and_options(
|
||||
api_version.as_deref(),
|
||||
)))
|
||||
}
|
||||
"bedrock" | "aws-bedrock" => Ok(Box::new(bedrock::BedrockProvider::new())),
|
||||
"bedrock" | "aws-bedrock" => {
|
||||
if let Some(api_key) = key {
|
||||
Ok(Box::new(bedrock::BedrockProvider::with_bearer_token(api_key)))
|
||||
} else {
|
||||
Ok(Box::new(bedrock::BedrockProvider::new()))
|
||||
}
|
||||
}
|
||||
name if is_qwen_oauth_alias(name) => {
|
||||
let base_url = api_url
|
||||
.map(str::trim)
|
||||
@ -2268,6 +2283,7 @@ mod tests {
|
||||
fn resolve_provider_credential_bedrock_uses_internal_credential_path() {
|
||||
let _generic_guard = EnvGuard::set("API_KEY", Some("generic-key"));
|
||||
let _override_guard = EnvGuard::set("OPENROUTER_API_KEY", Some("openrouter-key"));
|
||||
let _bedrock_guard = EnvGuard::set("BEDROCK_API_KEY", None);
|
||||
|
||||
assert_eq!(
|
||||
resolve_provider_credential("bedrock", Some("explicit")),
|
||||
@ -2277,6 +2293,20 @@ mod tests {
|
||||
assert!(resolve_provider_credential("aws-bedrock", None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_provider_credential_bedrock_returns_bearer_token_from_env() {
|
||||
let _bedrock_guard = EnvGuard::set("BEDROCK_API_KEY", Some("bedrock-bearer-token"));
|
||||
|
||||
assert_eq!(
|
||||
resolve_provider_credential("bedrock", None),
|
||||
Some("bedrock-bearer-token".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_provider_credential("aws-bedrock", None),
|
||||
Some("bedrock-bearer-token".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_qwen_oauth_context_prefers_explicit_override() {
|
||||
let _env_lock = env_lock();
|
||||
|
||||
@ -171,19 +171,23 @@ impl OpenRouterProvider {
|
||||
if items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
let valid: Vec<NativeToolSpec> = items
|
||||
.iter()
|
||||
.filter(|tool| is_valid_openai_tool_name(&tool.name))
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
if valid.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(valid)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||
@ -628,6 +632,16 @@ impl Provider for OpenRouterProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool name is valid for OpenAI-compatible APIs.
|
||||
/// Must match `^[a-zA-Z0-9_-]{1,64}$`.
|
||||
fn is_valid_openai_tool_name(name: &str) -> bool {
|
||||
!name.is_empty()
|
||||
&& name.len() <= 64
|
||||
&& name
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -1137,4 +1151,69 @@ mod tests {
|
||||
let provider = OpenRouterProvider::new(Some("key"), None).with_timeout_secs(300);
|
||||
assert_eq!(provider.timeout_secs, 300);
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// tool name validation tests
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[test]
|
||||
fn valid_openai_tool_names() {
|
||||
assert!(is_valid_openai_tool_name("shell"));
|
||||
assert!(is_valid_openai_tool_name("file_read"));
|
||||
assert!(is_valid_openai_tool_name("web-search"));
|
||||
assert!(is_valid_openai_tool_name("Tool123"));
|
||||
assert!(is_valid_openai_tool_name("a"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_openai_tool_names() {
|
||||
assert!(!is_valid_openai_tool_name(""));
|
||||
assert!(!is_valid_openai_tool_name("mcp:server.tool"));
|
||||
assert!(!is_valid_openai_tool_name("node.js"));
|
||||
assert!(!is_valid_openai_tool_name("tool name"));
|
||||
assert!(!is_valid_openai_tool_name(
|
||||
"this_tool_name_is_way_too_long_and_exceeds_the_sixty_four_character_limit_xxxxx"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_tools_skips_invalid_names() {
|
||||
use crate::tools::ToolSpec;
|
||||
|
||||
let tools = vec![
|
||||
ToolSpec {
|
||||
name: "valid_tool".into(),
|
||||
description: "A valid tool".into(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: "mcp:server.bad".into(),
|
||||
description: "Invalid name".into(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: "another-valid".into(),
|
||||
description: "Also valid".into(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
},
|
||||
];
|
||||
|
||||
let result = OpenRouterProvider::convert_tools(Some(&tools)).unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "valid_tool");
|
||||
assert_eq!(result[1].function.name, "another-valid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_tools_returns_none_when_all_invalid() {
|
||||
use crate::tools::ToolSpec;
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "mcp:bad.name".into(),
|
||||
description: "Invalid".into(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
assert!(OpenRouterProvider::convert_tools(Some(&tools)).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@ use zip::ZipArchive;
|
||||
mod audit;
|
||||
#[cfg(feature = "skill-creation")]
|
||||
pub mod creator;
|
||||
pub mod testing;
|
||||
|
||||
const OPEN_SKILLS_REPO_URL: &str = "https://github.com/besoeasy/open-skills";
|
||||
const OPEN_SKILLS_SYNC_MARKER: &str = ".zeroclaw-open-skills-sync";
|
||||
@ -1436,6 +1437,44 @@ pub fn handle_command(command: crate::SkillCommands, config: &crate::config::Con
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
crate::SkillCommands::Test { name, verbose } => {
|
||||
let results = if let Some(ref skill_name) = name {
|
||||
// Test a single skill
|
||||
let source_path = PathBuf::from(skill_name);
|
||||
let target = if source_path.exists() {
|
||||
source_path
|
||||
} else {
|
||||
skills_dir(workspace_dir).join(skill_name)
|
||||
};
|
||||
|
||||
if !target.exists() {
|
||||
anyhow::bail!("Skill not found: {}", skill_name);
|
||||
}
|
||||
|
||||
let r = testing::test_skill(&target, skill_name, verbose)?;
|
||||
if r.tests_run == 0 {
|
||||
println!(
|
||||
" {} No TEST.sh found for skill '{}'.",
|
||||
console::style("-").dim(),
|
||||
skill_name,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
vec![r]
|
||||
} else {
|
||||
// Test all skills
|
||||
let dirs = vec![skills_dir(workspace_dir)];
|
||||
testing::test_all_skills(&dirs, verbose)?
|
||||
};
|
||||
|
||||
testing::print_results(&results);
|
||||
|
||||
let any_failed = results.iter().any(|r| !r.failures.is_empty());
|
||||
if any_failed {
|
||||
anyhow::bail!("Some skill tests failed.");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
471
src/skills/testing.rs
Normal file
471
src/skills/testing.rs
Normal file
@ -0,0 +1,471 @@
|
||||
use anyhow::{Context, Result};
|
||||
use regex::Regex;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
const TEST_FILE_NAME: &str = "TEST.sh";
|
||||
|
||||
/// Result of running all tests for a single skill.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SkillTestResult {
|
||||
pub skill_name: String,
|
||||
pub tests_run: usize,
|
||||
pub tests_passed: usize,
|
||||
pub failures: Vec<TestFailure>,
|
||||
}
|
||||
|
||||
/// Details about a single failed test case.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TestFailure {
|
||||
pub command: String,
|
||||
pub expected_exit: i32,
|
||||
pub actual_exit: i32,
|
||||
pub expected_pattern: String,
|
||||
pub actual_output: String,
|
||||
}
|
||||
|
||||
/// A parsed test case from a TEST.sh line.
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestCase {
|
||||
command: String,
|
||||
expected_exit: i32,
|
||||
expected_pattern: String,
|
||||
}
|
||||
|
||||
/// Parse a single TEST.sh line into a `TestCase`.
|
||||
///
|
||||
/// Expected format: `command | expected_exit_code | expected_output_pattern`
|
||||
fn parse_test_line(line: &str) -> Option<TestCase> {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Split on ` | ` (pipe surrounded by spaces) to avoid splitting on shell
|
||||
// pipes inside the command itself. Fall back to bare `|` splitting only if
|
||||
// the line contains exactly two ` | ` delimiters.
|
||||
let parts: Vec<&str> = trimmed.split(" | ").collect();
|
||||
if parts.len() < 3 {
|
||||
// Try splitting on `|` as fallback
|
||||
let parts: Vec<&str> = trimmed.splitn(3, '|').collect();
|
||||
if parts.len() < 3 {
|
||||
return None;
|
||||
}
|
||||
let command = parts[0].trim().to_string();
|
||||
let expected_exit = parts[1].trim().parse::<i32>().ok()?;
|
||||
let expected_pattern = parts[2].trim().to_string();
|
||||
return Some(TestCase {
|
||||
command,
|
||||
expected_exit,
|
||||
expected_pattern,
|
||||
});
|
||||
}
|
||||
|
||||
let command = parts[0].trim().to_string();
|
||||
let expected_exit = parts[1].trim().parse::<i32>().ok()?;
|
||||
// Rejoin remaining parts in case the pattern itself contains ` | `
|
||||
let expected_pattern = parts[2..].join(" | ").trim().to_string();
|
||||
|
||||
Some(TestCase {
|
||||
command,
|
||||
expected_exit,
|
||||
expected_pattern,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check whether `output` matches `pattern`.
|
||||
///
|
||||
/// If the pattern looks like a regex (contains regex metacharacters beyond a
|
||||
/// simple `/` path), we attempt a regex match. Otherwise we fall back to a
|
||||
/// simple substring check.
|
||||
fn pattern_matches(output: &str, pattern: &str) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return true;
|
||||
}
|
||||
// Try regex first
|
||||
if let Ok(re) = Regex::new(pattern) {
|
||||
if re.is_match(output) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// Fallback: substring match
|
||||
output.contains(pattern)
|
||||
}
|
||||
|
||||
/// Run a single test case and return a possible failure.
|
||||
fn run_test_case(case: &TestCase, skill_dir: &Path, verbose: bool) -> Option<TestFailure> {
|
||||
if verbose {
|
||||
println!(" running: {}", case.command);
|
||||
}
|
||||
|
||||
let result = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(&case.command)
|
||||
.current_dir(skill_dir)
|
||||
.output();
|
||||
|
||||
let output = match result {
|
||||
Ok(o) => o,
|
||||
Err(err) => {
|
||||
return Some(TestFailure {
|
||||
command: case.command.clone(),
|
||||
expected_exit: case.expected_exit,
|
||||
actual_exit: -1,
|
||||
expected_pattern: case.expected_pattern.clone(),
|
||||
actual_output: format!("failed to execute command: {err}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let actual_exit = output.status.code().unwrap_or(-1);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let combined = format!("{stdout}{stderr}");
|
||||
|
||||
if verbose {
|
||||
if !stdout.is_empty() {
|
||||
println!(" stdout: {}", stdout.trim());
|
||||
}
|
||||
if !stderr.is_empty() {
|
||||
println!(" stderr: {}", stderr.trim());
|
||||
}
|
||||
println!(" exit: {actual_exit}");
|
||||
}
|
||||
|
||||
let exit_ok = actual_exit == case.expected_exit;
|
||||
let pattern_ok = pattern_matches(&combined, &case.expected_pattern);
|
||||
|
||||
if exit_ok && pattern_ok {
|
||||
None
|
||||
} else {
|
||||
Some(TestFailure {
|
||||
command: case.command.clone(),
|
||||
expected_exit: case.expected_exit,
|
||||
actual_exit,
|
||||
expected_pattern: case.expected_pattern.clone(),
|
||||
actual_output: combined.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Test a single skill by parsing and running its TEST.sh.
|
||||
pub fn test_skill(skill_dir: &Path, skill_name: &str, verbose: bool) -> Result<SkillTestResult> {
|
||||
let test_file = skill_dir.join(TEST_FILE_NAME);
|
||||
if !test_file.exists() {
|
||||
return Ok(SkillTestResult {
|
||||
skill_name: skill_name.to_string(),
|
||||
tests_run: 0,
|
||||
tests_passed: 0,
|
||||
failures: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&test_file)
|
||||
.with_context(|| format!("failed to read {}", test_file.display()))?;
|
||||
|
||||
let cases: Vec<TestCase> = content.lines().filter_map(parse_test_line).collect();
|
||||
|
||||
let mut result = SkillTestResult {
|
||||
skill_name: skill_name.to_string(),
|
||||
tests_run: cases.len(),
|
||||
tests_passed: 0,
|
||||
failures: Vec::new(),
|
||||
};
|
||||
|
||||
for case in &cases {
|
||||
match run_test_case(case, skill_dir, verbose) {
|
||||
None => result.tests_passed += 1,
|
||||
Some(failure) => result.failures.push(failure),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Test all skills that have a TEST.sh file within the given skill directories.
|
||||
pub fn test_all_skills(skills_dirs: &[PathBuf], verbose: bool) -> Result<Vec<SkillTestResult>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for dir in skills_dirs {
|
||||
if !dir.exists() || !dir.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let entries = std::fs::read_dir(dir)
|
||||
.with_context(|| format!("failed to read directory {}", dir.display()))?;
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let test_file = path.join(TEST_FILE_NAME);
|
||||
if !test_file.exists() {
|
||||
continue;
|
||||
}
|
||||
let skill_name = path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
if verbose {
|
||||
println!(" Testing skill: {} ({})", skill_name, path.display());
|
||||
}
|
||||
|
||||
let r = test_skill(&path, &skill_name, verbose)?;
|
||||
results.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Pretty-print test results using the `console` crate.
|
||||
pub fn print_results(results: &[SkillTestResult]) {
|
||||
if results.is_empty() {
|
||||
println!("No skills with {} found.", TEST_FILE_NAME);
|
||||
return;
|
||||
}
|
||||
|
||||
println!();
|
||||
for r in results {
|
||||
if r.tests_run == 0 {
|
||||
println!(
|
||||
" {} {} — no test cases",
|
||||
console::style("-").dim(),
|
||||
r.skill_name,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if r.failures.is_empty() {
|
||||
println!(
|
||||
" {} {} — {}/{} passed",
|
||||
console::style("✓").green().bold(),
|
||||
console::style(&r.skill_name).white().bold(),
|
||||
r.tests_passed,
|
||||
r.tests_run,
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" {} {} — {}/{} passed",
|
||||
console::style("✗").red().bold(),
|
||||
console::style(&r.skill_name).white().bold(),
|
||||
r.tests_passed,
|
||||
r.tests_run,
|
||||
);
|
||||
for f in &r.failures {
|
||||
println!(" command: {}", console::style(&f.command).dim(),);
|
||||
println!(
|
||||
" expected: exit={}, pattern={}",
|
||||
f.expected_exit, f.expected_pattern,
|
||||
);
|
||||
println!(
|
||||
" actual: exit={}, output={}",
|
||||
f.actual_exit,
|
||||
truncate_output(&f.actual_output, 200),
|
||||
);
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_run: usize = results.iter().map(|r| r.tests_run).sum();
|
||||
let total_passed: usize = results.iter().map(|r| r.tests_passed).sum();
|
||||
let total_failed = total_run - total_passed;
|
||||
|
||||
println!();
|
||||
if total_failed == 0 {
|
||||
println!(
|
||||
" {} All {total_run} test(s) passed across {} skill(s).",
|
||||
console::style("✓").green().bold(),
|
||||
results.len(),
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" {} {total_failed} of {total_run} test(s) failed across {} skill(s).",
|
||||
console::style("✗").red().bold(),
|
||||
results.len(),
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
fn truncate_output(s: &str, max: usize) -> String {
|
||||
let trimmed = s.trim();
|
||||
if trimmed.len() <= max {
|
||||
trimmed.replace('\n', " ")
|
||||
} else {
|
||||
format!("{}...", &trimmed[..max].replace('\n', " "))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
|
||||
#[test]
|
||||
fn parse_comment_and_empty_lines() {
|
||||
assert!(parse_test_line("").is_none());
|
||||
assert!(parse_test_line(" ").is_none());
|
||||
assert!(parse_test_line("# this is a comment").is_none());
|
||||
assert!(parse_test_line(" # indented comment").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_valid_test_line() {
|
||||
let case = parse_test_line("echo hello | 0 | hello").unwrap();
|
||||
assert_eq!(case.command, "echo hello");
|
||||
assert_eq!(case.expected_exit, 0);
|
||||
assert_eq!(case.expected_pattern, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_line_with_spaces_in_pattern() {
|
||||
let case = parse_test_line("echo 'hello world' | 0 | hello world").unwrap();
|
||||
assert_eq!(case.command, "echo 'hello world'");
|
||||
assert_eq!(case.expected_exit, 0);
|
||||
assert_eq!(case.expected_pattern, "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_invalid_line_missing_parts() {
|
||||
assert!(parse_test_line("just a command").is_none());
|
||||
assert!(parse_test_line("cmd | notanumber | pattern").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_matches_empty() {
|
||||
assert!(pattern_matches("anything", ""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_matches_substring() {
|
||||
assert!(pattern_matches("hello world", "hello"));
|
||||
assert!(pattern_matches("hello world", "world"));
|
||||
assert!(!pattern_matches("hello world", "missing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_matches_regex() {
|
||||
assert!(pattern_matches("hello world 42", r"world \d+"));
|
||||
assert!(pattern_matches("/usr/bin/bash", r"/"));
|
||||
assert!(!pattern_matches("hello", r"^\d+$"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_with_echo() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skill_dir = dir.path().join("echo-skill");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
fs::write(
|
||||
skill_dir.join("TEST.sh"),
|
||||
"# Echo test\necho hello | 0 | hello\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = test_skill(&skill_dir, "echo-skill", false).unwrap();
|
||||
assert_eq!(result.tests_run, 1);
|
||||
assert_eq!(result.tests_passed, 1);
|
||||
assert!(result.failures.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_without_test_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skill_dir = dir.path().join("no-tests");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
let result = test_skill(&skill_dir, "no-tests", false).unwrap();
|
||||
assert_eq!(result.tests_run, 0);
|
||||
assert_eq!(result.tests_passed, 0);
|
||||
assert!(result.failures.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_with_failing_test() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skill_dir = dir.path().join("fail-skill");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
fs::write(skill_dir.join("TEST.sh"), "echo hello | 1 | goodbye\n").unwrap();
|
||||
|
||||
let result = test_skill(&skill_dir, "fail-skill", false).unwrap();
|
||||
assert_eq!(result.tests_run, 1);
|
||||
assert_eq!(result.tests_passed, 0);
|
||||
assert_eq!(result.failures.len(), 1);
|
||||
assert_eq!(result.failures[0].expected_exit, 1);
|
||||
assert_eq!(result.failures[0].actual_exit, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_exit_code_mismatch() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skill_dir = dir.path().join("exit-mismatch");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
fs::write(skill_dir.join("TEST.sh"), "false | 0 | \n").unwrap();
|
||||
|
||||
let result = test_skill(&skill_dir, "exit-mismatch", false).unwrap();
|
||||
assert_eq!(result.tests_run, 1);
|
||||
assert_eq!(result.tests_passed, 0);
|
||||
assert_eq!(result.failures[0].actual_exit, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_aggregation() {
|
||||
let results = [
|
||||
SkillTestResult {
|
||||
skill_name: "a".to_string(),
|
||||
tests_run: 3,
|
||||
tests_passed: 3,
|
||||
failures: Vec::new(),
|
||||
},
|
||||
SkillTestResult {
|
||||
skill_name: "b".to_string(),
|
||||
tests_run: 2,
|
||||
tests_passed: 1,
|
||||
failures: vec![TestFailure {
|
||||
command: "false".to_string(),
|
||||
expected_exit: 0,
|
||||
actual_exit: 1,
|
||||
expected_pattern: String::new(),
|
||||
actual_output: String::new(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
let total_run: usize = results.iter().map(|r| r.tests_run).sum();
|
||||
let total_passed: usize = results.iter().map(|r| r.tests_passed).sum();
|
||||
assert_eq!(total_run, 5);
|
||||
assert_eq!(total_passed, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_skills_finds_skills_with_tests() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
|
||||
// Skill with TEST.sh
|
||||
let skill_a = skills_dir.join("skill-a");
|
||||
fs::create_dir_all(&skill_a).unwrap();
|
||||
fs::write(skill_a.join("TEST.sh"), "echo ok | 0 | ok\n").unwrap();
|
||||
|
||||
// Skill without TEST.sh — should be skipped
|
||||
let skill_b = skills_dir.join("skill-b");
|
||||
fs::create_dir_all(&skill_b).unwrap();
|
||||
|
||||
let results = test_all_skills(std::slice::from_ref(&skills_dir), false).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].skill_name, "skill-a");
|
||||
assert_eq!(results[0].tests_passed, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_output() {
|
||||
assert_eq!(truncate_output("short", 100), "short");
|
||||
let long = "a".repeat(300);
|
||||
let truncated = truncate_output(&long, 200);
|
||||
assert!(truncated.ends_with("..."));
|
||||
assert!(truncated.len() <= 204); // 200 + "..."
|
||||
}
|
||||
}
|
||||
503
src/tools/ask_user.rs
Normal file
503
src/tools/ask_user.rs
Normal file
@ -0,0 +1,503 @@
|
||||
//! Interactive user prompting tool for cross-channel confirmations.
|
||||
//!
|
||||
//! Exposes `ask_user` as an agent-callable tool that sends a question to a
|
||||
//! messaging channel and waits for the user's response. The tool holds a
|
||||
//! late-binding channel map handle that is populated once channels are
|
||||
//! initialized (after tool construction). This mirrors the pattern used by
|
||||
//! [`ReactionTool`](super::reaction::ReactionTool).
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::channels::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Shared handle giving tools late-bound access to the live channel map.
|
||||
pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
|
||||
|
||||
/// Default timeout in seconds when waiting for a user response.
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
/// Agent-callable tool for sending a question to a user and waiting for their response.
|
||||
pub struct AskUserTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
channels: ChannelMapHandle,
|
||||
}
|
||||
|
||||
impl AskUserTool {
|
||||
/// Create a new ask_user tool with an empty channel map.
|
||||
/// Call [`channel_map_handle`] and write to the returned handle once channels
|
||||
/// are available.
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self {
|
||||
security,
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the shared handle so callers can populate it after channel init.
|
||||
pub fn channel_map_handle(&self) -> ChannelMapHandle {
|
||||
Arc::clone(&self.channels)
|
||||
}
|
||||
|
||||
/// Convenience: populate the channel map from a pre-built map.
|
||||
pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
|
||||
*self.channels.write() = map;
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a question with optional choices for display.
|
||||
fn format_question(question: &str, choices: Option<&[String]>) -> String {
|
||||
let mut lines = Vec::new();
|
||||
lines.push(format!("**{question}**"));
|
||||
|
||||
if let Some(choices) = choices {
|
||||
lines.push(String::new());
|
||||
for (i, choice) in choices.iter().enumerate() {
|
||||
lines.push(format!("{}. {choice}", i + 1));
|
||||
}
|
||||
lines.push(String::new());
|
||||
lines.push("_Reply with a number or type your answer._".to_string());
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for AskUserTool {
|
||||
fn name(&self) -> &str {
|
||||
"ask_user"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Ask the user a question and wait for their response. \
|
||||
Sends the question to a messaging channel and blocks until the user replies \
|
||||
or the timeout expires. Optionally provide choices for structured responses."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the user"
|
||||
},
|
||||
"choices": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Optional list of choices (renders as buttons on Telegram, numbered list on CLI)"
|
||||
},
|
||||
"timeout_secs": {
|
||||
"type": "integer",
|
||||
"description": "Seconds to wait for a response (default: 300)"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Target channel name. Defaults to the first available channel if omitted."
|
||||
}
|
||||
},
|
||||
"required": ["question"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security gate: Act operation
|
||||
if let Err(e) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "ask_user")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Action blocked: {e}")),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse required params
|
||||
let question = args
|
||||
.get("question")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'question' parameter"))?
|
||||
.to_string();
|
||||
|
||||
let choices: Option<Vec<String>> = args.get("choices").and_then(|v| {
|
||||
v.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| item.as_str().map(|s| s.trim().to_string()))
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect()
|
||||
})
|
||||
});
|
||||
|
||||
let timeout_secs = args
|
||||
.get("timeout_secs")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
|
||||
let requested_channel = args
|
||||
.get("channel")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.trim().to_string());
|
||||
|
||||
// Resolve channel from handle — block-scoped to drop the RwLock guard
|
||||
// before any `.await` (parking_lot guards are !Send).
|
||||
let (channel_name, channel): (String, Arc<dyn Channel>) = {
|
||||
let channels = self.channels.read();
|
||||
if channels.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("No channels available yet (channels not initialized)".to_string()),
|
||||
});
|
||||
}
|
||||
if let Some(ref name) = requested_channel {
|
||||
let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
|
||||
let available: Vec<String> = channels.keys().cloned().collect();
|
||||
anyhow::anyhow!(
|
||||
"Channel '{}' not found. Available: {}",
|
||||
name,
|
||||
available.join(", ")
|
||||
)
|
||||
})?;
|
||||
(name.clone(), ch)
|
||||
} else {
|
||||
let (name, ch) = channels.iter().next().ok_or_else(|| {
|
||||
anyhow::anyhow!("No channels available. Configure at least one channel.")
|
||||
})?;
|
||||
(name.clone(), ch.clone())
|
||||
}
|
||||
};
|
||||
|
||||
// Format and send the question
|
||||
let text = format_question(&question, choices.as_deref());
|
||||
let msg = SendMessage::new(&text, "");
|
||||
if let Err(e) = channel.send(&msg).await {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to send question to channel '{channel_name}': {e}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Listen for user response with timeout
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<ChannelMessage>(1);
|
||||
let timeout = std::time::Duration::from_secs(timeout_secs);
|
||||
|
||||
// Spawn a listener task on the channel
|
||||
let listen_channel = Arc::clone(&channel);
|
||||
let listen_handle = tokio::spawn(async move { listen_channel.listen(tx).await });
|
||||
|
||||
let response = tokio::time::timeout(timeout, rx.recv()).await;
|
||||
|
||||
// Abort the listener once we have a response or timeout
|
||||
listen_handle.abort();
|
||||
|
||||
match response {
|
||||
Ok(Some(msg)) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: msg.content,
|
||||
error: None,
|
||||
}),
|
||||
Ok(None) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: "TIMEOUT".to_string(),
|
||||
error: Some("Channel closed before receiving a response".to_string()),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: "TIMEOUT".to_string(),
|
||||
error: Some(format!(
|
||||
"No response received within {timeout_secs} seconds"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// A stub channel that records sent messages but never produces incoming messages.
|
||||
struct SilentChannel {
|
||||
channel_name: String,
|
||||
sent: Arc<RwLock<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl SilentChannel {
|
||||
fn new(name: &str) -> Self {
|
||||
Self {
|
||||
channel_name: name.to_string(),
|
||||
sent: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for SilentChannel {
|
||||
fn name(&self) -> &str {
|
||||
&self.channel_name
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
self.sent.write().push(message.content.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Never sends anything — simulates no user response
|
||||
tokio::time::sleep(std::time::Duration::from_secs(600)).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A stub channel that immediately responds with a canned message.
|
||||
struct RespondingChannel {
|
||||
channel_name: String,
|
||||
response: String,
|
||||
sent: Arc<RwLock<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl RespondingChannel {
|
||||
fn new(name: &str, response: &str) -> Self {
|
||||
Self {
|
||||
channel_name: name.to_string(),
|
||||
response: response.to_string(),
|
||||
sent: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for RespondingChannel {
|
||||
fn name(&self) -> &str {
|
||||
&self.channel_name
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
self.sent.write().push(message.content.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
let msg = ChannelMessage {
|
||||
id: "resp_1".to_string(),
|
||||
sender: "user".to_string(),
|
||||
reply_target: "user".to_string(),
|
||||
content: self.response.clone(),
|
||||
channel: self.channel_name.clone(),
|
||||
timestamp: 1000,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
let _ = tx.send(msg).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> AskUserTool {
|
||||
let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let map: HashMap<String, Arc<dyn Channel>> = channels
|
||||
.into_iter()
|
||||
.map(|(name, ch)| (name.to_string(), ch))
|
||||
.collect();
|
||||
tool.populate(map);
|
||||
tool
|
||||
}
|
||||
|
||||
// ── Metadata tests ──
|
||||
|
||||
#[test]
|
||||
fn tool_name_and_description() {
|
||||
let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
|
||||
assert_eq!(tool.name(), "ask_user");
|
||||
assert!(!tool.description().is_empty());
|
||||
assert!(tool.description().contains("question"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameter_schema_validation() {
|
||||
let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["question"].is_object());
|
||||
assert!(schema["properties"]["choices"].is_object());
|
||||
assert!(schema["properties"]["timeout_secs"].is_object());
|
||||
assert!(schema["properties"]["channel"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.iter().any(|v| v == "question"));
|
||||
// choices, timeout_secs, channel are optional
|
||||
assert!(!required.iter().any(|v| v == "choices"));
|
||||
assert!(!required.iter().any(|v| v == "timeout_secs"));
|
||||
assert!(!required.iter().any(|v| v == "channel"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_matches_metadata() {
|
||||
let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "ask_user");
|
||||
assert_eq!(spec.description, tool.description());
|
||||
assert!(spec.parameters["required"].is_array());
|
||||
}
|
||||
|
||||
// ── Format question tests ──
|
||||
|
||||
#[test]
|
||||
fn format_question_without_choices() {
|
||||
let text = format_question("Are you sure?", None);
|
||||
assert!(text.contains("Are you sure?"));
|
||||
assert!(!text.contains("1."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_question_with_choices() {
|
||||
let choices = vec!["Yes".to_string(), "No".to_string(), "Maybe".to_string()];
|
||||
let text = format_question("Continue?", Some(&choices));
|
||||
assert!(text.contains("Continue?"));
|
||||
assert!(text.contains("1. Yes"));
|
||||
assert!(text.contains("2. No"));
|
||||
assert!(text.contains("3. Maybe"));
|
||||
assert!(text.contains("Reply with a number"));
|
||||
}
|
||||
|
||||
// ── Execute tests ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_missing_question() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"test",
|
||||
Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
|
||||
)]);
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_empty_question() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"test",
|
||||
Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
|
||||
)]);
|
||||
let result = tool.execute(json!({ "question": " " })).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_channels_returns_not_initialized() {
|
||||
let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("not initialized"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_channel_returns_error() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"slack",
|
||||
Arc::new(SilentChannel::new("slack")) as Arc<dyn Channel>,
|
||||
)]);
|
||||
let result = tool
|
||||
.execute(json!({ "question": "Hello?", "channel": "nonexistent" }))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn timeout_returns_timeout_output() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"test",
|
||||
Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
|
||||
)]);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"question": "Confirm?",
|
||||
"timeout_secs": 1
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert_eq!(result.output, "TIMEOUT");
|
||||
assert!(result.error.as_deref().unwrap().contains("1 seconds"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn successful_response_flow() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"test",
|
||||
Arc::new(RespondingChannel::new("test", "Yes, proceed!")) as Arc<dyn Channel>,
|
||||
)]);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"question": "Should we deploy?",
|
||||
"timeout_secs": 5
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success, "error: {:?}", result.error);
|
||||
assert_eq!(result.output, "Yes, proceed!");
|
||||
assert!(result.error.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn successful_response_with_choices() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"telegram",
|
||||
Arc::new(RespondingChannel::new("telegram", "2")) as Arc<dyn Channel>,
|
||||
)]);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"question": "Pick an option",
|
||||
"choices": ["Option A", "Option B"],
|
||||
"channel": "telegram",
|
||||
"timeout_secs": 5
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success, "error: {:?}", result.error);
|
||||
assert_eq!(result.output, "2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_map_handle_allows_late_binding() {
|
||||
let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let handle = tool.channel_map_handle();
|
||||
|
||||
// Initially empty — tool reports not initialized
|
||||
let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
|
||||
assert!(!result.success);
|
||||
|
||||
// Populate via the handle
|
||||
{
|
||||
let mut map = handle.write();
|
||||
map.insert(
|
||||
"cli".to_string(),
|
||||
Arc::new(RespondingChannel::new("cli", "ok")) as Arc<dyn Channel>,
|
||||
);
|
||||
}
|
||||
|
||||
// Now the tool can route to the channel
|
||||
let result = tool
|
||||
.execute(json!({ "question": "Hello?", "timeout_secs": 5 }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "ok");
|
||||
}
|
||||
}
|
||||
520
src/tools/claude_code_runner.rs
Normal file
520
src/tools/claude_code_runner.rs
Normal file
@ -0,0 +1,520 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::ClaudeCodeRunnerConfig;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// Environment variables safe to pass through to the `claude` subprocess.
|
||||
const SAFE_ENV_VARS: &[&str] = &[
|
||||
"PATH", "HOME", "TERM", "LANG", "LC_ALL", "LC_CTYPE", "USER", "SHELL", "TMPDIR",
|
||||
];
|
||||
|
||||
/// Event payload received from Claude Code HTTP hooks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeCodeHookEvent {
|
||||
/// The session identifier (matches the tmux session name suffix).
|
||||
pub session_id: String,
|
||||
/// Event type from Claude Code (e.g. "tool_use", "tool_result", "completion").
|
||||
pub event_type: String,
|
||||
/// Tool name when event_type is "tool_use" or "tool_result".
|
||||
#[serde(default)]
|
||||
pub tool_name: Option<String>,
|
||||
/// Human-readable summary of what happened.
|
||||
#[serde(default)]
|
||||
pub summary: Option<String>,
|
||||
}
|
||||
|
||||
/// Spawns Claude Code inside a tmux session with HTTP hooks that POST tool
|
||||
/// execution events back to ZeroClaw's gateway endpoint, enabling live Slack
|
||||
/// progress updates and SSH session handoff.
|
||||
///
|
||||
/// Unlike [`ClaudeCodeTool`](super::claude_code::ClaudeCodeTool) which runs
|
||||
/// `claude -p` inline and waits for completion, this runner:
|
||||
///
|
||||
/// 1. Creates a named tmux session (`<prefix><id>`)
|
||||
/// 2. Launches `claude` inside it with `--hook-url` pointing at the gateway
|
||||
/// 3. Returns immediately with the session ID and an SSH attach command
|
||||
/// 4. Receives streamed progress via the `/hooks/claude-code` endpoint
|
||||
pub struct ClaudeCodeRunnerTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
config: ClaudeCodeRunnerConfig,
|
||||
/// Base URL of the ZeroClaw gateway (e.g. "http://localhost:3000").
|
||||
gateway_url: String,
|
||||
}
|
||||
|
||||
impl ClaudeCodeRunnerTool {
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
config: ClaudeCodeRunnerConfig,
|
||||
gateway_url: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
config,
|
||||
gateway_url,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the tmux session name from the configured prefix and a unique id.
|
||||
fn session_name(&self, id: &str) -> String {
|
||||
format!("{}{}", self.config.tmux_prefix, id)
|
||||
}
|
||||
|
||||
/// Build the SSH attach command for session handoff.
|
||||
fn ssh_attach_command(&self, session_name: &str) -> Option<String> {
|
||||
self.config
|
||||
.ssh_host
|
||||
.as_ref()
|
||||
.map(|host| format!("ssh -t {host} tmux attach-session -t {session_name}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ClaudeCodeRunnerTool {
|
||||
fn name(&self) -> &str {
|
||||
"claude_code_runner"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Spawn a Claude Code task in a tmux session with live Slack progress updates and SSH handoff. Returns immediately with session ID and attach command."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The coding task to delegate to Claude Code"
|
||||
},
|
||||
"working_directory": {
|
||||
"type": "string",
|
||||
"description": "Working directory within the workspace (must be inside workspace_dir)"
|
||||
},
|
||||
"slack_channel": {
|
||||
"type": "string",
|
||||
"description": "Slack channel ID to post progress updates to"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Rate limit check
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Enforce act policy
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "claude_code_runner")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
// Extract prompt (required)
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
// Validate working directory
|
||||
let work_dir = if let Some(wd) = args.get("working_directory").and_then(|v| v.as_str()) {
|
||||
let wd_path = std::path::PathBuf::from(wd);
|
||||
let workspace = &self.security.workspace_dir;
|
||||
let canonical_wd = match wd_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"working_directory '{}' does not exist or is not accessible",
|
||||
wd
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
let canonical_ws = match workspace.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"workspace directory '{}' does not exist or is not accessible",
|
||||
workspace.display()
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
if !canonical_wd.starts_with(&canonical_ws) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"working_directory '{}' is outside the workspace '{}'",
|
||||
wd,
|
||||
workspace.display()
|
||||
)),
|
||||
});
|
||||
}
|
||||
canonical_wd
|
||||
} else {
|
||||
self.security.workspace_dir.clone()
|
||||
};
|
||||
|
||||
let slack_channel = args
|
||||
.get("slack_channel")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
// Record action budget
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Generate a unique session ID
|
||||
let session_id = uuid::Uuid::new_v4().to_string()[..8].to_string();
|
||||
let session_name = self.session_name(&session_id);
|
||||
|
||||
// Build the hook URL for Claude Code to POST events to
|
||||
let hook_url = format!("{}/hooks/claude-code", self.gateway_url);
|
||||
|
||||
// Build the claude command that will run inside tmux
|
||||
let mut claude_args = vec![
|
||||
"claude".to_string(),
|
||||
"-p".to_string(),
|
||||
prompt.to_string(),
|
||||
"--output-format".to_string(),
|
||||
"json".to_string(),
|
||||
];
|
||||
|
||||
// Pass hook URL via environment variable (Claude Code uses
|
||||
// CLAUDE_CODE_HOOK_URL when --hook-url is not available).
|
||||
// We also append --hook-url for newer CLI versions.
|
||||
claude_args.push("--hook-url".to_string());
|
||||
claude_args.push(hook_url.clone());
|
||||
|
||||
// Build env string for tmux send-keys
|
||||
let mut env_exports = String::new();
|
||||
for var in SAFE_ENV_VARS {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(env_exports, "{}={} ", var, shell_escape(&val));
|
||||
}
|
||||
}
|
||||
// Pass session metadata via env vars so the hook can correlate events
|
||||
use std::fmt::Write;
|
||||
let _ = write!(env_exports, "CLAUDE_CODE_SESSION_ID={} ", &session_id);
|
||||
if let Some(ref ch) = slack_channel {
|
||||
let _ = write!(env_exports, "CLAUDE_CODE_SLACK_CHANNEL={} ", ch);
|
||||
}
|
||||
let _ = write!(env_exports, "CLAUDE_CODE_HOOK_URL={} ", &hook_url);
|
||||
|
||||
// Create tmux session
|
||||
let create_result = Command::new("tmux")
|
||||
.args(["new-session", "-d", "-s", &session_name])
|
||||
.arg("-c")
|
||||
.arg(work_dir.to_str().unwrap_or("."))
|
||||
.output()
|
||||
.await;
|
||||
|
||||
match create_result {
|
||||
Ok(output) if !output.status.success() => {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to create tmux session: {stderr}")),
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"tmux not found or failed to execute: {e}. Install tmux to use claude_code_runner."
|
||||
)),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Send the claude command into the tmux session
|
||||
let full_command = format!(
|
||||
"{env_exports}{cmd}",
|
||||
env_exports = env_exports,
|
||||
cmd = claude_args
|
||||
.iter()
|
||||
.map(|a| shell_escape(a))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
);
|
||||
|
||||
let send_result = Command::new("tmux")
|
||||
.args(["send-keys", "-t", &session_name, &full_command, "Enter"])
|
||||
.output()
|
||||
.await;
|
||||
|
||||
if let Err(e) = send_result {
|
||||
// Clean up the session we just created
|
||||
let _ = Command::new("tmux")
|
||||
.args(["kill-session", "-t", &session_name])
|
||||
.output()
|
||||
.await;
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to send command to tmux session: {e}")),
|
||||
});
|
||||
}
|
||||
|
||||
// Schedule session TTL cleanup
|
||||
let ttl = self.config.session_ttl;
|
||||
let cleanup_session = session_name.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(ttl)).await;
|
||||
let _ = Command::new("tmux")
|
||||
.args(["kill-session", "-t", &cleanup_session])
|
||||
.output()
|
||||
.await;
|
||||
tracing::info!(
|
||||
session = cleanup_session,
|
||||
"Claude Code runner session TTL expired, cleaned up"
|
||||
);
|
||||
});
|
||||
|
||||
// Build response
|
||||
let mut output_parts = vec![
|
||||
format!("Session started: {session_name}"),
|
||||
format!("Session ID: {session_id}"),
|
||||
format!("Hook URL: {hook_url}"),
|
||||
];
|
||||
|
||||
if let Some(ssh_cmd) = self.ssh_attach_command(&session_name) {
|
||||
output_parts.push(format!("SSH attach: {ssh_cmd}"));
|
||||
} else {
|
||||
output_parts.push(format!(
|
||||
"Local attach: tmux attach-session -t {session_name}"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref ch) = slack_channel {
|
||||
output_parts.push(format!("Slack channel: {ch} (progress updates enabled)"));
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: output_parts.join("\n"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimal shell escaping for values embedded in tmux send-keys.
|
||||
fn shell_escape(s: &str) -> String {
|
||||
if s.chars()
|
||||
.all(|c| c.is_alphanumeric() || matches!(c, '-' | '_' | '.' | '/' | ':' | '=' | '+'))
|
||||
{
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("'{}'", s.replace('\'', "'\\''"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::ClaudeCodeRunnerConfig;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_config() -> ClaudeCodeRunnerConfig {
|
||||
ClaudeCodeRunnerConfig {
|
||||
enabled: true,
|
||||
ssh_host: Some("dev.example.com".into()),
|
||||
tmux_prefix: "zc-test-".into(),
|
||||
session_ttl: 3600,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Supervised),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
assert_eq!(tool.name(), "claude_code_runner");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_prompt() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Supervised),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.expect("required should be an array")
|
||||
.contains(&json!("prompt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_name_uses_prefix() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Supervised),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
let name = tool.session_name("abc123");
|
||||
assert_eq!(name, "zc-test-abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_attach_command_with_host() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Supervised),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
let cmd = tool.ssh_attach_command("zc-test-abc123");
|
||||
assert_eq!(
|
||||
cmd.as_deref(),
|
||||
Some("ssh -t dev.example.com tmux attach-session -t zc-test-abc123")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_attach_command_without_host() {
|
||||
let mut config = test_config();
|
||||
config.ssh_host = None;
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Supervised),
|
||||
config,
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
assert!(tool.ssh_attach_command("session").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blocks_rate_limited() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
max_actions_per_hour: 0,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool =
|
||||
ClaudeCodeRunnerTool::new(security, test_config(), "http://localhost:3000".into());
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "hello"}))
|
||||
.await
|
||||
.expect("rate-limited should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blocks_readonly() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::ReadOnly),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "hello"}))
|
||||
.await
|
||||
.expect("readonly should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Supervised),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_path_outside_workspace() {
|
||||
let tool = ClaudeCodeRunnerTool::new(
|
||||
test_security(AutonomyLevel::Full),
|
||||
test_config(),
|
||||
"http://localhost:3000".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"prompt": "hello",
|
||||
"working_directory": "/etc"
|
||||
}))
|
||||
.await
|
||||
.expect("should return a result for path validation");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("outside the workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_escape_simple() {
|
||||
assert_eq!(shell_escape("hello"), "hello");
|
||||
assert_eq!(shell_escape("hello world"), "'hello world'");
|
||||
assert_eq!(shell_escape("it's"), "'it'\\''s'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_event_deserialization() {
|
||||
let json = r#"{
|
||||
"session_id": "abc123",
|
||||
"event_type": "tool_use",
|
||||
"tool_name": "Edit",
|
||||
"summary": "Editing file.rs"
|
||||
}"#;
|
||||
let event: ClaudeCodeHookEvent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(event.session_id, "abc123");
|
||||
assert_eq!(event.event_type, "tool_use");
|
||||
assert_eq!(event.tool_name.as_deref(), Some("Edit"));
|
||||
}
|
||||
}
|
||||
@ -315,7 +315,13 @@ impl Tool for CronAddTool {
|
||||
.map(str::to_string);
|
||||
let allowed_tools = match args.get("allowed_tools") {
|
||||
Some(v) => match serde_json::from_value::<Vec<String>>(v.clone()) {
|
||||
Ok(v) => Some(v),
|
||||
Ok(v) => {
|
||||
if v.is_empty() {
|
||||
None // Treat empty list same as unset
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
@ -694,6 +700,32 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_allowed_tools_stored_as_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": { "kind": "cron", "expr": "*/5 * * * *" },
|
||||
"job_type": "agent",
|
||||
"prompt": "check status",
|
||||
"allowed_tools": []
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
|
||||
let jobs = cron::list_jobs(&cfg).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(
|
||||
jobs[0].allowed_tools, None,
|
||||
"empty allowed_tools should be stored as None"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delivery_schema_includes_matrix_channel() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@ -508,6 +508,43 @@ mod tests {
|
||||
assert!(cron::get_job(&cfg, &job.id).unwrap().enabled);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_allowed_tools_patch_stored_as_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let job = cron::add_agent_job(
|
||||
&cfg,
|
||||
None,
|
||||
crate::cron::Schedule::Cron {
|
||||
expr: "*/5 * * * *".into(),
|
||||
tz: None,
|
||||
},
|
||||
"check status",
|
||||
crate::cron::SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
Some(vec!["file_read".into()]),
|
||||
)
|
||||
.unwrap();
|
||||
let tool = CronUpdateTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"job_id": job.id,
|
||||
"patch": { "allowed_tools": [] }
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
assert_eq!(
|
||||
cron::get_job(&cfg, &job.id).unwrap().allowed_tools,
|
||||
None,
|
||||
"empty allowed_tools patch should clear to None"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn updates_agent_allowed_tools() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@ -65,10 +65,42 @@ impl GitOperationsTool {
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_git_command(&self, args: &[&str]) -> anyhow::Result<String> {
|
||||
/// Resolve a user-provided path to an absolute path within the workspace.
|
||||
/// Returns the workspace_dir if no path is provided.
|
||||
/// Rejects paths that escape the workspace via traversal.
|
||||
fn resolve_working_dir(&self, path: Option<&str>) -> anyhow::Result<std::path::PathBuf> {
|
||||
let base = match path {
|
||||
Some(p) if !p.is_empty() => {
|
||||
let candidate = if std::path::Path::new(p).is_absolute() {
|
||||
std::path::PathBuf::from(p)
|
||||
} else {
|
||||
self.workspace_dir.join(p)
|
||||
};
|
||||
let resolved = candidate
|
||||
.canonicalize()
|
||||
.map_err(|e| anyhow::anyhow!("Cannot resolve path '{}': {}", p, e))?;
|
||||
let workspace_canonical = self
|
||||
.workspace_dir
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| self.workspace_dir.clone());
|
||||
if !resolved.starts_with(&workspace_canonical) {
|
||||
anyhow::bail!("Path '{}' resolves outside the workspace directory", p);
|
||||
}
|
||||
resolved
|
||||
}
|
||||
_ => self.workspace_dir.clone(),
|
||||
};
|
||||
Ok(base)
|
||||
}
|
||||
|
||||
async fn run_git_command(
|
||||
&self,
|
||||
args: &[&str],
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<String> {
|
||||
let output = tokio::process::Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(&self.workspace_dir)
|
||||
.current_dir(working_dir)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
@ -80,9 +112,13 @@ impl GitOperationsTool {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
}
|
||||
|
||||
async fn git_status(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_status(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let output = self
|
||||
.run_git_command(&["status", "--porcelain=2", "--branch"])
|
||||
.run_git_command(&["status", "--porcelain=2", "--branch"], working_dir)
|
||||
.await?;
|
||||
|
||||
// Parse git status output into structured format
|
||||
@ -131,7 +167,11 @@ impl GitOperationsTool {
|
||||
})
|
||||
}
|
||||
|
||||
async fn git_diff(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_diff(
|
||||
&self,
|
||||
args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let files = args.get("files").and_then(|v| v.as_str()).unwrap_or(".");
|
||||
let cached = args
|
||||
.get("cached")
|
||||
@ -148,7 +188,7 @@ impl GitOperationsTool {
|
||||
git_args.push("--");
|
||||
git_args.push(files);
|
||||
|
||||
let output = self.run_git_command(&git_args).await?;
|
||||
let output = self.run_git_command(&git_args, working_dir).await?;
|
||||
|
||||
// Parse diff into structured hunks
|
||||
let mut result = serde_json::Map::new();
|
||||
@ -210,18 +250,25 @@ impl GitOperationsTool {
|
||||
})
|
||||
}
|
||||
|
||||
async fn git_log(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_log(
|
||||
&self,
|
||||
args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let limit_raw = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10);
|
||||
let limit = usize::try_from(limit_raw).unwrap_or(usize::MAX).min(1000);
|
||||
let limit_str = limit.to_string();
|
||||
|
||||
let output = self
|
||||
.run_git_command(&[
|
||||
"log",
|
||||
&format!("-{limit_str}"),
|
||||
"--pretty=format:%H|%an|%ae|%ad|%s",
|
||||
"--date=iso",
|
||||
])
|
||||
.run_git_command(
|
||||
&[
|
||||
"log",
|
||||
&format!("-{limit_str}"),
|
||||
"--pretty=format:%H|%an|%ae|%ad|%s",
|
||||
"--date=iso",
|
||||
],
|
||||
working_dir,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut commits = Vec::new();
|
||||
@ -247,9 +294,16 @@ impl GitOperationsTool {
|
||||
})
|
||||
}
|
||||
|
||||
async fn git_branch(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_branch(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let output = self
|
||||
.run_git_command(&["branch", "--format=%(refname:short)|%(HEAD)"])
|
||||
.run_git_command(
|
||||
&["branch", "--format=%(refname:short)|%(HEAD)"],
|
||||
working_dir,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut branches = Vec::new();
|
||||
@ -287,7 +341,11 @@ impl GitOperationsTool {
|
||||
}
|
||||
}
|
||||
|
||||
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_commit(
|
||||
&self,
|
||||
args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let message = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
@ -308,7 +366,9 @@ impl GitOperationsTool {
|
||||
// Limit message length
|
||||
let message = Self::truncate_commit_message(&sanitized);
|
||||
|
||||
let output = self.run_git_command(&["commit", "-m", &message]).await;
|
||||
let output = self
|
||||
.run_git_command(&["commit", "-m", &message], working_dir)
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(_) => Ok(ToolResult {
|
||||
@ -324,7 +384,11 @@ impl GitOperationsTool {
|
||||
}
|
||||
}
|
||||
|
||||
async fn git_add(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_add(
|
||||
&self,
|
||||
args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let paths = args
|
||||
.get("paths")
|
||||
.and_then(|v| v.as_str())
|
||||
@ -333,7 +397,9 @@ impl GitOperationsTool {
|
||||
// Validate paths against injection patterns
|
||||
self.sanitize_git_args(paths)?;
|
||||
|
||||
let output = self.run_git_command(&["add", "--", paths]).await;
|
||||
let output = self
|
||||
.run_git_command(&["add", "--", paths], working_dir)
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(_) => Ok(ToolResult {
|
||||
@ -349,7 +415,11 @@ impl GitOperationsTool {
|
||||
}
|
||||
}
|
||||
|
||||
async fn git_checkout(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_checkout(
|
||||
&self,
|
||||
args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let branch = args
|
||||
.get("branch")
|
||||
.and_then(|v| v.as_str())
|
||||
@ -369,7 +439,9 @@ impl GitOperationsTool {
|
||||
anyhow::bail!("Branch name contains invalid characters");
|
||||
}
|
||||
|
||||
let output = self.run_git_command(&["checkout", branch_name]).await;
|
||||
let output = self
|
||||
.run_git_command(&["checkout", branch_name], working_dir)
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(_) => Ok(ToolResult {
|
||||
@ -385,7 +457,11 @@ impl GitOperationsTool {
|
||||
}
|
||||
}
|
||||
|
||||
async fn git_stash(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
async fn git_stash(
|
||||
&self,
|
||||
args: serde_json::Value,
|
||||
working_dir: &std::path::Path,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
@ -393,17 +469,20 @@ impl GitOperationsTool {
|
||||
|
||||
let output = match action {
|
||||
"push" | "save" => {
|
||||
self.run_git_command(&["stash", "push", "-m", "auto-stash"])
|
||||
self.run_git_command(&["stash", "push", "-m", "auto-stash"], working_dir)
|
||||
.await
|
||||
}
|
||||
"pop" => self.run_git_command(&["stash", "pop"]).await,
|
||||
"list" => self.run_git_command(&["stash", "list"]).await,
|
||||
"pop" => self.run_git_command(&["stash", "pop"], working_dir).await,
|
||||
"list" => self.run_git_command(&["stash", "list"], working_dir).await,
|
||||
"drop" => {
|
||||
let index_raw = args.get("index").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
let index = i32::try_from(index_raw)
|
||||
.map_err(|_| anyhow::anyhow!("stash index too large: {index_raw}"))?;
|
||||
self.run_git_command(&["stash", "drop", &format!("stash@{{{index}}}")])
|
||||
.await
|
||||
self.run_git_command(
|
||||
&["stash", "drop", &format!("stash@{{{index}}}")],
|
||||
working_dir,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("Unknown stash action: {action}. Use: push, pop, list, drop"),
|
||||
};
|
||||
@ -474,6 +553,10 @@ impl Tool for GitOperationsTool {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "Stash index (for 'stash' with 'drop' action)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional subdirectory path within the workspace to run git operations in. Defaults to workspace root."
|
||||
}
|
||||
},
|
||||
"required": ["operation"]
|
||||
@ -492,10 +575,22 @@ impl Tool for GitOperationsTool {
|
||||
}
|
||||
};
|
||||
|
||||
let path = args.get("path").and_then(|v| v.as_str());
|
||||
let working_dir = match self.resolve_working_dir(path) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Invalid path: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Check if we're in a git repository
|
||||
if !self.workspace_dir.join(".git").exists() {
|
||||
if !working_dir.join(".git").exists() {
|
||||
// Try to find .git in parent directories
|
||||
let mut current_dir = self.workspace_dir.as_path();
|
||||
let mut current_dir = working_dir.as_path();
|
||||
let mut found_git = false;
|
||||
while current_dir.parent().is_some() {
|
||||
if current_dir.join(".git").exists() {
|
||||
@ -549,14 +644,14 @@ impl Tool for GitOperationsTool {
|
||||
|
||||
// Execute the requested operation
|
||||
match operation {
|
||||
"status" => self.git_status(args).await,
|
||||
"diff" => self.git_diff(args).await,
|
||||
"log" => self.git_log(args).await,
|
||||
"branch" => self.git_branch(args).await,
|
||||
"commit" => self.git_commit(args).await,
|
||||
"add" => self.git_add(args).await,
|
||||
"checkout" => self.git_checkout(args).await,
|
||||
"stash" => self.git_stash(args).await,
|
||||
"status" => self.git_status(args, &working_dir).await,
|
||||
"diff" => self.git_diff(args, &working_dir).await,
|
||||
"log" => self.git_log(args, &working_dir).await,
|
||||
"branch" => self.git_branch(args, &working_dir).await,
|
||||
"commit" => self.git_commit(args, &working_dir).await,
|
||||
"add" => self.git_add(args, &working_dir).await,
|
||||
"checkout" => self.git_checkout(args, &working_dir).await,
|
||||
"stash" => self.git_stash(args, &working_dir).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
@ -810,4 +905,82 @@ mod tests {
|
||||
|
||||
assert_eq!(truncated.chars().count(), 2000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_working_dir_none_returns_workspace() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
let result = tool.resolve_working_dir(None).unwrap();
|
||||
assert_eq!(result, tmp.path().to_path_buf());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_working_dir_empty_returns_workspace() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
let result = tool.resolve_working_dir(Some("")).unwrap();
|
||||
assert_eq!(result, tmp.path().to_path_buf());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_working_dir_valid_subdir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
std::fs::create_dir(tmp.path().join("subproject")).unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
let result = tool.resolve_working_dir(Some("subproject")).unwrap();
|
||||
let expected = tmp.path().join("subproject").canonicalize().unwrap();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_working_dir_rejects_traversal() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
let result = tool.resolve_working_dir(Some(".."));
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("resolves outside the workspace"),
|
||||
"Expected traversal rejection, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn git_operations_work_in_subdirectory() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sub = tmp.path().join("nested");
|
||||
std::fs::create_dir(&sub).unwrap();
|
||||
std::process::Command::new("git")
|
||||
.args(["init"])
|
||||
.current_dir(&sub)
|
||||
.output()
|
||||
.unwrap();
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.email", "test@test.com"])
|
||||
.current_dir(&sub)
|
||||
.output()
|
||||
.unwrap();
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.name", "Test"])
|
||||
.current_dir(&sub)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"operation": "status", "path": "nested"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
result.success,
|
||||
"Expected success, got error: {:?}",
|
||||
result.error
|
||||
);
|
||||
assert!(result.output.contains("branch"));
|
||||
}
|
||||
}
|
||||
|
||||
279
src/tools/memory_purge.rs
Normal file
279
src/tools/memory_purge.rs
Normal file
@ -0,0 +1,279 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Let the agent bulk-delete memories by namespace or session
|
||||
pub struct MemoryPurgeTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl MemoryPurgeTool {
|
||||
pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { memory, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryPurgeTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_purge"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Remove all memories in a namespace (category) or session. Use to bulk-delete conversation context or category-scoped data. Returns the number of deleted entries. WARNING: This operation cannot be undone."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"namespace": {
|
||||
"type": "string",
|
||||
"description": "The namespace (category) to purge. Deletes all memories in this category."
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The session ID to purge. Deletes all memories in this session."
|
||||
}
|
||||
},
|
||||
"minProperties": 1
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let namespace = args.get("namespace").and_then(|v| v.as_str());
|
||||
let session_id = args.get("session_id").and_then(|v| v.as_str());
|
||||
|
||||
if namespace.is_none() && session_id.is_none() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Must provide either 'namespace' or 'session_id' parameter"
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "memory_purge")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let mut total_purged = 0;
|
||||
let mut output_parts = Vec::new();
|
||||
|
||||
if let Some(ns) = namespace {
|
||||
match self.memory.purge_namespace(ns).await {
|
||||
Ok(count) => {
|
||||
total_purged += count;
|
||||
output_parts.push(format!("Purged {count} memories from namespace '{ns}'"));
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to purge namespace: {e}")),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
match self.memory.purge_session(sid).await {
|
||||
Ok(count) => {
|
||||
total_purged += count;
|
||||
output_parts.push(format!("Purged {count} memories from session '{sid}'"));
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to purge session: {e}")),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: if output_parts.is_empty() {
|
||||
format!("Purged {total_purged} memories")
|
||||
} else {
|
||||
output_parts.join("; ")
|
||||
},
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryPurgeTool::new(mem, test_security());
|
||||
assert_eq!(tool.name(), "memory_purge");
|
||||
assert!(tool.parameters_schema()["properties"]["namespace"].is_object());
|
||||
assert!(tool.parameters_schema()["properties"]["session_id"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_namespace_removes_all_memories() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store(
|
||||
"a1",
|
||||
"data1",
|
||||
MemoryCategory::Custom("test_ns".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"a2",
|
||||
"data2",
|
||||
MemoryCategory::Custom("test_ns".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b1", "data3", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryPurgeTool::new(mem.clone(), test_security());
|
||||
let result = tool.execute(json!({"namespace": "test_ns"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("2 memories"));
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_session_removes_all_memories() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("a1", "data1", MemoryCategory::Core, Some("sess-x"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("a2", "data2", MemoryCategory::Core, Some("sess-x"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b1", "data3", MemoryCategory::Core, Some("sess-y"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryPurgeTool::new(mem.clone(), test_security());
|
||||
let result = tool.execute(json!({"session_id": "sess-x"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("2 memories"));
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_namespace_nonexistent_is_noop() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryPurgeTool::new(mem.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({"namespace": "nonexistent"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("0 memories"));
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_session_nonexistent_is_noop() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("a", "data", MemoryCategory::Core, Some("sess"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryPurgeTool::new(mem.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "nonexistent"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("0 memories"));
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_missing_parameter() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryPurgeTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_blocked_in_readonly_mode() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("a", "data", MemoryCategory::Custom("test".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryPurgeTool::new(mem.clone(), readonly);
|
||||
let result = tool.execute(json!({"namespace": "test"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_blocked_when_rate_limited() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("a", "data", MemoryCategory::Custom("test".into()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryPurgeTool::new(mem.clone(), limited);
|
||||
let result = tool.execute(json!({"namespace": "test"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
}
|
||||
@ -15,6 +15,7 @@
|
||||
//! To add a new tool, implement [`Tool`] in a new submodule and register it in
|
||||
//! [`all_tools_with_runtime`]. See `AGENTS.md` §7.3 for the full change playbook.
|
||||
|
||||
pub mod ask_user;
|
||||
pub mod backup_tool;
|
||||
pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
@ -22,6 +23,7 @@ pub mod browser_open;
|
||||
pub mod calculator;
|
||||
pub mod canvas;
|
||||
pub mod claude_code;
|
||||
pub mod claude_code_runner;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
pub mod cloud_patterns;
|
||||
@ -64,6 +66,7 @@ pub mod mcp_protocol;
|
||||
pub mod mcp_tool;
|
||||
pub mod mcp_transport;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_purge;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod microsoft365;
|
||||
@ -104,6 +107,7 @@ mod web_search_provider_routing;
|
||||
pub mod web_search_tool;
|
||||
pub mod workspace_tool;
|
||||
|
||||
pub use ask_user::AskUserTool;
|
||||
pub use backup_tool::BackupTool;
|
||||
pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
#[allow(unused_imports)]
|
||||
@ -112,6 +116,7 @@ pub use browser_open::BrowserOpenTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use canvas::{CanvasStore, CanvasTool};
|
||||
pub use claude_code::ClaudeCodeTool;
|
||||
pub use claude_code_runner::ClaudeCodeRunnerTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
pub use codex_cli::CodexCliTool;
|
||||
@ -153,6 +158,7 @@ pub use mcp_client::McpRegistry;
|
||||
pub use mcp_deferred::{ActivatedToolSet, DeferredMcpToolSet};
|
||||
pub use mcp_tool::McpToolWrapper;
|
||||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_purge::MemoryPurgeTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use microsoft365::Microsoft365Tool;
|
||||
@ -338,6 +344,7 @@ pub fn all_tools(
|
||||
Option<DelegateParentToolsHandle>,
|
||||
Option<ChannelMapHandle>,
|
||||
ChannelMapHandle,
|
||||
Option<ChannelMapHandle>,
|
||||
) {
|
||||
all_tools_with_runtime(
|
||||
config,
|
||||
@ -383,15 +390,15 @@ pub fn all_tools_with_runtime(
|
||||
Option<DelegateParentToolsHandle>,
|
||||
Option<ChannelMapHandle>,
|
||||
ChannelMapHandle,
|
||||
Option<ChannelMapHandle>,
|
||||
) {
|
||||
let has_shell_access = runtime.has_shell_access();
|
||||
let sandbox = create_sandbox(&root_config.security);
|
||||
let mut tool_arcs: Vec<Arc<dyn Tool>> = vec![
|
||||
Arc::new(ShellTool::new_with_sandbox(
|
||||
security.clone(),
|
||||
runtime,
|
||||
sandbox,
|
||||
)),
|
||||
Arc::new(
|
||||
ShellTool::new_with_sandbox(security.clone(), runtime, sandbox)
|
||||
.with_timeout_secs(root_config.shell_tool.timeout_secs),
|
||||
),
|
||||
Arc::new(FileReadTool::new(security.clone())),
|
||||
Arc::new(FileWriteTool::new(security.clone())),
|
||||
Arc::new(FileEditTool::new(security.clone())),
|
||||
@ -405,7 +412,8 @@ pub fn all_tools_with_runtime(
|
||||
Arc::new(CronRunsTool::new(config.clone())),
|
||||
Arc::new(MemoryStoreTool::new(memory.clone(), security.clone())),
|
||||
Arc::new(MemoryRecallTool::new(memory.clone())),
|
||||
Arc::new(MemoryForgetTool::new(memory, security.clone())),
|
||||
Arc::new(MemoryForgetTool::new(memory.clone(), security.clone())),
|
||||
Arc::new(MemoryPurgeTool::new(memory, security.clone())),
|
||||
Arc::new(ScheduleTool::new(security.clone(), root_config.clone())),
|
||||
Arc::new(ModelRoutingConfigTool::new(
|
||||
config.clone(),
|
||||
@ -674,6 +682,19 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Claude Code task runner with Slack progress and SSH handoff
|
||||
if root_config.claude_code_runner.enabled {
|
||||
let gateway_url = format!(
|
||||
"http://{}:{}",
|
||||
root_config.gateway.host, root_config.gateway.port
|
||||
);
|
||||
tool_arcs.push(Arc::new(ClaudeCodeRunnerTool::new(
|
||||
security.clone(),
|
||||
root_config.claude_code_runner.clone(),
|
||||
gateway_url,
|
||||
)));
|
||||
}
|
||||
|
||||
// Codex CLI delegation tool
|
||||
if root_config.codex_cli.enabled {
|
||||
tool_arcs.push(Arc::new(CodexCliTool::new(
|
||||
@ -772,6 +793,11 @@ pub fn all_tools_with_runtime(
|
||||
let reaction_handle = reaction_tool.channel_map_handle();
|
||||
tool_arcs.push(Arc::new(reaction_tool));
|
||||
|
||||
// Interactive ask_user tool — always registered; channel map populated later by start_channels.
|
||||
let ask_user_tool = AskUserTool::new(security.clone());
|
||||
let ask_user_handle = ask_user_tool.channel_map_handle();
|
||||
tool_arcs.push(Arc::new(ask_user_tool));
|
||||
|
||||
// Microsoft 365 Graph API integration
|
||||
if root_config.microsoft365.enabled {
|
||||
let ms_cfg = &root_config.microsoft365;
|
||||
@ -803,6 +829,7 @@ pub fn all_tools_with_runtime(
|
||||
None,
|
||||
Some(reaction_handle),
|
||||
channel_map_handle,
|
||||
Some(ask_user_handle),
|
||||
);
|
||||
}
|
||||
|
||||
@ -992,6 +1019,7 @@ pub fn all_tools_with_runtime(
|
||||
delegate_handle,
|
||||
Some(reaction_handle),
|
||||
channel_map_handle,
|
||||
Some(ask_user_handle),
|
||||
)
|
||||
}
|
||||
|
||||
@ -1036,7 +1064,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _, _, _) = all_tools(
|
||||
let (tools, _, _, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@ -1079,7 +1107,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _, _, _) = all_tools(
|
||||
let (tools, _, _, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@ -1233,7 +1261,7 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
let (tools, _, _, _) = all_tools(
|
||||
let (tools, _, _, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@ -1267,7 +1295,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _, _, _) = all_tools(
|
||||
let (tools, _, _, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@ -1302,7 +1330,7 @@ mod tests {
|
||||
let mut cfg = test_config(&tmp);
|
||||
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Compact;
|
||||
|
||||
let (tools, _, _, _) = all_tools(
|
||||
let (tools, _, _, _, _) = all_tools(
|
||||
Arc::new(cfg.clone()),
|
||||
&security,
|
||||
mem,
|
||||
@ -1337,7 +1365,7 @@ mod tests {
|
||||
let mut cfg = test_config(&tmp);
|
||||
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Full;
|
||||
|
||||
let (tools, _, _, _) = all_tools(
|
||||
let (tools, _, _, _, _) = all_tools(
|
||||
Arc::new(cfg.clone()),
|
||||
&security,
|
||||
mem,
|
||||
|
||||
@ -8,8 +8,8 @@ use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Maximum shell command execution time before kill.
|
||||
const SHELL_TIMEOUT_SECS: u64 = 60;
|
||||
/// Default maximum shell command execution time before kill.
|
||||
const DEFAULT_SHELL_TIMEOUT_SECS: u64 = 60;
|
||||
/// Maximum output size in bytes (1MB).
|
||||
const MAX_OUTPUT_BYTES: usize = 1_048_576;
|
||||
|
||||
@ -46,6 +46,7 @@ pub struct ShellTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
runtime: Arc<dyn RuntimeAdapter>,
|
||||
sandbox: Arc<dyn Sandbox>,
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl ShellTool {
|
||||
@ -54,6 +55,7 @@ impl ShellTool {
|
||||
security,
|
||||
runtime,
|
||||
sandbox: Arc::new(crate::security::NoopSandbox),
|
||||
timeout_secs: DEFAULT_SHELL_TIMEOUT_SECS,
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,8 +68,15 @@ impl ShellTool {
|
||||
security,
|
||||
runtime,
|
||||
sandbox,
|
||||
timeout_secs: DEFAULT_SHELL_TIMEOUT_SECS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Override the command execution timeout (in seconds).
|
||||
pub fn with_timeout_secs(mut self, secs: u64) -> Self {
|
||||
self.timeout_secs = secs;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn is_valid_env_var_name(name: &str) -> bool {
|
||||
@ -203,8 +212,8 @@ impl Tool for ShellTool {
|
||||
}
|
||||
}
|
||||
|
||||
let result =
|
||||
tokio::time::timeout(Duration::from_secs(SHELL_TIMEOUT_SECS), cmd.output()).await;
|
||||
let timeout_secs = self.timeout_secs;
|
||||
let result = tokio::time::timeout(Duration::from_secs(timeout_secs), cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
@ -248,7 +257,7 @@ impl Tool for ShellTool {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Command timed out after {SHELL_TIMEOUT_SECS}s and was killed"
|
||||
"Command timed out after {timeout_secs}s and was killed"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
@ -607,8 +616,18 @@ mod tests {
|
||||
// ── shell timeout enforcement tests ─────────────────
|
||||
|
||||
#[test]
|
||||
fn shell_timeout_constant_is_reasonable() {
|
||||
assert_eq!(SHELL_TIMEOUT_SECS, 60, "shell timeout must be 60 seconds");
|
||||
fn shell_timeout_default_is_reasonable() {
|
||||
assert_eq!(
|
||||
DEFAULT_SHELL_TIMEOUT_SECS, 60,
|
||||
"default shell timeout must be 60 seconds"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_timeout_can_be_overridden() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime())
|
||||
.with_timeout_secs(120);
|
||||
assert_eq!(tool.timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -60,6 +60,11 @@ enum ChannelEvent {
|
||||
channel_id: String,
|
||||
message_id: String,
|
||||
},
|
||||
RedactMessage {
|
||||
channel_id: String,
|
||||
message_id: String,
|
||||
reason: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Full-featured matrix test channel that tracks every trait method invocation.
|
||||
@ -257,6 +262,23 @@ impl Channel for MatrixTestChannel {
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn redact_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
message_id: &str,
|
||||
reason: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.events
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(ChannelEvent::RedactMessage {
|
||||
channel_id: channel_id.to_string(),
|
||||
message_id: message_id.to_string(),
|
||||
reason,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
@ -553,7 +575,44 @@ async fn pin_multiple_messages_in_same_channel() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 6. CHANNEL MESSAGE IDENTITY & FIELD SEMANTICS
|
||||
// 6. MESSAGE REDACTION SUPPORT
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/// Tests that MatrixTestChannel correctly records redaction events.
|
||||
/// This validates the mock contract, not the trait default or real implementation.
|
||||
/// Trait default coverage: `src/channels/traits.rs::default_redact_message_returns_success`
|
||||
/// Real implementation coverage: requires live Matrix integration tests (not in this suite).
|
||||
#[tokio::test]
|
||||
async fn redact_message_lifecycle() {
|
||||
let ch = MatrixTestChannel::new("matrix");
|
||||
|
||||
ch.redact_message("room_1", "msg_1", Some("spam".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ch.redact_message("room_1", "msg_2", None).await.unwrap();
|
||||
|
||||
let events = ch.events();
|
||||
assert_eq!(events.len(), 2);
|
||||
assert!(matches!(
|
||||
&events[0],
|
||||
ChannelEvent::RedactMessage {
|
||||
channel_id,
|
||||
message_id,
|
||||
reason
|
||||
} if channel_id == "room_1" && message_id == "msg_1" && reason == &Some("spam".to_string())
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[1],
|
||||
ChannelEvent::RedactMessage {
|
||||
channel_id,
|
||||
message_id,
|
||||
reason
|
||||
} if channel_id == "room_1" && message_id == "msg_2" && reason.is_none()
|
||||
));
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 7. CHANNEL MESSAGE IDENTITY & FIELD SEMANTICS
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[test]
|
||||
@ -619,7 +678,7 @@ fn send_message_with_subject_preserves_thread() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 7. CROSS-CHANNEL IDENTITY SEMANTICS PER PLATFORM
|
||||
// 8. CROSS-CHANNEL IDENTITY SEMANTICS PER PLATFORM
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/// Simulates the identity mapping for each platform:
|
||||
@ -921,7 +980,7 @@ fn threaded_platforms_have_thread_ts() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 8. SEND → REPLY ROUNDTRIP CONSISTENCY
|
||||
// 9. SEND → REPLY ROUNDTRIP CONSISTENCY
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[tokio::test]
|
||||
@ -963,7 +1022,7 @@ async fn threaded_reply_preserves_thread_ts() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 9. CONCURRENT OPERATIONS
|
||||
// 10. CONCURRENT OPERATIONS
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[tokio::test]
|
||||
@ -1037,7 +1096,7 @@ async fn concurrent_reactions_all_recorded() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 10. EDGE CASES & BOUNDARY CONDITIONS
|
||||
// 11. EDGE CASES & BOUNDARY CONDITIONS
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[tokio::test]
|
||||
@ -1142,7 +1201,7 @@ fn send_message_empty_subject() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 11. MULTI-CHANNEL SIMULATION (CROSS-CHANNEL ROUTING)
|
||||
// 12. MULTI-CHANNEL SIMULATION (CROSS-CHANNEL ROUTING)
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[tokio::test]
|
||||
@ -1205,7 +1264,7 @@ async fn multi_channel_listen_produces_channel_tagged_messages() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 12. CAPABILITY MATRIX DECLARATIONS
|
||||
// 13. CAPABILITY MATRIX DECLARATIONS
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/// Documents the expected capability matrix for all channels. This test serves
|
||||
@ -1244,7 +1303,7 @@ async fn capability_matrix_spec() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 13. DEFAULT TRAIT METHOD CONTRACT (via dyn dispatch)
|
||||
// 14. DEFAULT TRAIT METHOD CONTRACT (via dyn dispatch)
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/// Minimal channel with ONLY required methods — validates all defaults work.
|
||||
@ -1286,6 +1345,11 @@ async fn minimal_channel_all_defaults_succeed() {
|
||||
assert!(ch.remove_reaction("c", "m", "\u{1F440}").await.is_ok());
|
||||
assert!(ch.pin_message("c", "m").await.is_ok());
|
||||
assert!(ch.unpin_message("c", "m").await.is_ok());
|
||||
assert!(ch
|
||||
.redact_message("c", "m", Some("test".to_string()))
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(ch.redact_message("c", "m", None).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -1307,7 +1371,7 @@ async fn dyn_channel_dispatch_works() {
|
||||
}
|
||||
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// 14. MIXED OPERATION SEQUENCES
|
||||
// 15. MIXED OPERATION SEQUENCES
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user