Compare commits
74 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e7ad69d69a | |||
| 861dd3e2e9 | |||
| 8a61a283b2 | |||
| dcc0a629ec | |||
| a8c6363cde | |||
| d9ab017df0 | |||
| 249434edb2 | |||
| 62781a8d45 | |||
| 0adec305f9 | |||
| 75701195d7 | |||
| 82fe2e53fd | |||
| 20f25ba108 | |||
| c676dc325e | |||
| 327e2b4c47 | |||
| 5a5d9ae5f9 | |||
| d8f228bd15 | |||
| d5eeaed3d9 | |||
| 429094b049 | |||
| 80213b08ef | |||
| bf67124499 | |||
| cb250dfecf | |||
| fabd35c4ea | |||
| a695ca4b9c | |||
| 811fab3b87 | |||
| 1a5d91fe69 | |||
| 6eec1c81b9 | |||
| 602db8bca1 | |||
| 314e1d3ae8 | |||
| 82be05b1e9 | |||
| 1373659058 | |||
| c7f064e866 | |||
| 9c1d63e109 | |||
| 966edf1553 | |||
| a1af84d992 | |||
| 0ad1965081 | |||
| 70e8e7ebcd | |||
| 2bcb82c5b3 | |||
| e211b5c3e3 | |||
| 8691476577 | |||
| e34a804255 | |||
| 6120b3f705 | |||
| f175261e32 | |||
| fd9f66cad7 | |||
| d928ebc92e | |||
| 9fca9f478a | |||
| 7106632b51 | |||
| b834278754 | |||
| 186f6d9797 | |||
| 6cdc92a256 | |||
| 02599dcd3c | |||
| fe64d7ef7e | |||
| 996dbe95cf | |||
| 45f953be6d | |||
| 82f29bbcb1 | |||
| 93b5a0b824 | |||
| 08a67c4a2d | |||
| c86a0673ba | |||
| cabf99ba07 | |||
| 2d978a6b64 | |||
| 4dbc9266c1 | |||
| ea0b3c8c8c | |||
| 0c56834385 | |||
| caccf0035e | |||
| 627b160f55 | |||
| 6463bc84b0 | |||
| f84f1229af | |||
| f85d21097b | |||
| 306821d6a2 | |||
| 06f9424274 | |||
| fa14ab4ab2 | |||
| 36a0c8eba9 | |||
| f4c82d5797 | |||
| 5edebf4869 | |||
| 613fa79444 |
@@ -1,139 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# sync-readme.sh — Auto-update "What's New" and "Recent Contributors" in all READMEs
|
||||
# Called by the sync-readme GitHub Actions workflow on each release.
|
||||
set -euo pipefail
|
||||
|
||||
# --- Resolve version and ranges ---
|
||||
|
||||
LATEST_TAG=$(git tag --sort=-creatordate | head -1 || echo "")
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No tags found — skipping README sync"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
VERSION="${LATEST_TAG#v}"
|
||||
|
||||
# Find previous stable tag for contributor range
|
||||
PREV_STABLE=$(git tag --sort=-creatordate \
|
||||
| grep -v "^${LATEST_TAG}$" \
|
||||
| grep -vE '\-beta\.' \
|
||||
| head -1 || echo "")
|
||||
|
||||
FEAT_RANGE="${PREV_STABLE:+${PREV_STABLE}..}${LATEST_TAG}"
|
||||
CONTRIB_RANGE="${PREV_STABLE:+${PREV_STABLE}..}${LATEST_TAG}"
|
||||
|
||||
# --- Build "What's New" table rows ---
|
||||
|
||||
FEATURES=$(git log "$FEAT_RANGE" --pretty=format:"%s" --no-merges \
|
||||
| grep -iE '^feat(\(|:)' \
|
||||
| sed 's/^feat(\([^)]*\)): /| \1 | /' \
|
||||
| sed 's/^feat: /| General | /' \
|
||||
| sed 's/ (#[0-9]*)$//' \
|
||||
| sort -uf \
|
||||
| while IFS= read -r line; do echo "${line} |"; done || true)
|
||||
|
||||
if [ -z "$FEATURES" ]; then
|
||||
FEATURES="| General | Incremental improvements and polish |"
|
||||
fi
|
||||
|
||||
MONTH_YEAR=$(date -u +"%B %Y")
|
||||
|
||||
# --- Build contributor list ---
|
||||
|
||||
GIT_AUTHORS=$(git log "$CONTRIB_RANGE" --pretty=format:"%an" --no-merges | sort -uf || true)
|
||||
CO_AUTHORS=$(git log "$CONTRIB_RANGE" --pretty=format:"%b" --no-merges \
|
||||
| grep -ioE 'Co-Authored-By: *[^<]+' \
|
||||
| sed 's/Co-Authored-By: *//i' \
|
||||
| sed 's/ *$//' \
|
||||
| sort -uf || true)
|
||||
|
||||
ALL_CONTRIBUTORS=$(printf "%s\n%s" "$GIT_AUTHORS" "$CO_AUTHORS" \
|
||||
| sort -uf \
|
||||
| grep -v '^$' \
|
||||
| grep -viE '\[bot\]$|^dependabot|^github-actions|^copilot|^ZeroClaw Bot|^ZeroClaw Runner|^ZeroClaw Agent|^blacksmith' \
|
||||
|| true)
|
||||
|
||||
CONTRIBUTOR_COUNT=$(echo "$ALL_CONTRIBUTORS" | grep -c . || echo "0")
|
||||
|
||||
CONTRIBUTOR_LIST=$(echo "$ALL_CONTRIBUTORS" \
|
||||
| while IFS= read -r name; do
|
||||
[ -z "$name" ] && continue
|
||||
echo "- **${name}**"
|
||||
done || true)
|
||||
|
||||
# --- Write temp files for section content ---
|
||||
|
||||
WHATS_NEW_FILE=$(mktemp)
|
||||
cat > "$WHATS_NEW_FILE" <<WHATS_EOF
|
||||
|
||||
### 🚀 What's New in ${LATEST_TAG} (${MONTH_YEAR})
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
${FEATURES}
|
||||
|
||||
WHATS_EOF
|
||||
|
||||
CONTRIBUTORS_FILE=$(mktemp)
|
||||
cat > "$CONTRIBUTORS_FILE" <<CONTRIB_EOF
|
||||
|
||||
### 🌟 Recent Contributors (${LATEST_TAG})
|
||||
|
||||
${CONTRIBUTOR_COUNT} contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
${CONTRIBUTOR_LIST}
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
CONTRIB_EOF
|
||||
|
||||
# --- Replace sections in all README files with markers ---
|
||||
|
||||
README_FILES=$(find . -maxdepth 1 -name 'README*.md' -type f | sort)
|
||||
UPDATED=0
|
||||
|
||||
for readme in $README_FILES; do
|
||||
if ! grep -q 'BEGIN:WHATS_NEW' "$readme"; then
|
||||
continue
|
||||
fi
|
||||
|
||||
python3 - "$readme" "$WHATS_NEW_FILE" "$CONTRIBUTORS_FILE" <<'PYEOF'
|
||||
import sys, re
|
||||
|
||||
readme_path = sys.argv[1]
|
||||
whats_new_path = sys.argv[2]
|
||||
contributors_path = sys.argv[3]
|
||||
|
||||
with open(readme_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
with open(whats_new_path, 'r') as f:
|
||||
whats_new = f.read()
|
||||
|
||||
with open(contributors_path, 'r') as f:
|
||||
contributors = f.read()
|
||||
|
||||
content = re.sub(
|
||||
r'(<!-- BEGIN:WHATS_NEW -->)\n.*?(<!-- END:WHATS_NEW -->)',
|
||||
r'\1\n' + whats_new + r'\2',
|
||||
content,
|
||||
flags=re.DOTALL
|
||||
)
|
||||
|
||||
content = re.sub(
|
||||
r'(<!-- BEGIN:RECENT_CONTRIBUTORS -->)\n.*?(<!-- END:RECENT_CONTRIBUTORS -->)',
|
||||
r'\1\n' + contributors + r'\2',
|
||||
content,
|
||||
flags=re.DOTALL
|
||||
)
|
||||
|
||||
with open(readme_path, 'w') as f:
|
||||
f.write(content)
|
||||
PYEOF
|
||||
|
||||
UPDATED=$((UPDATED + 1))
|
||||
done
|
||||
|
||||
rm -f "$WHATS_NEW_FILE" "$CONTRIBUTORS_FILE"
|
||||
|
||||
echo "README synced: ${LATEST_TAG} — ${CONTRIBUTOR_COUNT} contributors — ${UPDATED} files updated"
|
||||
@@ -155,11 +155,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
|
||||
# ensuring compatibility with Ubuntu 22.04+ (#3573).
|
||||
- os: ubuntu-22.04
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-22.04
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
@@ -170,6 +172,11 @@ jobs:
|
||||
target: aarch64-apple-darwin
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
target: aarch64-linux-android
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
ndk: true
|
||||
- os: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
artifact: zeroclaw.exe
|
||||
@@ -194,6 +201,10 @@ jobs:
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y ${{ matrix.cross_compiler }}
|
||||
|
||||
- name: Setup Android NDK
|
||||
if: matrix.ndk
|
||||
run: echo "$ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin" >> "$GITHUB_PATH"
|
||||
|
||||
- name: Build release
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -304,3 +315,13 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# ── Post-publish: only run after ALL artifacts are live ──────────────
|
||||
tweet:
|
||||
name: Tweet Release
|
||||
needs: [version, publish, docker, redeploy-website]
|
||||
uses: ./.github/workflows/tweet-release.yml
|
||||
with:
|
||||
release_tag: ${{ needs.version.outputs.tag }}
|
||||
release_url: https://github.com/zeroclaw-labs/zeroclaw/releases/tag/${{ needs.version.outputs.tag }}
|
||||
secrets: inherit
|
||||
|
||||
@@ -156,11 +156,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
|
||||
# ensuring compatibility with Ubuntu 22.04+ (#3573).
|
||||
- os: ubuntu-22.04
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-22.04
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
@@ -171,6 +173,11 @@ jobs:
|
||||
target: aarch64-apple-darwin
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
target: aarch64-linux-android
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
ndk: true
|
||||
- os: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
artifact: zeroclaw.exe
|
||||
@@ -195,6 +202,10 @@ jobs:
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y ${{ matrix.cross_compiler }}
|
||||
|
||||
- name: Setup Android NDK
|
||||
if: matrix.ndk
|
||||
run: echo "$ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin" >> "$GITHUB_PATH"
|
||||
|
||||
- name: Build release
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -344,3 +355,13 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# ── Post-publish: only run after ALL artifacts are live ──────────────
|
||||
tweet:
|
||||
name: Tweet Release
|
||||
needs: [validate, publish, docker, crates-io, redeploy-website]
|
||||
uses: ./.github/workflows/tweet-release.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
release_url: https://github.com/zeroclaw-labs/zeroclaw/releases/tag/${{ needs.validate.outputs.tag }}
|
||||
secrets: inherit
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
name: Sync README
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.RELEASE_TOKEN }}
|
||||
|
||||
- name: Sync README sections
|
||||
run: bash .github/scripts/sync-readme.sh
|
||||
|
||||
- name: Commit and push
|
||||
run: |
|
||||
git config user.name "ZeroClaw Bot"
|
||||
git config user.email "bot@zeroclawlabs.ai"
|
||||
if git diff --quiet -- 'README*.md'; then
|
||||
echo "No README changes — skipping commit"
|
||||
exit 0
|
||||
fi
|
||||
git add README*.md
|
||||
git commit -m "docs(readme): auto-sync What's New and Contributors"
|
||||
git push origin HEAD:master
|
||||
@@ -1,8 +1,26 @@
|
||||
name: Tweet Release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
# Called by release workflows AFTER all publish steps (docker, crates, website) complete.
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Release tag (e.g. v0.3.0 or v0.3.0-beta.42)"
|
||||
required: true
|
||||
type: string
|
||||
release_url:
|
||||
description: "GitHub Release URL"
|
||||
required: true
|
||||
type: string
|
||||
secrets:
|
||||
TWITTER_CONSUMER_API_KEY:
|
||||
required: false
|
||||
TWITTER_CONSUMER_API_SECRET_KEY:
|
||||
required: false
|
||||
TWITTER_ACCESS_TOKEN:
|
||||
required: false
|
||||
TWITTER_ACCESS_TOKEN_SECRET:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tweet_text:
|
||||
@@ -26,7 +44,7 @@ jobs:
|
||||
id: check
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name || '' }}
|
||||
RELEASE_TAG: ${{ inputs.release_tag || '' }}
|
||||
MANUAL_TEXT: ${{ inputs.tweet_text || '' }}
|
||||
run: |
|
||||
# Manual dispatch always proceeds
|
||||
@@ -62,8 +80,8 @@ jobs:
|
||||
if: steps.check.outputs.skip != 'true'
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name || '' }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url || '' }}
|
||||
RELEASE_TAG: ${{ inputs.release_tag || '' }}
|
||||
RELEASE_URL: ${{ inputs.release_url || '' }}
|
||||
MANUAL_TEXT: ${{ inputs.tweet_text || '' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
+4
-1
@@ -43,4 +43,7 @@ credentials.json
|
||||
lcov.info
|
||||
|
||||
# IDE's stuff
|
||||
.idea
|
||||
.idea
|
||||
|
||||
# Wrangler cache
|
||||
.wrangler/
|
||||
Generated
+1
-1
@@ -7945,7 +7945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.0"
|
||||
version = "0.3.4"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-imap",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.0"
|
||||
version = "0.3.4"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ zeroclaw version # عرض الإصدار ومعلومات البنا
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Postaveno studenty a členy komunit Harvard, MIT a Sundai.Club.
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ Stavíme v open source protože nejlepší nápady přicházejí odkudkoliv. Pok
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Hvis ZeroClaw er nyttigt for dig, overvej venligst at købe os en kaffe:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -92,11 +92,11 @@ Erstellt von Studenten und Mitgliedern der Harvard, MIT und Sundai.Club Gemeinsc
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -429,11 +429,13 @@ Wir bauen in Open Source, weil die besten Ideen von überall kommen. Wenn du das
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -56,11 +56,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -189,11 +189,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Construido por estudiantes y miembros de las comunidades de Harvard, MIT y Sunda
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ Construimos en código abierto porque las mejores ideas vienen de todas partes.
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Jos ZeroClaw on hyödyllinen sinulle, harkitse kahvin ostamista meille:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -86,11 +86,11 @@ Construit par des étudiants et membres des communautés Harvard, MIT et Sundai.
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -423,11 +423,13 @@ Nous construisons en open source parce que les meilleures idées viennent de par
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -208,11 +208,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Ha a ZeroClaw hasznos az Ön számára, kérjük, fontolja meg, hogy vesz nekün
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Jika ZeroClaw berguna bagi Anda, mohon pertimbangkan untuk membelikan kami kopi:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Costruito da studenti e membri delle comunità Harvard, MIT e Sundai.Club.
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ Costruiamo in open source perché le migliori idee vengono da ovunque. Se stai l
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -77,11 +77,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -237,11 +237,13 @@ zeroclaw agent --provider anthropic -m "hello"
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Harvard, MIT, 그리고 Sundai.Club 커뮤니티의 학생들과 멤버들이
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ ZeroClaw가 당신의 작업에 도움이 되었고 지속적인 개발을 지
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
@@ -85,13 +85,6 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.
|
||||
<p align="center"><code>Trait-driven architecture · secure-by-default runtime · provider/channel/tool swappable · pluggable everything</code></p>
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
### 📢 Announcements
|
||||
@@ -481,15 +474,6 @@ A heartfelt thank you to the communities and institutions that inspire and fuel
|
||||
We're building in the open because the best ideas come from everywhere. If you're reading this, you're part of it. Welcome. 🦀❤️
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
<!-- END:RECENT_CONTRIBUTORS -->
|
||||
|
||||
## ⚠️ Official Repository & Impersonation Warning
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Hvis ZeroClaw er nyttig for deg, vennligst vurder å kjøpe oss en kaffe:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Gebouwd door studenten en leden van de Harvard, MIT en Sundai.Club gemeenschappe
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ We bouwen in open source omdat de beste ideeën van overal komen. Als je dit lee
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Zbudowany przez studentów i członków społeczności Harvard, MIT i Sundai.Clu
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ Budujemy w open source ponieważ najlepsze pomysły przychodzą zewsząd. Jeśli
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Construído por estudantes e membros das comunidades Harvard, MIT e Sundai.Club.
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ Construímos em código aberto porque as melhores ideias vêm de todo lugar. Se
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Dacă ZeroClaw îți este util, te rugăm să iei în considerare să ne cumperi
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -77,11 +77,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -237,11 +237,13 @@ zeroclaw agent --provider anthropic -m "hello"
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ Om ZeroClaw är användbart för dig, vänligen överväg att köpa en kaffe til
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Binuo ng mga mag-aaral at miyembro ng Harvard, MIT, at Sundai.Club na komunidad.
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ Kami ay bumubuo sa open source dahil ang mga pinakamahusay na ideya ay nagmumula
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -88,11 +88,11 @@ Harvard, MIT ve Sundai.Club topluluklarının öğrencileri ve üyeleri tarafın
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -425,11 +425,13 @@ En iyi fikirler her yerden geldiği için açık kaynakta inşa ediyoruz. Bunu o
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -190,11 +190,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -59,11 +59,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -208,11 +208,13 @@ channels:
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -86,11 +86,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -468,11 +468,13 @@ Chúng tôi xây dựng công khai vì ý tưởng hay đến từ khắp nơi.
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
+6
-4
@@ -77,11 +77,11 @@
|
||||
|
||||
<!-- BEGIN:WHATS_NEW -->
|
||||
|
||||
### 🚀 What's New in v0.3.0-beta.200 (March 2026)
|
||||
### 🚀 What's New in v0.3.1 (March 2026)
|
||||
|
||||
| Area | Highlights |
|
||||
|---|---|
|
||||
| General | Incremental improvements and polish |
|
||||
| ci | add Termux (aarch64-linux-android) release target |
|
||||
|
||||
<!-- END:WHATS_NEW -->
|
||||
|
||||
@@ -242,11 +242,13 @@ zeroclaw agent --provider anthropic -m "hello"
|
||||
|
||||
<!-- BEGIN:RECENT_CONTRIBUTORS -->
|
||||
|
||||
### 🌟 Recent Contributors (v0.3.0-beta.200)
|
||||
### 🌟 Recent Contributors (v0.3.1)
|
||||
|
||||
1 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
3 contributors shipped features, fixes, and improvements in this release cycle:
|
||||
|
||||
- **Argenis**
|
||||
- **argenis de la rosa**
|
||||
- **Claude Opus 4.6**
|
||||
|
||||
Thank you to everyone who opened issues, reviewed PRs, translated docs, and helped test. Every contribution matters. 🦀
|
||||
|
||||
|
||||
Executable
+261
@@ -0,0 +1,261 @@
|
||||
#!/usr/bin/env bash
|
||||
# Termux release validation script
|
||||
# Validates the aarch64-linux-android release artifact for Termux compatibility.
|
||||
#
|
||||
# Usage:
|
||||
# ./dev/test-termux-release.sh [version]
|
||||
#
|
||||
# Examples:
|
||||
# ./dev/test-termux-release.sh 0.3.1
|
||||
# ./dev/test-termux-release.sh # auto-detects from Cargo.toml
|
||||
#
|
||||
set -euo pipefail
|
||||
|
||||
BLUE='\033[0;34m'
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[0;33m'
|
||||
BOLD='\033[1m'
|
||||
DIM='\033[2m'
|
||||
RESET='\033[0m'
|
||||
|
||||
pass() { echo -e " ${GREEN}✓${RESET} $*"; }
|
||||
fail() { echo -e " ${RED}✗${RESET} $*"; FAILURES=$((FAILURES + 1)); }
|
||||
info() { echo -e "${BLUE}→${RESET} ${BOLD}$*${RESET}"; }
|
||||
warn() { echo -e "${YELLOW}!${RESET} $*"; }
|
||||
|
||||
FAILURES=0
|
||||
TARGET="aarch64-linux-android"
|
||||
VERSION="${1:-}"
|
||||
|
||||
if [[ -z "$VERSION" ]]; then
|
||||
if [[ -f Cargo.toml ]]; then
|
||||
VERSION=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "$VERSION" ]]; then
|
||||
echo "Usage: $0 <version>"
|
||||
echo " e.g. $0 0.3.1"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TAG="v${VERSION}"
|
||||
ASSET_NAME="zeroclaw-${TARGET}.tar.gz"
|
||||
ASSET_URL="https://github.com/zeroclaw-labs/zeroclaw/releases/download/${TAG}/${ASSET_NAME}"
|
||||
TEMP_DIR="$(mktemp -d -t zeroclaw-termux-test-XXXXXX)"
|
||||
|
||||
cleanup() { rm -rf "$TEMP_DIR"; }
|
||||
trap cleanup EXIT
|
||||
|
||||
echo
|
||||
echo -e "${BOLD}Termux Release Validation — ${TAG}${RESET}"
|
||||
echo -e "${DIM}Target: ${TARGET}${RESET}"
|
||||
echo
|
||||
|
||||
# --- Test 1: Release tag exists ---
|
||||
info "Checking release tag ${TAG}"
|
||||
if gh release view "$TAG" >/dev/null 2>&1; then
|
||||
pass "Release ${TAG} exists"
|
||||
else
|
||||
fail "Release ${TAG} not found"
|
||||
echo -e "${RED}Release has not been published yet. Wait for the release workflow to complete.${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Test 2: Android asset is listed ---
|
||||
info "Checking for ${ASSET_NAME} in release assets"
|
||||
ASSETS=$(gh release view "$TAG" --json assets -q '.assets[].name')
|
||||
if echo "$ASSETS" | grep -q "$ASSET_NAME"; then
|
||||
pass "Asset ${ASSET_NAME} found in release"
|
||||
else
|
||||
fail "Asset ${ASSET_NAME} not found in release"
|
||||
echo "Available assets:"
|
||||
echo "$ASSETS" | sed 's/^/ /'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Test 3: Download the asset ---
|
||||
info "Downloading ${ASSET_NAME}"
|
||||
if curl -fsSL "$ASSET_URL" -o "$TEMP_DIR/$ASSET_NAME"; then
|
||||
FILESIZE=$(wc -c < "$TEMP_DIR/$ASSET_NAME" | tr -d ' ')
|
||||
pass "Downloaded successfully (${FILESIZE} bytes)"
|
||||
else
|
||||
fail "Download failed from ${ASSET_URL}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Test 4: Archive integrity ---
|
||||
info "Verifying archive integrity"
|
||||
if tar -tzf "$TEMP_DIR/$ASSET_NAME" >/dev/null 2>&1; then
|
||||
pass "Archive is a valid gzip tar"
|
||||
else
|
||||
fail "Archive is corrupted or not a valid tar.gz"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Test 5: Contains zeroclaw binary ---
|
||||
info "Checking archive contents"
|
||||
CONTENTS=$(tar -tzf "$TEMP_DIR/$ASSET_NAME")
|
||||
if echo "$CONTENTS" | grep -q "^zeroclaw$"; then
|
||||
pass "Archive contains 'zeroclaw' binary"
|
||||
else
|
||||
fail "Archive does not contain 'zeroclaw' binary"
|
||||
echo "Contents:"
|
||||
echo "$CONTENTS" | sed 's/^/ /'
|
||||
fi
|
||||
|
||||
# --- Test 6: Extract and inspect binary ---
|
||||
info "Extracting and inspecting binary"
|
||||
tar -xzf "$TEMP_DIR/$ASSET_NAME" -C "$TEMP_DIR"
|
||||
BINARY="$TEMP_DIR/zeroclaw"
|
||||
|
||||
if [[ -f "$BINARY" ]]; then
|
||||
pass "Binary extracted"
|
||||
else
|
||||
fail "Binary not found after extraction"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Test 7: ELF format and architecture ---
|
||||
info "Checking binary format"
|
||||
FILE_INFO=$(file "$BINARY")
|
||||
if echo "$FILE_INFO" | grep -q "ELF"; then
|
||||
pass "Binary is ELF format"
|
||||
else
|
||||
fail "Binary is not ELF format: $FILE_INFO"
|
||||
fi
|
||||
|
||||
if echo "$FILE_INFO" | grep -qi "aarch64\|ARM aarch64"; then
|
||||
pass "Binary targets aarch64 architecture"
|
||||
else
|
||||
fail "Binary does not target aarch64: $FILE_INFO"
|
||||
fi
|
||||
|
||||
if echo "$FILE_INFO" | grep -qi "android\|bionic"; then
|
||||
pass "Binary is linked for Android/Bionic"
|
||||
else
|
||||
# Android binaries may not always show "android" in file output,
|
||||
# check with readelf if available
|
||||
if command -v readelf >/dev/null 2>&1; then
|
||||
INTERP=$(readelf -l "$BINARY" 2>/dev/null | grep -o '/[^ ]*linker[^ ]*' || true)
|
||||
if echo "$INTERP" | grep -qi "android\|bionic"; then
|
||||
pass "Binary uses Android linker: $INTERP"
|
||||
else
|
||||
warn "Could not confirm Android linkage (interpreter: ${INTERP:-unknown})"
|
||||
warn "file output: $FILE_INFO"
|
||||
fi
|
||||
else
|
||||
warn "Could not confirm Android linkage (readelf not available)"
|
||||
warn "file output: $FILE_INFO"
|
||||
fi
|
||||
fi
|
||||
|
||||
# --- Test 8: Binary is stripped ---
|
||||
info "Checking binary optimization"
|
||||
if echo "$FILE_INFO" | grep -q "stripped"; then
|
||||
pass "Binary is stripped (release optimized)"
|
||||
else
|
||||
warn "Binary may not be stripped"
|
||||
fi
|
||||
|
||||
# --- Test 9: Binary is not dynamically linked to glibc ---
|
||||
info "Checking for glibc dependencies"
|
||||
if command -v readelf >/dev/null 2>&1; then
|
||||
NEEDED=$(readelf -d "$BINARY" 2>/dev/null | grep NEEDED || true)
|
||||
if echo "$NEEDED" | grep -qi "libc\.so\.\|libpthread\|libdl"; then
|
||||
# Check if it's glibc or bionic
|
||||
if echo "$NEEDED" | grep -qi "libc\.so\.6"; then
|
||||
fail "Binary links against glibc (libc.so.6) — will not work on Termux"
|
||||
else
|
||||
pass "Binary links against libc (likely Bionic)"
|
||||
fi
|
||||
else
|
||||
pass "No glibc dependencies detected"
|
||||
fi
|
||||
else
|
||||
warn "readelf not available — skipping dynamic library check"
|
||||
fi
|
||||
|
||||
# --- Test 10: SHA256 checksum verification ---
|
||||
info "Verifying SHA256 checksum"
|
||||
CHECKSUMS_URL="https://github.com/zeroclaw-labs/zeroclaw/releases/download/${TAG}/SHA256SUMS"
|
||||
if curl -fsSL "$CHECKSUMS_URL" -o "$TEMP_DIR/SHA256SUMS" 2>/dev/null; then
|
||||
EXPECTED=$(grep "$ASSET_NAME" "$TEMP_DIR/SHA256SUMS" | awk '{print $1}')
|
||||
if [[ -n "$EXPECTED" ]]; then
|
||||
if command -v sha256sum >/dev/null 2>&1; then
|
||||
ACTUAL=$(sha256sum "$TEMP_DIR/$ASSET_NAME" | awk '{print $1}')
|
||||
elif command -v shasum >/dev/null 2>&1; then
|
||||
ACTUAL=$(shasum -a 256 "$TEMP_DIR/$ASSET_NAME" | awk '{print $1}')
|
||||
else
|
||||
warn "No sha256sum or shasum available"
|
||||
ACTUAL=""
|
||||
fi
|
||||
|
||||
if [[ -n "$ACTUAL" && "$ACTUAL" == "$EXPECTED" ]]; then
|
||||
pass "SHA256 checksum matches"
|
||||
elif [[ -n "$ACTUAL" ]]; then
|
||||
fail "SHA256 mismatch: expected=$EXPECTED actual=$ACTUAL"
|
||||
fi
|
||||
else
|
||||
warn "No checksum entry for ${ASSET_NAME} in SHA256SUMS"
|
||||
fi
|
||||
else
|
||||
warn "Could not download SHA256SUMS"
|
||||
fi
|
||||
|
||||
# --- Test 11: install.sh Termux detection ---
|
||||
info "Validating install.sh Termux detection"
|
||||
INSTALL_SH="install.sh"
|
||||
if [[ ! -f "$INSTALL_SH" ]]; then
|
||||
INSTALL_SH="$(dirname "$0")/../install.sh"
|
||||
fi
|
||||
|
||||
if [[ -f "$INSTALL_SH" ]]; then
|
||||
if grep -q 'TERMUX_VERSION' "$INSTALL_SH"; then
|
||||
pass "install.sh checks TERMUX_VERSION"
|
||||
else
|
||||
fail "install.sh does not check TERMUX_VERSION"
|
||||
fi
|
||||
|
||||
if grep -q 'aarch64-linux-android' "$INSTALL_SH"; then
|
||||
pass "install.sh maps to aarch64-linux-android target"
|
||||
else
|
||||
fail "install.sh does not map to aarch64-linux-android"
|
||||
fi
|
||||
|
||||
# Simulate Termux detection (mock uname as Linux since we may run on macOS)
|
||||
detect_result=$(
|
||||
bash -c '
|
||||
TERMUX_VERSION="0.118"
|
||||
os="Linux"
|
||||
arch="aarch64"
|
||||
case "$os:$arch" in
|
||||
Linux:aarch64|Linux:arm64)
|
||||
if [[ -n "${TERMUX_VERSION:-}" || -d "/data/data/com.termux" ]]; then
|
||||
echo "aarch64-linux-android"
|
||||
else
|
||||
echo "aarch64-unknown-linux-gnu"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
'
|
||||
)
|
||||
if [[ "$detect_result" == "aarch64-linux-android" ]]; then
|
||||
pass "Termux detection returns correct target (simulated)"
|
||||
else
|
||||
fail "Termux detection returned: $detect_result (expected aarch64-linux-android)"
|
||||
fi
|
||||
else
|
||||
warn "install.sh not found — skipping detection tests"
|
||||
fi
|
||||
|
||||
# --- Summary ---
|
||||
echo
|
||||
if [[ "$FAILURES" -eq 0 ]]; then
|
||||
echo -e "${GREEN}${BOLD}All tests passed!${RESET}"
|
||||
echo -e "${DIM}The Termux release artifact for ${TAG} is valid.${RESET}"
|
||||
else
|
||||
echo -e "${RED}${BOLD}${FAILURES} test(s) failed.${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
+19
-1
@@ -187,7 +187,12 @@ detect_release_target() {
|
||||
echo "x86_64-unknown-linux-gnu"
|
||||
;;
|
||||
Linux:aarch64|Linux:arm64)
|
||||
echo "aarch64-unknown-linux-gnu"
|
||||
# Termux on Android needs the android target, not linux-gnu
|
||||
if [[ -n "${TERMUX_VERSION:-}" || -d "/data/data/com.termux" ]]; then
|
||||
echo "aarch64-linux-android"
|
||||
else
|
||||
echo "aarch64-unknown-linux-gnu"
|
||||
fi
|
||||
;;
|
||||
Linux:armv7l|Linux:armv6l)
|
||||
echo "armv7-unknown-linux-gnueabihf"
|
||||
@@ -534,6 +539,8 @@ install_system_deps() {
|
||||
openssl \
|
||||
perl \
|
||||
ca-certificates
|
||||
elif have_cmd pkg && [[ -n "${TERMUX_VERSION:-}" ]]; then
|
||||
pkg install -y build-essential pkg-config git curl openssl perl
|
||||
else
|
||||
warn "Unsupported Linux distribution. Install compiler toolchain + pkg-config + git + curl + OpenSSL headers + perl manually."
|
||||
fi
|
||||
@@ -1192,6 +1199,17 @@ fi
|
||||
|
||||
if [[ "$SKIP_INSTALL" == false ]]; then
|
||||
step_dot "Installing zeroclaw to cargo bin"
|
||||
|
||||
# Clean up stale cargo install tracking from the old "zeroclaw" package name
|
||||
# (renamed to "zeroclawlabs"). Without this, `cargo install zeroclawlabs` from
|
||||
# crates.io fails with "binary already exists as part of `zeroclaw`".
|
||||
if have_cmd cargo; then
|
||||
if [[ -f "$HOME/.cargo/.crates.toml" ]] && grep -q '^"zeroclaw ' "$HOME/.cargo/.crates.toml" 2>/dev/null; then
|
||||
step_dot "Removing stale cargo tracking for old 'zeroclaw' package name"
|
||||
cargo uninstall zeroclaw 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
|
||||
cargo install --path "$WORK_DIR" --force --locked
|
||||
step_ok "ZeroClaw installed"
|
||||
else
|
||||
|
||||
+78
-1
@@ -37,6 +37,7 @@ pub struct Agent {
|
||||
classification_config: crate::config::QueryClassificationConfig,
|
||||
available_hints: Vec<String>,
|
||||
route_model_by_hint: HashMap<String, String>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder {
|
||||
@@ -58,6 +59,7 @@ pub struct AgentBuilder {
|
||||
classification_config: Option<crate::config::QueryClassificationConfig>,
|
||||
available_hints: Option<Vec<String>>,
|
||||
route_model_by_hint: Option<HashMap<String, String>>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
@@ -81,6 +83,7 @@ impl AgentBuilder {
|
||||
classification_config: None,
|
||||
available_hints: None,
|
||||
route_model_by_hint: None,
|
||||
allowed_tools: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,10 +183,19 @@ impl AgentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn allowed_tools(mut self, allowed_tools: Option<Vec<String>>) -> Self {
|
||||
self.allowed_tools = allowed_tools;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Agent> {
|
||||
let tools = self
|
||||
let mut tools = self
|
||||
.tools
|
||||
.ok_or_else(|| anyhow::anyhow!("tools are required"))?;
|
||||
let allowed = self.allowed_tools.clone();
|
||||
if let Some(ref allow_list) = allowed {
|
||||
tools.retain(|t| allow_list.iter().any(|name| name == t.name()));
|
||||
}
|
||||
let tool_specs = tools.iter().map(|tool| tool.spec()).collect();
|
||||
|
||||
Ok(Agent {
|
||||
@@ -223,6 +235,7 @@ impl AgentBuilder {
|
||||
classification_config: self.classification_config.unwrap_or_default(),
|
||||
available_hints: self.available_hints.unwrap_or_default(),
|
||||
route_model_by_hint: self.route_model_by_hint.unwrap_or_default(),
|
||||
allowed_tools: allowed,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -892,4 +905,68 @@ mod tests {
|
||||
let seen = seen_models.lock();
|
||||
assert_eq!(seen.as_slice(), &["hint:fast".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_allowed_tools_none_keeps_all_tools() {
|
||||
let provider = Box::new(MockProvider {
|
||||
responses: Mutex::new(vec![]),
|
||||
});
|
||||
|
||||
let memory_cfg = crate::config::MemoryConfig {
|
||||
backend: "none".into(),
|
||||
..crate::config::MemoryConfig::default()
|
||||
};
|
||||
let mem: Arc<dyn Memory> = Arc::from(
|
||||
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
|
||||
.expect("memory creation should succeed with valid config"),
|
||||
);
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||
let agent = Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(vec![Box::new(MockTool)])
|
||||
.memory(mem)
|
||||
.observer(observer)
|
||||
.tool_dispatcher(Box::new(NativeToolDispatcher))
|
||||
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||
.allowed_tools(None)
|
||||
.build()
|
||||
.expect("agent builder should succeed with valid config");
|
||||
|
||||
assert_eq!(agent.tool_specs.len(), 1);
|
||||
assert_eq!(agent.tool_specs[0].name, "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_allowed_tools_some_filters_tools() {
|
||||
let provider = Box::new(MockProvider {
|
||||
responses: Mutex::new(vec![]),
|
||||
});
|
||||
|
||||
let memory_cfg = crate::config::MemoryConfig {
|
||||
backend: "none".into(),
|
||||
..crate::config::MemoryConfig::default()
|
||||
};
|
||||
let mem: Arc<dyn Memory> = Arc::from(
|
||||
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
|
||||
.expect("memory creation should succeed with valid config"),
|
||||
);
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||
let agent = Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(vec![Box::new(MockTool)])
|
||||
.memory(mem)
|
||||
.observer(observer)
|
||||
.tool_dispatcher(Box::new(NativeToolDispatcher))
|
||||
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||
.allowed_tools(Some(vec!["nonexistent".to_string()]))
|
||||
.build()
|
||||
.expect("agent builder should succeed with valid config");
|
||||
|
||||
assert!(
|
||||
agent.tool_specs.is_empty(),
|
||||
"No tools should match a non-existent allowlist entry"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+287
-12
@@ -93,6 +93,24 @@ pub(crate) fn filter_tool_specs_for_turn(
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Filters a tool spec list by an optional capability allowlist.
|
||||
///
|
||||
/// When `allowed` is `None`, all specs pass through unchanged.
|
||||
/// When `allowed` is `Some(list)`, only specs whose name appears in the list
|
||||
/// are retained. Unknown names in the allowlist are silently ignored.
|
||||
pub(crate) fn filter_by_allowed_tools(
|
||||
specs: Vec<crate::tools::ToolSpec>,
|
||||
allowed: Option<&[String]>,
|
||||
) -> Vec<crate::tools::ToolSpec> {
|
||||
match allowed {
|
||||
None => specs,
|
||||
Some(list) => specs
|
||||
.into_iter()
|
||||
.filter(|spec| list.iter().any(|name| name == &spec.name))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the list of MCP tool names that should be excluded for a given turn
|
||||
/// based on `tool_filter_groups` and the user message.
|
||||
///
|
||||
@@ -195,6 +213,18 @@ const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000;
|
||||
/// Max characters retained in stored compaction summary.
|
||||
const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000;
|
||||
|
||||
/// Estimate token count for a message history using ~4 chars/token heuristic.
|
||||
/// Includes a small overhead per message for role/framing tokens.
|
||||
fn estimate_history_tokens(history: &[ChatMessage]) -> usize {
|
||||
history
|
||||
.iter()
|
||||
.map(|m| {
|
||||
// ~4 chars per token + ~4 framing tokens per message (role, delimiters)
|
||||
m.content.len().div_ceil(4) + 4
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Minimum interval between progress sends to avoid flooding the draft channel.
|
||||
pub(crate) const PROGRESS_MIN_INTERVAL_MS: u64 = 500;
|
||||
|
||||
@@ -288,6 +318,7 @@ async fn auto_compact_history(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
max_history: usize,
|
||||
max_context_tokens: usize,
|
||||
) -> Result<bool> {
|
||||
let has_system = history.first().map_or(false, |m| m.role == "system");
|
||||
let non_system_count = if has_system {
|
||||
@@ -296,7 +327,10 @@ async fn auto_compact_history(
|
||||
history.len()
|
||||
};
|
||||
|
||||
if non_system_count <= max_history {
|
||||
let estimated_tokens = estimate_history_tokens(history);
|
||||
|
||||
// Trigger compaction when either token budget OR message count is exceeded.
|
||||
if estimated_tokens <= max_context_tokens && non_system_count <= max_history {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
@@ -307,7 +341,16 @@ async fn auto_compact_history(
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let compact_end = start + compact_count;
|
||||
let mut compact_end = start + compact_count;
|
||||
|
||||
// Snap compact_end to a user-turn boundary so we don't split mid-conversation.
|
||||
while compact_end > start && history.get(compact_end).map_or(false, |m| m.role != "user") {
|
||||
compact_end -= 1;
|
||||
}
|
||||
if compact_end <= start {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
|
||||
let transcript = build_compaction_transcript(&to_compact);
|
||||
|
||||
@@ -2635,6 +2678,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"arguments": scrub_credentials(&tool_args.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!(
|
||||
"\u{274c} {}: {}\n",
|
||||
call.name,
|
||||
truncate_with_ellipsis(&scrub_credentials(&cancelled), 200)
|
||||
))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
call.name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2662,11 +2714,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
arguments: tool_args.clone(),
|
||||
};
|
||||
|
||||
// Only prompt interactively on CLI; auto-approve on other channels.
|
||||
let decision = if channel_name == "cli" {
|
||||
mgr.prompt_cli(&request)
|
||||
// Interactive CLI: prompt the operator.
|
||||
// Non-interactive (channels): auto-deny since no operator
|
||||
// is present to approve.
|
||||
let decision = if mgr.is_non_interactive() {
|
||||
ApprovalResponse::No
|
||||
} else {
|
||||
ApprovalResponse::Yes
|
||||
mgr.prompt_cli(&request)
|
||||
};
|
||||
|
||||
mgr.record_decision(&tool_name, &tool_args, decision, channel_name);
|
||||
@@ -2687,6 +2741,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"arguments": scrub_credentials(&tool_args.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!("\u{274c} {}: {}\n", tool_name, denied))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
tool_name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2723,6 +2782,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"deduplicated": true,
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!("\u{274c} {}: {}\n", tool_name, duplicate))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
tool_name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2825,13 +2889,19 @@ pub(crate) async fn run_tool_call_loop(
|
||||
// ── Progress: tool completion ───────────────────────
|
||||
if let Some(ref tx) = on_delta {
|
||||
let secs = outcome.duration.as_secs();
|
||||
let icon = if outcome.success {
|
||||
"\u{2705}"
|
||||
let progress_msg = if outcome.success {
|
||||
format!("\u{2705} {} ({secs}s)\n", call.name)
|
||||
} else if let Some(ref reason) = outcome.error_reason {
|
||||
format!(
|
||||
"\u{274c} {} ({secs}s): {}\n",
|
||||
call.name,
|
||||
truncate_with_ellipsis(reason, 200)
|
||||
)
|
||||
} else {
|
||||
"\u{274c}"
|
||||
format!("\u{274c} {} ({secs}s)\n", call.name)
|
||||
};
|
||||
tracing::debug!(tool = %call.name, secs, "Sending progress complete to draft");
|
||||
let _ = tx.send(format!("{icon} {} ({secs}s)\n", call.name)).await;
|
||||
let _ = tx.send(progress_msg).await;
|
||||
}
|
||||
|
||||
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
|
||||
@@ -2942,6 +3012,7 @@ pub async fn run(
|
||||
peripheral_overrides: Vec<String>,
|
||||
interactive: bool,
|
||||
session_state_file: Option<PathBuf>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
) -> Result<String> {
|
||||
// ── Wire up agnostic subsystems ──────────────────────────────
|
||||
let base_observer = observability::create_observer(&config.observability);
|
||||
@@ -3003,6 +3074,19 @@ pub async fn run(
|
||||
tools_registry.extend(peripheral_tools);
|
||||
}
|
||||
|
||||
// ── Capability-based tool access control ─────────────────────
|
||||
// When `allowed_tools` is `Some(list)`, restrict the tool registry to only
|
||||
// those tools whose name appears in the list. Unknown names are silently
|
||||
// ignored. When `None`, all tools remain available (backward compatible).
|
||||
if let Some(ref allow_list) = allowed_tools {
|
||||
tools_registry.retain(|t| allow_list.iter().any(|name| name == t.name()));
|
||||
tracing::info!(
|
||||
allowed = allow_list.len(),
|
||||
retained = tools_registry.len(),
|
||||
"Applied capability-based tool access filter"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Wire MCP tools (non-fatal) — CLI path ────────────────────
|
||||
// NOTE: MCP tools are injected after built-in tool filtering
|
||||
// (filter_primary_agent_tools_or_fail / agent.allowed_tools / agent.denied_tools).
|
||||
@@ -3508,6 +3592,7 @@ pub async fn run(
|
||||
provider.as_ref(),
|
||||
model_name,
|
||||
config.agent.max_history_messages,
|
||||
config.agent.max_context_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -3821,7 +3906,7 @@ mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials() {
|
||||
fn scrub_credentials_redacts_bearer_token() {
|
||||
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
|
||||
let scrubbed = scrub_credentials(input);
|
||||
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
|
||||
@@ -3832,7 +3917,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials_json() {
|
||||
fn scrub_credentials_redacts_json_api_key() {
|
||||
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
|
||||
let scrubbed = scrub_credentials(input);
|
||||
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
|
||||
@@ -4109,6 +4194,52 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that always returns a failure with a given error reason.
|
||||
struct FailingTool {
|
||||
tool_name: String,
|
||||
error_reason: String,
|
||||
}
|
||||
|
||||
impl FailingTool {
|
||||
fn new(name: &str, error_reason: &str) -> Self {
|
||||
Self {
|
||||
tool_name: name.to_string(),
|
||||
error_reason: error_reason.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FailingTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"A tool that always fails for testing failure surfacing"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
) -> anyhow::Result<crate::tools::ToolResult> {
|
||||
Ok(crate::tools::ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(self.error_reason.clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
@@ -6449,4 +6580,148 @@ Let me check the result."#;
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "BROWSE the site");
|
||||
assert_eq!(result.len(), 1);
|
||||
}
|
||||
|
||||
// ── Token-based compaction tests ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_empty() {
|
||||
assert_eq!(super::estimate_history_tokens(&[]), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_single_message() {
|
||||
let history = vec![ChatMessage::user("hello world")]; // 11 chars
|
||||
let tokens = super::estimate_history_tokens(&history);
|
||||
// 11.div_ceil(4) + 4 = 3 + 4 = 7
|
||||
assert_eq!(tokens, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_multiple_messages() {
|
||||
let history = vec![
|
||||
ChatMessage::system("You are helpful."), // 16 chars → 4 + 4 = 8
|
||||
ChatMessage::user("What is Rust?"), // 13 chars → 4 + 4 = 8
|
||||
ChatMessage::assistant("A language."), // 11 chars → 3 + 4 = 7
|
||||
];
|
||||
let tokens = super::estimate_history_tokens(&history);
|
||||
assert_eq!(tokens, 23);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_surfaces_tool_failure_reason_in_on_delta() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"failing_shell","arguments":{"command":"rm -rf /"}}
|
||||
</tool_call>"#,
|
||||
"I could not execute that command.",
|
||||
]);
|
||||
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(FailingTool::new(
|
||||
"failing_shell",
|
||||
"Command not allowed by security policy: rm -rf /",
|
||||
))];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("delete everything"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"telegram",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
Some(tx),
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
|
||||
// Collect all messages sent to the on_delta channel.
|
||||
let mut deltas = Vec::new();
|
||||
while let Ok(msg) = rx.try_recv() {
|
||||
deltas.push(msg);
|
||||
}
|
||||
|
||||
let all_deltas = deltas.join("");
|
||||
|
||||
// The failure reason should appear in the progress messages.
|
||||
assert!(
|
||||
all_deltas.contains("Command not allowed by security policy"),
|
||||
"on_delta messages should include the tool failure reason, got: {all_deltas}"
|
||||
);
|
||||
|
||||
// Should also contain the cross mark (❌) icon to indicate failure.
|
||||
assert!(
|
||||
all_deltas.contains('\u{274c}'),
|
||||
"on_delta messages should include ❌ for failed tool calls, got: {all_deltas}"
|
||||
);
|
||||
|
||||
assert_eq!(result, "I could not execute that command.");
|
||||
}
|
||||
|
||||
// ── filter_by_allowed_tools tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn filter_by_allowed_tools_none_passes_all() {
|
||||
let specs = vec![
|
||||
make_spec("shell"),
|
||||
make_spec("memory_store"),
|
||||
make_spec("file_read"),
|
||||
];
|
||||
let result = filter_by_allowed_tools(specs, None);
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_by_allowed_tools_some_restricts_to_listed() {
|
||||
let specs = vec![
|
||||
make_spec("shell"),
|
||||
make_spec("memory_store"),
|
||||
make_spec("file_read"),
|
||||
];
|
||||
let allowed = vec!["shell".to_string(), "memory_store".to_string()];
|
||||
let result = filter_by_allowed_tools(specs, Some(&allowed));
|
||||
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
|
||||
assert_eq!(names.len(), 2);
|
||||
assert!(names.contains(&"shell"));
|
||||
assert!(names.contains(&"memory_store"));
|
||||
assert!(!names.contains(&"file_read"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_by_allowed_tools_unknown_names_silently_ignored() {
|
||||
let specs = vec![make_spec("shell"), make_spec("file_read")];
|
||||
let allowed = vec![
|
||||
"shell".to_string(),
|
||||
"nonexistent_tool".to_string(),
|
||||
"another_missing".to_string(),
|
||||
];
|
||||
let result = filter_by_allowed_tools(specs, Some(&allowed));
|
||||
let names: Vec<&str> = result.iter().map(|s| s.name.as_str()).collect();
|
||||
assert_eq!(names.len(), 1);
|
||||
assert!(names.contains(&"shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_by_allowed_tools_empty_list_excludes_all() {
|
||||
let specs = vec![make_spec("shell"), make_spec("file_read")];
|
||||
let allowed: Vec<String> = vec![];
|
||||
let result = filter_by_allowed_tools(specs, Some(&allowed));
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
+128
-4
@@ -44,11 +44,18 @@ pub struct ApprovalLogEntry {
|
||||
|
||||
// ── ApprovalManager ──────────────────────────────────────────────
|
||||
|
||||
/// Manages the interactive approval workflow.
|
||||
/// Manages the approval workflow for tool calls.
|
||||
///
|
||||
/// - Checks config-level `auto_approve` / `always_ask` lists
|
||||
/// - Maintains a session-scoped "always" allowlist
|
||||
/// - Records an audit trail of all decisions
|
||||
///
|
||||
/// Two modes:
|
||||
/// - **Interactive** (CLI): tools needing approval trigger a stdin prompt.
|
||||
/// - **Non-interactive** (channels): tools needing approval are auto-denied
|
||||
/// because there is no interactive operator to approve them. `auto_approve`
|
||||
/// policy is still enforced, and `always_ask` / supervised-default tools are
|
||||
/// denied rather than silently allowed.
|
||||
pub struct ApprovalManager {
|
||||
/// Tools that never need approval (from config).
|
||||
auto_approve: HashSet<String>,
|
||||
@@ -56,6 +63,9 @@ pub struct ApprovalManager {
|
||||
always_ask: HashSet<String>,
|
||||
/// Autonomy level from config.
|
||||
autonomy_level: AutonomyLevel,
|
||||
/// When `true`, tools that would require interactive approval are
|
||||
/// auto-denied instead. Used for channel-driven (non-CLI) runs.
|
||||
non_interactive: bool,
|
||||
/// Session-scoped allowlist built from "Always" responses.
|
||||
session_allowlist: Mutex<HashSet<String>>,
|
||||
/// Audit trail of approval decisions.
|
||||
@@ -63,17 +73,40 @@ pub struct ApprovalManager {
|
||||
}
|
||||
|
||||
impl ApprovalManager {
|
||||
/// Create from autonomy config.
|
||||
/// Create an interactive (CLI) approval manager from autonomy config.
|
||||
pub fn from_config(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: false,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-interactive approval manager for channel-driven runs.
|
||||
///
|
||||
/// Enforces the same `auto_approve` / `always_ask` / supervised policies
|
||||
/// as the CLI manager, but tools that would require interactive approval
|
||||
/// are auto-denied instead of prompting (since there is no operator).
|
||||
pub fn for_non_interactive(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: true,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` when this manager operates in non-interactive mode
|
||||
/// (i.e. for channel-driven runs where no operator can approve).
|
||||
pub fn is_non_interactive(&self) -> bool {
|
||||
self.non_interactive
|
||||
}
|
||||
|
||||
/// Check whether a tool call requires interactive approval.
|
||||
///
|
||||
/// Returns `true` if the call needs a prompt, `false` if it can proceed.
|
||||
@@ -147,8 +180,8 @@ impl ApprovalManager {
|
||||
|
||||
/// Prompt the user on the CLI and return their decision.
|
||||
///
|
||||
/// For non-CLI channels, returns `Yes` automatically (interactive
|
||||
/// approval is only supported on CLI for now).
|
||||
/// Only called for interactive (CLI) managers. Non-interactive managers
|
||||
/// auto-deny in the tool-call loop before reaching this point.
|
||||
pub fn prompt_cli(&self, request: &ApprovalRequest) -> ApprovalResponse {
|
||||
prompt_cli_interactive(request)
|
||||
}
|
||||
@@ -401,6 +434,97 @@ mod tests {
|
||||
assert!(summary.contains("just a string"));
|
||||
}
|
||||
|
||||
// ── non-interactive (channel) mode ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn non_interactive_manager_reports_non_interactive() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_manager_reports_interactive() {
|
||||
let mgr = ApprovalManager::from_config(&supervised_config());
|
||||
assert!(!mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_auto_approve_tools_skip_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// auto_approve tools (file_read, memory_recall) should not need approval.
|
||||
assert!(!mgr.needs_approval("file_read"));
|
||||
assert!(!mgr.needs_approval("memory_recall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_tools_need_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// always_ask tools (shell) still report as needing approval,
|
||||
// so the tool-call loop will auto-deny them in non-interactive mode.
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_unknown_tools_need_approval_in_supervised() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// Unknown tools in supervised mode need approval (will be auto-denied
|
||||
// by the tool-call loop for non-interactive managers).
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
assert!(mgr.needs_approval("http_request"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_full_autonomy_never_needs_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&full_config());
|
||||
// Full autonomy means no approval needed, even in non-interactive mode.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
assert!(!mgr.needs_approval("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_readonly_never_needs_approval() {
|
||||
let config = AutonomyConfig {
|
||||
level: AutonomyLevel::ReadOnly,
|
||||
..AutonomyConfig::default()
|
||||
};
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
// ReadOnly blocks execution elsewhere; approval manager does not prompt.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_session_allowlist_still_works() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
|
||||
// Simulate an "Always" decision (would come from a prior channel run
|
||||
// if the tool was auto-approved somehow, e.g. via config change).
|
||||
mgr.record_decision(
|
||||
"file_write",
|
||||
&serde_json::json!({"path": "test.txt"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_overrides_session_allowlist() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
|
||||
mgr.record_decision(
|
||||
"shell",
|
||||
&serde_json::json!({"command": "ls"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
// shell is in always_ask, so it still needs approval even after "Always".
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
// ── ApprovalResponse serde ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -746,7 +746,7 @@ impl Channel for MatrixChannel {
|
||||
MessageType::Notice(content) => (content.body.clone(), None),
|
||||
MessageType::Image(content) => {
|
||||
let dl = media_info(&content.source, &content.body);
|
||||
(format!("[image: {}]", content.body), dl)
|
||||
(format!("[IMAGE:{}]", content.body), dl)
|
||||
}
|
||||
MessageType::File(content) => {
|
||||
let dl = media_info(&content.source, &content.body);
|
||||
@@ -888,7 +888,7 @@ impl Channel for MatrixChannel {
|
||||
sender: sender.clone(),
|
||||
reply_target: format!("{}||{}", sender, room.room_id()),
|
||||
content: body,
|
||||
channel: format!("matrix:{}", room.room_id()),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
|
||||
+673
-2
@@ -30,7 +30,9 @@ pub mod mattermost;
|
||||
pub mod nextcloud_talk;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub mod nostr;
|
||||
pub mod notion;
|
||||
pub mod qq;
|
||||
pub mod session_store;
|
||||
pub mod signal;
|
||||
pub mod slack;
|
||||
pub mod telegram;
|
||||
@@ -61,6 +63,7 @@ pub use mattermost::MattermostChannel;
|
||||
pub use nextcloud_talk::NextcloudTalkChannel;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub use nostr::NostrChannel;
|
||||
pub use notion::NotionChannel;
|
||||
pub use qq::QQChannel;
|
||||
pub use signal::SignalChannel;
|
||||
pub use slack::SlackChannel;
|
||||
@@ -75,6 +78,7 @@ pub use whatsapp::WhatsAppChannel;
|
||||
pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
|
||||
use crate::approval::ApprovalManager;
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
use crate::memory::{self, Memory};
|
||||
@@ -310,8 +314,15 @@ struct ChannelRuntimeContext {
|
||||
non_cli_excluded_tools: Arc<Vec<String>>,
|
||||
tool_call_dedup_exempt: Arc<Vec<String>>,
|
||||
model_routes: Arc<Vec<crate::config::ModelRouteConfig>>,
|
||||
query_classification: crate::config::QueryClassificationConfig,
|
||||
ack_reactions: bool,
|
||||
show_tool_calls: bool,
|
||||
session_store: Option<Arc<session_store::SessionStore>>,
|
||||
/// Non-interactive approval manager for channel-driven runs.
|
||||
/// Enforces `auto_approve` / `always_ask` / supervised policy from
|
||||
/// `[autonomy]` config; auto-denies tools that would need interactive
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -965,6 +976,13 @@ fn proactive_trim_turns(turns: &mut Vec<ChatMessage>, budget: usize) -> usize {
|
||||
}
|
||||
|
||||
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||||
// Persist to JSONL before adding to in-memory history.
|
||||
if let Some(ref store) = ctx.session_store {
|
||||
if let Err(e) = store.append(sender_key, &turn) {
|
||||
tracing::warn!("Failed to persist session turn: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let mut histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
@@ -1777,7 +1795,31 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let history_key = conversation_history_key(&msg);
|
||||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
let mut route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
|
||||
// ── Query classification: override route when a rule matches ──
|
||||
if let Some(hint) = crate::agent::classifier::classify(&ctx.query_classification, &msg.content)
|
||||
{
|
||||
if let Some(matched_route) = ctx
|
||||
.model_routes
|
||||
.iter()
|
||||
.find(|r| r.hint.eq_ignore_ascii_case(&hint))
|
||||
{
|
||||
tracing::info!(
|
||||
target: "query_classification",
|
||||
hint = hint.as_str(),
|
||||
provider = matched_route.provider.as_str(),
|
||||
model = matched_route.model.as_str(),
|
||||
channel = %msg.channel,
|
||||
"Channel message classified — overriding route"
|
||||
);
|
||||
route = ChannelRouteSelection {
|
||||
provider: matched_route.provider.clone(),
|
||||
model: matched_route.model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
@@ -2016,7 +2058,7 @@ async fn process_channel_message(
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
None,
|
||||
Some(&*ctx.approval_manager),
|
||||
msg.channel.as_str(),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
@@ -2186,6 +2228,29 @@ async fn process_channel_message(
|
||||
&history_key,
|
||||
ChatMessage::assistant(&history_response),
|
||||
);
|
||||
|
||||
// Fire-and-forget LLM-driven memory consolidation.
|
||||
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let provider = Arc::clone(&ctx.provider);
|
||||
let model = ctx.model.to_string();
|
||||
let memory = Arc::clone(&ctx.memory);
|
||||
let user_msg = msg.content.clone();
|
||||
let assistant_resp = delivered_response.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::memory::consolidation::consolidate_turn(
|
||||
provider.as_ref(),
|
||||
&model,
|
||||
memory.as_ref(),
|
||||
&user_msg,
|
||||
&assistant_resp,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::debug!("Memory consolidation skipped: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
@@ -2918,6 +2983,12 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con
|
||||
channel.name()
|
||||
);
|
||||
}
|
||||
// Notion is a top-level config section, not part of ChannelsConfig
|
||||
{
|
||||
let notion_configured =
|
||||
config.notion.enabled && !config.notion.database_id.trim().is_empty();
|
||||
println!(" {} Notion", if notion_configured { "✅" } else { "❌" });
|
||||
}
|
||||
if !cfg!(feature = "channel-matrix") {
|
||||
println!(
|
||||
" ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)."
|
||||
@@ -3203,6 +3274,8 @@ fn collect_configured_channels(
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||||
eprintln!(" ⚠ WhatsApp Web is configured but the 'whatsapp-web' feature is not compiled in.");
|
||||
eprintln!(" Rebuild with: cargo build --features whatsapp-web");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -3348,6 +3421,34 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
// Notion database poller channel
|
||||
if config.notion.enabled && !config.notion.database_id.trim().is_empty() {
|
||||
let notion_api_key = if config.notion.api_key.trim().is_empty() {
|
||||
std::env::var("NOTION_API_KEY").unwrap_or_default()
|
||||
} else {
|
||||
config.notion.api_key.trim().to_string()
|
||||
};
|
||||
if notion_api_key.trim().is_empty() {
|
||||
tracing::warn!(
|
||||
"Notion channel enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
|
||||
);
|
||||
} else {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Notion",
|
||||
channel: Arc::new(NotionChannel::new(
|
||||
notion_api_key,
|
||||
config.notion.database_id.clone(),
|
||||
config.notion.poll_interval_secs,
|
||||
config.notion.status_property.clone(),
|
||||
config.notion.input_property.clone(),
|
||||
config.notion.result_property.clone(),
|
||||
config.notion.max_concurrent,
|
||||
config.notion.recover_stale,
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
channels
|
||||
}
|
||||
|
||||
@@ -3803,10 +3904,46 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
non_cli_excluded_tools: Arc::new(config.autonomy.non_cli_excluded_tools.clone()),
|
||||
tool_call_dedup_exempt: Arc::new(config.agent.tool_call_dedup_exempt.clone()),
|
||||
model_routes: Arc::new(config.model_routes.clone()),
|
||||
query_classification: config.query_classification.clone(),
|
||||
ack_reactions: config.channels_config.ack_reactions,
|
||||
show_tool_calls: config.channels_config.show_tool_calls,
|
||||
session_store: if config.channels_config.session_persistence {
|
||||
match session_store::SessionStore::new(&config.workspace_dir) {
|
||||
Ok(store) => {
|
||||
tracing::info!("📂 Session persistence enabled");
|
||||
Some(Arc::new(store))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Session persistence disabled: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
if let Some(ref store) = runtime_ctx.session_store {
|
||||
let mut hydrated = 0usize;
|
||||
let mut histories = runtime_ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
for key in store.list_sessions() {
|
||||
let msgs = store.load(&key);
|
||||
if !msgs.is_empty() {
|
||||
hydrated += 1;
|
||||
histories.insert(key, msgs);
|
||||
}
|
||||
}
|
||||
drop(histories);
|
||||
if hydrated > 0 {
|
||||
tracing::info!("📂 Restored {hydrated} session(s) from disk");
|
||||
}
|
||||
}
|
||||
|
||||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||||
|
||||
// Wait for all channel tasks
|
||||
@@ -4070,8 +4207,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -4173,8 +4315,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -4232,8 +4379,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -4749,8 +4901,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4816,8 +4973,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4897,8 +5059,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4963,8 +5130,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5039,8 +5211,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5135,8 +5312,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5213,8 +5395,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5306,8 +5493,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5384,8 +5576,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5452,8 +5649,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5631,8 +5833,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -5718,8 +5925,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5817,11 +6029,16 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
},
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5919,8 +6136,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6000,8 +6222,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6066,8 +6293,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6690,8 +6922,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6782,8 +7019,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6874,8 +7116,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7430,8 +7677,13 @@ This is an example JSON object for profile settings."#;
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -7503,8 +7755,13 @@ This is an example JSON object for profile settings."#;
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7584,6 +7841,420 @@ This is an example JSON object for profile settings."#;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Query classification in channel message processing ─────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_applies_query_classification_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "please analyze-image from the dataset".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Vision provider should have been called instead of the default.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
vision_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["gpt-4-vision".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_disabled_uses_default_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
// Classification is disabled — matching keyword should NOT trigger reroute.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: false,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-disabled".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "please analyze-image from the dataset".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Default provider should be used since classification is disabled.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_no_match_uses_default_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
// Classification enabled with a rule that won't match the message.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-nomatch".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "just a regular text message".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Default provider should be used since no classification rule matched.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_priority_selects_highest() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let fast_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let fast_provider: Arc<dyn Provider> = fast_provider_impl.clone();
|
||||
let code_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let code_provider: Arc<dyn Provider> = code_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("fast-provider".to_string(), fast_provider);
|
||||
provider_cache_seed.insert("code-provider".to_string(), code_provider);
|
||||
|
||||
// Both rules match "code" keyword, but "code" rule has higher priority.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![
|
||||
crate::config::schema::ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 1,
|
||||
..Default::default()
|
||||
},
|
||||
crate::config::schema::ClassificationRule {
|
||||
hint: "code".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 10,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let model_routes = vec![
|
||||
crate::config::ModelRouteConfig {
|
||||
hint: "fast".into(),
|
||||
provider: "fast-provider".into(),
|
||||
model: "fast-model".into(),
|
||||
api_key: None,
|
||||
},
|
||||
crate::config::ModelRouteConfig {
|
||||
hint: "code".into(),
|
||||
provider: "code-provider".into(),
|
||||
model: "code-model".into(),
|
||||
api_key: None,
|
||||
},
|
||||
];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-prio".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "write some code for me".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Higher-priority "code" rule (priority=10) should win over "fast" (priority=1).
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(fast_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(code_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
code_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["code-model".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_channel_by_id_unconfigured_telegram_returns_error() {
|
||||
let config = Config::default();
|
||||
|
||||
@@ -0,0 +1,614 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
|
||||
const NOTION_VERSION: &str = "2022-06-28";
|
||||
const MAX_RESULT_LENGTH: usize = 2000;
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_BASE_DELAY_MS: u64 = 2000;
|
||||
/// Maximum number of characters to include from an error response body.
|
||||
const MAX_ERROR_BODY_CHARS: usize = 500;
|
||||
|
||||
/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary.
|
||||
fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
|
||||
if max_bytes >= s.len() {
|
||||
return s.len();
|
||||
}
|
||||
let mut idx = max_bytes;
|
||||
while idx > 0 && !s.is_char_boundary(idx) {
|
||||
idx -= 1;
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
/// Notion channel — polls a Notion database for pending tasks and writes results back.
|
||||
///
|
||||
/// The channel connects to the Notion API, queries a database for rows with a "pending"
|
||||
/// status, dispatches them as channel messages, and writes results back when processing
|
||||
/// completes. It supports crash recovery by resetting stale "running" tasks on startup.
|
||||
pub struct NotionChannel {
|
||||
api_key: String,
|
||||
database_id: String,
|
||||
poll_interval_secs: u64,
|
||||
status_property: String,
|
||||
input_property: String,
|
||||
result_property: String,
|
||||
max_concurrent: usize,
|
||||
status_type: Arc<RwLock<String>>,
|
||||
inflight: Arc<RwLock<HashSet<String>>>,
|
||||
http: reqwest::Client,
|
||||
recover_stale: bool,
|
||||
}
|
||||
|
||||
impl NotionChannel {
|
||||
/// Create a new Notion channel with the given configuration.
|
||||
pub fn new(
|
||||
api_key: String,
|
||||
database_id: String,
|
||||
poll_interval_secs: u64,
|
||||
status_property: String,
|
||||
input_property: String,
|
||||
result_property: String,
|
||||
max_concurrent: usize,
|
||||
recover_stale: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
database_id,
|
||||
poll_interval_secs,
|
||||
status_property,
|
||||
input_property,
|
||||
result_property,
|
||||
max_concurrent,
|
||||
status_type: Arc::new(RwLock::new("select".to_string())),
|
||||
inflight: Arc::new(RwLock::new(HashSet::new())),
|
||||
http: reqwest::Client::new(),
|
||||
recover_stale,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the standard Notion API headers (Authorization, version, content-type).
|
||||
fn headers(&self) -> Result<reqwest::header::HeaderMap> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", self.api_key)
|
||||
.parse()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
|
||||
);
|
||||
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
|
||||
headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx).
|
||||
async fn api_call(
|
||||
&self,
|
||||
method: reqwest::Method,
|
||||
url: &str,
|
||||
body: Option<serde_json::Value>,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut last_err = None;
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
let mut req = self
|
||||
.http
|
||||
.request(method.clone(), url)
|
||||
.headers(self.headers()?);
|
||||
if let Some(ref b) = body {
|
||||
req = req.json(b);
|
||||
}
|
||||
match req.send().await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
return resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse response: {e}"));
|
||||
}
|
||||
let status_code = status.as_u16();
|
||||
// Only retry on 429 (rate limit) or 5xx (server errors)
|
||||
if status_code != 429 && (400..500).contains(&status_code) {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
let truncated =
|
||||
crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
|
||||
bail!("Notion API error {status_code}: {truncated}");
|
||||
}
|
||||
last_err = Some(anyhow::anyhow!("Notion API error: {status_code}"));
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(anyhow::anyhow!("HTTP request failed: {e}"));
|
||||
}
|
||||
}
|
||||
let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
|
||||
tracing::warn!(
|
||||
"Notion API call failed (attempt {}/{}), retrying in {}ms",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
delay
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
|
||||
}
|
||||
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries")))
|
||||
}
|
||||
|
||||
/// Query the database schema and detect whether Status uses "select" or "status" type.
|
||||
async fn detect_status_type(&self) -> Result<String> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
|
||||
let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
|
||||
let status_type = resp
|
||||
.get("properties")
|
||||
.and_then(|p| p.get(&self.status_property))
|
||||
.and_then(|s| s.get("type"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("select")
|
||||
.to_string();
|
||||
Ok(status_type)
|
||||
}
|
||||
|
||||
/// Query for rows where Status = "pending".
|
||||
async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let filter = build_status_filter(&self.status_property, &status_type, "pending");
|
||||
let resp = self
|
||||
.api_call(
|
||||
reqwest::Method::POST,
|
||||
&url,
|
||||
Some(serde_json::json!({ "filter": filter })),
|
||||
)
|
||||
.await?;
|
||||
Ok(resp
|
||||
.get("results")
|
||||
.and_then(|r| r.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Atomically claim a task. Returns true if this caller got it.
|
||||
async fn claim_task(&self, page_id: &str) -> bool {
|
||||
let mut inflight = self.inflight.write().await;
|
||||
if inflight.contains(page_id) {
|
||||
return false;
|
||||
}
|
||||
if inflight.len() >= self.max_concurrent {
|
||||
return false;
|
||||
}
|
||||
inflight.insert(page_id.to_string());
|
||||
true
|
||||
}
|
||||
|
||||
/// Release a task from the inflight set.
|
||||
async fn release_task(&self, page_id: &str) {
|
||||
let mut inflight = self.inflight.write().await;
|
||||
inflight.remove(page_id);
|
||||
}
|
||||
|
||||
/// Update a row's status.
|
||||
async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.status_property: build_status_payload(&status_type, status_value),
|
||||
}
|
||||
});
|
||||
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write result text to the Result column.
|
||||
async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.result_property: build_rich_text_payload(result_text),
|
||||
}
|
||||
});
|
||||
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// On startup, reset "running" tasks back to "pending" for crash recovery.
|
||||
async fn recover_stale(&self) -> Result<()> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let filter = build_status_filter(&self.status_property, &status_type, "running");
|
||||
let resp = self
|
||||
.api_call(
|
||||
reqwest::Method::POST,
|
||||
&url,
|
||||
Some(serde_json::json!({ "filter": filter })),
|
||||
)
|
||||
.await?;
|
||||
let stale = resp
|
||||
.get("results")
|
||||
.and_then(|r| r.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
if stale.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
tracing::warn!(
|
||||
"Found {} stale task(s) in 'running' state, resetting to 'pending'",
|
||||
stale.len()
|
||||
);
|
||||
for task in &stale {
|
||||
if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
|
||||
let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.status_property: build_status_payload(&status_type, "pending"),
|
||||
&self.result_property: build_rich_text_payload(
|
||||
"Reset: poller restarted while task was running"
|
||||
),
|
||||
}
|
||||
});
|
||||
let short_id_end = floor_utf8_char_boundary(page_id, 8);
|
||||
let short_id = &page_id[..short_id_end];
|
||||
if let Err(e) = self
|
||||
.api_call(reqwest::Method::PATCH, &page_url, Some(payload))
|
||||
.await
|
||||
{
|
||||
tracing::error!("Could not reset stale task {short_id}: {e}");
|
||||
} else {
|
||||
tracing::info!("Reset stale task {short_id} to pending");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for NotionChannel {
|
||||
fn name(&self) -> &str {
|
||||
"notion"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> Result<()> {
|
||||
// recipient is the page_id for Notion
|
||||
let page_id = &message.recipient;
|
||||
let status_type = self.status_type.read().await.clone();
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let payload = serde_json::json!({
|
||||
"properties": {
|
||||
&self.status_property: build_status_payload(&status_type, "done"),
|
||||
&self.result_property: build_rich_text_payload(&message.content),
|
||||
}
|
||||
});
|
||||
self.api_call(reqwest::Method::PATCH, &url, Some(payload))
|
||||
.await?;
|
||||
self.release_task(page_id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
// Detect status property type
|
||||
match self.detect_status_type().await {
|
||||
Ok(st) => {
|
||||
tracing::info!("Notion status property type: {st}");
|
||||
*self.status_type.write().await = st;
|
||||
}
|
||||
Err(e) => {
|
||||
bail!("Failed to detect Notion database schema: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
// Crash recovery
|
||||
if self.recover_stale {
|
||||
if let Err(e) = self.recover_stale().await {
|
||||
tracing::error!("Notion stale task recovery failed: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
// Polling loop
|
||||
loop {
|
||||
match self.query_pending().await {
|
||||
Ok(tasks) => {
|
||||
if !tasks.is_empty() {
|
||||
tracing::info!("Notion: found {} pending task(s)", tasks.len());
|
||||
}
|
||||
for task in tasks {
|
||||
let page_id = match task.get("id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let input_text = extract_text_from_property(
|
||||
task.get("properties")
|
||||
.and_then(|p| p.get(&self.input_property)),
|
||||
);
|
||||
|
||||
if input_text.trim().is_empty() {
|
||||
let short_end = floor_utf8_char_boundary(&page_id, 8);
|
||||
tracing::warn!(
|
||||
"Notion: empty input for task {}, skipping",
|
||||
&page_id[..short_end]
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !self.claim_task(&page_id).await {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Set status to running
|
||||
if let Err(e) = self.set_status(&page_id, "running").await {
|
||||
tracing::error!("Notion: failed to set running status: {e}");
|
||||
self.release_task(&page_id).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if tx
|
||||
.send(ChannelMessage {
|
||||
id: page_id.clone(),
|
||||
sender: "notion".into(),
|
||||
reply_target: page_id,
|
||||
content: input_text,
|
||||
channel: "notion".into(),
|
||||
timestamp,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
tracing::info!("Notion channel shutting down");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Notion poll error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
|
||||
self.api_call(reqwest::Method::GET, &url, None)
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ──────────────────────────────────────────────
|
||||
|
||||
/// Build a Notion API filter object for the given status property.
|
||||
fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
|
||||
if status_type == "status" {
|
||||
serde_json::json!({
|
||||
"property": property,
|
||||
"status": { "equals": value }
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"property": property,
|
||||
"select": { "equals": value }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Notion API property-update payload for a status field.
|
||||
fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
|
||||
if status_type == "status" {
|
||||
serde_json::json!({ "status": { "name": value } })
|
||||
} else {
|
||||
serde_json::json!({ "select": { "name": value } })
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Notion API rich-text property payload, truncating if necessary.
|
||||
fn build_rich_text_payload(value: &str) -> serde_json::Value {
|
||||
let truncated = truncate_result(value);
|
||||
serde_json::json!({
|
||||
"rich_text": [{
|
||||
"text": { "content": truncated }
|
||||
}]
|
||||
})
|
||||
}
|
||||
|
||||
/// Truncate result text to fit within the Notion rich-text content limit.
|
||||
fn truncate_result(value: &str) -> String {
|
||||
if value.len() <= MAX_RESULT_LENGTH {
|
||||
return value.to_string();
|
||||
}
|
||||
let cut = MAX_RESULT_LENGTH.saturating_sub(30);
|
||||
// Ensure we cut on a char boundary
|
||||
let end = floor_utf8_char_boundary(value, cut);
|
||||
format!("{}\n\n... [output truncated]", &value[..end])
|
||||
}
|
||||
|
||||
/// Extract plain text from a Notion property (title or rich_text type).
|
||||
fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
|
||||
let Some(prop) = prop else {
|
||||
return String::new();
|
||||
};
|
||||
let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
let array_key = match ptype {
|
||||
"title" => "title",
|
||||
"rich_text" => "rich_text",
|
||||
_ => return String::new(),
|
||||
};
|
||||
prop.get(array_key)
|
||||
.and_then(|arr| arr.as_array())
|
||||
.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_task_deduplication() {
|
||||
let channel = NotionChannel::new(
|
||||
"test-key".into(),
|
||||
"test-db".into(),
|
||||
5,
|
||||
"Status".into(),
|
||||
"Input".into(),
|
||||
"Result".into(),
|
||||
4,
|
||||
false,
|
||||
);
|
||||
|
||||
assert!(channel.claim_task("page-1").await);
|
||||
// Second claim for same page should fail
|
||||
assert!(!channel.claim_task("page-1").await);
|
||||
// Different page should succeed
|
||||
assert!(channel.claim_task("page-2").await);
|
||||
|
||||
// After release, can claim again
|
||||
channel.release_task("page-1").await;
|
||||
assert!(channel.claim_task("page-1").await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_truncation_within_limit() {
|
||||
let short = "hello world";
|
||||
assert_eq!(truncate_result(short), short);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_truncation_over_limit() {
|
||||
let long = "a".repeat(MAX_RESULT_LENGTH + 100);
|
||||
let truncated = truncate_result(&long);
|
||||
assert!(truncated.len() <= MAX_RESULT_LENGTH);
|
||||
assert!(truncated.ends_with("... [output truncated]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_truncation_multibyte_safe() {
|
||||
// Build a string that would cut in the middle of a multibyte char
|
||||
let mut s = String::new();
|
||||
for _ in 0..700 {
|
||||
s.push('\u{6E2C}'); // 3-byte UTF-8 char
|
||||
}
|
||||
let truncated = truncate_result(&s);
|
||||
// Should not panic and should be valid UTF-8
|
||||
assert!(truncated.len() <= MAX_RESULT_LENGTH);
|
||||
assert!(truncated.ends_with("... [output truncated]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_payload_select_type() {
|
||||
let payload = build_status_payload("select", "pending");
|
||||
assert_eq!(
|
||||
payload,
|
||||
serde_json::json!({ "select": { "name": "pending" } })
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_payload_status_type() {
|
||||
let payload = build_status_payload("status", "done");
|
||||
assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rich_text_payload_construction() {
|
||||
let payload = build_rich_text_payload("test output");
|
||||
let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
|
||||
assert_eq!(text, "test output");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_filter_select_type() {
|
||||
let filter = build_status_filter("Status", "select", "pending");
|
||||
assert_eq!(
|
||||
filter,
|
||||
serde_json::json!({
|
||||
"property": "Status",
|
||||
"select": { "equals": "pending" }
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_filter_status_type() {
|
||||
let filter = build_status_filter("Status", "status", "running");
|
||||
assert_eq!(
|
||||
filter,
|
||||
serde_json::json!({
|
||||
"property": "Status",
|
||||
"status": { "equals": "running" }
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_title_property() {
|
||||
let prop = serde_json::json!({
|
||||
"type": "title",
|
||||
"title": [
|
||||
{ "plain_text": "Hello " },
|
||||
{ "plain_text": "World" }
|
||||
]
|
||||
});
|
||||
assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_rich_text_property() {
|
||||
let prop = serde_json::json!({
|
||||
"type": "rich_text",
|
||||
"rich_text": [{ "plain_text": "task content" }]
|
||||
});
|
||||
assert_eq!(extract_text_from_property(Some(&prop)), "task content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_none() {
|
||||
assert_eq!(extract_text_from_property(None), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_text_from_unknown_type() {
|
||||
let prop = serde_json::json!({ "type": "number", "number": 42 });
|
||||
assert_eq!(extract_text_from_property(Some(&prop)), "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_task_respects_max_concurrent() {
|
||||
let channel = NotionChannel::new(
|
||||
"test-key".into(),
|
||||
"test-db".into(),
|
||||
5,
|
||||
"Status".into(),
|
||||
"Input".into(),
|
||||
"Result".into(),
|
||||
2, // max_concurrent = 2
|
||||
false,
|
||||
);
|
||||
|
||||
assert!(channel.claim_task("page-1").await);
|
||||
assert!(channel.claim_task("page-2").await);
|
||||
// Third claim should be rejected (at capacity)
|
||||
assert!(!channel.claim_task("page-3").await);
|
||||
|
||||
// After releasing one, can claim again
|
||||
channel.release_task("page-1").await;
|
||||
assert!(channel.claim_task("page-3").await);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
//! JSONL-based session persistence for channel conversations.
|
||||
//!
|
||||
//! Each session (keyed by `channel_sender` or `channel_thread_sender`) is stored
|
||||
//! as an append-only JSONL file in `{workspace}/sessions/`. Messages are appended
|
||||
//! one-per-line as JSON, never modifying old lines. On daemon restart, sessions
|
||||
//! are loaded from disk to restore conversation context.
|
||||
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Append-only JSONL session store for channel conversations.
|
||||
pub struct SessionStore {
|
||||
sessions_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Create a new session store, ensuring the sessions directory exists.
|
||||
pub fn new(workspace_dir: &Path) -> std::io::Result<Self> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir)?;
|
||||
Ok(Self { sessions_dir })
|
||||
}
|
||||
|
||||
/// Compute the file path for a session key, sanitizing for filesystem safety.
|
||||
fn session_path(&self, session_key: &str) -> PathBuf {
|
||||
let safe_key: String = session_key
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_alphanumeric() || c == '_' || c == '-' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
self.sessions_dir.join(format!("{safe_key}.jsonl"))
|
||||
}
|
||||
|
||||
/// Load all messages for a session from its JSONL file.
|
||||
/// Returns an empty vec if the file does not exist or is unreadable.
|
||||
pub fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
let path = self.session_path(session_key);
|
||||
let file = match std::fs::File::open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<ChatMessage>(trimmed) {
|
||||
messages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Append a single message to the session JSONL file.
|
||||
pub fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
let path = self.session_path(session_key);
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)?;
|
||||
|
||||
let json = serde_json::to_string(message)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
writeln!(file, "{json}")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all session keys that have files on disk.
|
||||
pub fn list_sessions(&self) -> Vec<String> {
|
||||
let entries = match std::fs::read_dir(&self.sessions_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
entries
|
||||
.filter_map(|entry| {
|
||||
let entry = entry.ok()?;
|
||||
let name = entry.file_name().into_string().ok()?;
|
||||
name.strip_suffix(".jsonl").map(String::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn round_trip_append_and_load() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("telegram_user123", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("telegram_user123", &ChatMessage::assistant("hi there"))
|
||||
.unwrap();
|
||||
|
||||
let messages = store.load("telegram_user123");
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "hello");
|
||||
assert_eq!(messages[1].role, "assistant");
|
||||
assert_eq!(messages[1].content, "hi there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_nonexistent_session_returns_empty() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
let messages = store.load("nonexistent");
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_sanitization() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
// Keys with special chars should be sanitized
|
||||
store
|
||||
.append("slack/thread:123/user", &ChatMessage::user("test"))
|
||||
.unwrap();
|
||||
|
||||
let messages = store.load("slack/thread:123/user");
|
||||
assert_eq!(messages.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sessions_returns_keys() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("telegram_alice", &ChatMessage::user("hi"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("discord_bob", &ChatMessage::user("hey"))
|
||||
.unwrap();
|
||||
|
||||
let mut sessions = store.list_sessions();
|
||||
sessions.sort();
|
||||
assert_eq!(sessions.len(), 2);
|
||||
assert!(sessions.contains(&"discord_bob".to_string()));
|
||||
assert!(sessions.contains(&"telegram_alice".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_is_truly_append_only() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "test_session";
|
||||
|
||||
store.append(key, &ChatMessage::user("msg1")).unwrap();
|
||||
store.append(key, &ChatMessage::user("msg2")).unwrap();
|
||||
|
||||
// Read raw file to verify append-only format
|
||||
let path = store.session_path(key);
|
||||
let content = std::fs::read_to_string(&path).unwrap();
|
||||
let lines: Vec<&str> = content.trim().lines().collect();
|
||||
assert_eq!(lines.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_corrupt_lines_gracefully() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "corrupt_test";
|
||||
|
||||
// Write valid message + corrupt line + valid message
|
||||
let path = store.session_path(key);
|
||||
std::fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
let mut file = std::fs::File::create(&path).unwrap();
|
||||
writeln!(file, r#"{{"role":"user","content":"hello"}}"#).unwrap();
|
||||
writeln!(file, "this is not valid json").unwrap();
|
||||
writeln!(file, r#"{{"role":"assistant","content":"world"}}"#).unwrap();
|
||||
|
||||
let messages = store.load(key);
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].content, "hello");
|
||||
assert_eq!(messages[1].content, "world");
|
||||
}
|
||||
}
|
||||
+17
-13
@@ -1,24 +1,28 @@
|
||||
pub mod schema;
|
||||
pub mod traits;
|
||||
pub mod workspace;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{
|
||||
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig,
|
||||
BuiltinHooksConfig, ChannelsConfig, ClassificationRule, ComposioConfig, Config, CostConfig,
|
||||
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
|
||||
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
AgentConfig, AuditConfig, AutonomyConfig, BackupConfig, BrowserComputerUseConfig,
|
||||
BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, CloudOpsConfig,
|
||||
ComposioConfig, Config, ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig,
|
||||
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig,
|
||||
PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
|
||||
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
|
||||
TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, WorkspaceConfig,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+1410
-6
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,382 @@
|
||||
//! Workspace profile management for multi-client isolation.
|
||||
//!
|
||||
//! Each workspace represents an isolated client engagement with its own
|
||||
//! memory namespace, audit trail, secrets scope, and tool restrictions.
|
||||
//! Profiles are stored under `~/.zeroclaw/workspaces/<client_name>/`.
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// A single client workspace profile.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkspaceProfile {
|
||||
/// Human-readable workspace name (also used as directory name).
|
||||
pub name: String,
|
||||
/// Allowed domains for network access within this workspace.
|
||||
#[serde(default)]
|
||||
pub allowed_domains: Vec<String>,
|
||||
/// Credential profile name scoped to this workspace.
|
||||
#[serde(default)]
|
||||
pub credential_profile: Option<String>,
|
||||
/// Memory namespace prefix for isolation.
|
||||
#[serde(default)]
|
||||
pub memory_namespace: Option<String>,
|
||||
/// Audit namespace prefix for isolation.
|
||||
#[serde(default)]
|
||||
pub audit_namespace: Option<String>,
|
||||
/// Tool names denied in this workspace (e.g. `["shell"]` to block shell access).
|
||||
#[serde(default)]
|
||||
pub tool_restrictions: Vec<String>,
|
||||
}
|
||||
|
||||
impl WorkspaceProfile {
|
||||
/// Effective memory namespace (falls back to workspace name).
|
||||
pub fn effective_memory_namespace(&self) -> &str {
|
||||
self.memory_namespace
|
||||
.as_deref()
|
||||
.unwrap_or(self.name.as_str())
|
||||
}
|
||||
|
||||
/// Effective audit namespace (falls back to workspace name).
|
||||
pub fn effective_audit_namespace(&self) -> &str {
|
||||
self.audit_namespace
|
||||
.as_deref()
|
||||
.unwrap_or(self.name.as_str())
|
||||
}
|
||||
|
||||
/// Returns true if the given tool name is restricted in this workspace.
|
||||
pub fn is_tool_restricted(&self, tool_name: &str) -> bool {
|
||||
self.tool_restrictions
|
||||
.iter()
|
||||
.any(|r| r.eq_ignore_ascii_case(tool_name))
|
||||
}
|
||||
|
||||
/// Returns true if the given domain is allowed for this workspace.
|
||||
/// An empty allowlist means all domains are allowed.
|
||||
pub fn is_domain_allowed(&self, domain: &str) -> bool {
|
||||
if self.allowed_domains.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let domain_lower = domain.to_ascii_lowercase();
|
||||
self.allowed_domains
|
||||
.iter()
|
||||
.any(|d| domain_lower == d.to_ascii_lowercase())
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages loading and switching between client workspace profiles.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkspaceManager {
|
||||
/// Base directory containing all workspace subdirectories.
|
||||
workspaces_dir: PathBuf,
|
||||
/// Loaded workspace profiles keyed by name.
|
||||
profiles: HashMap<String, WorkspaceProfile>,
|
||||
/// Currently active workspace name.
|
||||
active: Option<String>,
|
||||
}
|
||||
|
||||
impl WorkspaceManager {
|
||||
/// Create a new workspace manager rooted at the given directory.
|
||||
pub fn new(workspaces_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
workspaces_dir,
|
||||
profiles: HashMap::new(),
|
||||
active: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all workspace profiles from disk.
|
||||
///
|
||||
/// Each subdirectory of `workspaces_dir` that contains a `profile.toml`
|
||||
/// is treated as a workspace.
|
||||
pub async fn load_profiles(&mut self) -> Result<()> {
|
||||
self.profiles.clear();
|
||||
|
||||
let dir = &self.workspaces_dir;
|
||||
if !dir.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut entries = tokio::fs::read_dir(dir)
|
||||
.await
|
||||
.with_context(|| format!("reading workspaces directory: {}", dir.display()))?;
|
||||
|
||||
while let Some(entry) = entries.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let profile_path = path.join("profile.toml");
|
||||
if !profile_path.exists() {
|
||||
continue;
|
||||
}
|
||||
match tokio::fs::read_to_string(&profile_path).await {
|
||||
Ok(contents) => match toml::from_str::<WorkspaceProfile>(&contents) {
|
||||
Ok(profile) => {
|
||||
self.profiles.insert(profile.name.clone(), profile);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"skipping malformed workspace profile {}: {e}",
|
||||
profile_path.display()
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"skipping unreadable workspace profile {}: {e}",
|
||||
profile_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Switch to the named workspace. Returns an error if it does not exist.
|
||||
pub fn switch(&mut self, name: &str) -> Result<&WorkspaceProfile> {
|
||||
if !self.profiles.contains_key(name) {
|
||||
bail!("workspace '{}' not found", name);
|
||||
}
|
||||
self.active = Some(name.to_string());
|
||||
Ok(&self.profiles[name])
|
||||
}
|
||||
|
||||
/// Get the currently active workspace profile, if any.
|
||||
pub fn active_profile(&self) -> Option<&WorkspaceProfile> {
|
||||
self.active
|
||||
.as_deref()
|
||||
.and_then(|name| self.profiles.get(name))
|
||||
}
|
||||
|
||||
/// Get the active workspace name.
|
||||
pub fn active_name(&self) -> Option<&str> {
|
||||
self.active.as_deref()
|
||||
}
|
||||
|
||||
/// List all loaded workspace names.
|
||||
pub fn list(&self) -> Vec<&str> {
|
||||
let mut names: Vec<&str> = self.profiles.keys().map(String::as_str).collect();
|
||||
names.sort_unstable();
|
||||
names
|
||||
}
|
||||
|
||||
/// Get a workspace profile by name.
|
||||
pub fn get(&self, name: &str) -> Option<&WorkspaceProfile> {
|
||||
self.profiles.get(name)
|
||||
}
|
||||
|
||||
/// Create a new workspace on disk and register it.
|
||||
pub async fn create(&mut self, name: &str) -> Result<&WorkspaceProfile> {
|
||||
if name.is_empty() {
|
||||
bail!("workspace name must not be empty");
|
||||
}
|
||||
// Validate name: alphanumeric, hyphens, underscores only
|
||||
if !name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
|
||||
{
|
||||
bail!(
|
||||
"workspace name must contain only alphanumeric characters, hyphens, or underscores"
|
||||
);
|
||||
}
|
||||
if self.profiles.contains_key(name) {
|
||||
bail!("workspace '{}' already exists", name);
|
||||
}
|
||||
|
||||
let ws_dir = self.workspaces_dir.join(name);
|
||||
tokio::fs::create_dir_all(&ws_dir)
|
||||
.await
|
||||
.with_context(|| format!("creating workspace directory: {}", ws_dir.display()))?;
|
||||
|
||||
let profile = WorkspaceProfile {
|
||||
name: name.to_string(),
|
||||
allowed_domains: Vec::new(),
|
||||
credential_profile: None,
|
||||
memory_namespace: Some(name.to_string()),
|
||||
audit_namespace: Some(name.to_string()),
|
||||
tool_restrictions: Vec::new(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&profile).context("serializing workspace profile")?;
|
||||
let profile_path = ws_dir.join("profile.toml");
|
||||
tokio::fs::write(&profile_path, toml_str)
|
||||
.await
|
||||
.with_context(|| format!("writing workspace profile: {}", profile_path.display()))?;
|
||||
|
||||
self.profiles.insert(name.to_string(), profile);
|
||||
Ok(&self.profiles[name])
|
||||
}
|
||||
|
||||
/// Export a workspace profile as a sanitized TOML string (no secrets).
|
||||
pub fn export(&self, name: &str) -> Result<String> {
|
||||
let profile = self
|
||||
.profiles
|
||||
.get(name)
|
||||
.with_context(|| format!("workspace '{}' not found", name))?;
|
||||
|
||||
// Create an export-safe copy with credential_profile redacted
|
||||
let export = WorkspaceProfile {
|
||||
credential_profile: profile
|
||||
.credential_profile
|
||||
.as_ref()
|
||||
.map(|_| "***".to_string()),
|
||||
..profile.clone()
|
||||
};
|
||||
|
||||
toml::to_string_pretty(&export).context("serializing workspace profile for export")
|
||||
}
|
||||
|
||||
/// Directory for a specific workspace.
|
||||
pub fn workspace_dir(&self, name: &str) -> PathBuf {
|
||||
self.workspaces_dir.join(name)
|
||||
}
|
||||
|
||||
/// Base workspaces directory.
|
||||
pub fn workspaces_dir(&self) -> &Path {
|
||||
&self.workspaces_dir
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_profile(name: &str) -> WorkspaceProfile {
|
||||
WorkspaceProfile {
|
||||
name: name.to_string(),
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
credential_profile: Some("test-creds".to_string()),
|
||||
memory_namespace: Some(format!("{name}_mem")),
|
||||
audit_namespace: Some(format!("{name}_audit")),
|
||||
tool_restrictions: vec!["shell".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_profile_tool_restriction_check() {
|
||||
let profile = sample_profile("client_a");
|
||||
assert!(profile.is_tool_restricted("shell"));
|
||||
assert!(profile.is_tool_restricted("Shell"));
|
||||
assert!(!profile.is_tool_restricted("file_read"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_profile_domain_allowlist_empty_allows_all() {
|
||||
let mut profile = sample_profile("client_a");
|
||||
profile.allowed_domains.clear();
|
||||
assert!(profile.is_domain_allowed("anything.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_profile_domain_allowlist_enforced() {
|
||||
let profile = sample_profile("client_a");
|
||||
assert!(profile.is_domain_allowed("example.com"));
|
||||
assert!(!profile.is_domain_allowed("other.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_profile_effective_namespaces() {
|
||||
let profile = sample_profile("client_a");
|
||||
assert_eq!(profile.effective_memory_namespace(), "client_a_mem");
|
||||
assert_eq!(profile.effective_audit_namespace(), "client_a_audit");
|
||||
|
||||
let fallback = WorkspaceProfile {
|
||||
name: "test_ws".to_string(),
|
||||
memory_namespace: None,
|
||||
audit_namespace: None,
|
||||
..sample_profile("test_ws")
|
||||
};
|
||||
assert_eq!(fallback.effective_memory_namespace(), "test_ws");
|
||||
assert_eq!(fallback.effective_audit_namespace(), "test_ws");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_manager_create_and_list() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
|
||||
mgr.create("client_alpha").await.unwrap();
|
||||
mgr.create("client_beta").await.unwrap();
|
||||
|
||||
let names = mgr.list();
|
||||
assert_eq!(names, vec!["client_alpha", "client_beta"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_manager_create_rejects_duplicate() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
|
||||
mgr.create("client_a").await.unwrap();
|
||||
let result = mgr.create("client_a").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_manager_create_rejects_invalid_name() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
|
||||
assert!(mgr.create("").await.is_err());
|
||||
assert!(mgr.create("bad name").await.is_err());
|
||||
assert!(mgr.create("../escape").await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_manager_switch_and_active() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
|
||||
mgr.create("ws_one").await.unwrap();
|
||||
assert!(mgr.active_profile().is_none());
|
||||
|
||||
mgr.switch("ws_one").unwrap();
|
||||
assert_eq!(mgr.active_name(), Some("ws_one"));
|
||||
assert!(mgr.active_profile().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_manager_switch_nonexistent_fails() {
|
||||
let mgr = WorkspaceManager::new(PathBuf::from("/tmp/nonexistent"));
|
||||
let mut mgr = mgr;
|
||||
assert!(mgr.switch("no_such_ws").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_manager_load_profiles_from_disk() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create a workspace via the manager
|
||||
mgr.create("loaded_ws").await.unwrap();
|
||||
|
||||
// Create a fresh manager and load from disk
|
||||
let mut mgr2 = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
mgr2.load_profiles().await.unwrap();
|
||||
|
||||
assert_eq!(mgr2.list(), vec!["loaded_ws"]);
|
||||
let profile = mgr2.get("loaded_ws").unwrap();
|
||||
assert_eq!(profile.name, "loaded_ws");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_manager_export_redacts_credentials() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
mgr.create("export_test").await.unwrap();
|
||||
|
||||
// Manually set a credential profile
|
||||
if let Some(profile) = mgr.profiles.get_mut("export_test") {
|
||||
profile.credential_profile = Some("secret-cred-id".to_string());
|
||||
}
|
||||
|
||||
let exported = mgr.export("export_test").unwrap();
|
||||
assert!(exported.contains("***"));
|
||||
assert!(!exported.contains("secret-cred-id"));
|
||||
}
|
||||
}
|
||||
+172
-21
@@ -152,44 +152,122 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<(
|
||||
crate::CronCommands::Add {
|
||||
expression,
|
||||
tz,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
let schedule = Schedule::Cron {
|
||||
expr: expression,
|
||||
tz,
|
||||
};
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added cron job {}", job.id);
|
||||
println!(" Expr: {}", job.expression);
|
||||
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
println!("✅ Added agent cron job {}", job.id);
|
||||
println!(" Expr : {}", job.expression);
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added cron job {}", job.id);
|
||||
println!(" Expr: {}", job.expression);
|
||||
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::AddAt { at, command } => {
|
||||
crate::CronCommands::AddAt { at, agent, command } => {
|
||||
let at = chrono::DateTime::parse_from_rfc3339(&at)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid RFC3339 timestamp for --at: {e}"))?
|
||||
.with_timezone(&chrono::Utc);
|
||||
let schedule = Schedule::At { at };
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
println!("✅ Added one-shot agent cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::AddEvery { every_ms, command } => {
|
||||
crate::CronCommands::AddEvery {
|
||||
every_ms,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
let schedule = Schedule::Every { every_ms };
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added interval cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
println!("✅ Added interval agent cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt : {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added interval cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::Once { delay, command } => {
|
||||
let job = add_once(config, &delay, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
crate::CronCommands::Once {
|
||||
delay,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
if agent {
|
||||
let duration = parse_delay(&delay)?;
|
||||
let at = chrono::Utc::now() + duration;
|
||||
let schedule = Schedule::At { at };
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
println!("✅ Added one-shot agent cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_once(config, &delay, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::Update {
|
||||
@@ -686,4 +764,77 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("blocked by security policy"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_agent_flag_creates_agent_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/15 * * * *".into(),
|
||||
tz: None,
|
||||
agent: true,
|
||||
command: "Check server health: disk space, memory, CPU load".into(),
|
||||
},
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Agent);
|
||||
assert_eq!(
|
||||
jobs[0].prompt.as_deref(),
|
||||
Some("Check server health: disk space, memory, CPU load")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_agent_flag_bypasses_shell_security_validation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut config = test_config(&tmp);
|
||||
config.autonomy.allowed_commands = vec!["echo".into()];
|
||||
config.autonomy.level = crate::security::AutonomyLevel::Supervised;
|
||||
|
||||
// Without --agent, a natural language string would be blocked by shell
|
||||
// security policy. With --agent, it routes to agent job and skips
|
||||
// shell validation entirely.
|
||||
let result = handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/15 * * * *".into(),
|
||||
tz: None,
|
||||
agent: true,
|
||||
command: "Check server health: disk space, memory, CPU load".into(),
|
||||
},
|
||||
&config,
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Agent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_without_agent_flag_defaults_to_shell_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/5 * * * *".into(),
|
||||
tz: None,
|
||||
agent: false,
|
||||
command: "echo ok".into(),
|
||||
},
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Shell);
|
||||
assert_eq!(jobs[0].command, "echo ok");
|
||||
}
|
||||
}
|
||||
|
||||
+36
-23
@@ -53,7 +53,7 @@ pub async fn run(config: Config) -> Result<()> {
|
||||
|
||||
pub async fn execute_job_now(config: &Config, job: &CronJob) -> (bool, String) {
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
execute_job_with_retry(config, &security, job).await
|
||||
Box::pin(execute_job_with_retry(config, &security, job)).await
|
||||
}
|
||||
|
||||
async fn execute_job_with_retry(
|
||||
@@ -68,7 +68,7 @@ async fn execute_job_with_retry(
|
||||
for attempt in 0..=retries {
|
||||
let (success, output) = match job.job_type {
|
||||
JobType::Shell => run_job_command(config, security, job).await,
|
||||
JobType::Agent => run_agent_job(config, security, job).await,
|
||||
JobType::Agent => Box::pin(run_agent_job(config, security, job)).await,
|
||||
};
|
||||
last_output = output;
|
||||
|
||||
@@ -101,18 +101,21 @@ async fn process_due_jobs(
|
||||
crate::health::mark_component_ok(component);
|
||||
|
||||
let max_concurrent = config.scheduler.max_concurrent.max(1);
|
||||
let mut in_flight =
|
||||
stream::iter(
|
||||
jobs.into_iter().map(|job| {
|
||||
let config = config.clone();
|
||||
let security = Arc::clone(security);
|
||||
let component = component.to_owned();
|
||||
async move {
|
||||
execute_and_persist_job(&config, security.as_ref(), &job, &component).await
|
||||
}
|
||||
}),
|
||||
)
|
||||
.buffer_unordered(max_concurrent);
|
||||
let mut in_flight = stream::iter(jobs.into_iter().map(|job| {
|
||||
let config = config.clone();
|
||||
let security = Arc::clone(security);
|
||||
let component = component.to_owned();
|
||||
async move {
|
||||
Box::pin(execute_and_persist_job(
|
||||
&config,
|
||||
security.as_ref(),
|
||||
&job,
|
||||
&component,
|
||||
))
|
||||
.await
|
||||
}
|
||||
}))
|
||||
.buffer_unordered(max_concurrent);
|
||||
|
||||
while let Some((job_id, success, output)) = in_flight.next().await {
|
||||
if !success {
|
||||
@@ -131,9 +134,17 @@ async fn execute_and_persist_job(
|
||||
warn_if_high_frequency_agent_job(job);
|
||||
|
||||
let started_at = Utc::now();
|
||||
let (success, output) = execute_job_with_retry(config, security, job).await;
|
||||
let (success, output) = Box::pin(execute_job_with_retry(config, security, job)).await;
|
||||
let finished_at = Utc::now();
|
||||
let success = persist_job_result(config, job, success, &output, started_at, finished_at).await;
|
||||
let success = Box::pin(persist_job_result(
|
||||
config,
|
||||
job,
|
||||
success,
|
||||
&output,
|
||||
started_at,
|
||||
finished_at,
|
||||
))
|
||||
.await;
|
||||
|
||||
(job.id.clone(), success, output)
|
||||
}
|
||||
@@ -170,7 +181,7 @@ async fn run_agent_job(
|
||||
|
||||
let run_result = match job.session_target {
|
||||
SessionTarget::Main | SessionTarget::Isolated => {
|
||||
crate::agent::run(
|
||||
Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
Some(prefixed_prompt),
|
||||
None,
|
||||
@@ -179,7 +190,8 @@ async fn run_agent_job(
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
)
|
||||
job.allowed_tools.clone(),
|
||||
))
|
||||
.await
|
||||
}
|
||||
};
|
||||
@@ -557,6 +569,7 @@ mod tests {
|
||||
enabled: true,
|
||||
delivery: DeliveryConfig::default(),
|
||||
delete_after_run: false,
|
||||
allowed_tools: None,
|
||||
created_at: Utc::now(),
|
||||
next_run: Utc::now(),
|
||||
last_run: None,
|
||||
@@ -742,7 +755,7 @@ mod tests {
|
||||
.unwrap();
|
||||
let job = test_job("sh ./retry-once.sh");
|
||||
|
||||
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(execute_job_with_retry(&config, &security, &job)).await;
|
||||
assert!(success);
|
||||
assert!(output.contains("recovered"));
|
||||
}
|
||||
@@ -757,7 +770,7 @@ mod tests {
|
||||
|
||||
let job = test_job("ls always_missing_for_retry_test");
|
||||
|
||||
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(execute_job_with_retry(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("always_missing_for_retry_test"));
|
||||
}
|
||||
@@ -771,7 +784,7 @@ mod tests {
|
||||
job.prompt = Some("Say hello".into());
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
|
||||
let (success, output) = run_agent_job(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("agent job failed:"));
|
||||
}
|
||||
@@ -786,7 +799,7 @@ mod tests {
|
||||
job.prompt = Some("Say hello".into());
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
|
||||
let (success, output) = run_agent_job(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("blocked by security policy"));
|
||||
assert!(output.contains("read-only"));
|
||||
@@ -802,7 +815,7 @@ mod tests {
|
||||
job.prompt = Some("Say hello".into());
|
||||
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
|
||||
let (success, output) = run_agent_job(&config, &security, &job).await;
|
||||
let (success, output) = Box::pin(run_agent_job(&config, &security, &job)).await;
|
||||
assert!(!success);
|
||||
assert!(output.contains("blocked by security policy"));
|
||||
assert!(output.contains("rate limit exceeded"));
|
||||
|
||||
@@ -453,6 +453,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
},
|
||||
last_status: row.get(15)?,
|
||||
last_output: row.get(16)?,
|
||||
allowed_tools: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -115,6 +115,11 @@ pub struct CronJob {
|
||||
pub enabled: bool,
|
||||
pub delivery: DeliveryConfig,
|
||||
pub delete_after_run: bool,
|
||||
/// Optional allowlist of tool names this cron job may use.
|
||||
/// When `Some(list)`, only tools whose name is in the list are available.
|
||||
/// When `None`, all tools are available (backward compatible default).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub next_run: DateTime<Utc>,
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
@@ -144,6 +149,7 @@ pub struct CronJobPatch {
|
||||
pub model: Option<String>,
|
||||
pub session_target: Option<SessionTarget>,
|
||||
pub delete_after_run: Option<bool>,
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+145
-61
@@ -77,7 +77,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
|
||||
max_backoff,
|
||||
move || {
|
||||
let cfg = channels_cfg.clone();
|
||||
async move { crate::channels::start_channels(cfg).await }
|
||||
async move { Box::pin(crate::channels::start_channels(cfg)).await }
|
||||
},
|
||||
));
|
||||
} else {
|
||||
@@ -203,14 +203,17 @@ where
|
||||
}
|
||||
|
||||
async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
use crate::heartbeat::engine::HeartbeatEngine;
|
||||
|
||||
let observer: std::sync::Arc<dyn crate::observability::Observer> =
|
||||
std::sync::Arc::from(crate::observability::create_observer(&config.observability));
|
||||
let engine = crate::heartbeat::engine::HeartbeatEngine::new(
|
||||
let engine = HeartbeatEngine::new(
|
||||
config.heartbeat.clone(),
|
||||
config.workspace_dir.clone(),
|
||||
observer,
|
||||
);
|
||||
let delivery = heartbeat_delivery_target(&config)?;
|
||||
let delivery = resolve_heartbeat_delivery(&config)?;
|
||||
let two_phase = config.heartbeat.two_phase;
|
||||
|
||||
let interval_mins = config.heartbeat.interval_minutes.max(5);
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
|
||||
@@ -218,16 +221,74 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let file_tasks = engine.collect_tasks().await?;
|
||||
let tasks = heartbeat_tasks_for_tick(file_tasks, config.heartbeat.message.as_deref());
|
||||
// Collect runnable tasks (active only, sorted by priority)
|
||||
let mut tasks = engine.collect_runnable_tasks().await?;
|
||||
if tasks.is_empty() {
|
||||
continue;
|
||||
// Try fallback message
|
||||
if let Some(fallback) = config
|
||||
.heartbeat
|
||||
.message
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|m| !m.is_empty())
|
||||
{
|
||||
tasks.push(crate::heartbeat::engine::HeartbeatTask {
|
||||
text: fallback.to_string(),
|
||||
priority: crate::heartbeat::engine::TaskPriority::Medium,
|
||||
status: crate::heartbeat::engine::TaskStatus::Active,
|
||||
});
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let prompt = format!("[Heartbeat Task] {task}");
|
||||
// ── Phase 1: LLM decision (two-phase mode) ──────────────
|
||||
let tasks_to_run = if two_phase {
|
||||
let decision_prompt = HeartbeatEngine::build_decision_prompt(&tasks);
|
||||
match Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
Some(decision_prompt),
|
||||
None,
|
||||
None,
|
||||
0.0, // Low temperature for deterministic decision
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
let indices = HeartbeatEngine::parse_decision_response(&response, tasks.len());
|
||||
if indices.is_empty() {
|
||||
tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)");
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
continue;
|
||||
}
|
||||
tracing::info!(
|
||||
"💓 Heartbeat Phase 1: run {} of {} tasks",
|
||||
indices.len(),
|
||||
tasks.len()
|
||||
);
|
||||
indices
|
||||
.into_iter()
|
||||
.filter_map(|i| tasks.get(i).cloned())
|
||||
.collect()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("💓 Heartbeat Phase 1 failed, running all tasks: {e}");
|
||||
tasks
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tasks
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
for task in &tasks_to_run {
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let temp = config.default_temperature;
|
||||
match crate::agent::run(
|
||||
match Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
Some(prompt),
|
||||
None,
|
||||
@@ -236,13 +297,14 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
)
|
||||
None,
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
let announcement = if output.trim().is_empty() {
|
||||
"heartbeat task executed".to_string()
|
||||
format!("💓 heartbeat task completed: {}", task.text)
|
||||
} else {
|
||||
output
|
||||
};
|
||||
@@ -272,22 +334,8 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn heartbeat_tasks_for_tick(
|
||||
file_tasks: Vec<String>,
|
||||
fallback_message: Option<&str>,
|
||||
) -> Vec<String> {
|
||||
if !file_tasks.is_empty() {
|
||||
return file_tasks;
|
||||
}
|
||||
|
||||
fallback_message
|
||||
.map(str::trim)
|
||||
.filter(|message| !message.is_empty())
|
||||
.map(|message| vec![message.to_string()])
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn heartbeat_delivery_target(config: &Config) -> Result<Option<(String, String)>> {
|
||||
/// Resolve delivery target: explicit config > auto-detect first configured channel.
|
||||
fn resolve_heartbeat_delivery(config: &Config) -> Result<Option<(String, String)>> {
|
||||
let channel = config
|
||||
.heartbeat
|
||||
.target
|
||||
@@ -302,16 +350,45 @@ fn heartbeat_delivery_target(config: &Config) -> Result<Option<(String, String)>
|
||||
.filter(|value| !value.is_empty());
|
||||
|
||||
match (channel, target) {
|
||||
(None, None) => Ok(None),
|
||||
(Some(_), None) => anyhow::bail!("heartbeat.to is required when heartbeat.target is set"),
|
||||
(None, Some(_)) => anyhow::bail!("heartbeat.target is required when heartbeat.to is set"),
|
||||
// Both explicitly set — validate and use.
|
||||
(Some(channel), Some(target)) => {
|
||||
validate_heartbeat_channel_config(config, channel)?;
|
||||
Ok(Some((channel.to_string(), target.to_string())))
|
||||
}
|
||||
// Only one set — error.
|
||||
(Some(_), None) => anyhow::bail!("heartbeat.to is required when heartbeat.target is set"),
|
||||
(None, Some(_)) => anyhow::bail!("heartbeat.target is required when heartbeat.to is set"),
|
||||
// Neither set — try auto-detect the first configured channel.
|
||||
(None, None) => Ok(auto_detect_heartbeat_channel(config)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Auto-detect the best channel for heartbeat delivery by checking which
|
||||
/// channels are configured. Returns the first match in priority order.
|
||||
fn auto_detect_heartbeat_channel(config: &Config) -> Option<(String, String)> {
|
||||
// Priority order: telegram > discord > slack > mattermost
|
||||
if let Some(tg) = &config.channels_config.telegram {
|
||||
// Use the first allowed_user as target, or fall back to empty (broadcast)
|
||||
let target = tg.allowed_users.first().cloned().unwrap_or_default();
|
||||
if !target.is_empty() {
|
||||
return Some(("telegram".to_string(), target));
|
||||
}
|
||||
}
|
||||
if config.channels_config.discord.is_some() {
|
||||
// Discord requires explicit target — can't auto-detect
|
||||
return None;
|
||||
}
|
||||
if config.channels_config.slack.is_some() {
|
||||
// Slack requires explicit target
|
||||
return None;
|
||||
}
|
||||
if config.channels_config.mattermost.is_some() {
|
||||
// Mattermost requires explicit target
|
||||
return None;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn validate_heartbeat_channel_config(config: &Config, channel: &str) -> Result<()> {
|
||||
match channel.to_ascii_lowercase().as_str() {
|
||||
"telegram" => {
|
||||
@@ -487,75 +564,56 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_tasks_use_file_tasks_when_available() {
|
||||
let tasks =
|
||||
heartbeat_tasks_for_tick(vec!["From file".to_string()], Some("Fallback from config"));
|
||||
assert_eq!(tasks, vec!["From file".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_tasks_fall_back_to_config_message() {
|
||||
let tasks = heartbeat_tasks_for_tick(vec![], Some(" check london time "));
|
||||
assert_eq!(tasks, vec!["check london time".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_tasks_ignore_empty_fallback_message() {
|
||||
let tasks = heartbeat_tasks_for_tick(vec![], Some(" "));
|
||||
assert!(tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_none_when_unset() {
|
||||
fn resolve_delivery_none_when_unset() {
|
||||
let config = Config::default();
|
||||
let target = heartbeat_delivery_target(&config).unwrap();
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
assert!(target.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_requires_to_field() {
|
||||
fn resolve_delivery_requires_to_field() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("telegram".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("heartbeat.to is required when heartbeat.target is set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_requires_target_field() {
|
||||
fn resolve_delivery_requires_target_field() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.to = Some("123456".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("heartbeat.target is required when heartbeat.to is set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_rejects_unsupported_channel() {
|
||||
fn resolve_delivery_rejects_unsupported_channel() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("email".into());
|
||||
config.heartbeat.to = Some("ops@example.com".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("unsupported heartbeat.target channel"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_requires_channel_configuration() {
|
||||
fn resolve_delivery_requires_channel_configuration() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("telegram".into());
|
||||
config.heartbeat.to = Some("123456".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("channels_config.telegram is not configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_accepts_telegram_configuration() {
|
||||
fn resolve_delivery_accepts_telegram_configuration() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("telegram".into());
|
||||
config.heartbeat.to = Some("123456".into());
|
||||
@@ -568,7 +626,33 @@ mod tests {
|
||||
mention_only: false,
|
||||
});
|
||||
|
||||
let target = heartbeat_delivery_target(&config).unwrap();
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
assert_eq!(target, Some(("telegram".to_string(), "123456".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_detect_telegram_when_configured() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.telegram = Some(crate::config::TelegramConfig {
|
||||
bot_token: "bot-token".into(),
|
||||
allowed_users: vec!["user123".into()],
|
||||
stream_mode: crate::config::StreamMode::default(),
|
||||
draft_update_interval_ms: 1000,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
});
|
||||
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
assert_eq!(
|
||||
target,
|
||||
Some(("telegram".to_string(), "user123".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_detect_none_when_no_channels() {
|
||||
let config = Config::default();
|
||||
let target = auto_detect_heartbeat_channel(&config);
|
||||
assert!(target.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
+14
-10
@@ -910,7 +910,7 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
|
||||
/// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk).
|
||||
async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result<String> {
|
||||
let config = state.config.lock().clone();
|
||||
crate::agent::process_message(config, message).await
|
||||
Box::pin(crate::agent::process_message(config, message)).await
|
||||
}
|
||||
|
||||
/// Webhook request body
|
||||
@@ -1238,7 +1238,7 @@ async fn handle_whatsapp_message(
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
// Send reply via WhatsApp
|
||||
if let Err(e) = wa
|
||||
@@ -1346,7 +1346,7 @@ async fn handle_linq_webhook(
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
// Send reply via Linq
|
||||
if let Err(e) = linq
|
||||
@@ -1438,7 +1438,7 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
// Send reply via WATI
|
||||
if let Err(e) = wati
|
||||
@@ -1542,7 +1542,7 @@ async fn handle_nextcloud_talk_webhook(
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_gateway_chat_with_tools(&state, &msg.content).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
Ok(response) => {
|
||||
if let Err(e) = nextcloud_talk
|
||||
.send(&SendMessage::new(response, &msg.reply_target))
|
||||
@@ -2492,11 +2492,11 @@ mod tests {
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
};
|
||||
|
||||
let response = handle_nextcloud_talk_webhook(
|
||||
let response = Box::pin(handle_nextcloud_talk_webhook(
|
||||
State(state),
|
||||
HeaderMap::new(),
|
||||
Bytes::from_static(br#"{"type":"message"}"#),
|
||||
)
|
||||
))
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
@@ -2558,9 +2558,13 @@ mod tests {
|
||||
HeaderValue::from_str(invalid_signature).unwrap(),
|
||||
);
|
||||
|
||||
let response = handle_nextcloud_talk_webhook(State(state), headers, Bytes::from(body))
|
||||
.await
|
||||
.into_response();
|
||||
let response = Box::pin(handle_nextcloud_talk_webhook(
|
||||
State(state),
|
||||
headers,
|
||||
Bytes::from(body),
|
||||
))
|
||||
.await
|
||||
.into_response();
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{Hand, HandContext, HandRun, HandRunStatus};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Load all hand definitions from TOML files in the given directory.
|
||||
///
|
||||
/// Each `.toml` file in `hands_dir` is expected to deserialize into a [`Hand`].
|
||||
/// Files that fail to parse are logged and skipped.
|
||||
pub fn load_hands(hands_dir: &Path) -> Result<Vec<Hand>> {
|
||||
if !hands_dir.is_dir() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut hands = Vec::new();
|
||||
let entries = std::fs::read_dir(hands_dir)
|
||||
.with_context(|| format!("failed to read hands directory: {}", hands_dir.display()))?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("toml") {
|
||||
continue;
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand file: {}", path.display()))?;
|
||||
match toml::from_str::<Hand>(&content) {
|
||||
Ok(hand) => hands.push(hand),
|
||||
Err(e) => {
|
||||
tracing::warn!(path = %path.display(), error = %e, "skipping malformed hand file");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hands)
|
||||
}
|
||||
|
||||
/// Load the rolling context for a hand.
|
||||
///
|
||||
/// Reads from `{hands_dir}/{name}/context.json`. Returns a fresh
|
||||
/// [`HandContext`] if the file does not exist yet.
|
||||
pub fn load_hand_context(hands_dir: &Path, name: &str) -> Result<HandContext> {
|
||||
let path = hands_dir.join(name).join("context.json");
|
||||
if !path.exists() {
|
||||
return Ok(HandContext::new(name));
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand context: {}", path.display()))?;
|
||||
let ctx: HandContext = serde_json::from_str(&content)
|
||||
.with_context(|| format!("failed to parse hand context: {}", path.display()))?;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
/// Persist the rolling context for a hand.
|
||||
///
|
||||
/// Writes to `{hands_dir}/{name}/context.json`, creating the
|
||||
/// directory if it does not exist.
|
||||
pub fn save_hand_context(hands_dir: &Path, context: &HandContext) -> Result<()> {
|
||||
let dir = hands_dir.join(&context.hand_name);
|
||||
std::fs::create_dir_all(&dir)
|
||||
.with_context(|| format!("failed to create hand context dir: {}", dir.display()))?;
|
||||
let path = dir.join("context.json");
|
||||
let json = serde_json::to_string_pretty(context)?;
|
||||
std::fs::write(&path, json)
|
||||
.with_context(|| format!("failed to write hand context: {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn write_hand_toml(dir: &Path, filename: &str, content: &str) {
|
||||
std::fs::write(dir.join(filename), content).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_empty_dir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_nonexistent_dir() {
|
||||
let hands = load_hands(Path::new("/nonexistent/path/hands")).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_parses_valid_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"scanner.toml",
|
||||
r#"
|
||||
name = "scanner"
|
||||
description = "Market scanner"
|
||||
prompt = "Scan markets."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * *"
|
||||
"#,
|
||||
);
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"digest.toml",
|
||||
r#"
|
||||
name = "digest"
|
||||
description = "News digest"
|
||||
prompt = "Digest news."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_skips_malformed_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(tmp.path(), "bad.toml", "this is not valid toml struct");
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"good.toml",
|
||||
r#"
|
||||
name = "good"
|
||||
description = "A good hand"
|
||||
prompt = "Do good things."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 60000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 1);
|
||||
assert_eq!(hands[0].name, "good");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_ignores_non_toml_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
std::fs::write(tmp.path().join("readme.md"), "# Hands").unwrap();
|
||||
std::fs::write(tmp.path().join("notes.txt"), "some notes").unwrap();
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_roundtrip_through_filesystem() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("test-hand");
|
||||
let run = HandRun {
|
||||
hand_name: "test-hand".into(),
|
||||
run_id: "run-001".into(),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec!["found something".into()],
|
||||
knowledge_added: vec!["learned something".into()],
|
||||
duration_ms: Some(500),
|
||||
};
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "test-hand").unwrap();
|
||||
|
||||
assert_eq!(loaded.hand_name, "test-hand");
|
||||
assert_eq!(loaded.total_runs, 1);
|
||||
assert_eq!(loaded.history.len(), 1);
|
||||
assert_eq!(loaded.learned_facts, vec!["learned something"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_context_returns_fresh_when_missing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = load_hand_context(tmp.path(), "nonexistent").unwrap();
|
||||
assert_eq!(ctx.hand_name, "nonexistent");
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_context_creates_directory() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = HandContext::new("new-hand");
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
|
||||
assert!(tmp.path().join("new-hand").join("context.json").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_then_load_preserves_multiple_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("multi");
|
||||
|
||||
for i in 0..5 {
|
||||
let run = HandRun {
|
||||
hand_name: "multi".into(),
|
||||
run_id: format!("run-{i:03}"),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec![format!("finding-{i}")],
|
||||
knowledge_added: vec![format!("fact-{i}")],
|
||||
duration_ms: Some(100),
|
||||
};
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "multi").unwrap();
|
||||
|
||||
assert_eq!(loaded.total_runs, 5);
|
||||
assert_eq!(loaded.history.len(), 3, "history capped at max_history=3");
|
||||
assert_eq!(loaded.learned_facts.len(), 5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::cron::Schedule;
|
||||
|
||||
// ── Hand ───────────────────────────────────────────────────────
|
||||
|
||||
/// A Hand is an autonomous agent package that runs on a schedule,
|
||||
/// accumulates knowledge over time, and reports results.
|
||||
///
|
||||
/// Hands are defined as TOML files in `~/.zeroclaw/hands/` and each
|
||||
/// maintains a rolling context of findings across runs so the agent
|
||||
/// grows smarter with every execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Hand {
|
||||
/// Unique name (also used as directory/file stem)
|
||||
pub name: String,
|
||||
/// Human-readable description of what this hand does
|
||||
pub description: String,
|
||||
/// The schedule this hand runs on (reuses cron schedule types)
|
||||
pub schedule: Schedule,
|
||||
/// System prompt / execution plan for this hand
|
||||
pub prompt: String,
|
||||
/// Domain knowledge lines to inject into context
|
||||
#[serde(default)]
|
||||
pub knowledge: Vec<String>,
|
||||
/// Tools this hand is allowed to use (None = all available)
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// Model override for this hand (None = default provider)
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
/// Whether this hand is currently active
|
||||
#[serde(default = "default_true")]
|
||||
pub active: bool,
|
||||
/// Maximum runs to keep in history
|
||||
#[serde(default = "default_max_runs")]
|
||||
pub max_history: usize,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_max_runs() -> usize {
|
||||
100
|
||||
}
|
||||
|
||||
// ── Hand Run ───────────────────────────────────────────────────
|
||||
|
||||
/// The status of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "status")]
|
||||
pub enum HandRunStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
/// Record of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandRun {
|
||||
/// Name of the hand that produced this run
|
||||
pub hand_name: String,
|
||||
/// Unique identifier for this run
|
||||
pub run_id: String,
|
||||
/// When the run started
|
||||
pub started_at: DateTime<Utc>,
|
||||
/// When the run finished (None if still running)
|
||||
pub finished_at: Option<DateTime<Utc>>,
|
||||
/// Outcome of the run
|
||||
pub status: HandRunStatus,
|
||||
/// Key findings/outputs extracted from this run
|
||||
#[serde(default)]
|
||||
pub findings: Vec<String>,
|
||||
/// New knowledge accumulated and stored to memory
|
||||
#[serde(default)]
|
||||
pub knowledge_added: Vec<String>,
|
||||
/// Wall-clock duration in milliseconds
|
||||
pub duration_ms: Option<u64>,
|
||||
}
|
||||
|
||||
// ── Hand Context ───────────────────────────────────────────────
|
||||
|
||||
/// Rolling context that accumulates across hand runs.
|
||||
///
|
||||
/// Persisted as `~/.zeroclaw/hands/{name}/context.json`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandContext {
|
||||
/// Name of the hand this context belongs to
|
||||
pub hand_name: String,
|
||||
/// Past runs, most-recent first, capped at `Hand::max_history`
|
||||
#[serde(default)]
|
||||
pub history: Vec<HandRun>,
|
||||
/// Persistent facts learned across runs
|
||||
#[serde(default)]
|
||||
pub learned_facts: Vec<String>,
|
||||
/// Timestamp of the last completed run
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
/// Total number of successful runs
|
||||
#[serde(default)]
|
||||
pub total_runs: u64,
|
||||
}
|
||||
|
||||
impl HandContext {
|
||||
/// Create a fresh, empty context for a hand.
|
||||
pub fn new(hand_name: &str) -> Self {
|
||||
Self {
|
||||
hand_name: hand_name.to_string(),
|
||||
history: Vec::new(),
|
||||
learned_facts: Vec::new(),
|
||||
last_run: None,
|
||||
total_runs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed run, updating counters and trimming history.
|
||||
pub fn record_run(&mut self, run: HandRun, max_history: usize) {
|
||||
if run.status == (HandRunStatus::Completed) {
|
||||
self.total_runs += 1;
|
||||
self.last_run = run.finished_at;
|
||||
}
|
||||
|
||||
// Merge new knowledge
|
||||
for fact in &run.knowledge_added {
|
||||
if !self.learned_facts.contains(fact) {
|
||||
self.learned_facts.push(fact.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Insert at the front (most-recent first)
|
||||
self.history.insert(0, run);
|
||||
|
||||
// Cap history length
|
||||
self.history.truncate(max_history);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cron::Schedule;
|
||||
|
||||
fn sample_hand() -> Hand {
|
||||
Hand {
|
||||
name: "market-scanner".into(),
|
||||
description: "Scans market trends and reports findings".into(),
|
||||
schedule: Schedule::Cron {
|
||||
expr: "0 9 * * 1-5".into(),
|
||||
tz: Some("America/New_York".into()),
|
||||
},
|
||||
prompt: "Scan market trends and report key findings.".into(),
|
||||
knowledge: vec!["Focus on tech sector.".into()],
|
||||
allowed_tools: Some(vec!["web_search".into(), "memory".into()]),
|
||||
model: Some("claude-opus-4-6".into()),
|
||||
active: true,
|
||||
max_history: 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_run(name: &str, status: HandRunStatus) -> HandRun {
|
||||
let now = Utc::now();
|
||||
HandRun {
|
||||
hand_name: name.into(),
|
||||
run_id: uuid::Uuid::new_v4().to_string(),
|
||||
started_at: now,
|
||||
finished_at: Some(now),
|
||||
status,
|
||||
findings: vec!["finding-1".into()],
|
||||
knowledge_added: vec!["learned-fact-A".into()],
|
||||
duration_ms: Some(1234),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Deserialization ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_from_toml() {
|
||||
let toml_str = r#"
|
||||
name = "market-scanner"
|
||||
description = "Scans market trends"
|
||||
prompt = "Scan trends."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * 1-5"
|
||||
tz = "America/New_York"
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "market-scanner");
|
||||
assert!(hand.active, "active should default to true");
|
||||
assert_eq!(hand.max_history, 100, "max_history should default to 100");
|
||||
assert!(hand.knowledge.is_empty());
|
||||
assert!(hand.allowed_tools.is_none());
|
||||
assert!(hand.model.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_full_toml() {
|
||||
let toml_str = r#"
|
||||
name = "news-digest"
|
||||
description = "Daily news digest"
|
||||
prompt = "Summarize the day's news."
|
||||
knowledge = ["focus on AI", "include funding rounds"]
|
||||
allowed_tools = ["web_search"]
|
||||
model = "claude-opus-4-6"
|
||||
active = false
|
||||
max_history = 25
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "news-digest");
|
||||
assert!(!hand.active);
|
||||
assert_eq!(hand.max_history, 25);
|
||||
assert_eq!(hand.knowledge.len(), 2);
|
||||
assert_eq!(hand.allowed_tools.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(hand.model.as_deref(), Some("claude-opus-4-6"));
|
||||
assert!(matches!(
|
||||
hand.schedule,
|
||||
Schedule::Every {
|
||||
every_ms: 3_600_000
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_roundtrip_json() {
|
||||
let hand = sample_hand();
|
||||
let json = serde_json::to_string(&hand).unwrap();
|
||||
let parsed: Hand = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, hand.name);
|
||||
assert_eq!(parsed.max_history, hand.max_history);
|
||||
}
|
||||
|
||||
// ── HandRunStatus ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_run_status_serde_roundtrip() {
|
||||
let statuses = vec![
|
||||
HandRunStatus::Running,
|
||||
HandRunStatus::Completed,
|
||||
HandRunStatus::Failed {
|
||||
error: "timeout".into(),
|
||||
},
|
||||
];
|
||||
for status in statuses {
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
let parsed: HandRunStatus = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, status);
|
||||
}
|
||||
}
|
||||
|
||||
// ── HandContext ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn context_new_is_empty() {
|
||||
let ctx = HandContext::new("test-hand");
|
||||
assert_eq!(ctx.hand_name, "test-hand");
|
||||
assert!(ctx.history.is_empty());
|
||||
assert!(ctx.learned_facts.is_empty());
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_run_increments_counters() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 1);
|
||||
assert!(ctx.last_run.is_some());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
assert_eq!(ctx.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_failed_run_does_not_increment_total() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run(
|
||||
"scanner",
|
||||
HandRunStatus::Failed {
|
||||
error: "boom".into(),
|
||||
},
|
||||
);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_caps_history_at_max() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for _ in 0..10 {
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
assert_eq!(ctx.history.len(), 3);
|
||||
assert_eq!(ctx.total_runs, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_deduplicates_learned_facts() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run1 = sample_run("scanner", HandRunStatus::Completed);
|
||||
let run2 = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run1, 100);
|
||||
ctx.record_run(run2, 100);
|
||||
|
||||
// Both runs add "learned-fact-A" but it should appear only once
|
||||
assert_eq!(ctx.learned_facts.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_json_roundtrip() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
let json = serde_json::to_string_pretty(&ctx).unwrap();
|
||||
let parsed: HandContext = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.hand_name, "scanner");
|
||||
assert_eq!(parsed.total_runs, 1);
|
||||
assert_eq!(parsed.history.len(), 1);
|
||||
assert_eq!(parsed.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn most_recent_run_is_first_in_history() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for i in 0..3 {
|
||||
let mut run = sample_run("scanner", HandRunStatus::Completed);
|
||||
run.findings = vec![format!("finding-{i}")];
|
||||
ctx.record_run(run, 100);
|
||||
}
|
||||
assert_eq!(ctx.history[0].findings[0], "finding-2");
|
||||
assert_eq!(ctx.history[2].findings[0], "finding-0");
|
||||
}
|
||||
}
|
||||
+399
-27
@@ -1,11 +1,75 @@
|
||||
use crate::config::HeartbeatConfig;
|
||||
use crate::observability::{Observer, ObserverEvent};
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{self, Duration};
|
||||
use tracing::{info, warn};
|
||||
|
||||
// ── Structured task types ────────────────────────────────────────
|
||||
|
||||
/// Priority level for a heartbeat task.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TaskPriority {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskPriority {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Low => write!(f, "low"),
|
||||
Self::Medium => write!(f, "medium"),
|
||||
Self::High => write!(f, "high"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Status of a heartbeat task.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TaskStatus {
|
||||
Active,
|
||||
Paused,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Active => write!(f, "active"),
|
||||
Self::Paused => write!(f, "paused"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A structured heartbeat task with priority and status metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HeartbeatTask {
|
||||
pub text: String,
|
||||
pub priority: TaskPriority,
|
||||
pub status: TaskStatus,
|
||||
}
|
||||
|
||||
impl HeartbeatTask {
|
||||
pub fn is_runnable(&self) -> bool {
|
||||
self.status == TaskStatus::Active
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for HeartbeatTask {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "[{}] {}", self.priority, self.text)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Engine ───────────────────────────────────────────────────────
|
||||
|
||||
/// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically
|
||||
pub struct HeartbeatEngine {
|
||||
config: HeartbeatConfig,
|
||||
@@ -64,8 +128,8 @@ impl HeartbeatEngine {
|
||||
Ok(self.collect_tasks().await?.len())
|
||||
}
|
||||
|
||||
/// Read HEARTBEAT.md and return all parsed tasks.
|
||||
pub async fn collect_tasks(&self) -> Result<Vec<String>> {
|
||||
/// Read HEARTBEAT.md and return all parsed structured tasks.
|
||||
pub async fn collect_tasks(&self) -> Result<Vec<HeartbeatTask>> {
|
||||
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
|
||||
if !heartbeat_path.exists() {
|
||||
return Ok(Vec::new());
|
||||
@@ -74,13 +138,145 @@ impl HeartbeatEngine {
|
||||
Ok(Self::parse_tasks(&content))
|
||||
}
|
||||
|
||||
/// Parse tasks from HEARTBEAT.md (lines starting with `- `)
|
||||
fn parse_tasks(content: &str) -> Vec<String> {
|
||||
/// Collect only runnable (active) tasks, sorted by priority (high first).
|
||||
pub async fn collect_runnable_tasks(&self) -> Result<Vec<HeartbeatTask>> {
|
||||
let mut tasks: Vec<HeartbeatTask> = self
|
||||
.collect_tasks()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(HeartbeatTask::is_runnable)
|
||||
.collect();
|
||||
// Sort by priority descending (High > Medium > Low)
|
||||
tasks.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
/// Parse tasks from HEARTBEAT.md with structured metadata support.
|
||||
///
|
||||
/// Supports both legacy flat format and new structured format:
|
||||
///
|
||||
/// Legacy:
|
||||
/// `- Check email` → medium priority, active status
|
||||
///
|
||||
/// Structured:
|
||||
/// `- [high] Check email` → high priority, active
|
||||
/// `- [low|paused] Review old PRs` → low priority, paused
|
||||
/// `- [completed] Old task` → medium priority, completed
|
||||
fn parse_tasks(content: &str) -> Vec<HeartbeatTask> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
trimmed.strip_prefix("- ").map(ToString::to_string)
|
||||
let text = trimmed.strip_prefix("- ")?;
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(Self::parse_task_line(text))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a single task line into a structured `HeartbeatTask`.
|
||||
///
|
||||
/// Format: `[priority|status] task text` or just `task text`.
|
||||
fn parse_task_line(text: &str) -> HeartbeatTask {
|
||||
if let Some(rest) = text.strip_prefix('[') {
|
||||
if let Some((meta, task_text)) = rest.split_once(']') {
|
||||
let task_text = task_text.trim();
|
||||
if !task_text.is_empty() {
|
||||
let (priority, status) = Self::parse_meta(meta);
|
||||
return HeartbeatTask {
|
||||
text: task_text.to_string(),
|
||||
priority,
|
||||
status,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
// No metadata — default to medium/active
|
||||
HeartbeatTask {
|
||||
text: text.to_string(),
|
||||
priority: TaskPriority::Medium,
|
||||
status: TaskStatus::Active,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse metadata tags like `high`, `low|paused`, `completed`.
|
||||
fn parse_meta(meta: &str) -> (TaskPriority, TaskStatus) {
|
||||
let mut priority = TaskPriority::Medium;
|
||||
let mut status = TaskStatus::Active;
|
||||
|
||||
for part in meta.split('|') {
|
||||
match part.trim().to_ascii_lowercase().as_str() {
|
||||
"high" => priority = TaskPriority::High,
|
||||
"medium" | "med" => priority = TaskPriority::Medium,
|
||||
"low" => priority = TaskPriority::Low,
|
||||
"active" => status = TaskStatus::Active,
|
||||
"paused" | "pause" => status = TaskStatus::Paused,
|
||||
"completed" | "complete" | "done" => status = TaskStatus::Completed,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
(priority, status)
|
||||
}
|
||||
|
||||
/// Build the Phase 1 LLM decision prompt for two-phase heartbeat.
|
||||
pub fn build_decision_prompt(tasks: &[HeartbeatTask]) -> String {
|
||||
let mut prompt = String::from(
|
||||
"You are a heartbeat scheduler. Review the following periodic tasks and decide \
|
||||
whether any should be executed right now.\n\n\
|
||||
Consider:\n\
|
||||
- Task priority (high tasks are more urgent)\n\
|
||||
- Whether the task is time-sensitive or can wait\n\
|
||||
- Whether running the task now would provide value\n\n\
|
||||
Tasks:\n",
|
||||
);
|
||||
|
||||
for (i, task) in tasks.iter().enumerate() {
|
||||
use std::fmt::Write;
|
||||
let _ = writeln!(prompt, "{}. [{}] {}", i + 1, task.priority, task.text);
|
||||
}
|
||||
|
||||
prompt.push_str(
|
||||
"\nRespond with ONLY one of:\n\
|
||||
- `run: 1,2,3` (comma-separated task numbers to execute)\n\
|
||||
- `skip` (nothing needs to run right now)\n\n\
|
||||
Be conservative — skip if tasks are routine and not time-sensitive.",
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Parse the Phase 1 LLM decision response.
|
||||
///
|
||||
/// Returns indices of tasks to run, or empty vec if skipped.
|
||||
pub fn parse_decision_response(response: &str, task_count: usize) -> Vec<usize> {
|
||||
let trimmed = response.trim().to_ascii_lowercase();
|
||||
|
||||
if trimmed == "skip" || trimmed.starts_with("skip") {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Look for "run: 1,2,3" pattern
|
||||
let numbers_part = if let Some(after_run) = trimmed.strip_prefix("run:") {
|
||||
after_run.trim()
|
||||
} else if let Some(after_run) = trimmed.strip_prefix("run ") {
|
||||
after_run.trim()
|
||||
} else {
|
||||
// Try to parse as bare numbers
|
||||
trimmed.as_str()
|
||||
};
|
||||
|
||||
numbers_part
|
||||
.split(',')
|
||||
.filter_map(|s| {
|
||||
let n: usize = s.trim().parse().ok()?;
|
||||
if n >= 1 && n <= task_count {
|
||||
Some(n - 1) // Convert to 0-indexed
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -93,10 +289,14 @@ impl HeartbeatEngine {
|
||||
# Add tasks below (one per line, starting with `- `)\n\
|
||||
# The agent will check this file on each heartbeat tick.\n\
|
||||
#\n\
|
||||
# Format: - [priority|status] Task description\n\
|
||||
# priority: high, medium (default), low\n\
|
||||
# status: active (default), paused, completed\n\
|
||||
#\n\
|
||||
# Examples:\n\
|
||||
# - Check my email for important messages\n\
|
||||
# - [high] Check my email for important messages\n\
|
||||
# - Review my calendar for upcoming events\n\
|
||||
# - Check the weather forecast\n";
|
||||
# - [low|paused] Check the weather forecast\n";
|
||||
tokio::fs::write(&path, default).await?;
|
||||
}
|
||||
Ok(())
|
||||
@@ -112,9 +312,9 @@ mod tests {
|
||||
let content = "# Tasks\n\n- Check email\n- Review calendar\nNot a task\n- Third task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert_eq!(tasks[0], "Check email");
|
||||
assert_eq!(tasks[1], "Review calendar");
|
||||
assert_eq!(tasks[2], "Third task");
|
||||
assert_eq!(tasks[0].text, "Check email");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -133,26 +333,21 @@ mod tests {
|
||||
let content = " - Indented task\n\t- Tab indented";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Indented task");
|
||||
assert_eq!(tasks[1], "Tab indented");
|
||||
assert_eq!(tasks[0].text, "Indented task");
|
||||
assert_eq!(tasks[1].text, "Tab indented");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_dash_without_space_ignored() {
|
||||
let content = "- Real task\n-\n- Another";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
// "-" trimmed = "-", does NOT start with "- " => skipped
|
||||
// "- Real task" => "Real task"
|
||||
// "- Another" => "Another"
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Real task");
|
||||
assert_eq!(tasks[1], "Another");
|
||||
assert_eq!(tasks[0].text, "Real task");
|
||||
assert_eq!(tasks[1].text, "Another");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_trailing_space_bullet_trimmed_to_dash() {
|
||||
// "- " trimmed becomes "-" (trim removes trailing space)
|
||||
// "-" does NOT start with "- " => skipped
|
||||
let content = "- ";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 0);
|
||||
@@ -160,11 +355,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_bullet_with_content_after_spaces() {
|
||||
// "- hello " trimmed becomes "- hello" => starts_with "- " => "hello"
|
||||
let content = "- hello ";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0], "hello");
|
||||
assert_eq!(tasks[0].text, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -172,8 +366,8 @@ mod tests {
|
||||
let content = "- Check email 📧\n- Review calendar 📅\n- 日本語タスク";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert!(tasks[0].contains("📧"));
|
||||
assert!(tasks[2].contains("日本語"));
|
||||
assert!(tasks[0].text.contains('📧'));
|
||||
assert!(tasks[2].text.contains("日本語"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -181,15 +375,15 @@ mod tests {
|
||||
let content = "# Periodic Tasks\n\n## Quick\n- Task A\n\n## Long\n- Task B\n\n* Not a dash bullet\n1. Not numbered";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Task A");
|
||||
assert_eq!(tasks[1], "Task B");
|
||||
assert_eq!(tasks[0].text, "Task A");
|
||||
assert_eq!(tasks[1].text, "Task B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_single_task() {
|
||||
let tasks = HeartbeatEngine::parse_tasks("- Only one");
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0], "Only one");
|
||||
assert_eq!(tasks[0].text, "Only one");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -201,9 +395,153 @@ mod tests {
|
||||
});
|
||||
let tasks = HeartbeatEngine::parse_tasks(&content);
|
||||
assert_eq!(tasks.len(), 100);
|
||||
assert_eq!(tasks[99], "Task 99");
|
||||
assert_eq!(tasks[99].text, "Task 99");
|
||||
}
|
||||
|
||||
// ── Structured task parsing tests ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_task_with_high_priority() {
|
||||
let content = "- [high] Urgent email check";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].text, "Urgent email check");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::High);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_task_with_low_paused() {
|
||||
let content = "- [low|paused] Review old PRs";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].text, "Review old PRs");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Low);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Paused);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_task_completed() {
|
||||
let content = "- [completed] Old task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Completed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_task_without_metadata_defaults() {
|
||||
let content = "- Plain task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].text, "Plain task");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mixed_structured_and_legacy() {
|
||||
let content = "- [high] Urgent\n- Normal task\n- [low|paused] Later";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert_eq!(tasks[0].priority, TaskPriority::High);
|
||||
assert_eq!(tasks[1].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[2].priority, TaskPriority::Low);
|
||||
assert_eq!(tasks[2].status, TaskStatus::Paused);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runnable_filters_paused_and_completed() {
|
||||
let content = "- [high] Active\n- [low|paused] Paused\n- [completed] Done";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
let runnable: Vec<_> = tasks
|
||||
.into_iter()
|
||||
.filter(HeartbeatTask::is_runnable)
|
||||
.collect();
|
||||
assert_eq!(runnable.len(), 1);
|
||||
assert_eq!(runnable[0].text, "Active");
|
||||
}
|
||||
|
||||
// ── Two-phase decision tests ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn decision_prompt_includes_all_tasks() {
|
||||
let tasks = vec![
|
||||
HeartbeatTask {
|
||||
text: "Check email".into(),
|
||||
priority: TaskPriority::High,
|
||||
status: TaskStatus::Active,
|
||||
},
|
||||
HeartbeatTask {
|
||||
text: "Review calendar".into(),
|
||||
priority: TaskPriority::Medium,
|
||||
status: TaskStatus::Active,
|
||||
},
|
||||
];
|
||||
let prompt = HeartbeatEngine::build_decision_prompt(&tasks);
|
||||
assert!(prompt.contains("1. [high] Check email"));
|
||||
assert!(prompt.contains("2. [medium] Review calendar"));
|
||||
assert!(prompt.contains("skip"));
|
||||
assert!(prompt.contains("run:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_skip() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("skip", 3);
|
||||
assert!(indices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_skip_with_reason() {
|
||||
let indices =
|
||||
HeartbeatEngine::parse_decision_response("skip — nothing urgent right now", 3);
|
||||
assert!(indices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_single() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 1", 3);
|
||||
assert_eq!(indices, vec![0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_multiple() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 1, 3", 3);
|
||||
assert_eq!(indices, vec![0, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_out_of_range_ignored() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 1, 5, 2", 3);
|
||||
assert_eq!(indices, vec![0, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_zero_ignored() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 0, 1", 3);
|
||||
assert_eq!(indices, vec![0]);
|
||||
}
|
||||
|
||||
// ── Task display ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn task_display_format() {
|
||||
let task = HeartbeatTask {
|
||||
text: "Check email".into(),
|
||||
priority: TaskPriority::High,
|
||||
status: TaskStatus::Active,
|
||||
};
|
||||
assert_eq!(format!("{task}"), "[high] Check email");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn priority_ordering() {
|
||||
assert!(TaskPriority::High > TaskPriority::Medium);
|
||||
assert!(TaskPriority::Medium > TaskPriority::Low);
|
||||
}
|
||||
|
||||
// ── Async tests ─────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_heartbeat_file_creates_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_heartbeat");
|
||||
@@ -216,6 +554,7 @@ mod tests {
|
||||
assert!(path.exists());
|
||||
let content = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert!(content.contains("Periodic Tasks"));
|
||||
assert!(content.contains("[high]"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
@@ -301,4 +640,37 @@ mod tests {
|
||||
let result = engine.run().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn collect_runnable_tasks_sorts_by_priority() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_runnable_sort");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
tokio::fs::write(
|
||||
dir.join("HEARTBEAT.md"),
|
||||
"- [low] Low task\n- [high] High task\n- Medium task\n- [low|paused] Skip me",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine = HeartbeatEngine::new(
|
||||
HeartbeatConfig {
|
||||
enabled: true,
|
||||
interval_minutes: 30,
|
||||
..HeartbeatConfig::default()
|
||||
},
|
||||
dir.clone(),
|
||||
observer,
|
||||
);
|
||||
|
||||
let tasks = engine.collect_runnable_tasks().await.unwrap();
|
||||
assert_eq!(tasks.len(), 3); // paused one excluded
|
||||
assert_eq!(tasks[0].priority, TaskPriority::High);
|
||||
assert_eq!(tasks[1].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[2].priority, TaskPriority::Low);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
}
|
||||
|
||||
+21
-6
@@ -48,6 +48,7 @@ pub(crate) mod cron;
|
||||
pub(crate) mod daemon;
|
||||
pub(crate) mod doctor;
|
||||
pub mod gateway;
|
||||
pub mod hands;
|
||||
pub(crate) mod hardware;
|
||||
pub(crate) mod health;
|
||||
pub(crate) mod heartbeat;
|
||||
@@ -57,6 +58,7 @@ pub(crate) mod integrations;
|
||||
pub mod memory;
|
||||
pub(crate) mod migration;
|
||||
pub(crate) mod multimodal;
|
||||
pub mod nodes;
|
||||
pub mod observability;
|
||||
pub(crate) mod onboard;
|
||||
pub mod peripherals;
|
||||
@@ -280,15 +282,19 @@ Times are evaluated in UTC by default; use --tz with an IANA \
|
||||
timezone name to override.
|
||||
|
||||
Examples:
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health'")]
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
|
||||
zeroclaw cron add '*/5 * * * *' 'echo ok'")]
|
||||
Add {
|
||||
/// Cron expression
|
||||
expression: String,
|
||||
/// Optional IANA timezone (e.g. America/Los_Angeles)
|
||||
#[arg(long)]
|
||||
tz: Option<String>,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a one-shot scheduled task at an RFC3339 timestamp
|
||||
@@ -303,7 +309,10 @@ Examples:
|
||||
AddAt {
|
||||
/// One-shot timestamp in RFC3339 format
|
||||
at: String,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a fixed-interval scheduled task
|
||||
@@ -318,7 +327,10 @@ Examples:
|
||||
AddEvery {
|
||||
/// Interval in milliseconds
|
||||
every_ms: u64,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a one-shot delayed task (e.g. "30m", "2h", "1d")
|
||||
@@ -335,7 +347,10 @@ Examples:
|
||||
Once {
|
||||
/// Delay duration
|
||||
delay: String,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Remove a scheduled task
|
||||
|
||||
+41
-14
@@ -37,7 +37,7 @@ use anyhow::{bail, Context, Result};
|
||||
use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use dialoguer::{Input, Password};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Write;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use tracing::{info, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
@@ -325,11 +325,12 @@ override with --tz and an IANA timezone name.
|
||||
|
||||
Examples:
|
||||
zeroclaw cron list
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health'
|
||||
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder'
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
|
||||
zeroclaw cron add '*/5 * * * *' 'echo ok'
|
||||
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder' --agent
|
||||
zeroclaw cron add-every 60000 'Ping heartbeat'
|
||||
zeroclaw cron once 30m 'Run backup in 30 minutes'
|
||||
zeroclaw cron once 30m 'Run backup in 30 minutes' --agent
|
||||
zeroclaw cron pause <task-id>
|
||||
zeroclaw cron update <task-id> --expression '0 8 * * *' --tz Europe/London")]
|
||||
Cron {
|
||||
@@ -718,10 +719,11 @@ async fn main() -> Result<()> {
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
||||
// Onboard runs quick setup by default, or the interactive wizard with --interactive.
|
||||
// The onboard wizard uses reqwest::blocking internally, which creates its own
|
||||
// Tokio runtime. To avoid "Cannot drop a runtime in a context where blocking is
|
||||
// not allowed", we run the wizard on a blocking thread via spawn_blocking.
|
||||
// Onboard auto-detects the environment: if stdin/stdout are a TTY and no
|
||||
// provider flags were given, it runs the full interactive wizard; otherwise
|
||||
// it runs the quick (scriptable) setup. This means `curl … | bash` and
|
||||
// `zeroclaw onboard --api-key …` both take the fast path, while a bare
|
||||
// `zeroclaw onboard` in a terminal launches the wizard.
|
||||
if let Commands::Onboard {
|
||||
force,
|
||||
reinit,
|
||||
@@ -793,8 +795,16 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect: run the interactive wizard when in a TTY with no
|
||||
// provider flags, quick setup otherwise (scriptable path).
|
||||
let has_provider_flags =
|
||||
api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some();
|
||||
let is_tty = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
|
||||
|
||||
let config = if channels_only {
|
||||
Box::pin(onboard::run_channels_repair_wizard()).await
|
||||
} else if is_tty && !has_provider_flags {
|
||||
Box::pin(onboard::run_wizard(force)).await
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
@@ -834,7 +844,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Auto-start channels if user said yes during wizard
|
||||
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
||||
channels::start_channels(config).await?;
|
||||
Box::pin(channels::start_channels(config)).await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
@@ -870,7 +880,7 @@ async fn main() -> Result<()> {
|
||||
} => {
|
||||
let final_temperature = temperature.unwrap_or(config.default_temperature);
|
||||
|
||||
agent::run(
|
||||
Box::pin(agent::run(
|
||||
config,
|
||||
message,
|
||||
provider,
|
||||
@@ -879,7 +889,8 @@ async fn main() -> Result<()> {
|
||||
peripheral,
|
||||
true,
|
||||
session_state_file,
|
||||
)
|
||||
None,
|
||||
))
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
@@ -1178,8 +1189,8 @@ async fn main() -> Result<()> {
|
||||
},
|
||||
|
||||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => channels::start_channels(config).await,
|
||||
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
||||
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
|
||||
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
|
||||
other => channels::handle_command(other, &config).await,
|
||||
},
|
||||
|
||||
@@ -2206,6 +2217,22 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_rejects_removed_interactive_flag() {
|
||||
// --interactive was removed; onboard auto-detects TTY instead.
|
||||
assert!(Cli::try_parse_from(["zeroclaw", "onboard", "--interactive"]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_bare_parses() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "onboard"]).expect("bare onboard should parse");
|
||||
|
||||
match cli.command {
|
||||
Commands::Onboard { .. } => {}
|
||||
other => panic!("expected onboard command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_parses_estop_default_engage() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "estop"]).expect("estop command should parse");
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
//! LLM-driven memory consolidation.
|
||||
//!
|
||||
//! After each conversation turn, extracts structured information:
|
||||
//! - `history_entry`: A timestamped summary for the daily conversation log.
|
||||
//! - `memory_update`: New facts, preferences, or decisions worth remembering
|
||||
//! long-term (or `null` if nothing new was learned).
|
||||
//!
|
||||
//! This two-phase approach replaces the naive raw-message auto-save with
|
||||
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
|
||||
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
/// Output of consolidation extraction.
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct ConsolidationResult {
|
||||
/// Brief timestamped summary for the conversation history log.
|
||||
pub history_entry: String,
|
||||
/// New facts/preferences/decisions to store long-term, or None.
|
||||
pub memory_update: Option<String>,
|
||||
}
|
||||
|
||||
const CONSOLIDATION_SYSTEM_PROMPT: &str = r#"You are a memory consolidation engine. Given a conversation turn, extract:
|
||||
1. "history_entry": A brief summary of what happened in this turn (1-2 sentences). Include the key topic or action.
|
||||
2. "memory_update": Any NEW facts, preferences, decisions, or commitments worth remembering long-term. Return null if nothing new was learned.
|
||||
|
||||
Respond ONLY with valid JSON: {"history_entry": "...", "memory_update": "..." or null}
|
||||
Do not include any text outside the JSON object."#;
|
||||
|
||||
/// Run two-phase LLM-driven consolidation on a conversation turn.
|
||||
///
|
||||
/// Phase 1: Write a history entry to the Daily memory category.
|
||||
/// Phase 2: Write a memory update to the Core category (if the LLM identified new facts).
|
||||
///
|
||||
/// This function is designed to be called fire-and-forget via `tokio::spawn`.
|
||||
pub async fn consolidate_turn(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
assistant_response: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let turn_text = format!("User: {user_message}\nAssistant: {assistant_response}");
|
||||
|
||||
// Truncate very long turns to avoid wasting tokens on consolidation.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8 (e.g. CJK text).
|
||||
let truncated = if turn_text.len() > 4000 {
|
||||
let mut end = 4000;
|
||||
while end > 0 && !turn_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}…", &turn_text[..end])
|
||||
} else {
|
||||
turn_text.clone()
|
||||
};
|
||||
|
||||
let raw = provider
|
||||
.chat_with_system(Some(CONSOLIDATION_SYSTEM_PROMPT), &truncated, model, 0.1)
|
||||
.await?;
|
||||
|
||||
let result: ConsolidationResult = parse_consolidation_response(&raw, &turn_text);
|
||||
|
||||
// Phase 1: Write history entry to Daily category.
|
||||
let date = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||||
let history_key = format!("daily_{date}_{}", uuid::Uuid::new_v4());
|
||||
memory
|
||||
.store(
|
||||
&history_key,
|
||||
&result.history_entry,
|
||||
MemoryCategory::Daily,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Phase 2: Write memory update to Core category (if present).
|
||||
if let Some(ref update) = result.memory_update {
|
||||
if !update.trim().is_empty() {
|
||||
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
|
||||
memory
|
||||
.store(&mem_key, update, MemoryCategory::Core, None)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse the LLM's consolidation response, with fallback for malformed JSON.
|
||||
fn parse_consolidation_response(raw: &str, fallback_text: &str) -> ConsolidationResult {
|
||||
// Try to extract JSON from the response (LLM may wrap in markdown code blocks).
|
||||
let cleaned = raw
|
||||
.trim()
|
||||
.trim_start_matches("```json")
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim();
|
||||
|
||||
serde_json::from_str(cleaned).unwrap_or_else(|_| {
|
||||
// Fallback: use truncated turn text as history entry.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let summary = if fallback_text.len() > 200 {
|
||||
let mut end = 200;
|
||||
while end > 0 && !fallback_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}…", &fallback_text[..end])
|
||||
} else {
|
||||
fallback_text.to_string()
|
||||
};
|
||||
ConsolidationResult {
|
||||
history_entry: summary,
|
||||
memory_update: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_valid_json_response() {
|
||||
let raw = r#"{"history_entry": "User asked about Rust.", "memory_update": "User prefers Rust over Go."}"#;
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "User asked about Rust.");
|
||||
assert_eq!(
|
||||
result.memory_update.as_deref(),
|
||||
Some("User prefers Rust over Go.")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_with_null_memory() {
|
||||
let raw = r#"{"history_entry": "Routine greeting.", "memory_update": null}"#;
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "Routine greeting.");
|
||||
assert!(result.memory_update.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_wrapped_in_code_block() {
|
||||
let raw =
|
||||
"```json\n{\"history_entry\": \"Discussed deployment.\", \"memory_update\": null}\n```";
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "Discussed deployment.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_on_malformed_response() {
|
||||
let raw = "I'm sorry, I can't do that.";
|
||||
let result = parse_consolidation_response(raw, "User: hello\nAssistant: hi");
|
||||
assert_eq!(result.history_entry, "User: hello\nAssistant: hi");
|
||||
assert!(result.memory_update.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_truncates_long_text() {
|
||||
let long_text = "x".repeat(500);
|
||||
let result = parse_consolidation_response("invalid", &long_text);
|
||||
// 200 bytes + "…" (3 bytes in UTF-8) = 203
|
||||
assert!(result.history_entry.len() <= 203);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_truncates_cjk_text_without_panic() {
|
||||
// Each CJK character is 3 bytes in UTF-8; byte index 200 may land
|
||||
// inside a character. This must not panic.
|
||||
let cjk_text = "二手书项目".repeat(50); // 250 chars = 750 bytes
|
||||
let result = parse_consolidation_response("invalid", &cjk_text);
|
||||
assert!(result
|
||||
.history_entry
|
||||
.is_char_boundary(result.history_entry.len()));
|
||||
assert!(result.history_entry.ends_with('…'));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod consolidation;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod lucid;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
pub mod transport;
|
||||
|
||||
pub use transport::NodeTransport;
|
||||
@@ -0,0 +1,235 @@
|
||||
//! Corporate-friendly secure node transport using standard HTTPS + HMAC-SHA256 authentication.
|
||||
//!
|
||||
//! All inter-node traffic uses plain HTTPS on port 443 — no exotic protocols,
|
||||
//! no custom binary framing, no UDP tunneling. This makes the transport
|
||||
//! compatible with corporate proxies, firewalls, and IT audit expectations.
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use chrono::Utc;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Signs a request payload with HMAC-SHA256.
|
||||
///
|
||||
/// Uses `timestamp` + `nonce` alongside the payload to prevent replay attacks.
|
||||
pub fn sign_request(
|
||||
shared_secret: &str,
|
||||
payload: &[u8],
|
||||
timestamp: i64,
|
||||
nonce: &str,
|
||||
) -> Result<String> {
|
||||
let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes())
|
||||
.map_err(|e| anyhow::anyhow!("HMAC key error: {e}"))?;
|
||||
mac.update(×tamp.to_le_bytes());
|
||||
mac.update(nonce.as_bytes());
|
||||
mac.update(payload);
|
||||
Ok(hex::encode(mac.finalize().into_bytes()))
|
||||
}
|
||||
|
||||
/// Verify a signed request, rejecting stale timestamps for replay protection.
|
||||
pub fn verify_request(
|
||||
shared_secret: &str,
|
||||
payload: &[u8],
|
||||
timestamp: i64,
|
||||
nonce: &str,
|
||||
signature: &str,
|
||||
max_age_secs: i64,
|
||||
) -> Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
if (now - timestamp).abs() > max_age_secs {
|
||||
bail!("Request timestamp too old or too far in future");
|
||||
}
|
||||
|
||||
let expected = sign_request(shared_secret, payload, timestamp, nonce)?;
|
||||
Ok(constant_time_eq(expected.as_bytes(), signature.as_bytes()))
|
||||
}
|
||||
|
||||
/// Constant-time comparison to prevent timing attacks.
|
||||
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
|
||||
== 0
|
||||
}
|
||||
|
||||
// ── Node transport client ───────────────────────────────────────
|
||||
|
||||
/// Sends authenticated HTTPS requests to peer nodes.
|
||||
///
|
||||
/// Every outgoing request carries three custom headers:
|
||||
/// - `X-ZeroClaw-Timestamp` — unix epoch seconds
|
||||
/// - `X-ZeroClaw-Nonce` — random UUID v4
|
||||
/// - `X-ZeroClaw-Signature` — HMAC-SHA256 hex digest
|
||||
///
|
||||
/// Incoming requests are verified with the same scheme via [`Self::verify_incoming`].
|
||||
pub struct NodeTransport {
|
||||
http: reqwest::Client,
|
||||
shared_secret: String,
|
||||
max_request_age_secs: i64,
|
||||
}
|
||||
|
||||
impl NodeTransport {
|
||||
pub fn new(shared_secret: String) -> Self {
|
||||
Self {
|
||||
http: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("HTTP client build"),
|
||||
shared_secret,
|
||||
max_request_age_secs: 300, // 5 min replay window
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an authenticated request to a peer node.
|
||||
pub async fn send(
|
||||
&self,
|
||||
node_address: &str,
|
||||
endpoint: &str,
|
||||
payload: serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
let body = serde_json::to_vec(&payload)?;
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let nonce = uuid::Uuid::new_v4().to_string();
|
||||
let signature = sign_request(&self.shared_secret, &body, timestamp, &nonce)?;
|
||||
|
||||
let url = format!("https://{node_address}/api/node-control/{endpoint}");
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("X-ZeroClaw-Timestamp", timestamp.to_string())
|
||||
.header("X-ZeroClaw-Nonce", &nonce)
|
||||
.header("X-ZeroClaw-Signature", &signature)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!(
|
||||
"Node request failed: {} {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Verify an incoming request from a peer node.
|
||||
pub fn verify_incoming(
|
||||
&self,
|
||||
payload: &[u8],
|
||||
timestamp_header: &str,
|
||||
nonce_header: &str,
|
||||
signature_header: &str,
|
||||
) -> Result<bool> {
|
||||
let timestamp: i64 = timestamp_header
|
||||
.parse()
|
||||
.map_err(|_| anyhow::anyhow!("Invalid timestamp header"))?;
|
||||
verify_request(
|
||||
&self.shared_secret,
|
||||
payload,
|
||||
timestamp,
|
||||
nonce_header,
|
||||
signature_header,
|
||||
self.max_request_age_secs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const TEST_SECRET: &str = "test-shared-secret-key";
|
||||
|
||||
#[test]
|
||||
fn sign_request_deterministic() {
|
||||
let sig1 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
|
||||
let sig2 = sign_request(TEST_SECRET, b"hello", 1_700_000_000, "nonce-1").unwrap();
|
||||
assert_eq!(sig1, sig2, "Same inputs must produce the same signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_accepts_valid_signature() {
|
||||
let now = Utc::now().timestamp();
|
||||
let sig = sign_request(TEST_SECRET, b"payload", now, "nonce-a").unwrap();
|
||||
let ok = verify_request(TEST_SECRET, b"payload", now, "nonce-a", &sig, 300).unwrap();
|
||||
assert!(ok, "Valid signature must pass verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_rejects_tampered_payload() {
|
||||
let now = Utc::now().timestamp();
|
||||
let sig = sign_request(TEST_SECRET, b"original", now, "nonce-b").unwrap();
|
||||
let ok = verify_request(TEST_SECRET, b"tampered", now, "nonce-b", &sig, 300).unwrap();
|
||||
assert!(!ok, "Tampered payload must fail verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_rejects_expired_timestamp() {
|
||||
let old = Utc::now().timestamp() - 600;
|
||||
let sig = sign_request(TEST_SECRET, b"data", old, "nonce-c").unwrap();
|
||||
let result = verify_request(TEST_SECRET, b"data", old, "nonce-c", &sig, 300);
|
||||
assert!(result.is_err(), "Expired timestamp must be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_request_rejects_wrong_secret() {
|
||||
let now = Utc::now().timestamp();
|
||||
let sig = sign_request(TEST_SECRET, b"data", now, "nonce-d").unwrap();
|
||||
let ok = verify_request("wrong-secret", b"data", now, "nonce-d", &sig, 300).unwrap();
|
||||
assert!(!ok, "Wrong secret must fail verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constant_time_eq_correctness() {
|
||||
assert!(constant_time_eq(b"abc", b"abc"));
|
||||
assert!(!constant_time_eq(b"abc", b"abd"));
|
||||
assert!(!constant_time_eq(b"abc", b"ab"));
|
||||
assert!(!constant_time_eq(b"", b"a"));
|
||||
assert!(constant_time_eq(b"", b""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_transport_construction() {
|
||||
let transport = NodeTransport::new("secret-key".into());
|
||||
assert_eq!(transport.max_request_age_secs, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_transport_verify_incoming_valid() {
|
||||
let transport = NodeTransport::new(TEST_SECRET.into());
|
||||
let now = Utc::now().timestamp();
|
||||
let payload = b"test-body";
|
||||
let nonce = "incoming-nonce";
|
||||
let sig = sign_request(TEST_SECRET, payload, now, nonce).unwrap();
|
||||
|
||||
let ok = transport
|
||||
.verify_incoming(payload, &now.to_string(), nonce, &sig)
|
||||
.unwrap();
|
||||
assert!(ok, "Valid incoming request must pass verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_transport_verify_incoming_bad_timestamp_header() {
|
||||
let transport = NodeTransport::new(TEST_SECRET.into());
|
||||
let result = transport.verify_incoming(b"body", "not-a-number", "nonce", "sig");
|
||||
assert!(result.is_err(), "Non-numeric timestamp header must error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sign_request_different_nonce_different_signature() {
|
||||
let sig1 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-1").unwrap();
|
||||
let sig2 = sign_request(TEST_SECRET, b"data", 1_700_000_000, "nonce-2").unwrap();
|
||||
assert_ne!(
|
||||
sig1, sig2,
|
||||
"Different nonces must produce different signatures"
|
||||
);
|
||||
}
|
||||
}
|
||||
+2
-1
@@ -4,7 +4,7 @@ pub mod wizard;
|
||||
#[allow(unused_imports)]
|
||||
pub use wizard::{
|
||||
run_channels_repair_wizard, run_models_list, run_models_refresh, run_models_refresh_all,
|
||||
run_models_set, run_models_status, run_quick_setup,
|
||||
run_models_set, run_models_status, run_quick_setup, run_wizard,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -17,6 +17,7 @@ mod tests {
|
||||
fn wizard_functions_are_reexported() {
|
||||
assert_reexport_exists(run_channels_repair_wizard);
|
||||
assert_reexport_exists(run_quick_setup);
|
||||
assert_reexport_exists(run_wizard);
|
||||
assert_reexport_exists(run_models_refresh);
|
||||
assert_reexport_exists(run_models_list);
|
||||
assert_reexport_exists(run_models_set);
|
||||
|
||||
@@ -143,7 +143,12 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
extra_headers: std::collections::HashMap::new(),
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
backup: crate::config::BackupConfig::default(),
|
||||
data_retention: crate::config::DataRetentionConfig::default(),
|
||||
cloud_ops: crate::config::CloudOpsConfig::default(),
|
||||
conversational_ai: crate::config::ConversationalAiConfig::default(),
|
||||
security: crate::config::SecurityConfig::default(),
|
||||
security_ops: crate::config::SecurityOpsConfig::default(),
|
||||
runtime: RuntimeConfig::default(),
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
@@ -159,17 +164,20 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
tunnel: tunnel_config,
|
||||
gateway: crate::config::GatewayConfig::default(),
|
||||
composio: composio_config,
|
||||
microsoft365: crate::config::Microsoft365Config::default(),
|
||||
secrets: secrets_config,
|
||||
browser: BrowserConfig::default(),
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: hardware_config,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -177,6 +185,9 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
tts: crate::config::TtsConfig::default(),
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
nodes: crate::config::NodesConfig::default(),
|
||||
workspace: crate::config::WorkspaceConfig::default(),
|
||||
notion: crate::config::NotionConfig::default(),
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@@ -500,7 +511,12 @@ async fn run_quick_setup_with_home(
|
||||
extra_headers: std::collections::HashMap::new(),
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
backup: crate::config::BackupConfig::default(),
|
||||
data_retention: crate::config::DataRetentionConfig::default(),
|
||||
cloud_ops: crate::config::CloudOpsConfig::default(),
|
||||
conversational_ai: crate::config::ConversationalAiConfig::default(),
|
||||
security: crate::config::SecurityConfig::default(),
|
||||
security_ops: crate::config::SecurityOpsConfig::default(),
|
||||
runtime: RuntimeConfig::default(),
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
@@ -516,17 +532,20 @@ async fn run_quick_setup_with_home(
|
||||
tunnel: crate::config::TunnelConfig::default(),
|
||||
gateway: crate::config::GatewayConfig::default(),
|
||||
composio: ComposioConfig::default(),
|
||||
microsoft365: crate::config::Microsoft365Config::default(),
|
||||
secrets: SecretsConfig::default(),
|
||||
browser: BrowserConfig::default(),
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: crate::config::HardwareConfig::default(),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -534,6 +553,9 @@ async fn run_quick_setup_with_home(
|
||||
tts: crate::config::TtsConfig::default(),
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
nodes: crate::config::NodesConfig::default(),
|
||||
workspace: crate::config::WorkspaceConfig::default(),
|
||||
notion: crate::config::NotionConfig::default(),
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
@@ -4147,6 +4169,23 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
.interact()?;
|
||||
|
||||
if mode_idx == 0 {
|
||||
// Compile-time check: warn early if the feature is not enabled.
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
println!();
|
||||
println!(
|
||||
" {} {}",
|
||||
style("⚠").yellow().bold(),
|
||||
style("The 'whatsapp-web' feature is not compiled in. WhatsApp Web will not work at runtime.").yellow()
|
||||
);
|
||||
println!(
|
||||
" {} Rebuild with: {}",
|
||||
style("→").dim(),
|
||||
style("cargo build --features whatsapp-web").white().bold()
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
println!(" {}", style("Mode: WhatsApp Web").dim());
|
||||
print_bullet("1. Build with --features whatsapp-web");
|
||||
print_bullet(
|
||||
|
||||
@@ -500,19 +500,23 @@ struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
function: Option<Function>,
|
||||
|
||||
// Compatibility: Some providers (e.g., older GLM) may use 'name' directly
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
arguments: Option<String>,
|
||||
|
||||
// Compatibility: DeepSeek sometimes wraps arguments differently
|
||||
#[serde(rename = "parameters", default)]
|
||||
#[serde(
|
||||
rename = "parameters",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
@@ -3094,4 +3098,50 @@ mod tests {
|
||||
// Should not panic
|
||||
let _client = p.http_client();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_none_fields_omitted_from_json() {
|
||||
// Ensures providers like Mistral that reject extra fields (e.g. "name": null)
|
||||
// don't receive them when the ToolCall compat fields are None.
|
||||
let tc = ToolCall {
|
||||
id: Some("call_1".to_string()),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert!(!json.as_object().unwrap().contains_key("name"));
|
||||
assert!(!json.as_object().unwrap().contains_key("arguments"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
// Standard fields must be present
|
||||
assert!(json.as_object().unwrap().contains_key("id"));
|
||||
assert!(json.as_object().unwrap().contains_key("type"));
|
||||
assert!(json.as_object().unwrap().contains_key("function"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_with_compat_fields_serializes_them() {
|
||||
// When compat fields are Some, they should appear in the output.
|
||||
let tc = ToolCall {
|
||||
id: None,
|
||||
kind: None,
|
||||
function: None,
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert_eq!(json["name"], "shell");
|
||||
assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
|
||||
// None fields should be omitted
|
||||
assert!(!json.as_object().unwrap().contains_key("id"));
|
||||
assert!(!json.as_object().unwrap().contains_key("type"));
|
||||
assert!(!json.as_object().unwrap().contains_key("function"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::multimodal;
|
||||
use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities};
|
||||
use crate::providers::ProviderRuntimeOptions;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -472,8 +473,24 @@ fn extract_stream_error_message(event: &Value) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Read the response body incrementally via `bytes_stream()` to avoid
|
||||
/// buffering the entire SSE payload in memory. The previous implementation
|
||||
/// used `response.text().await?` which holds the HTTP connection open until
|
||||
/// every byte has arrived — on high-latency links the long-lived connection
|
||||
/// often drops mid-read, producing the "error decoding response body" failure
|
||||
/// reported in #3544.
|
||||
async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<String> {
|
||||
let body = response.text().await?;
|
||||
let mut body = String::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let bytes = chunk
|
||||
.map_err(|err| anyhow::anyhow!("error reading OpenAI Codex response stream: {err}"))?;
|
||||
let text = std::str::from_utf8(&bytes).map_err(|err| {
|
||||
anyhow::anyhow!("OpenAI Codex response contained invalid UTF-8: {err}")
|
||||
})?;
|
||||
body.push_str(text);
|
||||
}
|
||||
|
||||
if let Some(text) = parse_sse_text(&body)? {
|
||||
return Ok(text);
|
||||
|
||||
@@ -0,0 +1,449 @@
|
||||
//! IAM-aware policy enforcement for Nevis role-to-permission mapping.
|
||||
//!
|
||||
//! Evaluates tool and workspace access based on Nevis roles using a
|
||||
//! deny-by-default policy model. All policy decisions are audit-logged.
|
||||
|
||||
use super::nevis::NevisIdentity;
|
||||
use anyhow::{bail, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Maps a single Nevis role to ZeroClaw permissions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoleMapping {
|
||||
/// Nevis role name (case-insensitive matching).
|
||||
pub nevis_role: String,
|
||||
/// Tool names this role can access. Use `"all"` to grant all tools.
|
||||
pub zeroclaw_permissions: Vec<String>,
|
||||
/// Workspace names this role can access. Use `"all"` for unrestricted.
|
||||
#[serde(default)]
|
||||
pub workspace_access: Vec<String>,
|
||||
}
|
||||
|
||||
/// Result of a policy evaluation.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PolicyDecision {
|
||||
/// Access is allowed.
|
||||
Allow,
|
||||
/// Access is denied, with reason.
|
||||
Deny(String),
|
||||
}
|
||||
|
||||
impl PolicyDecision {
|
||||
pub fn is_allowed(&self) -> bool {
|
||||
matches!(self, PolicyDecision::Allow)
|
||||
}
|
||||
}
|
||||
|
||||
/// IAM policy engine that maps Nevis roles to ZeroClaw tool permissions.
|
||||
///
|
||||
/// Deny-by-default: if no role mapping grants access, the request is denied.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IamPolicy {
|
||||
/// Compiled role mappings indexed by lowercase Nevis role name.
|
||||
role_map: HashMap<String, CompiledRole>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CompiledRole {
|
||||
/// Whether this role has access to all tools.
|
||||
all_tools: bool,
|
||||
/// Specific tool names this role can access (lowercase).
|
||||
allowed_tools: Vec<String>,
|
||||
/// Whether this role has access to all workspaces.
|
||||
all_workspaces: bool,
|
||||
/// Specific workspace names this role can access (lowercase).
|
||||
allowed_workspaces: Vec<String>,
|
||||
}
|
||||
|
||||
impl IamPolicy {
|
||||
/// Build a policy from role mappings (typically from config).
|
||||
///
|
||||
/// Returns an error if duplicate normalized role names are detected,
|
||||
/// since silent last-wins overwrites can accidentally broaden or revoke access.
|
||||
pub fn from_mappings(mappings: &[RoleMapping]) -> Result<Self> {
|
||||
let mut role_map = HashMap::new();
|
||||
|
||||
for mapping in mappings {
|
||||
let key = mapping.nevis_role.trim().to_ascii_lowercase();
|
||||
if key.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let all_tools = mapping
|
||||
.zeroclaw_permissions
|
||||
.iter()
|
||||
.any(|p| p.eq_ignore_ascii_case("all"));
|
||||
let allowed_tools: Vec<String> = mapping
|
||||
.zeroclaw_permissions
|
||||
.iter()
|
||||
.filter(|p| !p.eq_ignore_ascii_case("all"))
|
||||
.map(|p| p.trim().to_ascii_lowercase())
|
||||
.collect();
|
||||
|
||||
let all_workspaces = mapping
|
||||
.workspace_access
|
||||
.iter()
|
||||
.any(|w| w.eq_ignore_ascii_case("all"));
|
||||
let allowed_workspaces: Vec<String> = mapping
|
||||
.workspace_access
|
||||
.iter()
|
||||
.filter(|w| !w.eq_ignore_ascii_case("all"))
|
||||
.map(|w| w.trim().to_ascii_lowercase())
|
||||
.collect();
|
||||
|
||||
if role_map.contains_key(&key) {
|
||||
bail!(
|
||||
"IAM policy: duplicate role mapping for normalized key '{}' \
|
||||
(from nevis_role '{}') — remove or merge the duplicate entry",
|
||||
key,
|
||||
mapping.nevis_role
|
||||
);
|
||||
}
|
||||
|
||||
role_map.insert(
|
||||
key,
|
||||
CompiledRole {
|
||||
all_tools,
|
||||
allowed_tools,
|
||||
all_workspaces,
|
||||
allowed_workspaces,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self { role_map })
|
||||
}
|
||||
|
||||
/// Evaluate whether an identity is allowed to use a specific tool.
|
||||
///
|
||||
/// Deny-by-default: returns `Deny` unless at least one of the identity's
|
||||
/// roles grants access to the requested tool.
|
||||
pub fn evaluate_tool_access(
|
||||
&self,
|
||||
identity: &NevisIdentity,
|
||||
tool_name: &str,
|
||||
) -> PolicyDecision {
|
||||
let normalized_tool = tool_name.trim().to_ascii_lowercase();
|
||||
if normalized_tool.is_empty() {
|
||||
return PolicyDecision::Deny("empty tool name".into());
|
||||
}
|
||||
|
||||
for role in &identity.roles {
|
||||
let key = role.trim().to_ascii_lowercase();
|
||||
if let Some(compiled) = self.role_map.get(&key) {
|
||||
if compiled.all_tools
|
||||
|| compiled.allowed_tools.iter().any(|t| t == &normalized_tool)
|
||||
{
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
role = %key,
|
||||
tool = %normalized_tool,
|
||||
"IAM policy: tool access ALLOWED"
|
||||
);
|
||||
return PolicyDecision::Allow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let reason = format!(
|
||||
"no role grants access to tool '{normalized_tool}' for user '{}'",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
tool = %normalized_tool,
|
||||
"IAM policy: tool access DENIED"
|
||||
);
|
||||
PolicyDecision::Deny(reason)
|
||||
}
|
||||
|
||||
/// Evaluate whether an identity is allowed to access a specific workspace.
|
||||
///
|
||||
/// Deny-by-default: returns `Deny` unless at least one of the identity's
|
||||
/// roles grants access to the requested workspace.
|
||||
pub fn evaluate_workspace_access(
|
||||
&self,
|
||||
identity: &NevisIdentity,
|
||||
workspace: &str,
|
||||
) -> PolicyDecision {
|
||||
let normalized_ws = workspace.trim().to_ascii_lowercase();
|
||||
if normalized_ws.is_empty() {
|
||||
return PolicyDecision::Deny("empty workspace name".into());
|
||||
}
|
||||
|
||||
for role in &identity.roles {
|
||||
let key = role.trim().to_ascii_lowercase();
|
||||
if let Some(compiled) = self.role_map.get(&key) {
|
||||
if compiled.all_workspaces
|
||||
|| compiled
|
||||
.allowed_workspaces
|
||||
.iter()
|
||||
.any(|w| w == &normalized_ws)
|
||||
{
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
role = %key,
|
||||
workspace = %normalized_ws,
|
||||
"IAM policy: workspace access ALLOWED"
|
||||
);
|
||||
return PolicyDecision::Allow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let reason = format!(
|
||||
"no role grants access to workspace '{normalized_ws}' for user '{}'",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
tracing::info!(
|
||||
user_id = %crate::security::redact(&identity.user_id),
|
||||
workspace = %normalized_ws,
|
||||
"IAM policy: workspace access DENIED"
|
||||
);
|
||||
PolicyDecision::Deny(reason)
|
||||
}
|
||||
|
||||
/// Check if the policy has any role mappings configured.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.role_map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_mappings() -> Vec<RoleMapping> {
|
||||
vec![
|
||||
RoleMapping {
|
||||
nevis_role: "admin".into(),
|
||||
zeroclaw_permissions: vec!["all".into()],
|
||||
workspace_access: vec!["all".into()],
|
||||
},
|
||||
RoleMapping {
|
||||
nevis_role: "operator".into(),
|
||||
zeroclaw_permissions: vec![
|
||||
"shell".into(),
|
||||
"file_read".into(),
|
||||
"file_write".into(),
|
||||
"memory_search".into(),
|
||||
],
|
||||
workspace_access: vec!["production".into(), "staging".into()],
|
||||
},
|
||||
RoleMapping {
|
||||
nevis_role: "viewer".into(),
|
||||
zeroclaw_permissions: vec!["file_read".into(), "memory_search".into()],
|
||||
workspace_access: vec!["staging".into()],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn identity_with_roles(roles: Vec<&str>) -> NevisIdentity {
|
||||
NevisIdentity {
|
||||
user_id: "zeroclaw_user".into(),
|
||||
roles: roles.into_iter().map(String::from).collect(),
|
||||
scopes: vec!["openid".into()],
|
||||
mfa_verified: true,
|
||||
session_expiry: u64::MAX,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn admin_gets_all_tools() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "any_tool_name")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn admin_gets_all_workspaces() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "production")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "any_workspace")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operator_gets_subset_of_tools() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["operator"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_tool_access(&identity, "browser")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operator_workspace_access_is_scoped() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["operator"]);
|
||||
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "production")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_workspace_access(&identity, "staging")
|
||||
.is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_workspace_access(&identity, "development")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn viewer_is_read_only() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["viewer"]);
|
||||
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "memory_search")
|
||||
.is_allowed());
|
||||
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_tool_access(&identity, "file_write")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_by_default_for_unknown_role() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["unknown_role"]);
|
||||
|
||||
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
assert!(!policy
|
||||
.evaluate_workspace_access(&identity, "production")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_by_default_for_no_roles() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec![]);
|
||||
|
||||
assert!(!policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_roles_union_permissions() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["viewer", "operator"]);
|
||||
|
||||
// viewer has file_read, operator has shell — both should be accessible
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "file_read")
|
||||
.is_allowed());
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_matching_is_case_insensitive() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["ADMIN"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_matching_is_case_insensitive() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["operator"]);
|
||||
|
||||
assert!(policy.evaluate_tool_access(&identity, "SHELL").is_allowed());
|
||||
assert!(policy
|
||||
.evaluate_tool_access(&identity, "File_Read")
|
||||
.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_tool_name_is_denied() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(!policy.evaluate_tool_access(&identity, "").is_allowed());
|
||||
assert!(!policy.evaluate_tool_access(&identity, " ").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_workspace_name_is_denied() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(!policy.evaluate_workspace_access(&identity, "").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_mappings_deny_everything() {
|
||||
let policy = IamPolicy::from_mappings(&[]).unwrap();
|
||||
let identity = identity_with_roles(vec!["admin"]);
|
||||
|
||||
assert!(policy.is_empty());
|
||||
assert!(!policy.evaluate_tool_access(&identity, "shell").is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_decision_deny_contains_reason() {
|
||||
let policy = IamPolicy::from_mappings(&test_mappings()).unwrap();
|
||||
let identity = identity_with_roles(vec!["viewer"]);
|
||||
|
||||
let decision = policy.evaluate_tool_access(&identity, "shell");
|
||||
match decision {
|
||||
PolicyDecision::Deny(reason) => {
|
||||
assert!(reason.contains("shell"));
|
||||
}
|
||||
PolicyDecision::Allow => panic!("expected deny"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_normalized_roles_are_rejected() {
|
||||
let mappings = vec![
|
||||
RoleMapping {
|
||||
nevis_role: "admin".into(),
|
||||
zeroclaw_permissions: vec!["all".into()],
|
||||
workspace_access: vec!["all".into()],
|
||||
},
|
||||
RoleMapping {
|
||||
nevis_role: " ADMIN ".into(),
|
||||
zeroclaw_permissions: vec!["file_read".into()],
|
||||
workspace_access: vec![],
|
||||
},
|
||||
];
|
||||
let err = IamPolicy::from_mappings(&mappings).unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("duplicate role mapping"),
|
||||
"Expected duplicate role error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_role_name_in_mapping_is_skipped() {
|
||||
let mappings = vec![RoleMapping {
|
||||
nevis_role: " ".into(),
|
||||
zeroclaw_permissions: vec!["all".into()],
|
||||
workspace_access: vec![],
|
||||
}];
|
||||
let policy = IamPolicy::from_mappings(&mappings).unwrap();
|
||||
assert!(policy.is_empty());
|
||||
}
|
||||
}
|
||||
+27
-3
@@ -29,15 +29,20 @@ pub mod domain_matcher;
|
||||
pub mod estop;
|
||||
#[cfg(target_os = "linux")]
|
||||
pub mod firejail;
|
||||
pub mod iam_policy;
|
||||
#[cfg(feature = "sandbox-landlock")]
|
||||
pub mod landlock;
|
||||
pub mod leak_detector;
|
||||
pub mod nevis;
|
||||
pub mod otp;
|
||||
pub mod pairing;
|
||||
pub mod playbook;
|
||||
pub mod policy;
|
||||
pub mod prompt_guard;
|
||||
pub mod secrets;
|
||||
pub mod traits;
|
||||
pub mod vulnerability;
|
||||
pub mod workspace_boundary;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use audit::{AuditEvent, AuditEventType, AuditLogger};
|
||||
@@ -55,19 +60,29 @@ pub use policy::{AutonomyLevel, SecurityPolicy};
|
||||
pub use secrets::SecretStore;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{NoopSandbox, Sandbox};
|
||||
// Nevis IAM integration
|
||||
#[allow(unused_imports)]
|
||||
pub use iam_policy::{IamPolicy, PolicyDecision};
|
||||
#[allow(unused_imports)]
|
||||
pub use nevis::{NevisAuthProvider, NevisIdentity};
|
||||
// Prompt injection defense exports
|
||||
#[allow(unused_imports)]
|
||||
pub use leak_detector::{LeakDetector, LeakResult};
|
||||
#[allow(unused_imports)]
|
||||
pub use prompt_guard::{GuardAction, GuardResult, PromptGuard};
|
||||
#[allow(unused_imports)]
|
||||
pub use workspace_boundary::{BoundaryVerdict, WorkspaceBoundary};
|
||||
|
||||
/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix.
|
||||
/// Redact sensitive values for safe logging. Shows first 4 characters + "***" suffix.
|
||||
/// Uses char-boundary-safe indexing to avoid panics on multi-byte UTF-8 strings.
|
||||
/// This function intentionally breaks the data-flow taint chain for static analysis.
|
||||
pub fn redact(value: &str) -> String {
|
||||
if value.len() <= 4 {
|
||||
let char_count = value.chars().count();
|
||||
if char_count <= 4 {
|
||||
"***".to_string()
|
||||
} else {
|
||||
format!("{}***", &value[..4])
|
||||
let prefix: String = value.chars().take(4).collect();
|
||||
format!("{prefix}***")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,4 +117,13 @@ mod tests {
|
||||
assert_eq!(redact(""), "***");
|
||||
assert_eq!(redact("12345"), "1234***");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redact_handles_multibyte_utf8_without_panic() {
|
||||
// CJK characters are 3 bytes each; slicing at byte 4 would panic
|
||||
// without char-boundary-safe handling.
|
||||
let result = redact("密码是很长的秘密");
|
||||
assert!(result.ends_with("***"));
|
||||
assert!(result.is_char_boundary(result.len()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,587 @@
|
||||
//! Nevis IAM authentication provider for ZeroClaw.
|
||||
//!
|
||||
//! Integrates with Nevis Security Suite (Adnovum) for OAuth2/OIDC token
|
||||
//! validation, FIDO2/passkey verification, and session management. Maps Nevis
|
||||
//! roles to ZeroClaw tool permissions via [`super::iam_policy::IamPolicy`].
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Identity resolved from a validated Nevis token or session.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NevisIdentity {
|
||||
/// Unique user identifier from Nevis.
|
||||
pub user_id: String,
|
||||
/// Nevis roles assigned to this user.
|
||||
pub roles: Vec<String>,
|
||||
/// OAuth2 scopes granted to this session.
|
||||
pub scopes: Vec<String>,
|
||||
/// Whether the user completed MFA (FIDO2/passkey/OTP) in this session.
|
||||
pub mfa_verified: bool,
|
||||
/// When this session expires (seconds since UNIX epoch).
|
||||
pub session_expiry: u64,
|
||||
}
|
||||
|
||||
/// Token validation strategy.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TokenValidationMode {
|
||||
/// Validate JWT locally using cached JWKS keys.
|
||||
Local,
|
||||
/// Validate token by calling the Nevis introspection endpoint.
|
||||
Remote,
|
||||
}
|
||||
|
||||
impl TokenValidationMode {
|
||||
pub fn from_str_config(s: &str) -> Result<Self> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"local" => Ok(Self::Local),
|
||||
"remote" => Ok(Self::Remote),
|
||||
other => bail!("invalid token_validation mode '{other}': expected 'local' or 'remote'"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Authentication provider backed by a Nevis instance.
|
||||
///
|
||||
/// Validates tokens, manages sessions, and resolves identities. The provider
|
||||
/// is designed to be shared across concurrent requests (`Send + Sync`).
|
||||
pub struct NevisAuthProvider {
|
||||
/// Base URL of the Nevis instance (e.g. `https://nevis.example.com`).
|
||||
instance_url: String,
|
||||
/// Nevis realm to authenticate against.
|
||||
realm: String,
|
||||
/// OAuth2 client ID registered in Nevis.
|
||||
client_id: String,
|
||||
/// OAuth2 client secret (decrypted at startup).
|
||||
client_secret: Option<String>,
|
||||
/// Token validation strategy.
|
||||
validation_mode: TokenValidationMode,
|
||||
/// JWKS endpoint for local token validation.
|
||||
jwks_url: Option<String>,
|
||||
/// Whether MFA is required for all authentications.
|
||||
require_mfa: bool,
|
||||
/// Session timeout duration.
|
||||
session_timeout: Duration,
|
||||
/// HTTP client for Nevis API calls.
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NevisAuthProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NevisAuthProvider")
|
||||
.field("instance_url", &self.instance_url)
|
||||
.field("realm", &self.realm)
|
||||
.field("client_id", &self.client_id)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_ref().map(|_| "[REDACTED]"),
|
||||
)
|
||||
.field("validation_mode", &self.validation_mode)
|
||||
.field("jwks_url", &self.jwks_url)
|
||||
.field("require_mfa", &self.require_mfa)
|
||||
.field("session_timeout", &self.session_timeout)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: All fields are Send + Sync. The doc comment promises concurrent use,
|
||||
// so enforce it at compile time to prevent regressions.
|
||||
#[allow(clippy::used_underscore_items)]
|
||||
const _: () = {
|
||||
fn _assert_send_sync<T: Send + Sync>() {}
|
||||
fn _assert() {
|
||||
_assert_send_sync::<NevisAuthProvider>();
|
||||
}
|
||||
};
|
||||
|
||||
impl NevisAuthProvider {
|
||||
/// Create a new Nevis auth provider from config values.
|
||||
///
|
||||
/// `client_secret` should already be decrypted by the config loader.
|
||||
pub fn new(
|
||||
instance_url: String,
|
||||
realm: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
token_validation: &str,
|
||||
jwks_url: Option<String>,
|
||||
require_mfa: bool,
|
||||
session_timeout_secs: u64,
|
||||
) -> Result<Self> {
|
||||
let validation_mode = TokenValidationMode::from_str_config(token_validation)?;
|
||||
|
||||
if validation_mode == TokenValidationMode::Local && jwks_url.is_none() {
|
||||
bail!(
|
||||
"Nevis token_validation is 'local' but no jwks_url is configured. \
|
||||
Either set jwks_url or use token_validation = 'remote'."
|
||||
);
|
||||
}
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.context("Failed to create HTTP client for Nevis")?;
|
||||
|
||||
Ok(Self {
|
||||
instance_url,
|
||||
realm,
|
||||
client_id,
|
||||
client_secret,
|
||||
validation_mode,
|
||||
jwks_url,
|
||||
require_mfa,
|
||||
session_timeout: Duration::from_secs(session_timeout_secs),
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a bearer token and resolve the caller's identity.
|
||||
///
|
||||
/// Returns `NevisIdentity` on success, or an error if the token is invalid,
|
||||
/// expired, or MFA requirements are not met.
|
||||
pub async fn validate_token(&self, token: &str) -> Result<NevisIdentity> {
|
||||
if token.is_empty() {
|
||||
bail!("empty bearer token");
|
||||
}
|
||||
|
||||
let identity = match self.validation_mode {
|
||||
TokenValidationMode::Local => self.validate_token_local(token).await?,
|
||||
TokenValidationMode::Remote => self.validate_token_remote(token).await?,
|
||||
};
|
||||
|
||||
if self.require_mfa && !identity.mfa_verified {
|
||||
bail!(
|
||||
"MFA is required but user '{}' has not completed MFA verification",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
}
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if identity.session_expiry > 0 && identity.session_expiry < now {
|
||||
bail!("Nevis session expired");
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
|
||||
/// Validate token by calling the Nevis introspection endpoint.
|
||||
async fn validate_token_remote(&self, token: &str) -> Result<NevisIdentity> {
|
||||
let introspect_url = format!(
|
||||
"{}/auth/realms/{}/protocol/openid-connect/token/introspect",
|
||||
self.instance_url.trim_end_matches('/'),
|
||||
self.realm,
|
||||
);
|
||||
|
||||
let mut form = vec![("token", token), ("client_id", &self.client_id)];
|
||||
// client_secret is optional (public clients don't need it)
|
||||
let secret_ref;
|
||||
if let Some(ref secret) = self.client_secret {
|
||||
secret_ref = secret.as_str();
|
||||
form.push(("client_secret", secret_ref));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(&introspect_url)
|
||||
.form(&form)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to reach Nevis introspection endpoint")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!(
|
||||
"Nevis introspection returned HTTP {}",
|
||||
resp.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
let body: IntrospectionResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Nevis introspection response")?;
|
||||
|
||||
if !body.active {
|
||||
bail!("Token is not active (revoked or expired)");
|
||||
}
|
||||
|
||||
let user_id = body
|
||||
.sub
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.context("Token has missing or empty `sub` claim")?;
|
||||
|
||||
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
|
||||
roles.sort();
|
||||
roles.dedup();
|
||||
|
||||
Ok(NevisIdentity {
|
||||
user_id,
|
||||
roles,
|
||||
scopes: body
|
||||
.scope
|
||||
.unwrap_or_default()
|
||||
.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect(),
|
||||
mfa_verified: body.acr.as_deref() == Some("mfa")
|
||||
|| body
|
||||
.amr
|
||||
.iter()
|
||||
.flatten()
|
||||
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
|
||||
session_expiry: body.exp.unwrap_or(0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate token locally using JWKS.
|
||||
///
|
||||
/// Local JWT/JWKS validation is not yet implemented. Rather than silently
|
||||
/// falling back to the remote introspection endpoint (which would hide a
|
||||
/// misconfiguration), this returns an explicit error directing the operator
|
||||
/// to use `token_validation = "remote"` until local JWKS support is added.
|
||||
#[allow(clippy::unused_async)] // Will use async when JWKS validation is implemented
|
||||
async fn validate_token_local(&self, token: &str) -> Result<NevisIdentity> {
|
||||
// JWT structure check: header.payload.signature
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
bail!("Invalid JWT structure: expected 3 dot-separated parts");
|
||||
}
|
||||
|
||||
bail!(
|
||||
"Local JWKS token validation is not yet implemented. \
|
||||
Set token_validation = \"remote\" to use the Nevis introspection endpoint."
|
||||
);
|
||||
}
|
||||
|
||||
/// Validate a Nevis session token (cookie-based sessions).
|
||||
pub async fn validate_session(&self, session_token: &str) -> Result<NevisIdentity> {
|
||||
if session_token.is_empty() {
|
||||
bail!("empty session token");
|
||||
}
|
||||
|
||||
let session_url = format!(
|
||||
"{}/auth/realms/{}/protocol/openid-connect/userinfo",
|
||||
self.instance_url.trim_end_matches('/'),
|
||||
self.realm,
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&session_url)
|
||||
.bearer_auth(session_token)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to reach Nevis userinfo endpoint")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!(
|
||||
"Nevis session validation returned HTTP {}",
|
||||
resp.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
let body: UserInfoResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Nevis userinfo response")?;
|
||||
|
||||
if body.sub.trim().is_empty() {
|
||||
bail!("Userinfo response has missing or empty `sub` claim");
|
||||
}
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let mut roles = body.realm_access.map(|ra| ra.roles).unwrap_or_default();
|
||||
roles.sort();
|
||||
roles.dedup();
|
||||
|
||||
let identity = NevisIdentity {
|
||||
user_id: body.sub,
|
||||
roles,
|
||||
scopes: body
|
||||
.scope
|
||||
.unwrap_or_default()
|
||||
.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect(),
|
||||
mfa_verified: body.acr.as_deref() == Some("mfa")
|
||||
|| body
|
||||
.amr
|
||||
.iter()
|
||||
.flatten()
|
||||
.any(|m| m == "fido2" || m == "passkey" || m == "otp" || m == "webauthn"),
|
||||
session_expiry: now + self.session_timeout.as_secs(),
|
||||
};
|
||||
|
||||
if self.require_mfa && !identity.mfa_verified {
|
||||
bail!(
|
||||
"MFA is required but user '{}' has not completed MFA verification",
|
||||
crate::security::redact(&identity.user_id)
|
||||
);
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
|
||||
/// Health check against the Nevis instance.
|
||||
pub async fn health_check(&self) -> Result<()> {
|
||||
let health_url = format!(
|
||||
"{}/auth/realms/{}",
|
||||
self.instance_url.trim_end_matches('/'),
|
||||
self.realm,
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&health_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Nevis health check failed: cannot reach instance")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("Nevis health check failed: HTTP {}", resp.status().as_u16());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Getter for instance URL (for diagnostics).
|
||||
pub fn instance_url(&self) -> &str {
|
||||
&self.instance_url
|
||||
}
|
||||
|
||||
/// Getter for realm.
|
||||
pub fn realm(&self) -> &str {
|
||||
&self.realm
|
||||
}
|
||||
}
|
||||
|
||||
// ── Wire types for Nevis API responses ─────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IntrospectionResponse {
|
||||
active: bool,
|
||||
sub: Option<String>,
|
||||
scope: Option<String>,
|
||||
exp: Option<u64>,
|
||||
#[serde(rename = "realm_access")]
|
||||
realm_access: Option<RealmAccess>,
|
||||
/// Authentication Context Class Reference
|
||||
acr: Option<String>,
|
||||
/// Authentication Methods References
|
||||
amr: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RealmAccess {
|
||||
#[serde(default)]
|
||||
roles: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UserInfoResponse {
|
||||
sub: String,
|
||||
#[serde(rename = "realm_access")]
|
||||
realm_access: Option<RealmAccess>,
|
||||
scope: Option<String>,
|
||||
acr: Option<String>,
|
||||
/// Authentication Methods References
|
||||
amr: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn token_validation_mode_from_str() {
|
||||
assert_eq!(
|
||||
TokenValidationMode::from_str_config("local").unwrap(),
|
||||
TokenValidationMode::Local
|
||||
);
|
||||
assert_eq!(
|
||||
TokenValidationMode::from_str_config("REMOTE").unwrap(),
|
||||
TokenValidationMode::Remote
|
||||
);
|
||||
assert!(TokenValidationMode::from_str_config("invalid").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_mode_requires_jwks_url() {
|
||||
let result = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"local",
|
||||
None, // no JWKS URL
|
||||
false,
|
||||
3600,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("jwks_url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_mode_works_without_jwks_url() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
);
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_stores_config_correctly() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"test-realm".into(),
|
||||
"zeroclaw-client".into(),
|
||||
Some("test-secret".into()),
|
||||
"remote",
|
||||
None,
|
||||
true,
|
||||
7200,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(provider.instance_url(), "https://nevis.example.com");
|
||||
assert_eq!(provider.realm(), "test-realm");
|
||||
assert!(provider.require_mfa);
|
||||
assert_eq!(provider.session_timeout, Duration::from_secs(7200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_redacts_client_secret() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"test-realm".into(),
|
||||
"zeroclaw-client".into(),
|
||||
Some("super-secret-value".into()),
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let debug_output = format!("{:?}", provider);
|
||||
assert!(
|
||||
!debug_output.contains("super-secret-value"),
|
||||
"Debug output must not contain the raw client_secret"
|
||||
);
|
||||
assert!(
|
||||
debug_output.contains("[REDACTED]"),
|
||||
"Debug output must show [REDACTED] for client_secret"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_token_rejects_empty() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = provider.validate_token("").await.unwrap_err();
|
||||
assert!(err.to_string().contains("empty bearer token"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_session_rejects_empty() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"remote",
|
||||
None,
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = provider.validate_session("").await.unwrap_err();
|
||||
assert!(err.to_string().contains("empty session token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nevis_identity_serde_roundtrip() {
|
||||
let identity = NevisIdentity {
|
||||
user_id: "zeroclaw_user".into(),
|
||||
roles: vec!["admin".into(), "operator".into()],
|
||||
scopes: vec!["openid".into(), "profile".into()],
|
||||
mfa_verified: true,
|
||||
session_expiry: 1_700_000_000,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&identity).unwrap();
|
||||
let parsed: NevisIdentity = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.user_id, "zeroclaw_user");
|
||||
assert_eq!(parsed.roles.len(), 2);
|
||||
assert!(parsed.mfa_verified);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_validation_rejects_malformed_jwt() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"local",
|
||||
Some("https://nevis.example.com/.well-known/jwks.json".into()),
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = provider.validate_token("not-a-jwt").await.unwrap_err();
|
||||
assert!(err.to_string().contains("Invalid JWT structure"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_validation_errors_instead_of_silent_fallback() {
|
||||
let provider = NevisAuthProvider::new(
|
||||
"https://nevis.example.com".into(),
|
||||
"master".into(),
|
||||
"zeroclaw-client".into(),
|
||||
None,
|
||||
"local",
|
||||
Some("https://nevis.example.com/.well-known/jwks.json".into()),
|
||||
false,
|
||||
3600,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// A well-formed JWT structure should hit the "not yet implemented" error
|
||||
// instead of silently falling back to remote introspection.
|
||||
let err = provider
|
||||
.validate_token("header.payload.signature")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("not yet implemented"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
//! Incident response playbook definitions and execution engine.
|
||||
//!
|
||||
//! Playbooks define structured response procedures for security incidents.
|
||||
//! Each playbook has named steps, some of which require human approval before
|
||||
//! execution. Playbooks are loaded from JSON files in the configured directory.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
/// A single step in an incident response playbook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct PlaybookStep {
|
||||
/// Machine-readable action identifier (e.g. "isolate_host", "block_ip").
|
||||
pub action: String,
|
||||
/// Human-readable description of what this step does.
|
||||
pub description: String,
|
||||
/// Whether this step requires explicit human approval before execution.
|
||||
#[serde(default)]
|
||||
pub requires_approval: bool,
|
||||
/// Timeout in seconds for this step. Default: 300 (5 minutes).
|
||||
#[serde(default = "default_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
/// An incident response playbook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Playbook {
|
||||
/// Unique playbook name (e.g. "suspicious_login").
|
||||
pub name: String,
|
||||
/// Human-readable description.
|
||||
pub description: String,
|
||||
/// Ordered list of response steps.
|
||||
pub steps: Vec<PlaybookStep>,
|
||||
/// Minimum alert severity that triggers this playbook (low/medium/high/critical).
|
||||
#[serde(default = "default_severity_filter")]
|
||||
pub severity_filter: String,
|
||||
/// Step indices (0-based) that can be auto-approved when below max_auto_severity.
|
||||
#[serde(default)]
|
||||
pub auto_approve_steps: Vec<usize>,
|
||||
}
|
||||
|
||||
fn default_severity_filter() -> String {
|
||||
"medium".into()
|
||||
}
|
||||
|
||||
/// Result of executing a single playbook step.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StepExecutionResult {
|
||||
pub step_index: usize,
|
||||
pub action: String,
|
||||
pub status: StepStatus,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Status of a playbook step.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum StepStatus {
|
||||
/// Step completed successfully.
|
||||
Completed,
|
||||
/// Step is waiting for human approval.
|
||||
PendingApproval,
|
||||
/// Step was skipped (e.g. not applicable).
|
||||
Skipped,
|
||||
/// Step failed with an error.
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StepStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::PendingApproval => write!(f, "pending_approval"),
|
||||
Self::Skipped => write!(f, "skipped"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all playbook definitions from a directory of JSON files.
|
||||
pub fn load_playbooks(dir: &Path) -> Vec<Playbook> {
|
||||
let mut playbooks = Vec::new();
|
||||
|
||||
if !dir.exists() || !dir.is_dir() {
|
||||
return builtin_playbooks();
|
||||
}
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().map_or(false, |ext| ext == "json") {
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(contents) => match serde_json::from_str::<Playbook>(&contents) {
|
||||
Ok(pb) => playbooks.push(pb),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse playbook {}: {e}", path.display());
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to read playbook {}: {e}", path.display());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge built-in playbooks that aren't overridden by user-defined ones
|
||||
for builtin in builtin_playbooks() {
|
||||
if !playbooks.iter().any(|p| p.name == builtin.name) {
|
||||
playbooks.push(builtin);
|
||||
}
|
||||
}
|
||||
|
||||
playbooks
|
||||
}
|
||||
|
||||
/// Severity ordering for comparison: low < medium < high < critical.
|
||||
pub fn severity_level(severity: &str) -> u8 {
|
||||
match severity.to_lowercase().as_str() {
|
||||
"low" => 1,
|
||||
"medium" => 2,
|
||||
"high" => 3,
|
||||
"critical" => 4,
|
||||
// Deny-by-default: unknown severities get the highest level to prevent
|
||||
// auto-approval of unrecognized severity labels.
|
||||
_ => u8::MAX,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether a step can be auto-approved given config constraints.
|
||||
pub fn can_auto_approve(
|
||||
playbook: &Playbook,
|
||||
step_index: usize,
|
||||
alert_severity: &str,
|
||||
max_auto_severity: &str,
|
||||
) -> bool {
|
||||
// Never auto-approve if alert severity exceeds the configured max
|
||||
if severity_level(alert_severity) > severity_level(max_auto_severity) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only auto-approve steps explicitly listed in auto_approve_steps
|
||||
playbook.auto_approve_steps.contains(&step_index)
|
||||
}
|
||||
|
||||
/// Evaluate a playbook step. Returns the result with approval gating.
|
||||
///
|
||||
/// Steps that require approval and cannot be auto-approved will return
|
||||
/// `StepStatus::PendingApproval` without executing.
|
||||
pub fn evaluate_step(
|
||||
playbook: &Playbook,
|
||||
step_index: usize,
|
||||
alert_severity: &str,
|
||||
max_auto_severity: &str,
|
||||
require_approval: bool,
|
||||
) -> StepExecutionResult {
|
||||
let step = match playbook.steps.get(step_index) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return StepExecutionResult {
|
||||
step_index,
|
||||
action: "unknown".into(),
|
||||
status: StepStatus::Failed,
|
||||
message: format!("Step index {step_index} out of range"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Enforce approval gates: steps that require approval must either be
|
||||
// auto-approved or wait for human approval. Never mark an unexecuted
|
||||
// approval-gated step as Completed.
|
||||
if step.requires_approval
|
||||
&& (!require_approval
|
||||
|| !can_auto_approve(playbook, step_index, alert_severity, max_auto_severity))
|
||||
{
|
||||
return StepExecutionResult {
|
||||
step_index,
|
||||
action: step.action.clone(),
|
||||
status: StepStatus::PendingApproval,
|
||||
message: format!(
|
||||
"Step '{}' requires human approval (severity: {alert_severity})",
|
||||
step.description
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
// Step is approved (either doesn't require approval, or was auto-approved)
|
||||
// Actual execution would be delegated to the appropriate tool/system
|
||||
StepExecutionResult {
|
||||
step_index,
|
||||
action: step.action.clone(),
|
||||
status: StepStatus::Completed,
|
||||
message: format!("Executed: {}", step.description),
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in playbook definitions for common incident types.
|
||||
pub fn builtin_playbooks() -> Vec<Playbook> {
|
||||
vec![
|
||||
Playbook {
|
||||
name: "suspicious_login".into(),
|
||||
description: "Respond to suspicious login activity detected by SIEM".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "gather_login_context".into(),
|
||||
description: "Collect login metadata: IP, geo, device fingerprint, time".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "check_threat_intel".into(),
|
||||
description: "Query threat intelligence for source IP reputation".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 30,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "notify_user".into(),
|
||||
description: "Send verification notification to account owner".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "force_password_reset".into(),
|
||||
description: "Force password reset if login confirmed unauthorized".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 120,
|
||||
},
|
||||
],
|
||||
severity_filter: "medium".into(),
|
||||
auto_approve_steps: vec![0, 1],
|
||||
},
|
||||
Playbook {
|
||||
name: "malware_detected".into(),
|
||||
description: "Respond to malware detection on endpoint".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "isolate_endpoint".into(),
|
||||
description: "Network-isolate the affected endpoint".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "collect_forensics".into(),
|
||||
description: "Capture memory dump and disk image for analysis".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 600,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "scan_lateral_movement".into(),
|
||||
description: "Check for lateral movement indicators on adjacent hosts".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "remediate_endpoint".into(),
|
||||
description: "Remove malware and restore endpoint to clean state".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 600,
|
||||
},
|
||||
],
|
||||
severity_filter: "high".into(),
|
||||
auto_approve_steps: vec![1, 2],
|
||||
},
|
||||
Playbook {
|
||||
name: "data_exfiltration_attempt".into(),
|
||||
description: "Respond to suspected data exfiltration".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "block_egress".into(),
|
||||
description: "Block suspicious outbound connections".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 30,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "identify_data_scope".into(),
|
||||
description: "Determine what data may have been accessed or transferred".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "preserve_evidence".into(),
|
||||
description: "Preserve network logs and access records".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 120,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "escalate_to_legal".into(),
|
||||
description: "Notify legal and compliance teams".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
],
|
||||
severity_filter: "critical".into(),
|
||||
auto_approve_steps: vec![1, 2],
|
||||
},
|
||||
Playbook {
|
||||
name: "brute_force".into(),
|
||||
description: "Respond to brute force authentication attempts".into(),
|
||||
steps: vec![
|
||||
PlaybookStep {
|
||||
action: "block_source_ip".into(),
|
||||
description: "Block the attacking source IP at firewall".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 30,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "check_compromised_accounts".into(),
|
||||
description: "Check if any accounts were successfully compromised".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 120,
|
||||
},
|
||||
PlaybookStep {
|
||||
action: "enable_rate_limiting".into(),
|
||||
description: "Enable enhanced rate limiting on auth endpoints".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
],
|
||||
severity_filter: "medium".into(),
|
||||
auto_approve_steps: vec![1],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn builtin_playbooks_are_valid() {
|
||||
let playbooks = builtin_playbooks();
|
||||
assert_eq!(playbooks.len(), 4);
|
||||
|
||||
let names: Vec<&str> = playbooks.iter().map(|p| p.name.as_str()).collect();
|
||||
assert!(names.contains(&"suspicious_login"));
|
||||
assert!(names.contains(&"malware_detected"));
|
||||
assert!(names.contains(&"data_exfiltration_attempt"));
|
||||
assert!(names.contains(&"brute_force"));
|
||||
|
||||
for pb in &playbooks {
|
||||
assert!(!pb.steps.is_empty(), "Playbook {} has no steps", pb.name);
|
||||
assert!(!pb.description.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_level_ordering() {
|
||||
assert!(severity_level("low") < severity_level("medium"));
|
||||
assert!(severity_level("medium") < severity_level("high"));
|
||||
assert!(severity_level("high") < severity_level("critical"));
|
||||
assert_eq!(severity_level("unknown"), u8::MAX);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_approve_respects_severity_cap() {
|
||||
let pb = &builtin_playbooks()[0]; // suspicious_login
|
||||
|
||||
// Step 0 is in auto_approve_steps
|
||||
assert!(can_auto_approve(pb, 0, "low", "low"));
|
||||
assert!(can_auto_approve(pb, 0, "low", "medium"));
|
||||
|
||||
// Alert severity exceeds max -> cannot auto-approve
|
||||
assert!(!can_auto_approve(pb, 0, "high", "low"));
|
||||
assert!(!can_auto_approve(pb, 0, "critical", "medium"));
|
||||
|
||||
// Step 2 is NOT in auto_approve_steps
|
||||
assert!(!can_auto_approve(pb, 2, "low", "critical"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_step_requires_approval() {
|
||||
let pb = &builtin_playbooks()[0]; // suspicious_login
|
||||
|
||||
// Step 2 (notify_user) requires approval, high severity, max=low -> pending
|
||||
let result = evaluate_step(pb, 2, "high", "low", true);
|
||||
assert_eq!(result.status, StepStatus::PendingApproval);
|
||||
assert_eq!(result.action, "notify_user");
|
||||
|
||||
// Step 0 (gather_login_context) does NOT require approval -> completed
|
||||
let result = evaluate_step(pb, 0, "high", "low", true);
|
||||
assert_eq!(result.status, StepStatus::Completed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_step_out_of_range() {
|
||||
let pb = &builtin_playbooks()[0];
|
||||
let result = evaluate_step(pb, 99, "low", "low", true);
|
||||
assert_eq!(result.status, StepStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn playbook_json_roundtrip() {
|
||||
let pb = &builtin_playbooks()[0];
|
||||
let json = serde_json::to_string(pb).unwrap();
|
||||
let parsed: Playbook = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, *pb);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_playbooks_from_nonexistent_dir_returns_builtins() {
|
||||
let playbooks = load_playbooks(Path::new("/nonexistent/dir"));
|
||||
assert_eq!(playbooks.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_playbooks_merges_custom_and_builtin() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let custom = Playbook {
|
||||
name: "custom_playbook".into(),
|
||||
description: "A custom playbook".into(),
|
||||
steps: vec![PlaybookStep {
|
||||
action: "custom_action".into(),
|
||||
description: "Do something custom".into(),
|
||||
requires_approval: true,
|
||||
timeout_secs: 60,
|
||||
}],
|
||||
severity_filter: "low".into(),
|
||||
auto_approve_steps: vec![],
|
||||
};
|
||||
let json = serde_json::to_string(&custom).unwrap();
|
||||
std::fs::write(dir.path().join("custom.json"), json).unwrap();
|
||||
|
||||
let playbooks = load_playbooks(dir.path());
|
||||
// 4 builtins + 1 custom
|
||||
assert_eq!(playbooks.len(), 5);
|
||||
assert!(playbooks.iter().any(|p| p.name == "custom_playbook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_playbooks_custom_overrides_builtin() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let override_pb = Playbook {
|
||||
name: "suspicious_login".into(),
|
||||
description: "Custom override".into(),
|
||||
steps: vec![PlaybookStep {
|
||||
action: "custom_step".into(),
|
||||
description: "Overridden step".into(),
|
||||
requires_approval: false,
|
||||
timeout_secs: 30,
|
||||
}],
|
||||
severity_filter: "low".into(),
|
||||
auto_approve_steps: vec![0],
|
||||
};
|
||||
let json = serde_json::to_string(&override_pb).unwrap();
|
||||
std::fs::write(dir.path().join("suspicious_login.json"), json).unwrap();
|
||||
|
||||
let playbooks = load_playbooks(dir.path());
|
||||
// 3 remaining builtins + 1 overridden = 4
|
||||
assert_eq!(playbooks.len(), 4);
|
||||
let sl = playbooks
|
||||
.iter()
|
||||
.find(|p| p.name == "suspicious_login")
|
||||
.unwrap();
|
||||
assert_eq!(sl.description, "Custom override");
|
||||
}
|
||||
}
|
||||
+144
-3
@@ -793,6 +793,8 @@ impl SecurityPolicy {
|
||||
// 1. Allowlist check (is the base command permitted at all?)
|
||||
// 2. Risk classification (high / medium / low)
|
||||
// 3. Policy flags (block_high_risk_commands, require_approval_for_medium_risk)
|
||||
// — explicit allowlist entries exempt a command from the high-risk block,
|
||||
// but the wildcard "*" does NOT grant an exemption.
|
||||
// 4. Autonomy level × approval status (supervised requires explicit approval)
|
||||
// This ordering ensures deny-by-default: unknown commands are rejected
|
||||
// before any risk or autonomy logic runs.
|
||||
@@ -810,7 +812,7 @@ impl SecurityPolicy {
|
||||
let risk = self.command_risk_level(command);
|
||||
|
||||
if risk == CommandRiskLevel::High {
|
||||
if self.block_high_risk_commands {
|
||||
if self.block_high_risk_commands && !self.is_command_explicitly_allowed(command) {
|
||||
return Err("Command blocked: high-risk command is disallowed by policy".into());
|
||||
}
|
||||
if self.autonomy == AutonomyLevel::Supervised && !approved {
|
||||
@@ -834,6 +836,48 @@ impl SecurityPolicy {
|
||||
Ok(risk)
|
||||
}
|
||||
|
||||
/// Check whether **every** segment of a command is explicitly listed in
|
||||
/// `allowed_commands` — i.e., matched by a concrete entry rather than by
|
||||
/// the wildcard `"*"`.
|
||||
///
|
||||
/// This is used to exempt explicitly-allowlisted high-risk commands from
|
||||
/// the `block_high_risk_commands` gate. The wildcard entry intentionally
|
||||
/// does **not** qualify as an explicit allowlist match, so that operators
|
||||
/// who set `allowed_commands = ["*"]` still get the high-risk safety net.
|
||||
fn is_command_explicitly_allowed(&self, command: &str) -> bool {
|
||||
let segments = split_unquoted_segments(command);
|
||||
for segment in &segments {
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
|
||||
let base_cmd_owned = command_basename(executable).to_ascii_lowercase();
|
||||
let base_cmd = strip_windows_exe_suffix(&base_cmd_owned);
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let explicitly_listed = self.allowed_commands.iter().any(|allowed| {
|
||||
let allowed = strip_wrapping_quotes(allowed).trim();
|
||||
// Skip wildcard — it does not count as an explicit entry.
|
||||
if allowed.is_empty() || allowed == "*" {
|
||||
return false;
|
||||
}
|
||||
is_allowlist_entry_match(allowed, executable, base_cmd)
|
||||
});
|
||||
|
||||
if !explicitly_listed {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At least one real command must be present.
|
||||
segments.iter().any(|s| {
|
||||
let s = skip_env_assignments(s.trim());
|
||||
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
|
||||
})
|
||||
}
|
||||
|
||||
// ── Layered Command Allowlist ──────────────────────────────────────────
|
||||
// Defence-in-depth: five independent gates run in order before the
|
||||
// per-segment allowlist check. Each gate targets a specific bypass
|
||||
@@ -1503,10 +1547,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_blocks_high_risk_by_default() {
|
||||
fn validate_command_blocks_high_risk_via_wildcard() {
|
||||
// Wildcard allows the command through is_command_allowed, but
|
||||
// block_high_risk_commands still rejects it because "*" does not
|
||||
// count as an explicit allowlist entry.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
allowed_commands: vec!["*".into()],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
@@ -1515,6 +1562,100 @@ mod tests {
|
||||
assert!(result.unwrap_err().contains("high-risk"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_allows_explicitly_listed_high_risk() {
|
||||
// When a high-risk command is explicitly in allowed_commands, the
|
||||
// block_high_risk_commands gate is bypassed — the operator has made
|
||||
// a deliberate decision to permit it.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("curl https://api.example.com/data", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_allows_wget_when_explicitly_listed() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["wget".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result =
|
||||
p.validate_command_execution("wget https://releases.example.com/v1.tar.gz", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_blocks_non_listed_high_risk_when_another_is_allowed() {
|
||||
// Allowing curl explicitly should not exempt wget.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("wget https://evil.com", true);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_explicit_rm_bypasses_high_risk_block() {
|
||||
// Operator explicitly listed "rm" — they accept the risk.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("rm -rf /tmp/test", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_high_risk_still_needs_approval_in_supervised() {
|
||||
// Even when explicitly allowed, supervised mode still requires
|
||||
// approval for high-risk commands (the approval gate is separate
|
||||
// from the block gate).
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let denied = p.validate_command_execution("curl https://api.example.com", false);
|
||||
assert!(denied.is_err());
|
||||
assert!(denied.unwrap_err().contains("requires explicit approval"));
|
||||
|
||||
let allowed = p.validate_command_execution("curl https://api.example.com", true);
|
||||
assert_eq!(allowed.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_pipe_needs_all_segments_explicitly_allowed() {
|
||||
// When a pipeline contains a high-risk command, every segment
|
||||
// must be explicitly allowed for the exemption to apply.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into(), "grep".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("curl https://api.example.com | grep data", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_full_mode_skips_medium_risk_approval_gate() {
|
||||
let p = SecurityPolicy {
|
||||
|
||||
@@ -0,0 +1,397 @@
|
||||
//! Vulnerability scan result parsing and management.
|
||||
//!
|
||||
//! Parses vulnerability scan outputs from common scanners (Nessus, Qualys, generic
|
||||
//! CVSS JSON) and provides priority scoring with business context adjustments.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
/// A single vulnerability finding.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Finding {
|
||||
/// CVE identifier (e.g. "CVE-2024-1234"). May be empty for non-CVE findings.
|
||||
#[serde(default)]
|
||||
pub cve_id: String,
|
||||
/// CVSS base score (0.0 - 10.0).
|
||||
pub cvss_score: f64,
|
||||
/// Severity label: "low", "medium", "high", "critical".
|
||||
pub severity: String,
|
||||
/// Affected asset identifier (hostname, IP, or service name).
|
||||
pub affected_asset: String,
|
||||
/// Description of the vulnerability.
|
||||
pub description: String,
|
||||
/// Recommended remediation steps.
|
||||
#[serde(default)]
|
||||
pub remediation: String,
|
||||
/// Whether the asset is internet-facing (increases effective priority).
|
||||
#[serde(default)]
|
||||
pub internet_facing: bool,
|
||||
/// Whether the asset is in a production environment.
|
||||
#[serde(default = "default_true")]
|
||||
pub production: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// A parsed vulnerability scan report.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VulnerabilityReport {
|
||||
/// When the scan was performed.
|
||||
pub scan_date: DateTime<Utc>,
|
||||
/// Scanner that produced the results (e.g. "nessus", "qualys", "generic").
|
||||
pub scanner: String,
|
||||
/// Individual findings from the scan.
|
||||
pub findings: Vec<Finding>,
|
||||
}
|
||||
|
||||
/// Compute effective priority score for a finding.
|
||||
///
|
||||
/// Base: CVSS score (0-10). Adjustments:
|
||||
/// - Internet-facing: +2.0 (capped at 10.0)
|
||||
/// - Production: +1.0 (capped at 10.0)
|
||||
pub fn effective_priority(finding: &Finding) -> f64 {
|
||||
let mut score = finding.cvss_score;
|
||||
if finding.internet_facing {
|
||||
score += 2.0;
|
||||
}
|
||||
if finding.production {
|
||||
score += 1.0;
|
||||
}
|
||||
score.min(10.0)
|
||||
}
|
||||
|
||||
/// Classify CVSS score into severity label.
|
||||
pub fn cvss_to_severity(cvss: f64) -> &'static str {
|
||||
match cvss {
|
||||
s if s >= 9.0 => "critical",
|
||||
s if s >= 7.0 => "high",
|
||||
s if s >= 4.0 => "medium",
|
||||
s if s > 0.0 => "low",
|
||||
_ => "informational",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a generic CVSS JSON vulnerability report.
|
||||
///
|
||||
/// Expects a JSON object with:
|
||||
/// - `scan_date`: ISO 8601 date string
|
||||
/// - `scanner`: string
|
||||
/// - `findings`: array of Finding objects
|
||||
pub fn parse_vulnerability_json(json_str: &str) -> anyhow::Result<VulnerabilityReport> {
|
||||
let report: VulnerabilityReport = serde_json::from_str(json_str)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse vulnerability report: {e}"))?;
|
||||
|
||||
for (i, finding) in report.findings.iter().enumerate() {
|
||||
if !(0.0..=10.0).contains(&finding.cvss_score) {
|
||||
anyhow::bail!(
|
||||
"findings[{}].cvss_score must be between 0.0 and 10.0, got {}",
|
||||
i,
|
||||
finding.cvss_score
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(report)
|
||||
}
|
||||
|
||||
/// Generate a summary of the vulnerability report.
|
||||
pub fn generate_summary(report: &VulnerabilityReport) -> String {
|
||||
if report.findings.is_empty() {
|
||||
return format!(
|
||||
"Vulnerability scan by {} on {}: No findings.",
|
||||
report.scanner,
|
||||
report.scan_date.format("%Y-%m-%d")
|
||||
);
|
||||
}
|
||||
|
||||
let total = report.findings.len();
|
||||
let critical = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("critical"))
|
||||
.count();
|
||||
let high = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("high"))
|
||||
.count();
|
||||
let medium = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("medium"))
|
||||
.count();
|
||||
let low = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("low"))
|
||||
.count();
|
||||
let informational = report
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.eq_ignore_ascii_case("informational"))
|
||||
.count();
|
||||
|
||||
// Sort by effective priority descending
|
||||
let mut sorted: Vec<&Finding> = report.findings.iter().collect();
|
||||
sorted.sort_by(|a, b| {
|
||||
effective_priority(b)
|
||||
.partial_cmp(&effective_priority(a))
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut summary = format!(
|
||||
"## Vulnerability Scan Summary\n\
|
||||
**Scanner:** {} | **Date:** {}\n\
|
||||
**Total findings:** {} (Critical: {}, High: {}, Medium: {}, Low: {}, Informational: {})\n\n",
|
||||
report.scanner,
|
||||
report.scan_date.format("%Y-%m-%d"),
|
||||
total,
|
||||
critical,
|
||||
high,
|
||||
medium,
|
||||
low,
|
||||
informational
|
||||
);
|
||||
|
||||
// Top 10 by effective priority
|
||||
summary.push_str("### Top Findings by Priority\n\n");
|
||||
for (i, finding) in sorted.iter().take(10).enumerate() {
|
||||
let priority = effective_priority(finding);
|
||||
let context = match (finding.internet_facing, finding.production) {
|
||||
(true, true) => " [internet-facing, production]",
|
||||
(true, false) => " [internet-facing]",
|
||||
(false, true) => " [production]",
|
||||
(false, false) => "",
|
||||
};
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"{}. **{}** (CVSS: {:.1}, Priority: {:.1}){}\n Asset: {} | {}",
|
||||
i + 1,
|
||||
if finding.cve_id.is_empty() {
|
||||
"No CVE"
|
||||
} else {
|
||||
&finding.cve_id
|
||||
},
|
||||
finding.cvss_score,
|
||||
priority,
|
||||
context,
|
||||
finding.affected_asset,
|
||||
finding.description
|
||||
);
|
||||
if !finding.remediation.is_empty() {
|
||||
let _ = writeln!(summary, " Remediation: {}", finding.remediation);
|
||||
}
|
||||
summary.push('\n');
|
||||
}
|
||||
|
||||
// Remediation recommendations
|
||||
if critical > 0 || high > 0 {
|
||||
summary.push_str("### Remediation Recommendations\n\n");
|
||||
if critical > 0 {
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"- **URGENT:** {} critical findings require immediate remediation",
|
||||
critical
|
||||
);
|
||||
}
|
||||
if high > 0 {
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"- **HIGH:** {} high-severity findings should be addressed within 7 days",
|
||||
high
|
||||
);
|
||||
}
|
||||
let internet_facing_critical = sorted
|
||||
.iter()
|
||||
.filter(|f| f.internet_facing && (f.severity == "critical" || f.severity == "high"))
|
||||
.count();
|
||||
if internet_facing_critical > 0 {
|
||||
let _ = writeln!(
|
||||
summary,
|
||||
"- **PRIORITY:** {} critical/high findings on internet-facing assets",
|
||||
internet_facing_critical
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
summary
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_findings() -> Vec<Finding> {
|
||||
vec![
|
||||
Finding {
|
||||
cve_id: "CVE-2024-0001".into(),
|
||||
cvss_score: 9.8,
|
||||
severity: "critical".into(),
|
||||
affected_asset: "web-server-01".into(),
|
||||
description: "Remote code execution in web framework".into(),
|
||||
remediation: "Upgrade to version 2.1.0".into(),
|
||||
internet_facing: true,
|
||||
production: true,
|
||||
},
|
||||
Finding {
|
||||
cve_id: "CVE-2024-0002".into(),
|
||||
cvss_score: 7.5,
|
||||
severity: "high".into(),
|
||||
affected_asset: "db-server-01".into(),
|
||||
description: "SQL injection in query parser".into(),
|
||||
remediation: "Apply patch KB-12345".into(),
|
||||
internet_facing: false,
|
||||
production: true,
|
||||
},
|
||||
Finding {
|
||||
cve_id: "CVE-2024-0003".into(),
|
||||
cvss_score: 4.3,
|
||||
severity: "medium".into(),
|
||||
affected_asset: "staging-app-01".into(),
|
||||
description: "Information disclosure via debug endpoint".into(),
|
||||
remediation: "Disable debug endpoint in config".into(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_priority_adds_context_bonuses() {
|
||||
let mut f = Finding {
|
||||
cve_id: String::new(),
|
||||
cvss_score: 7.0,
|
||||
severity: "high".into(),
|
||||
affected_asset: "host".into(),
|
||||
description: "test".into(),
|
||||
remediation: String::new(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
};
|
||||
|
||||
assert!((effective_priority(&f) - 7.0).abs() < f64::EPSILON);
|
||||
|
||||
f.internet_facing = true;
|
||||
assert!((effective_priority(&f) - 9.0).abs() < f64::EPSILON);
|
||||
|
||||
f.production = true;
|
||||
assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON); // capped
|
||||
|
||||
// High CVSS + both bonuses still caps at 10.0
|
||||
f.cvss_score = 9.5;
|
||||
assert!((effective_priority(&f) - 10.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cvss_to_severity_classification() {
|
||||
assert_eq!(cvss_to_severity(9.8), "critical");
|
||||
assert_eq!(cvss_to_severity(9.0), "critical");
|
||||
assert_eq!(cvss_to_severity(8.5), "high");
|
||||
assert_eq!(cvss_to_severity(7.0), "high");
|
||||
assert_eq!(cvss_to_severity(5.0), "medium");
|
||||
assert_eq!(cvss_to_severity(4.0), "medium");
|
||||
assert_eq!(cvss_to_severity(3.9), "low");
|
||||
assert_eq!(cvss_to_severity(0.1), "low");
|
||||
assert_eq!(cvss_to_severity(0.0), "informational");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_roundtrip() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "nessus".into(),
|
||||
findings: sample_findings(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
let parsed = parse_vulnerability_json(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.scanner, "nessus");
|
||||
assert_eq!(parsed.findings.len(), 3);
|
||||
assert_eq!(parsed.findings[0].cve_id, "CVE-2024-0001");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_rejects_invalid() {
|
||||
let result = parse_vulnerability_json("not json");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_summary_includes_key_sections() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "qualys".into(),
|
||||
findings: sample_findings(),
|
||||
};
|
||||
|
||||
let summary = generate_summary(&report);
|
||||
|
||||
assert!(summary.contains("qualys"));
|
||||
assert!(summary.contains("Total findings:** 3"));
|
||||
assert!(summary.contains("Critical: 1"));
|
||||
assert!(summary.contains("High: 1"));
|
||||
assert!(summary.contains("CVE-2024-0001"));
|
||||
assert!(summary.contains("URGENT"));
|
||||
assert!(summary.contains("internet-facing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_rejects_out_of_range_cvss() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "test".into(),
|
||||
findings: vec![Finding {
|
||||
cve_id: "CVE-2024-9999".into(),
|
||||
cvss_score: 11.0,
|
||||
severity: "critical".into(),
|
||||
affected_asset: "host".into(),
|
||||
description: "bad score".into(),
|
||||
remediation: String::new(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
let result = parse_vulnerability_json(&json);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("cvss_score must be between 0.0 and 10.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_vulnerability_json_rejects_negative_cvss() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "test".into(),
|
||||
findings: vec![Finding {
|
||||
cve_id: "CVE-2024-9998".into(),
|
||||
cvss_score: -1.0,
|
||||
severity: "low".into(),
|
||||
affected_asset: "host".into(),
|
||||
description: "negative score".into(),
|
||||
remediation: String::new(),
|
||||
internet_facing: false,
|
||||
production: false,
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
let result = parse_vulnerability_json(&json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_summary_empty_findings() {
|
||||
let report = VulnerabilityReport {
|
||||
scan_date: Utc::now(),
|
||||
scanner: "nessus".into(),
|
||||
findings: vec![],
|
||||
};
|
||||
|
||||
let summary = generate_summary(&report);
|
||||
assert!(summary.contains("No findings"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
//! Workspace isolation boundary enforcement.
|
||||
//!
|
||||
//! Prevents cross-workspace data access and enforces per-workspace
|
||||
//! domain allowlists and tool restrictions.
|
||||
|
||||
use crate::config::workspace::WorkspaceProfile;
|
||||
use std::path::Path;
|
||||
|
||||
/// Outcome of a workspace boundary check.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BoundaryVerdict {
|
||||
/// Access is allowed.
|
||||
Allow,
|
||||
/// Access is denied with a reason.
|
||||
Deny(String),
|
||||
}
|
||||
|
||||
/// Enforces isolation boundaries for the active workspace.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkspaceBoundary {
|
||||
/// The active workspace profile (if workspace isolation is active).
|
||||
profile: Option<WorkspaceProfile>,
|
||||
/// Whether cross-workspace search is allowed.
|
||||
cross_workspace_search: bool,
|
||||
}
|
||||
|
||||
impl WorkspaceBoundary {
|
||||
/// Create a boundary enforcer for the given active workspace.
|
||||
pub fn new(profile: Option<WorkspaceProfile>, cross_workspace_search: bool) -> Self {
|
||||
Self {
|
||||
profile,
|
||||
cross_workspace_search,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a boundary enforcer with no active workspace (no restrictions).
|
||||
pub fn inactive() -> Self {
|
||||
Self {
|
||||
profile: None,
|
||||
cross_workspace_search: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether a tool is allowed in the current workspace.
|
||||
pub fn check_tool_access(&self, tool_name: &str) -> BoundaryVerdict {
|
||||
if let Some(profile) = &self.profile {
|
||||
if profile.is_tool_restricted(tool_name) {
|
||||
return BoundaryVerdict::Deny(format!(
|
||||
"tool '{}' is restricted in workspace '{}'",
|
||||
tool_name, profile.name
|
||||
));
|
||||
}
|
||||
}
|
||||
BoundaryVerdict::Allow
|
||||
}
|
||||
|
||||
/// Check whether a domain is allowed in the current workspace.
|
||||
pub fn check_domain_access(&self, domain: &str) -> BoundaryVerdict {
|
||||
if let Some(profile) = &self.profile {
|
||||
if !profile.is_domain_allowed(domain) {
|
||||
return BoundaryVerdict::Deny(format!(
|
||||
"domain '{}' is not in the allowlist for workspace '{}'",
|
||||
domain, profile.name
|
||||
));
|
||||
}
|
||||
}
|
||||
BoundaryVerdict::Allow
|
||||
}
|
||||
|
||||
/// Check whether accessing a path is allowed given workspace isolation.
|
||||
///
|
||||
/// When a workspace is active, paths outside the workspace directory
|
||||
/// and paths belonging to other workspaces are denied.
|
||||
pub fn check_path_access(&self, path: &Path, workspaces_base: &Path) -> BoundaryVerdict {
|
||||
let profile = match &self.profile {
|
||||
Some(p) => p,
|
||||
None => return BoundaryVerdict::Allow,
|
||||
};
|
||||
|
||||
// If the path is under the workspaces base, verify it belongs to the active workspace
|
||||
if let Ok(relative) = path.strip_prefix(workspaces_base) {
|
||||
let first_component = relative
|
||||
.components()
|
||||
.next()
|
||||
.and_then(|c| c.as_os_str().to_str());
|
||||
|
||||
if let Some(ws_name) = first_component {
|
||||
if ws_name != profile.name {
|
||||
if self.cross_workspace_search {
|
||||
// Cross-workspace search is allowed, but only for read-like access
|
||||
return BoundaryVerdict::Allow;
|
||||
}
|
||||
return BoundaryVerdict::Deny(format!(
|
||||
"access to workspace '{}' is denied from workspace '{}'",
|
||||
ws_name, profile.name
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BoundaryVerdict::Allow
|
||||
}
|
||||
|
||||
/// Whether workspace isolation is active.
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.profile.is_some()
|
||||
}
|
||||
|
||||
/// Get the active workspace name, if any.
|
||||
pub fn active_workspace_name(&self) -> Option<&str> {
|
||||
self.profile.as_ref().map(|p| p.name.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn test_profile() -> WorkspaceProfile {
|
||||
WorkspaceProfile {
|
||||
name: "client_a".to_string(),
|
||||
allowed_domains: vec!["api.example.com".to_string()],
|
||||
credential_profile: None,
|
||||
memory_namespace: Some("client_a".to_string()),
|
||||
audit_namespace: Some("client_a".to_string()),
|
||||
tool_restrictions: vec!["shell".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_inactive_allows_everything() {
|
||||
let boundary = WorkspaceBoundary::inactive();
|
||||
assert_eq!(boundary.check_tool_access("shell"), BoundaryVerdict::Allow);
|
||||
assert_eq!(
|
||||
boundary.check_domain_access("any.domain"),
|
||||
BoundaryVerdict::Allow
|
||||
);
|
||||
assert!(!boundary.is_active());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_denies_restricted_tool() {
|
||||
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
|
||||
assert!(matches!(
|
||||
boundary.check_tool_access("shell"),
|
||||
BoundaryVerdict::Deny(_)
|
||||
));
|
||||
assert_eq!(
|
||||
boundary.check_tool_access("file_read"),
|
||||
BoundaryVerdict::Allow
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_denies_unlisted_domain() {
|
||||
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
|
||||
assert_eq!(
|
||||
boundary.check_domain_access("api.example.com"),
|
||||
BoundaryVerdict::Allow
|
||||
);
|
||||
assert!(matches!(
|
||||
boundary.check_domain_access("evil.com"),
|
||||
BoundaryVerdict::Deny(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_denies_cross_workspace_path_access() {
|
||||
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
|
||||
let base = PathBuf::from("/home/zeroclaw_user/.zeroclaw/workspaces");
|
||||
|
||||
// Access to own workspace is allowed
|
||||
let own_path = base.join("client_a").join("data.db");
|
||||
assert_eq!(
|
||||
boundary.check_path_access(&own_path, &base),
|
||||
BoundaryVerdict::Allow
|
||||
);
|
||||
|
||||
// Access to other workspace is denied
|
||||
let other_path = base.join("client_b").join("data.db");
|
||||
assert!(matches!(
|
||||
boundary.check_path_access(&other_path, &base),
|
||||
BoundaryVerdict::Deny(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_allows_cross_workspace_when_enabled() {
|
||||
let boundary = WorkspaceBoundary::new(Some(test_profile()), true);
|
||||
let base = PathBuf::from("/home/zeroclaw_user/.zeroclaw/workspaces");
|
||||
let other_path = base.join("client_b").join("data.db");
|
||||
|
||||
assert_eq!(
|
||||
boundary.check_path_access(&other_path, &base),
|
||||
BoundaryVerdict::Allow
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boundary_allows_paths_outside_workspaces_dir() {
|
||||
let boundary = WorkspaceBoundary::new(Some(test_profile()), false);
|
||||
let base = PathBuf::from("/home/zeroclaw_user/.zeroclaw/workspaces");
|
||||
let outside_path = PathBuf::from("/tmp/something");
|
||||
|
||||
assert_eq!(
|
||||
boundary.check_path_access(&outside_path, &base),
|
||||
BoundaryVerdict::Allow
|
||||
);
|
||||
}
|
||||
}
|
||||
+91
-6
@@ -442,8 +442,24 @@ fn install_linux_systemd(config: &Config) -> Result<()> {
|
||||
|
||||
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
|
||||
let unit = format!(
|
||||
"[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n",
|
||||
exe.display()
|
||||
"[Unit]\n\
|
||||
Description=ZeroClaw daemon\n\
|
||||
After=network.target\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
ExecStart={exe} daemon\n\
|
||||
Restart=always\n\
|
||||
RestartSec=3\n\
|
||||
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
|
||||
Environment=HOME=%h\n\
|
||||
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
|
||||
# so graphical/headless browsers can function correctly.\n\
|
||||
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n",
|
||||
exe = exe.display()
|
||||
);
|
||||
|
||||
fs::write(&file, unit)?;
|
||||
@@ -826,8 +842,8 @@ fn generate_openrc_script(exe_path: &Path, config_dir: &Path) -> String {
|
||||
name="zeroclaw"
|
||||
description="ZeroClaw daemon"
|
||||
|
||||
command="{}"
|
||||
command_args="--config-dir {} daemon"
|
||||
command="{exe}"
|
||||
command_args="--config-dir {config_dir} daemon"
|
||||
command_background="yes"
|
||||
command_user="zeroclaw:zeroclaw"
|
||||
pidfile="/run/${{RC_SVCNAME}}.pid"
|
||||
@@ -835,13 +851,21 @@ umask 027
|
||||
output_log="/var/log/zeroclaw/access.log"
|
||||
error_log="/var/log/zeroclaw/error.log"
|
||||
|
||||
# Provide HOME so headless browsers can create profile/cache directories.
|
||||
# Without this, Chromium/Firefox fail with sandbox or profile errors.
|
||||
export HOME="/var/lib/zeroclaw"
|
||||
|
||||
depend() {{
|
||||
need net
|
||||
after firewall
|
||||
}}
|
||||
|
||||
start_pre() {{
|
||||
checkpath --directory --owner zeroclaw:zeroclaw --mode 0750 /var/lib/zeroclaw
|
||||
}}
|
||||
"#,
|
||||
exe_path.display(),
|
||||
config_dir.display()
|
||||
exe = exe_path.display(),
|
||||
config_dir = config_dir.display(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1196,6 +1220,67 @@ mod tests {
|
||||
assert!(script.contains("after firewall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_openrc_script_sets_home_for_browser() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
|
||||
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
|
||||
|
||||
assert!(
|
||||
script.contains("export HOME=\"/var/lib/zeroclaw\""),
|
||||
"OpenRC script must set HOME for headless browser support"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_openrc_script_creates_home_directory() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
|
||||
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
|
||||
|
||||
assert!(
|
||||
script.contains("start_pre()"),
|
||||
"OpenRC script must have start_pre to create HOME dir"
|
||||
);
|
||||
assert!(
|
||||
script.contains("checkpath --directory --owner zeroclaw:zeroclaw"),
|
||||
"start_pre must ensure /var/lib/zeroclaw exists with correct ownership"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn systemd_unit_contains_home_and_pass_environment() {
|
||||
let unit = "[Unit]\n\
|
||||
Description=ZeroClaw daemon\n\
|
||||
After=network.target\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
ExecStart=/usr/local/bin/zeroclaw daemon\n\
|
||||
Restart=always\n\
|
||||
RestartSec=3\n\
|
||||
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
|
||||
Environment=HOME=%h\n\
|
||||
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
|
||||
# so graphical/headless browsers can function correctly.\n\
|
||||
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n"
|
||||
.to_string();
|
||||
|
||||
assert!(
|
||||
unit.contains("Environment=HOME=%h"),
|
||||
"systemd unit must set HOME for headless browser support"
|
||||
);
|
||||
assert!(
|
||||
unit.contains("PassEnvironment=DISPLAY XDG_RUNTIME_DIR"),
|
||||
"systemd unit must pass through display/runtime env vars"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warn_if_binary_in_home_detects_home_path() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -0,0 +1,466 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
|
||||
/// Workspace backup tool: create, list, verify, and restore timestamped backups
|
||||
/// with SHA-256 manifest integrity checking.
|
||||
pub struct BackupTool {
|
||||
workspace_dir: PathBuf,
|
||||
include_dirs: Vec<String>,
|
||||
max_keep: usize,
|
||||
}
|
||||
|
||||
impl BackupTool {
|
||||
pub fn new(workspace_dir: PathBuf, include_dirs: Vec<String>, max_keep: usize) -> Self {
|
||||
Self {
|
||||
workspace_dir,
|
||||
include_dirs,
|
||||
max_keep,
|
||||
}
|
||||
}
|
||||
|
||||
fn backups_dir(&self) -> PathBuf {
|
||||
self.workspace_dir.join("backups")
|
||||
}
|
||||
|
||||
async fn cmd_create(&self) -> anyhow::Result<ToolResult> {
|
||||
let ts = chrono::Utc::now().format("%Y%m%dT%H%M%SZ");
|
||||
let name = format!("backup-{ts}");
|
||||
let backup_dir = self.backups_dir().join(&name);
|
||||
fs::create_dir_all(&backup_dir).await?;
|
||||
|
||||
for sub in &self.include_dirs {
|
||||
let src = self.workspace_dir.join(sub);
|
||||
if src.is_dir() {
|
||||
let dst = backup_dir.join(sub);
|
||||
copy_dir_recursive(&src, &dst).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let checksums = compute_checksums(&backup_dir).await?;
|
||||
let file_count = checksums.len();
|
||||
let manifest = serde_json::to_string_pretty(&checksums)?;
|
||||
fs::write(backup_dir.join("manifest.json"), &manifest).await?;
|
||||
|
||||
// Enforce max_keep: remove oldest backups beyond the limit.
|
||||
self.enforce_max_keep().await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"backup": name,
|
||||
"file_count": file_count,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn enforce_max_keep(&self) -> anyhow::Result<()> {
|
||||
let mut backups = self.list_backup_dirs().await?;
|
||||
// Sorted newest-first; drop excess from the tail.
|
||||
while backups.len() > self.max_keep {
|
||||
if let Some(old) = backups.pop() {
|
||||
fs::remove_dir_all(old).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_backup_dirs(&self) -> anyhow::Result<Vec<PathBuf>> {
|
||||
let dir = self.backups_dir();
|
||||
if !dir.is_dir() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut entries = Vec::new();
|
||||
let mut rd = fs::read_dir(&dir).await?;
|
||||
while let Some(e) = rd.next_entry().await? {
|
||||
let p = e.path();
|
||||
if p.is_dir() && e.file_name().to_string_lossy().starts_with("backup-") {
|
||||
entries.push(p);
|
||||
}
|
||||
}
|
||||
entries.sort();
|
||||
entries.reverse(); // newest first
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
async fn cmd_list(&self) -> anyhow::Result<ToolResult> {
|
||||
let dirs = self.list_backup_dirs().await?;
|
||||
let mut items = Vec::new();
|
||||
for d in &dirs {
|
||||
let name = d
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
let manifest_path = d.join("manifest.json");
|
||||
let file_count = if manifest_path.is_file() {
|
||||
let data = fs::read_to_string(&manifest_path).await?;
|
||||
let map: HashMap<String, String> = serde_json::from_str(&data).unwrap_or_default();
|
||||
map.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let meta = fs::metadata(d).await?;
|
||||
let created = meta
|
||||
.created()
|
||||
.or_else(|_| meta.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
let dt: chrono::DateTime<chrono::Utc> = created.into();
|
||||
items.push(json!({
|
||||
"name": name,
|
||||
"file_count": file_count,
|
||||
"created": dt.to_rfc3339(),
|
||||
}));
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&items)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_verify(&self, backup_name: &str) -> anyhow::Result<ToolResult> {
|
||||
let backup_dir = self.backups_dir().join(backup_name);
|
||||
if !backup_dir.is_dir() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Backup not found: {backup_name}")),
|
||||
});
|
||||
}
|
||||
let manifest_path = backup_dir.join("manifest.json");
|
||||
let data = fs::read_to_string(&manifest_path).await?;
|
||||
let expected: HashMap<String, String> = serde_json::from_str(&data)?;
|
||||
let actual = compute_checksums(&backup_dir).await?;
|
||||
|
||||
let mut mismatches = Vec::new();
|
||||
for (path, expected_hash) in &expected {
|
||||
match actual.get(path) {
|
||||
Some(actual_hash) if actual_hash == expected_hash => {}
|
||||
Some(actual_hash) => mismatches.push(json!({
|
||||
"file": path,
|
||||
"expected": expected_hash,
|
||||
"actual": actual_hash,
|
||||
})),
|
||||
None => mismatches.push(json!({
|
||||
"file": path,
|
||||
"error": "missing",
|
||||
})),
|
||||
}
|
||||
}
|
||||
let pass = mismatches.is_empty();
|
||||
Ok(ToolResult {
|
||||
success: pass,
|
||||
output: json!({
|
||||
"backup": backup_name,
|
||||
"pass": pass,
|
||||
"checked": expected.len(),
|
||||
"mismatches": mismatches,
|
||||
})
|
||||
.to_string(),
|
||||
error: if pass {
|
||||
None
|
||||
} else {
|
||||
Some("Integrity check failed".into())
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_restore(&self, backup_name: &str, confirm: bool) -> anyhow::Result<ToolResult> {
|
||||
let backup_dir = self.backups_dir().join(backup_name);
|
||||
if !backup_dir.is_dir() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Backup not found: {backup_name}")),
|
||||
});
|
||||
}
|
||||
|
||||
// Collect restorable subdirectories (skip manifest.json).
|
||||
let mut restore_items: Vec<String> = Vec::new();
|
||||
let mut rd = fs::read_dir(&backup_dir).await?;
|
||||
while let Some(e) = rd.next_entry().await? {
|
||||
let name = e.file_name().to_string_lossy().to_string();
|
||||
if name == "manifest.json" {
|
||||
continue;
|
||||
}
|
||||
if e.path().is_dir() {
|
||||
restore_items.push(name);
|
||||
}
|
||||
}
|
||||
|
||||
if !confirm {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"dry_run": true,
|
||||
"backup": backup_name,
|
||||
"would_restore": restore_items,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
for sub in &restore_items {
|
||||
let src = backup_dir.join(sub);
|
||||
let dst = self.workspace_dir.join(sub);
|
||||
copy_dir_recursive(&src, &dst).await?;
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"restored": backup_name,
|
||||
"directories": restore_items,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for BackupTool {
|
||||
fn name(&self) -> &str {
|
||||
"backup"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Create, list, verify, and restore workspace backups"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"enum": ["create", "list", "verify", "restore"],
|
||||
"description": "Backup command to execute"
|
||||
},
|
||||
"backup_name": {
|
||||
"type": "string",
|
||||
"description": "Name of backup (for verify/restore)"
|
||||
},
|
||||
"confirm": {
|
||||
"type": "boolean",
|
||||
"description": "Confirm restore (required for actual restore, default false)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing 'command' parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match command {
|
||||
"create" => self.cmd_create().await,
|
||||
"list" => self.cmd_list().await,
|
||||
"verify" => {
|
||||
let name = args
|
||||
.get("backup_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'backup_name' for verify"))?;
|
||||
self.cmd_verify(name).await
|
||||
}
|
||||
"restore" => {
|
||||
let name = args
|
||||
.get("backup_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'backup_name' for restore"))?;
|
||||
let confirm = args
|
||||
.get("confirm")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
self.cmd_restore(name, confirm).await
|
||||
}
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown command: {other}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Helpers ------------------------------------------------------------------
|
||||
|
||||
async fn copy_dir_recursive(src: &Path, dst: &Path) -> anyhow::Result<()> {
|
||||
fs::create_dir_all(dst).await?;
|
||||
let mut rd = fs::read_dir(src).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let src_path = entry.path();
|
||||
let dst_path = dst.join(entry.file_name());
|
||||
if src_path.is_dir() {
|
||||
Box::pin(copy_dir_recursive(&src_path, &dst_path)).await?;
|
||||
} else {
|
||||
fs::copy(&src_path, &dst_path).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn compute_checksums(dir: &Path) -> anyhow::Result<HashMap<String, String>> {
|
||||
let mut map = HashMap::new();
|
||||
let base = dir.to_path_buf();
|
||||
walk_and_hash(&base, dir, &mut map).await?;
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
async fn walk_and_hash(
|
||||
base: &Path,
|
||||
dir: &Path,
|
||||
map: &mut HashMap<String, String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
Box::pin(walk_and_hash(base, &path, map)).await?;
|
||||
} else {
|
||||
let rel = path
|
||||
.strip_prefix(base)
|
||||
.unwrap_or(&path)
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/");
|
||||
if rel == "manifest.json" {
|
||||
continue;
|
||||
}
|
||||
let bytes = fs::read(&path).await?;
|
||||
let hash = hex::encode(Sha256::digest(&bytes));
|
||||
map.insert(rel, hash);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_tool(tmp: &TempDir) -> BackupTool {
|
||||
BackupTool::new(
|
||||
tmp.path().to_path_buf(),
|
||||
vec!["config".into(), "memory".into()],
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_backup_produces_manifest() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
// Seed workspace subdirectories.
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "key = 1").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
assert!(res.success, "create failed: {:?}", res.error);
|
||||
|
||||
let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(parsed["file_count"], 1);
|
||||
|
||||
// Manifest should exist inside the backup directory.
|
||||
let backup_name = parsed["backup"].as_str().unwrap();
|
||||
let manifest = tmp
|
||||
.path()
|
||||
.join("backups")
|
||||
.join(backup_name)
|
||||
.join("manifest.json");
|
||||
assert!(manifest.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_backup_detects_corruption() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "original").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
let name = parsed["backup"].as_str().unwrap();
|
||||
|
||||
// Corrupt a file inside the backup.
|
||||
let backed_up = tmp.path().join("backups").join(name).join("config/a.toml");
|
||||
std::fs::write(&backed_up, "corrupted").unwrap();
|
||||
|
||||
let res = tool
|
||||
.execute(json!({"command": "verify", "backup_name": name}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert!(!v["mismatches"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn restore_requires_confirmation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
let name = parsed["backup"].as_str().unwrap();
|
||||
|
||||
// Without confirm: dry-run.
|
||||
let res = tool
|
||||
.execute(json!({"command": "restore", "backup_name": name}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["dry_run"], true);
|
||||
|
||||
// With confirm: actual restore.
|
||||
let res = tool
|
||||
.execute(json!({"command": "restore", "backup_name": name, "confirm": true}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert!(v.get("restored").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_backups_sorted_newest_first() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg_dir = tmp.path().join("config");
|
||||
std::fs::create_dir_all(&cfg_dir).unwrap();
|
||||
std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
// Delay to ensure different second-resolution timestamps.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
tool.execute(json!({"command": "create"})).await.unwrap();
|
||||
|
||||
let res = tool.execute(json!({"command": "list"})).await.unwrap();
|
||||
assert!(res.success);
|
||||
let items: Vec<serde_json::Value> = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(items.len(), 2);
|
||||
// Newest first by name (ISO8601 names sort lexicographically).
|
||||
assert!(items[0]["name"].as_str().unwrap() >= items[1]["name"].as_str().unwrap());
|
||||
}
|
||||
}
|
||||
@@ -440,6 +440,12 @@ impl BrowserTool {
|
||||
async fn run_command(&self, args: &[&str]) -> anyhow::Result<AgentBrowserResponse> {
|
||||
let mut cmd = Command::new("agent-browser");
|
||||
|
||||
// When running as a service (systemd/OpenRC), the process may lack
|
||||
// HOME which browsers need for profile directories.
|
||||
if is_service_environment() {
|
||||
ensure_browser_env(&mut cmd);
|
||||
}
|
||||
|
||||
// Add session if configured
|
||||
if let Some(ref session) = self.session_name {
|
||||
cmd.arg("--session").arg(session);
|
||||
@@ -1461,6 +1467,14 @@ mod native_backend {
|
||||
args.push(Value::String("--disable-gpu".to_string()));
|
||||
}
|
||||
|
||||
// When running as a service (systemd/OpenRC), the browser sandbox
|
||||
// fails because the process lacks a user namespace / session.
|
||||
// --no-sandbox and --disable-dev-shm-usage are required in this context.
|
||||
if super::is_service_environment() {
|
||||
args.push(Value::String("--no-sandbox".to_string()));
|
||||
args.push(Value::String("--disable-dev-shm-usage".to_string()));
|
||||
}
|
||||
|
||||
if !args.is_empty() {
|
||||
chrome_options.insert("args".to_string(), Value::Array(args));
|
||||
}
|
||||
@@ -2111,6 +2125,44 @@ fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
|
||||
|| v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
|
||||
}
|
||||
|
||||
/// Detect whether the current process is running inside a service environment
|
||||
/// (e.g. systemd, OpenRC, or launchd) where the browser sandbox and
|
||||
/// environment setup may be restricted.
|
||||
fn is_service_environment() -> bool {
|
||||
if std::env::var_os("INVOCATION_ID").is_some() {
|
||||
return true;
|
||||
}
|
||||
if std::env::var_os("JOURNAL_STREAM").is_some() {
|
||||
return true;
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if std::path::Path::new("/run/openrc").exists() && std::env::var_os("HOME").is_none() {
|
||||
return true;
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if std::env::var_os("HOME").is_none() {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Ensure environment variables required by headless browsers are present
|
||||
/// when running inside a service context.
|
||||
fn ensure_browser_env(cmd: &mut Command) {
|
||||
if std::env::var_os("HOME").is_none() {
|
||||
cmd.env("HOME", "/tmp");
|
||||
}
|
||||
let existing = std::env::var("CHROMIUM_FLAGS").unwrap_or_default();
|
||||
if !existing.contains("--no-sandbox") {
|
||||
let new_flags = if existing.is_empty() {
|
||||
"--no-sandbox --disable-dev-shm-usage".to_string()
|
||||
} else {
|
||||
format!("{existing} --no-sandbox --disable-dev-shm-usage")
|
||||
};
|
||||
cmd.env("CHROMIUM_FLAGS", new_flags);
|
||||
}
|
||||
}
|
||||
|
||||
fn host_matches_allowlist(host: &str, allowed: &[String]) -> bool {
|
||||
allowed.iter().any(|pattern| {
|
||||
if pattern == "*" {
|
||||
@@ -2492,4 +2544,78 @@ mod tests {
|
||||
state.reset_session().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_browser_env_sets_home_when_missing() {
|
||||
let original_home = std::env::var_os("HOME");
|
||||
unsafe { std::env::remove_var("HOME") };
|
||||
|
||||
let mut cmd = Command::new("true");
|
||||
ensure_browser_env(&mut cmd);
|
||||
// Function completes without panic — HOME and CHROMIUM_FLAGS set on cmd.
|
||||
|
||||
if let Some(home) = original_home {
|
||||
unsafe { std::env::set_var("HOME", home) };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_browser_env_sets_chromium_flags() {
|
||||
let original = std::env::var_os("CHROMIUM_FLAGS");
|
||||
unsafe { std::env::remove_var("CHROMIUM_FLAGS") };
|
||||
|
||||
let mut cmd = Command::new("true");
|
||||
ensure_browser_env(&mut cmd);
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("CHROMIUM_FLAGS", val) };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_detects_invocation_id() {
|
||||
let original = std::env::var_os("INVOCATION_ID");
|
||||
unsafe { std::env::set_var("INVOCATION_ID", "test-unit-id") };
|
||||
|
||||
assert!(is_service_environment());
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("INVOCATION_ID", val) };
|
||||
} else {
|
||||
unsafe { std::env::remove_var("INVOCATION_ID") };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_detects_journal_stream() {
|
||||
let original = std::env::var_os("JOURNAL_STREAM");
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", "8:12345") };
|
||||
|
||||
assert!(is_service_environment());
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
|
||||
} else {
|
||||
unsafe { std::env::remove_var("JOURNAL_STREAM") };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_false_in_normal_context() {
|
||||
let inv = std::env::var_os("INVOCATION_ID");
|
||||
let journal = std::env::var_os("JOURNAL_STREAM");
|
||||
unsafe { std::env::remove_var("INVOCATION_ID") };
|
||||
unsafe { std::env::remove_var("JOURNAL_STREAM") };
|
||||
|
||||
if std::env::var_os("HOME").is_some() {
|
||||
assert!(!is_service_environment());
|
||||
}
|
||||
|
||||
if let Some(val) = inv {
|
||||
unsafe { std::env::set_var("INVOCATION_ID", val) };
|
||||
}
|
||||
if let Some(val) = journal {
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,851 @@
|
||||
//! Cloud operations advisory tool for cloud transformation analysis.
|
||||
//!
|
||||
//! Provides read-only analysis capabilities: IaC review, migration assessment,
|
||||
//! cost analysis, and Well-Architected Framework architecture review.
|
||||
//! This tool does NOT create, modify, or delete cloud resources.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::CloudOpsConfig;
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
/// Read-only cloud operations advisory tool.
|
||||
///
|
||||
/// Actions: `review_iac`, `assess_migration`, `cost_analysis`, `architecture_review`.
|
||||
pub struct CloudOpsTool {
|
||||
config: CloudOpsConfig,
|
||||
}
|
||||
|
||||
impl CloudOpsTool {
|
||||
pub fn new(config: CloudOpsConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CloudOpsTool {
|
||||
fn name(&self) -> &str {
|
||||
"cloud_ops"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Cloud transformation advisory tool. Analyzes IaC plans, assesses migration paths, \
|
||||
reviews costs, and checks architecture against Well-Architected Framework pillars. \
|
||||
Read-only: does not create or modify cloud resources."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["review_iac", "assess_migration", "cost_analysis", "architecture_review"],
|
||||
"description": "The analysis action to perform."
|
||||
},
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "For review_iac: IaC plan text or JSON content to analyze. For assess_migration: current architecture description text. For cost_analysis: billing data as CSV/JSON text. For architecture_review: architecture description text. Note: provide text content directly, not file paths."
|
||||
},
|
||||
"cloud": {
|
||||
"type": "string",
|
||||
"description": "Target cloud provider (aws, azure, gcp). Uses configured default if omitted."
|
||||
}
|
||||
},
|
||||
"required": ["action", "input"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action") {
|
||||
Some(v) => v
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("'action' must be a string, got: {}", v))?,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'action' parameter is required".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let input = match args.get("input") {
|
||||
Some(v) => v
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("'input' must be a string, got: {}", v))?,
|
||||
None => "",
|
||||
};
|
||||
let cloud = match args.get("cloud") {
|
||||
Some(v) => v
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("'cloud' must be a string, got: {}", v))?,
|
||||
None => &self.config.default_cloud,
|
||||
};
|
||||
|
||||
if input.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'input' parameter is required and cannot be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
if !self.config.supported_clouds.contains(&cloud.to_string()) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Cloud provider '{}' is not in supported_clouds: {:?}",
|
||||
cloud, self.config.supported_clouds
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
match action {
|
||||
"review_iac" => self.review_iac(input, cloud).await,
|
||||
"assess_migration" => self.assess_migration(input, cloud).await,
|
||||
"cost_analysis" => self.cost_analysis(input, cloud).await,
|
||||
"architecture_review" => self.architecture_review(input, cloud).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action '{}'. Valid: review_iac, assess_migration, cost_analysis, architecture_review",
|
||||
action
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unused_async)]
|
||||
impl CloudOpsTool {
|
||||
async fn review_iac(&self, input: &str, cloud: &str) -> anyhow::Result<ToolResult> {
|
||||
let mut findings = Vec::new();
|
||||
|
||||
// Detect IaC type from content
|
||||
let iac_type = detect_iac_type(input);
|
||||
|
||||
// Security findings
|
||||
for finding in scan_iac_security(input) {
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Best practice findings
|
||||
for finding in scan_iac_best_practices(input, cloud) {
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Cost implications
|
||||
for finding in scan_iac_cost(input, cloud, self.config.cost_threshold_monthly_usd) {
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
let output = json!({
|
||||
"iac_type": iac_type,
|
||||
"cloud": cloud,
|
||||
"findings_count": findings.len(),
|
||||
"findings": findings,
|
||||
"supported_iac_tools": self.config.iac_tools,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn assess_migration(&self, input: &str, cloud: &str) -> anyhow::Result<ToolResult> {
|
||||
let recommendations = assess_migration_recommendations(input, cloud);
|
||||
|
||||
let output = json!({
|
||||
"cloud": cloud,
|
||||
"source_description": truncate_with_ellipsis(input, 200),
|
||||
"recommendations": recommendations,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cost_analysis(&self, input: &str, cloud: &str) -> anyhow::Result<ToolResult> {
|
||||
let opportunities =
|
||||
analyze_cost_opportunities(input, self.config.cost_threshold_monthly_usd);
|
||||
|
||||
let output = json!({
|
||||
"cloud": cloud,
|
||||
"threshold_usd": self.config.cost_threshold_monthly_usd,
|
||||
"opportunities_count": opportunities.len(),
|
||||
"opportunities": opportunities,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn architecture_review(&self, input: &str, cloud: &str) -> anyhow::Result<ToolResult> {
|
||||
let frameworks = &self.config.well_architected_frameworks;
|
||||
let pillars = review_architecture_pillars(input, cloud, frameworks);
|
||||
|
||||
let output = json!({
|
||||
"cloud": cloud,
|
||||
"frameworks": frameworks,
|
||||
"pillars": pillars,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── Analysis helpers ──────────────────────────────────────────────
|
||||
|
||||
fn detect_iac_type(input: &str) -> &'static str {
|
||||
let lower = input.to_lowercase();
|
||||
if lower.contains("resource \"") || lower.contains("terraform") || lower.contains(".tf") {
|
||||
"terraform"
|
||||
} else if lower.contains("awstemplatebody")
|
||||
|| lower.contains("cloudformation")
|
||||
|| lower.contains("aws::")
|
||||
{
|
||||
"cloudformation"
|
||||
} else if lower.contains("pulumi") {
|
||||
"pulumi"
|
||||
} else {
|
||||
"unknown"
|
||||
}
|
||||
}
|
||||
|
||||
/// Scan IaC content for common security issues.
|
||||
fn scan_iac_security(input: &str) -> Vec<serde_json::Value> {
|
||||
let lower = input.to_lowercase();
|
||||
let mut findings = Vec::new();
|
||||
|
||||
let security_patterns: &[(&str, &str, &str)] = &[
|
||||
(
|
||||
"0.0.0.0/0",
|
||||
"high",
|
||||
"Unrestricted ingress (0.0.0.0/0) detected. Restrict CIDR ranges to known networks.",
|
||||
),
|
||||
(
|
||||
"::/0",
|
||||
"high",
|
||||
"Unrestricted IPv6 ingress (::/0) detected. Restrict CIDR ranges.",
|
||||
),
|
||||
(
|
||||
"public_access",
|
||||
"medium",
|
||||
"Public access setting detected. Verify this is intentional and necessary.",
|
||||
),
|
||||
(
|
||||
"publicly_accessible",
|
||||
"medium",
|
||||
"Resource marked as publicly accessible. Ensure this is required.",
|
||||
),
|
||||
(
|
||||
"encrypted = false",
|
||||
"high",
|
||||
"Encryption explicitly disabled. Enable encryption at rest.",
|
||||
),
|
||||
(
|
||||
"\"*\"",
|
||||
"medium",
|
||||
"Wildcard permission detected. Follow least-privilege principle.",
|
||||
),
|
||||
(
|
||||
"password",
|
||||
"medium",
|
||||
"Hardcoded password reference detected. Use secrets manager instead.",
|
||||
),
|
||||
(
|
||||
"access_key",
|
||||
"high",
|
||||
"Access key reference in IaC. Use IAM roles or secrets manager.",
|
||||
),
|
||||
(
|
||||
"secret_key",
|
||||
"high",
|
||||
"Secret key reference in IaC. Use IAM roles or secrets manager.",
|
||||
),
|
||||
];
|
||||
|
||||
for (pattern, severity, message) in security_patterns {
|
||||
if lower.contains(pattern) {
|
||||
findings.push(json!({
|
||||
"category": "security",
|
||||
"severity": severity,
|
||||
"message": message,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
|
||||
/// Scan for IaC best practice violations.
|
||||
fn scan_iac_best_practices(input: &str, cloud: &str) -> Vec<serde_json::Value> {
|
||||
let lower = input.to_lowercase();
|
||||
let mut findings = Vec::new();
|
||||
|
||||
// Tagging
|
||||
if !lower.contains("tags") && !lower.contains("tag") {
|
||||
findings.push(json!({
|
||||
"category": "best_practice",
|
||||
"severity": "low",
|
||||
"message": "No resource tags detected. Add tags for cost allocation and resource management.",
|
||||
}));
|
||||
}
|
||||
|
||||
// Versioning
|
||||
if lower.contains("s3") && !lower.contains("versioning") {
|
||||
findings.push(json!({
|
||||
"category": "best_practice",
|
||||
"severity": "medium",
|
||||
"message": "S3 bucket without versioning detected. Enable versioning for data protection.",
|
||||
}));
|
||||
}
|
||||
|
||||
// Logging
|
||||
if !lower.contains("logging") && !lower.contains("log_group") && !lower.contains("access_logs")
|
||||
{
|
||||
findings.push(json!({
|
||||
"category": "best_practice",
|
||||
"severity": "low",
|
||||
"message": format!("No logging configuration detected for {}. Enable access logging.", cloud),
|
||||
}));
|
||||
}
|
||||
|
||||
// Backup
|
||||
if lower.contains("rds") && !lower.contains("backup_retention") {
|
||||
findings.push(json!({
|
||||
"category": "best_practice",
|
||||
"severity": "medium",
|
||||
"message": "RDS instance without backup retention configuration. Set backup_retention_period.",
|
||||
}));
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
|
||||
/// Scan for cost-related observations in IaC.
|
||||
///
|
||||
/// Only emits findings for resources whose estimated monthly cost exceeds
|
||||
/// `threshold`. AWS-specific patterns (NAT Gateway, Elastic IP, ALB) are
|
||||
/// gated behind `cloud == "aws"`.
|
||||
fn scan_iac_cost(input: &str, cloud: &str, threshold: f64) -> Vec<serde_json::Value> {
|
||||
let lower = input.to_lowercase();
|
||||
let mut findings = Vec::new();
|
||||
|
||||
// (pattern, message, estimated_monthly_usd, aws_only)
|
||||
let expensive_patterns: &[(&str, &str, f64, bool)] = &[
|
||||
("instance_type", "Review instance sizing. Consider right-sizing or spot/preemptible instances.", 50.0, false),
|
||||
("nat_gateway", "NAT Gateway detected. These incur hourly + data transfer charges. Consider VPC endpoints for AWS services.", 45.0, true),
|
||||
("elastic_ip", "Elastic IP detected. Unused EIPs incur charges.", 5.0, true),
|
||||
("load_balancer", "Load balancer detected. Verify it is needed; consider ALB over NLB/CLB for cost.", 25.0, true),
|
||||
];
|
||||
|
||||
for (pattern, message, estimated_cost, aws_only) in expensive_patterns {
|
||||
if *aws_only && cloud != "aws" {
|
||||
continue;
|
||||
}
|
||||
if *estimated_cost < threshold {
|
||||
continue;
|
||||
}
|
||||
if lower.contains(pattern) {
|
||||
findings.push(json!({
|
||||
"category": "cost",
|
||||
"severity": "info",
|
||||
"message": message,
|
||||
"estimated_monthly_usd": estimated_cost,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
|
||||
/// Generate migration recommendations based on architecture description.
|
||||
fn assess_migration_recommendations(input: &str, cloud: &str) -> Vec<serde_json::Value> {
|
||||
let lower = input.to_lowercase();
|
||||
let mut recs = Vec::new();
|
||||
|
||||
let migration_patterns: &[(&str, &str, &str, &str)] = &[
|
||||
("monolith", "Decompose into microservices or modular containers.",
|
||||
"high", "Consider containerizing with ECS/EKS (AWS), AKS (Azure), or GKE (GCP)."),
|
||||
("vm", "Migrate VMs to containers or serverless where feasible.",
|
||||
"medium", "Evaluate lift-and-shift to managed container services."),
|
||||
("on-premises", "Assess workloads for cloud readiness using 6 Rs framework (rehost, replatform, refactor, repurchase, retire, retain).",
|
||||
"high", "Start with rehost for quick migration, then optimize."),
|
||||
("database", "Evaluate managed database services for reduced operational overhead.",
|
||||
"medium", &format!("Consider managed options: RDS/Aurora (AWS), Azure SQL (Azure), Cloud SQL (GCP) for {}.", cloud)),
|
||||
("batch", "Consider serverless compute for batch workloads.",
|
||||
"low", "Evaluate Lambda (AWS), Azure Functions, or Cloud Functions for event-driven batch."),
|
||||
("queue", "Evaluate managed message queue services.",
|
||||
"low", "Consider SQS/SNS (AWS), Service Bus (Azure), or Pub/Sub (GCP)."),
|
||||
("storage", "Evaluate tiered object storage for cost optimization.",
|
||||
"medium", "Use lifecycle policies for infrequent access data."),
|
||||
("legacy", "Assess modernization path: replatform or refactor.",
|
||||
"high", "Legacy systems carry tech debt; prioritize incremental modernization."),
|
||||
];
|
||||
|
||||
for (keyword, recommendation, effort, detail) in migration_patterns {
|
||||
if lower.contains(keyword) {
|
||||
recs.push(json!({
|
||||
"trigger": keyword,
|
||||
"recommendation": recommendation,
|
||||
"effort_estimate": effort,
|
||||
"detail": detail,
|
||||
"target_cloud": cloud,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if recs.is_empty() {
|
||||
recs.push(json!({
|
||||
"trigger": "general",
|
||||
"recommendation": "Provide more detail about current architecture components for targeted recommendations.",
|
||||
"effort_estimate": "unknown",
|
||||
"detail": "Include details about compute, storage, networking, and data layers.",
|
||||
"target_cloud": cloud,
|
||||
}));
|
||||
}
|
||||
|
||||
recs
|
||||
}
|
||||
|
||||
/// Analyze billing/cost data for optimization opportunities.
|
||||
fn analyze_cost_opportunities(input: &str, threshold: f64) -> Vec<serde_json::Value> {
|
||||
let lower = input.to_lowercase();
|
||||
let mut opportunities = Vec::new();
|
||||
|
||||
// General cost patterns
|
||||
let cost_patterns: &[(&str, &str, &str)] = &[
|
||||
("reserved", "Review reserved instance utilization. Unused reservations waste budget.", "high"),
|
||||
("on-demand", "On-demand instances detected. Evaluate savings plans or reserved instances for stable workloads.", "high"),
|
||||
("data transfer", "Data transfer costs detected. Use VPC endpoints, CDN, or regional placement to reduce.", "medium"),
|
||||
("storage", "Storage costs detected. Implement lifecycle policies and tiered storage.", "medium"),
|
||||
("idle", "Idle resources detected. Identify and terminate unused resources.", "high"),
|
||||
("unattached", "Unattached resources (volumes, IPs) detected. Clean up to reduce waste.", "medium"),
|
||||
("snapshot", "Snapshot costs detected. Review retention policies and delete stale snapshots.", "low"),
|
||||
];
|
||||
|
||||
for (pattern, suggestion, priority) in cost_patterns {
|
||||
if lower.contains(pattern) {
|
||||
opportunities.push(json!({
|
||||
"pattern": pattern,
|
||||
"suggestion": suggestion,
|
||||
"priority": priority,
|
||||
"threshold_usd": threshold,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if opportunities.is_empty() {
|
||||
opportunities.push(json!({
|
||||
"pattern": "general",
|
||||
"suggestion": "Provide billing CSV/JSON data with service and cost columns for detailed analysis.",
|
||||
"priority": "info",
|
||||
"threshold_usd": threshold,
|
||||
}));
|
||||
}
|
||||
|
||||
opportunities
|
||||
}
|
||||
|
||||
/// Review architecture against Well-Architected Framework pillars.
|
||||
fn review_architecture_pillars(
|
||||
input: &str,
|
||||
cloud: &str,
|
||||
_frameworks: &[String],
|
||||
) -> Vec<serde_json::Value> {
|
||||
let lower = input.to_lowercase();
|
||||
|
||||
let pillars = vec![
|
||||
("security", review_pillar_security(&lower, cloud)),
|
||||
("reliability", review_pillar_reliability(&lower, cloud)),
|
||||
("performance", review_pillar_performance(&lower, cloud)),
|
||||
("cost_optimization", review_pillar_cost(&lower, cloud)),
|
||||
(
|
||||
"operational_excellence",
|
||||
review_pillar_operations(&lower, cloud),
|
||||
),
|
||||
];
|
||||
|
||||
pillars
|
||||
.into_iter()
|
||||
.map(|(name, findings)| {
|
||||
json!({
|
||||
"pillar": name,
|
||||
"findings_count": findings.len(),
|
||||
"findings": findings,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn review_pillar_security(input: &str, _cloud: &str) -> Vec<String> {
|
||||
let mut findings = Vec::new();
|
||||
if !input.contains("iam") && !input.contains("identity") {
|
||||
findings.push(
|
||||
"No IAM/identity layer described. Define identity and access management strategy."
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
if !input.contains("encrypt") {
|
||||
findings
|
||||
.push("No encryption mentioned. Implement encryption at rest and in transit.".into());
|
||||
}
|
||||
if !input.contains("firewall") && !input.contains("waf") && !input.contains("security group") {
|
||||
findings.push(
|
||||
"No network security controls described. Add WAF, security groups, or firewall rules."
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
if !input.contains("audit") && !input.contains("logging") {
|
||||
findings.push(
|
||||
"No audit logging described. Enable CloudTrail/Azure Monitor/Cloud Audit Logs.".into(),
|
||||
);
|
||||
}
|
||||
findings
|
||||
}
|
||||
|
||||
fn review_pillar_reliability(input: &str, _cloud: &str) -> Vec<String> {
|
||||
let mut findings = Vec::new();
|
||||
if !input.contains("multi-az") && !input.contains("multi-region") && !input.contains("redundan")
|
||||
{
|
||||
findings
|
||||
.push("No redundancy described. Consider multi-AZ or multi-region deployment.".into());
|
||||
}
|
||||
if !input.contains("backup") {
|
||||
findings.push("No backup strategy described. Define RPO/RTO and backup schedules.".into());
|
||||
}
|
||||
if !input.contains("auto-scal") && !input.contains("autoscal") {
|
||||
findings.push(
|
||||
"No auto-scaling described. Implement scaling policies for variable load.".into(),
|
||||
);
|
||||
}
|
||||
if !input.contains("health check") && !input.contains("monitor") {
|
||||
findings.push("No health monitoring described. Add health checks and alerting.".into());
|
||||
}
|
||||
findings
|
||||
}
|
||||
|
||||
fn review_pillar_performance(input: &str, _cloud: &str) -> Vec<String> {
|
||||
let mut findings = Vec::new();
|
||||
if !input.contains("cache") && !input.contains("cdn") {
|
||||
findings
|
||||
.push("No caching layer described. Consider CDN and application-level caching.".into());
|
||||
}
|
||||
if !input.contains("load balanc") {
|
||||
findings
|
||||
.push("No load balancing described. Add load balancer for distributed traffic.".into());
|
||||
}
|
||||
if !input.contains("metric") && !input.contains("benchmark") {
|
||||
findings.push(
|
||||
"No performance metrics described. Define SLIs/SLOs and baseline benchmarks.".into(),
|
||||
);
|
||||
}
|
||||
findings
|
||||
}
|
||||
|
||||
fn review_pillar_cost(input: &str, _cloud: &str) -> Vec<String> {
|
||||
let mut findings = Vec::new();
|
||||
if !input.contains("budget") && !input.contains("cost") {
|
||||
findings
|
||||
.push("No cost controls described. Set budget alerts and cost allocation tags.".into());
|
||||
}
|
||||
if !input.contains("reserved") && !input.contains("savings plan") && !input.contains("spot") {
|
||||
findings.push("No cost optimization strategy described. Evaluate RIs, savings plans, or spot instances.".into());
|
||||
}
|
||||
if !input.contains("rightsiz") && !input.contains("right-siz") {
|
||||
findings.push(
|
||||
"No right-sizing mentioned. Regularly review instance utilization and downsize.".into(),
|
||||
);
|
||||
}
|
||||
findings
|
||||
}
|
||||
|
||||
fn review_pillar_operations(input: &str, _cloud: &str) -> Vec<String> {
|
||||
let mut findings = Vec::new();
|
||||
if !input.contains("iac")
|
||||
&& !input.contains("terraform")
|
||||
&& !input.contains("infrastructure as code")
|
||||
{
|
||||
findings.push(
|
||||
"No IaC mentioned. Manage all infrastructure as code for reproducibility.".into(),
|
||||
);
|
||||
}
|
||||
if !input.contains("ci") && !input.contains("pipeline") && !input.contains("deploy") {
|
||||
findings.push("No CI/CD described. Automate build, test, and deployment pipelines.".into());
|
||||
}
|
||||
if !input.contains("runbook") && !input.contains("incident") {
|
||||
findings.push(
|
||||
"No incident response described. Create runbooks and incident procedures.".into(),
|
||||
);
|
||||
}
|
||||
findings
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CloudOpsConfig {
|
||||
CloudOpsConfig::default()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn review_iac_detects_security_findings() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": "resource \"aws_security_group\" \"open\" { ingress { cidr_blocks = [\"0.0.0.0/0\"] } }"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Unrestricted ingress"));
|
||||
assert!(result.output.contains("high"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn review_iac_detects_terraform_type() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": "resource \"aws_instance\" \"test\" { instance_type = \"t3.micro\" tags = { Name = \"test\" } }"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("\"iac_type\": \"terraform\""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn review_iac_detects_encrypted_false() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": "resource \"aws_ebs_volume\" \"vol\" { encrypted = false tags = {} }"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Encryption explicitly disabled"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_analysis_detects_on_demand() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "cost_analysis",
|
||||
"input": "service,cost\nEC2 On-Demand,5000\nS3 Storage,200"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("on-demand"));
|
||||
assert!(result.output.contains("storage"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn architecture_review_returns_all_pillars() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "architecture_review",
|
||||
"input": "Web app with EC2, RDS, S3. No caching layer."
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("security"));
|
||||
assert!(result.output.contains("reliability"));
|
||||
assert!(result.output.contains("performance"));
|
||||
assert!(result.output.contains("cost_optimization"));
|
||||
assert!(result.output.contains("operational_excellence"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn assess_migration_detects_monolith() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "assess_migration",
|
||||
"input": "Legacy monolith application running on VMs with on-premises database."
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("monolith"));
|
||||
assert!(result.output.contains("microservices"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_input_returns_error() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": ""
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsupported_cloud_returns_error() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": "some content",
|
||||
"cloud": "alibaba"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("not in supported_clouds"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_action_returns_error() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "deploy_everything",
|
||||
"input": "some content"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_iac_type_identifies_cloudformation() {
|
||||
assert_eq!(detect_iac_type("AWS::EC2::Instance"), "cloudformation");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_iac_type_identifies_pulumi() {
|
||||
assert_eq!(detect_iac_type("import pulumi"), "pulumi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_iac_security_finds_wildcard_permission() {
|
||||
let findings = scan_iac_security("Action: \"*\" Effect: Allow");
|
||||
assert!(!findings.is_empty());
|
||||
let msg = findings[0]["message"].as_str().unwrap();
|
||||
assert!(msg.contains("Wildcard permission"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_iac_cost_gates_aws_patterns_for_non_aws() {
|
||||
// NAT Gateway / Elastic IP / Load Balancer are AWS-only; should not appear for azure
|
||||
let findings = scan_iac_cost(
|
||||
"nat_gateway elastic_ip load_balancer instance_type",
|
||||
"azure",
|
||||
0.0, // threshold 0 so all cost-eligible items pass
|
||||
);
|
||||
for f in &findings {
|
||||
let msg = f["message"].as_str().unwrap();
|
||||
assert!(
|
||||
!msg.contains("NAT Gateway") && !msg.contains("Elastic IP") && !msg.contains("ALB"),
|
||||
"AWS-specific finding leaked for azure: {}",
|
||||
msg
|
||||
);
|
||||
}
|
||||
// instance_type is cloud-agnostic and should still appear
|
||||
assert!(findings
|
||||
.iter()
|
||||
.any(|f| f["message"].as_str().unwrap().contains("instance sizing")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_iac_cost_respects_threshold() {
|
||||
// With a high threshold, low-cost patterns should be filtered out
|
||||
let findings = scan_iac_cost(
|
||||
"nat_gateway elastic_ip instance_type",
|
||||
"aws",
|
||||
200.0, // above all estimated costs
|
||||
);
|
||||
assert!(
|
||||
findings.is_empty(),
|
||||
"expected no findings above threshold 200, got {:?}",
|
||||
findings
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_string_action_returns_error() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": 42,
|
||||
"input": "some content"
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("'action' must be a string"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_string_input_returns_error() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": 123
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("'input' must be a string"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_string_cloud_returns_error() {
|
||||
let tool = CloudOpsTool::new(test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "review_iac",
|
||||
"input": "some content",
|
||||
"cloud": true
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("'cloud' must be a string"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,412 @@
|
||||
//! Cloud pattern library for recommending cloud-native architectural patterns.
|
||||
//!
|
||||
//! Provides a built-in set of cloud migration and modernization patterns,
|
||||
//! with pattern matching against workload descriptions.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
/// A cloud architecture pattern with metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CloudPattern {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub cloud_providers: Vec<String>,
|
||||
pub use_case: String,
|
||||
pub example_iac: String,
|
||||
/// Keywords for matching against workload descriptions.
|
||||
keywords: Vec<String>,
|
||||
}
|
||||
|
||||
/// Tool that suggests cloud patterns given a workload description.
|
||||
pub struct CloudPatternsTool {
|
||||
patterns: Vec<CloudPattern>,
|
||||
}
|
||||
|
||||
impl CloudPatternsTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
patterns: built_in_patterns(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CloudPatternsTool {
|
||||
fn name(&self) -> &str {
|
||||
"cloud_patterns"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Cloud pattern library. Given a workload description, suggests applicable cloud-native \
|
||||
architectural patterns (containerization, serverless, database modernization, etc.)."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["match", "list"],
|
||||
"description": "Action: 'match' to find patterns for a workload, 'list' to show all patterns."
|
||||
},
|
||||
"workload": {
|
||||
"type": "string",
|
||||
"description": "Description of the workload to match patterns against (required for 'match')."
|
||||
},
|
||||
"cloud": {
|
||||
"type": "string",
|
||||
"description": "Filter patterns by cloud provider (aws, azure, gcp). Optional."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let workload = args
|
||||
.get("workload")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let cloud_filter = args.get("cloud").and_then(|v| v.as_str());
|
||||
|
||||
match action {
|
||||
"list" => {
|
||||
let filtered = self.filter_by_cloud(cloud_filter);
|
||||
let summaries: Vec<serde_json::Value> = filtered
|
||||
.iter()
|
||||
.map(|p| {
|
||||
json!({
|
||||
"name": p.name,
|
||||
"description": p.description,
|
||||
"cloud_providers": p.cloud_providers,
|
||||
"use_case": p.use_case,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output = json!({
|
||||
"patterns_count": summaries.len(),
|
||||
"patterns": summaries,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
"match" => {
|
||||
if workload.trim().is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'workload' parameter is required for 'match' action".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let matched = self.match_patterns(workload, cloud_filter);
|
||||
|
||||
let output = json!({
|
||||
"workload_summary": truncate_with_ellipsis(workload, 200),
|
||||
"matched_count": matched.len(),
|
||||
"matched_patterns": matched,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action '{}'. Valid: match, list", action)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CloudPatternsTool {
|
||||
fn filter_by_cloud(&self, cloud: Option<&str>) -> Vec<&CloudPattern> {
|
||||
match cloud {
|
||||
Some(c) => self
|
||||
.patterns
|
||||
.iter()
|
||||
.filter(|p| p.cloud_providers.iter().any(|cp| cp == c))
|
||||
.collect(),
|
||||
None => self.patterns.iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn match_patterns(&self, workload: &str, cloud: Option<&str>) -> Vec<serde_json::Value> {
|
||||
let lower = workload.to_lowercase();
|
||||
let candidates = self.filter_by_cloud(cloud);
|
||||
|
||||
let mut scored: Vec<(&CloudPattern, usize)> = candidates
|
||||
.into_iter()
|
||||
.filter_map(|p| {
|
||||
let score: usize = p
|
||||
.keywords
|
||||
.iter()
|
||||
.filter(|kw| lower.contains(kw.as_str()))
|
||||
.count();
|
||||
if score > 0 {
|
||||
Some((p, score))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
// Built-in IaC examples are AWS Terraform only; include them only when
|
||||
// the cloud filter is unset or explicitly "aws".
|
||||
let include_example = cloud.is_none() || cloud == Some("aws");
|
||||
|
||||
scored
|
||||
.into_iter()
|
||||
.map(|(p, score)| {
|
||||
let mut entry = json!({
|
||||
"name": p.name,
|
||||
"description": p.description,
|
||||
"cloud_providers": p.cloud_providers,
|
||||
"use_case": p.use_case,
|
||||
"relevance_score": score,
|
||||
});
|
||||
if include_example {
|
||||
entry["example_iac"] = json!(p.example_iac);
|
||||
}
|
||||
entry
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn built_in_patterns() -> Vec<CloudPattern> {
|
||||
vec![
|
||||
CloudPattern {
|
||||
name: "containerization".into(),
|
||||
description: "Package applications into containers for portability and consistent deployment.".into(),
|
||||
cloud_providers: vec!["aws".into(), "azure".into(), "gcp".into()],
|
||||
use_case: "Modernizing monolithic applications, improving deployment consistency, enabling microservices.".into(),
|
||||
example_iac: r#"# Terraform ECS Fargate example
|
||||
resource "aws_ecs_cluster" "main" {
|
||||
name = "app-cluster"
|
||||
}
|
||||
resource "aws_ecs_service" "app" {
|
||||
cluster = aws_ecs_cluster.main.id
|
||||
task_definition = aws_ecs_task_definition.app.arn
|
||||
launch_type = "FARGATE"
|
||||
desired_count = 2
|
||||
}"#.into(),
|
||||
keywords: vec!["container".into(), "docker".into(), "monolith".into(), "microservice".into(), "ecs".into(), "aks".into(), "gke".into(), "kubernetes".into(), "k8s".into()],
|
||||
},
|
||||
CloudPattern {
|
||||
name: "serverless_migration".into(),
|
||||
description: "Migrate event-driven or periodic workloads to serverless compute.".into(),
|
||||
cloud_providers: vec!["aws".into(), "azure".into(), "gcp".into()],
|
||||
use_case: "Batch jobs, API backends, event processing, cron tasks with variable load.".into(),
|
||||
example_iac: r#"# Terraform Lambda example
|
||||
resource "aws_lambda_function" "handler" {
|
||||
function_name = "event-handler"
|
||||
runtime = "python3.12"
|
||||
handler = "main.handler"
|
||||
filename = "handler.zip"
|
||||
memory_size = 256
|
||||
timeout = 30
|
||||
}"#.into(),
|
||||
keywords: vec!["serverless".into(), "lambda".into(), "function".into(), "event".into(), "batch".into(), "cron".into(), "api".into(), "webhook".into()],
|
||||
},
|
||||
CloudPattern {
|
||||
name: "database_modernization".into(),
|
||||
description: "Migrate self-managed databases to cloud-managed services for reduced ops overhead.".into(),
|
||||
cloud_providers: vec!["aws".into(), "azure".into(), "gcp".into()],
|
||||
use_case: "Self-managed MySQL/PostgreSQL/SQL Server migration, NoSQL adoption, read replica scaling.".into(),
|
||||
example_iac: r#"# Terraform RDS example
|
||||
resource "aws_db_instance" "main" {
|
||||
engine = "postgres"
|
||||
engine_version = "15"
|
||||
instance_class = "db.t3.medium"
|
||||
allocated_storage = 100
|
||||
multi_az = true
|
||||
backup_retention_period = 7
|
||||
storage_encrypted = true
|
||||
}"#.into(),
|
||||
keywords: vec!["database".into(), "mysql".into(), "postgres".into(), "sql".into(), "rds".into(), "nosql".into(), "dynamo".into(), "mongodb".into(), "migration".into()],
|
||||
},
|
||||
CloudPattern {
|
||||
name: "api_gateway".into(),
|
||||
description: "Centralize API management with rate limiting, auth, and routing.".into(),
|
||||
cloud_providers: vec!["aws".into(), "azure".into(), "gcp".into()],
|
||||
use_case: "Public API exposure, microservice routing, API versioning, throttling.".into(),
|
||||
example_iac: r#"# Terraform API Gateway example
|
||||
resource "aws_apigatewayv2_api" "main" {
|
||||
name = "app-api"
|
||||
protocol_type = "HTTP"
|
||||
}
|
||||
resource "aws_apigatewayv2_stage" "prod" {
|
||||
api_id = aws_apigatewayv2_api.main.id
|
||||
name = "prod"
|
||||
auto_deploy = true
|
||||
}"#.into(),
|
||||
keywords: vec!["api".into(), "gateway".into(), "rest".into(), "graphql".into(), "routing".into(), "rate limit".into(), "throttl".into()],
|
||||
},
|
||||
CloudPattern {
|
||||
name: "service_mesh".into(),
|
||||
description: "Implement service mesh for observability, traffic management, and security between microservices.".into(),
|
||||
cloud_providers: vec!["aws".into(), "azure".into(), "gcp".into()],
|
||||
use_case: "Microservice communication, mTLS, traffic splitting, canary deployments.".into(),
|
||||
example_iac: r#"# AWS App Mesh example
|
||||
resource "aws_appmesh_mesh" "main" {
|
||||
name = "app-mesh"
|
||||
}
|
||||
resource "aws_appmesh_virtual_service" "app" {
|
||||
name = "app.local"
|
||||
mesh_name = aws_appmesh_mesh.main.name
|
||||
}"#.into(),
|
||||
keywords: vec!["mesh".into(), "istio".into(), "envoy".into(), "sidecar".into(), "mtls".into(), "canary".into(), "traffic".into(), "microservice".into()],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn built_in_patterns_are_populated() {
|
||||
let patterns = built_in_patterns();
|
||||
assert_eq!(patterns.len(), 5);
|
||||
let names: Vec<&str> = patterns.iter().map(|p| p.name.as_str()).collect();
|
||||
assert!(names.contains(&"containerization"));
|
||||
assert!(names.contains(&"serverless_migration"));
|
||||
assert!(names.contains(&"database_modernization"));
|
||||
assert!(names.contains(&"api_gateway"));
|
||||
assert!(names.contains(&"service_mesh"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn match_returns_containerization_for_monolith() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "match",
|
||||
"workload": "We have a monolith Java application running on VMs that we want to containerize."
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("containerization"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn match_returns_serverless_for_batch_workload() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "match",
|
||||
"workload": "Batch processing cron jobs that handle event data"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("serverless_migration"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn match_filters_by_cloud_provider() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "match",
|
||||
"workload": "Container deployment with Kubernetes",
|
||||
"cloud": "aws"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("containerization"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_returns_all_patterns() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "list"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("\"patterns_count\": 5"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn match_with_empty_workload_returns_error() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "match",
|
||||
"workload": ""
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn match_database_workload_finds_db_modernization() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "match",
|
||||
"workload": "Self-hosted PostgreSQL database needs migration to managed service"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("database_modernization"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_matching_scores_correctly() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let matches =
|
||||
tool.match_patterns("microservice container docker kubernetes deployment", None);
|
||||
// containerization should rank highest (most keyword matches)
|
||||
assert!(!matches.is_empty());
|
||||
assert_eq!(matches[0]["name"], "containerization");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_action_returns_error() {
|
||||
let tool = CloudPatternsTool::new();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "deploy"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,8 @@ impl Tool for CronRunTool {
|
||||
}
|
||||
|
||||
let started_at = Utc::now();
|
||||
let (success, output) = cron::scheduler::execute_job_now(&self.config, &job).await;
|
||||
let (success, output) =
|
||||
Box::pin(cron::scheduler::execute_job_now(&self.config, &job)).await;
|
||||
let finished_at = Utc::now();
|
||||
let duration_ms = (finished_at - started_at).num_milliseconds();
|
||||
let status = if success { "ok" } else { "error" };
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
|
||||
/// Workspace data lifecycle tool: retention status, time-based purge, and
|
||||
/// storage statistics.
|
||||
pub struct DataManagementTool {
|
||||
workspace_dir: PathBuf,
|
||||
retention_days: u64,
|
||||
}
|
||||
|
||||
impl DataManagementTool {
|
||||
pub fn new(workspace_dir: PathBuf, retention_days: u64) -> Self {
|
||||
Self {
|
||||
workspace_dir,
|
||||
retention_days,
|
||||
}
|
||||
}
|
||||
|
||||
async fn cmd_retention_status(&self) -> anyhow::Result<ToolResult> {
|
||||
let cutoff = chrono::Utc::now()
|
||||
- chrono::Duration::days(i64::try_from(self.retention_days).unwrap_or(i64::MAX));
|
||||
let cutoff_ts = cutoff.timestamp().try_into().unwrap_or(0u64);
|
||||
let count = count_files_older_than(&self.workspace_dir, cutoff_ts).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"retention_days": self.retention_days,
|
||||
"cutoff": cutoff.to_rfc3339(),
|
||||
"affected_files": count,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_purge(&self, dry_run: bool) -> anyhow::Result<ToolResult> {
|
||||
let cutoff = chrono::Utc::now()
|
||||
- chrono::Duration::days(i64::try_from(self.retention_days).unwrap_or(i64::MAX));
|
||||
let cutoff_ts: u64 = cutoff.timestamp().try_into().unwrap_or(0);
|
||||
let (deleted, bytes) = purge_old_files(&self.workspace_dir, cutoff_ts, dry_run).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"dry_run": dry_run,
|
||||
"files": deleted,
|
||||
"bytes_freed": bytes,
|
||||
"bytes_freed_human": format_bytes(bytes),
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn cmd_stats(&self) -> anyhow::Result<ToolResult> {
|
||||
let (total_files, total_bytes, breakdown) = dir_stats(&self.workspace_dir).await?;
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: json!({
|
||||
"total_files": total_files,
|
||||
"total_size": total_bytes,
|
||||
"total_size_human": format_bytes(total_bytes),
|
||||
"subdirectories": breakdown,
|
||||
})
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DataManagementTool {
|
||||
fn name(&self) -> &str {
|
||||
"data_management"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Workspace data retention, purge, and storage statistics"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"enum": ["retention_status", "purge", "stats"],
|
||||
"description": "Data management command"
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": "If true, purge only lists what would be deleted (default true)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing 'command' parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match command {
|
||||
"retention_status" => self.cmd_retention_status().await,
|
||||
"purge" => {
|
||||
let dry_run = args
|
||||
.get("dry_run")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true);
|
||||
self.cmd_purge(dry_run).await
|
||||
}
|
||||
"stats" => self.cmd_stats().await,
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown command: {other}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Helpers ------------------------------------------------------------------
|
||||
|
||||
fn format_bytes(bytes: u64) -> String {
|
||||
const KB: u64 = 1024;
|
||||
const MB: u64 = 1024 * KB;
|
||||
const GB: u64 = 1024 * MB;
|
||||
if bytes >= GB {
|
||||
format!("{:.1} GB", bytes as f64 / GB as f64)
|
||||
} else if bytes >= MB {
|
||||
format!("{:.1} MB", bytes as f64 / MB as f64)
|
||||
} else if bytes >= KB {
|
||||
format!("{:.1} KB", bytes as f64 / KB as f64)
|
||||
} else {
|
||||
format!("{bytes} B")
|
||||
}
|
||||
}
|
||||
|
||||
async fn count_files_older_than(dir: &Path, cutoff_epoch: u64) -> anyhow::Result<usize> {
|
||||
let mut count = 0;
|
||||
if !dir.is_dir() {
|
||||
return Ok(0);
|
||||
}
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
count += Box::pin(count_files_older_than(&path, cutoff_epoch)).await?;
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
let modified = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
let epoch = modified
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
if epoch < cutoff_epoch {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
async fn purge_old_files(
|
||||
dir: &Path,
|
||||
cutoff_epoch: u64,
|
||||
dry_run: bool,
|
||||
) -> anyhow::Result<(usize, u64)> {
|
||||
let mut deleted = 0usize;
|
||||
let mut bytes = 0u64;
|
||||
if !dir.is_dir() {
|
||||
return Ok((0, 0));
|
||||
}
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let (d, b) = Box::pin(purge_old_files(&path, cutoff_epoch, dry_run)).await?;
|
||||
deleted += d;
|
||||
bytes += b;
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
let modified = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
let epoch = modified
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
if epoch < cutoff_epoch {
|
||||
bytes += meta.len();
|
||||
deleted += 1;
|
||||
if !dry_run {
|
||||
let _ = fs::remove_file(&path).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((deleted, bytes))
|
||||
}
|
||||
|
||||
async fn dir_stats(root: &Path) -> anyhow::Result<(usize, u64, serde_json::Value)> {
|
||||
let mut total_files = 0usize;
|
||||
let mut total_bytes = 0u64;
|
||||
let mut breakdown = serde_json::Map::new();
|
||||
|
||||
if !root.is_dir() {
|
||||
return Ok((0, 0, serde_json::Value::Object(breakdown)));
|
||||
}
|
||||
|
||||
let mut rd = fs::read_dir(root).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
let (f, b) = count_dir_contents(&path).await?;
|
||||
total_files += f;
|
||||
total_bytes += b;
|
||||
breakdown.insert(
|
||||
name,
|
||||
json!({"files": f, "size": b, "size_human": format_bytes(b)}),
|
||||
);
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
total_files += 1;
|
||||
total_bytes += meta.len();
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
total_files,
|
||||
total_bytes,
|
||||
serde_json::Value::Object(breakdown),
|
||||
))
|
||||
}
|
||||
|
||||
async fn count_dir_contents(dir: &Path) -> anyhow::Result<(usize, u64)> {
|
||||
let mut files = 0usize;
|
||||
let mut bytes = 0u64;
|
||||
let mut rd = fs::read_dir(dir).await?;
|
||||
while let Some(entry) = rd.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let (f, b) = Box::pin(count_dir_contents(&path)).await?;
|
||||
files += f;
|
||||
bytes += b;
|
||||
} else if let Ok(meta) = fs::metadata(&path).await {
|
||||
files += 1;
|
||||
bytes += meta.len();
|
||||
}
|
||||
}
|
||||
Ok((files, bytes))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_tool(tmp: &TempDir) -> DataManagementTool {
|
||||
DataManagementTool::new(tmp.path().to_path_buf(), 90)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retention_status_reports_correct_cutoff() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool
|
||||
.execute(json!({"command": "retention_status"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["retention_days"], 90);
|
||||
assert!(v["cutoff"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn purge_dry_run_does_not_delete() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
// Create a file with an old modification time by writing it (it will have
|
||||
// the current mtime, so it should not be purged with a 90-day retention).
|
||||
std::fs::write(tmp.path().join("recent.txt"), "data").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool
|
||||
.execute(json!({"command": "purge", "dry_run": true}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["dry_run"], true);
|
||||
// Recent file should not be counted for purge.
|
||||
assert_eq!(v["files"], 0);
|
||||
// File still exists.
|
||||
assert!(tmp.path().join("recent.txt").exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_counts_files_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sub = tmp.path().join("subdir");
|
||||
std::fs::create_dir_all(&sub).unwrap();
|
||||
std::fs::write(sub.join("a.txt"), "hello").unwrap();
|
||||
std::fs::write(sub.join("b.txt"), "world").unwrap();
|
||||
std::fs::write(tmp.path().join("root.txt"), "top").unwrap();
|
||||
|
||||
let tool = make_tool(&tmp);
|
||||
let res = tool.execute(json!({"command": "stats"})).await.unwrap();
|
||||
assert!(res.success);
|
||||
let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
|
||||
assert_eq!(v["total_files"], 3);
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ pub struct HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
}
|
||||
|
||||
impl HttpRequestTool {
|
||||
@@ -20,12 +21,14 @@ impl HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
allowed_domains: normalize_allowed_domains(allowed_domains),
|
||||
max_response_size,
|
||||
timeout_secs,
|
||||
allow_private_hosts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +55,7 @@ impl HttpRequestTool {
|
||||
|
||||
let host = extract_host(url)?;
|
||||
|
||||
if is_private_or_local_host(&host) {
|
||||
if !self.allow_private_hosts && is_private_or_local_host(&host) {
|
||||
anyhow::bail!("Blocked local/private host: {host}");
|
||||
}
|
||||
|
||||
@@ -454,6 +457,13 @@ mod tests {
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool {
|
||||
test_tool_with_private(allowed_domains, false)
|
||||
}
|
||||
|
||||
fn test_tool_with_private(
|
||||
allowed_domains: Vec<&str>,
|
||||
allow_private_hosts: bool,
|
||||
) -> HttpRequestTool {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
..SecurityPolicy::default()
|
||||
@@ -463,6 +473,7 @@ mod tests {
|
||||
allowed_domains.into_iter().map(String::from).collect(),
|
||||
1_000_000,
|
||||
30,
|
||||
allow_private_hosts,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -570,7 +581,7 @@ mod tests {
|
||||
#[test]
|
||||
fn validate_requires_allowlist() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30, false);
|
||||
let err = tool
|
||||
.validate_url("https://example.com")
|
||||
.unwrap_err()
|
||||
@@ -686,7 +697,7 @@ mod tests {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({"url": "https://example.com"}))
|
||||
.await
|
||||
@@ -701,7 +712,7 @@ mod tests {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({"url": "https://example.com"}))
|
||||
.await
|
||||
@@ -724,6 +735,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
10,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "hello world this is long";
|
||||
let truncated = tool.truncate_response(text);
|
||||
@@ -738,6 +750,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
0, // max_response_size = 0 means no limit
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "a".repeat(10_000_000);
|
||||
assert_eq!(tool.truncate_response(&text), text);
|
||||
@@ -750,6 +763,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
5,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "hello world";
|
||||
let truncated = tool.truncate_response(text);
|
||||
@@ -935,4 +949,70 @@ mod tests {
|
||||
.to_string();
|
||||
assert!(err.contains("IPv6"));
|
||||
}
|
||||
|
||||
// ── allow_private_hosts opt-in tests ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn default_blocks_private_hosts() {
|
||||
let tool = test_tool(vec!["localhost", "192.168.1.5", "*"]);
|
||||
assert!(tool
|
||||
.validate_url("https://localhost:8080")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
assert!(tool
|
||||
.validate_url("https://192.168.1.5")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
assert!(tool
|
||||
.validate_url("https://10.0.0.1")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_localhost() {
|
||||
let tool = test_tool_with_private(vec!["localhost"], true);
|
||||
assert!(tool.validate_url("https://localhost:8080").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_private_ipv4() {
|
||||
let tool = test_tool_with_private(vec!["192.168.1.5"], true);
|
||||
assert!(tool.validate_url("https://192.168.1.5").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_rfc1918_with_wildcard() {
|
||||
let tool = test_tool_with_private(vec!["*"], true);
|
||||
assert!(tool.validate_url("https://10.0.0.1").is_ok());
|
||||
assert!(tool.validate_url("https://172.16.0.1").is_ok());
|
||||
assert!(tool.validate_url("https://192.168.1.1").is_ok());
|
||||
assert!(tool.validate_url("http://localhost:8123").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_still_requires_allowlist() {
|
||||
let tool = test_tool_with_private(vec!["example.com"], true);
|
||||
let err = tool
|
||||
.validate_url("https://192.168.1.5")
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
assert!(
|
||||
err.contains("allowed_domains"),
|
||||
"Private host should still need allowlist match, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_false_still_blocks() {
|
||||
let tool = test_tool_with_private(vec!["*"], false);
|
||||
assert!(tool
|
||||
.validate_url("https://localhost:8080")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
use anyhow::Context;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Cached OAuth2 token state persisted to disk between runs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CachedTokenState {
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
/// Unix timestamp (seconds) when the access token expires.
|
||||
pub expires_at: i64,
|
||||
}
|
||||
|
||||
impl CachedTokenState {
|
||||
/// Returns `true` when the token is expired or will expire within 60 seconds.
|
||||
pub fn is_expired(&self) -> bool {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
self.expires_at <= now + 60
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe token cache with disk persistence.
|
||||
pub struct TokenCache {
|
||||
inner: RwLock<Option<CachedTokenState>>,
|
||||
/// Serialises the slow acquire/refresh path so only one caller performs the
|
||||
/// network round-trip while others wait and then read the updated cache.
|
||||
acquire_lock: Mutex<()>,
|
||||
config: super::types::Microsoft365ResolvedConfig,
|
||||
cache_path: PathBuf,
|
||||
}
|
||||
|
||||
impl TokenCache {
|
||||
pub fn new(
|
||||
config: super::types::Microsoft365ResolvedConfig,
|
||||
zeroclaw_dir: &std::path::Path,
|
||||
) -> anyhow::Result<Self> {
|
||||
if config.token_cache_encrypted {
|
||||
anyhow::bail!(
|
||||
"microsoft365: token_cache_encrypted is enabled but encryption is not yet \
|
||||
implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \
|
||||
to false or wait for encryption support."
|
||||
);
|
||||
}
|
||||
|
||||
// Scope cache file to (tenant_id, client_id, auth_flow) so config
|
||||
// changes never reuse tokens from a different account/flow.
|
||||
let mut hasher = DefaultHasher::new();
|
||||
config.tenant_id.hash(&mut hasher);
|
||||
config.client_id.hash(&mut hasher);
|
||||
config.auth_flow.hash(&mut hasher);
|
||||
let fingerprint = format!("{:016x}", hasher.finish());
|
||||
|
||||
let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json"));
|
||||
let cached = Self::load_from_disk(&cache_path);
|
||||
Ok(Self {
|
||||
inner: RwLock::new(cached),
|
||||
acquire_lock: Mutex::new(()),
|
||||
config,
|
||||
cache_path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing or re-authenticating as needed.
|
||||
pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
|
||||
// Fast path: cached and not expired.
|
||||
{
|
||||
let guard = self.inner.read();
|
||||
if let Some(ref state) = *guard {
|
||||
if !state.is_expired() {
|
||||
return Ok(state.access_token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: serialise through a mutex so only one caller performs the
|
||||
// network round-trip while concurrent callers wait and re-check.
|
||||
let _lock = self.acquire_lock.lock().await;
|
||||
|
||||
// Re-check after acquiring the lock — another caller may have refreshed
|
||||
// while we were waiting.
|
||||
{
|
||||
let guard = self.inner.read();
|
||||
if let Some(ref state) = *guard {
|
||||
if !state.is_expired() {
|
||||
return Ok(state.access_token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let new_state = self.acquire_token(client).await?;
|
||||
let token = new_state.access_token.clone();
|
||||
self.persist_to_disk(&new_state);
|
||||
*self.inner.write() = Some(new_state);
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
|
||||
// Try refresh first if we have a refresh token and the flow supports it.
|
||||
// Client credentials flow does not issue refresh tokens, so skip the
|
||||
// attempt entirely to avoid a wasted round-trip.
|
||||
if self.config.auth_flow.as_str() != "client_credentials" {
|
||||
// Clone the token out so the RwLock guard is dropped before the await.
|
||||
let refresh_token_copy = {
|
||||
let guard = self.inner.read();
|
||||
guard.as_ref().and_then(|state| state.refresh_token.clone())
|
||||
};
|
||||
if let Some(refresh_tok) = refresh_token_copy {
|
||||
match self.refresh_token(client, &refresh_tok).await {
|
||||
Ok(new_state) => return Ok(new_state),
|
||||
Err(e) => {
|
||||
tracing::debug!("ms365: refresh token failed, re-authenticating: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match self.config.auth_flow.as_str() {
|
||||
"client_credentials" => self.client_credentials_flow(client).await,
|
||||
"device_code" => self.device_code_flow(client).await,
|
||||
other => anyhow::bail!("Unsupported auth flow: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn client_credentials_flow(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
) -> anyhow::Result<CachedTokenState> {
|
||||
let client_secret = self
|
||||
.config
|
||||
.client_secret
|
||||
.as_deref()
|
||||
.context("client_credentials flow requires client_secret")?;
|
||||
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
self.config.tenant_id
|
||||
);
|
||||
|
||||
let scope = self.config.scopes.join(" ");
|
||||
|
||||
let resp = client
|
||||
.post(&token_url)
|
||||
.form(&[
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", &self.config.client_id),
|
||||
("client_secret", client_secret),
|
||||
("scope", &scope),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to request client_credentials token")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::debug!("ms365: client_credentials raw OAuth error: {body}");
|
||||
anyhow::bail!("ms365: client_credentials token request failed ({status})");
|
||||
}
|
||||
|
||||
let token_resp: TokenResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse token response")?;
|
||||
|
||||
Ok(CachedTokenState {
|
||||
access_token: token_resp.access_token,
|
||||
refresh_token: token_resp.refresh_token,
|
||||
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
||||
})
|
||||
}
|
||||
|
||||
async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
|
||||
let device_code_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
|
||||
self.config.tenant_id
|
||||
);
|
||||
let scope = self.config.scopes.join(" ");
|
||||
|
||||
let resp = client
|
||||
.post(&device_code_url)
|
||||
.form(&[
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("scope", &scope),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to request device code")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::debug!("ms365: device_code initiation raw error: {body}");
|
||||
anyhow::bail!("ms365: device code request failed ({status})");
|
||||
}
|
||||
|
||||
let device_resp: DeviceCodeResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse device code response")?;
|
||||
|
||||
// Log only a generic prompt; the full device_resp.message may contain
|
||||
// sensitive verification URIs or codes that should not appear in logs.
|
||||
tracing::info!(
|
||||
"ms365: device code auth required — follow the instructions shown to the user"
|
||||
);
|
||||
// Print the user-facing message to stderr so the operator can act on it
|
||||
// without it being captured in structured log sinks.
|
||||
eprintln!("ms365: {}", device_resp.message);
|
||||
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
self.config.tenant_id
|
||||
);
|
||||
|
||||
let interval = device_resp.interval.max(5);
|
||||
let max_polls = u32::try_from(
|
||||
(device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1),
|
||||
)
|
||||
.unwrap_or(u32::MAX);
|
||||
|
||||
for _ in 0..max_polls {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
|
||||
|
||||
let poll_resp = client
|
||||
.post(&token_url)
|
||||
.form(&[
|
||||
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("device_code", &device_resp.device_code),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to poll device code token")?;
|
||||
|
||||
if poll_resp.status().is_success() {
|
||||
let token_resp: TokenResponse = poll_resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse token response")?;
|
||||
return Ok(CachedTokenState {
|
||||
access_token: token_resp.access_token,
|
||||
refresh_token: token_resp.refresh_token,
|
||||
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
let body = poll_resp.text().await.unwrap_or_default();
|
||||
if body.contains("authorization_pending") {
|
||||
continue;
|
||||
}
|
||||
if body.contains("slow_down") {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
tracing::debug!("ms365: device code polling raw error: {body}");
|
||||
anyhow::bail!("ms365: device code polling failed");
|
||||
}
|
||||
|
||||
anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
|
||||
}
|
||||
|
||||
async fn refresh_token(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
refresh_token: &str,
|
||||
) -> anyhow::Result<CachedTokenState> {
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
self.config.tenant_id
|
||||
);
|
||||
|
||||
let mut params = vec![
|
||||
("grant_type", "refresh_token"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("refresh_token", refresh_token),
|
||||
];
|
||||
|
||||
let secret_ref;
|
||||
if let Some(ref secret) = self.config.client_secret {
|
||||
secret_ref = secret.as_str();
|
||||
params.push(("client_secret", secret_ref));
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: failed to refresh token")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::debug!("ms365: token refresh raw error: {body}");
|
||||
anyhow::bail!("ms365: token refresh failed ({status})");
|
||||
}
|
||||
|
||||
let token_resp: TokenResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.context("ms365: failed to parse refresh token response")?;
|
||||
|
||||
Ok(CachedTokenState {
|
||||
access_token: token_resp.access_token,
|
||||
refresh_token: token_resp
|
||||
.refresh_token
|
||||
.or_else(|| Some(refresh_token.to_string())),
|
||||
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
|
||||
let data = std::fs::read_to_string(path).ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
fn persist_to_disk(&self, state: &CachedTokenState) {
|
||||
if let Ok(json) = serde_json::to_string_pretty(state) {
|
||||
if let Err(e) = std::fs::write(&self.cache_path, json) {
|
||||
tracing::warn!("ms365: failed to persist token cache: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: i64,
|
||||
}
|
||||
|
||||
fn default_expires_in() -> i64 {
|
||||
3600
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
message: String,
|
||||
#[serde(default = "default_device_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_device_expires_in")]
|
||||
expires_in: i64,
|
||||
}
|
||||
|
||||
fn default_device_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_device_expires_in() -> i64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn token_is_expired_when_past_deadline() {
|
||||
let state = CachedTokenState {
|
||||
access_token: "test".into(),
|
||||
refresh_token: None,
|
||||
expires_at: chrono::Utc::now().timestamp() - 10,
|
||||
};
|
||||
assert!(state.is_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_is_expired_within_buffer() {
|
||||
let state = CachedTokenState {
|
||||
access_token: "test".into(),
|
||||
refresh_token: None,
|
||||
expires_at: chrono::Utc::now().timestamp() + 30,
|
||||
};
|
||||
assert!(state.is_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_is_valid_when_far_from_expiry() {
|
||||
let state = CachedTokenState {
|
||||
access_token: "test".into(),
|
||||
refresh_token: None,
|
||||
expires_at: chrono::Utc::now().timestamp() + 3600,
|
||||
};
|
||||
assert!(!state.is_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_from_disk_returns_none_for_missing_file() {
|
||||
let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
|
||||
assert!(TokenCache::load_from_disk(path).is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,495 @@
|
||||
use anyhow::Context;
|
||||
|
||||
const GRAPH_BASE: &str = "https://graph.microsoft.com/v1.0";
|
||||
|
||||
/// Build the user path segment: `/me` or `/users/{user_id}`.
|
||||
/// The user_id is percent-encoded to prevent path-traversal attacks.
|
||||
fn user_path(user_id: &str) -> String {
|
||||
if user_id == "me" {
|
||||
"/me".to_string()
|
||||
} else {
|
||||
format!("/users/{}", urlencoding::encode(user_id))
|
||||
}
|
||||
}
|
||||
|
||||
/// Percent-encode a single path segment to prevent path-traversal attacks.
|
||||
fn encode_path_segment(segment: &str) -> String {
|
||||
urlencoding::encode(segment).into_owned()
|
||||
}
|
||||
|
||||
/// List mail messages for a user.
|
||||
pub async fn mail_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
folder: Option<&str>,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let base = user_path(user_id);
|
||||
let path = match folder {
|
||||
Some(f) => format!(
|
||||
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
|
||||
encode_path_segment(f)
|
||||
),
|
||||
None => format!("{GRAPH_BASE}{base}/messages"),
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.get(&path)
|
||||
.bearer_auth(token)
|
||||
.query(&[("$top", top.to_string())])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: mail_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "mail_list").await
|
||||
}
|
||||
|
||||
/// Send a mail message.
|
||||
pub async fn mail_send(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
to: &[String],
|
||||
subject: &str,
|
||||
body: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!("{GRAPH_BASE}{base}/sendMail");
|
||||
|
||||
let to_recipients: Vec<serde_json::Value> = to
|
||||
.iter()
|
||||
.map(|addr| {
|
||||
serde_json::json!({
|
||||
"emailAddress": { "address": addr }
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"message": {
|
||||
"subject": subject,
|
||||
"body": {
|
||||
"contentType": "Text",
|
||||
"content": body
|
||||
},
|
||||
"toRecipients": to_recipients
|
||||
}
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: mail_send request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: mail_send raw error body: {body}");
|
||||
anyhow::bail!("ms365: mail_send failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List messages in a Teams channel.
|
||||
pub async fn teams_message_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
team_id: &str,
|
||||
channel_id: &str,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
|
||||
encode_path_segment(team_id),
|
||||
encode_path_segment(channel_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.query(&[("$top", top.to_string())])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: teams_message_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "teams_message_list").await
|
||||
}
|
||||
|
||||
/// Send a message to a Teams channel.
|
||||
pub async fn teams_message_send(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
team_id: &str,
|
||||
channel_id: &str,
|
||||
body: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
|
||||
encode_path_segment(team_id),
|
||||
encode_path_segment(channel_id)
|
||||
);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"body": {
|
||||
"content": body
|
||||
}
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: teams_message_send request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: teams_message_send raw error body: {body}");
|
||||
anyhow::bail!("ms365: teams_message_send failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List calendar events in a date range.
|
||||
pub async fn calendar_events_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
start: &str,
|
||||
end: &str,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!("{GRAPH_BASE}{base}/calendarView");
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.query(&[
|
||||
("startDateTime", start.to_string()),
|
||||
("endDateTime", end.to_string()),
|
||||
("$top", top.to_string()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: calendar_events_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "calendar_events_list").await
|
||||
}
|
||||
|
||||
/// Create a calendar event.
|
||||
pub async fn calendar_event_create(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
subject: &str,
|
||||
start: &str,
|
||||
end: &str,
|
||||
attendees: &[String],
|
||||
body_text: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!("{GRAPH_BASE}{base}/events");
|
||||
|
||||
let attendee_list: Vec<serde_json::Value> = attendees
|
||||
.iter()
|
||||
.map(|email| {
|
||||
serde_json::json!({
|
||||
"emailAddress": { "address": email },
|
||||
"type": "required"
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut payload = serde_json::json!({
|
||||
"subject": subject,
|
||||
"start": {
|
||||
"dateTime": start,
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"end": {
|
||||
"dateTime": end,
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"attendees": attendee_list
|
||||
});
|
||||
|
||||
if let Some(text) = body_text {
|
||||
payload["body"] = serde_json::json!({
|
||||
"contentType": "Text",
|
||||
"content": text
|
||||
});
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: calendar_event_create request failed")?;
|
||||
|
||||
let value = handle_json_response(resp, "calendar_event_create").await?;
|
||||
let event_id = value["id"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
/// Delete a calendar event by ID.
|
||||
pub async fn calendar_event_delete(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
event_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}{base}/events/{}",
|
||||
encode_path_segment(event_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.delete(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: calendar_event_delete request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: calendar_event_delete raw error body: {body}");
|
||||
anyhow::bail!("ms365: calendar_event_delete failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List children of a OneDrive folder.
|
||||
pub async fn onedrive_list(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
path: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let base = user_path(user_id);
|
||||
let url = match path {
|
||||
Some(p) if !p.is_empty() => {
|
||||
let encoded = urlencoding::encode(p);
|
||||
format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children")
|
||||
}
|
||||
_ => format!("{GRAPH_BASE}{base}/drive/root/children"),
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: onedrive_list request failed")?;
|
||||
|
||||
handle_json_response(resp, "onedrive_list").await
|
||||
}
|
||||
|
||||
/// Download a OneDrive item by ID, with a maximum size guard.
|
||||
pub async fn onedrive_download(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
user_id: &str,
|
||||
item_id: &str,
|
||||
max_size: usize,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let base = user_path(user_id);
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}{base}/drive/items/{}/content",
|
||||
encode_path_segment(item_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: onedrive_download request failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: onedrive_download raw error body: {body}");
|
||||
anyhow::bail!("ms365: onedrive_download failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
let bytes = resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("ms365: failed to read download body")?;
|
||||
if bytes.len() > max_size {
|
||||
anyhow::bail!(
|
||||
"ms365: downloaded file exceeds max_size ({} > {max_size})",
|
||||
bytes.len()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(bytes.to_vec())
|
||||
}
|
||||
|
||||
/// Search SharePoint for documents matching a query.
|
||||
pub async fn sharepoint_search(
|
||||
client: &reqwest::Client,
|
||||
token: &str,
|
||||
query: &str,
|
||||
top: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{GRAPH_BASE}/search/query");
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"requests": [{
|
||||
"entityTypes": ["driveItem", "listItem", "site"],
|
||||
"query": {
|
||||
"queryString": query
|
||||
},
|
||||
"from": 0,
|
||||
"size": top
|
||||
}]
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("ms365: sharepoint_search request failed")?;
|
||||
|
||||
handle_json_response(resp, "sharepoint_search").await
|
||||
}
|
||||
|
||||
/// Extract a short, safe error code from a Graph API JSON error body.
|
||||
/// Returns `None` when the body is not a recognised Graph error envelope.
|
||||
fn extract_graph_error_code(body: &str) -> Option<String> {
|
||||
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
|
||||
let code = parsed
|
||||
.get("error")
|
||||
.and_then(|e| e.get("code"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string());
|
||||
code
|
||||
}
|
||||
|
||||
/// Parse a JSON response body, returning an error on non-success status.
|
||||
/// Raw Graph API error bodies are not propagated; only the HTTP status and a
|
||||
/// short error code (when available) are surfaced to avoid leaking internal
|
||||
/// API details.
|
||||
async fn handle_json_response(
|
||||
resp: reqwest::Response,
|
||||
operation: &str,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let code = extract_graph_error_code(&body).unwrap_or_else(|| "unknown".to_string());
|
||||
tracing::debug!("ms365: {operation} raw error body: {body}");
|
||||
anyhow::bail!("ms365: {operation} failed ({status}, code={code})");
|
||||
}
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.with_context(|| format!("ms365: failed to parse {operation} response"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn user_path_me() {
|
||||
assert_eq!(user_path("me"), "/me");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_path_specific_user() {
|
||||
assert_eq!(user_path("user@contoso.com"), "/users/user%40contoso.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mail_list_url_no_folder() {
|
||||
let base = user_path("me");
|
||||
let url = format!("{GRAPH_BASE}{base}/messages");
|
||||
assert_eq!(url, "https://graph.microsoft.com/v1.0/me/messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mail_list_url_with_folder() {
|
||||
let base = user_path("me");
|
||||
let folder = "inbox";
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}{base}/mailFolders/{}/messages",
|
||||
encode_path_segment(folder)
|
||||
);
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/me/mailFolders/inbox/messages"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calendar_view_url() {
|
||||
let base = user_path("user@example.com");
|
||||
let url = format!("{GRAPH_BASE}{base}/calendarView");
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/users/user%40example.com/calendarView"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn teams_message_url() {
|
||||
let url = format!(
|
||||
"{GRAPH_BASE}/teams/{}/channels/{}/messages",
|
||||
encode_path_segment("team-123"),
|
||||
encode_path_segment("channel-456")
|
||||
);
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/teams/team-123/channels/channel-456/messages"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onedrive_root_url() {
|
||||
let base = user_path("me");
|
||||
let url = format!("{GRAPH_BASE}{base}/drive/root/children");
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/me/drive/root/children"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onedrive_path_url() {
|
||||
let base = user_path("me");
|
||||
let encoded = urlencoding::encode("Documents/Reports");
|
||||
let url = format!("{GRAPH_BASE}{base}/drive/root:/{encoded}:/children");
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://graph.microsoft.com/v1.0/me/drive/root:/Documents%2FReports:/children"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sharepoint_search_url() {
|
||||
let url = format!("{GRAPH_BASE}/search/query");
|
||||
assert_eq!(url, "https://graph.microsoft.com/v1.0/search/query");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,567 @@
|
||||
//! Microsoft 365 integration tool — Graph API access for Mail, Teams, Calendar,
|
||||
//! OneDrive, and SharePoint via a single action-dispatched tool surface.
|
||||
//!
|
||||
//! Auth is handled through direct HTTP calls to the Microsoft identity platform
|
||||
//! (client credentials or device code flow) with token caching.
|
||||
|
||||
pub mod auth;
|
||||
pub mod graph_client;
|
||||
pub mod types;
|
||||
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Maximum download size for OneDrive files (10 MB).
|
||||
const MAX_ONEDRIVE_DOWNLOAD_SIZE: usize = 10 * 1024 * 1024;
|
||||
|
||||
/// Default number of items to return in list operations.
|
||||
const DEFAULT_TOP: u32 = 25;
|
||||
|
||||
pub struct Microsoft365Tool {
|
||||
config: types::Microsoft365ResolvedConfig,
|
||||
security: Arc<SecurityPolicy>,
|
||||
token_cache: Arc<auth::TokenCache>,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Microsoft365Tool {
|
||||
pub fn new(
|
||||
config: types::Microsoft365ResolvedConfig,
|
||||
security: Arc<SecurityPolicy>,
|
||||
zeroclaw_dir: &std::path::Path,
|
||||
) -> anyhow::Result<Self> {
|
||||
let http_client =
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("tool.microsoft365", 60, 10);
|
||||
let token_cache = Arc::new(auth::TokenCache::new(config.clone(), zeroclaw_dir)?);
|
||||
Ok(Self {
|
||||
config,
|
||||
security,
|
||||
token_cache,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_token(&self) -> anyhow::Result<String> {
|
||||
self.token_cache.get_token(&self.http_client).await
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.config.user_id
|
||||
}
|
||||
|
||||
async fn dispatch(&self, action: &str, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
match action {
|
||||
"mail_list" => self.handle_mail_list(args).await,
|
||||
"mail_send" => self.handle_mail_send(args).await,
|
||||
"teams_message_list" => self.handle_teams_message_list(args).await,
|
||||
"teams_message_send" => self.handle_teams_message_send(args).await,
|
||||
"calendar_events_list" => self.handle_calendar_events_list(args).await,
|
||||
"calendar_event_create" => self.handle_calendar_event_create(args).await,
|
||||
"calendar_event_delete" => self.handle_calendar_event_delete(args).await,
|
||||
"onedrive_list" => self.handle_onedrive_list(args).await,
|
||||
"onedrive_download" => self.handle_onedrive_download(args).await,
|
||||
"sharepoint_search" => self.handle_sharepoint_search(args).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action: {action}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Read actions ────────────────────────────────────────────────
|
||||
|
||||
async fn handle_mail_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.mail_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let folder = args["folder"].as_str();
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result =
|
||||
graph_client::mail_list(&self.http_client, &token, self.user_id(), folder, top).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_teams_message_list(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.teams_message_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let team_id = args["team_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
|
||||
let channel_id = args["channel_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result =
|
||||
graph_client::teams_message_list(&self.http_client, &token, team_id, channel_id, top)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_calendar_events_list(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.calendar_events_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let start = args["start"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
|
||||
let end = args["end"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result = graph_client::calendar_events_list(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
start,
|
||||
end,
|
||||
top,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_onedrive_list(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_list")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let path = args["path"].as_str();
|
||||
|
||||
let result =
|
||||
graph_client::onedrive_list(&self.http_client, &token, self.user_id(), path).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_onedrive_download(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.onedrive_download")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let item_id = args["item_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("item_id is required"))?;
|
||||
let max_size = args["max_size"]
|
||||
.as_u64()
|
||||
.and_then(|v| usize::try_from(v).ok())
|
||||
.unwrap_or(MAX_ONEDRIVE_DOWNLOAD_SIZE)
|
||||
.min(MAX_ONEDRIVE_DOWNLOAD_SIZE);
|
||||
|
||||
let bytes = graph_client::onedrive_download(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
item_id,
|
||||
max_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Return base64-encoded for binary safety.
|
||||
use base64::Engine;
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Downloaded {} bytes (base64 encoded):\n{encoded}",
|
||||
bytes.len()
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_sharepoint_search(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "microsoft365.sharepoint_search")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let query = args["query"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("query is required"))?;
|
||||
let top = u32::try_from(args["top"].as_u64().unwrap_or(u64::from(DEFAULT_TOP)))
|
||||
.unwrap_or(DEFAULT_TOP);
|
||||
|
||||
let result = graph_client::sharepoint_search(&self.http_client, &token, query, top).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&result)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Write actions ───────────────────────────────────────────────
|
||||
|
||||
async fn handle_mail_send(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.mail_send")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let to: Vec<String> = args["to"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("to must be an array of email addresses"))?
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect();
|
||||
|
||||
if to.is_empty() {
|
||||
anyhow::bail!("to must contain at least one email address");
|
||||
}
|
||||
|
||||
let subject = args["subject"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
|
||||
let body = args["body"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
|
||||
|
||||
graph_client::mail_send(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
&to,
|
||||
subject,
|
||||
body,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Email sent to: {}", to.join(", ")),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_teams_message_send(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.teams_message_send")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let team_id = args["team_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("team_id is required"))?;
|
||||
let channel_id = args["channel_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("channel_id is required"))?;
|
||||
let body = args["body"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("body is required"))?;
|
||||
|
||||
graph_client::teams_message_send(&self.http_client, &token, team_id, channel_id, body)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: "Teams message sent".to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_calendar_event_create(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_create")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let subject = args["subject"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("subject is required"))?;
|
||||
let start = args["start"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("start datetime is required"))?;
|
||||
let end = args["end"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("end datetime is required"))?;
|
||||
let attendees: Vec<String> = args["attendees"]
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let body_text = args["body"].as_str();
|
||||
|
||||
let event_id = graph_client::calendar_event_create(
|
||||
&self.http_client,
|
||||
&token,
|
||||
self.user_id(),
|
||||
subject,
|
||||
start,
|
||||
end,
|
||||
&attendees,
|
||||
body_text,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Calendar event created (id: {event_id})"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_calendar_event_delete(
|
||||
&self,
|
||||
args: &serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "microsoft365.calendar_event_delete")
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let token = self.get_token().await?;
|
||||
let event_id = args["event_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("event_id is required"))?;
|
||||
|
||||
graph_client::calendar_event_delete(&self.http_client, &token, self.user_id(), event_id)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Calendar event {event_id} deleted"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for Microsoft365Tool {
|
||||
fn name(&self) -> &str {
|
||||
"microsoft365"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Microsoft 365 integration: manage Outlook mail, Teams messages, Calendar events, \
|
||||
OneDrive files, and SharePoint search via Microsoft Graph API"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["action"],
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"teams_message_list",
|
||||
"teams_message_send",
|
||||
"calendar_events_list",
|
||||
"calendar_event_create",
|
||||
"calendar_event_delete",
|
||||
"onedrive_list",
|
||||
"onedrive_download",
|
||||
"sharepoint_search"
|
||||
],
|
||||
"description": "The Microsoft 365 action to perform"
|
||||
},
|
||||
"folder": {
|
||||
"type": "string",
|
||||
"description": "Mail folder ID (for mail_list, e.g. 'inbox', 'sentitems')"
|
||||
},
|
||||
"to": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Recipient email addresses (for mail_send)"
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Email subject or calendar event subject"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Message body text"
|
||||
},
|
||||
"team_id": {
|
||||
"type": "string",
|
||||
"description": "Teams team ID (for teams_message_list/send)"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Teams channel ID (for teams_message_list/send)"
|
||||
},
|
||||
"start": {
|
||||
"type": "string",
|
||||
"description": "Start datetime in ISO 8601 format (for calendar actions)"
|
||||
},
|
||||
"end": {
|
||||
"type": "string",
|
||||
"description": "End datetime in ISO 8601 format (for calendar actions)"
|
||||
},
|
||||
"attendees": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Attendee email addresses (for calendar_event_create)"
|
||||
},
|
||||
"event_id": {
|
||||
"type": "string",
|
||||
"description": "Calendar event ID (for calendar_event_delete)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "OneDrive folder path (for onedrive_list)"
|
||||
},
|
||||
"item_id": {
|
||||
"type": "string",
|
||||
"description": "OneDrive item ID (for onedrive_download)"
|
||||
},
|
||||
"max_size": {
|
||||
"type": "integer",
|
||||
"description": "Maximum download size in bytes (for onedrive_download, default 10MB)"
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (for sharepoint_search)"
|
||||
},
|
||||
"top": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of items to return (default 25)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args["action"].as_str() {
|
||||
Some(a) => a.to_string(),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'action' parameter is required".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match self.dispatch(&action, &args).await {
|
||||
Ok(result) => Ok(result),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("microsoft365.{action} failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tool_name_is_microsoft365() {
|
||||
// Verify the schema is valid JSON with the expected structure.
|
||||
let schema_str = r#"{"type":"object","required":["action"]}"#;
|
||||
let _: serde_json::Value = serde_json::from_str(schema_str).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_has_action_enum() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"required": ["action"],
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"teams_message_list",
|
||||
"teams_message_send",
|
||||
"calendar_events_list",
|
||||
"calendar_event_create",
|
||||
"calendar_event_delete",
|
||||
"onedrive_list",
|
||||
"onedrive_download",
|
||||
"sharepoint_search"
|
||||
]
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
|
||||
assert_eq!(actions.len(), 10);
|
||||
assert!(actions.contains(&json!("mail_list")));
|
||||
assert!(actions.contains(&json!("sharepoint_search")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn action_dispatch_table_is_exhaustive() {
|
||||
let valid_actions = [
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"teams_message_list",
|
||||
"teams_message_send",
|
||||
"calendar_events_list",
|
||||
"calendar_event_create",
|
||||
"calendar_event_delete",
|
||||
"onedrive_list",
|
||||
"onedrive_download",
|
||||
"sharepoint_search",
|
||||
];
|
||||
assert_eq!(valid_actions.len(), 10);
|
||||
assert!(!valid_actions.contains(&"invalid_action"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Resolved Microsoft 365 configuration with all secrets decrypted and defaults applied.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Microsoft365ResolvedConfig {
|
||||
pub tenant_id: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub auth_flow: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub token_cache_encrypted: bool,
|
||||
pub user_id: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Microsoft365ResolvedConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Microsoft365ResolvedConfig")
|
||||
.field("tenant_id", &self.tenant_id)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("client_secret", &self.client_secret.as_ref().map(|_| "***"))
|
||||
.field("auth_flow", &self.auth_flow)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("token_cache_encrypted", &self.token_cache_encrypted)
|
||||
.field("user_id", &self.user_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolved_config_serialization_roundtrip() {
|
||||
let config = Microsoft365ResolvedConfig {
|
||||
tenant_id: "test-tenant".into(),
|
||||
client_id: "test-client".into(),
|
||||
client_secret: Some("secret".into()),
|
||||
auth_flow: "client_credentials".into(),
|
||||
scopes: vec!["https://graph.microsoft.com/.default".into()],
|
||||
token_cache_encrypted: false,
|
||||
user_id: "me".into(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let parsed: Microsoft365ResolvedConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.tenant_id, "test-tenant");
|
||||
assert_eq!(parsed.client_id, "test-client");
|
||||
assert_eq!(parsed.client_secret.as_deref(), Some("secret"));
|
||||
assert_eq!(parsed.auth_flow, "client_credentials");
|
||||
assert_eq!(parsed.scopes.len(), 1);
|
||||
assert_eq!(parsed.user_id, "me");
|
||||
}
|
||||
}
|
||||
+183
-18
@@ -15,9 +15,12 @@
|
||||
//! To add a new tool, implement [`Tool`] in a new submodule and register it in
|
||||
//! [`all_tools_with_runtime`]. See `AGENTS.md` §7.3 for the full change playbook.
|
||||
|
||||
pub mod backup_tool;
|
||||
pub mod browser;
|
||||
pub mod browser_open;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
pub mod cloud_patterns;
|
||||
pub mod composio;
|
||||
pub mod content_search;
|
||||
pub mod cron_add;
|
||||
@@ -26,6 +29,7 @@ pub mod cron_remove;
|
||||
pub mod cron_run;
|
||||
pub mod cron_runs;
|
||||
pub mod cron_update;
|
||||
pub mod data_management;
|
||||
pub mod delegate;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
@@ -48,22 +52,32 @@ pub mod mcp_transport;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod microsoft365;
|
||||
pub mod model_routing_config;
|
||||
pub mod node_tool;
|
||||
pub mod notion_tool;
|
||||
pub mod pdf_read;
|
||||
pub mod project_intel;
|
||||
pub mod proxy_config;
|
||||
pub mod pushover;
|
||||
pub mod report_templates;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod tool_search;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
pub mod web_search_tool;
|
||||
pub mod workspace_tool;
|
||||
|
||||
pub use backup_tool::BackupTool;
|
||||
pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
pub use composio::ComposioTool;
|
||||
pub use content_search::ContentSearchTool;
|
||||
pub use cron_add::CronAddTool;
|
||||
@@ -72,6 +86,7 @@ pub use cron_remove::CronRemoveTool;
|
||||
pub use cron_run::CronRunTool;
|
||||
pub use cron_runs::CronRunsTool;
|
||||
pub use cron_update::CronUpdateTool;
|
||||
pub use data_management::DataManagementTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
@@ -92,23 +107,29 @@ pub use mcp_tool::McpToolWrapper;
|
||||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use microsoft365::Microsoft365Tool;
|
||||
pub use model_routing_config::ModelRoutingConfigTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use node_tool::NodeTool;
|
||||
pub use notion_tool::NotionTool;
|
||||
pub use pdf_read::PdfReadTool;
|
||||
pub use project_intel::ProjectIntelTool;
|
||||
pub use proxy_config::ProxyConfigTool;
|
||||
pub use pushover::PushoverTool;
|
||||
pub use schedule::ScheduleTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
pub use traits::Tool;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{ToolResult, ToolSpec};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
pub use web_search_tool::WebSearchTool;
|
||||
pub use workspace_tool::WorkspaceTool;
|
||||
|
||||
use crate::config::{Config, DelegateAgentConfig};
|
||||
use crate::memory::Memory;
|
||||
@@ -314,6 +335,7 @@ pub fn all_tools_with_runtime(
|
||||
http_config.allowed_domains.clone(),
|
||||
http_config.max_response_size,
|
||||
http_config.timeout_secs,
|
||||
http_config.allow_private_hosts,
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -339,6 +361,60 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Notion API tool (conditionally registered)
|
||||
if root_config.notion.enabled {
|
||||
let notion_api_key = if root_config.notion.api_key.trim().is_empty() {
|
||||
std::env::var("NOTION_API_KEY").unwrap_or_default()
|
||||
} else {
|
||||
root_config.notion.api_key.trim().to_string()
|
||||
};
|
||||
if notion_api_key.trim().is_empty() {
|
||||
tracing::warn!(
|
||||
"Notion tool enabled but no API key found (set notion.api_key or NOTION_API_KEY env var)"
|
||||
);
|
||||
} else {
|
||||
tool_arcs.push(Arc::new(NotionTool::new(notion_api_key, security.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
// Project delivery intelligence
|
||||
if root_config.project_intel.enabled {
|
||||
tool_arcs.push(Arc::new(ProjectIntelTool::new(
|
||||
root_config.project_intel.default_language.clone(),
|
||||
root_config.project_intel.risk_sensitivity.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// MCSS Security Operations
|
||||
if root_config.security_ops.enabled {
|
||||
tool_arcs.push(Arc::new(SecurityOpsTool::new(
|
||||
root_config.security_ops.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// Backup tool (enabled by default)
|
||||
if root_config.backup.enabled {
|
||||
tool_arcs.push(Arc::new(BackupTool::new(
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.backup.include_dirs.clone(),
|
||||
root_config.backup.max_keep,
|
||||
)));
|
||||
}
|
||||
|
||||
// Data management tool (disabled by default)
|
||||
if root_config.data_retention.enabled {
|
||||
tool_arcs.push(Arc::new(DataManagementTool::new(
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.data_retention.retention_days,
|
||||
)));
|
||||
}
|
||||
|
||||
// Cloud operations advisory tools (read-only analysis)
|
||||
if root_config.cloud_ops.enabled {
|
||||
tool_arcs.push(Arc::new(CloudOpsTool::new(root_config.cloud_ops.clone())));
|
||||
tool_arcs.push(Arc::new(CloudPatternsTool::new()));
|
||||
}
|
||||
|
||||
// PDF extraction (feature-gated at compile time via rag-pdf)
|
||||
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
|
||||
|
||||
@@ -356,7 +432,80 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
}
|
||||
|
||||
// Microsoft 365 Graph API integration
|
||||
if root_config.microsoft365.enabled {
|
||||
let ms_cfg = &root_config.microsoft365;
|
||||
let tenant_id = ms_cfg
|
||||
.tenant_id
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_string();
|
||||
let client_id = ms_cfg
|
||||
.client_id
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_string();
|
||||
if !tenant_id.is_empty() && !client_id.is_empty() {
|
||||
// Fail fast: client_credentials flow requires a client_secret at registration time.
|
||||
if ms_cfg.auth_flow.trim() == "client_credentials"
|
||||
&& ms_cfg
|
||||
.client_secret
|
||||
.as_deref()
|
||||
.map_or(true, |s| s.trim().is_empty())
|
||||
{
|
||||
tracing::error!(
|
||||
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
|
||||
);
|
||||
return (boxed_registry_from_arcs(tool_arcs), None);
|
||||
}
|
||||
|
||||
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret: ms_cfg.client_secret.clone(),
|
||||
auth_flow: ms_cfg.auth_flow.clone(),
|
||||
scopes: ms_cfg.scopes.clone(),
|
||||
token_cache_encrypted: ms_cfg.token_cache_encrypted,
|
||||
user_id: ms_cfg.user_id.as_deref().unwrap_or("me").to_string(),
|
||||
};
|
||||
// Store token cache in the config directory (next to config.toml),
|
||||
// not the workspace directory, to keep bearer tokens out of the
|
||||
// project tree.
|
||||
let cache_dir = root_config.config_path.parent().unwrap_or(workspace_dir);
|
||||
match Microsoft365Tool::new(resolved, security.clone(), cache_dir) {
|
||||
Ok(tool) => tool_arcs.push(Arc::new(tool)),
|
||||
Err(e) => {
|
||||
tracing::error!("microsoft365: failed to initialize tool: {e}");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"microsoft365: skipped registration because tenant_id or client_id is empty"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Add delegation tool when agents are configured
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
};
|
||||
|
||||
let delegate_handle: Option<DelegateParentToolsHandle> = if agents.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -364,28 +513,12 @@ pub fn all_tools_with_runtime(
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let parent_tools = Arc::new(RwLock::new(tool_arcs.clone()));
|
||||
let delegate_tool = DelegateTool::new_with_options(
|
||||
delegate_agents,
|
||||
delegate_fallback_credential,
|
||||
delegate_fallback_credential.clone(),
|
||||
security.clone(),
|
||||
crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
},
|
||||
provider_runtime_options.clone(),
|
||||
)
|
||||
.with_parent_tools(Arc::clone(&parent_tools))
|
||||
.with_multimodal_config(root_config.multimodal.clone());
|
||||
@@ -393,6 +526,38 @@ pub fn all_tools_with_runtime(
|
||||
Some(parent_tools)
|
||||
};
|
||||
|
||||
// Add swarm tool when swarms are configured
|
||||
if !root_config.swarms.is_empty() {
|
||||
let swarm_agents: HashMap<String, DelegateAgentConfig> = agents
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
tool_arcs.push(Arc::new(SwarmTool::new(
|
||||
root_config.swarms.clone(),
|
||||
swarm_agents,
|
||||
delegate_fallback_credential,
|
||||
security.clone(),
|
||||
provider_runtime_options,
|
||||
)));
|
||||
}
|
||||
|
||||
// Workspace management tool (conditionally registered when workspace isolation is enabled)
|
||||
if root_config.workspace.enabled {
|
||||
let workspaces_dir = if root_config.workspace.workspaces_dir.starts_with("~/") {
|
||||
let home = directories::UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||
home.join(&root_config.workspace.workspaces_dir[2..])
|
||||
} else {
|
||||
std::path::PathBuf::from(&root_config.workspace.workspaces_dir)
|
||||
};
|
||||
let ws_manager = crate::config::workspace::WorkspaceManager::new(workspaces_dir);
|
||||
tool_arcs.push(Arc::new(WorkspaceTool::new(
|
||||
Arc::new(tokio::sync::RwLock::new(ws_manager)),
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,438 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::{policy::ToolOperation, SecurityPolicy};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
const NOTION_API_BASE: &str = "https://api.notion.com/v1";
|
||||
const NOTION_VERSION: &str = "2022-06-28";
|
||||
const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||
/// Maximum number of characters to include from an error response body.
|
||||
const MAX_ERROR_BODY_CHARS: usize = 500;
|
||||
|
||||
/// Tool for interacting with the Notion API — query databases, read/create/update pages,
|
||||
/// and search the workspace. Each action is gated by the appropriate security operation
|
||||
/// (Read for queries, Act for mutations).
|
||||
pub struct NotionTool {
|
||||
api_key: String,
|
||||
http: reqwest::Client,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl NotionTool {
|
||||
/// Create a new Notion tool with the given API key and security policy.
|
||||
pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
http: reqwest::Client::new(),
|
||||
security,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the standard Notion API headers (Authorization, version, content-type).
|
||||
fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", self.api_key)
|
||||
.parse()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
|
||||
);
|
||||
headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
|
||||
headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Query a Notion database with an optional filter.
|
||||
async fn query_database(
|
||||
&self,
|
||||
database_id: &str,
|
||||
filter: Option<&serde_json::Value>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
|
||||
let mut body = json!({});
|
||||
if let Some(f) = filter {
|
||||
body["filter"] = f.clone();
|
||||
}
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion query_database failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Read a single Notion page by ID.
|
||||
async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.headers(self.headers()?)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion read_page failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Create a new Notion page, optionally within a database.
|
||||
async fn create_page(
|
||||
&self,
|
||||
properties: &serde_json::Value,
|
||||
database_id: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/pages");
|
||||
let mut body = json!({ "properties": properties });
|
||||
if let Some(db_id) = database_id {
|
||||
body["parent"] = json!({ "database_id": db_id });
|
||||
}
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion create_page failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Update an existing Notion page's properties.
|
||||
async fn update_page(
|
||||
&self,
|
||||
page_id: &str,
|
||||
properties: &serde_json::Value,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/pages/{page_id}");
|
||||
let body = json!({ "properties": properties });
|
||||
let resp = self
|
||||
.http
|
||||
.patch(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion update_page failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Search the Notion workspace by query string.
|
||||
async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
|
||||
let url = format!("{NOTION_API_BASE}/search");
|
||||
let body = json!({ "query": query });
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.headers(self.headers()?)
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
|
||||
anyhow::bail!("Notion search failed ({status}): {truncated}");
|
||||
}
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for NotionTool {
|
||||
fn name(&self) -> &str {
|
||||
"notion"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Interact with Notion: query databases, read/create/update pages, and search the workspace."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["query_database", "read_page", "create_page", "update_page", "search"],
|
||||
"description": "The Notion API action to perform"
|
||||
},
|
||||
"database_id": {
|
||||
"type": "string",
|
||||
"description": "Database ID (required for query_database, optional for create_page)"
|
||||
},
|
||||
"page_id": {
|
||||
"type": "string",
|
||||
"description": "Page ID (required for read_page and update_page)"
|
||||
},
|
||||
"filter": {
|
||||
"type": "object",
|
||||
"description": "Notion filter object for query_database"
|
||||
},
|
||||
"properties": {
|
||||
"type": "object",
|
||||
"description": "Properties object for create_page and update_page"
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string for the search action"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Enforce granular security: Read for queries, Act for mutations
|
||||
let operation = match action {
|
||||
"query_database" | "read_page" | "search" => ToolOperation::Read,
|
||||
"create_page" | "update_page" => ToolOperation::Act,
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let result = match action {
|
||||
"query_database" => {
|
||||
let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("query_database requires database_id parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let filter = args.get("filter");
|
||||
self.query_database(database_id, filter).await
|
||||
}
|
||||
"read_page" => {
|
||||
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("read_page requires page_id parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
self.read_page(page_id).await
|
||||
}
|
||||
"create_page" => {
|
||||
let properties = match args.get("properties") {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("create_page requires properties parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let database_id = args.get("database_id").and_then(|v| v.as_str());
|
||||
self.create_page(properties, database_id).await
|
||||
}
|
||||
"update_page" => {
|
||||
let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("update_page requires page_id parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let properties = match args.get("properties") {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("update_page requires properties parameter".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
self.update_page(page_id, properties).await
|
||||
}
|
||||
"search" => {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
self.search(query).await
|
||||
}
|
||||
_ => unreachable!(), // Already handled above
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(value) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::SecurityPolicy;
|
||||
|
||||
fn test_tool() -> NotionTool {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
NotionTool::new("test-key".into(), security)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_is_notion() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "notion");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_has_required_action() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.iter().any(|v| v.as_str() == Some("action")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_defines_all_actions() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
|
||||
let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
|
||||
assert!(action_strs.contains(&"query_database"));
|
||||
assert!(action_strs.contains(&"read_page"));
|
||||
assert!(action_strs.contains(&"create_page"));
|
||||
assert!(action_strs.contains(&"update_page"));
|
||||
assert!(action_strs.contains(&"search"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_missing_action_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_unknown_action_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_query_database_missing_id_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "query_database"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("database_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_read_page_missing_id_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("page_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_create_page_missing_properties_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "create_page"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("properties"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_update_page_missing_page_id_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "update_page", "properties": {}}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("page_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_update_page_missing_properties_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "update_page", "page_id": "test-id"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("properties"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,750 @@
|
||||
//! Project delivery intelligence tool.
|
||||
//!
|
||||
//! Provides read-only analysis and generation for project management:
|
||||
//! status reports, risk detection, client communication drafting,
|
||||
//! sprint summaries, and effort estimation.
|
||||
|
||||
use super::report_templates;
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
|
||||
/// Project intelligence tool for consulting project management.
|
||||
///
|
||||
/// All actions are read-only analysis/generation; nothing is modified externally.
|
||||
pub struct ProjectIntelTool {
|
||||
default_language: String,
|
||||
risk_sensitivity: RiskSensitivity,
|
||||
}
|
||||
|
||||
/// Risk detection sensitivity level.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RiskSensitivity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl RiskSensitivity {
|
||||
fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"low" => Self::Low,
|
||||
"high" => Self::High,
|
||||
_ => Self::Medium,
|
||||
}
|
||||
}
|
||||
|
||||
/// Threshold multiplier: higher sensitivity means lower thresholds.
|
||||
fn threshold_factor(self) -> f64 {
|
||||
match self {
|
||||
Self::Low => 1.5,
|
||||
Self::Medium => 1.0,
|
||||
Self::High => 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectIntelTool {
|
||||
pub fn new(default_language: String, risk_sensitivity: String) -> Self {
|
||||
Self {
|
||||
default_language,
|
||||
risk_sensitivity: RiskSensitivity::from_str(&risk_sensitivity),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_status_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let project_name = args
|
||||
.get("project_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for status_report"))?;
|
||||
let period = args
|
||||
.get("period")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'period' for status_report"))?;
|
||||
let lang = args
|
||||
.get("language")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&self.default_language);
|
||||
let git_log = args
|
||||
.get("git_log")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("No git data provided");
|
||||
let jira_summary = args
|
||||
.get("jira_summary")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("No Jira data provided");
|
||||
let notes = args.get("notes").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
let tpl = report_templates::weekly_status_template(lang);
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("project_name".into(), project_name.to_string());
|
||||
vars.insert("period".into(), period.to_string());
|
||||
vars.insert("completed".into(), git_log.to_string());
|
||||
vars.insert("in_progress".into(), jira_summary.to_string());
|
||||
vars.insert("blocked".into(), notes.to_string());
|
||||
vars.insert("next_steps".into(), "To be determined".into());
|
||||
|
||||
let rendered = tpl.render(&vars);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: rendered,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_risk_scan(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let deadlines = args
|
||||
.get("deadlines")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let velocity = args
|
||||
.get("velocity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let blockers = args
|
||||
.get("blockers")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
let lang = args
|
||||
.get("language")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&self.default_language);
|
||||
|
||||
let mut risks = Vec::new();
|
||||
|
||||
// Heuristic risk detection based on signals
|
||||
let factor = self.risk_sensitivity.threshold_factor();
|
||||
|
||||
if !blockers.is_empty() {
|
||||
let blocker_count = blockers.lines().filter(|l| !l.trim().is_empty()).count();
|
||||
let severity = if (blocker_count as f64) > 3.0 * factor {
|
||||
"critical"
|
||||
} else if (blocker_count as f64) > 1.0 * factor {
|
||||
"high"
|
||||
} else {
|
||||
"medium"
|
||||
};
|
||||
risks.push(RiskItem {
|
||||
title: "Active blockers detected".into(),
|
||||
severity: severity.into(),
|
||||
detail: format!("{blocker_count} blocker(s) identified"),
|
||||
mitigation: "Escalate blockers, assign owners, set resolution deadlines".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if deadlines.to_lowercase().contains("overdue")
|
||||
|| deadlines.to_lowercase().contains("missed")
|
||||
{
|
||||
risks.push(RiskItem {
|
||||
title: "Deadline risk".into(),
|
||||
severity: "high".into(),
|
||||
detail: "Overdue or missed deadlines detected in project context".into(),
|
||||
mitigation: "Re-prioritize scope, negotiate timeline, add resources".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if velocity.to_lowercase().contains("declining") || velocity.to_lowercase().contains("slow")
|
||||
{
|
||||
risks.push(RiskItem {
|
||||
title: "Velocity degradation".into(),
|
||||
severity: "medium".into(),
|
||||
detail: "Team velocity is declining or below expectations".into(),
|
||||
mitigation: "Identify bottlenecks, reduce WIP, address technical debt".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if risks.is_empty() {
|
||||
risks.push(RiskItem {
|
||||
title: "No significant risks detected".into(),
|
||||
severity: "low".into(),
|
||||
detail: "Current project signals within normal parameters".into(),
|
||||
mitigation: "Continue monitoring".into(),
|
||||
});
|
||||
}
|
||||
|
||||
let tpl = report_templates::risk_register_template(lang);
|
||||
let risks_text = risks
|
||||
.iter()
|
||||
.map(|r| {
|
||||
format!(
|
||||
"- [{}] {}: {}",
|
||||
r.severity.to_uppercase(),
|
||||
r.title,
|
||||
r.detail
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let mitigations_text = risks
|
||||
.iter()
|
||||
.map(|r| format!("- {}: {}", r.title, r.mitigation))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert(
|
||||
"project_name".into(),
|
||||
args.get("project_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Unknown")
|
||||
.to_string(),
|
||||
);
|
||||
vars.insert("risks".into(), risks_text);
|
||||
vars.insert("mitigations".into(), mitigations_text);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: tpl.render(&vars),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_draft_update(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let project_name = args
|
||||
.get("project_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'project_name' for draft_update"))?;
|
||||
let audience = args
|
||||
.get("audience")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("client");
|
||||
let tone = args
|
||||
.get("tone")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("formal");
|
||||
let highlights = args
|
||||
.get("highlights")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required 'highlights' for draft_update"))?;
|
||||
let concerns = args.get("concerns").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
let greeting = match (audience, tone) {
|
||||
("client", "casual") => "Hi there,".to_string(),
|
||||
("client", _) => "Dear valued partner,".to_string(),
|
||||
("internal", "casual") => "Hey team,".to_string(),
|
||||
("internal", _) => "Dear team,".to_string(),
|
||||
(_, "casual") => "Hi,".to_string(),
|
||||
_ => "Dear reader,".to_string(),
|
||||
};
|
||||
|
||||
let closing = match tone {
|
||||
"casual" => "Cheers",
|
||||
_ => "Best regards",
|
||||
};
|
||||
|
||||
let mut body = format!(
|
||||
"{greeting}\n\nHere is an update on {project_name}.\n\n**Highlights:**\n{highlights}"
|
||||
);
|
||||
if !concerns.is_empty() {
|
||||
let _ = write!(body, "\n\n**Items requiring attention:**\n{concerns}");
|
||||
}
|
||||
let _ = write!(
|
||||
body,
|
||||
"\n\nPlease do not hesitate to reach out with any questions.\n\n{closing}"
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: body,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_sprint_summary(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let sprint_dates = args
|
||||
.get("sprint_dates")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("current sprint");
|
||||
let completed = args
|
||||
.get("completed")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("None specified");
|
||||
let in_progress = args
|
||||
.get("in_progress")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("None specified");
|
||||
let blocked = args
|
||||
.get("blocked")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("None");
|
||||
let velocity = args
|
||||
.get("velocity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Not calculated");
|
||||
let lang = args
|
||||
.get("language")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&self.default_language);
|
||||
|
||||
let tpl = report_templates::sprint_review_template(lang);
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("sprint_dates".into(), sprint_dates.to_string());
|
||||
vars.insert("completed".into(), completed.to_string());
|
||||
vars.insert("in_progress".into(), in_progress.to_string());
|
||||
vars.insert("blocked".into(), blocked.to_string());
|
||||
vars.insert("velocity".into(), velocity.to_string());
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: tpl.render(&vars),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn execute_effort_estimate(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let tasks = args.get("tasks").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if tasks.trim().is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("No task descriptions provided".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut estimates = Vec::new();
|
||||
for line in tasks.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (size, rationale) = estimate_task_effort(line);
|
||||
estimates.push(format!("- **{size}** | {line}\n Rationale: {rationale}"));
|
||||
}
|
||||
|
||||
let output = format!(
|
||||
"## Effort Estimates\n\n{}\n\n_Sizes: XS (<2h), S (2-4h), M (4-8h), L (1-3d), XL (3-5d), XXL (>5d)_",
|
||||
estimates.join("\n")
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct RiskItem {
|
||||
title: String,
|
||||
severity: String,
|
||||
detail: String,
|
||||
mitigation: String,
|
||||
}
|
||||
|
||||
/// Heuristic effort estimation from task description text.
|
||||
fn estimate_task_effort(description: &str) -> (&'static str, &'static str) {
|
||||
let lower = description.to_lowercase();
|
||||
let word_count = description.split_whitespace().count();
|
||||
|
||||
// Signal-based heuristics
|
||||
let complexity_signals = [
|
||||
"refactor",
|
||||
"rewrite",
|
||||
"migrate",
|
||||
"redesign",
|
||||
"architecture",
|
||||
"infrastructure",
|
||||
];
|
||||
let medium_signals = [
|
||||
"implement",
|
||||
"create",
|
||||
"build",
|
||||
"integrate",
|
||||
"add feature",
|
||||
"new module",
|
||||
];
|
||||
let small_signals = [
|
||||
"fix", "update", "tweak", "adjust", "rename", "typo", "bump", "config",
|
||||
];
|
||||
|
||||
if complexity_signals.iter().any(|s| lower.contains(s)) {
|
||||
if word_count > 15 {
|
||||
return (
|
||||
"XXL",
|
||||
"Large-scope structural change with extensive description",
|
||||
);
|
||||
}
|
||||
return ("XL", "Structural change requiring significant effort");
|
||||
}
|
||||
|
||||
if medium_signals.iter().any(|s| lower.contains(s)) {
|
||||
if word_count > 12 {
|
||||
return ("L", "Feature implementation with detailed requirements");
|
||||
}
|
||||
return ("M", "Standard feature implementation");
|
||||
}
|
||||
|
||||
if small_signals.iter().any(|s| lower.contains(s)) {
|
||||
if word_count > 10 {
|
||||
return ("S", "Small change with additional context");
|
||||
}
|
||||
return ("XS", "Minor targeted change");
|
||||
}
|
||||
|
||||
// Fallback: estimate by description length as a proxy for complexity
|
||||
if word_count > 20 {
|
||||
("L", "Complex task inferred from detailed description")
|
||||
} else if word_count > 10 {
|
||||
("M", "Moderate task inferred from description length")
|
||||
} else {
|
||||
("S", "Simple task inferred from brief description")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ProjectIntelTool {
|
||||
fn name(&self) -> &str {
|
||||
"project_intel"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Project delivery intelligence: generate status reports, detect risks, draft client updates, summarize sprints, and estimate effort. Read-only analysis tool."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["status_report", "risk_scan", "draft_update", "sprint_summary", "effort_estimate"],
|
||||
"description": "The analysis action to perform"
|
||||
},
|
||||
"project_name": {
|
||||
"type": "string",
|
||||
"description": "Project name (for status_report, risk_scan, draft_update)"
|
||||
},
|
||||
"period": {
|
||||
"type": "string",
|
||||
"description": "Reporting period: week, sprint, or month (for status_report)"
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Report language: en, de, fr, it (default from config)"
|
||||
},
|
||||
"git_log": {
|
||||
"type": "string",
|
||||
"description": "Git log summary text (for status_report)"
|
||||
},
|
||||
"jira_summary": {
|
||||
"type": "string",
|
||||
"description": "Jira/issue tracker summary (for status_report)"
|
||||
},
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"description": "Additional notes or context"
|
||||
},
|
||||
"deadlines": {
|
||||
"type": "string",
|
||||
"description": "Deadline information (for risk_scan)"
|
||||
},
|
||||
"velocity": {
|
||||
"type": "string",
|
||||
"description": "Team velocity data (for risk_scan, sprint_summary)"
|
||||
},
|
||||
"blockers": {
|
||||
"type": "string",
|
||||
"description": "Current blockers (for risk_scan)"
|
||||
},
|
||||
"audience": {
|
||||
"type": "string",
|
||||
"enum": ["client", "internal"],
|
||||
"description": "Target audience (for draft_update)"
|
||||
},
|
||||
"tone": {
|
||||
"type": "string",
|
||||
"enum": ["formal", "casual"],
|
||||
"description": "Communication tone (for draft_update)"
|
||||
},
|
||||
"highlights": {
|
||||
"type": "string",
|
||||
"description": "Key highlights for the update (for draft_update)"
|
||||
},
|
||||
"concerns": {
|
||||
"type": "string",
|
||||
"description": "Items requiring attention (for draft_update)"
|
||||
},
|
||||
"sprint_dates": {
|
||||
"type": "string",
|
||||
"description": "Sprint date range (for sprint_summary)"
|
||||
},
|
||||
"completed": {
|
||||
"type": "string",
|
||||
"description": "Completed items (for sprint_summary)"
|
||||
},
|
||||
"in_progress": {
|
||||
"type": "string",
|
||||
"description": "In-progress items (for sprint_summary)"
|
||||
},
|
||||
"blocked": {
|
||||
"type": "string",
|
||||
"description": "Blocked items (for sprint_summary)"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "string",
|
||||
"description": "Task descriptions, one per line (for effort_estimate)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
|
||||
|
||||
match action {
|
||||
"status_report" => self.execute_status_report(&args),
|
||||
"risk_scan" => self.execute_risk_scan(&args),
|
||||
"draft_update" => self.execute_draft_update(&args),
|
||||
"sprint_summary" => self.execute_sprint_summary(&args),
|
||||
"effort_estimate" => self.execute_effort_estimate(&args),
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action '{other}'. Valid actions: status_report, risk_scan, draft_update, sprint_summary, effort_estimate"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tool() -> ProjectIntelTool {
|
||||
ProjectIntelTool::new("en".into(), "medium".into())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_and_description() {
|
||||
let t = tool();
|
||||
assert_eq!(t.name(), "project_intel");
|
||||
assert!(!t.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parameters_schema_has_action() {
|
||||
let t = tool();
|
||||
let schema = t.parameters_schema();
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&serde_json::Value::String("action".into())));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn status_report_renders() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "status_report",
|
||||
"project_name": "TestProject",
|
||||
"period": "week",
|
||||
"git_log": "- feat: added login"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("TestProject"));
|
||||
assert!(result.output.contains("added login"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn risk_scan_detects_blockers() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "risk_scan",
|
||||
"blockers": "DB migration stuck\nCI pipeline broken\nAPI key expired"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("blocker"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn risk_scan_detects_deadline_risk() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "risk_scan",
|
||||
"deadlines": "Sprint deadline overdue by 3 days"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Deadline risk"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn risk_scan_no_signals_returns_low_risk() {
|
||||
let t = tool();
|
||||
let result = t.execute(json!({ "action": "risk_scan" })).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No significant risks"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn draft_update_formal_client() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "draft_update",
|
||||
"project_name": "Portal",
|
||||
"audience": "client",
|
||||
"tone": "formal",
|
||||
"highlights": "Phase 1 delivered"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Dear valued partner"));
|
||||
assert!(result.output.contains("Portal"));
|
||||
assert!(result.output.contains("Phase 1 delivered"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn draft_update_casual_internal() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "draft_update",
|
||||
"project_name": "ZeroClaw",
|
||||
"audience": "internal",
|
||||
"tone": "casual",
|
||||
"highlights": "Core loop stabilized"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Hey team"));
|
||||
assert!(result.output.contains("Cheers"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sprint_summary_renders() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "sprint_summary",
|
||||
"sprint_dates": "2026-03-01 to 2026-03-14",
|
||||
"completed": "- Login page\n- API endpoints",
|
||||
"in_progress": "- Dashboard",
|
||||
"blocked": "- Payment integration"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Login page"));
|
||||
assert!(result.output.contains("Dashboard"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn effort_estimate_basic() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "effort_estimate",
|
||||
"tasks": "Fix typo in README\nImplement user authentication\nRefactor database layer"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("XS"));
|
||||
assert!(result.output.contains("Refactor database layer"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn effort_estimate_empty_tasks_fails() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({ "action": "effort_estimate", "tasks": "" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("No task descriptions"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_action_returns_error() {
|
||||
let t = tool();
|
||||
let result = t
|
||||
.execute(json!({ "action": "invalid_thing" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_action_returns_error() {
|
||||
let t = tool();
|
||||
let result = t.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effort_estimate_heuristics_coverage() {
|
||||
assert_eq!(estimate_task_effort("Fix typo").0, "XS");
|
||||
assert_eq!(estimate_task_effort("Update config values").0, "XS");
|
||||
assert_eq!(
|
||||
estimate_task_effort("Implement new notification system").0,
|
||||
"M"
|
||||
);
|
||||
assert_eq!(
|
||||
estimate_task_effort("Refactor the entire authentication module").0,
|
||||
"XL"
|
||||
);
|
||||
assert_eq!(
|
||||
estimate_task_effort("Migrate the database schema to support multi-tenancy with data isolation and proper indexing across all services").0,
|
||||
"XXL"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn risk_sensitivity_threshold_ordering() {
|
||||
assert!(
|
||||
RiskSensitivity::High.threshold_factor() < RiskSensitivity::Medium.threshold_factor()
|
||||
);
|
||||
assert!(
|
||||
RiskSensitivity::Medium.threshold_factor() < RiskSensitivity::Low.threshold_factor()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn risk_sensitivity_from_str_variants() {
|
||||
assert_eq!(RiskSensitivity::from_str("low"), RiskSensitivity::Low);
|
||||
assert_eq!(RiskSensitivity::from_str("high"), RiskSensitivity::High);
|
||||
assert_eq!(RiskSensitivity::from_str("medium"), RiskSensitivity::Medium);
|
||||
assert_eq!(
|
||||
RiskSensitivity::from_str("unknown"),
|
||||
RiskSensitivity::Medium
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn high_sensitivity_detects_single_blocker_as_high() {
|
||||
let t = ProjectIntelTool::new("en".into(), "high".into());
|
||||
let result = t
|
||||
.execute(json!({
|
||||
"action": "risk_scan",
|
||||
"blockers": "Single blocker"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("[HIGH]") || result.output.contains("[CRITICAL]"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,582 @@
|
||||
//! Report template engine for project delivery intelligence.
|
||||
//!
|
||||
//! Provides built-in templates for weekly status, sprint review, risk register,
|
||||
//! and milestone reports with multi-language support (EN, DE, FR, IT).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
|
||||
/// Supported report output formats.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReportFormat {
|
||||
Markdown,
|
||||
Html,
|
||||
}
|
||||
|
||||
/// A named section within a report template.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TemplateSection {
|
||||
pub heading: String,
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
/// A report template with named sections and variable placeholders.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReportTemplate {
|
||||
pub name: String,
|
||||
pub sections: Vec<TemplateSection>,
|
||||
pub format: ReportFormat,
|
||||
}
|
||||
|
||||
/// Escape a string for safe inclusion in HTML output.
|
||||
fn escape_html(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
impl ReportTemplate {
|
||||
/// Render the template by substituting `{{key}}` placeholders with values.
|
||||
pub fn render(&self, vars: &HashMap<String, String>) -> String {
|
||||
let mut out = String::new();
|
||||
for section in &self.sections {
|
||||
let heading = substitute(§ion.heading, vars);
|
||||
let body = substitute(§ion.body, vars);
|
||||
match self.format {
|
||||
ReportFormat::Markdown => {
|
||||
let _ = write!(out, "## {heading}\n\n{body}\n\n");
|
||||
}
|
||||
ReportFormat::Html => {
|
||||
let heading = escape_html(&heading);
|
||||
let body = escape_html(&body);
|
||||
let _ = write!(out, "<h2>{heading}</h2>\n<p>{body}</p>\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
out.trim_end().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Single-pass placeholder substitution.
|
||||
///
|
||||
/// Scans `template` left-to-right for `{{key}}` tokens and replaces them with
|
||||
/// the corresponding value from `vars`. Because the scan is single-pass,
|
||||
/// values that themselves contain `{{...}}` sequences are emitted literally
|
||||
/// and never re-expanded, preventing injection of new placeholders.
|
||||
fn substitute(template: &str, vars: &HashMap<String, String>) -> String {
|
||||
let mut result = String::with_capacity(template.len());
|
||||
let bytes = template.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut i = 0;
|
||||
|
||||
while i < len {
|
||||
if i + 1 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' {
|
||||
// Find the closing `}}`.
|
||||
if let Some(close) = template[i + 2..].find("}}") {
|
||||
let key = &template[i + 2..i + 2 + close];
|
||||
if let Some(value) = vars.get(key) {
|
||||
result.push_str(value);
|
||||
} else {
|
||||
// Unknown placeholder: emit as-is.
|
||||
result.push_str(&template[i..i + 2 + close + 2]);
|
||||
}
|
||||
i += 2 + close + 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result.push(template.as_bytes()[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// ── Built-in templates ────────────────────────────────────────────
|
||||
|
||||
/// Return the built-in weekly status template for the given language.
|
||||
pub fn weekly_status_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Wochenstatus",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Zusammenfassung".into(),
|
||||
body: "Projekt: {{project_name}} | Zeitraum: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Erledigt".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Bearbeitung".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blockiert".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Naechste Schritte".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Statut hebdomadaire",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Resume".into(),
|
||||
body: "Projet: {{project_name}} | Periode: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Termine".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "En cours".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloque".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Prochaines etapes".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Stato settimanale",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Riepilogo".into(),
|
||||
body: "Progetto: {{project_name}} | Periodo: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completato".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In corso".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloccato".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Prossimi passi".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Weekly Status",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Summary".into(),
|
||||
body: "Project: {{project_name}} | Period: {{period}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completed".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Progress".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blocked".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Next Steps".into(),
|
||||
body: "{{next_steps}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the built-in sprint review template for the given language.
|
||||
pub fn sprint_review_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Sprint-Uebersicht",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Erledigt".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Bearbeitung".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blockiert".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocity".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Revue de sprint",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Termine".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "En cours".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloque".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocite".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Revisione sprint",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completato".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In corso".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Bloccato".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocita".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Sprint Review",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Sprint".into(),
|
||||
body: "{{sprint_dates}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Completed".into(),
|
||||
body: "{{completed}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "In Progress".into(),
|
||||
body: "{{in_progress}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Blocked".into(),
|
||||
body: "{{blocked}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Velocity".into(),
|
||||
body: "{{velocity}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the built-in risk register template for the given language.
|
||||
pub fn risk_register_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Risikoregister",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projekt".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Risiken".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Massnahmen".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Registre des risques",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projet".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Risques".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Mesures".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Registro dei rischi",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Progetto".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Rischi".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Mitigazioni".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Risk Register",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Project".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Risks".into(),
|
||||
body: "{{risks}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Mitigations".into(),
|
||||
body: "{{mitigations}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the built-in milestone report template for the given language.
|
||||
pub fn milestone_report_template(lang: &str) -> ReportTemplate {
|
||||
let (name, sections) = match lang {
|
||||
"de" => (
|
||||
"Meilensteinbericht",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projekt".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Meilensteine".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Status".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"fr" => (
|
||||
"Rapport de jalons",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Projet".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Jalons".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Statut".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
"it" => (
|
||||
"Report milestone",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Progetto".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Milestone".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Stato".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
_ => (
|
||||
"Milestone Report",
|
||||
vec![
|
||||
TemplateSection {
|
||||
heading: "Project".into(),
|
||||
body: "{{project_name}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Milestones".into(),
|
||||
body: "{{milestones}}".into(),
|
||||
},
|
||||
TemplateSection {
|
||||
heading: "Status".into(),
|
||||
body: "{{status}}".into(),
|
||||
},
|
||||
],
|
||||
),
|
||||
};
|
||||
ReportTemplate {
|
||||
name: name.into(),
|
||||
sections,
|
||||
format: ReportFormat::Markdown,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn weekly_status_renders_with_variables() {
|
||||
let tpl = weekly_status_template("en");
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("project_name".into(), "ZeroClaw".into());
|
||||
vars.insert("period".into(), "2026-W10".into());
|
||||
vars.insert("completed".into(), "- Task A\n- Task B".into());
|
||||
vars.insert("in_progress".into(), "- Task C".into());
|
||||
vars.insert("blocked".into(), "None".into());
|
||||
vars.insert("next_steps".into(), "- Task D".into());
|
||||
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("Project: ZeroClaw"));
|
||||
assert!(rendered.contains("Period: 2026-W10"));
|
||||
assert!(rendered.contains("- Task A"));
|
||||
assert!(rendered.contains("## Completed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weekly_status_de_renders_german_headings() {
|
||||
let tpl = weekly_status_template("de");
|
||||
let vars = HashMap::new();
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("## Zusammenfassung"));
|
||||
assert!(rendered.contains("## Erledigt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weekly_status_fr_renders_french_headings() {
|
||||
let tpl = weekly_status_template("fr");
|
||||
let vars = HashMap::new();
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("## Resume"));
|
||||
assert!(rendered.contains("## Termine"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weekly_status_it_renders_italian_headings() {
|
||||
let tpl = weekly_status_template("it");
|
||||
let vars = HashMap::new();
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("## Riepilogo"));
|
||||
assert!(rendered.contains("## Completato"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_format_renders_tags() {
|
||||
let mut tpl = weekly_status_template("en");
|
||||
tpl.format = ReportFormat::Html;
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("project_name".into(), "Test".into());
|
||||
vars.insert("period".into(), "W1".into());
|
||||
vars.insert("completed".into(), "Done".into());
|
||||
vars.insert("in_progress".into(), "WIP".into());
|
||||
vars.insert("blocked".into(), "None".into());
|
||||
vars.insert("next_steps".into(), "Next".into());
|
||||
|
||||
let rendered = tpl.render(&vars);
|
||||
assert!(rendered.contains("<h2>Summary</h2>"));
|
||||
assert!(rendered.contains("<p>Project: Test | Period: W1</p>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sprint_review_template_has_velocity_section() {
|
||||
let tpl = sprint_review_template("en");
|
||||
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
|
||||
assert!(section_headings.contains(&"Velocity"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn risk_register_template_has_risk_sections() {
|
||||
let tpl = risk_register_template("en");
|
||||
let section_headings: Vec<&str> = tpl.sections.iter().map(|s| s.heading.as_str()).collect();
|
||||
assert!(section_headings.contains(&"Risks"));
|
||||
assert!(section_headings.contains(&"Mitigations"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn milestone_template_all_languages() {
|
||||
for lang in &["en", "de", "fr", "it"] {
|
||||
let tpl = milestone_report_template(lang);
|
||||
assert!(!tpl.name.is_empty());
|
||||
assert_eq!(tpl.sections.len(), 3);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn substitute_leaves_unknown_placeholders() {
|
||||
let vars = HashMap::new();
|
||||
let result = substitute("Hello {{name}}", &vars);
|
||||
assert_eq!(result, "Hello {{name}}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn substitute_replaces_all_occurrences() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("x".into(), "1".into());
|
||||
let result = substitute("{{x}} and {{x}}", &vars);
|
||||
assert_eq!(result, "1 and 1");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,659 @@
|
||||
//! Security operations tool for managed cybersecurity service (MCSS) workflows.
|
||||
//!
|
||||
//! Provides alert triage, incident response playbook execution, vulnerability
|
||||
//! scan parsing, and security report generation. All actions that modify state
|
||||
//! enforce human approval gates unless explicitly configured otherwise.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::SecurityOpsConfig;
|
||||
use crate::security::playbook::{
|
||||
evaluate_step, load_playbooks, severity_level, Playbook, StepStatus,
|
||||
};
|
||||
use crate::security::vulnerability::{generate_summary, parse_vulnerability_json};
|
||||
|
||||
/// Security operations tool — triage alerts, run playbooks, parse vulns, generate reports.
|
||||
pub struct SecurityOpsTool {
|
||||
config: SecurityOpsConfig,
|
||||
playbooks: Vec<Playbook>,
|
||||
}
|
||||
|
||||
impl SecurityOpsTool {
|
||||
pub fn new(config: SecurityOpsConfig) -> Self {
|
||||
let playbooks_dir = expand_tilde(&config.playbooks_dir);
|
||||
let playbooks = load_playbooks(&playbooks_dir);
|
||||
Self { config, playbooks }
|
||||
}
|
||||
|
||||
/// Triage an alert: classify severity and recommend response.
|
||||
fn triage_alert(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let alert = args
|
||||
.get("alert")
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'alert' parameter"))?;
|
||||
|
||||
// Extract key fields for classification
|
||||
let alert_type = alert
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let source = alert
|
||||
.get("source")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let severity = alert
|
||||
.get("severity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("medium");
|
||||
let description = alert
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Classify and find matching playbooks
|
||||
let matching_playbooks: Vec<&Playbook> = self
|
||||
.playbooks
|
||||
.iter()
|
||||
.filter(|pb| {
|
||||
severity_level(severity) >= severity_level(&pb.severity_filter)
|
||||
&& (pb.name.contains(alert_type)
|
||||
|| alert_type.contains(&pb.name)
|
||||
|| description
|
||||
.to_lowercase()
|
||||
.contains(&pb.name.replace('_', " ")))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let playbook_names: Vec<&str> =
|
||||
matching_playbooks.iter().map(|p| p.name.as_str()).collect();
|
||||
|
||||
let output = json!({
|
||||
"classification": {
|
||||
"alert_type": alert_type,
|
||||
"source": source,
|
||||
"severity": severity,
|
||||
"severity_level": severity_level(severity),
|
||||
"priority": if severity_level(severity) >= 3 { "immediate" } else { "standard" },
|
||||
},
|
||||
"recommended_playbooks": playbook_names,
|
||||
"recommended_action": if matching_playbooks.is_empty() {
|
||||
"Manual investigation required — no matching playbook found"
|
||||
} else {
|
||||
"Execute recommended playbook(s)"
|
||||
},
|
||||
"auto_triage": self.config.auto_triage,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a playbook step with approval gating.
|
||||
fn run_playbook(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let playbook_name = args
|
||||
.get("playbook")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'playbook' parameter"))?;
|
||||
|
||||
let step_index =
|
||||
usize::try_from(args.get("step").and_then(|v| v.as_u64()).ok_or_else(|| {
|
||||
anyhow::anyhow!("Missing required 'step' parameter (0-based index)")
|
||||
})?)
|
||||
.map_err(|_| anyhow::anyhow!("'step' parameter value too large for this platform"))?;
|
||||
|
||||
let alert_severity = args
|
||||
.get("alert_severity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("medium");
|
||||
|
||||
let playbook = self
|
||||
.playbooks
|
||||
.iter()
|
||||
.find(|p| p.name == playbook_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("Playbook '{}' not found", playbook_name))?;
|
||||
|
||||
let result = evaluate_step(
|
||||
playbook,
|
||||
step_index,
|
||||
alert_severity,
|
||||
&self.config.max_auto_severity,
|
||||
self.config.require_approval_for_actions,
|
||||
);
|
||||
|
||||
let output = json!({
|
||||
"playbook": playbook_name,
|
||||
"step_index": result.step_index,
|
||||
"action": result.action,
|
||||
"status": result.status.to_string(),
|
||||
"message": result.message,
|
||||
"requires_manual_approval": result.status == StepStatus::PendingApproval,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: result.status != StepStatus::Failed,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: if result.status == StepStatus::Failed {
|
||||
Some(result.message)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse vulnerability scan results.
|
||||
fn parse_vulnerability(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let scan_data = args
|
||||
.get("scan_data")
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'scan_data' parameter"))?;
|
||||
|
||||
let json_str = if scan_data.is_string() {
|
||||
scan_data.as_str().unwrap().to_string()
|
||||
} else {
|
||||
serde_json::to_string(scan_data)?
|
||||
};
|
||||
|
||||
let report = parse_vulnerability_json(&json_str)?;
|
||||
let summary = generate_summary(&report);
|
||||
|
||||
let output = json!({
|
||||
"scanner": report.scanner,
|
||||
"scan_date": report.scan_date.to_rfc3339(),
|
||||
"total_findings": report.findings.len(),
|
||||
"by_severity": {
|
||||
"critical": report.findings.iter().filter(|f| f.severity == "critical").count(),
|
||||
"high": report.findings.iter().filter(|f| f.severity == "high").count(),
|
||||
"medium": report.findings.iter().filter(|f| f.severity == "medium").count(),
|
||||
"low": report.findings.iter().filter(|f| f.severity == "low").count(),
|
||||
},
|
||||
"summary": summary,
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a client-facing security posture report.
|
||||
fn generate_report(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let client_name = args
|
||||
.get("client_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Client");
|
||||
let period = args
|
||||
.get("period")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("current");
|
||||
let alert_stats = args.get("alert_stats");
|
||||
let vuln_summary = args
|
||||
.get("vuln_summary")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let report = format!(
|
||||
"# Security Posture Report — {client_name}\n\
|
||||
**Period:** {period}\n\
|
||||
**Generated:** {}\n\n\
|
||||
## Executive Summary\n\n\
|
||||
This report provides an overview of the security posture for {client_name} \
|
||||
during the {period} period.\n\n\
|
||||
## Alert Summary\n\n\
|
||||
{}\n\n\
|
||||
## Vulnerability Assessment\n\n\
|
||||
{}\n\n\
|
||||
## Recommendations\n\n\
|
||||
1. Address all critical and high-severity findings immediately\n\
|
||||
2. Review and update incident response playbooks quarterly\n\
|
||||
3. Conduct regular vulnerability scans on all internet-facing assets\n\
|
||||
4. Ensure all endpoints have current security patches\n\n\
|
||||
---\n\
|
||||
*Report generated by ZeroClaw MCSS Agent*\n",
|
||||
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"),
|
||||
alert_stats
|
||||
.map(|s| serde_json::to_string_pretty(s).unwrap_or_default())
|
||||
.unwrap_or_else(|| "No alert statistics provided.".into()),
|
||||
if vuln_summary.is_empty() {
|
||||
"No vulnerability data provided."
|
||||
} else {
|
||||
vuln_summary
|
||||
},
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: report,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// List available playbooks.
|
||||
fn list_playbooks(&self) -> anyhow::Result<ToolResult> {
|
||||
if self.playbooks.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No playbooks available.".into(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let playbook_list: Vec<serde_json::Value> = self
|
||||
.playbooks
|
||||
.iter()
|
||||
.map(|pb| {
|
||||
json!({
|
||||
"name": pb.name,
|
||||
"description": pb.description,
|
||||
"steps": pb.steps.len(),
|
||||
"severity_filter": pb.severity_filter,
|
||||
"auto_approve_steps": pb.auto_approve_steps,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&playbook_list)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Summarize alert volume, categories, and resolution times.
|
||||
fn alert_stats(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let alerts = args
|
||||
.get("alerts")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'alerts' array parameter"))?;
|
||||
|
||||
let total = alerts.len();
|
||||
let mut by_severity = std::collections::HashMap::new();
|
||||
let mut by_category = std::collections::HashMap::new();
|
||||
let mut resolved_count = 0u64;
|
||||
let mut total_resolution_secs = 0u64;
|
||||
|
||||
for alert in alerts {
|
||||
let severity = alert
|
||||
.get("severity")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
*by_severity.entry(severity.to_string()).or_insert(0u64) += 1;
|
||||
|
||||
let category = alert
|
||||
.get("category")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("uncategorized");
|
||||
*by_category.entry(category.to_string()).or_insert(0u64) += 1;
|
||||
|
||||
if let Some(resolution_secs) = alert.get("resolution_secs").and_then(|v| v.as_u64()) {
|
||||
resolved_count += 1;
|
||||
total_resolution_secs += resolution_secs;
|
||||
}
|
||||
}
|
||||
|
||||
let avg_resolution = if resolved_count > 0 {
|
||||
total_resolution_secs as f64 / resolved_count as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
let avg_resolution_secs_u64 = avg_resolution.max(0.0) as u64;
|
||||
|
||||
let output = json!({
|
||||
"total_alerts": total,
|
||||
"resolved": resolved_count,
|
||||
"unresolved": total as u64 - resolved_count,
|
||||
"by_severity": by_severity,
|
||||
"by_category": by_category,
|
||||
"avg_resolution_secs": avg_resolution,
|
||||
"avg_resolution_human": format_duration_secs(avg_resolution_secs_u64),
|
||||
});
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&output)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn format_duration_secs(secs: u64) -> String {
|
||||
if secs < 60 {
|
||||
format!("{secs}s")
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
|
||||
/// Expand ~ to home directory.
|
||||
fn expand_tilde(path: &str) -> PathBuf {
|
||||
if let Some(rest) = path.strip_prefix("~/") {
|
||||
if let Some(user_dirs) = directories::UserDirs::new() {
|
||||
return user_dirs.home_dir().join(rest);
|
||||
}
|
||||
}
|
||||
PathBuf::from(path)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SecurityOpsTool {
|
||||
fn name(&self) -> &str {
|
||||
"security_ops"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Security operations tool for managed cybersecurity services. Actions: \
|
||||
triage_alert (classify/prioritize alerts), run_playbook (execute incident response steps), \
|
||||
parse_vulnerability (parse scan results), generate_report (create security posture reports), \
|
||||
list_playbooks (list available playbooks), alert_stats (summarize alert metrics)."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["action"],
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["triage_alert", "run_playbook", "parse_vulnerability", "generate_report", "list_playbooks", "alert_stats"],
|
||||
"description": "The security operation to perform"
|
||||
},
|
||||
"alert": {
|
||||
"type": "object",
|
||||
"description": "Alert JSON for triage_alert (requires: type, severity; optional: source, description)"
|
||||
},
|
||||
"playbook": {
|
||||
"type": "string",
|
||||
"description": "Playbook name for run_playbook"
|
||||
},
|
||||
"step": {
|
||||
"type": "integer",
|
||||
"description": "0-based step index for run_playbook"
|
||||
},
|
||||
"alert_severity": {
|
||||
"type": "string",
|
||||
"description": "Alert severity context for run_playbook"
|
||||
},
|
||||
"scan_data": {
|
||||
"description": "Vulnerability scan data (JSON string or object) for parse_vulnerability"
|
||||
},
|
||||
"client_name": {
|
||||
"type": "string",
|
||||
"description": "Client name for generate_report"
|
||||
},
|
||||
"period": {
|
||||
"type": "string",
|
||||
"description": "Reporting period for generate_report"
|
||||
},
|
||||
"alert_stats": {
|
||||
"type": "object",
|
||||
"description": "Alert statistics to include in generate_report"
|
||||
},
|
||||
"vuln_summary": {
|
||||
"type": "string",
|
||||
"description": "Vulnerability summary to include in generate_report"
|
||||
},
|
||||
"alerts": {
|
||||
"type": "array",
|
||||
"description": "Array of alert objects for alert_stats"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing required 'action' parameter"))?;
|
||||
|
||||
match action {
|
||||
"triage_alert" => self.triage_alert(&args),
|
||||
"run_playbook" => self.run_playbook(&args),
|
||||
"parse_vulnerability" => self.parse_vulnerability(&args),
|
||||
"generate_report" => self.generate_report(&args),
|
||||
"list_playbooks" => self.list_playbooks(),
|
||||
"alert_stats" => self.alert_stats(&args),
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action '{action}'. Valid: triage_alert, run_playbook, \
|
||||
parse_vulnerability, generate_report, list_playbooks, alert_stats"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> SecurityOpsConfig {
|
||||
SecurityOpsConfig {
|
||||
enabled: true,
|
||||
playbooks_dir: "/nonexistent".into(),
|
||||
auto_triage: false,
|
||||
require_approval_for_actions: true,
|
||||
max_auto_severity: "low".into(),
|
||||
report_output_dir: "/tmp/reports".into(),
|
||||
siem_integration: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_tool() -> SecurityOpsTool {
|
||||
SecurityOpsTool::new(test_config())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name_and_schema() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "security_ops");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("action")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn triage_alert_classifies_severity() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "triage_alert",
|
||||
"alert": {
|
||||
"type": "suspicious_login",
|
||||
"source": "siem",
|
||||
"severity": "high",
|
||||
"description": "Multiple failed login attempts followed by successful login"
|
||||
}
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["classification"]["severity"], "high");
|
||||
assert_eq!(output["classification"]["priority"], "immediate");
|
||||
// Should match suspicious_login playbook
|
||||
let playbooks = output["recommended_playbooks"].as_array().unwrap();
|
||||
assert!(playbooks.iter().any(|p| p == "suspicious_login"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn triage_alert_missing_alert_param() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "triage_alert"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_playbook_requires_approval() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "run_playbook",
|
||||
"playbook": "suspicious_login",
|
||||
"step": 2,
|
||||
"alert_severity": "high"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["status"], "pending_approval");
|
||||
assert_eq!(output["requires_manual_approval"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_playbook_executes_safe_step() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "run_playbook",
|
||||
"playbook": "suspicious_login",
|
||||
"step": 0,
|
||||
"alert_severity": "medium"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["status"], "completed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_playbook_not_found() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "run_playbook",
|
||||
"playbook": "nonexistent",
|
||||
"step": 0
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_vulnerability_valid_report() {
|
||||
let tool = test_tool();
|
||||
let scan_data = json!({
|
||||
"scan_date": "2025-01-15T10:00:00Z",
|
||||
"scanner": "nessus",
|
||||
"findings": [
|
||||
{
|
||||
"cve_id": "CVE-2024-0001",
|
||||
"cvss_score": 9.8,
|
||||
"severity": "critical",
|
||||
"affected_asset": "web-01",
|
||||
"description": "RCE in web framework",
|
||||
"remediation": "Upgrade",
|
||||
"internet_facing": true,
|
||||
"production": true
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "parse_vulnerability",
|
||||
"scan_data": scan_data
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["total_findings"], 1);
|
||||
assert_eq!(output["by_severity"]["critical"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn generate_report_produces_markdown() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "generate_report",
|
||||
"client_name": "ZeroClaw Corp",
|
||||
"period": "Q1 2025"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("ZeroClaw Corp"));
|
||||
assert!(result.output.contains("Q1 2025"));
|
||||
assert!(result.output.contains("Security Posture Report"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_playbooks_returns_builtins() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({"action": "list_playbooks"}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: Vec<serde_json::Value> = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output.len(), 4);
|
||||
let names: Vec<&str> = output.iter().map(|p| p["name"].as_str().unwrap()).collect();
|
||||
assert!(names.contains(&"suspicious_login"));
|
||||
assert!(names.contains(&"malware_detected"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn alert_stats_computes_summary() {
|
||||
let tool = test_tool();
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "alert_stats",
|
||||
"alerts": [
|
||||
{"severity": "critical", "category": "malware", "resolution_secs": 3600},
|
||||
{"severity": "high", "category": "phishing", "resolution_secs": 1800},
|
||||
{"severity": "medium", "category": "malware"},
|
||||
{"severity": "low", "category": "policy_violation", "resolution_secs": 600}
|
||||
]
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output: serde_json::Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(output["total_alerts"], 4);
|
||||
assert_eq!(output["resolved"], 3);
|
||||
assert_eq!(output["unresolved"], 1);
|
||||
assert_eq!(output["by_severity"]["critical"], 1);
|
||||
assert_eq!(output["by_category"]["malware"], 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_action_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"action": "bad_action"})).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_duration_secs_readable() {
|
||||
assert_eq!(format_duration_secs(45), "45s");
|
||||
assert_eq!(format_duration_secs(125), "2m 5s");
|
||||
assert_eq!(format_duration_secs(3665), "1h 1m");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,953 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::{DelegateAgentConfig, SwarmConfig, SwarmStrategy};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default timeout for individual agent calls within a swarm.
|
||||
const SWARM_AGENT_TIMEOUT_SECS: u64 = 120;
|
||||
|
||||
/// Tool that orchestrates multiple agents as a swarm. Supports sequential
|
||||
/// (pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies.
|
||||
pub struct SwarmTool {
|
||||
swarms: Arc<HashMap<String, SwarmConfig>>,
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
fallback_credential: Option<String>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
}
|
||||
|
||||
impl SwarmTool {
|
||||
pub fn new(
|
||||
swarms: HashMap<String, SwarmConfig>,
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_credential: Option<String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
) -> Self {
|
||||
Self {
|
||||
swarms: Arc::new(swarms),
|
||||
agents: Arc::new(agents),
|
||||
security,
|
||||
fallback_credential,
|
||||
provider_runtime_options,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider_for_agent(
|
||||
&self,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
agent_name: &str,
|
||||
) -> Result<Box<dyn Provider>, ToolResult> {
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
)
|
||||
.map_err(|e| ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider '{}' for agent '{agent_name}': {e}",
|
||||
agent_config.provider
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn call_agent(
|
||||
&self,
|
||||
agent_name: &str,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
prompt: &str,
|
||||
timeout_secs: u64,
|
||||
) -> Result<String, String> {
|
||||
let provider = self
|
||||
.create_provider_for_agent(agent_config, agent_name)
|
||||
.map_err(|r| r.error.unwrap_or_default())?;
|
||||
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
provider.chat_with_system(
|
||||
agent_config.system_prompt.as_deref(),
|
||||
prompt,
|
||||
&agent_config.model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(response)) => {
|
||||
if response.trim().is_empty() {
|
||||
Ok("[Empty response]".to_string())
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => Err(format!("Agent '{agent_name}' failed: {e}")),
|
||||
Err(_) => Err(format!(
|
||||
"Agent '{agent_name}' timed out after {timeout_secs}s"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_sequential(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let mut current_input = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let per_agent_timeout = swarm_config.timeout_secs / swarm_config.agents.len().max(1) as u64;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (i, agent_name) in swarm_config.agents.iter().enumerate() {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let agent_prompt = if i == 0 {
|
||||
current_input.clone()
|
||||
} else {
|
||||
format!("[Previous agent output]\n{current_input}\n\n[Original task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(agent_name, agent_config, &agent_prompt, per_agent_timeout)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
results.push(format!(
|
||||
"[{agent_name} ({}/{})] {output}",
|
||||
agent_config.provider, agent_config.model
|
||||
));
|
||||
current_input = output;
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: results.join("\n\n"),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm sequential — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_parallel(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for agent_name in &swarm_config.agents {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg.clone(),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
let provider = match providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider for agent '{agent_name}': {e}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let name = agent_name.clone();
|
||||
let prompt_clone = full_prompt.clone();
|
||||
let timeout = swarm_config.timeout_secs;
|
||||
let model = agent_config.model.clone();
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
let system_prompt = agent_config.system_prompt.clone();
|
||||
let provider_name = agent_config.provider.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout),
|
||||
provider.chat_with_system(
|
||||
system_prompt.as_deref(),
|
||||
&prompt_clone,
|
||||
&model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let output = match result {
|
||||
Ok(Ok(text)) => {
|
||||
if text.trim().is_empty() {
|
||||
"[Empty response]".to_string()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => format!("[Error] {e}"),
|
||||
Err(_) => format!("[Timed out after {timeout}s]"),
|
||||
};
|
||||
|
||||
(name, provider_name, model, output)
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(join_result) = join_set.join_next().await {
|
||||
match join_result {
|
||||
Ok((name, provider_name, model, output)) => {
|
||||
results.push(format!("[{name} ({provider_name}/{model})]\n{output}"));
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(format!("[join error] {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm parallel — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n---\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_router(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router swarm has no agents to choose from".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build agent descriptions for the router prompt
|
||||
let agent_descriptions: Vec<String> = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.filter_map(|name| {
|
||||
self.agents.get(name).map(|cfg| {
|
||||
let desc = cfg
|
||||
.system_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("General purpose agent");
|
||||
format!(
|
||||
"- {name}: {desc} (provider: {}, model: {})",
|
||||
cfg.provider, cfg.model
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Use the first agent's provider for routing
|
||||
let first_agent_name = &swarm_config.agents[0];
|
||||
let first_agent_config = match self.agents.get(first_agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Swarm references unknown agent '{first_agent_name}'"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let router_provider = self
|
||||
.create_provider_for_agent(first_agent_config, first_agent_name)
|
||||
.map_err(|r| anyhow::anyhow!(r.error.unwrap_or_default()))?;
|
||||
|
||||
let base_router_prompt = swarm_config
|
||||
.router_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("Pick the single best agent for this task.");
|
||||
|
||||
let routing_prompt = format!(
|
||||
"{base_router_prompt}\n\nAvailable agents:\n{}\n\nUser task: {prompt}\n\n\
|
||||
Respond with ONLY the agent name, nothing else.",
|
||||
agent_descriptions.join("\n")
|
||||
);
|
||||
|
||||
let chosen = tokio::time::timeout(
|
||||
Duration::from_secs(SWARM_AGENT_TIMEOUT_SECS),
|
||||
router_provider.chat_with_system(
|
||||
Some("You are a routing assistant. Respond with only the agent name."),
|
||||
&routing_prompt,
|
||||
&first_agent_config.model,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let chosen_name = match chosen {
|
||||
Ok(Ok(name)) => name.trim().to_string(),
|
||||
Ok(Err(e)) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router LLM call failed: {e}")),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router LLM call timed out".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Case-insensitive matching with fallback to first agent
|
||||
let matched_name = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.find(|name| name.eq_ignore_ascii_case(&chosen_name))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| swarm_config.agents[0].clone());
|
||||
|
||||
let agent_config = match self.agents.get(&matched_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router selected unknown agent '{matched_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(
|
||||
&matched_name,
|
||||
agent_config,
|
||||
&full_prompt,
|
||||
swarm_config.timeout_secs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm router — selected '{matched_name}' ({}/{})]\n{output}",
|
||||
agent_config.provider, agent_config.model
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SwarmTool {
|
||||
fn name(&self) -> &str {
|
||||
"swarm"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Orchestrate a swarm of agents to collaboratively handle a task. Supports sequential \
|
||||
(pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
let swarm_names: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
json!({
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"swarm": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": format!(
|
||||
"Name of the swarm to invoke. Available: {}",
|
||||
if swarm_names.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
swarm_names.join(", ")
|
||||
}
|
||||
)
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "The task/prompt to send to the swarm"
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Optional context to include (e.g. relevant code, prior findings)"
|
||||
}
|
||||
},
|
||||
"required": ["swarm", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let swarm_name = args
|
||||
.get("swarm")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'swarm' parameter"))?;
|
||||
|
||||
if swarm_name.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'swarm' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
if prompt.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'prompt' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let context = args
|
||||
.get("context")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.unwrap_or("");
|
||||
|
||||
let swarm_config = match self.swarms.get(swarm_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
let available: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown swarm '{swarm_name}'. Available swarms: {}",
|
||||
if available.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
available.join(", ")
|
||||
}
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm '{swarm_name}' has no agents configured")),
|
||||
});
|
||||
}
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "swarm")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match swarm_config.strategy {
|
||||
SwarmStrategy::Sequential => {
|
||||
self.execute_sequential(swarm_config, prompt, context).await
|
||||
}
|
||||
SwarmStrategy::Parallel => self.execute_parallel(swarm_config, prompt, context).await,
|
||||
SwarmStrategy::Router => self.execute_router(swarm_config, prompt, context).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn sample_agents() -> HashMap<String, DelegateAgentConfig> {
|
||||
let mut agents = HashMap::new();
|
||||
agents.insert(
|
||||
"researcher".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "llama3".to_string(),
|
||||
system_prompt: Some("You are a research assistant.".to_string()),
|
||||
api_key: None,
|
||||
temperature: Some(0.3),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
"writer".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
system_prompt: Some("You are a technical writer.".to_string()),
|
||||
api_key: Some("test-key".to_string()),
|
||||
temperature: Some(0.5),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents
|
||||
}
|
||||
|
||||
fn sample_swarms() -> HashMap<String, SwarmConfig> {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"pipeline".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: Some("Research then write".to_string()),
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"fanout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"router".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick the best agent.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert_eq!(tool.name(), "swarm");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["swarm"].is_object());
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["properties"]["context"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&json!("swarm")));
|
||||
assert!(required.contains(&json!("prompt")));
|
||||
assert_eq!(schema["additionalProperties"], json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn description_not_empty() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_lists_swarm_names() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("pipeline") || desc.contains("fanout") || desc.contains("router"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_swarms_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
HashMap::new(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("none configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_swarm_returns_error() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "nonexistent", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown swarm"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_swarm_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"swarm": "pipeline"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_swarm_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": " ", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_prompt_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": " "}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_missing_agent_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"broken".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["nonexistent_agent".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "broken", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("unknown agent"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_empty_agents_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"empty".to_string(),
|
||||
SwarmConfig {
|
||||
agents: Vec::new(),
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "empty", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("no agents configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_in_readonly_mode() {
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
readonly,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_when_rate_limited() {
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
limited,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"seq".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
// researcher uses "ollama" which won't be running in CI
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "seq", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Should fail at provider creation or call level
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"par".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "par", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Parallel strategy returns success with error annotations in output
|
||||
assert!(result.success || result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn router_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"rout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "rout", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
//! Tool for managing multi-client workspaces.
|
||||
//!
|
||||
//! Provides `workspace` subcommands: list, switch, create, info, export.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::workspace::WorkspaceManager;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Agent-callable tool for workspace management operations.
|
||||
pub struct WorkspaceTool {
|
||||
manager: Arc<RwLock<WorkspaceManager>>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl WorkspaceTool {
|
||||
pub fn new(manager: Arc<RwLock<WorkspaceManager>>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { manager, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WorkspaceTool {
|
||||
fn name(&self) -> &str {
|
||||
"workspace"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Manage multi-client workspaces. Subcommands: list, switch, create, info, export. Each workspace provides isolated memory, audit, secrets, and tool restrictions."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list", "switch", "create", "info", "export"],
|
||||
"description": "Workspace action to perform"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Workspace name (required for switch, create, export)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'action' parameter"))?;
|
||||
|
||||
let name = args.get("name").and_then(|v| v.as_str());
|
||||
|
||||
match action {
|
||||
"list" => {
|
||||
let mgr = self.manager.read().await;
|
||||
let names = mgr.list();
|
||||
let active = mgr.active_name();
|
||||
|
||||
if names.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No workspaces configured.".to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let mut output = format!("Workspaces ({}):\n", names.len());
|
||||
for ws_name in &names {
|
||||
let marker = if Some(*ws_name) == active {
|
||||
" (active)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let _ = writeln!(output, " - {ws_name}{marker}");
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
"switch" => {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "workspace")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let ws_name = name.ok_or_else(|| {
|
||||
anyhow::anyhow!("'name' parameter is required for switch action")
|
||||
})?;
|
||||
|
||||
let mut mgr = self.manager.write().await;
|
||||
match mgr.switch(ws_name) {
|
||||
Ok(profile) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Switched to workspace '{}'. Memory namespace: {}, Audit namespace: {}",
|
||||
profile.name,
|
||||
profile.effective_memory_namespace(),
|
||||
profile.effective_audit_namespace()
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
"create" => {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "workspace")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let ws_name = name.ok_or_else(|| {
|
||||
anyhow::anyhow!("'name' parameter is required for create action")
|
||||
})?;
|
||||
|
||||
let mut mgr = self.manager.write().await;
|
||||
match mgr.create(ws_name).await {
|
||||
Ok(profile) => {
|
||||
let name = profile.name.clone();
|
||||
let dir = mgr.workspace_dir(ws_name);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Created workspace '{}' at {}", name, dir.display()),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
"info" => {
|
||||
let mgr = self.manager.read().await;
|
||||
let target_name = name.or_else(|| mgr.active_name());
|
||||
|
||||
match target_name {
|
||||
Some(ws_name) => match mgr.get(ws_name) {
|
||||
Some(profile) => {
|
||||
let is_active = mgr.active_name() == Some(ws_name);
|
||||
let mut output = format!("Workspace: {}\n", profile.name);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Status: {}",
|
||||
if is_active { "active" } else { "inactive" }
|
||||
);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Memory namespace: {}",
|
||||
profile.effective_memory_namespace()
|
||||
);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Audit namespace: {}",
|
||||
profile.effective_audit_namespace()
|
||||
);
|
||||
if !profile.allowed_domains.is_empty() {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Allowed domains: {}",
|
||||
profile.allowed_domains.join(", ")
|
||||
);
|
||||
}
|
||||
if !profile.tool_restrictions.is_empty() {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Restricted tools: {}",
|
||||
profile.tool_restrictions.join(", ")
|
||||
);
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("workspace '{}' not found", ws_name)),
|
||||
}),
|
||||
},
|
||||
None => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No workspace is currently active. Use 'workspace switch <name>' to activate one.".to_string(),
|
||||
error: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
"export" => {
|
||||
let mgr = self.manager.read().await;
|
||||
let ws_name = name.or_else(|| mgr.active_name()).ok_or_else(|| {
|
||||
anyhow::anyhow!("'name' parameter is required when no workspace is active")
|
||||
})?;
|
||||
|
||||
match mgr.export(ws_name) {
|
||||
Ok(toml_str) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Exported workspace '{}' config (secrets redacted):\n\n{}",
|
||||
ws_name, toml_str
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"unknown workspace action '{}'. Expected: list, switch, create, info, export",
|
||||
other
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::SecurityPolicy;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_tool(tmp: &TempDir) -> WorkspaceTool {
|
||||
let mgr = WorkspaceManager::new(tmp.path().to_path_buf());
|
||||
WorkspaceTool::new(
|
||||
Arc::new(RwLock::new(mgr)),
|
||||
Arc::new(SecurityPolicy::default()),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_tool_list_empty() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(&tmp);
|
||||
let result = tool.execute(json!({"action": "list"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No workspaces"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_tool_create_and_list() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(&tmp);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"action": "create", "name": "test_client"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("test_client"));
|
||||
|
||||
let result = tool.execute(json!({"action": "list"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("test_client"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_tool_switch_and_info() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(&tmp);
|
||||
|
||||
tool.execute(json!({"action": "create", "name": "ws_test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"action": "switch", "name": "ws_test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Switched to workspace"));
|
||||
|
||||
let result = tool.execute(json!({"action": "info"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("ws_test"));
|
||||
assert!(result.output.contains("active"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_tool_export_redacts() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(&tmp);
|
||||
|
||||
tool.execute(json!({"action": "create", "name": "export_ws"}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"action": "export", "name": "export_ws"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("export_ws"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_tool_unknown_action() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(&tmp);
|
||||
let result = tool.execute(json!({"action": "destroy"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("unknown workspace action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn workspace_tool_switch_nonexistent() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(&tmp);
|
||||
let result = tool
|
||||
.execute(json!({"action": "switch", "name": "ghost"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("not found"));
|
||||
}
|
||||
}
|
||||
+59
-2
@@ -2,6 +2,7 @@ mod cloudflare;
|
||||
mod custom;
|
||||
mod ngrok;
|
||||
mod none;
|
||||
mod openvpn;
|
||||
mod tailscale;
|
||||
|
||||
pub use cloudflare::CloudflareTunnel;
|
||||
@@ -9,6 +10,7 @@ pub use custom::CustomTunnel;
|
||||
pub use ngrok::NgrokTunnel;
|
||||
#[allow(unused_imports)]
|
||||
pub use none::NoneTunnel;
|
||||
pub use openvpn::OpenVpnTunnel;
|
||||
pub use tailscale::TailscaleTunnel;
|
||||
|
||||
use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig};
|
||||
@@ -104,6 +106,20 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
"openvpn" => {
|
||||
let ov = config
|
||||
.openvpn
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?;
|
||||
Ok(Some(Box::new(OpenVpnTunnel::new(
|
||||
ov.config_file.clone(),
|
||||
ov.auth_file.clone(),
|
||||
ov.advertise_address.clone(),
|
||||
ov.connect_timeout_secs,
|
||||
ov.extra_args.clone(),
|
||||
))))
|
||||
}
|
||||
|
||||
"custom" => {
|
||||
let cu = config
|
||||
.custom
|
||||
@@ -116,7 +132,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"),
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +142,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::{
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig,
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig,
|
||||
TunnelConfig,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
@@ -315,6 +332,46 @@ mod tests {
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_missing_config_errors() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
assert_tunnel_err(&cfg, "[tunnel.openvpn]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_with_config_ok() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
openvpn: Some(OpenVpnTunnelConfig {
|
||||
config_file: "client.ovpn".into(),
|
||||
auth_file: None,
|
||||
advertise_address: None,
|
||||
connect_timeout_secs: 30,
|
||||
extra_args: vec![],
|
||||
}),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
let t = create_tunnel(&cfg).unwrap();
|
||||
assert!(t.is_some());
|
||||
assert_eq!(t.unwrap().name(), "openvpn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openvpn_tunnel_name() {
|
||||
let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert_eq!(t.name(), "openvpn");
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openvpn_health_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_shared_no_process_is_ok() {
|
||||
let proc = new_shared_process();
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess};
|
||||
use anyhow::{bail, Result};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection.
|
||||
///
|
||||
/// Requires the `openvpn` binary installed and accessible. On most systems,
|
||||
/// OpenVPN requires root/administrator privileges to create tun/tap devices.
|
||||
///
|
||||
/// The tunnel exposes the gateway via the VPN network using a configured
|
||||
/// `advertise_address` (e.g., `"10.8.0.2:42617"`).
|
||||
pub struct OpenVpnTunnel {
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
proc: SharedProcess,
|
||||
}
|
||||
|
||||
impl OpenVpnTunnel {
|
||||
/// Create a new OpenVPN tunnel instance.
|
||||
///
|
||||
/// * `config_file` — path to the `.ovpn` configuration file.
|
||||
/// * `auth_file` — optional path to a credentials file for `--auth-user-pass`.
|
||||
/// * `advertise_address` — optional public address to advertise once connected.
|
||||
/// * `connect_timeout_secs` — seconds to wait for the initialization sequence.
|
||||
/// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary.
|
||||
pub fn new(
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config_file,
|
||||
auth_file,
|
||||
advertise_address,
|
||||
connect_timeout_secs,
|
||||
extra_args,
|
||||
proc: new_shared_process(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the openvpn command arguments.
|
||||
fn build_args(&self) -> Vec<String> {
|
||||
let mut args = vec!["--config".to_string(), self.config_file.clone()];
|
||||
|
||||
if let Some(ref auth) = self.auth_file {
|
||||
args.push("--auth-user-pass".to_string());
|
||||
args.push(auth.clone());
|
||||
}
|
||||
|
||||
args.extend(self.extra_args.iter().cloned());
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tunnel for OpenVpnTunnel {
|
||||
fn name(&self) -> &str {
|
||||
"openvpn"
|
||||
}
|
||||
|
||||
/// Spawn the `openvpn` process and wait for the "Initialization Sequence
|
||||
/// Completed" marker on stderr. Returns the public URL on success.
|
||||
async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
|
||||
// Validate config file exists before spawning
|
||||
if !std::path::Path::new(&self.config_file).exists() {
|
||||
bail!("OpenVPN config file not found: {}", self.config_file);
|
||||
}
|
||||
|
||||
let args = self.build_args();
|
||||
|
||||
let mut child = Command::new("openvpn")
|
||||
.args(&args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
// Wait for "Initialization Sequence Completed" in stderr
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?;
|
||||
|
||||
let mut reader = tokio::io::BufReader::new(stderr).lines();
|
||||
let deadline = tokio::time::Instant::now()
|
||||
+ tokio::time::Duration::from_secs(self.connect_timeout_secs);
|
||||
|
||||
let mut connected = false;
|
||||
while tokio::time::Instant::now() < deadline {
|
||||
let line =
|
||||
tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await;
|
||||
|
||||
match line {
|
||||
Ok(Ok(Some(l))) => {
|
||||
tracing::debug!("openvpn: {l}");
|
||||
if l.contains("Initialization Sequence Completed") {
|
||||
connected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Ok(None)) => {
|
||||
bail!("OpenVPN process exited before connection was established");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
bail!("Error reading openvpn output: {e}");
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout on individual line read, continue waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !connected {
|
||||
child.kill().await.ok();
|
||||
bail!(
|
||||
"OpenVPN connection timed out after {}s waiting for initialization",
|
||||
self.connect_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
let public_url = self
|
||||
.advertise_address
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("http://{local_host}:{local_port}"));
|
||||
|
||||
// Drain stderr in background to prevent OS pipe buffer from filling and
|
||||
// blocking the openvpn process.
|
||||
tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = reader.next_line().await {
|
||||
tracing::trace!("openvpn: {line}");
|
||||
}
|
||||
});
|
||||
|
||||
let mut guard = self.proc.lock().await;
|
||||
*guard = Some(TunnelProcess {
|
||||
child,
|
||||
public_url: public_url.clone(),
|
||||
});
|
||||
|
||||
Ok(public_url)
|
||||
}
|
||||
|
||||
/// Kill the openvpn child process and release its resources.
|
||||
async fn stop(&self) -> Result<()> {
|
||||
kill_shared(&self.proc).await
|
||||
}
|
||||
|
||||
/// Return `true` if the openvpn child process is still running.
|
||||
async fn health_check(&self) -> bool {
|
||||
let guard = self.proc.lock().await;
|
||||
guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
|
||||
}
|
||||
|
||||
/// Return the public URL if the tunnel has been started.
|
||||
fn public_url(&self) -> Option<String> {
|
||||
self.proc
|
||||
.try_lock()
|
||||
.ok()
|
||||
.and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn constructor_stores_fields() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/etc/openvpn/client.ovpn".into(),
|
||||
Some("/etc/openvpn/auth.txt".into()),
|
||||
Some("10.8.0.2:42617".into()),
|
||||
45,
|
||||
vec!["--verb".into(), "3".into()],
|
||||
);
|
||||
assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn");
|
||||
assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt"));
|
||||
assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617"));
|
||||
assert_eq!(tunnel.connect_timeout_secs, 45);
|
||||
assert_eq!(tunnel.extra_args, vec!["--verb", "3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_basic() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(args, vec!["--config", "client.ovpn"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_with_auth_and_extras() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"client.ovpn".into(),
|
||||
Some("auth.txt".into()),
|
||||
None,
|
||||
30,
|
||||
vec!["--verb".into(), "5".into()],
|
||||
);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(
|
||||
args,
|
||||
vec![
|
||||
"--config",
|
||||
"client.ovpn",
|
||||
"--auth-user-pass",
|
||||
"auth.txt",
|
||||
"--verb",
|
||||
"5"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_url_is_none_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(tunnel.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_check_is_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stop_without_started_process_is_ok() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let result = tunnel.stop().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_with_missing_config_file_errors() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/nonexistent/path/to/client.ovpn".into(),
|
||||
None,
|
||||
None,
|
||||
30,
|
||||
vec![],
|
||||
);
|
||||
let result = tunnel.start("127.0.0.1", 8080).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("config file not found"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user