Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 603aa88f3c | |||
| 1ca2092ca0 | |||
| 5e3308eaaa | |||
| ec255ad788 | |||
| 7182f659ce | |||
| ae7681209d | |||
| ee3469e912 | |||
| fec81d8e75 | |||
| 9a073fae1a | |||
| f0db63e53c | |||
| df4dfeaf66 | |||
| e4ef25e913 | |||
| c3a3cfc9a6 | |||
| 013fca6ad2 | |||
| 23a0f25b44 | |||
| 2eaa8c45f4 | |||
| 85bf649432 | |||
| 3ea99a7619 | |||
| 14f58c77c1 | |||
| b833eb19f5 | |||
| 5db883b453 | |||
| 0ae515b6b8 | |||
| 2deb91455d | |||
| 595b81be41 | |||
| 4f9d817ddb | |||
| d13f5500e9 | |||
| 1ccfe643ba | |||
| d4d3e03e34 | |||
| aa0f11b0a2 | |||
| 806f8b4020 | |||
| 83803cef5b | |||
| dcb182cdd5 | |||
| 7c36a403b0 | |||
| 058dbc8786 | |||
| aff7a19494 | |||
| 0894429b54 | |||
| 6b03e885fc | |||
| 84470a2dd2 | |||
| 8a890be021 | |||
| 74a5ff78e7 | |||
| 93b16dece5 | |||
| c773170753 | |||
| f210b43977 | |||
| 50bc360bf4 | |||
| fc8ed583a0 | |||
| d593b6b1e4 | |||
| 426faa3923 | |||
| 85429b3657 | |||
| 8adf05f307 | |||
| a5f844d7cc | |||
| 7a9e815948 | |||
| 46378cf8b4 | |||
| c2133e6e62 | |||
| e9b3148e73 | |||
| 3d007f6b55 | |||
| f349de78ed | |||
| cd40051f4c | |||
| 6e4b1ede28 | |||
| cfba009833 | |||
| 45abd27e4a | |||
| 566e3cf35b | |||
| d642b0f3c8 | |||
| 98688c61ff | |||
| 9ba5ba5632 | |||
| 318ed8e9f1 | |||
| 2539bcafe0 | |||
| 37d76f7c42 | |||
| 41b46f23e3 | |||
| b4f3e4f37b | |||
| 6b1fe960e3 | |||
| 7a4ea4cbe9 | |||
| c1e05069ea |
@@ -0,0 +1,169 @@
|
||||
name: Pub AUR Package
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Generate PKGBUILD only (no push)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
AUR_SSH_KEY:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Generate PKGBUILD only (no push)"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
concurrency:
|
||||
group: aur-publish-${{ github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
publish-aur:
|
||||
name: Update AUR Package
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ inputs.release_tag }}
|
||||
DRY_RUN: ${{ inputs.dry_run }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Validate and compute metadata
|
||||
id: meta
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ! "$RELEASE_TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "::error::release_tag must be vX.Y.Z format."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
version="${RELEASE_TAG#v}"
|
||||
tarball_url="https://github.com/${GITHUB_REPOSITORY}/archive/refs/tags/${RELEASE_TAG}.tar.gz"
|
||||
tarball_sha="$(curl -fsSL "$tarball_url" | sha256sum | awk '{print $1}')"
|
||||
|
||||
if [[ -z "$tarball_sha" ]]; then
|
||||
echo "::error::Could not compute SHA256 for source tarball."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "tarball_url=$tarball_url"
|
||||
echo "tarball_sha=$tarball_sha"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "### AUR Package Metadata"
|
||||
echo "- version: \`${version}\`"
|
||||
echo "- tarball_url: \`${tarball_url}\`"
|
||||
echo "- tarball_sha: \`${tarball_sha}\`"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Generate PKGBUILD
|
||||
id: pkgbuild
|
||||
shell: bash
|
||||
env:
|
||||
VERSION: ${{ steps.meta.outputs.version }}
|
||||
TARBALL_SHA: ${{ steps.meta.outputs.tarball_sha }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
pkgbuild_file="$(mktemp)"
|
||||
sed -e "s/^pkgver=.*/pkgver=${VERSION}/" \
|
||||
-e "s/^sha256sums=.*/sha256sums=('${TARBALL_SHA}')/" \
|
||||
dist/aur/PKGBUILD > "$pkgbuild_file"
|
||||
|
||||
echo "pkgbuild_file=$pkgbuild_file" >> "$GITHUB_OUTPUT"
|
||||
|
||||
echo "### Generated PKGBUILD" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo '```bash' >> "$GITHUB_STEP_SUMMARY"
|
||||
cat "$pkgbuild_file" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo '```' >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Generate .SRCINFO
|
||||
id: srcinfo
|
||||
shell: bash
|
||||
env:
|
||||
VERSION: ${{ steps.meta.outputs.version }}
|
||||
TARBALL_SHA: ${{ steps.meta.outputs.tarball_sha }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
srcinfo_file="$(mktemp)"
|
||||
sed -e "s/pkgver = .*/pkgver = ${VERSION}/" \
|
||||
-e "s/sha256sums = .*/sha256sums = ${TARBALL_SHA}/" \
|
||||
-e "s|zeroclaw-[0-9.]*.tar.gz|zeroclaw-${VERSION}.tar.gz|g" \
|
||||
-e "s|/v[0-9.]*\.tar\.gz|/v${VERSION}.tar.gz|g" \
|
||||
dist/aur/.SRCINFO > "$srcinfo_file"
|
||||
|
||||
echo "srcinfo_file=$srcinfo_file" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Push to AUR
|
||||
if: inputs.dry_run == false
|
||||
shell: bash
|
||||
env:
|
||||
AUR_SSH_KEY: ${{ secrets.AUR_SSH_KEY }}
|
||||
PKGBUILD_FILE: ${{ steps.pkgbuild.outputs.pkgbuild_file }}
|
||||
SRCINFO_FILE: ${{ steps.srcinfo.outputs.srcinfo_file }}
|
||||
VERSION: ${{ steps.meta.outputs.version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${AUR_SSH_KEY}" ]]; then
|
||||
echo "::error::Secret AUR_SSH_KEY is required for non-dry-run."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p ~/.ssh
|
||||
echo "$AUR_SSH_KEY" > ~/.ssh/aur
|
||||
chmod 600 ~/.ssh/aur
|
||||
cat >> ~/.ssh/config <<SSH_CONFIG
|
||||
Host aur.archlinux.org
|
||||
IdentityFile ~/.ssh/aur
|
||||
User aur
|
||||
StrictHostKeyChecking accept-new
|
||||
SSH_CONFIG
|
||||
|
||||
tmp_dir="$(mktemp -d)"
|
||||
git clone ssh://aur@aur.archlinux.org/zeroclaw.git "$tmp_dir/aur"
|
||||
|
||||
cp "$PKGBUILD_FILE" "$tmp_dir/aur/PKGBUILD"
|
||||
cp "$SRCINFO_FILE" "$tmp_dir/aur/.SRCINFO"
|
||||
|
||||
cd "$tmp_dir/aur"
|
||||
git config user.name "zeroclaw-bot"
|
||||
git config user.email "bot@zeroclaw.dev"
|
||||
git add PKGBUILD .SRCINFO
|
||||
git commit -m "zeroclaw ${VERSION}"
|
||||
git push origin HEAD
|
||||
|
||||
echo "AUR package updated to ${VERSION}"
|
||||
|
||||
- name: Summary
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ "$DRY_RUN" == "true" ]]; then
|
||||
echo "Dry run complete: PKGBUILD generated, no push performed."
|
||||
else
|
||||
echo "Publish complete: AUR package pushed."
|
||||
fi
|
||||
@@ -0,0 +1,206 @@
|
||||
name: Pub Homebrew Core
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag to publish (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Patch formula only (no push/PR)"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
concurrency:
|
||||
group: homebrew-core-${{ github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
publish-homebrew-core:
|
||||
name: Publish Homebrew Core PR
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
UPSTREAM_REPO: Homebrew/homebrew-core
|
||||
FORMULA_PATH: Formula/z/zeroclaw.rb
|
||||
RELEASE_TAG: ${{ inputs.release_tag }}
|
||||
DRY_RUN: ${{ inputs.dry_run }}
|
||||
BOT_FORK_REPO: ${{ vars.HOMEBREW_CORE_BOT_FORK_REPO }}
|
||||
BOT_EMAIL: ${{ vars.HOMEBREW_CORE_BOT_EMAIL }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Validate release tag and version alignment
|
||||
id: release_meta
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
semver_pattern='^v[0-9]+\.[0-9]+\.[0-9]+([.-][0-9A-Za-z.-]+)?$'
|
||||
if [[ ! "$RELEASE_TAG" =~ $semver_pattern ]]; then
|
||||
echo "::error::release_tag must match semver-like format (vX.Y.Z[-suffix])."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! git rev-parse "refs/tags/${RELEASE_TAG}" >/dev/null 2>&1; then
|
||||
git fetch --tags origin
|
||||
fi
|
||||
|
||||
tag_version="${RELEASE_TAG#v}"
|
||||
cargo_version="$(git show "${RELEASE_TAG}:Cargo.toml" \
|
||||
| sed -n 's/^version = "\([^"]*\)"/\1/p' | head -n1)"
|
||||
if [[ -z "$cargo_version" ]]; then
|
||||
echo "::error::Unable to read Cargo.toml version from tag ${RELEASE_TAG}."
|
||||
exit 1
|
||||
fi
|
||||
if [[ "$cargo_version" != "$tag_version" ]]; then
|
||||
echo "::error::Tag ${RELEASE_TAG} does not match Cargo.toml version (${cargo_version})."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
tarball_url="https://github.com/${GITHUB_REPOSITORY}/archive/refs/tags/${RELEASE_TAG}.tar.gz"
|
||||
tarball_sha="$(curl -fsSL "$tarball_url" | sha256sum | awk '{print $1}')"
|
||||
|
||||
{
|
||||
echo "tag_version=$tag_version"
|
||||
echo "tarball_url=$tarball_url"
|
||||
echo "tarball_sha=$tarball_sha"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "### Release Metadata"
|
||||
echo "- release_tag: \`${RELEASE_TAG}\`"
|
||||
echo "- cargo_version: \`${cargo_version}\`"
|
||||
echo "- tarball_sha256: \`${tarball_sha}\`"
|
||||
echo "- dry_run: ${DRY_RUN}"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Patch Homebrew formula
|
||||
id: patch_formula
|
||||
shell: bash
|
||||
env:
|
||||
HOMEBREW_CORE_BOT_TOKEN: ${{ secrets.HOMEBREW_UPSTREAM_PR_TOKEN || secrets.HOMEBREW_CORE_BOT_TOKEN }}
|
||||
GH_TOKEN: ${{ secrets.HOMEBREW_UPSTREAM_PR_TOKEN || secrets.HOMEBREW_CORE_BOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
tmp_repo="$(mktemp -d)"
|
||||
echo "tmp_repo=$tmp_repo" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if [[ "$DRY_RUN" == "true" ]]; then
|
||||
git clone --depth=1 "https://github.com/${UPSTREAM_REPO}.git" "$tmp_repo/homebrew-core"
|
||||
else
|
||||
if [[ -z "${BOT_FORK_REPO}" ]]; then
|
||||
echo "::error::Repository variable HOMEBREW_CORE_BOT_FORK_REPO is required when dry_run=false."
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "${HOMEBREW_CORE_BOT_TOKEN}" ]]; then
|
||||
echo "::error::Repository secret HOMEBREW_CORE_BOT_TOKEN is required when dry_run=false."
|
||||
exit 1
|
||||
fi
|
||||
if [[ "$BOT_FORK_REPO" != */* ]]; then
|
||||
echo "::error::HOMEBREW_CORE_BOT_FORK_REPO must be in owner/repo format."
|
||||
exit 1
|
||||
fi
|
||||
if ! gh api "repos/${BOT_FORK_REPO}" >/dev/null 2>&1; then
|
||||
echo "::error::HOMEBREW_CORE_BOT_TOKEN cannot access ${BOT_FORK_REPO}."
|
||||
exit 1
|
||||
fi
|
||||
gh repo clone "${BOT_FORK_REPO}" "$tmp_repo/homebrew-core" -- --depth=1
|
||||
fi
|
||||
|
||||
repo_dir="$tmp_repo/homebrew-core"
|
||||
formula_file="$repo_dir/$FORMULA_PATH"
|
||||
if [[ ! -f "$formula_file" ]]; then
|
||||
echo "::error::Formula file not found: $FORMULA_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$DRY_RUN" == "false" ]]; then
|
||||
if git -C "$repo_dir" remote get-url upstream >/dev/null 2>&1; then
|
||||
git -C "$repo_dir" remote set-url upstream "https://github.com/${UPSTREAM_REPO}.git"
|
||||
else
|
||||
git -C "$repo_dir" remote add upstream "https://github.com/${UPSTREAM_REPO}.git"
|
||||
fi
|
||||
if git -C "$repo_dir" ls-remote --exit-code --heads upstream main >/dev/null 2>&1; then
|
||||
upstream_ref="main"
|
||||
else
|
||||
upstream_ref="master"
|
||||
fi
|
||||
git -C "$repo_dir" fetch --depth=1 upstream "$upstream_ref"
|
||||
branch_name="zeroclaw-${RELEASE_TAG}-${GITHUB_RUN_ID}"
|
||||
git -C "$repo_dir" checkout -B "$branch_name" "upstream/$upstream_ref"
|
||||
echo "branch_name=$branch_name" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
tarball_url="$(grep 'tarball_url=' "$GITHUB_OUTPUT" | head -1 | cut -d= -f2-)"
|
||||
tarball_sha="$(grep 'tarball_sha=' "$GITHUB_OUTPUT" | head -1 | cut -d= -f2-)"
|
||||
|
||||
perl -0pi -e "s|^ url \".*\"| url \"${tarball_url}\"|m" "$formula_file"
|
||||
perl -0pi -e "s|^ sha256 \".*\"| sha256 \"${tarball_sha}\"|m" "$formula_file"
|
||||
perl -0pi -e "s|^ license \".*\"| license \"Apache-2.0 OR MIT\"|m" "$formula_file"
|
||||
|
||||
git -C "$repo_dir" diff -- "$FORMULA_PATH" > "$tmp_repo/formula.diff"
|
||||
if [[ ! -s "$tmp_repo/formula.diff" ]]; then
|
||||
echo "::error::No formula changes generated. Nothing to publish."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
{
|
||||
echo "### Formula Diff"
|
||||
echo '```diff'
|
||||
cat "$tmp_repo/formula.diff"
|
||||
echo '```'
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Push branch and open Homebrew PR
|
||||
if: inputs.dry_run == false
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.HOMEBREW_UPSTREAM_PR_TOKEN || secrets.HOMEBREW_CORE_BOT_TOKEN }}
|
||||
TMP_REPO: ${{ steps.patch_formula.outputs.tmp_repo }}
|
||||
BRANCH_NAME: ${{ steps.patch_formula.outputs.branch_name }}
|
||||
TAG_VERSION: ${{ steps.release_meta.outputs.tag_version }}
|
||||
TARBALL_URL: ${{ steps.release_meta.outputs.tarball_url }}
|
||||
TARBALL_SHA: ${{ steps.release_meta.outputs.tarball_sha }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
repo_dir="${TMP_REPO}/homebrew-core"
|
||||
fork_owner="${BOT_FORK_REPO%%/*}"
|
||||
bot_email="${BOT_EMAIL:-${fork_owner}@users.noreply.github.com}"
|
||||
|
||||
git -C "$repo_dir" config user.name "$fork_owner"
|
||||
git -C "$repo_dir" config user.email "$bot_email"
|
||||
git -C "$repo_dir" add "$FORMULA_PATH"
|
||||
git -C "$repo_dir" commit -m "zeroclaw ${TAG_VERSION}"
|
||||
gh auth setup-git
|
||||
git -C "$repo_dir" push --set-upstream origin "$BRANCH_NAME"
|
||||
|
||||
pr_body="Automated formula bump from ZeroClaw release workflow.
|
||||
|
||||
- Release tag: ${RELEASE_TAG}
|
||||
- Source tarball: ${TARBALL_URL}
|
||||
- Source sha256: ${TARBALL_SHA}"
|
||||
|
||||
gh pr create \
|
||||
--repo "$UPSTREAM_REPO" \
|
||||
--base main \
|
||||
--head "${fork_owner}:${BRANCH_NAME}" \
|
||||
--title "zeroclaw ${TAG_VERSION}" \
|
||||
--body "$pr_body"
|
||||
|
||||
- name: Summary
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ "$DRY_RUN" == "true" ]]; then
|
||||
echo "Dry run complete: formula diff generated, no push/PR performed."
|
||||
else
|
||||
echo "Publish complete: branch pushed and PR opened from bot fork."
|
||||
fi
|
||||
@@ -0,0 +1,165 @@
|
||||
name: Pub Scoop Manifest
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Generate manifest only (no push)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
SCOOP_BUCKET_TOKEN:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Generate manifest only (no push)"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
concurrency:
|
||||
group: scoop-publish-${{ github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
publish-scoop:
|
||||
name: Update Scoop Manifest
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ inputs.release_tag }}
|
||||
DRY_RUN: ${{ inputs.dry_run }}
|
||||
SCOOP_BUCKET_REPO: ${{ vars.SCOOP_BUCKET_REPO }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Validate and compute metadata
|
||||
id: meta
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ! "$RELEASE_TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "::error::release_tag must be vX.Y.Z format."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
version="${RELEASE_TAG#v}"
|
||||
zip_url="https://github.com/${GITHUB_REPOSITORY}/releases/download/${RELEASE_TAG}/zeroclaw-x86_64-pc-windows-msvc.zip"
|
||||
sums_url="https://github.com/${GITHUB_REPOSITORY}/releases/download/${RELEASE_TAG}/SHA256SUMS"
|
||||
|
||||
sha256="$(curl -fsSL "$sums_url" | grep 'zeroclaw-x86_64-pc-windows-msvc.zip' | awk '{print $1}')"
|
||||
|
||||
if [[ -z "$sha256" ]]; then
|
||||
echo "::error::Could not find Windows binary hash in SHA256SUMS for ${RELEASE_TAG}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "zip_url=$zip_url"
|
||||
echo "sha256=$sha256"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "### Scoop Manifest Metadata"
|
||||
echo "- version: \`${version}\`"
|
||||
echo "- zip_url: \`${zip_url}\`"
|
||||
echo "- sha256: \`${sha256}\`"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Generate manifest
|
||||
id: manifest
|
||||
shell: bash
|
||||
env:
|
||||
VERSION: ${{ steps.meta.outputs.version }}
|
||||
ZIP_URL: ${{ steps.meta.outputs.zip_url }}
|
||||
SHA256: ${{ steps.meta.outputs.sha256 }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
manifest_file="$(mktemp)"
|
||||
cat > "$manifest_file" <<MANIFEST
|
||||
{
|
||||
"version": "${VERSION}",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "${ZIP_URL}",
|
||||
"hash": "${SHA256}",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
},
|
||||
"checkver": {
|
||||
"github": "https://github.com/zeroclaw-labs/zeroclaw"
|
||||
},
|
||||
"autoupdate": {
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v\$version/zeroclaw-x86_64-pc-windows-msvc.zip"
|
||||
}
|
||||
},
|
||||
"hash": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v\$version/SHA256SUMS",
|
||||
"regex": "([a-f0-9]{64})\\\\s+zeroclaw-x86_64-pc-windows-msvc\\\\.zip"
|
||||
}
|
||||
}
|
||||
}
|
||||
MANIFEST
|
||||
|
||||
jq '.' "$manifest_file" > "${manifest_file}.formatted"
|
||||
mv "${manifest_file}.formatted" "$manifest_file"
|
||||
|
||||
echo "manifest_file=$manifest_file" >> "$GITHUB_OUTPUT"
|
||||
|
||||
echo "### Generated Manifest" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo '```json' >> "$GITHUB_STEP_SUMMARY"
|
||||
cat "$manifest_file" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo '```' >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Push to Scoop bucket
|
||||
if: inputs.dry_run == false
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.SCOOP_BUCKET_TOKEN }}
|
||||
MANIFEST_FILE: ${{ steps.manifest.outputs.manifest_file }}
|
||||
VERSION: ${{ steps.meta.outputs.version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ -z "${SCOOP_BUCKET_REPO}" ]]; then
|
||||
echo "::error::Repository variable SCOOP_BUCKET_REPO is required (e.g. zeroclaw-labs/scoop-zeroclaw)."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
tmp_dir="$(mktemp -d)"
|
||||
gh repo clone "${SCOOP_BUCKET_REPO}" "$tmp_dir/bucket" -- --depth=1
|
||||
|
||||
mkdir -p "$tmp_dir/bucket/bucket"
|
||||
cp "$MANIFEST_FILE" "$tmp_dir/bucket/bucket/zeroclaw.json"
|
||||
|
||||
cd "$tmp_dir/bucket"
|
||||
git config user.name "zeroclaw-bot"
|
||||
git config user.email "bot@zeroclaw.dev"
|
||||
git add bucket/zeroclaw.json
|
||||
git commit -m "zeroclaw ${VERSION}"
|
||||
gh auth setup-git
|
||||
git push origin HEAD
|
||||
|
||||
echo "Scoop manifest updated to ${VERSION}"
|
||||
@@ -103,9 +103,20 @@ jobs:
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish to crates.io
|
||||
run: cargo publish --locked --allow-dirty --no-verify
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
VERSION: ${{ needs.detect-version-change.outputs.version }}
|
||||
run: |
|
||||
# Publish to crates.io; treat "already exists" as success
|
||||
# (manual publish or stable workflow may have already published)
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::zeroclawlabs@${VERSION} already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Verify published
|
||||
shell: bash
|
||||
|
||||
@@ -75,6 +75,16 @@ jobs:
|
||||
|
||||
- name: Publish to crates.io
|
||||
if: "!inputs.dry_run"
|
||||
run: cargo publish --locked --allow-dirty --no-verify
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
# Publish to crates.io; treat "already exists" as success
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::zeroclawlabs@${VERSION} already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
@@ -5,8 +5,8 @@ on:
|
||||
branches: [master]
|
||||
|
||||
concurrency:
|
||||
group: release
|
||||
cancel-in-progress: false
|
||||
group: release-beta
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -318,10 +318,13 @@ jobs:
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# ── Post-publish: only run after ALL artifacts are live ──────────────
|
||||
# ── Post-publish: tweet after release + website are live ──────────────
|
||||
# Docker is slow (multi-platform) and can be cancelled by concurrency;
|
||||
# don't let it block the tweet.
|
||||
tweet:
|
||||
name: Tweet Release
|
||||
needs: [version, publish, docker, redeploy-website]
|
||||
needs: [version, publish, redeploy-website]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/tweet-release.yml
|
||||
with:
|
||||
release_tag: ${{ needs.version.outputs.tag }}
|
||||
|
||||
@@ -307,13 +307,16 @@ jobs:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
# Skip if this version is already on crates.io (auto-sync may have published it)
|
||||
# Publish to crates.io; treat "already exists" as success
|
||||
# (auto-publish workflow may have already published this version)
|
||||
CRATE_NAME=$(sed -n 's/^name = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
if curl -sfL "https://crates.io/api/v1/crates/${CRATE_NAME}/${VERSION}" | grep -q '"version"'; then
|
||||
echo "::notice::${CRATE_NAME}@${VERSION} already published on crates.io — skipping"
|
||||
else
|
||||
cargo publish --locked --allow-dirty --no-verify
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::${CRATE_NAME}@${VERSION} already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
redeploy-website:
|
||||
name: Trigger Website Redeploy
|
||||
@@ -358,10 +361,33 @@ jobs:
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# ── Post-publish: only run after ALL artifacts are live ──────────────
|
||||
# ── Post-publish: package manager auto-sync ─────────────────────────
|
||||
scoop:
|
||||
name: Update Scoop Manifest
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-scoop.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
aur:
|
||||
name: Update AUR Package
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-aur.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
# ── Post-publish: tweet after release + website are live ──────────────
|
||||
# Docker push can be slow; don't let it block the tweet.
|
||||
tweet:
|
||||
name: Tweet Release
|
||||
needs: [validate, publish, docker, crates-io, redeploy-website]
|
||||
needs: [validate, publish, redeploy-website]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/tweet-release.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
|
||||
@@ -53,7 +53,15 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Find the PREVIOUS release tag (including betas) to check for new features
|
||||
# Stable releases (no -beta suffix) always tweet — they represent
|
||||
# the full release cycle, so skipping them loses visibility.
|
||||
if [[ ! "$RELEASE_TAG" =~ -beta\. ]]; then
|
||||
echo "Stable release ${RELEASE_TAG} — always tweet"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# For betas: find the PREVIOUS release tag to check for new features
|
||||
PREV_TAG=$(git tag --sort=-creatordate \
|
||||
| grep -v "^${RELEASE_TAG}$" \
|
||||
| head -1 || echo "")
|
||||
@@ -63,15 +71,15 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Count new feat() commits since the previous release
|
||||
NEW_FEATS=$(git log "${PREV_TAG}..${RELEASE_TAG}" --pretty=format:"%s" --no-merges \
|
||||
| grep -ciE '^feat(\(|:)' || echo "0")
|
||||
# Count new feat() OR fix() commits since the previous release
|
||||
NEW_CHANGES=$(git log "${PREV_TAG}..${RELEASE_TAG}" --pretty=format:"%s" --no-merges \
|
||||
| grep -ciE '^(feat|fix)(\(|:)' || echo "0")
|
||||
|
||||
if [ "$NEW_FEATS" -eq 0 ]; then
|
||||
echo "No new features since ${PREV_TAG} — skipping tweet"
|
||||
if [ "$NEW_CHANGES" -eq 0 ]; then
|
||||
echo "No new features or fixes since ${PREV_TAG} — skipping tweet"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "${NEW_FEATS} new feature(s) since ${PREV_TAG} — tweeting"
|
||||
echo "${NEW_CHANGES} new change(s) since ${PREV_TAG} — tweeting"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
|
||||
+2
-1
@@ -1,7 +1,8 @@
|
||||
/target
|
||||
/target-*/
|
||||
firmware/*/target
|
||||
web/dist/
|
||||
web/dist/*
|
||||
!web/dist/.gitkeep
|
||||
*.db
|
||||
*.db-journal
|
||||
.DS_Store
|
||||
|
||||
Generated
+1
-1
@@ -7945,7 +7945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.4.0"
|
||||
version = "0.4.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-imap",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.4.0"
|
||||
version = "0.4.3"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
+8
-1
@@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1.7
|
||||
|
||||
# ── Stage 1: Build ────────────────────────────────────────────
|
||||
FROM rust:1.93-slim@sha256:9663b80a1621253d30b146454f903de48f0af925c967be48c84745537cd35d8b AS builder
|
||||
FROM rust:1.94-slim@sha256:7d3701660d2aa7101811ba0c54920021452aa60e5bae073b79c2b137a432b2f4 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -33,6 +33,7 @@ COPY benches/ benches/
|
||||
COPY crates/ crates/
|
||||
COPY firmware/ firmware/
|
||||
COPY web/ web/
|
||||
COPY *.rs .
|
||||
# Keep release builds resilient when frontend dist assets are not prebuilt in Git.
|
||||
RUN mkdir -p web/dist && \
|
||||
if [ ! -f web/dist/index.html ]; then \
|
||||
@@ -50,12 +51,18 @@ RUN mkdir -p web/dist && \
|
||||
' </body>' \
|
||||
'</html>' > web/dist/index.html; \
|
||||
fi
|
||||
RUN touch src/main.rs
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||
rm -rf target/release/.fingerprint/zeroclawlabs-* \
|
||||
target/release/deps/zeroclawlabs-* \
|
||||
target/release/incremental/zeroclawlabs-* && \
|
||||
cargo build --release --locked && \
|
||||
cp target/release/zeroclaw /app/zeroclaw && \
|
||||
strip /app/zeroclaw
|
||||
RUN size=$(stat -c%s /app/zeroclaw 2>/dev/null || stat -f%z /app/zeroclaw) && \
|
||||
if [ "$size" -lt 1000000 ]; then echo "ERROR: binary too small (${size} bytes), likely dummy build artifact" && exit 1; fi
|
||||
|
||||
# Prepare runtime directory structure and default config inline (no extra stage)
|
||||
RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace && \
|
||||
|
||||
+5
-1
@@ -16,7 +16,7 @@
|
||||
# docker compose -f docker-compose.yml -f docker-compose.debian.yml up
|
||||
|
||||
# ── Stage 1: Build (identical to main Dockerfile) ───────────
|
||||
FROM rust:1.93-slim@sha256:9663b80a1621253d30b146454f903de48f0af925c967be48c84745537cd35d8b AS builder
|
||||
FROM rust:1.94-slim@sha256:7d3701660d2aa7101811ba0c54920021452aa60e5bae073b79c2b137a432b2f4 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -33,6 +33,7 @@ COPY crates/robot-kit/Cargo.toml crates/robot-kit/Cargo.toml
|
||||
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
|
||||
RUN mkdir -p src benches crates/robot-kit/src \
|
||||
&& echo "fn main() {}" > src/main.rs \
|
||||
&& echo "" > src/lib.rs \
|
||||
&& echo "fn main() {}" > benches/agent_benchmarks.rs \
|
||||
&& echo "pub fn placeholder() {}" > crates/robot-kit/src/lib.rs
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
@@ -64,12 +65,15 @@ RUN mkdir -p web/dist && \
|
||||
' </body>' \
|
||||
'</html>' > web/dist/index.html; \
|
||||
fi
|
||||
RUN touch src/main.rs
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||
cargo build --release --locked && \
|
||||
cp target/release/zeroclaw /app/zeroclaw && \
|
||||
strip /app/zeroclaw
|
||||
RUN size=$(stat -c%s /app/zeroclaw 2>/dev/null || stat -f%z /app/zeroclaw) && \
|
||||
if [ "$size" -lt 1000000 ]; then echo "ERROR: binary too small (${size} bytes), likely dummy build artifact" && exit 1; fi
|
||||
|
||||
# Prepare runtime directory structure and default config inline (no extra stage)
|
||||
RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace && \
|
||||
|
||||
Vendored
+16
@@ -0,0 +1,16 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.4.3
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
license = MIT
|
||||
license = Apache-2.0
|
||||
makedepends = cargo
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.4.3.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.4.3.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
Vendored
+32
@@ -0,0 +1,32 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.4.3
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
url="https://github.com/zeroclaw-labs/zeroclaw"
|
||||
license=('MIT' 'Apache-2.0')
|
||||
depends=('gcc-libs' 'openssl')
|
||||
makedepends=('cargo' 'git')
|
||||
source=("${pkgname}-${pkgver}.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v${pkgver}.tar.gz")
|
||||
sha256sums=('SKIP')
|
||||
|
||||
prepare() {
|
||||
cd "${pkgname}-${pkgver}"
|
||||
export RUSTUP_TOOLCHAIN=stable
|
||||
cargo fetch --locked --target "$(rustc -vV | sed -n 's/host: //p')"
|
||||
}
|
||||
|
||||
build() {
|
||||
cd "${pkgname}-${pkgver}"
|
||||
export RUSTUP_TOOLCHAIN=stable
|
||||
export CARGO_TARGET_DIR=target
|
||||
cargo build --frozen --release --profile dist
|
||||
}
|
||||
|
||||
package() {
|
||||
cd "${pkgname}-${pkgver}"
|
||||
install -Dm0755 -t "${pkgdir}/usr/bin/" "target/dist/zeroclaw"
|
||||
install -Dm0644 LICENSE-MIT "${pkgdir}/usr/share/licenses/${pkgname}/LICENSE-MIT"
|
||||
install -Dm0644 LICENSE-APACHE "${pkgdir}/usr/share/licenses/${pkgname}/LICENSE-APACHE"
|
||||
}
|
||||
Vendored
+27
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"version": "0.4.3",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.4.3/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
},
|
||||
"checkver": {
|
||||
"github": "https://github.com/zeroclaw-labs/zeroclaw"
|
||||
},
|
||||
"autoupdate": {
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v$version/zeroclaw-x86_64-pc-windows-msvc.zip"
|
||||
}
|
||||
},
|
||||
"hash": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v$version/SHA256SUMS",
|
||||
"regex": "([a-f0-9]{64})\\s+zeroclaw-x86_64-pc-windows-msvc\\.zip"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,12 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
||||
- `.github/workflows/pub-homebrew-core.yml` (`Pub Homebrew Core`)
|
||||
- Purpose: manual, bot-owned Homebrew core formula bump PR flow for tagged releases
|
||||
- Guardrail: release tag must match `Cargo.toml` version
|
||||
- `.github/workflows/pub-scoop.yml` (`Pub Scoop Manifest`)
|
||||
- Purpose: Scoop bucket manifest update for Windows; auto-called by stable release, also manual dispatch
|
||||
- Guardrail: release tag must be `vX.Y.Z` format; Windows binary hash extracted from `SHA256SUMS`
|
||||
- `.github/workflows/pub-aur.yml` (`Pub AUR Package`)
|
||||
- Purpose: AUR PKGBUILD push for Arch Linux; auto-called by stable release, also manual dispatch
|
||||
- Guardrail: release tag must be `vX.Y.Z` format; source tarball SHA256 computed at publish time
|
||||
- `.github/workflows/pr-label-policy-check.yml` (`Label Policy Sanity`)
|
||||
- Purpose: validate shared contributor-tier policy in `.github/label-policy.json` and ensure label workflows consume that policy
|
||||
- `.github/workflows/test-rust-build.yml` (`Rust Reusable Job`)
|
||||
@@ -75,6 +81,8 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
||||
- `Docker`: tag push (`v*`) for publish, matching PRs to `master` for smoke build, manual dispatch for smoke only
|
||||
- `Release`: tag push (`v*`), weekly schedule (verification-only), manual dispatch (verification or publish)
|
||||
- `Pub Homebrew Core`: manual dispatch only
|
||||
- `Pub Scoop Manifest`: auto-called by stable release, also manual dispatch
|
||||
- `Pub AUR Package`: auto-called by stable release, also manual dispatch
|
||||
- `Security Audit`: push to `master`, PRs to `master`, weekly schedule
|
||||
- `Sec Vorpal Reviewdog`: manual dispatch only
|
||||
- `Workflow Sanity`: PR/push when `.github/workflows/**`, `.github/*.yml`, or `.github/*.yaml` change
|
||||
@@ -92,12 +100,14 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
||||
2. Docker failures on PRs: inspect `.github/workflows/pub-docker-img.yml` `pr-smoke` job.
|
||||
3. Release failures (tag/manual/scheduled): inspect `.github/workflows/pub-release.yml` and the `prepare` job outputs.
|
||||
4. Homebrew formula publish failures: inspect `.github/workflows/pub-homebrew-core.yml` summary output and bot token/fork variables.
|
||||
5. Security failures: inspect `.github/workflows/sec-audit.yml` and `deny.toml`.
|
||||
6. Workflow syntax/lint failures: inspect `.github/workflows/workflow-sanity.yml`.
|
||||
7. PR intake failures: inspect `.github/workflows/pr-intake-checks.yml` sticky comment and run logs.
|
||||
8. Label policy parity failures: inspect `.github/workflows/pr-label-policy-check.yml`.
|
||||
9. Docs failures in CI: inspect `docs-quality` job logs in `.github/workflows/ci-run.yml`.
|
||||
10. Strict delta lint failures in CI: inspect `lint-strict-delta` job logs and compare with `BASE_SHA` diff scope.
|
||||
5. Scoop manifest publish failures: inspect `.github/workflows/pub-scoop.yml` summary output and `SCOOP_BUCKET_REPO`/`SCOOP_BUCKET_TOKEN` settings.
|
||||
6. AUR package publish failures: inspect `.github/workflows/pub-aur.yml` summary output and `AUR_SSH_KEY` secret.
|
||||
7. Security failures: inspect `.github/workflows/sec-audit.yml` and `deny.toml`.
|
||||
8. Workflow syntax/lint failures: inspect `.github/workflows/workflow-sanity.yml`.
|
||||
9. PR intake failures: inspect `.github/workflows/pr-intake-checks.yml` sticky comment and run logs.
|
||||
10. Label policy parity failures: inspect `.github/workflows/pr-label-policy-check.yml`.
|
||||
11. Docs failures in CI: inspect `docs-quality` job logs in `.github/workflows/ci-run.yml`.
|
||||
12. Strict delta lint failures in CI: inspect `lint-strict-delta` job logs and compare with `BASE_SHA` diff scope.
|
||||
|
||||
## Maintenance Rules
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ Release automation lives in:
|
||||
|
||||
- `.github/workflows/pub-release.yml`
|
||||
- `.github/workflows/pub-homebrew-core.yml` (manual Homebrew formula PR, bot-owned)
|
||||
- `.github/workflows/pub-scoop.yml` (manual Scoop bucket manifest update)
|
||||
- `.github/workflows/pub-aur.yml` (manual AUR PKGBUILD push)
|
||||
|
||||
Modes:
|
||||
|
||||
@@ -115,6 +117,41 @@ Workflow guardrails:
|
||||
- formula license is normalized to `Apache-2.0 OR MIT`
|
||||
- PR is opened from the bot fork into `Homebrew/homebrew-core:master`
|
||||
|
||||
### 7) Publish Scoop manifest (Windows)
|
||||
|
||||
Run `Pub Scoop Manifest` manually:
|
||||
|
||||
- `release_tag`: `vX.Y.Z`
|
||||
- `dry_run`: `true` first, then `false`
|
||||
|
||||
Required repository settings for non-dry-run:
|
||||
|
||||
- secret: `SCOOP_BUCKET_TOKEN` (PAT with push access to the bucket repo)
|
||||
- variable: `SCOOP_BUCKET_REPO` (for example `zeroclaw-labs/scoop-zeroclaw`)
|
||||
|
||||
Workflow guardrails:
|
||||
|
||||
- release tag must be `vX.Y.Z` format
|
||||
- Windows binary SHA256 extracted from `SHA256SUMS` release asset
|
||||
- manifest pushed to `bucket/zeroclaw.json` in the Scoop bucket repo
|
||||
|
||||
### 8) Publish AUR package (Arch Linux)
|
||||
|
||||
Run `Pub AUR Package` manually:
|
||||
|
||||
- `release_tag`: `vX.Y.Z`
|
||||
- `dry_run`: `true` first, then `false`
|
||||
|
||||
Required repository settings for non-dry-run:
|
||||
|
||||
- secret: `AUR_SSH_KEY` (SSH private key registered with AUR)
|
||||
|
||||
Workflow guardrails:
|
||||
|
||||
- release tag must be `vX.Y.Z` format
|
||||
- source tarball SHA256 computed from the tagged release
|
||||
- PKGBUILD and .SRCINFO pushed to AUR `zeroclaw` package
|
||||
|
||||
## Emergency / Recovery Path
|
||||
|
||||
If tag-push release fails after artifacts are validated:
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
# OpenAI Temperature Compatibility Reference
|
||||
|
||||
This document provides empirical evidence for temperature parameter compatibility across OpenAI models.
|
||||
|
||||
## Summary
|
||||
|
||||
Different OpenAI model families have different temperature requirements:
|
||||
|
||||
- **Reasoning models** (o-series, gpt-5 base variants): Only accept `temperature=1.0`
|
||||
- **Search models**: Do not accept temperature parameter (must be omitted)
|
||||
- **Standard models** (gpt-3.5, gpt-4, gpt-4o): Accept flexible temperature values (0.0-2.0)
|
||||
|
||||
## Tested Models
|
||||
|
||||
### Models Requiring temperature=1.0
|
||||
|
||||
| Model | Accepts 0.7 | Accepts 1.0 | Recommendation |
|
||||
|-------|-------------|-------------|----------------|
|
||||
| o1 | ❌ | ✅ | USE_1.0 |
|
||||
| o1-2024-12-17 | ❌ | ✅ | USE_1.0 |
|
||||
| o3 | ❌ | ✅ | USE_1.0 |
|
||||
| o3-2025-04-16 | ❌ | ✅ | USE_1.0 |
|
||||
| o3-mini | ❌ | ✅ | USE_1.0 |
|
||||
| o3-mini-2025-01-31 | ❌ | ✅ | USE_1.0 |
|
||||
| o4-mini | ❌ | ✅ | USE_1.0 |
|
||||
| o4-mini-2025-04-16 | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5 | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5-2025-08-07 | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5-mini | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5-mini-2025-08-07 | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5-nano | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5-nano-2025-08-07 | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5.1-chat-latest | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5.2-chat-latest | ❌ | ✅ | USE_1.0 |
|
||||
| gpt-5.3-chat-latest | ❌ | ✅ | USE_1.0 |
|
||||
|
||||
### Models Accepting Flexible Temperature (0.7 works)
|
||||
|
||||
All standard GPT models accept flexible temperature values:
|
||||
- gpt-3.5-turbo (all variants)
|
||||
- gpt-4 (all variants)
|
||||
- gpt-4-turbo (all variants)
|
||||
- gpt-4o (all variants)
|
||||
- gpt-4o-mini (all variants)
|
||||
- gpt-4.1 (all variants)
|
||||
- gpt-5-chat-latest
|
||||
- gpt-5.2, gpt-5.2-2025-12-11
|
||||
- gpt-5.4, gpt-5.4-2026-03-05
|
||||
|
||||
### Models Requiring Temperature Omission
|
||||
|
||||
Search-preview models do not accept temperature parameter:
|
||||
- gpt-4o-mini-search-preview
|
||||
- gpt-4o-search-preview
|
||||
- gpt-5-search-api
|
||||
|
||||
## Implementation
|
||||
|
||||
The `adjust_temperature_for_model()` function in `src/providers/openai.rs` automatically adjusts temperature to 1.0 for reasoning models while preserving user-specified values for standard models.
|
||||
|
||||
## Testing Methodology
|
||||
|
||||
Models were tested with:
|
||||
1. No temperature parameter (baseline)
|
||||
2. temperature=0.7 (common default)
|
||||
3. temperature=1.0 (reasoning model requirement)
|
||||
|
||||
Results were validated against actual OpenAI API responses.
|
||||
|
||||
## References
|
||||
|
||||
- OpenAI API Documentation: https://platform.openai.com/docs/api-reference/chat
|
||||
- Related Issue: Temperature errors with o1/o3/gpt-5 models
|
||||
@@ -22,6 +22,64 @@ For first-time installation, start from [one-click-bootstrap.md](../setup-guides
|
||||
| Foreground runtime | `zeroclaw daemon` | local debugging, short-lived sessions |
|
||||
| Foreground gateway only | `zeroclaw gateway` | webhook endpoint testing |
|
||||
| User service | `zeroclaw service install && zeroclaw service start` | persistent operator-managed runtime |
|
||||
| Docker / Podman | `docker compose up -d` | containerized deployment |
|
||||
|
||||
## Docker / Podman Runtime
|
||||
|
||||
If you installed via `./install.sh --docker`, the container exits after onboarding. To run
|
||||
ZeroClaw as a long-lived container, use the repository `docker-compose.yml` or start a
|
||||
container manually against the persisted data directory.
|
||||
|
||||
### Recommended: docker-compose
|
||||
|
||||
```bash
|
||||
# Start (detached, auto-restarts on reboot)
|
||||
docker compose up -d
|
||||
|
||||
# Stop
|
||||
docker compose down
|
||||
|
||||
# Restart
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
Replace `docker` with `podman` if using Podman.
|
||||
|
||||
### Manual container lifecycle
|
||||
|
||||
```bash
|
||||
# Start a new container from the bootstrap image
|
||||
docker run -d --name zeroclaw \
|
||||
--restart unless-stopped \
|
||||
-v "$PWD/.zeroclaw-docker/.zeroclaw:/zeroclaw-data/.zeroclaw" \
|
||||
-v "$PWD/.zeroclaw-docker/workspace:/zeroclaw-data/workspace" \
|
||||
-e HOME=/zeroclaw-data \
|
||||
-e ZEROCLAW_WORKSPACE=/zeroclaw-data/workspace \
|
||||
-p 42617:42617 \
|
||||
zeroclaw-bootstrap:local \
|
||||
gateway
|
||||
|
||||
# Stop (preserves config and workspace)
|
||||
docker stop zeroclaw
|
||||
|
||||
# Restart a stopped container
|
||||
docker start zeroclaw
|
||||
|
||||
# View logs
|
||||
docker logs -f zeroclaw
|
||||
|
||||
# Health check
|
||||
docker exec zeroclaw zeroclaw status
|
||||
```
|
||||
|
||||
For Podman, add `--userns keep-id --user "$(id -u):$(id -g)"` and append `:Z` to volume mounts.
|
||||
|
||||
### Key detail: do not re-run install.sh to restart
|
||||
|
||||
Re-running `install.sh --docker` rebuilds the image and re-runs onboarding. To simply
|
||||
restart, use `docker start`, `docker compose up -d`, or `podman start`.
|
||||
|
||||
For full setup instructions, see [one-click-bootstrap.md](../setup-guides/one-click-bootstrap.md#stopping-and-restarting-a-dockerpodman-container).
|
||||
|
||||
## Baseline Operator Checklist
|
||||
|
||||
|
||||
@@ -98,6 +98,103 @@ If you add `--skip-build`, the installer skips local image build. It first tries
|
||||
Docker tag (`ZEROCLAW_DOCKER_IMAGE`, default: `zeroclaw-bootstrap:local`); if missing,
|
||||
it pulls `ghcr.io/zeroclaw-labs/zeroclaw:latest` and tags it locally before running.
|
||||
|
||||
### Stopping and restarting a Docker/Podman container
|
||||
|
||||
After `./install.sh --docker` finishes, the container exits. Your config and workspace
|
||||
are persisted in the data directory (default: `./.zeroclaw-docker`, or `~/.zeroclaw-docker`
|
||||
when bootstrapping via `curl | bash`). You can override this path with `ZEROCLAW_DOCKER_DATA_DIR`.
|
||||
|
||||
**Do not re-run `install.sh`** to restart -- it will rebuild the image and re-run onboarding.
|
||||
Instead, start a new container from the existing image and mount the persisted data directory.
|
||||
|
||||
#### Using the repository docker-compose.yml
|
||||
|
||||
The simplest way to run ZeroClaw long-term in Docker/Podman is with the provided
|
||||
`docker-compose.yml` at the repository root. It uses a named volume (`zeroclaw-data`)
|
||||
and sets `restart: unless-stopped` so the container survives reboots.
|
||||
|
||||
```bash
|
||||
# Start (detached)
|
||||
docker compose up -d
|
||||
|
||||
# Stop
|
||||
docker compose down
|
||||
|
||||
# Restart after stopping
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
Replace `docker` with `podman` if you use Podman.
|
||||
|
||||
#### Manual container run (using install.sh data directory)
|
||||
|
||||
If you installed via `./install.sh --docker` and want to reuse the `.zeroclaw-docker`
|
||||
data directory without compose:
|
||||
|
||||
```bash
|
||||
# Docker
|
||||
docker run -d --name zeroclaw \
|
||||
--restart unless-stopped \
|
||||
-v "$PWD/.zeroclaw-docker/.zeroclaw:/zeroclaw-data/.zeroclaw" \
|
||||
-v "$PWD/.zeroclaw-docker/workspace:/zeroclaw-data/workspace" \
|
||||
-e HOME=/zeroclaw-data \
|
||||
-e ZEROCLAW_WORKSPACE=/zeroclaw-data/workspace \
|
||||
-p 42617:42617 \
|
||||
zeroclaw-bootstrap:local \
|
||||
gateway
|
||||
|
||||
# Podman (add --userns keep-id and :Z volume labels)
|
||||
podman run -d --name zeroclaw \
|
||||
--restart unless-stopped \
|
||||
--userns keep-id \
|
||||
--user "$(id -u):$(id -g)" \
|
||||
-v "$PWD/.zeroclaw-docker/.zeroclaw:/zeroclaw-data/.zeroclaw:Z" \
|
||||
-v "$PWD/.zeroclaw-docker/workspace:/zeroclaw-data/workspace:Z" \
|
||||
-e HOME=/zeroclaw-data \
|
||||
-e ZEROCLAW_WORKSPACE=/zeroclaw-data/workspace \
|
||||
-p 42617:42617 \
|
||||
zeroclaw-bootstrap:local \
|
||||
gateway
|
||||
```
|
||||
|
||||
#### Common lifecycle commands
|
||||
|
||||
```bash
|
||||
# Stop the container (preserves data)
|
||||
docker stop zeroclaw
|
||||
|
||||
# Start a stopped container (config and workspace are intact)
|
||||
docker start zeroclaw
|
||||
|
||||
# View logs
|
||||
docker logs -f zeroclaw
|
||||
|
||||
# Remove the container (data in volumes/.zeroclaw-docker is preserved)
|
||||
docker rm zeroclaw
|
||||
|
||||
# Check health
|
||||
docker exec zeroclaw zeroclaw status
|
||||
```
|
||||
|
||||
#### Environment variables
|
||||
|
||||
When running manually, pass provider configuration as environment variables
|
||||
or ensure they are already saved in the persisted `config.toml`:
|
||||
|
||||
```bash
|
||||
docker run -d --name zeroclaw \
|
||||
-e API_KEY="sk-..." \
|
||||
-e PROVIDER="openrouter" \
|
||||
-v "$PWD/.zeroclaw-docker/.zeroclaw:/zeroclaw-data/.zeroclaw" \
|
||||
-v "$PWD/.zeroclaw-docker/workspace:/zeroclaw-data/workspace" \
|
||||
-p 42617:42617 \
|
||||
zeroclaw-bootstrap:local \
|
||||
gateway
|
||||
```
|
||||
|
||||
If you already ran `onboard` during the initial install, your API key and provider are
|
||||
saved in `.zeroclaw-docker/.zeroclaw/config.toml` and do not need to be passed again.
|
||||
|
||||
### Quick onboarding (non-interactive)
|
||||
|
||||
```bash
|
||||
|
||||
+1
-1
@@ -517,7 +517,7 @@ install_system_deps() {
|
||||
fi
|
||||
elif have_cmd apt-get; then
|
||||
run_privileged apt-get update -qq
|
||||
run_privileged apt-get install -y build-essential pkg-config git curl
|
||||
run_privileged apt-get install -y build-essential pkg-config git curl libssl-dev
|
||||
elif have_cmd dnf; then
|
||||
run_privileged dnf install -y \
|
||||
gcc \
|
||||
|
||||
+102
-2
@@ -33,11 +33,13 @@ pub struct Agent {
|
||||
skills: Vec<crate::skills::Skill>,
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode,
|
||||
auto_save: bool,
|
||||
memory_session_id: Option<String>,
|
||||
history: Vec<ConversationMessage>,
|
||||
classification_config: crate::config::QueryClassificationConfig,
|
||||
available_hints: Vec<String>,
|
||||
route_model_by_hint: HashMap<String, String>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
response_cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder {
|
||||
@@ -56,10 +58,12 @@ pub struct AgentBuilder {
|
||||
skills: Option<Vec<crate::skills::Skill>>,
|
||||
skills_prompt_mode: Option<crate::config::SkillsPromptInjectionMode>,
|
||||
auto_save: Option<bool>,
|
||||
memory_session_id: Option<String>,
|
||||
classification_config: Option<crate::config::QueryClassificationConfig>,
|
||||
available_hints: Option<Vec<String>>,
|
||||
route_model_by_hint: Option<HashMap<String, String>>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
response_cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
@@ -80,10 +84,12 @@ impl AgentBuilder {
|
||||
skills: None,
|
||||
skills_prompt_mode: None,
|
||||
auto_save: None,
|
||||
memory_session_id: None,
|
||||
classification_config: None,
|
||||
available_hints: None,
|
||||
route_model_by_hint: None,
|
||||
allowed_tools: None,
|
||||
response_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,6 +171,11 @@ impl AgentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn memory_session_id(mut self, memory_session_id: Option<String>) -> Self {
|
||||
self.memory_session_id = memory_session_id;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn classification_config(
|
||||
mut self,
|
||||
classification_config: crate::config::QueryClassificationConfig,
|
||||
@@ -188,6 +199,14 @@ impl AgentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_cache(
|
||||
mut self,
|
||||
cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
) -> Self {
|
||||
self.response_cache = cache;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Agent> {
|
||||
let mut tools = self
|
||||
.tools
|
||||
@@ -231,11 +250,13 @@ impl AgentBuilder {
|
||||
skills: self.skills.unwrap_or_default(),
|
||||
skills_prompt_mode: self.skills_prompt_mode.unwrap_or_default(),
|
||||
auto_save: self.auto_save.unwrap_or(false),
|
||||
memory_session_id: self.memory_session_id,
|
||||
history: Vec::new(),
|
||||
classification_config: self.classification_config.unwrap_or_default(),
|
||||
available_hints: self.available_hints.unwrap_or_default(),
|
||||
route_model_by_hint: self.route_model_by_hint.unwrap_or_default(),
|
||||
allowed_tools: allowed,
|
||||
response_cache: self.response_cache,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -253,6 +274,10 @@ impl Agent {
|
||||
self.history.clear();
|
||||
}
|
||||
|
||||
pub fn set_memory_session_id(&mut self, session_id: Option<String>) {
|
||||
self.memory_session_id = session_id;
|
||||
}
|
||||
|
||||
pub fn from_config(config: &Config) -> Result<Self> {
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
@@ -330,11 +355,25 @@ impl Agent {
|
||||
.collect();
|
||||
let available_hints: Vec<String> = route_model_by_hint.keys().cloned().collect();
|
||||
|
||||
let response_cache = if config.memory.response_cache_enabled {
|
||||
crate::memory::response_cache::ResponseCache::with_hot_cache(
|
||||
&config.workspace_dir,
|
||||
config.memory.response_cache_ttl_minutes,
|
||||
config.memory.response_cache_max_entries,
|
||||
config.memory.response_cache_hot_entries,
|
||||
)
|
||||
.ok()
|
||||
.map(Arc::new)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(tools)
|
||||
.memory(memory)
|
||||
.observer(observer)
|
||||
.response_cache(response_cache)
|
||||
.tool_dispatcher(tool_dispatcher)
|
||||
.memory_loader(Box::new(DefaultMemoryLoader::new(
|
||||
5,
|
||||
@@ -489,13 +528,22 @@ impl Agent {
|
||||
if self.auto_save {
|
||||
let _ = self
|
||||
.memory
|
||||
.store("user_msg", user_message, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
"user_msg",
|
||||
user_message,
|
||||
MemoryCategory::Conversation,
|
||||
self.memory_session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let context = self
|
||||
.memory_loader
|
||||
.load_context(self.memory.as_ref(), user_message)
|
||||
.load_context(
|
||||
self.memory.as_ref(),
|
||||
user_message,
|
||||
self.memory_session_id.as_deref(),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
@@ -513,6 +561,47 @@ impl Agent {
|
||||
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
|
||||
// Response cache: check before LLM call (only for deterministic, text-only prompts)
|
||||
let cache_key = if self.temperature == 0.0 {
|
||||
self.response_cache.as_ref().map(|_| {
|
||||
let last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.as_str())
|
||||
.unwrap_or("");
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
crate::memory::response_cache::ResponseCache::cache_key(
|
||||
&effective_model,
|
||||
system,
|
||||
last_user,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
if let Ok(Some(cached)) = cache.get(key) {
|
||||
self.observer.record_event(&ObserverEvent::CacheHit {
|
||||
cache_type: "response".into(),
|
||||
tokens_saved: 0,
|
||||
});
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
cached.clone(),
|
||||
)));
|
||||
self.trim_history();
|
||||
return Ok(cached);
|
||||
}
|
||||
self.observer.record_event(&ObserverEvent::CacheMiss {
|
||||
cache_type: "response".into(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.provider
|
||||
.chat(
|
||||
@@ -541,6 +630,17 @@ impl Agent {
|
||||
text
|
||||
};
|
||||
|
||||
// Store in response cache (text-only, no tool calls)
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
let token_count = response
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|u| u.output_tokens)
|
||||
.unwrap_or(0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let _ = cache.put(key, &effective_model, &final_text, token_count as u32);
|
||||
}
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
final_text.clone(),
|
||||
|
||||
+1
-12
@@ -128,7 +128,7 @@ impl ToolDispatcher for XmlToolDispatcher {
|
||||
ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
|
||||
}
|
||||
|
||||
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String {
|
||||
fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
|
||||
let mut instructions = String::new();
|
||||
instructions.push_str("## Tool Use Protocol\n\n");
|
||||
instructions
|
||||
@@ -136,17 +136,6 @@ impl ToolDispatcher for XmlToolDispatcher {
|
||||
instructions.push_str(
|
||||
"```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
|
||||
);
|
||||
instructions.push_str("### Available Tools\n\n");
|
||||
|
||||
for tool in tools {
|
||||
let _ = writeln!(
|
||||
instructions,
|
||||
"- **{}**: {}\n Parameters: `{}`",
|
||||
tool.name(),
|
||||
tool.description(),
|
||||
tool.parameters_schema()
|
||||
);
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
+125
-34
@@ -269,6 +269,15 @@ fn autosave_memory_key(prefix: &str) -> String {
|
||||
format!("{prefix}_{}", Uuid::new_v4())
|
||||
}
|
||||
|
||||
fn memory_session_id_from_state_file(path: &Path) -> Option<String> {
|
||||
let raw = path.to_string_lossy().trim().to_string();
|
||||
if raw.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(format!("cli:{raw}"))
|
||||
}
|
||||
|
||||
/// Trim conversation history to prevent unbounded growth.
|
||||
/// Preserves the system prompt (first message if role=system) and the most recent messages.
|
||||
fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
|
||||
@@ -419,11 +428,16 @@ fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Res
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
async fn build_context(mem: &dyn Memory, user_msg: &str, min_relevance_score: f64) -> String {
|
||||
async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
min_relevance_score: f64,
|
||||
session_id: Option<&str>,
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
@@ -438,6 +452,9 @@ async fn build_context(mem: &dyn Memory, user_msg: &str, min_relevance_score: f6
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
continue;
|
||||
}
|
||||
if memory::should_skip_autosave_content(&entry.content) {
|
||||
continue;
|
||||
}
|
||||
// Skip entries containing tool_result blocks — they can leak
|
||||
// stale tool output from previous heartbeat ticks into new
|
||||
// sessions, presenting the LLM with orphan tool_result data.
|
||||
@@ -1281,15 +1298,6 @@ fn parse_glm_style_tool_calls(text: &str) -> Vec<(String, serde_json::Value, Opt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Plain URL
|
||||
if let Some(command) = build_curl_command(line) {
|
||||
calls.push((
|
||||
"shell".to_string(),
|
||||
serde_json::json!({ "command": command }),
|
||||
Some(line.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
calls
|
||||
@@ -2144,6 +2152,7 @@ pub(crate) async fn agent_turn(
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2152,6 +2161,7 @@ async fn execute_one_tool(
|
||||
call_name: &str,
|
||||
call_arguments: serde_json::Value,
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
observer: &dyn Observer,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<ToolExecutionOutcome> {
|
||||
@@ -2162,7 +2172,13 @@ async fn execute_one_tool(
|
||||
});
|
||||
let start = Instant::now();
|
||||
|
||||
let Some(tool) = find_tool(tools_registry, call_name) else {
|
||||
let static_tool = find_tool(tools_registry, call_name);
|
||||
let activated_arc = if static_tool.is_none() {
|
||||
activated_tools.and_then(|at| at.lock().unwrap().get(call_name))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let Some(tool) = static_tool.or(activated_arc.as_deref()) else {
|
||||
let reason = format!("Unknown tool: {call_name}");
|
||||
let duration = start.elapsed();
|
||||
observer.record_event(&ObserverEvent::ToolCall {
|
||||
@@ -2260,6 +2276,7 @@ fn should_execute_tools_in_parallel(
|
||||
async fn execute_tools_parallel(
|
||||
tool_calls: &[ParsedToolCall],
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
observer: &dyn Observer,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<Vec<ToolExecutionOutcome>> {
|
||||
@@ -2270,6 +2287,7 @@ async fn execute_tools_parallel(
|
||||
&call.name,
|
||||
call.arguments.clone(),
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token,
|
||||
)
|
||||
@@ -2283,6 +2301,7 @@ async fn execute_tools_parallel(
|
||||
async fn execute_tools_sequential(
|
||||
tool_calls: &[ParsedToolCall],
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
observer: &dyn Observer,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<Vec<ToolExecutionOutcome>> {
|
||||
@@ -2294,6 +2313,7 @@ async fn execute_tools_sequential(
|
||||
&call.name,
|
||||
call.arguments.clone(),
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token,
|
||||
)
|
||||
@@ -2337,6 +2357,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
excluded_tools: &[String],
|
||||
dedup_exempt_tools: &[String],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
DEFAULT_MAX_TOOL_ITERATIONS
|
||||
@@ -2344,12 +2365,6 @@ pub(crate) async fn run_tool_call_loop(
|
||||
max_tool_iterations
|
||||
};
|
||||
|
||||
let tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
|
||||
.iter()
|
||||
.filter(|tool| !excluded_tools.iter().any(|ex| ex == tool.name()))
|
||||
.map(|tool| tool.spec())
|
||||
.collect();
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
let turn_id = Uuid::new_v4().to_string();
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
|
||||
@@ -2361,6 +2376,21 @@ pub(crate) async fn run_tool_call_loop(
|
||||
return Err(ToolLoopCancelled.into());
|
||||
}
|
||||
|
||||
// Rebuild tool_specs each iteration so newly activated deferred tools appear.
|
||||
let mut tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
|
||||
.iter()
|
||||
.filter(|tool| !excluded_tools.iter().any(|ex| ex == tool.name()))
|
||||
.map(|tool| tool.spec())
|
||||
.collect();
|
||||
if let Some(at) = activated_tools {
|
||||
for spec in at.lock().unwrap().tool_specs() {
|
||||
if !excluded_tools.iter().any(|ex| ex == &spec.name) {
|
||||
tool_specs.push(spec);
|
||||
}
|
||||
}
|
||||
}
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
@@ -2839,6 +2869,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
execute_tools_parallel(
|
||||
&executable_calls,
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token.as_ref(),
|
||||
)
|
||||
@@ -2847,6 +2878,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
execute_tools_sequential(
|
||||
&executable_calls,
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token.as_ref(),
|
||||
)
|
||||
@@ -3098,6 +3130,9 @@ pub async fn run(
|
||||
// eagerly. Instead, a `tool_search` built-in is registered so the LLM can
|
||||
// fetch schemas on demand. This reduces context window waste.
|
||||
let mut deferred_section = String::new();
|
||||
let mut activated_handle: Option<
|
||||
std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>,
|
||||
> = None;
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
@@ -3122,6 +3157,7 @@ pub async fn run(
|
||||
let activated = std::sync::Arc::new(std::sync::Mutex::new(
|
||||
crate::tools::ActivatedToolSet::new(),
|
||||
));
|
||||
activated_handle = Some(std::sync::Arc::clone(&activated));
|
||||
tools_registry.push(Box::new(crate::tools::ToolSearchTool::new(
|
||||
deferred_set,
|
||||
activated,
|
||||
@@ -3360,6 +3396,9 @@ pub async fn run(
|
||||
None
|
||||
};
|
||||
let channel_name = if interactive { "cli" } else { "daemon" };
|
||||
let memory_session_id = session_state_file
|
||||
.as_deref()
|
||||
.and_then(memory_session_id_from_state_file);
|
||||
|
||||
// ── Execute ──────────────────────────────────────────────────
|
||||
let start = Instant::now();
|
||||
@@ -3368,16 +3407,29 @@ pub async fn run(
|
||||
|
||||
if let Some(msg) = message {
|
||||
// Auto-save user message to memory (skip short/trivial messages)
|
||||
if config.memory.auto_save && msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
if config.memory.auto_save
|
||||
&& msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&msg)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(&user_key, &msg, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&user_key,
|
||||
&msg,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context =
|
||||
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&msg,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -3418,6 +3470,7 @@ pub async fn run(
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
@@ -3516,16 +3569,29 @@ pub async fn run(
|
||||
}
|
||||
|
||||
// Auto-save conversation turns (skip short/trivial messages)
|
||||
if config.memory.auto_save && user_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
if config.memory.auto_save
|
||||
&& user_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&user_input)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(&user_key, &user_input, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&user_key,
|
||||
&user_input,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context =
|
||||
build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&user_input,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -3566,6 +3632,7 @@ pub async fn run(
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -3624,7 +3691,11 @@ pub async fn run(
|
||||
|
||||
/// Process a single message through the full agent (with tools, peripherals, memory).
|
||||
/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use.
|
||||
pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
pub async fn process_message(
|
||||
config: Config,
|
||||
message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
@@ -3817,7 +3888,13 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
system_prompt.push_str(&build_tool_instructions(&tools_registry));
|
||||
}
|
||||
|
||||
let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
message,
|
||||
config.memory.min_relevance_score,
|
||||
session_id,
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -3935,7 +4012,8 @@ mod tests {
|
||||
.expect("should produce a sample whose byte index 300 is not a char boundary");
|
||||
|
||||
let observer = NoopObserver;
|
||||
let result = execute_one_tool("unknown_tool", call_arguments, &[], &observer, None).await;
|
||||
let result =
|
||||
execute_one_tool("unknown_tool", call_arguments, &[], None, &observer, None).await;
|
||||
assert!(result.is_ok(), "execute_one_tool should not panic or error");
|
||||
|
||||
let outcome = result.unwrap();
|
||||
@@ -3977,6 +4055,7 @@ mod tests {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4271,6 +4350,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
@@ -4318,6 +4398,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
@@ -4359,6 +4440,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
@@ -4486,6 +4568,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("parallel execution should complete");
|
||||
@@ -4556,6 +4639,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish after deduplicating repeated calls");
|
||||
@@ -4618,6 +4702,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&exempt,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish with exempt tool executing twice");
|
||||
@@ -4695,6 +4780,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&exempt,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should complete");
|
||||
@@ -4749,6 +4835,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("native fallback id flow should complete");
|
||||
@@ -5489,7 +5576,7 @@ Tail"#;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_context(&mem, "status updates", 0.0).await;
|
||||
let context = build_context(&mem, "status updates", 0.0, None).await;
|
||||
assert!(context.contains("user_msg_real"));
|
||||
assert!(!context.contains("assistant_resp_poisoned"));
|
||||
assert!(!context.contains("fabricated event"));
|
||||
@@ -5915,12 +6002,15 @@ Final answer."#;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_glm_style_plain_url() {
|
||||
fn parse_glm_style_ignores_plain_url() {
|
||||
// A bare URL should NOT be interpreted as a tool call — this was
|
||||
// causing false positives when LLMs included URLs in normal text.
|
||||
let response = "https://example.com/api";
|
||||
let calls = parse_glm_style_tool_calls(response);
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].0, "shell");
|
||||
assert!(calls[0].1["command"].as_str().unwrap().contains("curl"));
|
||||
assert!(
|
||||
calls.is_empty(),
|
||||
"plain URL must not be parsed as tool call"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -6647,6 +6737,7 @@ Let me check the result."#;
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
|
||||
@@ -4,8 +4,12 @@ use std::fmt::Write;
|
||||
|
||||
#[async_trait]
|
||||
pub trait MemoryLoader: Send + Sync {
|
||||
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
|
||||
-> anyhow::Result<String>;
|
||||
async fn load_context(
|
||||
&self,
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
pub struct DefaultMemoryLoader {
|
||||
@@ -37,8 +41,9 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
&self,
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||
let entries = memory.recall(user_message, self.limit, session_id).await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
@@ -48,6 +53,9 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
continue;
|
||||
}
|
||||
if memory::should_skip_autosave_content(&entry.content) {
|
||||
continue;
|
||||
}
|
||||
if let Some(score) = entry.score {
|
||||
if score < self.min_relevance_score {
|
||||
continue;
|
||||
@@ -191,7 +199,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn default_loader_formats_context() {
|
||||
let loader = DefaultMemoryLoader::default();
|
||||
let context = loader.load_context(&MockMemory, "hello").await.unwrap();
|
||||
let context = loader
|
||||
.load_context(&MockMemory, "hello", None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(context.contains("[Memory context]"));
|
||||
assert!(context.contains("- k: v"));
|
||||
}
|
||||
@@ -222,7 +233,10 @@ mod tests {
|
||||
]),
|
||||
};
|
||||
|
||||
let context = loader.load_context(&memory, "answer style").await.unwrap();
|
||||
let context = loader
|
||||
.load_context(&memory, "answer style", None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(context.contains("user_fact"));
|
||||
assert!(!context.contains("assistant_resp_legacy"));
|
||||
assert!(!context.contains("fabricated detail"));
|
||||
|
||||
+6
-2
@@ -1282,8 +1282,12 @@ fn xml_dispatcher_generates_tool_instructions() {
|
||||
|
||||
assert!(instructions.contains("## Tool Use Protocol"));
|
||||
assert!(instructions.contains("<tool_call>"));
|
||||
assert!(instructions.contains("echo"));
|
||||
assert!(instructions.contains("Echoes the input"));
|
||||
// Tool listing is handled by ToolsSection in prompt.rs, not by the
|
||||
// dispatcher. prompt_instructions() must only emit the protocol envelope.
|
||||
assert!(
|
||||
!instructions.contains("echo"),
|
||||
"dispatcher should not duplicate tool listing"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
+41
-1
@@ -711,8 +711,13 @@ impl Channel for DiscordChannel {
|
||||
}
|
||||
|
||||
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
// DMs carry no guild_id in the Discord gateway payload. They are
|
||||
// inherently private and implicitly addressed to the bot, so bypass
|
||||
// the mention gate — requiring a @mention in a DM is never correct.
|
||||
let is_dm = d.get("guild_id").is_none();
|
||||
let effective_mention_only = self.mention_only && !is_dm;
|
||||
let Some(clean_content) =
|
||||
normalize_incoming_content(content, self.mention_only, &bot_user_id)
|
||||
normalize_incoming_content(content, effective_mention_only, &bot_user_id)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
@@ -1027,6 +1032,41 @@ mod tests {
|
||||
assert!(cleaned.is_none());
|
||||
}
|
||||
|
||||
// mention_only DM-bypass tests
|
||||
|
||||
#[test]
|
||||
fn mention_only_dm_bypasses_mention_gate() {
|
||||
// DMs (no guild_id) must pass through even when mention_only is true
|
||||
// and the message contains no @mention. Mirrors the listen call-site logic.
|
||||
let mention_only = true;
|
||||
let is_dm = true;
|
||||
let effective = mention_only && !is_dm;
|
||||
let cleaned = normalize_incoming_content("hello without mention", effective, "12345");
|
||||
assert_eq!(cleaned.as_deref(), Some("hello without mention"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mention_only_guild_message_without_mention_is_rejected() {
|
||||
// Guild messages (has guild_id, so is_dm = false) must still be rejected
|
||||
// when mention_only is true and the message contains no @mention.
|
||||
let mention_only = true;
|
||||
let is_dm = false;
|
||||
let effective = mention_only && !is_dm;
|
||||
let cleaned = normalize_incoming_content("hello without mention", effective, "12345");
|
||||
assert!(cleaned.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mention_only_guild_message_with_mention_passes_and_strips() {
|
||||
// Guild messages that do carry a @mention pass through and have the
|
||||
// mention tag stripped, consistent with pre-existing behaviour.
|
||||
let mention_only = true;
|
||||
let is_dm = false;
|
||||
let effective = mention_only && !is_dm;
|
||||
let cleaned = normalize_incoming_content("<@12345> run status", effective, "12345");
|
||||
assert_eq!(cleaned.as_deref(), Some("run status"));
|
||||
}
|
||||
|
||||
// Message splitting tests
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Deduplication set capacity — evict half of entries when full.
|
||||
const DEDUP_CAPACITY: usize = 10_000;
|
||||
|
||||
/// Mochat customer service channel.
|
||||
///
|
||||
/// Integrates with the Mochat open-source customer service platform API
|
||||
/// for receiving and sending messages through its HTTP endpoints.
|
||||
pub struct MochatChannel {
|
||||
api_url: String,
|
||||
api_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
poll_interval_secs: u64,
|
||||
/// Message deduplication set.
|
||||
dedup: Arc<RwLock<HashSet<String>>>,
|
||||
}
|
||||
|
||||
impl MochatChannel {
|
||||
pub fn new(
|
||||
api_url: String,
|
||||
api_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
poll_interval_secs: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_url: api_url.trim_end_matches('/').to_string(),
|
||||
api_token,
|
||||
allowed_users,
|
||||
poll_interval_secs,
|
||||
dedup: Arc::new(RwLock::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.mochat")
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
/// Check and insert message ID for deduplication.
|
||||
async fn is_duplicate(&self, msg_id: &str) -> bool {
|
||||
if msg_id.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut dedup = self.dedup.write().await;
|
||||
|
||||
if dedup.contains(msg_id) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if dedup.len() >= DEDUP_CAPACITY {
|
||||
let to_remove: Vec<String> = dedup.iter().take(DEDUP_CAPACITY / 2).cloned().collect();
|
||||
for key in to_remove {
|
||||
dedup.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
dedup.insert(msg_id.to_string());
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for MochatChannel {
|
||||
fn name(&self) -> &str {
|
||||
"mochat"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
let body = json!({
|
||||
"toUserId": message.recipient,
|
||||
"msgType": "text",
|
||||
"content": {
|
||||
"text": message.content,
|
||||
}
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http_client()
|
||||
.post(format!("{}/api/message/send", self.api_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_token))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Mochat send message failed ({status}): {err}");
|
||||
}
|
||||
|
||||
let result: serde_json::Value = resp.json().await?;
|
||||
let code = result.get("code").and_then(|v| v.as_i64()).unwrap_or(-1);
|
||||
if code != 0 && code != 200 {
|
||||
let msg = result
|
||||
.get("msg")
|
||||
.or_else(|| result.get("message"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown error");
|
||||
anyhow::bail!("Mochat API error (code={code}): {msg}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
tracing::info!("Mochat: starting message poller");
|
||||
|
||||
let poll_interval = std::time::Duration::from_secs(self.poll_interval_secs);
|
||||
let mut last_message_id: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let mut url = format!("{}/api/message/receive", self.api_url);
|
||||
if let Some(ref id) = last_message_id {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(url, "?since_id={id}");
|
||||
}
|
||||
|
||||
match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_token))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let data: serde_json::Value = match resp.json().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!("Mochat: failed to parse response: {e}");
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let messages = data
|
||||
.get("data")
|
||||
.or_else(|| data.get("messages"))
|
||||
.and_then(|d| d.as_array());
|
||||
|
||||
if let Some(messages) = messages {
|
||||
for msg in messages {
|
||||
let msg_id = msg
|
||||
.get("messageId")
|
||||
.or_else(|| msg.get("id"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if self.is_duplicate(msg_id).await {
|
||||
continue;
|
||||
}
|
||||
|
||||
let sender = msg
|
||||
.get("fromUserId")
|
||||
.or_else(|| msg.get("sender"))
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
if !self.is_user_allowed(sender) {
|
||||
tracing::debug!(
|
||||
"Mochat: ignoring message from unauthorized user: {sender}"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = msg
|
||||
.get("content")
|
||||
.and_then(|c| {
|
||||
c.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.or_else(|| c.as_str())
|
||||
})
|
||||
.unwrap_or("")
|
||||
.trim();
|
||||
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: sender.to_string(),
|
||||
reply_target: sender.to_string(),
|
||||
content: content.to_string(),
|
||||
channel: "mochat".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
tracing::warn!("Mochat: message channel closed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !msg_id.is_empty() {
|
||||
last_message_id = Some(msg_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
tracing::warn!("Mochat: poll request failed ({status}): {err}");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Mochat: poll request error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
let resp = self
|
||||
.http_client()
|
||||
.get(format!("{}/api/health", self.api_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_token))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(r) => r.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_name() {
|
||||
let ch = MochatChannel::new("https://mochat.example.com".into(), "tok".into(), vec![], 5);
|
||||
assert_eq!(ch.name(), "mochat");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_url_trailing_slash_stripped() {
|
||||
let ch = MochatChannel::new(
|
||||
"https://mochat.example.com/".into(),
|
||||
"tok".into(),
|
||||
vec![],
|
||||
5,
|
||||
);
|
||||
assert_eq!(ch.api_url, "https://mochat.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_allowed_wildcard() {
|
||||
let ch = MochatChannel::new("https://m.test".into(), "tok".into(), vec!["*".into()], 5);
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_allowed_specific() {
|
||||
let ch = MochatChannel::new(
|
||||
"https://m.test".into(),
|
||||
"tok".into(),
|
||||
vec!["user123".into()],
|
||||
5,
|
||||
);
|
||||
assert!(ch.is_user_allowed("user123"));
|
||||
assert!(!ch.is_user_allowed("other"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_denied_empty() {
|
||||
let ch = MochatChannel::new("https://m.test".into(), "tok".into(), vec![], 5);
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dedup() {
|
||||
let ch = MochatChannel::new("https://m.test".into(), "tok".into(), vec![], 5);
|
||||
assert!(!ch.is_duplicate("msg1").await);
|
||||
assert!(ch.is_duplicate("msg1").await);
|
||||
assert!(!ch.is_duplicate("msg2").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dedup_empty_id() {
|
||||
let ch = MochatChannel::new("https://m.test".into(), "tok".into(), vec![], 5);
|
||||
assert!(!ch.is_duplicate("").await);
|
||||
assert!(!ch.is_duplicate("").await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serde() {
|
||||
let toml_str = r#"
|
||||
api_url = "https://mochat.example.com"
|
||||
api_token = "secret"
|
||||
allowed_users = ["user1"]
|
||||
"#;
|
||||
let config: crate::config::schema::MochatConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.api_url, "https://mochat.example.com");
|
||||
assert_eq!(config.api_token, "secret");
|
||||
assert_eq!(config.allowed_users, vec!["user1"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serde_defaults() {
|
||||
let toml_str = r#"
|
||||
api_url = "https://mochat.example.com"
|
||||
api_token = "secret"
|
||||
"#;
|
||||
let config: crate::config::schema::MochatConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(config.allowed_users.is_empty());
|
||||
assert_eq!(config.poll_interval_secs, 5);
|
||||
}
|
||||
}
|
||||
+240
-15
@@ -27,11 +27,14 @@ pub mod linq;
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
pub mod matrix;
|
||||
pub mod mattermost;
|
||||
pub mod mochat;
|
||||
pub mod nextcloud_talk;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub mod nostr;
|
||||
pub mod notion;
|
||||
pub mod qq;
|
||||
pub mod session_backend;
|
||||
pub mod session_sqlite;
|
||||
pub mod session_store;
|
||||
pub mod signal;
|
||||
pub mod slack;
|
||||
@@ -39,6 +42,7 @@ pub mod telegram;
|
||||
pub mod traits;
|
||||
pub mod transcription;
|
||||
pub mod tts;
|
||||
pub mod twitter;
|
||||
pub mod wati;
|
||||
pub mod wecom;
|
||||
pub mod whatsapp;
|
||||
@@ -60,6 +64,7 @@ pub use linq::LinqChannel;
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
pub use matrix::MatrixChannel;
|
||||
pub use mattermost::MattermostChannel;
|
||||
pub use mochat::MochatChannel;
|
||||
pub use nextcloud_talk::NextcloudTalkChannel;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub use nostr::NostrChannel;
|
||||
@@ -71,6 +76,7 @@ pub use telegram::TelegramChannel;
|
||||
pub use traits::{Channel, SendMessage};
|
||||
#[allow(unused_imports)]
|
||||
pub use tts::{TtsManager, TtsProvider};
|
||||
pub use twitter::TwitterChannel;
|
||||
pub use wati::WatiChannel;
|
||||
pub use wecom::WeComChannel;
|
||||
pub use whatsapp::WhatsAppChannel;
|
||||
@@ -323,6 +329,7 @@ struct ChannelRuntimeContext {
|
||||
/// `[autonomy]` config; auto-denies tools that would need interactive
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -839,9 +846,17 @@ async fn maybe_apply_runtime_config_update(ctx: &ChannelRuntimeContext) -> Resul
|
||||
let next_default_provider: Arc<dyn Provider> = Arc::from(next_default_provider);
|
||||
|
||||
if let Err(err) = next_default_provider.warmup().await {
|
||||
if crate::providers::reliable::is_non_retryable(&err) {
|
||||
tracing::warn!(
|
||||
provider = %next_defaults.default_provider,
|
||||
model = %next_defaults.model,
|
||||
"Rejecting config reload: model not available (non-retryable): {err}"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
tracing::warn!(
|
||||
provider = %next_defaults.default_provider,
|
||||
"Provider warmup failed after config reload: {err}"
|
||||
"Provider warmup failed after config reload (retryable, applying anyway): {err}"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1018,6 +1033,15 @@ fn rollback_orphan_user_turn(
|
||||
if turns.is_empty() {
|
||||
histories.remove(sender_key);
|
||||
}
|
||||
|
||||
// Also remove the orphan turn from the persisted JSONL session store so
|
||||
// it doesn't resurface after a daemon restart (fixes #3674).
|
||||
if let Some(ref store) = ctx.session_store {
|
||||
if let Err(e) = store.remove_last(sender_key) {
|
||||
tracing::warn!("Failed to rollback session store entry: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
@@ -1026,6 +1050,10 @@ fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
if memory::should_skip_autosave_content(content) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if key.trim().to_ascii_lowercase().ends_with("_history") {
|
||||
return true;
|
||||
}
|
||||
@@ -1317,10 +1345,11 @@ async fn build_memory_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
min_relevance_score: f64,
|
||||
session_id: Option<&str>,
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
let mut included = 0usize;
|
||||
let mut used_chars = 0usize;
|
||||
|
||||
@@ -1786,7 +1815,17 @@ async fn process_channel_message(
|
||||
msg
|
||||
};
|
||||
|
||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
let target_channel = ctx
|
||||
.channels_by_name
|
||||
.get(&msg.channel)
|
||||
.or_else(|| {
|
||||
// Multi-room channels use "name:qualifier" format (e.g. "matrix:!roomId");
|
||||
// fall back to base channel name for routing.
|
||||
msg.channel
|
||||
.split_once(':')
|
||||
.and_then(|(base, _)| ctx.channels_by_name.get(base))
|
||||
})
|
||||
.cloned();
|
||||
if let Err(err) = maybe_apply_runtime_config_update(ctx.as_ref()).await {
|
||||
tracing::warn!("Failed to apply runtime config update: {err}");
|
||||
}
|
||||
@@ -1840,7 +1879,10 @@ async fn process_channel_message(
|
||||
return;
|
||||
}
|
||||
};
|
||||
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
if ctx.auto_save_memory
|
||||
&& msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&msg.content)
|
||||
{
|
||||
let autosave_key = conversation_memory_key(&msg);
|
||||
let _ = ctx
|
||||
.memory
|
||||
@@ -1848,7 +1890,7 @@ async fn process_channel_message(
|
||||
&autosave_key,
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
None,
|
||||
Some(&history_key),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1885,6 +1927,29 @@ async fn process_channel_message(
|
||||
}
|
||||
}
|
||||
|
||||
// Strip [IMAGE:] markers from *older* history messages when the active
|
||||
// provider does not support vision. This prevents "history poisoning"
|
||||
// where a previously-sent image marker gets reloaded from the JSONL
|
||||
// session file and permanently breaks the conversation (fixes #3674).
|
||||
// We skip the last turn (the current message) so the vision check can
|
||||
// still reject fresh image sends with a proper error.
|
||||
if !active_provider.supports_vision() && prior_turns.len() > 1 {
|
||||
let last_idx = prior_turns.len() - 1;
|
||||
for turn in &mut prior_turns[..last_idx] {
|
||||
if turn.content.contains("[IMAGE:") {
|
||||
let (cleaned, _refs) = crate::multimodal::parse_image_markers(&turn.content);
|
||||
turn.content = cleaned;
|
||||
}
|
||||
}
|
||||
// Drop older turns that became empty after marker removal (e.g. image-only messages).
|
||||
// Keep the last turn (current message) intact.
|
||||
let current = prior_turns.pop();
|
||||
prior_turns.retain(|turn| !turn.content.trim().is_empty());
|
||||
if let Some(current) = current {
|
||||
prior_turns.push(current);
|
||||
}
|
||||
}
|
||||
|
||||
// Proactively trim conversation history before sending to the provider
|
||||
// to prevent context-window-exceeded errors (bug #3460).
|
||||
let dropped = proactive_trim_turns(&mut prior_turns, PROACTIVE_CONTEXT_BUDGET_CHARS);
|
||||
@@ -1901,8 +1966,13 @@ async fn process_channel_message(
|
||||
// Only enrich with memory context when there is no prior conversation
|
||||
// history. Follow-up turns already include context from previous messages.
|
||||
if !had_prior_history {
|
||||
let memory_context =
|
||||
build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await;
|
||||
let memory_context = build_memory_context(
|
||||
ctx.memory.as_ref(),
|
||||
&msg.content,
|
||||
ctx.min_relevance_score,
|
||||
Some(&history_key),
|
||||
)
|
||||
.await;
|
||||
if let Some(last_turn) = prior_turns.last_mut() {
|
||||
if last_turn.role == "user" && !memory_context.is_empty() {
|
||||
last_turn.content = format!("{memory_context}{}", msg.content);
|
||||
@@ -2071,6 +2141,7 @@ async fn process_channel_message(
|
||||
ctx.non_cli_excluded_tools.as_ref()
|
||||
},
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@@ -3167,6 +3238,7 @@ fn collect_configured_channels(
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
),
|
||||
});
|
||||
@@ -3261,12 +3333,15 @@ fn collect_configured_channels(
|
||||
if wa.is_web_config() {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WhatsApp",
|
||||
channel: Arc::new(WhatsAppWebChannel::new(
|
||||
wa.session_path.clone().unwrap_or_default(),
|
||||
wa.pair_phone.clone(),
|
||||
wa.pair_code.clone(),
|
||||
wa.allowed_numbers.clone(),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
WhatsAppWebChannel::new(
|
||||
wa.session_path.clone().unwrap_or_default(),
|
||||
wa.pair_phone.clone(),
|
||||
wa.pair_code.clone(),
|
||||
wa.allowed_numbers.clone(),
|
||||
)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
} else {
|
||||
tracing::warn!("WhatsApp Web configured but session_path not set");
|
||||
@@ -3404,6 +3479,28 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref tw) = config.channels_config.twitter {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "X/Twitter",
|
||||
channel: Arc::new(TwitterChannel::new(
|
||||
tw.bearer_token.clone(),
|
||||
tw.allowed_users.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref mc) = config.channels_config.mochat {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Mochat",
|
||||
channel: Arc::new(MochatChannel::new(
|
||||
mc.api_url.clone(),
|
||||
mc.api_token.clone(),
|
||||
mc.allowed_users.clone(),
|
||||
mc.poll_interval_secs,
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref wc) = config.channels_config.wecom {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WeCom",
|
||||
@@ -3605,6 +3702,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
|
||||
// Instead, a `tool_search` built-in is registered for on-demand loading.
|
||||
let mut deferred_section = String::new();
|
||||
let mut ch_activated_handle: Option<
|
||||
std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>,
|
||||
> = None;
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
@@ -3628,6 +3728,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
let activated = std::sync::Arc::new(std::sync::Mutex::new(
|
||||
crate::tools::ActivatedToolSet::new(),
|
||||
));
|
||||
ch_activated_handle = Some(std::sync::Arc::clone(&activated));
|
||||
built_tools.push(Box::new(crate::tools::ToolSearchTool::new(
|
||||
deferred_set,
|
||||
activated,
|
||||
@@ -3922,6 +4023,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
None
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
activated_tools: ch_activated_handle,
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
@@ -4214,6 +4316,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -4322,6 +4425,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -4386,6 +4490,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -4402,6 +4507,101 @@ mod tests {
|
||||
assert_eq!(turns[1].content, "ok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rollback_orphan_user_turn_also_removes_from_session_store() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let store = Arc::new(session_store::SessionStore::new(tmp.path()).unwrap());
|
||||
|
||||
let sender = "telegram_u4".to_string();
|
||||
|
||||
// Pre-populate the session store with the same turns.
|
||||
store.append(&sender, &ChatMessage::user("first")).unwrap();
|
||||
store
|
||||
.append(&sender, &ChatMessage::assistant("ok"))
|
||||
.unwrap();
|
||||
store
|
||||
.append(
|
||||
&sender,
|
||||
&ChatMessage::user("[IMAGE:/tmp/photo.jpg]\n\nDescribe this"),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut histories = HashMap::new();
|
||||
histories.insert(
|
||||
sender.clone(),
|
||||
vec![
|
||||
ChatMessage::user("first"),
|
||||
ChatMessage::assistant("ok"),
|
||||
ChatMessage::user("[IMAGE:/tmp/photo.jpg]\n\nDescribe this"),
|
||||
],
|
||||
);
|
||||
|
||||
let ctx = ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(HashMap::new()),
|
||||
provider: Arc::new(DummyProvider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("system".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: Some(Arc::clone(&store)),
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(
|
||||
&ctx,
|
||||
&sender,
|
||||
"[IMAGE:/tmp/photo.jpg]\n\nDescribe this"
|
||||
));
|
||||
|
||||
// In-memory history should have 2 turns remaining.
|
||||
let locked = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let turns = locked.get(&sender).expect("history should remain");
|
||||
assert_eq!(turns.len(), 2);
|
||||
|
||||
// Session store should also have only 2 entries.
|
||||
let persisted = store.load(&sender);
|
||||
assert_eq!(
|
||||
persisted.len(),
|
||||
2,
|
||||
"session store should also lose the rolled-back turn"
|
||||
);
|
||||
assert_eq!(persisted[0].content, "first");
|
||||
assert_eq!(persisted[1].content, "ok");
|
||||
}
|
||||
|
||||
struct DummyProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -4908,6 +5108,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4980,6 +5181,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5066,6 +5268,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5137,6 +5340,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5218,6 +5422,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5319,6 +5524,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5402,6 +5608,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5500,6 +5707,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5583,6 +5791,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5656,6 +5865,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5840,6 +6050,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -5932,6 +6143,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6038,6 +6250,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
});
|
||||
|
||||
@@ -6143,6 +6356,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6229,6 +6443,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6300,6 +6515,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6835,7 +7051,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_memory_context(&mem, "age", 0.0).await;
|
||||
let context = build_memory_context(&mem, "age", 0.0, None).await;
|
||||
assert!(context.contains("[Memory context]"));
|
||||
assert!(context.contains("Age is 45"));
|
||||
}
|
||||
@@ -6867,7 +7083,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_memory_context(&mem, "screenshot", 0.0).await;
|
||||
let context = build_memory_context(&mem, "screenshot", 0.0, None).await;
|
||||
|
||||
// The image-marker entry must be excluded to prevent duplication.
|
||||
assert!(
|
||||
@@ -6929,6 +7145,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7026,6 +7243,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7123,6 +7341,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7684,6 +7903,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -7762,6 +7982,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7914,6 +8135,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8016,6 +8238,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8110,6 +8333,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8224,6 +8448,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
||||
+211
-32
@@ -62,24 +62,146 @@ impl NextcloudTalkChannel {
|
||||
|
||||
/// Parse a Nextcloud Talk webhook payload into channel messages.
|
||||
///
|
||||
/// Relevant payload fields:
|
||||
/// - `type` (accepts `message` or `Create`)
|
||||
/// - `object.token` (room token for reply routing)
|
||||
/// - `message.actorType`, `message.actorId`, `message.message`, `message.timestamp`
|
||||
/// Two payload formats are supported:
|
||||
///
|
||||
/// **Format A — legacy/custom** (`type: "message"`):
|
||||
/// ```json
|
||||
/// {
|
||||
/// "type": "message",
|
||||
/// "object": { "token": "<room>" },
|
||||
/// "message": { "actorId": "...", "message": "...", ... }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// **Format B — Activity Streams 2.0** (`type: "Create"`):
|
||||
/// This is the format actually sent by Nextcloud Talk bot webhooks.
|
||||
/// ```json
|
||||
/// {
|
||||
/// "type": "Create",
|
||||
/// "actor": { "type": "Person", "id": "users/alice", "name": "Alice" },
|
||||
/// "object": { "type": "Note", "id": "177", "content": "{\"message\":\"hi\",\"parameters\":[]}", "mediaType": "text/markdown" },
|
||||
/// "target": { "type": "Collection", "id": "<room_token>", "name": "Room Name" }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn parse_webhook_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let messages = Vec::new();
|
||||
|
||||
let event_type = match payload.get("type").and_then(|v| v.as_str()) {
|
||||
Some(t) => t,
|
||||
None => return messages,
|
||||
};
|
||||
|
||||
// Activity Streams 2.0 format sent by Nextcloud Talk bot webhooks.
|
||||
if event_type.eq_ignore_ascii_case("create") {
|
||||
return self.parse_as2_payload(payload);
|
||||
}
|
||||
|
||||
// Legacy/custom format.
|
||||
if !event_type.eq_ignore_ascii_case("message") {
|
||||
tracing::debug!("Nextcloud Talk: skipping non-message event: {event_type}");
|
||||
return messages;
|
||||
}
|
||||
|
||||
self.parse_message_payload(payload)
|
||||
}
|
||||
|
||||
/// Parse Activity Streams 2.0 `Create` payload (real Nextcloud Talk bot webhook format).
|
||||
fn parse_as2_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(event_type) = payload.get("type").and_then(|v| v.as_str()) {
|
||||
// Nextcloud Talk bot webhooks send "Create" for new chat messages,
|
||||
// but some setups may use "message". Accept both.
|
||||
let is_message_event = event_type.eq_ignore_ascii_case("message")
|
||||
|| event_type.eq_ignore_ascii_case("create");
|
||||
if !is_message_event {
|
||||
tracing::debug!("Nextcloud Talk: skipping non-message event: {event_type}");
|
||||
return messages;
|
||||
}
|
||||
let obj = match payload.get("object") {
|
||||
Some(o) => o,
|
||||
None => return messages,
|
||||
};
|
||||
|
||||
// Only handle Note objects (= chat messages). Ignore reactions, etc.
|
||||
let object_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if !object_type.eq_ignore_ascii_case("note") {
|
||||
tracing::debug!("Nextcloud Talk: skipping AS2 Create with object.type={object_type}");
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Room token is in target.id.
|
||||
let room_token = payload
|
||||
.get("target")
|
||||
.and_then(|t| t.get("id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|t| !t.is_empty());
|
||||
|
||||
let Some(room_token) = room_token else {
|
||||
tracing::warn!("Nextcloud Talk: missing target.id (room token) in AS2 payload");
|
||||
return messages;
|
||||
};
|
||||
|
||||
// Actor — skip bot-originated messages to prevent feedback loops.
|
||||
let actor = payload.get("actor").cloned().unwrap_or_default();
|
||||
let actor_type = actor.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if actor_type.eq_ignore_ascii_case("application") {
|
||||
tracing::debug!("Nextcloud Talk: skipping bot-originated AS2 message");
|
||||
return messages;
|
||||
}
|
||||
|
||||
// actor.id is "users/<id>" — strip the prefix.
|
||||
let actor_id = actor
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|id| id.trim_start_matches("users/").trim())
|
||||
.filter(|id| !id.is_empty());
|
||||
|
||||
let Some(actor_id) = actor_id else {
|
||||
tracing::warn!("Nextcloud Talk: missing actor.id in AS2 payload");
|
||||
return messages;
|
||||
};
|
||||
|
||||
if !self.is_user_allowed(actor_id) {
|
||||
tracing::warn!(
|
||||
"Nextcloud Talk: ignoring message from unauthorized actor: {actor_id}. \
|
||||
Add to channels.nextcloud_talk.allowed_users in config.toml, \
|
||||
or run `zeroclaw onboard --channels-only` to configure interactively."
|
||||
);
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Message text is JSON-encoded inside object.content.
|
||||
// e.g. content = "{\"message\":\"hello\",\"parameters\":[]}"
|
||||
let content = obj
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(|s| serde_json::from_str::<serde_json::Value>(s).ok())
|
||||
.and_then(|v| {
|
||||
v.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(str::trim)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
let Some(content) = content else {
|
||||
tracing::debug!("Nextcloud Talk: empty or unparseable AS2 message content");
|
||||
return messages;
|
||||
};
|
||||
|
||||
let message_id =
|
||||
Self::value_to_string(obj.get("id")).unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
|
||||
messages.push(ChannelMessage {
|
||||
id: message_id,
|
||||
reply_target: room_token.to_string(),
|
||||
sender: actor_id.to_string(),
|
||||
content,
|
||||
channel: "nextcloud_talk".to_string(),
|
||||
timestamp: Self::now_unix_secs(),
|
||||
thread_ts: None,
|
||||
});
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Parse legacy `type: "message"` payload format.
|
||||
fn parse_message_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
let Some(message_obj) = payload.get("message") else {
|
||||
return messages;
|
||||
};
|
||||
@@ -343,33 +465,90 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nextcloud_talk_parse_create_event_type() {
|
||||
let channel = make_channel();
|
||||
fn nextcloud_talk_parse_as2_create_payload() {
|
||||
let channel = NextcloudTalkChannel::new(
|
||||
"https://cloud.example.com".into(),
|
||||
"app-token".into(),
|
||||
vec!["*".into()],
|
||||
);
|
||||
// Real payload format sent by Nextcloud Talk bot webhooks.
|
||||
let payload = serde_json::json!({
|
||||
"type": "Create",
|
||||
"object": {
|
||||
"id": "42",
|
||||
"token": "room-token-123",
|
||||
"name": "Team Room",
|
||||
"type": "room"
|
||||
"actor": {
|
||||
"type": "Person",
|
||||
"id": "users/user_a",
|
||||
"name": "User A",
|
||||
"talkParticipantType": "1"
|
||||
},
|
||||
"message": {
|
||||
"id": 88,
|
||||
"token": "room-token-123",
|
||||
"actorType": "users",
|
||||
"actorId": "user_a",
|
||||
"actorDisplayName": "User A",
|
||||
"timestamp": 1_735_701_300,
|
||||
"messageType": "comment",
|
||||
"systemMessage": "",
|
||||
"message": "Hello via Create event"
|
||||
"object": {
|
||||
"type": "Note",
|
||||
"id": "177",
|
||||
"name": "message",
|
||||
"content": "{\"message\":\"hallo, bist du da?\",\"parameters\":[]}",
|
||||
"mediaType": "text/markdown"
|
||||
},
|
||||
"target": {
|
||||
"type": "Collection",
|
||||
"id": "room-token-123",
|
||||
"name": "HOME"
|
||||
}
|
||||
});
|
||||
|
||||
let messages = channel.parse_webhook_payload(&payload);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].id, "88");
|
||||
assert_eq!(messages[0].content, "Hello via Create event");
|
||||
assert_eq!(messages[0].reply_target, "room-token-123");
|
||||
assert_eq!(messages[0].sender, "user_a");
|
||||
assert_eq!(messages[0].content, "hallo, bist du da?");
|
||||
assert_eq!(messages[0].channel, "nextcloud_talk");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nextcloud_talk_parse_as2_skips_bot_originated() {
|
||||
let channel = NextcloudTalkChannel::new(
|
||||
"https://cloud.example.com".into(),
|
||||
"app-token".into(),
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"type": "Create",
|
||||
"actor": {
|
||||
"type": "Application",
|
||||
"id": "bots/jarvis",
|
||||
"name": "jarvis"
|
||||
},
|
||||
"object": {
|
||||
"type": "Note",
|
||||
"id": "178",
|
||||
"content": "{\"message\":\"I am the bot\",\"parameters\":[]}",
|
||||
"mediaType": "text/markdown"
|
||||
},
|
||||
"target": {
|
||||
"type": "Collection",
|
||||
"id": "room-token-123",
|
||||
"name": "HOME"
|
||||
}
|
||||
});
|
||||
|
||||
let messages = channel.parse_webhook_payload(&payload);
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nextcloud_talk_parse_as2_skips_non_note_objects() {
|
||||
let channel = NextcloudTalkChannel::new(
|
||||
"https://cloud.example.com".into(),
|
||||
"app-token".into(),
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"type": "Create",
|
||||
"actor": { "type": "Person", "id": "users/user_a" },
|
||||
"object": { "type": "Reaction", "id": "5" },
|
||||
"target": { "type": "Collection", "id": "room-token-123" }
|
||||
});
|
||||
|
||||
let messages = channel.parse_webhook_payload(&payload);
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
+39
-4
@@ -257,8 +257,10 @@ impl Channel for QQChannel {
|
||||
(
|
||||
format!("{QQ_API_BASE}/v2/groups/{group_id}/messages"),
|
||||
json!({
|
||||
"content": &message.content,
|
||||
"msg_type": 0,
|
||||
"markdown": {
|
||||
"content": &message.content,
|
||||
},
|
||||
"msg_type": 2,
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
@@ -273,8 +275,10 @@ impl Channel for QQChannel {
|
||||
(
|
||||
format!("{QQ_API_BASE}/v2/users/{user_id}/messages"),
|
||||
json!({
|
||||
"content": &message.content,
|
||||
"msg_type": 0,
|
||||
"markdown": {
|
||||
"content": &message.content,
|
||||
},
|
||||
"msg_type": 2,
|
||||
}),
|
||||
)
|
||||
};
|
||||
@@ -667,4 +671,35 @@ allowed_users = ["user1"]
|
||||
|
||||
assert_eq!(compose_message_content(&payload), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_body_uses_markdown_msg_type() {
|
||||
// Verify the expected JSON shape for both group and user send paths.
|
||||
// msg_type 2 with a nested markdown object is required by the QQ API
|
||||
// for markdown rendering; msg_type 0 (plain text) causes markdown
|
||||
// syntax to appear literally in the client.
|
||||
let content = "**bold** and `code`";
|
||||
|
||||
let group_body = json!({
|
||||
"markdown": { "content": content },
|
||||
"msg_type": 2,
|
||||
});
|
||||
assert_eq!(group_body["msg_type"], 2);
|
||||
assert_eq!(group_body["markdown"]["content"], content);
|
||||
assert!(
|
||||
group_body.get("content").is_none(),
|
||||
"top-level 'content' must not be present"
|
||||
);
|
||||
|
||||
let user_body = json!({
|
||||
"markdown": { "content": content },
|
||||
"msg_type": 2,
|
||||
});
|
||||
assert_eq!(user_body["msg_type"], 2);
|
||||
assert_eq!(user_body["markdown"]["content"], content);
|
||||
assert!(
|
||||
user_body.get("content").is_none(),
|
||||
"top-level 'content' must not be present"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
//! Trait abstraction for session persistence backends.
|
||||
//!
|
||||
//! Backends store per-sender conversation histories. The trait is intentionally
|
||||
//! minimal — load, append, remove_last, list — so that JSONL and SQLite (and
|
||||
//! future backends) share a common interface.
|
||||
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Metadata about a persisted session.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionMetadata {
|
||||
/// Session key (e.g. `telegram_user123`).
|
||||
pub key: String,
|
||||
/// When the session was first created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the last message was appended.
|
||||
pub last_activity: DateTime<Utc>,
|
||||
/// Total number of messages in the session.
|
||||
pub message_count: usize,
|
||||
}
|
||||
|
||||
/// Query parameters for listing sessions.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SessionQuery {
|
||||
/// Keyword to search in session messages (FTS5 if available).
|
||||
pub keyword: Option<String>,
|
||||
/// Maximum number of sessions to return.
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
/// Trait for session persistence backends.
|
||||
///
|
||||
/// Implementations must be `Send + Sync` for sharing across async tasks.
|
||||
pub trait SessionBackend: Send + Sync {
|
||||
/// Load all messages for a session. Returns empty vec if session doesn't exist.
|
||||
fn load(&self, session_key: &str) -> Vec<ChatMessage>;
|
||||
|
||||
/// Append a single message to a session.
|
||||
fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()>;
|
||||
|
||||
/// Remove the last message from a session. Returns `true` if a message was removed.
|
||||
fn remove_last(&self, session_key: &str) -> std::io::Result<bool>;
|
||||
|
||||
/// List all session keys.
|
||||
fn list_sessions(&self) -> Vec<String>;
|
||||
|
||||
/// List sessions with metadata.
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
// Default: construct metadata from messages (backends can override for efficiency)
|
||||
self.list_sessions()
|
||||
.into_iter()
|
||||
.map(|key| {
|
||||
let messages = self.load(&key);
|
||||
SessionMetadata {
|
||||
key,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: messages.len(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compact a session file (remove duplicates/corruption). No-op by default.
|
||||
fn compact(&self, _session_key: &str) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove sessions that haven't been active within the given TTL hours.
|
||||
fn cleanup_stale(&self, _ttl_hours: u32) -> std::io::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Search sessions by keyword. Default returns empty (backends with FTS override).
|
||||
fn search(&self, _query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn session_metadata_is_constructible() {
|
||||
let meta = SessionMetadata {
|
||||
key: "test".into(),
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: 5,
|
||||
};
|
||||
assert_eq!(meta.key, "test");
|
||||
assert_eq!(meta.message_count, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_query_defaults() {
|
||||
let q = SessionQuery::default();
|
||||
assert!(q.keyword.is_none());
|
||||
assert!(q.limit.is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,503 @@
|
||||
//! SQLite-backed session persistence with FTS5 search.
|
||||
//!
|
||||
//! Stores sessions in `{workspace}/sessions/sessions.db` using WAL mode.
|
||||
//! Provides full-text search via FTS5 and automatic TTL-based cleanup.
|
||||
//! Designed as the default backend, replacing JSONL for new installations.
|
||||
|
||||
use crate::channels::session_backend::{SessionBackend, SessionMetadata, SessionQuery};
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// SQLite-backed session store with FTS5 and WAL mode.
|
||||
pub struct SqliteSessionBackend {
|
||||
conn: Mutex<Connection>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl SqliteSessionBackend {
|
||||
/// Open or create the sessions database.
|
||||
pub fn new(workspace_dir: &Path) -> Result<Self> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir).context("Failed to create sessions directory")?;
|
||||
let db_path = sessions_dir.join("sessions.db");
|
||||
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open session DB: {}", db_path.display()))?;
|
||||
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA temp_store = MEMORY;
|
||||
PRAGMA mmap_size = 4194304;",
|
||||
)?;
|
||||
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_key TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_key ON sessions(session_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_key_id ON sessions(session_key, id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS session_metadata (
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
last_activity TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
|
||||
session_key, content, content=sessions, content_rowid=id
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS sessions_ai AFTER INSERT ON sessions BEGIN
|
||||
INSERT INTO sessions_fts(rowid, session_key, content)
|
||||
VALUES (new.id, new.session_key, new.content);
|
||||
END;
|
||||
CREATE TRIGGER IF NOT EXISTS sessions_ad AFTER DELETE ON sessions BEGIN
|
||||
INSERT INTO sessions_fts(sessions_fts, rowid, session_key, content)
|
||||
VALUES ('delete', old.id, old.session_key, old.content);
|
||||
END;",
|
||||
)
|
||||
.context("Failed to initialize session schema")?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Migrate JSONL session files into SQLite. Renames migrated files to `.jsonl.migrated`.
|
||||
pub fn migrate_from_jsonl(&self, workspace_dir: &Path) -> Result<usize> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
let entries = match std::fs::read_dir(&sessions_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Ok(0),
|
||||
};
|
||||
|
||||
let mut migrated = 0;
|
||||
for entry in entries {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let name = match entry.file_name().into_string() {
|
||||
Ok(n) => n,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let Some(key) = name.strip_suffix(".jsonl") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let path = entry.path();
|
||||
let file = match std::fs::File::open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut count = 0;
|
||||
for line in std::io::BufRead::lines(reader) {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<ChatMessage>(trimmed) {
|
||||
if self.append(key, &msg).is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
let migrated_path = path.with_extension("jsonl.migrated");
|
||||
let _ = std::fs::rename(&path, &migrated_path);
|
||||
migrated += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(migrated)
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionBackend for SqliteSessionBackend {
|
||||
fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn
|
||||
.prepare("SELECT role, content FROM sessions WHERE session_key = ?1 ORDER BY id ASC")
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let rows = match stmt.query_map(params![session_key], |row| {
|
||||
Ok(ChatMessage {
|
||||
role: row.get(0)?,
|
||||
content: row.get(1)?,
|
||||
})
|
||||
}) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
}
|
||||
|
||||
fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
let conn = self.conn.lock();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (session_key, role, content, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
params![session_key, message.role, message.content, now],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
// Upsert metadata
|
||||
conn.execute(
|
||||
"INSERT INTO session_metadata (session_key, created_at, last_activity, message_count)
|
||||
VALUES (?1, ?2, ?3, 1)
|
||||
ON CONFLICT(session_key) DO UPDATE SET
|
||||
last_activity = excluded.last_activity,
|
||||
message_count = message_count + 1",
|
||||
params![session_key, now, now],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_last(&self, session_key: &str) -> std::io::Result<bool> {
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let last_id: Option<i64> = conn
|
||||
.query_row(
|
||||
"SELECT id FROM sessions WHERE session_key = ?1 ORDER BY id DESC LIMIT 1",
|
||||
params![session_key],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.ok();
|
||||
|
||||
let Some(id) = last_id else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?1", params![id])
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
// Update metadata count
|
||||
conn.execute(
|
||||
"UPDATE session_metadata SET message_count = MAX(0, message_count - 1)
|
||||
WHERE session_key = ?1",
|
||||
params![session_key],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Vec<String> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn
|
||||
.prepare("SELECT session_key FROM session_metadata ORDER BY last_activity DESC")
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let rows = match stmt.query_map([], |row| row.get(0)) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
}
|
||||
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_key, created_at, last_activity, message_count
|
||||
FROM session_metadata ORDER BY last_activity DESC",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let rows = match stmt.query_map([], |row| {
|
||||
let key: String = row.get(0)?;
|
||||
let created_str: String = row.get(1)?;
|
||||
let activity_str: String = row.get(2)?;
|
||||
let count: i64 = row.get(3)?;
|
||||
|
||||
let created = DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
let activity = DateTime::parse_from_rfc3339(&activity_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(SessionMetadata {
|
||||
key,
|
||||
created_at: created,
|
||||
last_activity: activity,
|
||||
message_count: count as usize,
|
||||
})
|
||||
}) {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
}
|
||||
|
||||
fn cleanup_stale(&self, ttl_hours: u32) -> std::io::Result<usize> {
|
||||
let conn = self.conn.lock();
|
||||
let cutoff = (Utc::now() - Duration::hours(i64::from(ttl_hours))).to_rfc3339();
|
||||
|
||||
// Find stale sessions
|
||||
let stale_keys: Vec<String> = {
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT session_key FROM session_metadata WHERE last_activity < ?1")
|
||||
.map_err(std::io::Error::other)?;
|
||||
let rows = stmt
|
||||
.query_map(params![cutoff], |row| row.get(0))
|
||||
.map_err(std::io::Error::other)?;
|
||||
rows.filter_map(|r| r.ok()).collect()
|
||||
};
|
||||
|
||||
let count = stale_keys.len();
|
||||
for key in &stale_keys {
|
||||
let _ = conn.execute("DELETE FROM sessions WHERE session_key = ?1", params![key]);
|
||||
let _ = conn.execute(
|
||||
"DELETE FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
);
|
||||
}
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
let Some(keyword) = &query.keyword else {
|
||||
return self.list_sessions_with_metadata();
|
||||
};
|
||||
|
||||
let conn = self.conn.lock();
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit = query.limit.unwrap_or(50) as i64;
|
||||
|
||||
// FTS5 search
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT DISTINCT f.session_key
|
||||
FROM sessions_fts f
|
||||
WHERE sessions_fts MATCH ?1
|
||||
LIMIT ?2",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Quote each word for FTS5
|
||||
let fts_query: String = keyword
|
||||
.split_whitespace()
|
||||
.map(|w| format!("\"{w}\""))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" OR ");
|
||||
|
||||
let keys: Vec<String> = match stmt.query_map(params![fts_query, limit], |row| row.get(0)) {
|
||||
Ok(r) => r.filter_map(|r| r.ok()).collect(),
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Look up metadata for matched sessions
|
||||
keys.iter()
|
||||
.filter_map(|key| {
|
||||
conn.query_row(
|
||||
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
|row| {
|
||||
let created_str: String = row.get(0)?;
|
||||
let activity_str: String = row.get(1)?;
|
||||
let count: i64 = row.get(2)?;
|
||||
Ok(SessionMetadata {
|
||||
key: key.clone(),
|
||||
created_at: DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
last_activity: DateTime::parse_from_rfc3339(&activity_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
message_count: count as usize,
|
||||
})
|
||||
},
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn round_trip_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend
|
||||
.append("user1", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
backend
|
||||
.append("user1", &ChatMessage::assistant("hi"))
|
||||
.unwrap();
|
||||
|
||||
let msgs = backend.load("user1");
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].role, "user");
|
||||
assert_eq!(msgs[1].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("u", &ChatMessage::user("a")).unwrap();
|
||||
backend.append("u", &ChatMessage::user("b")).unwrap();
|
||||
|
||||
assert!(backend.remove_last("u").unwrap());
|
||||
let msgs = backend.load("u");
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_empty_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
assert!(!backend.remove_last("nonexistent").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sessions_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("a", &ChatMessage::user("hi")).unwrap();
|
||||
backend.append("b", &ChatMessage::user("hey")).unwrap();
|
||||
|
||||
let sessions = backend.list_sessions();
|
||||
assert_eq!(sessions.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metadata_tracks_counts() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("a")).unwrap();
|
||||
backend.append("s1", &ChatMessage::user("b")).unwrap();
|
||||
backend.append("s1", &ChatMessage::user("c")).unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert_eq!(meta[0].message_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fts5_search_finds_content() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend
|
||||
.append(
|
||||
"code_chat",
|
||||
&ChatMessage::user("How do I parse JSON in Rust?"),
|
||||
)
|
||||
.unwrap();
|
||||
backend
|
||||
.append("weather", &ChatMessage::user("What's the weather today?"))
|
||||
.unwrap();
|
||||
|
||||
let results = backend.search(&SessionQuery {
|
||||
keyword: Some("Rust".into()),
|
||||
limit: Some(10),
|
||||
});
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "code_chat");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cleanup_stale_removes_old_sessions() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
// Insert a session with old timestamp
|
||||
{
|
||||
let conn = backend.conn.lock();
|
||||
let old_time = (Utc::now() - Duration::hours(100)).to_rfc3339();
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (session_key, role, content, created_at) VALUES (?1, ?2, ?3, ?4)",
|
||||
params!["old_session", "user", "ancient", old_time],
|
||||
).unwrap();
|
||||
conn.execute(
|
||||
"INSERT INTO session_metadata (session_key, created_at, last_activity, message_count) VALUES (?1, ?2, ?3, 1)",
|
||||
params!["old_session", old_time, old_time],
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
backend
|
||||
.append("new_session", &ChatMessage::user("fresh"))
|
||||
.unwrap();
|
||||
|
||||
let cleaned = backend.cleanup_stale(48).unwrap(); // 48h TTL
|
||||
assert_eq!(cleaned, 1);
|
||||
|
||||
let sessions = backend.list_sessions();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0], "new_session");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migrate_from_jsonl_imports_and_renames() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sessions_dir = tmp.path().join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir).unwrap();
|
||||
|
||||
// Create a JSONL file
|
||||
let jsonl_path = sessions_dir.join("test_user.jsonl");
|
||||
std::fs::write(
|
||||
&jsonl_path,
|
||||
"{\"role\":\"user\",\"content\":\"hello\"}\n{\"role\":\"assistant\",\"content\":\"hi\"}\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
let migrated = backend.migrate_from_jsonl(tmp.path()).unwrap();
|
||||
assert_eq!(migrated, 1);
|
||||
|
||||
// JSONL should be renamed
|
||||
assert!(!jsonl_path.exists());
|
||||
assert!(sessions_dir.join("test_user.jsonl.migrated").exists());
|
||||
|
||||
// Messages should be in SQLite
|
||||
let msgs = backend.load("test_user");
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].content, "hello");
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@
|
||||
//! one-per-line as JSON, never modifying old lines. On daemon restart, sessions
|
||||
//! are loaded from disk to restore conversation context.
|
||||
|
||||
use crate::channels::session_backend::SessionBackend;
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -78,6 +79,37 @@ impl SessionStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove the last message from a session's JSONL file.
|
||||
///
|
||||
/// Rewrite approach: load all messages, drop the last, rewrite. This is
|
||||
/// O(n) but rollbacks are rare.
|
||||
pub fn remove_last(&self, session_key: &str) -> std::io::Result<bool> {
|
||||
let mut messages = self.load(session_key);
|
||||
if messages.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
messages.pop();
|
||||
self.rewrite(session_key, &messages)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Compact a session file by rewriting only valid messages (removes corrupt lines).
|
||||
pub fn compact(&self, session_key: &str) -> std::io::Result<()> {
|
||||
let messages = self.load(session_key);
|
||||
self.rewrite(session_key, &messages)
|
||||
}
|
||||
|
||||
fn rewrite(&self, session_key: &str, messages: &[ChatMessage]) -> std::io::Result<()> {
|
||||
let path = self.session_path(session_key);
|
||||
let mut file = std::fs::File::create(&path)?;
|
||||
for msg in messages {
|
||||
let json = serde_json::to_string(msg)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
writeln!(file, "{json}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all session keys that have files on disk.
|
||||
pub fn list_sessions(&self) -> Vec<String> {
|
||||
let entries = match std::fs::read_dir(&self.sessions_dir) {
|
||||
@@ -95,6 +127,28 @@ impl SessionStore {
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionBackend for SessionStore {
|
||||
fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
self.load(session_key)
|
||||
}
|
||||
|
||||
fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
self.append(session_key, message)
|
||||
}
|
||||
|
||||
fn remove_last(&self, session_key: &str) -> std::io::Result<bool> {
|
||||
self.remove_last(session_key)
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Vec<String> {
|
||||
self.list_sessions()
|
||||
}
|
||||
|
||||
fn compact(&self, session_key: &str) -> std::io::Result<()> {
|
||||
self.compact(session_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -178,6 +232,63 @@ mod tests {
|
||||
assert_eq!(lines.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_drops_final_message() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("rm_test", &ChatMessage::user("first"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("rm_test", &ChatMessage::user("second"))
|
||||
.unwrap();
|
||||
|
||||
assert!(store.remove_last("rm_test").unwrap());
|
||||
let messages = store.load("rm_test");
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].content, "first");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_last_empty_returns_false() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
assert!(!store.remove_last("nonexistent").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compact_removes_corrupt_lines() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "compact_test";
|
||||
|
||||
let path = store.session_path(key);
|
||||
std::fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
let mut file = std::fs::File::create(&path).unwrap();
|
||||
writeln!(file, r#"{{"role":"user","content":"ok"}}"#).unwrap();
|
||||
writeln!(file, "corrupt line").unwrap();
|
||||
writeln!(file, r#"{{"role":"assistant","content":"hi"}}"#).unwrap();
|
||||
|
||||
store.compact(key).unwrap();
|
||||
|
||||
let raw = std::fs::read_to_string(&path).unwrap();
|
||||
assert_eq!(raw.trim().lines().count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_backend_trait_works_via_dyn() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let backend: &dyn SessionBackend = &store;
|
||||
|
||||
backend
|
||||
.append("trait_test", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
let msgs = backend.load("trait_test");
|
||||
assert_eq!(msgs.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_corrupt_lines_gracefully() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@@ -334,6 +334,13 @@ pub struct TelegramChannel {
|
||||
workspace_dir: Option<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum EditMessageResult {
|
||||
Success,
|
||||
NotModified,
|
||||
Failed(reqwest::StatusCode),
|
||||
}
|
||||
|
||||
impl TelegramChannel {
|
||||
pub fn new(bot_token: String, allowed_users: Vec<String>, mention_only: bool) -> Self {
|
||||
let normalized_allowed = Self::normalize_allowed_users(allowed_users);
|
||||
@@ -540,6 +547,20 @@ impl TelegramChannel {
|
||||
format!("{}/bot{}/{method}", self.api_base, self.bot_token)
|
||||
}
|
||||
|
||||
async fn classify_edit_message_response(resp: reqwest::Response) -> EditMessageResult {
|
||||
if resp.status().is_success() {
|
||||
return EditMessageResult::Success;
|
||||
}
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
if body.contains("message is not modified") {
|
||||
return EditMessageResult::NotModified;
|
||||
}
|
||||
|
||||
EditMessageResult::Failed(status)
|
||||
}
|
||||
|
||||
async fn fetch_bot_username(&self) -> anyhow::Result<String> {
|
||||
let resp = self.http_client().get(self.api_url("getMe")).send().await?;
|
||||
|
||||
@@ -2374,11 +2395,17 @@ impl Channel for TelegramChannel {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
match Self::classify_edit_message_response(resp).await {
|
||||
EditMessageResult::Success | EditMessageResult::NotModified => return Ok(()),
|
||||
EditMessageResult::Failed(status) => {
|
||||
tracing::debug!(
|
||||
status = ?status,
|
||||
"Telegram finalize_draft HTML edit failed; retrying without parse_mode"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Markdown failed — retry without parse_mode
|
||||
// HTML failed — retry without parse_mode
|
||||
let plain_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"message_id": id,
|
||||
@@ -2392,14 +2419,45 @@ impl Channel for TelegramChannel {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
match Self::classify_edit_message_response(resp).await {
|
||||
EditMessageResult::Success | EditMessageResult::NotModified => return Ok(()),
|
||||
EditMessageResult::Failed(status) => {
|
||||
tracing::warn!(
|
||||
status = ?status,
|
||||
"Telegram finalize_draft plain edit failed; attempting delete+send fallback"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Edit failed entirely — fall back to new message
|
||||
tracing::warn!("Telegram finalize_draft edit failed; falling back to sendMessage");
|
||||
self.send_text_chunks(text, &chat_id, thread_id.as_deref())
|
||||
.await
|
||||
let delete_resp = self
|
||||
.client
|
||||
.post(self.api_url("deleteMessage"))
|
||||
.json(&serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"message_id": id,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match delete_resp {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
self.send_text_chunks(text, &chat_id, thread_id.as_deref())
|
||||
.await
|
||||
}
|
||||
Ok(resp) => {
|
||||
tracing::warn!(
|
||||
status = ?resp.status(),
|
||||
"Telegram finalize_draft delete failed; skipping sendMessage to avoid duplicate"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"Telegram finalize_draft delete request failed: {err}; skipping sendMessage to avoid duplicate"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn cancel_draft(&self, recipient: &str, message_id: &str) -> anyhow::Result<()> {
|
||||
|
||||
@@ -78,6 +78,10 @@ pub async fn transcribe_audio(
|
||||
form = form.text("language", lang.clone());
|
||||
}
|
||||
|
||||
if let Some(ref prompt) = config.initial_prompt {
|
||||
form = form.text("prompt", prompt.clone());
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&config.api_url)
|
||||
.bearer_auth(&api_key)
|
||||
|
||||
@@ -0,0 +1,485 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
const TWITTER_API_BASE: &str = "https://api.x.com/2";
|
||||
|
||||
/// X/Twitter channel — uses the Twitter API v2 with OAuth 2.0 Bearer Token
|
||||
/// for sending tweets/DMs and filtered stream for receiving mentions.
|
||||
pub struct TwitterChannel {
|
||||
bearer_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
/// Message deduplication set.
|
||||
dedup: Arc<RwLock<HashSet<String>>>,
|
||||
}
|
||||
|
||||
/// Deduplication set capacity — evict half of entries when full.
|
||||
const DEDUP_CAPACITY: usize = 10_000;
|
||||
|
||||
impl TwitterChannel {
|
||||
pub fn new(bearer_token: String, allowed_users: Vec<String>) -> Self {
|
||||
Self {
|
||||
bearer_token,
|
||||
allowed_users,
|
||||
dedup: Arc::new(RwLock::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.twitter")
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
/// Check and insert tweet ID for deduplication.
|
||||
async fn is_duplicate(&self, tweet_id: &str) -> bool {
|
||||
if tweet_id.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut dedup = self.dedup.write().await;
|
||||
|
||||
if dedup.contains(tweet_id) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if dedup.len() >= DEDUP_CAPACITY {
|
||||
let to_remove: Vec<String> = dedup.iter().take(DEDUP_CAPACITY / 2).cloned().collect();
|
||||
for key in to_remove {
|
||||
dedup.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
dedup.insert(tweet_id.to_string());
|
||||
false
|
||||
}
|
||||
|
||||
/// Get the authenticated user's ID for filtered stream rules.
|
||||
async fn get_authenticated_user_id(&self) -> anyhow::Result<String> {
|
||||
let resp = self
|
||||
.http_client()
|
||||
.get(format!("{TWITTER_API_BASE}/users/me"))
|
||||
.bearer_auth(&self.bearer_token)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Twitter users/me failed ({status}): {err}");
|
||||
}
|
||||
|
||||
let data: serde_json::Value = resp.json().await?;
|
||||
let user_id = data
|
||||
.get("data")
|
||||
.and_then(|d| d.get("id"))
|
||||
.and_then(|id| id.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing user id in Twitter response"))?
|
||||
.to_string();
|
||||
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
/// Send a reply tweet.
|
||||
async fn create_tweet(
|
||||
&self,
|
||||
text: &str,
|
||||
reply_tweet_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut body = json!({ "text": text });
|
||||
|
||||
if let Some(reply_id) = reply_tweet_id {
|
||||
body["reply"] = json!({ "in_reply_to_tweet_id": reply_id });
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client()
|
||||
.post(format!("{TWITTER_API_BASE}/tweets"))
|
||||
.bearer_auth(&self.bearer_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Twitter create tweet failed ({status}): {err}");
|
||||
}
|
||||
|
||||
let data: serde_json::Value = resp.json().await?;
|
||||
let tweet_id = data
|
||||
.get("data")
|
||||
.and_then(|d| d.get("id"))
|
||||
.and_then(|id| id.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
Ok(tweet_id)
|
||||
}
|
||||
|
||||
/// Send a DM to a user.
|
||||
async fn send_dm(&self, recipient_id: &str, text: &str) -> anyhow::Result<()> {
|
||||
let body = json!({
|
||||
"text": text,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http_client()
|
||||
.post(format!(
|
||||
"{TWITTER_API_BASE}/dm_conversations/with/{recipient_id}/messages"
|
||||
))
|
||||
.bearer_auth(&self.bearer_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Twitter DM send failed ({status}): {err}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for TwitterChannel {
|
||||
fn name(&self) -> &str {
|
||||
"twitter"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
// recipient format: "dm:{user_id}" for DMs, "tweet:{tweet_id}" for replies
|
||||
if let Some(user_id) = message.recipient.strip_prefix("dm:") {
|
||||
// Twitter API enforces a 280 char limit on tweets but DMs can be up to 10000.
|
||||
self.send_dm(user_id, &message.content).await
|
||||
} else if let Some(tweet_id) = message.recipient.strip_prefix("tweet:") {
|
||||
// Split long replies into tweet threads (280 char limit).
|
||||
let chunks = split_tweet_text(&message.content, 280);
|
||||
let mut reply_to = tweet_id.to_string();
|
||||
for chunk in chunks {
|
||||
reply_to = self.create_tweet(&chunk, Some(&reply_to)).await?;
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
// Default: treat as tweet reply
|
||||
let chunks = split_tweet_text(&message.content, 280);
|
||||
let mut reply_to = message.recipient.clone();
|
||||
for chunk in chunks {
|
||||
reply_to = self.create_tweet(&chunk, Some(&reply_to)).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
tracing::info!("Twitter: authenticating...");
|
||||
let bot_user_id = self.get_authenticated_user_id().await?;
|
||||
tracing::info!("Twitter: authenticated as user {bot_user_id}");
|
||||
|
||||
// Poll mentions timeline (filtered stream requires elevated access).
|
||||
// Using mentions timeline polling as a more accessible approach.
|
||||
let mut since_id: Option<String> = None;
|
||||
let poll_interval = std::time::Duration::from_secs(15);
|
||||
|
||||
loop {
|
||||
let mut url = format!(
|
||||
"{TWITTER_API_BASE}/users/{bot_user_id}/mentions?tweet.fields=author_id,conversation_id,created_at&expansions=author_id&max_results=20"
|
||||
);
|
||||
|
||||
if let Some(ref id) = since_id {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(url, "&since_id={id}");
|
||||
}
|
||||
|
||||
match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.bearer_auth(&self.bearer_token)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let data: serde_json::Value = match resp.json().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!("Twitter: failed to parse mentions response: {e}");
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(tweets) = data.get("data").and_then(|d| d.as_array()) {
|
||||
// Build user lookup map from includes
|
||||
let user_map: std::collections::HashMap<String, String> = data
|
||||
.get("includes")
|
||||
.and_then(|i| i.get("users"))
|
||||
.and_then(|u| u.as_array())
|
||||
.map(|users| {
|
||||
users
|
||||
.iter()
|
||||
.filter_map(|u| {
|
||||
let id = u.get("id")?.as_str()?.to_string();
|
||||
let username = u.get("username")?.as_str()?.to_string();
|
||||
Some((id, username))
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Process tweets in chronological order (oldest first)
|
||||
for tweet in tweets.iter().rev() {
|
||||
let tweet_id = tweet.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let author_id = tweet
|
||||
.get("author_id")
|
||||
.and_then(|a| a.as_str())
|
||||
.unwrap_or("");
|
||||
let text = tweet.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
|
||||
// Skip own tweets
|
||||
if author_id == bot_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.is_duplicate(tweet_id).await {
|
||||
continue;
|
||||
}
|
||||
|
||||
let username = user_map
|
||||
.get(author_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| author_id.to_string());
|
||||
|
||||
if !self.is_user_allowed(&username) && !self.is_user_allowed(author_id)
|
||||
{
|
||||
tracing::debug!(
|
||||
"Twitter: ignoring mention from unauthorized user: {username}"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Strip the @mention from the text
|
||||
let clean_text = strip_at_mention(text, &bot_user_id);
|
||||
|
||||
if clean_text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let reply_target = format!("tweet:{tweet_id}");
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: username,
|
||||
reply_target,
|
||||
content: clean_text,
|
||||
channel: "twitter".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
thread_ts: tweet
|
||||
.get("conversation_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string()),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
tracing::warn!("Twitter: message channel closed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Track newest ID for pagination
|
||||
if since_id.as_deref().map_or(true, |s| tweet_id > s) {
|
||||
since_id = Some(tweet_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update newest_id from meta
|
||||
if let Some(newest) = data
|
||||
.get("meta")
|
||||
.and_then(|m| m.get("newest_id"))
|
||||
.and_then(|n| n.as_str())
|
||||
{
|
||||
since_id = Some(newest.to_string());
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
if status.as_u16() == 429 {
|
||||
// Rate limited — back off
|
||||
tracing::warn!("Twitter: rate limited, backing off 60s");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
|
||||
continue;
|
||||
}
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
tracing::warn!("Twitter: mentions request failed ({status}): {err}");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Twitter: mentions request error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.get_authenticated_user_id().await.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip @mention from the beginning of a tweet text.
|
||||
fn strip_at_mention(text: &str, _bot_user_id: &str) -> String {
|
||||
// Remove all leading @mentions (Twitter includes @bot_name at start of replies)
|
||||
let mut result = text;
|
||||
while let Some(rest) = result.strip_prefix('@') {
|
||||
// Skip past the username (until whitespace or end)
|
||||
match rest.find(char::is_whitespace) {
|
||||
Some(idx) => result = rest[idx..].trim_start(),
|
||||
None => return String::new(),
|
||||
}
|
||||
}
|
||||
result.to_string()
|
||||
}
|
||||
|
||||
/// Split text into tweet-sized chunks, breaking at word boundaries.
|
||||
fn split_tweet_text(text: &str, max_len: usize) -> Vec<String> {
|
||||
if text.len() <= max_len {
|
||||
return vec![text.to_string()];
|
||||
}
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut remaining = text;
|
||||
|
||||
while !remaining.is_empty() {
|
||||
if remaining.len() <= max_len {
|
||||
chunks.push(remaining.to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
// Find last space within limit
|
||||
let split_at = remaining[..max_len].rfind(' ').unwrap_or(max_len);
|
||||
|
||||
chunks.push(remaining[..split_at].to_string());
|
||||
remaining = remaining[split_at..].trim_start();
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_name() {
|
||||
let ch = TwitterChannel::new("token".into(), vec![]);
|
||||
assert_eq!(ch.name(), "twitter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_allowed_wildcard() {
|
||||
let ch = TwitterChannel::new("token".into(), vec!["*".into()]);
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_allowed_specific() {
|
||||
let ch = TwitterChannel::new("token".into(), vec!["user123".into()]);
|
||||
assert!(ch.is_user_allowed("user123"));
|
||||
assert!(!ch.is_user_allowed("other"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_denied_empty() {
|
||||
let ch = TwitterChannel::new("token".into(), vec![]);
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dedup() {
|
||||
let ch = TwitterChannel::new("token".into(), vec![]);
|
||||
assert!(!ch.is_duplicate("tweet1").await);
|
||||
assert!(ch.is_duplicate("tweet1").await);
|
||||
assert!(!ch.is_duplicate("tweet2").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dedup_empty_id() {
|
||||
let ch = TwitterChannel::new("token".into(), vec![]);
|
||||
assert!(!ch.is_duplicate("").await);
|
||||
assert!(!ch.is_duplicate("").await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_at_mention_single() {
|
||||
assert_eq!(strip_at_mention("@bot hello world", "123"), "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_at_mention_multiple() {
|
||||
assert_eq!(strip_at_mention("@bot @other hello", "123"), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_at_mention_only() {
|
||||
assert_eq!(strip_at_mention("@bot", "123"), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_at_mention_no_mention() {
|
||||
assert_eq!(strip_at_mention("hello world", "123"), "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_tweet_text_short() {
|
||||
let chunks = split_tweet_text("hello", 280);
|
||||
assert_eq!(chunks, vec!["hello"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_tweet_text_long() {
|
||||
let text = "a ".repeat(200);
|
||||
let chunks = split_tweet_text(text.trim(), 280);
|
||||
assert!(chunks.len() > 1);
|
||||
for chunk in &chunks {
|
||||
assert!(chunk.len() <= 280);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_tweet_text_no_spaces() {
|
||||
let text = "a".repeat(300);
|
||||
let chunks = split_tweet_text(&text, 280);
|
||||
assert_eq!(chunks.len(), 2);
|
||||
assert_eq!(chunks[0].len(), 280);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serde() {
|
||||
let toml_str = r#"
|
||||
bearer_token = "AAAA"
|
||||
allowed_users = ["user1"]
|
||||
"#;
|
||||
let config: crate::config::schema::TwitterConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.bearer_token, "AAAA");
|
||||
assert_eq!(config.allowed_users, vec!["user1"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serde_defaults() {
|
||||
let toml_str = r#"
|
||||
bearer_token = "tok"
|
||||
"#;
|
||||
let config: crate::config::schema::TwitterConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(config.allowed_users.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,8 @@ pub struct WhatsAppWebChannel {
|
||||
client: Arc<Mutex<Option<Arc<wa_rs::Client>>>>,
|
||||
/// Message sender channel
|
||||
tx: Arc<Mutex<Option<tokio::sync::mpsc::Sender<ChannelMessage>>>>,
|
||||
/// Voice transcription configuration
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
impl WhatsAppWebChannel {
|
||||
@@ -90,9 +92,19 @@ impl WhatsAppWebChannel {
|
||||
bot_handle: Arc::new(Mutex::new(None)),
|
||||
client: Arc::new(Mutex::new(None)),
|
||||
tx: Arc::new(Mutex::new(None)),
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure voice transcription.
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if a phone number is allowed (E.164 format: +1234567890)
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
fn is_number_allowed(&self, phone: &str) -> bool {
|
||||
@@ -380,17 +392,19 @@ impl Channel for WhatsAppWebChannel {
|
||||
let logout_tx_clone = logout_tx.clone();
|
||||
let retry_count_clone = retry_count.clone();
|
||||
let session_revoked_clone = session_revoked.clone();
|
||||
let transcription_config = self.transcription.clone();
|
||||
|
||||
let mut builder = Bot::builder()
|
||||
.with_backend(backend)
|
||||
.with_transport_factory(transport_factory)
|
||||
.with_http_client(http_client)
|
||||
.on_event(move |event, _client| {
|
||||
.on_event(move |event, client| {
|
||||
let tx_inner = tx_clone.clone();
|
||||
let allowed_numbers = allowed_numbers.clone();
|
||||
let logout_tx = logout_tx_clone.clone();
|
||||
let retry_count = retry_count_clone.clone();
|
||||
let session_revoked = session_revoked_clone.clone();
|
||||
let transcription_config = transcription_config.clone();
|
||||
async move {
|
||||
match event {
|
||||
Event::Message(msg, info) => {
|
||||
@@ -413,7 +427,7 @@ impl Channel for WhatsAppWebChannel {
|
||||
);
|
||||
|
||||
let mapped_phone = if sender_jid.is_lid() {
|
||||
_client.get_phone_number_from_lid(&sender_jid.user).await
|
||||
client.get_phone_number_from_lid(&sender_jid.user).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -430,14 +444,65 @@ impl Channel for WhatsAppWebChannel {
|
||||
})
|
||||
.cloned()
|
||||
{
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::debug!(
|
||||
"WhatsApp Web: ignoring empty or non-text message from {}",
|
||||
normalized
|
||||
let content = if !text.trim().is_empty() {
|
||||
text.trim().to_string()
|
||||
} else if let Some(ref audio) = msg.get_base_message().audio_message {
|
||||
let duration = audio.seconds.unwrap_or(0);
|
||||
tracing::info!(
|
||||
"WhatsApp Web audio from {} ({}s, ptt={})",
|
||||
normalized, duration, audio.ptt.unwrap_or(false)
|
||||
);
|
||||
|
||||
let config = match transcription_config.as_ref() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
tracing::debug!("WhatsApp Web: transcription disabled, ignoring audio");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if u64::from(duration) > config.max_duration_secs {
|
||||
tracing::info!(
|
||||
"WhatsApp Web: skipping audio ({}s > {}s limit)",
|
||||
duration, config.max_duration_secs
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let audio_data = match client.download(audio.as_ref()).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!("WhatsApp Web: failed to download audio: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let file_name = match audio.mimetype.as_deref() {
|
||||
Some(m) if m.contains("ogg") => "voice.ogg",
|
||||
Some(m) if m.contains("opus") => "voice.opus",
|
||||
Some(m) if m.contains("mp4") || m.contains("m4a") => "voice.m4a",
|
||||
Some(m) if m.contains("webm") => "voice.webm",
|
||||
_ => "voice.ogg",
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, file_name, config).await {
|
||||
Ok(t) if !t.trim().is_empty() => {
|
||||
tracing::info!("WhatsApp Web: transcribed audio from {}: {}", normalized, t.trim());
|
||||
t.trim().to_string()
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::info!("WhatsApp Web: transcription returned empty text");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("WhatsApp Web: transcription failed: {e}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("WhatsApp Web: ignoring non-text/non-audio message from {}", normalized);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = tx_inner
|
||||
.send(ChannelMessage {
|
||||
@@ -446,7 +511,7 @@ impl Channel for WhatsAppWebChannel {
|
||||
sender: normalized.clone(),
|
||||
// Reply to the originating chat JID (DM or group).
|
||||
reply_target: chat,
|
||||
content: trimmed.to_string(),
|
||||
content,
|
||||
timestamp: chrono::Utc::now().timestamp() as u64,
|
||||
thread_ts: None,
|
||||
})
|
||||
@@ -695,6 +760,10 @@ impl WhatsAppWebChannel {
|
||||
) -> Self {
|
||||
Self { _private: () }
|
||||
}
|
||||
|
||||
pub fn with_transcription(self, _config: crate::config::TranscriptionConfig) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
@@ -936,6 +1005,24 @@ mod tests {
|
||||
assert!(WhatsAppWebChannel::should_purge_session(&flag));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
fn with_transcription_sets_config_when_enabled() {
|
||||
let mut tc = crate::config::TranscriptionConfig::default();
|
||||
tc.enabled = true;
|
||||
|
||||
let ch = make_channel().with_transcription(tc);
|
||||
assert!(ch.transcription.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
fn with_transcription_ignores_when_disabled() {
|
||||
let tc = crate::config::TranscriptionConfig::default(); // enabled = false
|
||||
let ch = make_channel().with_transcription(tc);
|
||||
assert!(ch.transcription.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
fn session_file_paths_includes_wal_and_shm() {
|
||||
|
||||
+7
-6
@@ -11,12 +11,13 @@ pub use schema::{
|
||||
ComposioConfig, Config, ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig,
|
||||
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig,
|
||||
PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
GoogleWorkspaceConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
|
||||
McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig,
|
||||
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
|
||||
ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
|
||||
|
||||
+299
-3
@@ -220,6 +220,30 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub browser: BrowserConfig,
|
||||
|
||||
/// Browser delegation configuration (`[browser_delegate]`).
|
||||
///
|
||||
/// Delegates browser-based tasks to a browser-capable CLI subprocess (e.g.
|
||||
/// Claude Code with `claude-in-chrome` MCP tools). Useful for interacting
|
||||
/// with corporate web apps (Teams, Outlook, Jira, Confluence) that lack
|
||||
/// direct API access. A persistent Chrome profile can be configured so SSO
|
||||
/// sessions survive across invocations.
|
||||
///
|
||||
/// Fields:
|
||||
/// - `enabled` (`bool`, default `false`) — enable the browser delegation tool.
|
||||
/// - `cli_binary` (`String`, default `"claude"`) — CLI binary to spawn for browser tasks.
|
||||
/// - `chrome_profile_dir` (`String`, default `""`) — Chrome user-data directory for
|
||||
/// persistent SSO sessions. When empty, a fresh profile is used each invocation.
|
||||
/// - `allowed_domains` (`Vec<String>`, default `[]`) — allowlist of domains the browser
|
||||
/// may navigate to. Empty means all non-blocked domains are permitted.
|
||||
/// - `blocked_domains` (`Vec<String>`, default `[]`) — denylist of domains. Blocked
|
||||
/// domains take precedence over allowed domains.
|
||||
/// - `task_timeout_secs` (`u64`, default `120`) — per-task timeout in seconds.
|
||||
///
|
||||
/// Compatibility: additive and disabled by default; existing configs remain valid when omitted.
|
||||
/// Rollback/migration: remove `[browser_delegate]` or keep `enabled = false` to disable.
|
||||
#[serde(default)]
|
||||
pub browser_delegate: crate::tools::browser_delegate::BrowserDelegateConfig,
|
||||
|
||||
/// HTTP request tool configuration (`[http_request]`).
|
||||
#[serde(default)]
|
||||
pub http_request: HttpRequestConfig,
|
||||
@@ -236,6 +260,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub web_search: WebSearchConfig,
|
||||
|
||||
/// Google Workspace CLI (`gws`) tool configuration (`[google_workspace]`).
|
||||
#[serde(default)]
|
||||
pub google_workspace: GoogleWorkspaceConfig,
|
||||
|
||||
/// Project delivery intelligence configuration (`[project_intel]`).
|
||||
#[serde(default)]
|
||||
pub project_intel: ProjectIntelConfig,
|
||||
@@ -595,6 +623,11 @@ pub struct TranscriptionConfig {
|
||||
/// Optional language hint (ISO-639-1, e.g. "en", "ru").
|
||||
#[serde(default)]
|
||||
pub language: Option<String>,
|
||||
/// Optional initial prompt to bias transcription toward expected vocabulary
|
||||
/// (proper nouns, technical terms, etc.). Sent as the `prompt` field in the
|
||||
/// Whisper API request.
|
||||
#[serde(default)]
|
||||
pub initial_prompt: Option<String>,
|
||||
/// Maximum voice duration in seconds (messages longer than this are skipped).
|
||||
#[serde(default = "default_transcription_max_duration_secs")]
|
||||
pub max_duration_secs: u64,
|
||||
@@ -607,6 +640,7 @@ impl Default for TranscriptionConfig {
|
||||
api_url: default_transcription_api_url(),
|
||||
model: default_transcription_model(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: default_transcription_max_duration_secs(),
|
||||
}
|
||||
}
|
||||
@@ -1809,6 +1843,94 @@ impl Default for WebSearchConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Google Workspace ─────────────────────────────────────────────
|
||||
|
||||
/// Google Workspace CLI (`gws`) tool configuration (`[google_workspace]` section).
|
||||
///
|
||||
/// ## Defaults
|
||||
/// - `enabled`: `false` (tool is not registered unless explicitly opted-in).
|
||||
/// - `allowed_services`: empty vector, which grants access to the full default
|
||||
/// service set: `drive`, `sheets`, `gmail`, `calendar`, `docs`, `slides`,
|
||||
/// `tasks`, `people`, `chat`, `classroom`, `forms`, `keep`, `meet`, `events`.
|
||||
/// - `credentials_path`: `None` (uses default `gws` credential discovery).
|
||||
/// - `default_account`: `None` (uses the `gws` active account).
|
||||
/// - `rate_limit_per_minute`: `60`.
|
||||
/// - `timeout_secs`: `30`.
|
||||
/// - `audit_log`: `false`.
|
||||
/// - `credentials_path`: `None` (uses default `gws` credential discovery).
|
||||
/// - `default_account`: `None` (uses the `gws` active account).
|
||||
/// - `rate_limit_per_minute`: `60`.
|
||||
/// - `timeout_secs`: `30`.
|
||||
/// - `audit_log`: `false`.
|
||||
///
|
||||
/// ## Compatibility
|
||||
/// Configs that omit the `[google_workspace]` section entirely are treated as
|
||||
/// `GoogleWorkspaceConfig::default()` (disabled, all defaults allowed). Adding
|
||||
/// the section is purely opt-in and does not affect other config sections.
|
||||
///
|
||||
/// ## Rollback / Migration
|
||||
/// To revert, remove the `[google_workspace]` section from the config file (or
|
||||
/// set `enabled = false`). No data migration is required; the tool simply stops
|
||||
/// being registered.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GoogleWorkspaceConfig {
|
||||
/// Enable the `google_workspace` tool. Default: `false`.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Restrict which Google Workspace services the agent can access.
|
||||
///
|
||||
/// When empty (the default), the full default service set is allowed (see
|
||||
/// struct-level docs). When non-empty, only the listed service IDs are
|
||||
/// permitted. Each entry must be non-empty, lowercase alphanumeric with
|
||||
/// optional underscores/hyphens, and unique.
|
||||
#[serde(default)]
|
||||
pub allowed_services: Vec<String>,
|
||||
/// Path to service account JSON or OAuth client credentials file.
|
||||
///
|
||||
/// When `None`, the tool relies on the default `gws` credential discovery
|
||||
/// (`gws auth login`). Set this to point at a service-account key or an
|
||||
/// OAuth client-secrets JSON for headless / CI environments.
|
||||
#[serde(default)]
|
||||
pub credentials_path: Option<String>,
|
||||
/// Default Google account email to pass to `gws --account`.
|
||||
///
|
||||
/// When `None`, the currently active `gws` account is used.
|
||||
#[serde(default)]
|
||||
pub default_account: Option<String>,
|
||||
/// Maximum number of `gws` API calls allowed per minute. Default: `60`.
|
||||
#[serde(default = "default_gws_rate_limit")]
|
||||
pub rate_limit_per_minute: u32,
|
||||
/// Command execution timeout in seconds. Default: `30`.
|
||||
#[serde(default = "default_gws_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
/// Enable audit logging of every `gws` invocation (service, resource,
|
||||
/// method, timestamp). Default: `false`.
|
||||
#[serde(default)]
|
||||
pub audit_log: bool,
|
||||
}
|
||||
|
||||
fn default_gws_rate_limit() -> u32 {
|
||||
60
|
||||
}
|
||||
|
||||
fn default_gws_timeout_secs() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
impl Default for GoogleWorkspaceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_services: Vec::new(),
|
||||
credentials_path: None,
|
||||
default_account: None,
|
||||
rate_limit_per_minute: default_gws_rate_limit(),
|
||||
timeout_secs: default_gws_timeout_secs(),
|
||||
audit_log: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Project Intelligence ────────────────────────────────────────
|
||||
|
||||
/// Project delivery intelligence configuration (`[project_intel]` section).
|
||||
@@ -2318,10 +2440,10 @@ fn validate_proxy_url(field: &str, url: &str) -> Result<()> {
|
||||
.with_context(|| format!("Invalid {field} URL: '{url}' is not a valid URL"))?;
|
||||
|
||||
match parsed.scheme() {
|
||||
"http" | "https" | "socks5" | "socks5h" => {}
|
||||
"http" | "https" | "socks5" | "socks5h" | "socks" => {}
|
||||
scheme => {
|
||||
anyhow::bail!(
|
||||
"Invalid {field} URL scheme '{scheme}'. Allowed: http, https, socks5, socks5h"
|
||||
"Invalid {field} URL scheme '{scheme}'. Allowed: http, https, socks5, socks5h, socks"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -2650,6 +2772,9 @@ pub struct MemoryConfig {
|
||||
/// Max number of cached responses before LRU eviction (default: 5000)
|
||||
#[serde(default = "default_response_cache_max")]
|
||||
pub response_cache_max_entries: usize,
|
||||
/// Max in-memory hot cache entries for the two-tier response cache (default: 256)
|
||||
#[serde(default = "default_response_cache_hot_entries")]
|
||||
pub response_cache_hot_entries: usize,
|
||||
|
||||
// ── Memory Snapshot (soul backup to Markdown) ─────────────
|
||||
/// Enable periodic export of core memories to MEMORY_SNAPSHOT.md
|
||||
@@ -2718,6 +2843,10 @@ fn default_response_cache_max() -> usize {
|
||||
5_000
|
||||
}
|
||||
|
||||
fn default_response_cache_hot_entries() -> usize {
|
||||
256
|
||||
}
|
||||
|
||||
impl Default for MemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -2738,6 +2867,7 @@ impl Default for MemoryConfig {
|
||||
response_cache_enabled: false,
|
||||
response_cache_ttl_minutes: default_response_cache_ttl(),
|
||||
response_cache_max_entries: default_response_cache_max(),
|
||||
response_cache_hot_entries: default_response_cache_hot_entries(),
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
@@ -3344,12 +3474,48 @@ pub struct HeartbeatConfig {
|
||||
/// explicitly set).
|
||||
#[serde(default, alias = "recipient")]
|
||||
pub to: Option<String>,
|
||||
/// Enable adaptive intervals that back off on failures and speed up for
|
||||
/// high-priority tasks. Default: `false`.
|
||||
#[serde(default)]
|
||||
pub adaptive: bool,
|
||||
/// Minimum interval in minutes when adaptive mode is enabled. Default: `5`.
|
||||
#[serde(default = "default_heartbeat_min_interval")]
|
||||
pub min_interval_minutes: u32,
|
||||
/// Maximum interval in minutes when adaptive mode backs off. Default: `120`.
|
||||
#[serde(default = "default_heartbeat_max_interval")]
|
||||
pub max_interval_minutes: u32,
|
||||
/// Dead-man's switch timeout in minutes. If the heartbeat has not ticked
|
||||
/// within this window, an alert is sent. `0` disables. Default: `0`.
|
||||
#[serde(default)]
|
||||
pub deadman_timeout_minutes: u32,
|
||||
/// Channel for dead-man's switch alerts (e.g. `telegram`). Falls back to
|
||||
/// the heartbeat delivery channel.
|
||||
#[serde(default)]
|
||||
pub deadman_channel: Option<String>,
|
||||
/// Recipient for dead-man's switch alerts. Falls back to `to`.
|
||||
#[serde(default)]
|
||||
pub deadman_to: Option<String>,
|
||||
/// Maximum number of heartbeat run history records to retain. Default: `100`.
|
||||
#[serde(default = "default_heartbeat_max_run_history")]
|
||||
pub max_run_history: u32,
|
||||
}
|
||||
|
||||
fn default_two_phase() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_heartbeat_min_interval() -> u32 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_heartbeat_max_interval() -> u32 {
|
||||
120
|
||||
}
|
||||
|
||||
fn default_heartbeat_max_run_history() -> u32 {
|
||||
100
|
||||
}
|
||||
|
||||
impl Default for HeartbeatConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -3359,6 +3525,13 @@ impl Default for HeartbeatConfig {
|
||||
message: None,
|
||||
target: None,
|
||||
to: None,
|
||||
adaptive: false,
|
||||
min_interval_minutes: default_heartbeat_min_interval(),
|
||||
max_interval_minutes: default_heartbeat_max_interval(),
|
||||
deadman_timeout_minutes: 0,
|
||||
deadman_channel: None,
|
||||
deadman_to: None,
|
||||
max_run_history: default_heartbeat_max_run_history(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3524,6 +3697,7 @@ impl<T: ChannelConfig> crate::config::traits::ConfigHandle for ConfigWrapper<T>
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ChannelsConfig {
|
||||
/// Enable the CLI interactive channel. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub cli: bool,
|
||||
/// Telegram bot channel configuration.
|
||||
pub telegram: Option<TelegramConfig>,
|
||||
@@ -3563,6 +3737,10 @@ pub struct ChannelsConfig {
|
||||
pub wecom: Option<WeComConfig>,
|
||||
/// QQ Official Bot channel configuration.
|
||||
pub qq: Option<QQConfig>,
|
||||
/// X/Twitter channel configuration.
|
||||
pub twitter: Option<TwitterConfig>,
|
||||
/// Mochat customer service channel configuration.
|
||||
pub mochat: Option<MochatConfig>,
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub nostr: Option<NostrConfig>,
|
||||
/// ClawdTalk voice channel configuration.
|
||||
@@ -3587,6 +3765,13 @@ pub struct ChannelsConfig {
|
||||
/// daemon restarts. Files are stored in `{workspace}/sessions/`. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub session_persistence: bool,
|
||||
/// Session persistence backend: `"jsonl"` (legacy) or `"sqlite"` (new default).
|
||||
/// SQLite provides FTS5 search, metadata tracking, and TTL cleanup.
|
||||
#[serde(default = "default_session_backend")]
|
||||
pub session_backend: String,
|
||||
/// Auto-archive stale sessions older than this many hours. `0` disables. Default: `0`.
|
||||
#[serde(default)]
|
||||
pub session_ttl_hours: u32,
|
||||
}
|
||||
|
||||
impl ChannelsConfig {
|
||||
@@ -3692,6 +3877,10 @@ fn default_channel_message_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_session_backend() -> String {
|
||||
"sqlite".into()
|
||||
}
|
||||
|
||||
impl Default for ChannelsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -3715,6 +3904,8 @@ impl Default for ChannelsConfig {
|
||||
dingtalk: None,
|
||||
wecom: None,
|
||||
qq: None,
|
||||
twitter: None,
|
||||
mochat: None,
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
nostr: None,
|
||||
clawdtalk: None,
|
||||
@@ -3722,6 +3913,8 @@ impl Default for ChannelsConfig {
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3819,6 +4012,10 @@ pub struct SlackConfig {
|
||||
/// cancels the in-flight request and starts a fresh response with preserved history.
|
||||
#[serde(default)]
|
||||
pub interrupt_on_new_message: bool,
|
||||
/// When true, only respond to messages that @-mention the bot in groups.
|
||||
/// Direct messages remain allowed.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
}
|
||||
|
||||
impl ChannelConfig for SlackConfig {
|
||||
@@ -4733,6 +4930,53 @@ impl ChannelConfig for QQConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// X/Twitter channel configuration (Twitter API v2)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TwitterConfig {
|
||||
/// Twitter API v2 Bearer Token (OAuth 2.0)
|
||||
pub bearer_token: String,
|
||||
/// Allowed usernames or user IDs. Empty = deny all, "*" = allow all
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for TwitterConfig {
|
||||
fn name() -> &'static str {
|
||||
"X/Twitter"
|
||||
}
|
||||
fn desc() -> &'static str {
|
||||
"X/Twitter Bot via API v2"
|
||||
}
|
||||
}
|
||||
|
||||
/// Mochat channel configuration (Mochat customer service API)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct MochatConfig {
|
||||
/// Mochat API base URL
|
||||
pub api_url: String,
|
||||
/// Mochat API token
|
||||
pub api_token: String,
|
||||
/// Allowed user IDs. Empty = deny all, "*" = allow all
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Poll interval in seconds for new messages. Default: 5
|
||||
#[serde(default = "default_mochat_poll_interval")]
|
||||
pub poll_interval_secs: u64,
|
||||
}
|
||||
|
||||
fn default_mochat_poll_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
impl ChannelConfig for MochatConfig {
|
||||
fn name() -> &'static str {
|
||||
"Mochat"
|
||||
}
|
||||
fn desc() -> &'static str {
|
||||
"Mochat Customer Service"
|
||||
}
|
||||
}
|
||||
|
||||
/// Nostr channel configuration (NIP-04 + NIP-17 private messages)
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -5106,10 +5350,12 @@ impl Default for Config {
|
||||
microsoft365: Microsoft365Config::default(),
|
||||
secrets: SecretsConfig::default(),
|
||||
browser: BrowserConfig::default(),
|
||||
browser_delegate: crate::tools::browser_delegate::BrowserDelegateConfig::default(),
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
@@ -6268,6 +6514,28 @@ impl Config {
|
||||
validate_mcp_config(&self.mcp)?;
|
||||
}
|
||||
|
||||
// Google Workspace allowed_services validation
|
||||
let mut seen_gws_services = std::collections::HashSet::new();
|
||||
for (i, service) in self.google_workspace.allowed_services.iter().enumerate() {
|
||||
let normalized = service.trim();
|
||||
if normalized.is_empty() {
|
||||
anyhow::bail!("google_workspace.allowed_services[{i}] must not be empty");
|
||||
}
|
||||
if !normalized
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-')
|
||||
{
|
||||
anyhow::bail!(
|
||||
"google_workspace.allowed_services[{i}] contains invalid characters: {normalized}"
|
||||
);
|
||||
}
|
||||
if !seen_gws_services.insert(normalized.to_string()) {
|
||||
anyhow::bail!(
|
||||
"google_workspace.allowed_services contains duplicate entry: {normalized}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Project intelligence
|
||||
if self.project_intel.enabled {
|
||||
let lang = &self.project_intel.default_language;
|
||||
@@ -6320,7 +6588,6 @@ impl Config {
|
||||
if let Err(msg) = self.security.nevis.validate() {
|
||||
anyhow::bail!("security.nevis: {msg}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -7358,6 +7625,7 @@ default_temperature = 0.7
|
||||
message: Some("Check London time".into()),
|
||||
target: Some("telegram".into()),
|
||||
to: Some("123456".into()),
|
||||
..HeartbeatConfig::default()
|
||||
},
|
||||
cron: CronConfig::default(),
|
||||
channels_config: ChannelsConfig {
|
||||
@@ -7388,6 +7656,8 @@ default_temperature = 0.7
|
||||
dingtalk: None,
|
||||
wecom: None,
|
||||
qq: None,
|
||||
twitter: None,
|
||||
mochat: None,
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
nostr: None,
|
||||
clawdtalk: None,
|
||||
@@ -7395,6 +7665,8 @@ default_temperature = 0.7
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
},
|
||||
memory: MemoryConfig::default(),
|
||||
storage: StorageConfig::default(),
|
||||
@@ -7404,10 +7676,12 @@ default_temperature = 0.7
|
||||
microsoft365: Microsoft365Config::default(),
|
||||
secrets: SecretsConfig::default(),
|
||||
browser: BrowserConfig::default(),
|
||||
browser_delegate: crate::tools::browser_delegate::BrowserDelegateConfig::default(),
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
@@ -7706,10 +7980,12 @@ tool_dispatcher = "xml"
|
||||
microsoft365: Microsoft365Config::default(),
|
||||
secrets: SecretsConfig::default(),
|
||||
browser: BrowserConfig::default(),
|
||||
browser_delegate: crate::tools::browser_delegate::BrowserDelegateConfig::default(),
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
@@ -8121,12 +8397,16 @@ allowed_users = ["@ops:matrix.org"]
|
||||
dingtalk: None,
|
||||
wecom: None,
|
||||
qq: None,
|
||||
twitter: None,
|
||||
mochat: None,
|
||||
nostr: None,
|
||||
clawdtalk: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -8166,6 +8446,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8174,6 +8455,15 @@ allowed_users = ["@ops:matrix.org"]
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(parsed.allowed_users, vec!["U111"]);
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn slack_config_deserializes_with_mention_only() {
|
||||
let json = r#"{"bot_token":"xoxb-tok","mention_only":true}"#;
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.mention_only);
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8181,6 +8471,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
let json = r#"{"bot_token":"xoxb-tok","interrupt_on_new_message":true}"#;
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8203,6 +8494,7 @@ channel_id = "C123"
|
||||
let parsed: SlackConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
assert_eq!(parsed.channel_id.as_deref(), Some("C123"));
|
||||
}
|
||||
|
||||
@@ -8349,12 +8641,16 @@ channel_id = "C123"
|
||||
dingtalk: None,
|
||||
wecom: None,
|
||||
qq: None,
|
||||
twitter: None,
|
||||
mochat: None,
|
||||
nostr: None,
|
||||
clawdtalk: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
session_backend: default_session_backend(),
|
||||
session_ttl_hours: 0,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
|
||||
+168
-16
@@ -8,7 +8,8 @@ use tokio::time::Duration;
|
||||
|
||||
const STATUS_FLUSH_SECONDS: u64 = 5;
|
||||
|
||||
/// Wait for shutdown signal (SIGINT or SIGTERM)
|
||||
/// Wait for shutdown signal (SIGINT or SIGTERM).
|
||||
/// SIGHUP is explicitly ignored so the daemon survives terminal/SSH disconnects.
|
||||
async fn wait_for_shutdown_signal() -> Result<()> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
@@ -16,13 +17,21 @@ async fn wait_for_shutdown_signal() -> Result<()> {
|
||||
|
||||
let mut sigint = signal(SignalKind::interrupt())?;
|
||||
let mut sigterm = signal(SignalKind::terminate())?;
|
||||
let mut sighup = signal(SignalKind::hangup())?;
|
||||
|
||||
tokio::select! {
|
||||
_ = sigint.recv() => {
|
||||
tracing::info!("Received SIGINT, shutting down...");
|
||||
}
|
||||
_ = sigterm.recv() => {
|
||||
tracing::info!("Received SIGTERM, shutting down...");
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = sigint.recv() => {
|
||||
tracing::info!("Received SIGINT, shutting down...");
|
||||
break;
|
||||
}
|
||||
_ = sigterm.recv() => {
|
||||
tracing::info!("Received SIGTERM, shutting down...");
|
||||
break;
|
||||
}
|
||||
_ = sighup.recv() => {
|
||||
tracing::info!("Received SIGHUP, ignoring (daemon stays running)");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,7 +212,10 @@ where
|
||||
}
|
||||
|
||||
async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
use crate::heartbeat::engine::HeartbeatEngine;
|
||||
use crate::heartbeat::engine::{
|
||||
compute_adaptive_interval, HeartbeatEngine, HeartbeatTask, TaskPriority, TaskStatus,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
let observer: std::sync::Arc<dyn crate::observability::Observer> =
|
||||
std::sync::Arc::from(crate::observability::create_observer(&config.observability));
|
||||
@@ -212,19 +224,72 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
config.workspace_dir.clone(),
|
||||
observer,
|
||||
);
|
||||
let metrics = engine.metrics();
|
||||
let delivery = resolve_heartbeat_delivery(&config)?;
|
||||
let two_phase = config.heartbeat.two_phase;
|
||||
let adaptive = config.heartbeat.adaptive;
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
let interval_mins = config.heartbeat.interval_minutes.max(5);
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
|
||||
// ── Deadman watcher ──────────────────────────────────────────
|
||||
let deadman_timeout = config.heartbeat.deadman_timeout_minutes;
|
||||
if deadman_timeout > 0 {
|
||||
let dm_metrics = Arc::clone(&metrics);
|
||||
let dm_config = config.clone();
|
||||
let dm_delivery = delivery.clone();
|
||||
tokio::spawn(async move {
|
||||
let check_interval = Duration::from_secs(60);
|
||||
let timeout = chrono::Duration::minutes(i64::from(deadman_timeout));
|
||||
loop {
|
||||
tokio::time::sleep(check_interval).await;
|
||||
let last_tick = dm_metrics.lock().last_tick_at;
|
||||
if let Some(last) = last_tick {
|
||||
if chrono::Utc::now() - last > timeout {
|
||||
let alert = format!(
|
||||
"⚠️ Heartbeat dead-man's switch: no tick in {deadman_timeout} minutes"
|
||||
);
|
||||
let (channel, target) =
|
||||
if let Some(ch) = &dm_config.heartbeat.deadman_channel {
|
||||
let to = dm_config
|
||||
.heartbeat
|
||||
.deadman_to
|
||||
.as_deref()
|
||||
.or(dm_config.heartbeat.to.as_deref())
|
||||
.unwrap_or_default();
|
||||
(ch.clone(), to.to_string())
|
||||
} else if let Some((ch, to)) = &dm_delivery {
|
||||
(ch.clone(), to.clone())
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let _ = crate::cron::scheduler::deliver_announcement(
|
||||
&dm_config, &channel, &target, &alert,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let base_interval = config.heartbeat.interval_minutes.max(5);
|
||||
let mut sleep_mins = base_interval;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
tokio::time::sleep(Duration::from_secs(u64::from(sleep_mins) * 60)).await;
|
||||
|
||||
// Update uptime
|
||||
{
|
||||
let mut m = metrics.lock();
|
||||
m.uptime_secs = start_time.elapsed().as_secs();
|
||||
}
|
||||
|
||||
let tick_start = std::time::Instant::now();
|
||||
|
||||
// Collect runnable tasks (active only, sorted by priority)
|
||||
let mut tasks = engine.collect_runnable_tasks().await?;
|
||||
let has_high_priority = tasks.iter().any(|t| t.priority == TaskPriority::High);
|
||||
|
||||
if tasks.is_empty() {
|
||||
// Try fallback message
|
||||
if let Some(fallback) = config
|
||||
.heartbeat
|
||||
.message
|
||||
@@ -232,12 +297,15 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
.map(str::trim)
|
||||
.filter(|m| !m.is_empty())
|
||||
{
|
||||
tasks.push(crate::heartbeat::engine::HeartbeatTask {
|
||||
tasks.push(HeartbeatTask {
|
||||
text: fallback.to_string(),
|
||||
priority: crate::heartbeat::engine::TaskPriority::Medium,
|
||||
status: crate::heartbeat::engine::TaskStatus::Active,
|
||||
priority: TaskPriority::Medium,
|
||||
status: TaskStatus::Active,
|
||||
});
|
||||
} else {
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
metrics.lock().record_success(elapsed);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -250,7 +318,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
Some(decision_prompt),
|
||||
None,
|
||||
None,
|
||||
0.0, // Low temperature for deterministic decision
|
||||
0.0,
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
@@ -263,6 +331,9 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
if indices.is_empty() {
|
||||
tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)");
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
metrics.lock().record_success(elapsed);
|
||||
continue;
|
||||
}
|
||||
tracing::info!(
|
||||
@@ -285,7 +356,9 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
let mut tick_had_error = false;
|
||||
for task in &tasks_to_run {
|
||||
let task_start = std::time::Instant::now();
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let temp = config.default_temperature;
|
||||
match Box::pin(crate::agent::run(
|
||||
@@ -303,6 +376,20 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
{
|
||||
Ok(output) => {
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let duration_ms = task_start.elapsed().as_millis() as i64;
|
||||
let now = chrono::Utc::now();
|
||||
let _ = crate::heartbeat::store::record_run(
|
||||
&config.workspace_dir,
|
||||
&task.text,
|
||||
&task.priority.to_string(),
|
||||
now - chrono::Duration::milliseconds(duration_ms),
|
||||
now,
|
||||
"ok",
|
||||
Some(output.as_str()),
|
||||
duration_ms,
|
||||
config.heartbeat.max_run_history,
|
||||
);
|
||||
let announcement = if output.trim().is_empty() {
|
||||
format!("💓 heartbeat task completed: {}", task.text)
|
||||
} else {
|
||||
@@ -326,11 +413,52 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tick_had_error = true;
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let duration_ms = task_start.elapsed().as_millis() as i64;
|
||||
let now = chrono::Utc::now();
|
||||
let _ = crate::heartbeat::store::record_run(
|
||||
&config.workspace_dir,
|
||||
&task.text,
|
||||
&task.priority.to_string(),
|
||||
now - chrono::Duration::milliseconds(duration_ms),
|
||||
now,
|
||||
"error",
|
||||
Some(&e.to_string()),
|
||||
duration_ms,
|
||||
config.heartbeat.max_run_history,
|
||||
);
|
||||
crate::health::mark_component_error("heartbeat", e.to_string());
|
||||
tracing::warn!("Heartbeat task failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let tick_elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
{
|
||||
let mut m = metrics.lock();
|
||||
if tick_had_error {
|
||||
m.record_failure(tick_elapsed);
|
||||
} else {
|
||||
m.record_success(tick_elapsed);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute next sleep interval
|
||||
if adaptive {
|
||||
let failures = metrics.lock().consecutive_failures;
|
||||
sleep_mins = compute_adaptive_interval(
|
||||
base_interval,
|
||||
config.heartbeat.min_interval_minutes,
|
||||
config.heartbeat.max_interval_minutes,
|
||||
failures,
|
||||
has_high_priority,
|
||||
);
|
||||
} else {
|
||||
sleep_mins = base_interval;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -655,4 +783,28 @@ mod tests {
|
||||
let target = auto_detect_heartbeat_channel(&config);
|
||||
assert!(target.is_none());
|
||||
}
|
||||
|
||||
/// Verify that SIGHUP does not cause shutdown — the daemon should ignore it
|
||||
/// and only terminate on SIGINT or SIGTERM.
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn sighup_does_not_shut_down_daemon() {
|
||||
use libc;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
let handle = tokio::spawn(wait_for_shutdown_signal());
|
||||
|
||||
// Give the signal handler time to register
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Send SIGHUP to ourselves — should be ignored by the handler
|
||||
unsafe { libc::raise(libc::SIGHUP) };
|
||||
|
||||
// The future should NOT complete within a short window
|
||||
let result = timeout(Duration::from_millis(200), handle).await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"wait_for_shutdown_signal should not return after SIGHUP"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+90
-16
@@ -75,6 +75,22 @@ fn nextcloud_talk_memory_key(msg: &crate::channels::traits::ChannelMessage) -> S
|
||||
format!("nextcloud_talk_{}_{}", msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn sender_session_id(channel: &str, msg: &crate::channels::traits::ChannelMessage) -> String {
|
||||
match &msg.thread_ts {
|
||||
Some(thread_id) => format!("{channel}_{thread_id}_{}", msg.sender),
|
||||
None => format!("{channel}_{}", msg.sender),
|
||||
}
|
||||
}
|
||||
|
||||
fn webhook_session_id(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("X-Session-Id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(str::to_owned)
|
||||
}
|
||||
|
||||
fn hash_webhook_secret(value: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
@@ -908,9 +924,13 @@ async fn run_gateway_chat_simple(state: &AppState, message: &str) -> anyhow::Res
|
||||
}
|
||||
|
||||
/// Full-featured chat with tools for channel handlers (WhatsApp, Linq, Nextcloud Talk).
|
||||
async fn run_gateway_chat_with_tools(state: &AppState, message: &str) -> anyhow::Result<String> {
|
||||
async fn run_gateway_chat_with_tools(
|
||||
state: &AppState,
|
||||
message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let config = state.config.lock().clone();
|
||||
Box::pin(crate::agent::process_message(config, message)).await
|
||||
Box::pin(crate::agent::process_message(config, message, session_id)).await
|
||||
}
|
||||
|
||||
/// Webhook request body
|
||||
@@ -1002,12 +1022,18 @@ async fn handle_webhook(
|
||||
}
|
||||
|
||||
let message = &webhook_body.message;
|
||||
let session_id = webhook_session_id(&headers);
|
||||
|
||||
if state.auto_save {
|
||||
if state.auto_save && !memory::should_skip_autosave_content(message) {
|
||||
let key = webhook_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, message, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
message,
|
||||
MemoryCategory::Conversation,
|
||||
session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1228,17 +1254,29 @@ async fn handle_whatsapp_message(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = sender_session_id("whatsapp", msg);
|
||||
|
||||
// Auto-save to memory
|
||||
if state.auto_save {
|
||||
if state.auto_save && !memory::should_skip_autosave_content(&msg.content) {
|
||||
let key = whatsapp_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(
|
||||
&state,
|
||||
&msg.content,
|
||||
Some(&session_id),
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
// Send reply via WhatsApp
|
||||
if let Err(e) = wa
|
||||
@@ -1335,18 +1373,30 @@ async fn handle_linq_webhook(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = sender_session_id("linq", msg);
|
||||
|
||||
// Auto-save to memory
|
||||
if state.auto_save {
|
||||
if state.auto_save && !memory::should_skip_autosave_content(&msg.content) {
|
||||
let key = linq_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(
|
||||
&state,
|
||||
&msg.content,
|
||||
Some(&session_id),
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
// Send reply via Linq
|
||||
if let Err(e) = linq
|
||||
@@ -1427,18 +1477,30 @@ async fn handle_wati_webhook(State(state): State<AppState>, body: Bytes) -> impl
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = sender_session_id("wati", msg);
|
||||
|
||||
// Auto-save to memory
|
||||
if state.auto_save {
|
||||
if state.auto_save && !memory::should_skip_autosave_content(&msg.content) {
|
||||
let key = wati_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Call the LLM
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(
|
||||
&state,
|
||||
&msg.content,
|
||||
Some(&session_id),
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
// Send reply via WATI
|
||||
if let Err(e) = wati
|
||||
@@ -1533,16 +1595,28 @@ async fn handle_nextcloud_talk_webhook(
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 50)
|
||||
);
|
||||
let session_id = sender_session_id("nextcloud_talk", msg);
|
||||
|
||||
if state.auto_save {
|
||||
if state.auto_save && !memory::should_skip_autosave_content(&msg.content) {
|
||||
let key = nextcloud_talk_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.store(
|
||||
&key,
|
||||
&msg.content,
|
||||
MemoryCategory::Conversation,
|
||||
Some(&session_id),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
match Box::pin(run_gateway_chat_with_tools(&state, &msg.content)).await {
|
||||
match Box::pin(run_gateway_chat_with_tools(
|
||||
&state,
|
||||
&msg.content,
|
||||
Some(&session_id),
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
if let Err(e) = nextcloud_talk
|
||||
.send(&SendMessage::new(response, &msg.reply_target))
|
||||
|
||||
+2
-1
@@ -116,7 +116,7 @@ pub async fn handle_ws_chat(
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: AppState, _session_id: Option<String>) {
|
||||
async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<String>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Build a persistent Agent for this connection so history is maintained across turns.
|
||||
@@ -129,6 +129,7 @@ async fn handle_socket(socket: WebSocket, state: AppState, _session_id: Option<S
|
||||
return;
|
||||
}
|
||||
};
|
||||
agent.set_memory_session_id(session_id.clone());
|
||||
|
||||
while let Some(msg) = receiver.next().await {
|
||||
let msg = match msg {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use crate::config::HeartbeatConfig;
|
||||
use crate::observability::{Observer, ObserverEvent};
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use parking_lot::Mutex as ParkingMutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
@@ -68,6 +70,99 @@ impl fmt::Display for HeartbeatTask {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Health Metrics ───────────────────────────────────────────────
|
||||
|
||||
/// Live health metrics for the heartbeat subsystem.
|
||||
///
|
||||
/// Shared via `Arc<ParkingMutex<>>` between the heartbeat worker,
|
||||
/// deadman watcher, and API consumers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HeartbeatMetrics {
|
||||
/// Monotonic uptime since the heartbeat loop started.
|
||||
pub uptime_secs: u64,
|
||||
/// Consecutive successful ticks (resets on failure).
|
||||
pub consecutive_successes: u64,
|
||||
/// Consecutive failed ticks (resets on success).
|
||||
pub consecutive_failures: u64,
|
||||
/// Timestamp of the most recent tick (UTC RFC 3339).
|
||||
pub last_tick_at: Option<DateTime<Utc>>,
|
||||
/// Exponential moving average of tick durations in milliseconds.
|
||||
pub avg_tick_duration_ms: f64,
|
||||
/// Total number of ticks executed since startup.
|
||||
pub total_ticks: u64,
|
||||
}
|
||||
|
||||
impl Default for HeartbeatMetrics {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
uptime_secs: 0,
|
||||
consecutive_successes: 0,
|
||||
consecutive_failures: 0,
|
||||
last_tick_at: None,
|
||||
avg_tick_duration_ms: 0.0,
|
||||
total_ticks: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HeartbeatMetrics {
|
||||
/// Record a successful tick with the given duration.
|
||||
pub fn record_success(&mut self, duration_ms: f64) {
|
||||
self.consecutive_successes += 1;
|
||||
self.consecutive_failures = 0;
|
||||
self.last_tick_at = Some(Utc::now());
|
||||
self.total_ticks += 1;
|
||||
self.update_avg_duration(duration_ms);
|
||||
}
|
||||
|
||||
/// Record a failed tick with the given duration.
|
||||
pub fn record_failure(&mut self, duration_ms: f64) {
|
||||
self.consecutive_failures += 1;
|
||||
self.consecutive_successes = 0;
|
||||
self.last_tick_at = Some(Utc::now());
|
||||
self.total_ticks += 1;
|
||||
self.update_avg_duration(duration_ms);
|
||||
}
|
||||
|
||||
fn update_avg_duration(&mut self, duration_ms: f64) {
|
||||
const ALPHA: f64 = 0.3; // EMA smoothing factor
|
||||
if self.total_ticks == 1 {
|
||||
self.avg_tick_duration_ms = duration_ms;
|
||||
} else {
|
||||
self.avg_tick_duration_ms =
|
||||
ALPHA * duration_ms + (1.0 - ALPHA) * self.avg_tick_duration_ms;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the adaptive interval for the next heartbeat tick.
|
||||
///
|
||||
/// Strategy:
|
||||
/// - On failures: exponential back-off `base * 2^failures` capped at `max_interval`.
|
||||
/// - When high-priority tasks are present: use `min_interval` for faster reaction.
|
||||
/// - Otherwise: use `base_interval`.
|
||||
pub fn compute_adaptive_interval(
|
||||
base_minutes: u32,
|
||||
min_minutes: u32,
|
||||
max_minutes: u32,
|
||||
consecutive_failures: u64,
|
||||
has_high_priority_tasks: bool,
|
||||
) -> u32 {
|
||||
if consecutive_failures > 0 {
|
||||
let backoff = base_minutes.saturating_mul(
|
||||
1u32.checked_shl(consecutive_failures.min(10) as u32)
|
||||
.unwrap_or(u32::MAX),
|
||||
);
|
||||
return backoff.min(max_minutes).max(min_minutes);
|
||||
}
|
||||
|
||||
if has_high_priority_tasks {
|
||||
return min_minutes.max(5); // never go below 5 minutes
|
||||
}
|
||||
|
||||
base_minutes.clamp(min_minutes, max_minutes)
|
||||
}
|
||||
|
||||
// ── Engine ───────────────────────────────────────────────────────
|
||||
|
||||
/// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically
|
||||
@@ -75,6 +170,7 @@ pub struct HeartbeatEngine {
|
||||
config: HeartbeatConfig,
|
||||
workspace_dir: std::path::PathBuf,
|
||||
observer: Arc<dyn Observer>,
|
||||
metrics: Arc<ParkingMutex<HeartbeatMetrics>>,
|
||||
}
|
||||
|
||||
impl HeartbeatEngine {
|
||||
@@ -87,9 +183,15 @@ impl HeartbeatEngine {
|
||||
config,
|
||||
workspace_dir,
|
||||
observer,
|
||||
metrics: Arc::new(ParkingMutex::new(HeartbeatMetrics::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a shared handle to the live heartbeat metrics.
|
||||
pub fn metrics(&self) -> Arc<ParkingMutex<HeartbeatMetrics>> {
|
||||
Arc::clone(&self.metrics)
|
||||
}
|
||||
|
||||
/// Start the heartbeat loop (runs until cancelled)
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
if !self.config.enabled {
|
||||
@@ -673,4 +775,79 @@ mod tests {
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
// ── HeartbeatMetrics tests ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn metrics_record_success_updates_fields() {
|
||||
let mut m = HeartbeatMetrics::default();
|
||||
m.record_success(100.0);
|
||||
assert_eq!(m.consecutive_successes, 1);
|
||||
assert_eq!(m.consecutive_failures, 0);
|
||||
assert_eq!(m.total_ticks, 1);
|
||||
assert!(m.last_tick_at.is_some());
|
||||
assert!((m.avg_tick_duration_ms - 100.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_record_failure_resets_successes() {
|
||||
let mut m = HeartbeatMetrics::default();
|
||||
m.record_success(50.0);
|
||||
m.record_success(50.0);
|
||||
m.record_failure(200.0);
|
||||
assert_eq!(m.consecutive_successes, 0);
|
||||
assert_eq!(m.consecutive_failures, 1);
|
||||
assert_eq!(m.total_ticks, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_ema_smoothing() {
|
||||
let mut m = HeartbeatMetrics::default();
|
||||
m.record_success(100.0);
|
||||
assert!((m.avg_tick_duration_ms - 100.0).abs() < f64::EPSILON);
|
||||
m.record_success(200.0);
|
||||
// EMA: 0.3 * 200 + 0.7 * 100 = 130
|
||||
assert!((m.avg_tick_duration_ms - 130.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Adaptive interval tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_uses_base_when_no_failures() {
|
||||
let result = compute_adaptive_interval(30, 5, 120, 0, false);
|
||||
assert_eq!(result, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_uses_min_for_high_priority() {
|
||||
let result = compute_adaptive_interval(30, 5, 120, 0, true);
|
||||
assert_eq!(result, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_backs_off_on_failures() {
|
||||
// 1 failure: 30 * 2 = 60
|
||||
assert_eq!(compute_adaptive_interval(30, 5, 120, 1, false), 60);
|
||||
// 2 failures: 30 * 4 = 120 (capped at max)
|
||||
assert_eq!(compute_adaptive_interval(30, 5, 120, 2, false), 120);
|
||||
// 3 failures: 30 * 8 = 240 → capped at 120
|
||||
assert_eq!(compute_adaptive_interval(30, 5, 120, 3, false), 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_backoff_respects_min() {
|
||||
// Even with failures, must be >= min
|
||||
assert!(compute_adaptive_interval(5, 10, 120, 0, false) >= 10);
|
||||
}
|
||||
|
||||
// ── Engine metrics accessor ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn engine_exposes_shared_metrics() {
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine =
|
||||
HeartbeatEngine::new(HeartbeatConfig::default(), std::env::temp_dir(), observer);
|
||||
let metrics = engine.metrics();
|
||||
assert_eq!(metrics.lock().total_ticks, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod engine;
|
||||
pub mod store;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
//! SQLite persistence for heartbeat task execution history.
|
||||
//!
|
||||
//! Mirrors the `cron/store.rs` pattern: fresh connection per call, schema
|
||||
//! auto-created, output truncated, history pruned to a configurable limit.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const MAX_OUTPUT_BYTES: usize = 16 * 1024;
|
||||
const TRUNCATED_MARKER: &str = "\n...[truncated]";
|
||||
|
||||
/// A single heartbeat task execution record.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeartbeatRun {
|
||||
pub id: i64,
|
||||
pub task_text: String,
|
||||
pub task_priority: String,
|
||||
pub started_at: DateTime<Utc>,
|
||||
pub finished_at: DateTime<Utc>,
|
||||
pub status: String, // "ok" or "error"
|
||||
pub output: Option<String>,
|
||||
pub duration_ms: i64,
|
||||
}
|
||||
|
||||
/// Record a heartbeat task execution and prune old entries.
|
||||
pub fn record_run(
|
||||
workspace_dir: &Path,
|
||||
task_text: &str,
|
||||
task_priority: &str,
|
||||
started_at: DateTime<Utc>,
|
||||
finished_at: DateTime<Utc>,
|
||||
status: &str,
|
||||
output: Option<&str>,
|
||||
duration_ms: i64,
|
||||
max_history: u32,
|
||||
) -> Result<()> {
|
||||
let bounded_output = output.map(truncate_output);
|
||||
with_connection(workspace_dir, |conn| {
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
tx.execute(
|
||||
"INSERT INTO heartbeat_runs
|
||||
(task_text, task_priority, started_at, finished_at, status, output, duration_ms)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
||||
params![
|
||||
task_text,
|
||||
task_priority,
|
||||
started_at.to_rfc3339(),
|
||||
finished_at.to_rfc3339(),
|
||||
status,
|
||||
bounded_output.as_deref(),
|
||||
duration_ms,
|
||||
],
|
||||
)
|
||||
.context("Failed to insert heartbeat run")?;
|
||||
|
||||
let keep = i64::from(max_history.max(1));
|
||||
tx.execute(
|
||||
"DELETE FROM heartbeat_runs
|
||||
WHERE id NOT IN (
|
||||
SELECT id FROM heartbeat_runs
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT ?1
|
||||
)",
|
||||
params![keep],
|
||||
)
|
||||
.context("Failed to prune heartbeat run history")?;
|
||||
|
||||
tx.commit()
|
||||
.context("Failed to commit heartbeat run transaction")?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// List the most recent heartbeat runs.
|
||||
pub fn list_runs(workspace_dir: &Path, limit: usize) -> Result<Vec<HeartbeatRun>> {
|
||||
with_connection(workspace_dir, |conn| {
|
||||
let lim = i64::try_from(limit.max(1)).context("Run history limit overflow")?;
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, task_text, task_priority, started_at, finished_at, status, output, duration_ms
|
||||
FROM heartbeat_runs
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT ?1",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![lim], |row| {
|
||||
Ok(HeartbeatRun {
|
||||
id: row.get(0)?,
|
||||
task_text: row.get(1)?,
|
||||
task_priority: row.get(2)?,
|
||||
started_at: parse_rfc3339(&row.get::<_, String>(3)?).map_err(sql_err)?,
|
||||
finished_at: parse_rfc3339(&row.get::<_, String>(4)?).map_err(sql_err)?,
|
||||
status: row.get(5)?,
|
||||
output: row.get(6)?,
|
||||
duration_ms: row.get(7)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut runs = Vec::new();
|
||||
for row in rows {
|
||||
runs.push(row?);
|
||||
}
|
||||
Ok(runs)
|
||||
})
|
||||
}
|
||||
|
||||
/// Get aggregate stats: (total_runs, total_ok, total_error).
|
||||
pub fn run_stats(workspace_dir: &Path) -> Result<(u64, u64, u64)> {
|
||||
with_connection(workspace_dir, |conn| {
|
||||
let total: i64 = conn.query_row("SELECT COUNT(*) FROM heartbeat_runs", [], |r| r.get(0))?;
|
||||
let ok: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM heartbeat_runs WHERE status = 'ok'",
|
||||
[],
|
||||
|r| r.get(0),
|
||||
)?;
|
||||
let err: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM heartbeat_runs WHERE status = 'error'",
|
||||
[],
|
||||
|r| r.get(0),
|
||||
)?;
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
Ok((total as u64, ok as u64, err as u64))
|
||||
})
|
||||
}
|
||||
|
||||
fn db_path(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("heartbeat").join("history.db")
|
||||
}
|
||||
|
||||
fn with_connection<T>(workspace_dir: &Path, f: impl FnOnce(&Connection) -> Result<T>) -> Result<T> {
|
||||
let path = db_path(workspace_dir);
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).with_context(|| {
|
||||
format!("Failed to create heartbeat directory: {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&path)
|
||||
.with_context(|| format!("Failed to open heartbeat history DB: {}", path.display()))?;
|
||||
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA temp_store = MEMORY;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS heartbeat_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_text TEXT NOT NULL,
|
||||
task_priority TEXT NOT NULL,
|
||||
started_at TEXT NOT NULL,
|
||||
finished_at TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
output TEXT,
|
||||
duration_ms INTEGER
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_hb_runs_started ON heartbeat_runs(started_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_hb_runs_task ON heartbeat_runs(task_text);",
|
||||
)
|
||||
.context("Failed to initialize heartbeat history schema")?;
|
||||
|
||||
f(&conn)
|
||||
}
|
||||
|
||||
fn truncate_output(output: &str) -> String {
|
||||
if output.len() <= MAX_OUTPUT_BYTES {
|
||||
return output.to_string();
|
||||
}
|
||||
|
||||
if MAX_OUTPUT_BYTES <= TRUNCATED_MARKER.len() {
|
||||
return TRUNCATED_MARKER.to_string();
|
||||
}
|
||||
|
||||
let mut cutoff = MAX_OUTPUT_BYTES - TRUNCATED_MARKER.len();
|
||||
while cutoff > 0 && !output.is_char_boundary(cutoff) {
|
||||
cutoff -= 1;
|
||||
}
|
||||
|
||||
let mut truncated = output[..cutoff].to_string();
|
||||
truncated.push_str(TRUNCATED_MARKER);
|
||||
truncated
|
||||
}
|
||||
|
||||
fn parse_rfc3339(raw: &str) -> Result<DateTime<Utc>> {
|
||||
let parsed = DateTime::parse_from_rfc3339(raw)
|
||||
.with_context(|| format!("Invalid RFC3339 timestamp in heartbeat DB: {raw}"))?;
|
||||
Ok(parsed.with_timezone(&Utc))
|
||||
}
|
||||
|
||||
fn sql_err(err: anyhow::Error) -> rusqlite::Error {
|
||||
rusqlite::Error::ToSqlConversionFailure(err.into())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn record_and_list_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let base = Utc::now();
|
||||
|
||||
for i in 0..3 {
|
||||
let start = base + ChronoDuration::seconds(i);
|
||||
let end = start + ChronoDuration::milliseconds(100);
|
||||
record_run(
|
||||
tmp.path(),
|
||||
&format!("Task {i}"),
|
||||
"medium",
|
||||
start,
|
||||
end,
|
||||
"ok",
|
||||
Some("done"),
|
||||
100,
|
||||
50,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let runs = list_runs(tmp.path(), 10).unwrap();
|
||||
assert_eq!(runs.len(), 3);
|
||||
// Most recent first
|
||||
assert!(runs[0].task_text.contains('2'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prunes_old_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let base = Utc::now();
|
||||
|
||||
for i in 0..5 {
|
||||
let start = base + ChronoDuration::seconds(i);
|
||||
let end = start + ChronoDuration::milliseconds(50);
|
||||
record_run(
|
||||
tmp.path(),
|
||||
"Task",
|
||||
"high",
|
||||
start,
|
||||
end,
|
||||
"ok",
|
||||
None,
|
||||
50,
|
||||
2, // keep only 2
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let runs = list_runs(tmp.path(), 10).unwrap();
|
||||
assert_eq!(runs.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_stats_counts_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let now = Utc::now();
|
||||
|
||||
record_run(tmp.path(), "A", "high", now, now, "ok", None, 10, 50).unwrap();
|
||||
record_run(
|
||||
tmp.path(),
|
||||
"B",
|
||||
"low",
|
||||
now,
|
||||
now,
|
||||
"error",
|
||||
Some("fail"),
|
||||
20,
|
||||
50,
|
||||
)
|
||||
.unwrap();
|
||||
record_run(tmp.path(), "C", "medium", now, now, "ok", None, 15, 50).unwrap();
|
||||
|
||||
let (total, ok, err) = run_stats(tmp.path()).unwrap();
|
||||
assert_eq!(total, 3);
|
||||
assert_eq!(ok, 2);
|
||||
assert_eq!(err, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_large_output() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let now = Utc::now();
|
||||
let big = "x".repeat(MAX_OUTPUT_BYTES + 512);
|
||||
|
||||
record_run(
|
||||
tmp.path(),
|
||||
"T",
|
||||
"medium",
|
||||
now,
|
||||
now,
|
||||
"ok",
|
||||
Some(&big),
|
||||
10,
|
||||
50,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let runs = list_runs(tmp.path(), 1).unwrap();
|
||||
let stored = runs[0].output.as_deref().unwrap_or_default();
|
||||
assert!(stored.ends_with(TRUNCATED_MARKER));
|
||||
assert!(stored.len() <= MAX_OUTPUT_BYTES);
|
||||
}
|
||||
}
|
||||
@@ -509,6 +509,18 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
},
|
||||
},
|
||||
// ── Productivity ────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Google Workspace",
|
||||
description: "Drive, Gmail, Calendar, Sheets, Docs via gws CLI",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |c| {
|
||||
if c.google_workspace.enabled {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "GitHub",
|
||||
description: "Code, issues, PRs",
|
||||
@@ -606,7 +618,13 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
name: "Browser",
|
||||
description: "Chrome/Chromium control",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
status_fn: |c| {
|
||||
if c.browser.enabled {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Shell",
|
||||
@@ -624,7 +642,13 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
name: "Cron",
|
||||
description: "Scheduled tasks",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
status_fn: |c| {
|
||||
if c.cron.enabled {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Voice",
|
||||
@@ -917,6 +941,54 @@ mod tests {
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_active_when_enabled() {
|
||||
let mut config = Config::default();
|
||||
config.cron.enabled = true;
|
||||
let entries = all_integrations();
|
||||
let cron = entries.iter().find(|e| e.name == "Cron").unwrap();
|
||||
assert!(matches!(
|
||||
(cron.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_available_when_disabled() {
|
||||
let mut config = Config::default();
|
||||
config.cron.enabled = false;
|
||||
let entries = all_integrations();
|
||||
let cron = entries.iter().find(|e| e.name == "Cron").unwrap();
|
||||
assert!(matches!(
|
||||
(cron.status_fn)(&config),
|
||||
IntegrationStatus::Available
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn browser_active_when_enabled() {
|
||||
let mut config = Config::default();
|
||||
config.browser.enabled = true;
|
||||
let entries = all_integrations();
|
||||
let browser = entries.iter().find(|e| e.name == "Browser").unwrap();
|
||||
assert!(matches!(
|
||||
(browser.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn browser_available_when_disabled() {
|
||||
let mut config = Config::default();
|
||||
config.browser.enabled = false;
|
||||
let entries = all_integrations();
|
||||
let browser = entries.iter().find(|e| e.name == "Browser").unwrap();
|
||||
assert!(matches!(
|
||||
(browser.status_fn)(&config),
|
||||
IntegrationStatus::Available
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_and_filesystem_always_active() {
|
||||
let config = Config::default();
|
||||
|
||||
@@ -45,10 +45,12 @@ pub async fn consolidate_turn(
|
||||
// Truncate very long turns to avoid wasting tokens on consolidation.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8 (e.g. CJK text).
|
||||
let truncated = if turn_text.len() > 4000 {
|
||||
let mut end = 4000;
|
||||
while end > 0 && !turn_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
let end = turn_text
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 4000)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
format!("{}…", &turn_text[..end])
|
||||
} else {
|
||||
turn_text.clone()
|
||||
@@ -99,10 +101,12 @@ fn parse_consolidation_response(raw: &str, fallback_text: &str) -> Consolidation
|
||||
// Fallback: use truncated turn text as history entry.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let summary = if fallback_text.len() > 200 {
|
||||
let mut end = 200;
|
||||
while end > 0 && !fallback_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
let end = fallback_text
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 200)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
format!("{}…", &fallback_text[..end])
|
||||
} else {
|
||||
fallback_text.to_string()
|
||||
|
||||
@@ -90,6 +90,20 @@ pub fn is_assistant_autosave_key(key: &str) -> bool {
|
||||
normalized == "assistant_resp" || normalized.starts_with("assistant_resp_")
|
||||
}
|
||||
|
||||
/// Filter known synthetic autosave noise patterns that should not be
|
||||
/// persisted as user conversation memories.
|
||||
pub fn should_skip_autosave_content(content: &str) -> bool {
|
||||
let normalized = content.trim();
|
||||
if normalized.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let lowered = normalized.to_ascii_lowercase();
|
||||
lowered.starts_with("[cron:")
|
||||
|| lowered.starts_with("[distilled_")
|
||||
|| lowered.contains("distilled_index_sig:")
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
struct ResolvedEmbeddingConfig {
|
||||
provider: String,
|
||||
@@ -450,6 +464,17 @@ mod tests {
|
||||
assert!(!is_assistant_autosave_key("user_msg_1234"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn autosave_content_filter_drops_cron_and_distilled_noise() {
|
||||
assert!(should_skip_autosave_content("[cron:auto] patrol check"));
|
||||
assert!(should_skip_autosave_content(
|
||||
"[DISTILLED_MEMORY_CHUNK 1/2] DISTILLED_INDEX_SIG:abc123"
|
||||
));
|
||||
assert!(!should_skip_autosave_content(
|
||||
"User prefers concise answers."
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
+130
-27
@@ -10,23 +10,45 @@ use chrono::{Duration, Local};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Response cache backed by a dedicated SQLite database.
|
||||
/// An in-memory hot cache entry for the two-tier response cache.
|
||||
struct InMemoryEntry {
|
||||
response: String,
|
||||
token_count: u32,
|
||||
created_at: std::time::Instant,
|
||||
accessed_at: std::time::Instant,
|
||||
}
|
||||
|
||||
/// Two-tier response cache: in-memory LRU (hot) + SQLite (warm).
|
||||
///
|
||||
/// Lives alongside `brain.db` as `response_cache.db` so it can be
|
||||
/// independently wiped without touching memories.
|
||||
/// The hot cache avoids SQLite round-trips for frequently repeated prompts.
|
||||
/// On miss from hot cache, falls through to SQLite. On hit from SQLite,
|
||||
/// the entry is promoted to the hot cache.
|
||||
pub struct ResponseCache {
|
||||
conn: Mutex<Connection>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
ttl_minutes: i64,
|
||||
max_entries: usize,
|
||||
hot_cache: Mutex<HashMap<String, InMemoryEntry>>,
|
||||
hot_max_entries: usize,
|
||||
}
|
||||
|
||||
impl ResponseCache {
|
||||
/// Open (or create) the response cache database.
|
||||
pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result<Self> {
|
||||
Self::with_hot_cache(workspace_dir, ttl_minutes, max_entries, 256)
|
||||
}
|
||||
|
||||
/// Open (or create) the response cache database with a custom hot cache size.
|
||||
pub fn with_hot_cache(
|
||||
workspace_dir: &Path,
|
||||
ttl_minutes: u32,
|
||||
max_entries: usize,
|
||||
hot_max_entries: usize,
|
||||
) -> Result<Self> {
|
||||
let db_dir = workspace_dir.join("memory");
|
||||
std::fs::create_dir_all(&db_dir)?;
|
||||
let db_path = db_dir.join("response_cache.db");
|
||||
@@ -58,6 +80,8 @@ impl ResponseCache {
|
||||
db_path,
|
||||
ttl_minutes: i64::from(ttl_minutes),
|
||||
max_entries,
|
||||
hot_cache: Mutex::new(HashMap::new()),
|
||||
hot_max_entries,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -76,35 +100,77 @@ impl ResponseCache {
|
||||
}
|
||||
|
||||
/// Look up a cached response. Returns `None` on miss or expired entry.
|
||||
///
|
||||
/// Two-tier lookup: checks the in-memory hot cache first, then falls
|
||||
/// through to SQLite. On a SQLite hit the entry is promoted to hot cache.
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT response FROM response_cache
|
||||
WHERE prompt_hash = ?1 AND created_at > ?2",
|
||||
)?;
|
||||
|
||||
let result: Option<String> = stmt.query_row(params![key, cutoff], |row| row.get(0)).ok();
|
||||
|
||||
if result.is_some() {
|
||||
// Bump hit count and accessed_at
|
||||
let now_str = now.to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
// Tier 1: hot cache (with TTL check)
|
||||
{
|
||||
let mut hot = self.hot_cache.lock();
|
||||
if let Some(entry) = hot.get_mut(key) {
|
||||
let ttl = std::time::Duration::from_secs(self.ttl_minutes as u64 * 60);
|
||||
if entry.created_at.elapsed() > ttl {
|
||||
hot.remove(key);
|
||||
} else {
|
||||
entry.accessed_at = std::time::Instant::now();
|
||||
let response = entry.response.clone();
|
||||
drop(hot);
|
||||
// Still bump SQLite hit count for accurate stats
|
||||
let conn = self.conn.lock();
|
||||
let now_str = Local::now().to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
return Ok(Some(response));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
// Tier 2: SQLite (warm)
|
||||
let result: Option<(String, u32)> = {
|
||||
let conn = self.conn.lock();
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT response, token_count FROM response_cache
|
||||
WHERE prompt_hash = ?1 AND created_at > ?2",
|
||||
)?;
|
||||
|
||||
let result: Option<(String, u32)> = stmt
|
||||
.query_row(params![key, cutoff], |row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.ok();
|
||||
|
||||
if result.is_some() {
|
||||
let now_str = now.to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
}
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
if let Some((ref response, token_count)) = result {
|
||||
self.promote_to_hot(key, response, token_count);
|
||||
}
|
||||
|
||||
Ok(result.map(|(r, _)| r))
|
||||
}
|
||||
|
||||
/// Store a response in the cache.
|
||||
/// Store a response in the cache (both hot and warm tiers).
|
||||
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
|
||||
// Write to hot cache
|
||||
self.promote_to_hot(key, response, token_count);
|
||||
|
||||
// Write to SQLite (warm)
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now().to_rfc3339();
|
||||
@@ -138,6 +204,43 @@ impl ResponseCache {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Promote an entry to the in-memory hot cache, evicting the oldest if full.
|
||||
fn promote_to_hot(&self, key: &str, response: &str, token_count: u32) {
|
||||
let mut hot = self.hot_cache.lock();
|
||||
|
||||
// If already present, just update (keep original created_at for TTL)
|
||||
if let Some(entry) = hot.get_mut(key) {
|
||||
entry.response = response.to_string();
|
||||
entry.token_count = token_count;
|
||||
entry.accessed_at = std::time::Instant::now();
|
||||
return;
|
||||
}
|
||||
|
||||
// Evict oldest entry if at capacity
|
||||
if self.hot_max_entries > 0 && hot.len() >= self.hot_max_entries {
|
||||
if let Some(oldest_key) = hot
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.accessed_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
hot.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
if self.hot_max_entries > 0 {
|
||||
let now = std::time::Instant::now();
|
||||
hot.insert(
|
||||
key.to_string(),
|
||||
InMemoryEntry {
|
||||
response: response.to_string(),
|
||||
token_count,
|
||||
created_at: now,
|
||||
accessed_at: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
|
||||
pub fn stats(&self) -> Result<(usize, u64, u64)> {
|
||||
let conn = self.conn.lock();
|
||||
@@ -163,8 +266,8 @@ impl ResponseCache {
|
||||
|
||||
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
|
||||
pub fn clear(&self) -> Result<usize> {
|
||||
self.hot_cache.lock().clear();
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let affected = conn.execute("DELETE FROM response_cache", [])?;
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
+28
-2
@@ -27,8 +27,7 @@ impl std::fmt::Debug for MemoryEntry {
|
||||
}
|
||||
|
||||
/// Memory categories for organization
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum MemoryCategory {
|
||||
/// Long-term facts, preferences, decisions
|
||||
Core,
|
||||
@@ -40,6 +39,24 @@ pub enum MemoryCategory {
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl serde::Serialize for MemoryCategory {
|
||||
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for MemoryCategory {
|
||||
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Ok(match s.as_str() {
|
||||
"core" => Self::Core,
|
||||
"daily" => Self::Daily,
|
||||
"conversation" => Self::Conversation,
|
||||
_ => Self::Custom(s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MemoryCategory {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
@@ -120,6 +137,15 @@ mod tests {
|
||||
assert_eq!(conversation, "\"conversation\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_category_custom_roundtrip() {
|
||||
let custom = MemoryCategory::Custom("project_notes".into());
|
||||
let json = serde_json::to_string(&custom).unwrap();
|
||||
assert_eq!(json, "\"project_notes\"");
|
||||
let parsed: MemoryCategory = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, custom);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_entry_roundtrip_preserves_optional_fields() {
|
||||
let entry = MemoryEntry {
|
||||
|
||||
@@ -566,4 +566,28 @@ mod tests {
|
||||
.expect("payload should be extracted");
|
||||
assert_eq!(payload, "abcd==");
|
||||
}
|
||||
|
||||
/// Stripping `[IMAGE:]` markers from history messages leaves only the text
|
||||
/// portion, which is the behaviour needed for non-vision providers (#3674).
|
||||
#[test]
|
||||
fn parse_image_markers_strips_markers_leaving_caption() {
|
||||
let input = "[IMAGE:/tmp/photo.jpg]\n\nDescribe this screenshot";
|
||||
let (cleaned, refs) = parse_image_markers(input);
|
||||
assert_eq!(cleaned, "Describe this screenshot");
|
||||
assert_eq!(refs.len(), 1);
|
||||
assert_eq!(refs[0], "/tmp/photo.jpg");
|
||||
}
|
||||
|
||||
/// An image-only message (no caption) should produce an empty string after
|
||||
/// marker stripping, so callers can drop it from history.
|
||||
#[test]
|
||||
fn parse_image_markers_image_only_message_becomes_empty() {
|
||||
let input = "[IMAGE:/tmp/photo.jpg]";
|
||||
let (cleaned, refs) = parse_image_markers(input);
|
||||
assert!(
|
||||
cleaned.is_empty(),
|
||||
"expected empty string, got: {cleaned:?}"
|
||||
);
|
||||
assert_eq!(refs.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,15 @@ impl Observer for LogObserver {
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
info!("heartbeat.tick");
|
||||
}
|
||||
ObserverEvent::CacheHit {
|
||||
cache_type,
|
||||
tokens_saved,
|
||||
} => {
|
||||
info!(cache_type = %cache_type, tokens_saved = tokens_saved, "cache.hit");
|
||||
}
|
||||
ObserverEvent::CacheMiss { cache_type } => {
|
||||
info!(cache_type = %cache_type, "cache.miss");
|
||||
}
|
||||
ObserverEvent::Error { component, message } => {
|
||||
info!(component = %component, error = %message, "error");
|
||||
}
|
||||
@@ -83,6 +92,23 @@ impl Observer for LogObserver {
|
||||
"llm.response"
|
||||
);
|
||||
}
|
||||
ObserverEvent::HandStarted { hand_name } => {
|
||||
info!(hand = %hand_name, "hand.started");
|
||||
}
|
||||
ObserverEvent::HandCompleted {
|
||||
hand_name,
|
||||
duration_ms,
|
||||
findings_count,
|
||||
} => {
|
||||
info!(hand = %hand_name, duration_ms = duration_ms, findings = findings_count, "hand.completed");
|
||||
}
|
||||
ObserverEvent::HandFailed {
|
||||
hand_name,
|
||||
error,
|
||||
duration_ms,
|
||||
} => {
|
||||
info!(hand = %hand_name, error = %error, duration_ms = duration_ms, "hand.failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +127,19 @@ impl Observer for LogObserver {
|
||||
ObserverMetric::QueueDepth(d) => {
|
||||
info!(depth = d, "metric.queue_depth");
|
||||
}
|
||||
ObserverMetric::HandRunDuration {
|
||||
hand_name,
|
||||
duration,
|
||||
} => {
|
||||
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
||||
info!(hand = %hand_name, duration_ms = ms, "metric.hand_run_duration");
|
||||
}
|
||||
ObserverMetric::HandFindingsCount { hand_name, count } => {
|
||||
info!(hand = %hand_name, count = count, "metric.hand_findings_count");
|
||||
}
|
||||
ObserverMetric::HandSuccessRate { hand_name, success } => {
|
||||
info!(hand = %hand_name, success = success, "metric.hand_success_rate");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,4 +226,39 @@ mod tests {
|
||||
obs.record_metric(&ObserverMetric::ActiveSessions(1));
|
||||
obs.record_metric(&ObserverMetric::QueueDepth(999));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_observer_hand_events_no_panic() {
|
||||
let obs = LogObserver::new();
|
||||
obs.record_event(&ObserverEvent::HandStarted {
|
||||
hand_name: "review".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 1500,
|
||||
findings_count: 3,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandFailed {
|
||||
hand_name: "review".into(),
|
||||
error: "timeout".into(),
|
||||
duration_ms: 5000,
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_observer_hand_metrics_no_panic() {
|
||||
let obs = LogObserver::new();
|
||||
obs.record_metric(&ObserverMetric::HandRunDuration {
|
||||
hand_name: "review".into(),
|
||||
duration: Duration::from_millis(1500),
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandFindingsCount {
|
||||
hand_name: "review".into(),
|
||||
count: 5,
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandSuccessRate {
|
||||
hand_name: "review".into(),
|
||||
success: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,4 +80,39 @@ mod tests {
|
||||
fn noop_flush_does_not_panic() {
|
||||
NoopObserver.flush();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noop_hand_events_do_not_panic() {
|
||||
let obs = NoopObserver;
|
||||
obs.record_event(&ObserverEvent::HandStarted {
|
||||
hand_name: "review".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 1500,
|
||||
findings_count: 3,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandFailed {
|
||||
hand_name: "review".into(),
|
||||
error: "timeout".into(),
|
||||
duration_ms: 5000,
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noop_hand_metrics_do_not_panic() {
|
||||
let obs = NoopObserver;
|
||||
obs.record_metric(&ObserverMetric::HandRunDuration {
|
||||
hand_name: "review".into(),
|
||||
duration: Duration::from_millis(1500),
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandFindingsCount {
|
||||
hand_name: "review".into(),
|
||||
count: 5,
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandSuccessRate {
|
||||
hand_name: "review".into(),
|
||||
success: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,9 @@ pub struct OtelObserver {
|
||||
tokens_used: Counter<u64>,
|
||||
active_sessions: Gauge<u64>,
|
||||
queue_depth: Gauge<u64>,
|
||||
hand_runs: Counter<u64>,
|
||||
hand_duration: Histogram<f64>,
|
||||
hand_findings: Counter<u64>,
|
||||
}
|
||||
|
||||
impl OtelObserver {
|
||||
@@ -152,6 +155,22 @@ impl OtelObserver {
|
||||
.with_description("Current message queue depth")
|
||||
.build();
|
||||
|
||||
let hand_runs = meter
|
||||
.u64_counter("zeroclaw.hand.runs")
|
||||
.with_description("Total hand runs")
|
||||
.build();
|
||||
|
||||
let hand_duration = meter
|
||||
.f64_histogram("zeroclaw.hand.duration")
|
||||
.with_description("Hand run duration in seconds")
|
||||
.with_unit("s")
|
||||
.build();
|
||||
|
||||
let hand_findings = meter
|
||||
.u64_counter("zeroclaw.hand.findings")
|
||||
.with_description("Total findings produced by hand runs")
|
||||
.build();
|
||||
|
||||
Ok(Self {
|
||||
tracer_provider,
|
||||
meter_provider: meter_provider_clone,
|
||||
@@ -168,6 +187,9 @@ impl OtelObserver {
|
||||
tokens_used,
|
||||
active_sessions,
|
||||
queue_depth,
|
||||
hand_runs,
|
||||
hand_duration,
|
||||
hand_findings,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -335,6 +357,77 @@ impl Observer for OtelObserver {
|
||||
self.errors
|
||||
.add(1, &[KeyValue::new("component", component.clone())]);
|
||||
}
|
||||
ObserverEvent::HandStarted { .. } => {}
|
||||
ObserverEvent::HandCompleted {
|
||||
hand_name,
|
||||
duration_ms,
|
||||
findings_count,
|
||||
} => {
|
||||
let secs = *duration_ms as f64 / 1000.0;
|
||||
let duration = std::time::Duration::from_millis(*duration_ms);
|
||||
let start_time = SystemTime::now()
|
||||
.checked_sub(duration)
|
||||
.unwrap_or(SystemTime::now());
|
||||
|
||||
let mut span = tracer.build(
|
||||
opentelemetry::trace::SpanBuilder::from_name("hand.run")
|
||||
.with_kind(SpanKind::Internal)
|
||||
.with_start_time(start_time)
|
||||
.with_attributes(vec![
|
||||
KeyValue::new("hand.name", hand_name.clone()),
|
||||
KeyValue::new("hand.success", true),
|
||||
KeyValue::new("hand.findings", *findings_count as i64),
|
||||
KeyValue::new("duration_s", secs),
|
||||
]),
|
||||
);
|
||||
span.set_status(Status::Ok);
|
||||
span.end();
|
||||
|
||||
let attrs = [
|
||||
KeyValue::new("hand", hand_name.clone()),
|
||||
KeyValue::new("success", "true"),
|
||||
];
|
||||
self.hand_runs.add(1, &attrs);
|
||||
self.hand_duration
|
||||
.record(secs, &[KeyValue::new("hand", hand_name.clone())]);
|
||||
self.hand_findings.add(
|
||||
*findings_count as u64,
|
||||
&[KeyValue::new("hand", hand_name.clone())],
|
||||
);
|
||||
}
|
||||
ObserverEvent::HandFailed {
|
||||
hand_name,
|
||||
error,
|
||||
duration_ms,
|
||||
} => {
|
||||
let secs = *duration_ms as f64 / 1000.0;
|
||||
let duration = std::time::Duration::from_millis(*duration_ms);
|
||||
let start_time = SystemTime::now()
|
||||
.checked_sub(duration)
|
||||
.unwrap_or(SystemTime::now());
|
||||
|
||||
let mut span = tracer.build(
|
||||
opentelemetry::trace::SpanBuilder::from_name("hand.run")
|
||||
.with_kind(SpanKind::Internal)
|
||||
.with_start_time(start_time)
|
||||
.with_attributes(vec![
|
||||
KeyValue::new("hand.name", hand_name.clone()),
|
||||
KeyValue::new("hand.success", false),
|
||||
KeyValue::new("error.message", error.clone()),
|
||||
KeyValue::new("duration_s", secs),
|
||||
]),
|
||||
);
|
||||
span.set_status(Status::error(error.clone()));
|
||||
span.end();
|
||||
|
||||
let attrs = [
|
||||
KeyValue::new("hand", hand_name.clone()),
|
||||
KeyValue::new("success", "false"),
|
||||
];
|
||||
self.hand_runs.add(1, &attrs);
|
||||
self.hand_duration
|
||||
.record(secs, &[KeyValue::new("hand", hand_name.clone())]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,6 +445,29 @@ impl Observer for OtelObserver {
|
||||
ObserverMetric::QueueDepth(d) => {
|
||||
self.queue_depth.record(*d as u64, &[]);
|
||||
}
|
||||
ObserverMetric::HandRunDuration {
|
||||
hand_name,
|
||||
duration,
|
||||
} => {
|
||||
self.hand_duration.record(
|
||||
duration.as_secs_f64(),
|
||||
&[KeyValue::new("hand", hand_name.clone())],
|
||||
);
|
||||
}
|
||||
ObserverMetric::HandFindingsCount { hand_name, count } => {
|
||||
self.hand_findings
|
||||
.add(*count, &[KeyValue::new("hand", hand_name.clone())]);
|
||||
}
|
||||
ObserverMetric::HandSuccessRate { hand_name, success } => {
|
||||
let success_str = if *success { "true" } else { "false" };
|
||||
self.hand_runs.add(
|
||||
1,
|
||||
&[
|
||||
KeyValue::new("hand", hand_name.clone()),
|
||||
KeyValue::new("success", success_str),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -519,6 +635,41 @@ mod tests {
|
||||
obs.record_metric(&ObserverMetric::QueueDepth(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn otel_hand_events_do_not_panic() {
|
||||
let obs = test_observer();
|
||||
obs.record_event(&ObserverEvent::HandStarted {
|
||||
hand_name: "review".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 1500,
|
||||
findings_count: 3,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandFailed {
|
||||
hand_name: "review".into(),
|
||||
error: "timeout".into(),
|
||||
duration_ms: 5000,
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn otel_hand_metrics_do_not_panic() {
|
||||
let obs = test_observer();
|
||||
obs.record_metric(&ObserverMetric::HandRunDuration {
|
||||
hand_name: "review".into(),
|
||||
duration: Duration::from_millis(1500),
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandFindingsCount {
|
||||
hand_name: "review".into(),
|
||||
count: 5,
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandSuccessRate {
|
||||
hand_name: "review".into(),
|
||||
success: true,
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn otel_observer_creation_with_valid_endpoint_succeeds() {
|
||||
// Even though endpoint is unreachable, creation should succeed
|
||||
|
||||
@@ -16,6 +16,9 @@ pub struct PrometheusObserver {
|
||||
channel_messages: IntCounterVec,
|
||||
heartbeat_ticks: prometheus::IntCounter,
|
||||
errors: IntCounterVec,
|
||||
cache_hits: IntCounterVec,
|
||||
cache_misses: IntCounterVec,
|
||||
cache_tokens_saved: IntCounterVec,
|
||||
|
||||
// Histograms
|
||||
agent_duration: HistogramVec,
|
||||
@@ -26,6 +29,11 @@ pub struct PrometheusObserver {
|
||||
tokens_used: prometheus::IntGauge,
|
||||
active_sessions: GaugeVec,
|
||||
queue_depth: GaugeVec,
|
||||
|
||||
// Hands
|
||||
hand_runs: IntCounterVec,
|
||||
hand_duration: HistogramVec,
|
||||
hand_findings: IntCounterVec,
|
||||
}
|
||||
|
||||
impl PrometheusObserver {
|
||||
@@ -81,6 +89,27 @@ impl PrometheusObserver {
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_hits = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_cache_hits_total", "Total response cache hits"),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_misses = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_cache_misses_total", "Total response cache misses"),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_tokens_saved = IntCounterVec::new(
|
||||
prometheus::Opts::new(
|
||||
"zeroclaw_cache_tokens_saved_total",
|
||||
"Total tokens saved by response cache",
|
||||
),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let agent_duration = HistogramVec::new(
|
||||
HistogramOpts::new(
|
||||
"zeroclaw_agent_duration_seconds",
|
||||
@@ -128,6 +157,31 @@ impl PrometheusObserver {
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let hand_runs = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_hand_runs_total", "Total hand runs by outcome"),
|
||||
&["hand", "success"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let hand_duration = HistogramVec::new(
|
||||
HistogramOpts::new(
|
||||
"zeroclaw_hand_duration_seconds",
|
||||
"Hand run duration in seconds",
|
||||
)
|
||||
.buckets(vec![0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0]),
|
||||
&["hand"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let hand_findings = IntCounterVec::new(
|
||||
prometheus::Opts::new(
|
||||
"zeroclaw_hand_findings_total",
|
||||
"Total findings produced by hand runs",
|
||||
),
|
||||
&["hand"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
// Register all metrics
|
||||
registry.register(Box::new(agent_starts.clone())).ok();
|
||||
registry.register(Box::new(llm_requests.clone())).ok();
|
||||
@@ -139,12 +193,18 @@ impl PrometheusObserver {
|
||||
registry.register(Box::new(channel_messages.clone())).ok();
|
||||
registry.register(Box::new(heartbeat_ticks.clone())).ok();
|
||||
registry.register(Box::new(errors.clone())).ok();
|
||||
registry.register(Box::new(cache_hits.clone())).ok();
|
||||
registry.register(Box::new(cache_misses.clone())).ok();
|
||||
registry.register(Box::new(cache_tokens_saved.clone())).ok();
|
||||
registry.register(Box::new(agent_duration.clone())).ok();
|
||||
registry.register(Box::new(tool_duration.clone())).ok();
|
||||
registry.register(Box::new(request_latency.clone())).ok();
|
||||
registry.register(Box::new(tokens_used.clone())).ok();
|
||||
registry.register(Box::new(active_sessions.clone())).ok();
|
||||
registry.register(Box::new(queue_depth.clone())).ok();
|
||||
registry.register(Box::new(hand_runs.clone())).ok();
|
||||
registry.register(Box::new(hand_duration.clone())).ok();
|
||||
registry.register(Box::new(hand_findings.clone())).ok();
|
||||
|
||||
Self {
|
||||
registry,
|
||||
@@ -156,12 +216,18 @@ impl PrometheusObserver {
|
||||
channel_messages,
|
||||
heartbeat_ticks,
|
||||
errors,
|
||||
cache_hits,
|
||||
cache_misses,
|
||||
cache_tokens_saved,
|
||||
agent_duration,
|
||||
tool_duration,
|
||||
request_latency,
|
||||
tokens_used,
|
||||
active_sessions,
|
||||
queue_depth,
|
||||
hand_runs,
|
||||
hand_duration,
|
||||
hand_findings,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,12 +311,56 @@ impl Observer for PrometheusObserver {
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
self.heartbeat_ticks.inc();
|
||||
}
|
||||
ObserverEvent::CacheHit {
|
||||
cache_type,
|
||||
tokens_saved,
|
||||
} => {
|
||||
self.cache_hits.with_label_values(&[cache_type]).inc();
|
||||
self.cache_tokens_saved
|
||||
.with_label_values(&[cache_type])
|
||||
.inc_by(*tokens_saved);
|
||||
}
|
||||
ObserverEvent::CacheMiss { cache_type } => {
|
||||
self.cache_misses.with_label_values(&[cache_type]).inc();
|
||||
}
|
||||
ObserverEvent::Error {
|
||||
component,
|
||||
message: _,
|
||||
} => {
|
||||
self.errors.with_label_values(&[component]).inc();
|
||||
}
|
||||
ObserverEvent::HandStarted { hand_name } => {
|
||||
self.hand_runs
|
||||
.with_label_values(&[hand_name.as_str(), "true"])
|
||||
.inc_by(0); // touch the series so it appears in output
|
||||
}
|
||||
ObserverEvent::HandCompleted {
|
||||
hand_name,
|
||||
duration_ms,
|
||||
findings_count,
|
||||
} => {
|
||||
self.hand_runs
|
||||
.with_label_values(&[hand_name.as_str(), "true"])
|
||||
.inc();
|
||||
self.hand_duration
|
||||
.with_label_values(&[hand_name.as_str()])
|
||||
.observe(*duration_ms as f64 / 1000.0);
|
||||
self.hand_findings
|
||||
.with_label_values(&[hand_name.as_str()])
|
||||
.inc_by(*findings_count as u64);
|
||||
}
|
||||
ObserverEvent::HandFailed {
|
||||
hand_name,
|
||||
duration_ms,
|
||||
..
|
||||
} => {
|
||||
self.hand_runs
|
||||
.with_label_values(&[hand_name.as_str(), "false"])
|
||||
.inc();
|
||||
self.hand_duration
|
||||
.with_label_values(&[hand_name.as_str()])
|
||||
.observe(*duration_ms as f64 / 1000.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,6 +382,25 @@ impl Observer for PrometheusObserver {
|
||||
.with_label_values(&[] as &[&str])
|
||||
.set(*d as f64);
|
||||
}
|
||||
ObserverMetric::HandRunDuration {
|
||||
hand_name,
|
||||
duration,
|
||||
} => {
|
||||
self.hand_duration
|
||||
.with_label_values(&[hand_name.as_str()])
|
||||
.observe(duration.as_secs_f64());
|
||||
}
|
||||
ObserverMetric::HandFindingsCount { hand_name, count } => {
|
||||
self.hand_findings
|
||||
.with_label_values(&[hand_name.as_str()])
|
||||
.inc_by(*count);
|
||||
}
|
||||
ObserverMetric::HandSuccessRate { hand_name, success } => {
|
||||
let success_str = if *success { "true" } else { "false" };
|
||||
self.hand_runs
|
||||
.with_label_values(&[hand_name.as_str(), success_str])
|
||||
.inc();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -471,6 +600,61 @@ mod tests {
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_events_track_runs_and_duration() {
|
||||
let obs = PrometheusObserver::new();
|
||||
|
||||
obs.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 1500,
|
||||
findings_count: 3,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 2000,
|
||||
findings_count: 1,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandFailed {
|
||||
hand_name: "review".into(),
|
||||
error: "timeout".into(),
|
||||
duration_ms: 5000,
|
||||
});
|
||||
|
||||
let output = obs.encode();
|
||||
assert!(output.contains(r#"zeroclaw_hand_runs_total{hand="review",success="true"} 2"#));
|
||||
assert!(output.contains(r#"zeroclaw_hand_runs_total{hand="review",success="false"} 1"#));
|
||||
assert!(output.contains(r#"zeroclaw_hand_findings_total{hand="review"} 4"#));
|
||||
assert!(output.contains("zeroclaw_hand_duration_seconds"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_metrics_record_duration_and_findings() {
|
||||
let obs = PrometheusObserver::new();
|
||||
|
||||
obs.record_metric(&ObserverMetric::HandRunDuration {
|
||||
hand_name: "scan".into(),
|
||||
duration: Duration::from_millis(800),
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandFindingsCount {
|
||||
hand_name: "scan".into(),
|
||||
count: 5,
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandSuccessRate {
|
||||
hand_name: "scan".into(),
|
||||
success: true,
|
||||
});
|
||||
obs.record_metric(&ObserverMetric::HandSuccessRate {
|
||||
hand_name: "scan".into(),
|
||||
success: false,
|
||||
});
|
||||
|
||||
let output = obs.encode();
|
||||
assert!(output.contains("zeroclaw_hand_duration_seconds"));
|
||||
assert!(output.contains(r#"zeroclaw_hand_findings_total{hand="scan"} 5"#));
|
||||
assert!(output.contains(r#"zeroclaw_hand_runs_total{hand="scan",success="true"} 1"#));
|
||||
assert!(output.contains(r#"zeroclaw_hand_runs_total{hand="scan",success="false"} 1"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_response_without_tokens_increments_request_only() {
|
||||
let obs = PrometheusObserver::new();
|
||||
|
||||
@@ -61,6 +61,18 @@ pub enum ObserverEvent {
|
||||
},
|
||||
/// Periodic heartbeat tick from the runtime keep-alive loop.
|
||||
HeartbeatTick,
|
||||
/// Response cache hit — an LLM call was avoided.
|
||||
CacheHit {
|
||||
/// `"hot"` (in-memory) or `"warm"` (SQLite).
|
||||
cache_type: String,
|
||||
/// Estimated tokens saved by this cache hit.
|
||||
tokens_saved: u64,
|
||||
},
|
||||
/// Response cache miss — the prompt was not found in cache.
|
||||
CacheMiss {
|
||||
/// `"response"` cache layer that was checked.
|
||||
cache_type: String,
|
||||
},
|
||||
/// An error occurred in a named component.
|
||||
Error {
|
||||
/// Subsystem where the error originated (e.g., `"provider"`, `"gateway"`).
|
||||
@@ -68,6 +80,20 @@ pub enum ObserverEvent {
|
||||
/// Human-readable error description. Must not contain secrets or tokens.
|
||||
message: String,
|
||||
},
|
||||
/// A hand has started execution.
|
||||
HandStarted { hand_name: String },
|
||||
/// A hand has completed execution successfully.
|
||||
HandCompleted {
|
||||
hand_name: String,
|
||||
duration_ms: u64,
|
||||
findings_count: usize,
|
||||
},
|
||||
/// A hand has failed during execution.
|
||||
HandFailed {
|
||||
hand_name: String,
|
||||
error: String,
|
||||
duration_ms: u64,
|
||||
},
|
||||
}
|
||||
|
||||
/// Numeric metrics emitted by the agent runtime.
|
||||
@@ -84,6 +110,15 @@ pub enum ObserverMetric {
|
||||
ActiveSessions(u64),
|
||||
/// Current depth of the inbound message queue.
|
||||
QueueDepth(u64),
|
||||
/// Duration of a single hand run.
|
||||
HandRunDuration {
|
||||
hand_name: String,
|
||||
duration: Duration,
|
||||
},
|
||||
/// Number of findings produced by a hand run.
|
||||
HandFindingsCount { hand_name: String, count: u64 },
|
||||
/// Records a hand run outcome for success-rate tracking.
|
||||
HandSuccessRate { hand_name: String, success: bool },
|
||||
}
|
||||
|
||||
/// Core observability trait for recording agent runtime telemetry.
|
||||
@@ -200,4 +235,67 @@ mod tests {
|
||||
assert!(matches!(cloned_event, ObserverEvent::ToolCall { .. }));
|
||||
assert!(matches!(cloned_metric, ObserverMetric::RequestLatency(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_events_recordable() {
|
||||
let observer = DummyObserver::default();
|
||||
|
||||
observer.record_event(&ObserverEvent::HandStarted {
|
||||
hand_name: "review".into(),
|
||||
});
|
||||
observer.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 1500,
|
||||
findings_count: 3,
|
||||
});
|
||||
observer.record_event(&ObserverEvent::HandFailed {
|
||||
hand_name: "review".into(),
|
||||
error: "timeout".into(),
|
||||
duration_ms: 5000,
|
||||
});
|
||||
|
||||
assert_eq!(*observer.events.lock(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_metrics_recordable() {
|
||||
let observer = DummyObserver::default();
|
||||
|
||||
observer.record_metric(&ObserverMetric::HandRunDuration {
|
||||
hand_name: "review".into(),
|
||||
duration: Duration::from_millis(1500),
|
||||
});
|
||||
observer.record_metric(&ObserverMetric::HandFindingsCount {
|
||||
hand_name: "review".into(),
|
||||
count: 3,
|
||||
});
|
||||
observer.record_metric(&ObserverMetric::HandSuccessRate {
|
||||
hand_name: "review".into(),
|
||||
success: true,
|
||||
});
|
||||
|
||||
assert_eq!(*observer.metrics.lock(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_event_and_metric_are_cloneable() {
|
||||
let event = ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 500,
|
||||
findings_count: 2,
|
||||
};
|
||||
let metric = ObserverMetric::HandRunDuration {
|
||||
hand_name: "review".into(),
|
||||
duration: Duration::from_millis(500),
|
||||
};
|
||||
|
||||
let cloned_event = event.clone();
|
||||
let cloned_metric = metric.clone();
|
||||
|
||||
assert!(matches!(cloned_event, ObserverEvent::HandCompleted { .. }));
|
||||
assert!(matches!(
|
||||
cloned_metric,
|
||||
ObserverMetric::HandRunDuration { .. }
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,4 +101,22 @@ mod tests {
|
||||
});
|
||||
obs.record_event(&ObserverEvent::TurnComplete);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verbose_hand_events_do_not_panic() {
|
||||
let obs = VerboseObserver::new();
|
||||
obs.record_event(&ObserverEvent::HandStarted {
|
||||
hand_name: "review".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandCompleted {
|
||||
hand_name: "review".into(),
|
||||
duration_ms: 1500,
|
||||
findings_count: 3,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HandFailed {
|
||||
hand_name: "review".into(),
|
||||
error: "timeout".into(),
|
||||
duration_ms: 5000,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,10 +167,12 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
microsoft365: crate::config::Microsoft365Config::default(),
|
||||
secrets: secrets_config,
|
||||
browser: BrowserConfig::default(),
|
||||
browser_delegate: crate::tools::browser_delegate::BrowserDelegateConfig::default(),
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
google_workspace: crate::config::GoogleWorkspaceConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
@@ -402,6 +404,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||
response_cache_enabled: false,
|
||||
response_cache_ttl_minutes: 60,
|
||||
response_cache_max_entries: 5_000,
|
||||
response_cache_hot_entries: 256,
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
@@ -535,10 +538,12 @@ async fn run_quick_setup_with_home(
|
||||
microsoft365: crate::config::Microsoft365Config::default(),
|
||||
secrets: SecretsConfig::default(),
|
||||
browser: BrowserConfig::default(),
|
||||
browser_delegate: crate::tools::browser_delegate::BrowserDelegateConfig::default(),
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
google_workspace: crate::config::GoogleWorkspaceConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
@@ -3899,6 +3904,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
},
|
||||
allowed_users,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::IMessage => {
|
||||
|
||||
@@ -149,6 +149,10 @@ struct AnthropicUsage {
|
||||
input_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
output_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -475,6 +479,7 @@ impl AnthropicProvider {
|
||||
let usage = response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
cached_input_tokens: u.cache_read_input_tokens,
|
||||
});
|
||||
|
||||
for block in response.content {
|
||||
@@ -614,6 +619,7 @@ impl Provider for AnthropicProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -312,6 +312,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -431,6 +432,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@@ -491,6 +493,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@@ -832,6 +832,7 @@ impl BedrockProvider {
|
||||
let usage = response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
|
||||
if let Some(output) = response.output {
|
||||
@@ -967,6 +968,7 @@ impl Provider for BedrockProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
//! Claude Code headless CLI provider.
|
||||
//!
|
||||
//! Integrates with the Claude Code CLI, spawning the `claude` binary
|
||||
//! as a subprocess for each inference request. This allows using Claude's AI
|
||||
//! models without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `claude` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `CLAUDE_CODE_PATH` environment variable.
|
||||
//!
|
||||
//! Claude Code is invoked as:
|
||||
//! ```text
|
||||
//! claude --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by Claude Code itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `claude` binary.
|
||||
pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
|
||||
|
||||
/// Default `claude` binary name (resolved via `PATH`).
|
||||
const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// Claude Code requests are bounded to avoid hung subprocesses.
|
||||
const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the Claude Code CLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `claude` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct ClaudeCodeProvider {
|
||||
/// Path to the `claude` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ClaudeCodeProvider {
|
||||
/// Create a new `ClaudeCodeProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `CLAUDE_CODE_PATH` env var if set,
|
||||
/// otherwise defaults to `"claude"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
CLAUDE_CODE_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("Claude Code provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by Claude Code CLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the claude binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn Claude Code binary at {}: {err}. \
|
||||
Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(message.as_bytes()).await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
|
||||
})?;
|
||||
stdin.shutdown().await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
|
||||
})?;
|
||||
}
|
||||
|
||||
let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Claude Code request timed out after {:?} (binary: {})",
|
||||
CLAUDE_CODE_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"Claude Code exited with non-zero status {code}. \
|
||||
Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeCodeProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ClaudeCodeProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude");
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
|
||||
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_claude() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::remove_var(CLAUDE_CODE_PATH_ENV);
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("claude"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, " ");
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("claude"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
|
||||
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(ClaudeCodeProvider::should_forward_model(
|
||||
"claude-sonnet-4-20250514"
|
||||
));
|
||||
assert!(ClaudeCodeProvider::should_forward_model(
|
||||
"claude-3.5-sonnet"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(
|
||||
DEFAULT_MODEL_MARKER
|
||||
));
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(""));
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = ClaudeCodeProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by Claude Code CLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = ClaudeCodeProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/claude"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn Claude Code binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1193,6 +1193,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
crate::providers::traits::ProviderCapabilities {
|
||||
native_tool_calling: self.native_tool_calling,
|
||||
vision: self.supports_vision,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1514,6 +1515,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
let usage = chat_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = chat_response
|
||||
.choices
|
||||
@@ -1657,6 +1659,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@@ -353,6 +353,7 @@ impl CopilotProvider {
|
||||
let usage = api_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = api_response
|
||||
.choices
|
||||
|
||||
@@ -1128,6 +1128,7 @@ impl GeminiProvider {
|
||||
let usage = result.usage_metadata.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_token_count,
|
||||
output_tokens: u.candidates_token_count,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
|
||||
let text = result
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
//! Gemini CLI subprocess provider.
|
||||
//!
|
||||
//! Integrates with the Gemini CLI, spawning the `gemini` binary
|
||||
//! as a subprocess for each inference request. This allows using Google's
|
||||
//! Gemini models via the CLI without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `gemini` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `GEMINI_CLI_PATH` environment variable.
|
||||
//!
|
||||
//! Gemini CLI is invoked as:
|
||||
//! ```text
|
||||
//! gemini --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by the Gemini CLI itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `GEMINI_CLI_PATH` — override the path to the `gemini` binary (default: `"gemini"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `gemini` binary.
|
||||
pub const GEMINI_CLI_PATH_ENV: &str = "GEMINI_CLI_PATH";
|
||||
|
||||
/// Default `gemini` binary name (resolved via `PATH`).
|
||||
const DEFAULT_GEMINI_CLI_BINARY: &str = "gemini";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// Gemini CLI requests are bounded to avoid hung subprocesses.
|
||||
const GEMINI_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_GEMINI_CLI_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const GEMINI_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the Gemini CLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `gemini` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct GeminiCliProvider {
|
||||
/// Path to the `gemini` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl GeminiCliProvider {
|
||||
/// Create a new `GeminiCliProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `GEMINI_CLI_PATH` env var if set,
|
||||
/// otherwise defaults to `"gemini"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(GEMINI_CLI_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_GEMINI_CLI_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
GEMINI_CLI_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("Gemini CLI provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by Gemini CLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_GEMINI_CLI_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_GEMINI_CLI_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the gemini binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn Gemini CLI binary at {}: {err}. \
|
||||
Ensure `gemini` is installed and in PATH, or set GEMINI_CLI_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(message.as_bytes()).await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to write prompt to Gemini CLI stdin: {err}")
|
||||
})?;
|
||||
stdin.shutdown().await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to finalize Gemini CLI stdin stream: {err}")
|
||||
})?;
|
||||
}
|
||||
|
||||
let output = timeout(GEMINI_CLI_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Gemini CLI request timed out after {:?} (binary: {})",
|
||||
GEMINI_CLI_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("Gemini CLI process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"Gemini CLI exited with non-zero status {code}. \
|
||||
Check that Gemini CLI is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("Gemini CLI produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GeminiCliProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for GeminiCliProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, "/usr/local/bin/gemini");
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/gemini"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_gemini() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::remove_var(GEMINI_CLI_PATH_ENV);
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, " ");
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-pro"));
|
||||
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-flash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!GeminiCliProvider::should_forward_model(
|
||||
DEFAULT_MODEL_MARKER
|
||||
));
|
||||
assert!(!GeminiCliProvider::should_forward_model(""));
|
||||
assert!(!GeminiCliProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(GeminiCliProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(GeminiCliProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = GeminiCliProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by Gemini CLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = GeminiCliProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/gemini"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn Gemini CLI binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
//! KiloCLI subprocess provider.
|
||||
//!
|
||||
//! Integrates with the KiloCLI tool, spawning the `kilo` binary
|
||||
//! as a subprocess for each inference request. This allows using KiloCLI's AI
|
||||
//! models without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `kilo` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `KILO_CLI_PATH` environment variable.
|
||||
//!
|
||||
//! KiloCLI is invoked as:
|
||||
//! ```text
|
||||
//! kilo --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by KiloCLI itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `KILO_CLI_PATH` — override the path to the `kilo` binary (default: `"kilo"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `kilo` binary.
|
||||
pub const KILO_CLI_PATH_ENV: &str = "KILO_CLI_PATH";
|
||||
|
||||
/// Default `kilo` binary name (resolved via `PATH`).
|
||||
const DEFAULT_KILO_CLI_BINARY: &str = "kilo";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// KiloCLI requests are bounded to avoid hung subprocesses.
|
||||
const KILO_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_KILO_CLI_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const KILO_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the KiloCLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `kilo` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct KiloCliProvider {
|
||||
/// Path to the `kilo` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl KiloCliProvider {
|
||||
/// Create a new `KiloCliProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `KILO_CLI_PATH` env var if set,
|
||||
/// otherwise defaults to `"kilo"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(KILO_CLI_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_KILO_CLI_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
KILO_CLI_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("KiloCLI provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by KiloCLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_KILO_CLI_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_KILO_CLI_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the kilo binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn KiloCLI binary at {}: {err}. \
|
||||
Ensure `kilo` is installed and in PATH, or set KILO_CLI_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(message.as_bytes())
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("Failed to write prompt to KiloCLI stdin: {err}"))?;
|
||||
stdin
|
||||
.shutdown()
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("Failed to finalize KiloCLI stdin stream: {err}"))?;
|
||||
}
|
||||
|
||||
let output = timeout(KILO_CLI_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"KiloCLI request timed out after {:?} (binary: {})",
|
||||
KILO_CLI_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("KiloCLI process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"KiloCLI exited with non-zero status {code}. \
|
||||
Check that KiloCLI is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("KiloCLI produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KiloCliProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for KiloCliProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, "/usr/local/bin/kilo");
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/kilo"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(KILO_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_kilo() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::remove_var(KILO_CLI_PATH_ENV);
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, " ");
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(KILO_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(KiloCliProvider::should_forward_model("some-model"));
|
||||
assert!(KiloCliProvider::should_forward_model("gpt-4o"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!KiloCliProvider::should_forward_model(DEFAULT_MODEL_MARKER));
|
||||
assert!(!KiloCliProvider::should_forward_model(""));
|
||||
assert!(!KiloCliProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(KiloCliProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(KiloCliProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = KiloCliProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by KiloCLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = KiloCliProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/kilo"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn KiloCLI binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
+94
-1
@@ -19,9 +19,12 @@
|
||||
pub mod anthropic;
|
||||
pub mod azure_openai;
|
||||
pub mod bedrock;
|
||||
pub mod claude_code;
|
||||
pub mod compatible;
|
||||
pub mod copilot;
|
||||
pub mod gemini;
|
||||
pub mod gemini_cli;
|
||||
pub mod kilocli;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openai_codex;
|
||||
@@ -846,7 +849,9 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) ->
|
||||
// not a single API key. Credential resolution happens inside BedrockProvider.
|
||||
"bedrock" | "aws-bedrock" => return None,
|
||||
name if is_qianfan_alias(name) => vec!["QIANFAN_API_KEY"],
|
||||
name if is_doubao_alias(name) => vec!["ARK_API_KEY", "DOUBAO_API_KEY"],
|
||||
name if is_doubao_alias(name) => {
|
||||
vec!["ARK_API_KEY", "VOLCENGINE_API_KEY", "DOUBAO_API_KEY"]
|
||||
}
|
||||
name if is_qwen_alias(name) => vec!["DASHSCOPE_API_KEY"],
|
||||
name if is_zai_alias(name) => vec!["ZAI_API_KEY"],
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec!["NVIDIA_API_KEY"],
|
||||
@@ -860,6 +865,8 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) ->
|
||||
"llamacpp" | "llama.cpp" => vec!["LLAMACPP_API_KEY"],
|
||||
"sglang" => vec!["SGLANG_API_KEY"],
|
||||
"vllm" => vec!["VLLM_API_KEY"],
|
||||
"aihubmix" => vec!["AIHUBMIX_API_KEY"],
|
||||
"siliconflow" | "silicon-flow" => vec!["SILICONFLOW_API_KEY"],
|
||||
"osaurus" => vec!["OSAURUS_API_KEY"],
|
||||
"telnyx" => vec!["TELNYX_API_KEY"],
|
||||
"azure_openai" | "azure-openai" | "azure" => vec!["AZURE_OPENAI_API_KEY"],
|
||||
@@ -1247,6 +1254,9 @@ fn create_provider_with_url_and_options(
|
||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"copilot" | "github-copilot" => Ok(Box::new(copilot::CopilotProvider::new(key))),
|
||||
"claude-code" => Ok(Box::new(claude_code::ClaudeCodeProvider::new())),
|
||||
"gemini-cli" => Ok(Box::new(gemini_cli::GeminiCliProvider::new())),
|
||||
"kilocli" | "kilo" => Ok(Box::new(kilocli::KiloCliProvider::new())),
|
||||
"lmstudio" | "lm-studio" => {
|
||||
let lm_studio_key = key
|
||||
.map(str::trim)
|
||||
@@ -1898,6 +1908,24 @@ pub fn list_providers() -> Vec<ProviderInfo> {
|
||||
aliases: &["github-copilot"],
|
||||
local: false,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "claude-code",
|
||||
display_name: "Claude Code (CLI)",
|
||||
aliases: &[],
|
||||
local: true,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "gemini-cli",
|
||||
display_name: "Gemini CLI",
|
||||
aliases: &[],
|
||||
local: true,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "kilocli",
|
||||
display_name: "KiloCLI",
|
||||
aliases: &["kilo"],
|
||||
local: true,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "lmstudio",
|
||||
display_name: "LM Studio",
|
||||
@@ -2607,6 +2635,52 @@ mod tests {
|
||||
assert_eq!(resolved, Some("osaurus-test-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_provider_credential_volcengine_env() {
|
||||
let _env_lock = env_lock();
|
||||
let _guard = EnvGuard::set("VOLCENGINE_API_KEY", Some("volc-test-key"));
|
||||
let resolved = resolve_provider_credential("volcengine", None);
|
||||
assert_eq!(resolved, Some("volc-test-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_provider_credential_aihubmix_env() {
|
||||
let _env_lock = env_lock();
|
||||
let _guard = EnvGuard::set("AIHUBMIX_API_KEY", Some("aihubmix-test-key"));
|
||||
let resolved = resolve_provider_credential("aihubmix", None);
|
||||
assert_eq!(resolved, Some("aihubmix-test-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_provider_credential_siliconflow_env() {
|
||||
let _env_lock = env_lock();
|
||||
let _guard = EnvGuard::set("SILICONFLOW_API_KEY", Some("sf-test-key"));
|
||||
let resolved = resolve_provider_credential("siliconflow", None);
|
||||
assert_eq!(resolved, Some("sf-test-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_aihubmix() {
|
||||
assert!(create_provider("aihubmix", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_siliconflow() {
|
||||
assert!(create_provider("siliconflow", Some("key")).is_ok());
|
||||
assert!(create_provider("silicon-flow", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_codex_oauth_aliases() {
|
||||
let options = ProviderRuntimeOptions::default();
|
||||
for alias in &["codex", "openai-codex", "openai_codex"] {
|
||||
assert!(
|
||||
create_provider_with_options(alias, None, &options).is_ok(),
|
||||
"codex alias '{alias}' should produce a provider"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Extended ecosystem ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
@@ -2670,6 +2744,22 @@ mod tests {
|
||||
assert!(create_provider("github-copilot", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_claude_code() {
|
||||
assert!(create_provider("claude-code", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_gemini_cli() {
|
||||
assert!(create_provider("gemini-cli", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_kilocli() {
|
||||
assert!(create_provider("kilocli", None).is_ok());
|
||||
assert!(create_provider("kilo", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_nvidia() {
|
||||
assert!(create_provider("nvidia", Some("nvapi-test")).is_ok());
|
||||
@@ -3003,6 +3093,9 @@ mod tests {
|
||||
"perplexity",
|
||||
"cohere",
|
||||
"copilot",
|
||||
"claude-code",
|
||||
"gemini-cli",
|
||||
"kilocli",
|
||||
"nvidia",
|
||||
"astrai",
|
||||
"ovhcloud",
|
||||
|
||||
@@ -632,6 +632,7 @@ impl Provider for OllamaProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -764,6 +765,7 @@ impl Provider for OllamaProvider {
|
||||
Some(TokenUsage {
|
||||
input_tokens: response.prompt_eval_count,
|
||||
output_tokens: response.eval_count,
|
||||
cached_input_tokens: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
||||
+172
-3
@@ -135,6 +135,14 @@ struct UsageInfo {
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
completion_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
prompt_tokens_details: Option<PromptTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PromptTokensDetails {
|
||||
#[serde(default)]
|
||||
cached_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -178,6 +186,38 @@ impl OpenAiProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjust temperature for models that have specific requirements.
|
||||
/// Some OpenAI models (like gpt-5-mini, o1, o3, etc) only accept temperature=1.0.
|
||||
fn adjust_temperature_for_model(model: &str, requested_temperature: f64) -> f64 {
|
||||
// Models that require temperature=1.0
|
||||
let requires_1_0 = matches!(
|
||||
model,
|
||||
"gpt-5"
|
||||
| "gpt-5-2025-08-07"
|
||||
| "gpt-5-mini"
|
||||
| "gpt-5-mini-2025-08-07"
|
||||
| "gpt-5-nano"
|
||||
| "gpt-5-nano-2025-08-07"
|
||||
| "gpt-5.1-chat-latest"
|
||||
| "gpt-5.2-chat-latest"
|
||||
| "gpt-5.3-chat-latest"
|
||||
| "o1"
|
||||
| "o1-2024-12-17"
|
||||
| "o3"
|
||||
| "o3-2025-04-16"
|
||||
| "o3-mini"
|
||||
| "o3-mini-2025-01-31"
|
||||
| "o4-mini"
|
||||
| "o4-mini-2025-04-16"
|
||||
);
|
||||
|
||||
if requires_1_0 {
|
||||
1.0
|
||||
} else {
|
||||
requested_temperature
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
@@ -308,6 +348,8 @@ impl Provider for OpenAiProvider {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
let adjusted_temperature = Self::adjust_temperature_for_model(model, temperature);
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
@@ -325,7 +367,7 @@ impl Provider for OpenAiProvider {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
temperature: adjusted_temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
@@ -360,11 +402,13 @@ impl Provider for OpenAiProvider {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
let adjusted_temperature = Self::adjust_temperature_for_model(model, temperature);
|
||||
|
||||
let tools = Self::convert_tools(request.tools);
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: Self::convert_messages(request.messages),
|
||||
temperature,
|
||||
temperature: adjusted_temperature,
|
||||
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools,
|
||||
};
|
||||
@@ -385,6 +429,7 @@ impl Provider for OpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@@ -412,6 +457,8 @@ impl Provider for OpenAiProvider {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
let adjusted_temperature = Self::adjust_temperature_for_model(model, temperature);
|
||||
|
||||
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -427,7 +474,7 @@ impl Provider for OpenAiProvider {
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: Self::convert_messages(messages),
|
||||
temperature,
|
||||
temperature: adjusted_temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
@@ -448,6 +495,7 @@ impl Provider for OpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@@ -828,4 +876,125 @@ mod tests {
|
||||
assert!(json.contains("reasoning_content"));
|
||||
assert!(json.contains("thinking..."));
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// Temperature adjustment tests
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_for_o1_models() {
|
||||
assert_eq!(OpenAiProvider::adjust_temperature_for_model("o1", 0.7), 1.0);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("o1-2024-12-17", 0.5),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_for_o3_models() {
|
||||
assert_eq!(OpenAiProvider::adjust_temperature_for_model("o3", 0.7), 1.0);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("o3-2025-04-16", 0.5),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("o3-mini", 0.3),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("o3-mini-2025-01-31", 0.8),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_for_o4_models() {
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("o4-mini", 0.7),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("o4-mini-2025-04-16", 0.5),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_for_gpt5_models() {
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5", 0.7),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5-2025-08-07", 0.5),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5-mini", 0.3),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5-mini-2025-08-07", 0.8),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5-nano", 0.6),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5-nano-2025-08-07", 0.4),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_for_gpt5_chat_latest_models() {
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5.1-chat-latest", 0.7),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5.2-chat-latest", 0.5),
|
||||
1.0
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-5.3-chat-latest", 0.3),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_preserves_for_standard_models() {
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-4o", 0.7),
|
||||
0.7
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-4-turbo", 0.5),
|
||||
0.5
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-3.5-turbo", 0.3),
|
||||
0.3
|
||||
);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-4", 1.0),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_temperature_handles_edge_cases() {
|
||||
// Temperature 0.0 should be preserved for standard models
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-4o", 0.0),
|
||||
0.0
|
||||
);
|
||||
// Temperature 1.0 should be preserved for all models
|
||||
assert_eq!(OpenAiProvider::adjust_temperature_for_model("o1", 1.0), 1.0);
|
||||
assert_eq!(
|
||||
OpenAiProvider::adjust_temperature_for_model("gpt-4o", 1.0),
|
||||
1.0
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,6 +473,75 @@ fn extract_stream_error_message(event: &Value) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
fn append_utf8_stream_chunk(
|
||||
body: &mut String,
|
||||
pending: &mut Vec<u8>,
|
||||
chunk: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
if pending.is_empty() {
|
||||
if let Ok(text) = std::str::from_utf8(chunk) {
|
||||
body.push_str(text);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
if !chunk.is_empty() {
|
||||
pending.extend_from_slice(chunk);
|
||||
}
|
||||
if pending.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match std::str::from_utf8(pending) {
|
||||
Ok(text) => {
|
||||
body.push_str(text);
|
||||
pending.clear();
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let valid_up_to = err.valid_up_to();
|
||||
if valid_up_to > 0 {
|
||||
// SAFETY: `valid_up_to` always points to the end of a valid UTF-8 prefix.
|
||||
let prefix = std::str::from_utf8(&pending[..valid_up_to])
|
||||
.expect("valid UTF-8 prefix from Utf8Error::valid_up_to");
|
||||
body.push_str(prefix);
|
||||
pending.drain(..valid_up_to);
|
||||
}
|
||||
|
||||
if err.error_len().is_some() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"OpenAI Codex response contained invalid UTF-8: {err}"
|
||||
));
|
||||
}
|
||||
|
||||
// `error_len == None` means we have a valid prefix and an incomplete
|
||||
// multi-byte sequence at the end; keep it buffered until next chunk.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_utf8_stream_chunks<'a, I>(chunks: I) -> anyhow::Result<String>
|
||||
where
|
||||
I: IntoIterator<Item = &'a [u8]>,
|
||||
{
|
||||
let mut body = String::new();
|
||||
let mut pending = Vec::new();
|
||||
|
||||
for chunk in chunks {
|
||||
append_utf8_stream_chunk(&mut body, &mut pending, chunk)?;
|
||||
}
|
||||
|
||||
if !pending.is_empty() {
|
||||
let err = std::str::from_utf8(&pending).expect_err("pending bytes should be invalid UTF-8");
|
||||
return Err(anyhow::anyhow!(
|
||||
"OpenAI Codex response ended with incomplete UTF-8: {err}"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
/// Read the response body incrementally via `bytes_stream()` to avoid
|
||||
/// buffering the entire SSE payload in memory. The previous implementation
|
||||
/// used `response.text().await?` which holds the HTTP connection open until
|
||||
@@ -481,15 +550,21 @@ fn extract_stream_error_message(event: &Value) -> Option<String> {
|
||||
/// reported in #3544.
|
||||
async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<String> {
|
||||
let mut body = String::new();
|
||||
let mut pending_utf8 = Vec::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let bytes = chunk
|
||||
.map_err(|err| anyhow::anyhow!("error reading OpenAI Codex response stream: {err}"))?;
|
||||
let text = std::str::from_utf8(&bytes).map_err(|err| {
|
||||
anyhow::anyhow!("OpenAI Codex response contained invalid UTF-8: {err}")
|
||||
})?;
|
||||
body.push_str(text);
|
||||
append_utf8_stream_chunk(&mut body, &mut pending_utf8, &bytes)?;
|
||||
}
|
||||
|
||||
if !pending_utf8.is_empty() {
|
||||
let err = std::str::from_utf8(&pending_utf8)
|
||||
.expect_err("pending bytes should be invalid UTF-8 at end of stream");
|
||||
return Err(anyhow::anyhow!(
|
||||
"OpenAI Codex response ended with incomplete UTF-8: {err}"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(text) = parse_sse_text(&body)? {
|
||||
@@ -640,6 +715,7 @@ impl Provider for OpenAiCodexProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -900,6 +976,21 @@ data: [DONE]
|
||||
assert_eq!(parse_sse_text(payload).unwrap().as_deref(), Some("Done"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_utf8_stream_chunks_handles_multibyte_split_across_chunks() {
|
||||
let payload =
|
||||
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello 世\"}\n\ndata: [DONE]\n";
|
||||
let bytes = payload.as_bytes();
|
||||
let split_at = payload.find('世').unwrap() + 1;
|
||||
|
||||
let decoded = decode_utf8_stream_chunks([&bytes[..split_at], &bytes[split_at..]]).unwrap();
|
||||
assert_eq!(decoded, payload);
|
||||
assert_eq!(
|
||||
parse_sse_text(&decoded).unwrap().as_deref(),
|
||||
Some("Hello 世")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_responses_input_maps_content_types_by_role() {
|
||||
let messages = vec![
|
||||
|
||||
@@ -306,6 +306,7 @@ impl Provider for OpenRouterProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,6 +464,7 @@ impl Provider for OpenRouterProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@@ -554,6 +556,7 @@ impl Provider for OpenRouterProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@@ -15,7 +15,7 @@ use std::time::Duration;
|
||||
// immediately — avoiding wasted latency on errors that cannot self-heal.
|
||||
|
||||
/// Check if an error is non-retryable (client errors that won't resolve with retries).
|
||||
fn is_non_retryable(err: &anyhow::Error) -> bool {
|
||||
pub fn is_non_retryable(err: &anyhow::Error) -> bool {
|
||||
if is_context_window_exceeded(err) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -54,6 +54,9 @@ pub struct ToolCall {
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: Option<u64>,
|
||||
pub output_tokens: Option<u64>,
|
||||
/// Tokens served from the provider's prompt cache (Anthropic `cache_read_input_tokens`,
|
||||
/// OpenAI `prompt_tokens_details.cached_tokens`).
|
||||
pub cached_input_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
/// An LLM response that may contain text, tool calls, or both.
|
||||
@@ -233,6 +236,9 @@ pub struct ProviderCapabilities {
|
||||
pub native_tool_calling: bool,
|
||||
/// Whether the provider supports vision / image inputs.
|
||||
pub vision: bool,
|
||||
/// Whether the provider supports prompt caching (Anthropic cache_control,
|
||||
/// OpenAI automatic prompt caching).
|
||||
pub prompt_caching: bool,
|
||||
}
|
||||
|
||||
/// Provider-specific tool payload formats.
|
||||
@@ -498,6 +504,7 @@ mod tests {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,6 +575,7 @@ mod tests {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(100),
|
||||
output_tokens: Some(50),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
};
|
||||
@@ -613,14 +621,17 @@ mod tests {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
|
||||
+355
-8
@@ -1,15 +1,22 @@
|
||||
//! Audit logging for security events
|
||||
//!
|
||||
//! Each audit entry is chained via a Merkle hash: `entry_hash = SHA-256(prev_hash || canonical_json)`.
|
||||
//! This makes the trail tamper-evident — modifying any entry invalidates all subsequent hashes.
|
||||
|
||||
use crate::config::AuditConfig;
|
||||
use anyhow::Result;
|
||||
use anyhow::{bail, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Well-known seed for the genesis entry's `prev_hash`.
|
||||
const GENESIS_PREV_HASH: &str = "0000000000000000000000000000000000000000000000000000000000000000";
|
||||
|
||||
/// Audit event types
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -57,7 +64,7 @@ pub struct SecurityContext {
|
||||
pub sandbox_backend: Option<String>,
|
||||
}
|
||||
|
||||
/// Complete audit event
|
||||
/// Complete audit event with Merkle hash-chain fields.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditEvent {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
@@ -67,6 +74,16 @@ pub struct AuditEvent {
|
||||
pub action: Option<Action>,
|
||||
pub result: Option<ExecutionResult>,
|
||||
pub security: SecurityContext,
|
||||
|
||||
/// Monotonically increasing sequence number.
|
||||
#[serde(default)]
|
||||
pub sequence: u64,
|
||||
/// SHA-256 hash of the previous entry (genesis uses [`GENESIS_PREV_HASH`]).
|
||||
#[serde(default)]
|
||||
pub prev_hash: String,
|
||||
/// SHA-256 hash of (`prev_hash` || canonical JSON of this entry's content fields).
|
||||
#[serde(default)]
|
||||
pub entry_hash: String,
|
||||
}
|
||||
|
||||
impl AuditEvent {
|
||||
@@ -84,6 +101,9 @@ impl AuditEvent {
|
||||
rate_limit_remaining: None,
|
||||
sandbox_backend: None,
|
||||
},
|
||||
sequence: 0,
|
||||
prev_hash: String::new(),
|
||||
entry_hash: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,11 +163,42 @@ impl AuditEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the SHA-256 entry hash: `H(prev_hash || content_json)`.
|
||||
///
|
||||
/// `content_json` is the canonical JSON of the event *without* the chain fields
|
||||
/// (`sequence`, `prev_hash`, `entry_hash`), so the hash covers only the payload.
|
||||
fn compute_entry_hash(prev_hash: &str, event: &AuditEvent) -> String {
|
||||
// Build a canonical representation of the content fields only.
|
||||
let content = serde_json::json!({
|
||||
"timestamp": event.timestamp,
|
||||
"event_id": event.event_id,
|
||||
"event_type": event.event_type,
|
||||
"actor": event.actor,
|
||||
"action": event.action,
|
||||
"result": event.result,
|
||||
"security": event.security,
|
||||
"sequence": event.sequence,
|
||||
});
|
||||
let content_json = serde_json::to_string(&content).expect("serialize canonical content");
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(prev_hash.as_bytes());
|
||||
hasher.update(content_json.as_bytes());
|
||||
hex::encode(hasher.finalize())
|
||||
}
|
||||
|
||||
/// Internal chain state tracked across writes.
|
||||
struct ChainState {
|
||||
prev_hash: String,
|
||||
sequence: u64,
|
||||
}
|
||||
|
||||
/// Audit logger
|
||||
pub struct AuditLogger {
|
||||
log_path: PathBuf,
|
||||
config: AuditConfig,
|
||||
buffer: Mutex<Vec<AuditEvent>>,
|
||||
chain: Mutex<ChainState>,
|
||||
}
|
||||
|
||||
/// Structured command execution details for audit logging.
|
||||
@@ -163,13 +214,18 @@ pub struct CommandExecutionLog<'a> {
|
||||
}
|
||||
|
||||
impl AuditLogger {
|
||||
/// Create a new audit logger
|
||||
/// Create a new audit logger.
|
||||
///
|
||||
/// If the log file already exists, the chain state is recovered from the last
|
||||
/// entry so that new writes continue the existing hash chain.
|
||||
pub fn new(config: AuditConfig, zeroclaw_dir: PathBuf) -> Result<Self> {
|
||||
let log_path = zeroclaw_dir.join(&config.log_path);
|
||||
let chain_state = recover_chain_state(&log_path);
|
||||
Ok(Self {
|
||||
log_path,
|
||||
config,
|
||||
buffer: Mutex::new(Vec::new()),
|
||||
chain: Mutex::new(chain_state),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,8 +238,19 @@ impl AuditLogger {
|
||||
// Check log size and rotate if needed
|
||||
self.rotate_if_needed()?;
|
||||
|
||||
// Populate chain fields under the lock
|
||||
let mut chained = event.clone();
|
||||
{
|
||||
let mut state = self.chain.lock();
|
||||
chained.sequence = state.sequence;
|
||||
chained.prev_hash = state.prev_hash.clone();
|
||||
chained.entry_hash = compute_entry_hash(&state.prev_hash, &chained);
|
||||
state.prev_hash = chained.entry_hash.clone();
|
||||
state.sequence += 1;
|
||||
}
|
||||
|
||||
// Serialize and write
|
||||
let line = serde_json::to_string(event)?;
|
||||
let line = serde_json::to_string(&chained)?;
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
@@ -258,6 +325,102 @@ impl AuditLogger {
|
||||
}
|
||||
}
|
||||
|
||||
/// Recover chain state from an existing log file.
|
||||
///
|
||||
/// Returns the genesis state if the file does not exist or is empty.
|
||||
fn recover_chain_state(log_path: &Path) -> ChainState {
|
||||
let file = match std::fs::File::open(log_path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => {
|
||||
return ChainState {
|
||||
prev_hash: GENESIS_PREV_HASH.to_string(),
|
||||
sequence: 0,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(file);
|
||||
let mut last_entry: Option<AuditEvent> = None;
|
||||
for l in reader.lines().map_while(Result::ok) {
|
||||
if let Ok(entry) = serde_json::from_str::<AuditEvent>(&l) {
|
||||
last_entry = Some(entry);
|
||||
}
|
||||
}
|
||||
|
||||
match last_entry {
|
||||
Some(entry) => ChainState {
|
||||
prev_hash: entry.entry_hash,
|
||||
sequence: entry.sequence + 1,
|
||||
},
|
||||
None => ChainState {
|
||||
prev_hash: GENESIS_PREV_HASH.to_string(),
|
||||
sequence: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the integrity of an audit log's Merkle hash chain.
|
||||
///
|
||||
/// Reads every entry from the log file and checks:
|
||||
/// - Each `entry_hash` matches the recomputed `SHA-256(prev_hash || content)`.
|
||||
/// - `prev_hash` links to the preceding entry (or the genesis seed for the first).
|
||||
/// - Sequence numbers are contiguous starting from 0.
|
||||
///
|
||||
/// Returns `Ok(entry_count)` on success, or an error describing the first violation.
|
||||
pub fn verify_chain(log_path: &Path) -> Result<u64> {
|
||||
let file = std::fs::File::open(log_path)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut expected_prev_hash = GENESIS_PREV_HASH.to_string();
|
||||
let mut expected_sequence: u64 = 0;
|
||||
|
||||
for (line_idx, line) in reader.lines().enumerate() {
|
||||
let line = line?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let entry: AuditEvent = serde_json::from_str(&line)?;
|
||||
|
||||
// Check sequence continuity
|
||||
if entry.sequence != expected_sequence {
|
||||
bail!(
|
||||
"sequence gap at line {}: expected {}, got {}",
|
||||
line_idx + 1,
|
||||
expected_sequence,
|
||||
entry.sequence
|
||||
);
|
||||
}
|
||||
|
||||
// Check prev_hash linkage
|
||||
if entry.prev_hash != expected_prev_hash {
|
||||
bail!(
|
||||
"prev_hash mismatch at line {} (sequence {}): expected {}, got {}",
|
||||
line_idx + 1,
|
||||
entry.sequence,
|
||||
expected_prev_hash,
|
||||
entry.prev_hash
|
||||
);
|
||||
}
|
||||
|
||||
// Recompute and verify entry_hash
|
||||
let recomputed = compute_entry_hash(&entry.prev_hash, &entry);
|
||||
if entry.entry_hash != recomputed {
|
||||
bail!(
|
||||
"entry_hash mismatch at line {} (sequence {}): expected {}, got {}",
|
||||
line_idx + 1,
|
||||
entry.sequence,
|
||||
recomputed,
|
||||
entry.entry_hash
|
||||
);
|
||||
}
|
||||
|
||||
expected_prev_hash = entry.entry_hash.clone();
|
||||
expected_sequence += 1;
|
||||
}
|
||||
|
||||
Ok(expected_sequence)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -275,14 +438,14 @@ mod tests {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_actor(
|
||||
"telegram".to_string(),
|
||||
Some("123".to_string()),
|
||||
Some("@alice".to_string()),
|
||||
Some("@zeroclaw_user".to_string()),
|
||||
);
|
||||
|
||||
assert!(event.actor.is_some());
|
||||
let actor = event.actor.as_ref().unwrap();
|
||||
assert_eq!(actor.channel, "telegram");
|
||||
assert_eq!(actor.user_id, Some("123".to_string()));
|
||||
assert_eq!(actor.username, Some("@alice".to_string()));
|
||||
assert_eq!(actor.username, Some("@zeroclaw_user".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -420,4 +583,188 @@ mod tests {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Merkle hash-chain tests ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn merkle_chain_genesis_uses_well_known_seed() -> Result<()> {
|
||||
let tmp = TempDir::new()?;
|
||||
let config = AuditConfig {
|
||||
enabled: true,
|
||||
max_size_mb: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let logger = AuditLogger::new(config, tmp.path().to_path_buf())?;
|
||||
|
||||
let event = AuditEvent::new(AuditEventType::SecurityEvent);
|
||||
logger.log(&event)?;
|
||||
|
||||
let log_path = tmp.path().join("audit.log");
|
||||
let content = std::fs::read_to_string(&log_path)?;
|
||||
let parsed: AuditEvent = serde_json::from_str(content.trim())?;
|
||||
|
||||
assert_eq!(parsed.sequence, 0);
|
||||
assert_eq!(parsed.prev_hash, GENESIS_PREV_HASH);
|
||||
assert!(!parsed.entry_hash.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merkle_chain_multiple_entries_verify() -> Result<()> {
|
||||
let tmp = TempDir::new()?;
|
||||
let config = AuditConfig {
|
||||
enabled: true,
|
||||
max_size_mb: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let logger = AuditLogger::new(config, tmp.path().to_path_buf())?;
|
||||
|
||||
// Write several events
|
||||
for i in 0..5 {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
|
||||
format!("cmd-{}", i),
|
||||
"low".to_string(),
|
||||
false,
|
||||
true,
|
||||
);
|
||||
logger.log(&event)?;
|
||||
}
|
||||
|
||||
let log_path = tmp.path().join("audit.log");
|
||||
let count = verify_chain(&log_path)?;
|
||||
assert_eq!(count, 5);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merkle_chain_detects_tampered_entry() -> Result<()> {
|
||||
let tmp = TempDir::new()?;
|
||||
let config = AuditConfig {
|
||||
enabled: true,
|
||||
max_size_mb: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let logger = AuditLogger::new(config, tmp.path().to_path_buf())?;
|
||||
|
||||
for i in 0..3 {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
|
||||
format!("cmd-{}", i),
|
||||
"low".to_string(),
|
||||
false,
|
||||
true,
|
||||
);
|
||||
logger.log(&event)?;
|
||||
}
|
||||
|
||||
// Tamper with the second entry (change the command text)
|
||||
let log_path = tmp.path().join("audit.log");
|
||||
let content = std::fs::read_to_string(&log_path)?;
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
assert_eq!(lines.len(), 3);
|
||||
|
||||
let mut entry: serde_json::Value = serde_json::from_str(lines[1])?;
|
||||
entry["action"]["command"] = serde_json::Value::String("TAMPERED".to_string());
|
||||
let tampered_line = serde_json::to_string(&entry)?;
|
||||
|
||||
let tampered_content = format!("{}\n{}\n{}\n", lines[0], tampered_line, lines[2]);
|
||||
std::fs::write(&log_path, tampered_content)?;
|
||||
|
||||
// Verification must fail
|
||||
let result = verify_chain(&log_path);
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("entry_hash mismatch"),
|
||||
"expected entry_hash mismatch, got: {}",
|
||||
err_msg
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merkle_chain_detects_sequence_gap() -> Result<()> {
|
||||
let tmp = TempDir::new()?;
|
||||
let config = AuditConfig {
|
||||
enabled: true,
|
||||
max_size_mb: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let logger = AuditLogger::new(config, tmp.path().to_path_buf())?;
|
||||
|
||||
for i in 0..3 {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
|
||||
format!("cmd-{}", i),
|
||||
"low".to_string(),
|
||||
false,
|
||||
true,
|
||||
);
|
||||
logger.log(&event)?;
|
||||
}
|
||||
|
||||
// Remove the second entry to create a sequence gap
|
||||
let log_path = tmp.path().join("audit.log");
|
||||
let content = std::fs::read_to_string(&log_path)?;
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let gapped_content = format!("{}\n{}\n", lines[0], lines[2]);
|
||||
std::fs::write(&log_path, gapped_content)?;
|
||||
|
||||
let result = verify_chain(&log_path);
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("sequence gap"),
|
||||
"expected sequence gap, got: {}",
|
||||
err_msg
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merkle_chain_recovery_continues_after_restart() -> Result<()> {
|
||||
let tmp = TempDir::new()?;
|
||||
let log_path = tmp.path().join("audit.log");
|
||||
|
||||
// First logger writes 2 entries
|
||||
{
|
||||
let config = AuditConfig {
|
||||
enabled: true,
|
||||
max_size_mb: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let logger = AuditLogger::new(config, tmp.path().to_path_buf())?;
|
||||
for i in 0..2 {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
|
||||
format!("batch1-{}", i),
|
||||
"low".to_string(),
|
||||
false,
|
||||
true,
|
||||
);
|
||||
logger.log(&event)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Second logger (simulating restart) continues the chain
|
||||
{
|
||||
let config = AuditConfig {
|
||||
enabled: true,
|
||||
max_size_mb: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let logger = AuditLogger::new(config, tmp.path().to_path_buf())?;
|
||||
for i in 0..2 {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
|
||||
format!("batch2-{}", i),
|
||||
"low".to_string(),
|
||||
false,
|
||||
true,
|
||||
);
|
||||
logger.log(&event)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Full chain should verify (4 entries, sequences 0..3)
|
||||
let count = verify_chain(&log_path)?;
|
||||
assert_eq!(count, 4);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,757 @@
|
||||
//! Browser delegation tool.
|
||||
//!
|
||||
//! Delegates browser-based tasks to a browser-capable CLI subprocess (e.g.
|
||||
//! Claude Code with `claude-in-chrome` MCP tools) for interacting with
|
||||
//! corporate web applications (Teams, Outlook, Jira, Confluence) that lack
|
||||
//! direct API access.
|
||||
//!
|
||||
//! The tool spawns the configured CLI binary in non-interactive mode, passing
|
||||
//! a structured prompt that instructs it to use browser automation. A
|
||||
//! persistent Chrome profile can be configured so SSO sessions survive across
|
||||
//! invocations.
|
||||
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Configuration for browser delegation (`[browser_delegate]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct BrowserDelegateConfig {
|
||||
/// Enable browser delegation tool.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// CLI binary to use for browser tasks (default: `"claude"`).
|
||||
#[serde(default = "default_browser_cli")]
|
||||
pub cli_binary: String,
|
||||
/// Chrome profile directory for persistent SSO sessions.
|
||||
#[serde(default)]
|
||||
pub chrome_profile_dir: String,
|
||||
/// Allowed domains for browser navigation (empty = allow all non-blocked).
|
||||
#[serde(default)]
|
||||
pub allowed_domains: Vec<String>,
|
||||
/// Blocked domains for browser navigation.
|
||||
#[serde(default)]
|
||||
pub blocked_domains: Vec<String>,
|
||||
/// Task timeout in seconds.
|
||||
#[serde(default = "default_browser_task_timeout")]
|
||||
pub task_timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// Default CLI binary for browser delegation.
|
||||
fn default_browser_cli() -> String {
|
||||
"claude".into()
|
||||
}
|
||||
|
||||
/// Default task timeout in seconds (2 minutes).
|
||||
fn default_browser_task_timeout() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
impl Default for BrowserDelegateConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
cli_binary: default_browser_cli(),
|
||||
chrome_profile_dir: String::new(),
|
||||
allowed_domains: Vec::new(),
|
||||
blocked_domains: Vec::new(),
|
||||
task_timeout_secs: default_browser_task_timeout(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool that delegates browser-based tasks to a browser-capable CLI subprocess.
|
||||
pub struct BrowserDelegateTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
config: BrowserDelegateConfig,
|
||||
}
|
||||
|
||||
impl BrowserDelegateTool {
|
||||
/// Create a new `BrowserDelegateTool` with the given security policy and config.
|
||||
pub fn new(security: Arc<SecurityPolicy>, config: BrowserDelegateConfig) -> Self {
|
||||
Self { security, config }
|
||||
}
|
||||
|
||||
/// Build the CLI command for a browser task.
|
||||
///
|
||||
/// Constructs a `tokio::process::Command` with the configured CLI binary,
|
||||
/// `--print` flag for non-interactive mode, and optional Chrome profile env.
|
||||
fn build_command(&self, task: &str, url: Option<&str>) -> tokio::process::Command {
|
||||
let mut cmd = tokio::process::Command::new(&self.config.cli_binary);
|
||||
|
||||
// Claude Code non-interactive mode
|
||||
cmd.arg("--print");
|
||||
|
||||
let prompt = if let Some(url) = url {
|
||||
format!(
|
||||
"Use your browser tools to navigate to {} and perform the following task: {}",
|
||||
url, task
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Use your browser tools to perform the following task: {}",
|
||||
task
|
||||
)
|
||||
};
|
||||
|
||||
cmd.arg(&prompt);
|
||||
|
||||
// Set Chrome profile if configured for persistent SSO sessions
|
||||
if !self.config.chrome_profile_dir.is_empty() {
|
||||
cmd.env("CHROME_USER_DATA_DIR", &self.config.chrome_profile_dir);
|
||||
}
|
||||
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
cmd
|
||||
}
|
||||
|
||||
/// Extract URLs from free-form text and validate each against domain policy.
|
||||
///
|
||||
/// Prevents policy bypass by embedding blocked URLs in the `task` text,
|
||||
/// which is forwarded verbatim to the browser CLI subprocess.
|
||||
fn validate_task_urls(&self, task: &str) -> anyhow::Result<()> {
|
||||
let url_re = Regex::new(r#"https?://[^\s\)\]\},\"'`<>]+"#).expect("valid regex");
|
||||
for m in url_re.find_iter(task) {
|
||||
self.validate_url(m.as_str())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate URL against allowed/blocked domain lists and scheme restrictions.
|
||||
///
|
||||
/// Only `http` and `https` schemes are permitted. Blocked domains take
|
||||
/// precedence over allowed domains when both lists contain the same entry.
|
||||
fn validate_url(&self, url: &str) -> anyhow::Result<()> {
|
||||
let parsed = url
|
||||
.parse::<reqwest::Url>()
|
||||
.map_err(|e| anyhow::anyhow!("invalid URL '{}': {}", url, e))?;
|
||||
|
||||
// Only allow http/https schemes
|
||||
let scheme = parsed.scheme();
|
||||
if scheme != "http" && scheme != "https" {
|
||||
anyhow::bail!("unsupported URL scheme: {}", scheme);
|
||||
}
|
||||
|
||||
let domain = parsed.host_str().unwrap_or("").to_string();
|
||||
|
||||
if domain.is_empty() {
|
||||
anyhow::bail!("URL has no host: {}", url);
|
||||
}
|
||||
|
||||
// Check blocked domains first (deny takes precedence)
|
||||
for blocked in &self.config.blocked_domains {
|
||||
if domain_matches(&domain, blocked) {
|
||||
anyhow::bail!("domain '{}' is blocked by browser_delegate policy", domain);
|
||||
}
|
||||
}
|
||||
|
||||
// If allowed_domains is non-empty, it acts as an allowlist
|
||||
if !self.config.allowed_domains.is_empty() {
|
||||
let allowed = self
|
||||
.config
|
||||
.allowed_domains
|
||||
.iter()
|
||||
.any(|d| domain_matches(&domain, d));
|
||||
if !allowed {
|
||||
anyhow::bail!(
|
||||
"domain '{}' is not in browser_delegate allowed_domains",
|
||||
domain
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether `domain` matches a pattern (exact or suffix match).
|
||||
fn domain_matches(domain: &str, pattern: &str) -> bool {
|
||||
let d = domain.to_lowercase();
|
||||
let p = pattern.to_lowercase();
|
||||
d == p || d.ends_with(&format!(".{}", p))
|
||||
}
|
||||
|
||||
/// Maximum stderr bytes to capture from the subprocess.
|
||||
const MAX_STDERR_CHARS: usize = 512;
|
||||
|
||||
/// Supported values for the `extract_format` parameter.
|
||||
const VALID_EXTRACT_FORMATS: &[&str] = &["text", "json", "summary"];
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for BrowserDelegateTool {
|
||||
fn name(&self) -> &str {
|
||||
"browser_delegate"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Delegate browser-based tasks to a browser-capable CLI for interacting with web applications like Teams, Outlook, Jira, Confluence"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "Description of the browser task to perform"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Optional URL to navigate to before performing the task"
|
||||
},
|
||||
"extract_format": {
|
||||
"type": "string",
|
||||
"enum": ["text", "json", "summary"],
|
||||
"description": "Desired output format (default: text)"
|
||||
}
|
||||
},
|
||||
"required": ["task"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security gate
|
||||
if !self.security.can_act() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("browser_delegate tool is denied by security policy".into()),
|
||||
});
|
||||
}
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("browser_delegate action rate-limited".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let task = args
|
||||
.get("task")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("")
|
||||
.trim();
|
||||
|
||||
if task.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'task' parameter is required and cannot be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let url = args
|
||||
.get("url")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(str::trim)
|
||||
.filter(|u| !u.is_empty());
|
||||
|
||||
// Validate URL if provided
|
||||
if let Some(url) = url {
|
||||
if let Err(e) = self.validate_url(url) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("URL validation failed: {e}")),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Scan task text for embedded URLs and validate against domain policy.
|
||||
// This prevents bypassing domain restrictions by embedding blocked URLs
|
||||
// in the task text, which is forwarded verbatim to the browser CLI.
|
||||
if let Err(e) = self.validate_task_urls(task) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("task text contains a disallowed URL: {e}")),
|
||||
});
|
||||
}
|
||||
|
||||
let extract_format = args
|
||||
.get("extract_format")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("text");
|
||||
|
||||
// Validate extract_format against allowed enum values
|
||||
if !VALID_EXTRACT_FORMATS.contains(&extract_format) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"unsupported extract_format '{}': allowed values are 'text', 'json', 'summary'",
|
||||
extract_format
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Append format instruction to the task
|
||||
let full_task = match extract_format {
|
||||
"json" => format!("{task}. Return the result as structured JSON."),
|
||||
"summary" => format!("{task}. Return a concise summary."),
|
||||
_ => task.to_string(),
|
||||
};
|
||||
|
||||
let mut cmd = self.build_command(&full_task, url);
|
||||
// Ensure the subprocess is killed when the future is dropped (e.g. on timeout)
|
||||
cmd.kill_on_drop(true);
|
||||
|
||||
let deadline = Duration::from_secs(self.config.task_timeout_secs);
|
||||
let result = timeout(deadline, cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let stderr_truncated: String = stderr.chars().take(MAX_STDERR_CHARS).collect();
|
||||
|
||||
if output.status.success() {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: stdout,
|
||||
error: if stderr_truncated.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr_truncated)
|
||||
},
|
||||
})
|
||||
} else {
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: stdout,
|
||||
error: Some(format!(
|
||||
"CLI exited with status {}: {}",
|
||||
output.status, stderr_truncated
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("failed to spawn browser CLI: {e}")),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"browser task timed out after {}s",
|
||||
self.config.task_timeout_secs
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-built task templates for common corporate tools.
|
||||
pub struct BrowserTaskTemplates;
|
||||
|
||||
impl BrowserTaskTemplates {
|
||||
/// Read messages from a Microsoft Teams channel.
|
||||
pub fn read_teams_messages(channel: &str, count: usize) -> String {
|
||||
format!(
|
||||
"Open Microsoft Teams, navigate to the '{}' channel, \
|
||||
read the last {} messages, and return them as a structured \
|
||||
summary with sender, timestamp, and message content.",
|
||||
channel, count
|
||||
)
|
||||
}
|
||||
|
||||
/// Read emails from the Outlook Web inbox.
|
||||
pub fn read_outlook_inbox(count: usize) -> String {
|
||||
format!(
|
||||
"Open Outlook Web (outlook.office.com), go to the inbox, \
|
||||
read the last {} emails, and return a summary of each with \
|
||||
sender, subject, date, and first 2 lines of body.",
|
||||
count
|
||||
)
|
||||
}
|
||||
|
||||
/// Read Jira board for a project.
|
||||
pub fn read_jira_board(project: &str) -> String {
|
||||
format!(
|
||||
"Open Jira, navigate to the '{}' project board, and return \
|
||||
the current sprint tickets with their status, assignee, and title.",
|
||||
project
|
||||
)
|
||||
}
|
||||
|
||||
/// Read a Confluence page.
|
||||
pub fn read_confluence_page(url: &str) -> String {
|
||||
format!(
|
||||
"Open the Confluence page at {}, read the full content, \
|
||||
and return a structured summary.",
|
||||
url
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_test_config() -> BrowserDelegateConfig {
|
||||
BrowserDelegateConfig::default()
|
||||
}
|
||||
|
||||
fn config_with_domains(allowed: Vec<String>, blocked: Vec<String>) -> BrowserDelegateConfig {
|
||||
BrowserDelegateConfig {
|
||||
enabled: true,
|
||||
allowed_domains: allowed,
|
||||
blocked_domains: blocked,
|
||||
..BrowserDelegateConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn test_tool(config: BrowserDelegateConfig) -> BrowserDelegateTool {
|
||||
BrowserDelegateTool::new(Arc::new(SecurityPolicy::default()), config)
|
||||
}
|
||||
|
||||
// ── Config defaults ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn config_defaults_are_sensible() {
|
||||
let cfg = default_test_config();
|
||||
assert!(!cfg.enabled);
|
||||
assert_eq!(cfg.cli_binary, "claude");
|
||||
assert!(cfg.chrome_profile_dir.is_empty());
|
||||
assert!(cfg.allowed_domains.is_empty());
|
||||
assert!(cfg.blocked_domains.is_empty());
|
||||
assert_eq!(cfg.task_timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_serde_roundtrip() {
|
||||
let cfg = BrowserDelegateConfig {
|
||||
enabled: true,
|
||||
cli_binary: "my-cli".into(),
|
||||
chrome_profile_dir: "/tmp/profile".into(),
|
||||
allowed_domains: vec!["example.com".into()],
|
||||
blocked_domains: vec!["evil.com".into()],
|
||||
task_timeout_secs: 60,
|
||||
};
|
||||
let toml_str = toml::to_string(&cfg).unwrap();
|
||||
let parsed: BrowserDelegateConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert!(parsed.enabled);
|
||||
assert_eq!(parsed.cli_binary, "my-cli");
|
||||
assert_eq!(parsed.chrome_profile_dir, "/tmp/profile");
|
||||
assert_eq!(parsed.allowed_domains, vec!["example.com"]);
|
||||
assert_eq!(parsed.blocked_domains, vec!["evil.com"]);
|
||||
assert_eq!(parsed.task_timeout_secs, 60);
|
||||
}
|
||||
|
||||
// ── URL validation ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn validate_url_allows_when_no_restrictions() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec![]));
|
||||
assert!(tool.validate_url("https://example.com/page").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_blocked_domain() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec!["evil.com".into()]));
|
||||
let result = tool.validate_url("https://evil.com/phish");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("blocked"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_blocked_subdomain() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec!["evil.com".into()]));
|
||||
assert!(tool.validate_url("https://sub.evil.com/phish").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_allows_listed_domain() {
|
||||
let tool = test_tool(config_with_domains(vec!["corp.example.com".into()], vec![]));
|
||||
assert!(tool.validate_url("https://corp.example.com/page").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_unlisted_domain_with_allowlist() {
|
||||
let tool = test_tool(config_with_domains(vec!["corp.example.com".into()], vec![]));
|
||||
let result = tool.validate_url("https://other.example.com/page");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("not in"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocked_takes_precedence_over_allowed() {
|
||||
let tool = test_tool(config_with_domains(
|
||||
vec!["example.com".into()],
|
||||
vec!["example.com".into()],
|
||||
));
|
||||
let result = tool.validate_url("https://example.com/page");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("blocked"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_invalid_url() {
|
||||
let tool = test_tool(default_test_config());
|
||||
assert!(tool.validate_url("not-a-url").is_err());
|
||||
}
|
||||
|
||||
// ── Command building ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn build_command_uses_configured_binary() {
|
||||
let config = BrowserDelegateConfig {
|
||||
cli_binary: "my-browser-cli".into(),
|
||||
..BrowserDelegateConfig::default()
|
||||
};
|
||||
let tool = test_tool(config);
|
||||
let cmd = tool.build_command("read inbox", None);
|
||||
assert_eq!(cmd.as_std().get_program(), "my-browser-cli");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_command_includes_print_flag() {
|
||||
let tool = test_tool(default_test_config());
|
||||
let cmd = tool.build_command("read inbox", None);
|
||||
let args: Vec<&std::ffi::OsStr> = cmd.as_std().get_args().collect();
|
||||
assert!(args.contains(&std::ffi::OsStr::new("--print")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_command_includes_url_in_prompt() {
|
||||
let tool = test_tool(default_test_config());
|
||||
let cmd = tool.build_command("read page", Some("https://example.com"));
|
||||
let args: Vec<String> = cmd
|
||||
.as_std()
|
||||
.get_args()
|
||||
.map(|a| a.to_string_lossy().to_string())
|
||||
.collect();
|
||||
let prompt = args.last().unwrap();
|
||||
assert!(prompt.contains("https://example.com"));
|
||||
assert!(prompt.contains("read page"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_command_sets_chrome_profile_env() {
|
||||
let config = BrowserDelegateConfig {
|
||||
chrome_profile_dir: "/tmp/chrome-profile".into(),
|
||||
..BrowserDelegateConfig::default()
|
||||
};
|
||||
let tool = test_tool(config);
|
||||
let cmd = tool.build_command("task", None);
|
||||
let envs: Vec<_> = cmd.as_std().get_envs().collect();
|
||||
let chrome_env = envs
|
||||
.iter()
|
||||
.find(|(k, _)| k == &std::ffi::OsStr::new("CHROME_USER_DATA_DIR"));
|
||||
assert!(chrome_env.is_some());
|
||||
assert_eq!(
|
||||
chrome_env.unwrap().1,
|
||||
Some(std::ffi::OsStr::new("/tmp/chrome-profile"))
|
||||
);
|
||||
}
|
||||
|
||||
// ── Task templates ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn template_teams_includes_channel_and_count() {
|
||||
let t = BrowserTaskTemplates::read_teams_messages("engineering", 10);
|
||||
assert!(t.contains("engineering"));
|
||||
assert!(t.contains("10"));
|
||||
assert!(t.contains("Teams"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_outlook_includes_count() {
|
||||
let t = BrowserTaskTemplates::read_outlook_inbox(5);
|
||||
assert!(t.contains('5'));
|
||||
assert!(t.contains("Outlook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_jira_includes_project() {
|
||||
let t = BrowserTaskTemplates::read_jira_board("PROJ-X");
|
||||
assert!(t.contains("PROJ-X"));
|
||||
assert!(t.contains("Jira"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_confluence_includes_url() {
|
||||
let t = BrowserTaskTemplates::read_confluence_page("https://wiki.example.com/page/123");
|
||||
assert!(t.contains("https://wiki.example.com/page/123"));
|
||||
assert!(t.contains("Confluence"));
|
||||
}
|
||||
|
||||
// ── Domain matching ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn domain_matches_exact() {
|
||||
assert!(domain_matches("example.com", "example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_matches_subdomain() {
|
||||
assert!(domain_matches("sub.example.com", "example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_matches_case_insensitive() {
|
||||
assert!(domain_matches("Example.COM", "example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_does_not_match_partial() {
|
||||
assert!(!domain_matches("notexample.com", "example.com"));
|
||||
}
|
||||
|
||||
// ── Execute edge cases ──────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_empty_task() {
|
||||
let tool = test_tool(default_test_config());
|
||||
let result = tool
|
||||
.execute(serde_json::json!({ "task": "" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("required"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_blocked_url() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec!["evil.com".into()]));
|
||||
let result = tool
|
||||
.execute(serde_json::json!({
|
||||
"task": "read page",
|
||||
"url": "https://evil.com/page"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("blocked"));
|
||||
}
|
||||
|
||||
// ── URL scheme validation ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_ftp_scheme() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec![]));
|
||||
let result = tool.validate_url("ftp://example.com/file");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("unsupported URL scheme"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_file_scheme() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec![]));
|
||||
let result = tool.validate_url("file:///etc/passwd");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("unsupported URL scheme"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_javascript_scheme() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec![]));
|
||||
let result = tool.validate_url("javascript:alert(1)");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("unsupported URL scheme"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_rejects_data_scheme() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec![]));
|
||||
let result = tool.validate_url("data:text/html,<h1>hi</h1>");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("unsupported URL scheme"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_allows_http_scheme() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec![]));
|
||||
assert!(tool.validate_url("http://example.com/page").is_ok());
|
||||
}
|
||||
|
||||
// ── Task text URL scanning ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn validate_task_urls_blocks_embedded_blocked_url() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec!["evil.com".into()]));
|
||||
let result = tool.validate_task_urls("go to https://evil.com/steal and read it");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("blocked"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_task_urls_blocks_embedded_url_not_in_allowlist() {
|
||||
let tool = test_tool(config_with_domains(vec!["corp.example.com".into()], vec![]));
|
||||
let result =
|
||||
tool.validate_task_urls("navigate to https://attacker.com/page and extract data");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("not in"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_task_urls_allows_permitted_embedded_url() {
|
||||
let tool = test_tool(config_with_domains(vec!["corp.example.com".into()], vec![]));
|
||||
assert!(tool
|
||||
.validate_task_urls("read https://corp.example.com/page and summarize")
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_task_urls_allows_text_without_urls() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec!["evil.com".into()]));
|
||||
assert!(tool
|
||||
.validate_task_urls("read the last 10 messages from engineering channel")
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_blocked_url_in_task_text() {
|
||||
let tool = test_tool(config_with_domains(vec![], vec!["evil.com".into()]));
|
||||
let result = tool
|
||||
.execute(serde_json::json!({
|
||||
"task": "navigate to https://evil.com/phish and extract credentials"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("disallowed URL"));
|
||||
}
|
||||
|
||||
// ── extract_format validation ──────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_invalid_extract_format() {
|
||||
let tool = test_tool(default_test_config());
|
||||
let result = tool
|
||||
.execute(serde_json::json!({
|
||||
"task": "read page",
|
||||
"extract_format": "xml"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("unsupported extract_format"));
|
||||
assert!(result.error.as_deref().unwrap().contains("xml"));
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,8 @@ pub enum CliCategory {
|
||||
Container,
|
||||
Build,
|
||||
Cloud,
|
||||
Productivity,
|
||||
AiAgent,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CliCategory {
|
||||
@@ -23,6 +25,8 @@ impl std::fmt::Display for CliCategory {
|
||||
Self::Container => write!(f, "Container"),
|
||||
Self::Build => write!(f, "Build"),
|
||||
Self::Cloud => write!(f, "Cloud"),
|
||||
Self::Productivity => write!(f, "Productivity"),
|
||||
Self::AiAgent => write!(f, "AI Agent"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,6 +108,26 @@ const KNOWN_CLIS: &[KnownCli] = &[
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::Language,
|
||||
},
|
||||
KnownCli {
|
||||
name: "gws",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::Productivity,
|
||||
},
|
||||
KnownCli {
|
||||
name: "claude",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::AiAgent,
|
||||
},
|
||||
KnownCli {
|
||||
name: "gemini",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::AiAgent,
|
||||
},
|
||||
KnownCli {
|
||||
name: "kilo",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::AiAgent,
|
||||
},
|
||||
];
|
||||
|
||||
/// Discover available CLI tools on the system.
|
||||
@@ -235,5 +259,7 @@ mod tests {
|
||||
assert_eq!(CliCategory::Container.to_string(), "Container");
|
||||
assert_eq!(CliCategory::Build.to_string(), "Build");
|
||||
assert_eq!(CliCategory::Cloud.to_string(), "Cloud");
|
||||
assert_eq!(CliCategory::Productivity.to_string(), "Productivity");
|
||||
assert_eq!(CliCategory::AiAgent.to_string(), "AI Agent");
|
||||
}
|
||||
}
|
||||
|
||||
+152
-15
@@ -65,27 +65,97 @@ impl Tool for CronAddTool {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"schedule": {
|
||||
"type": "object",
|
||||
"description": "Schedule object: {kind:'cron',expr,tz?} | {kind:'at',at} | {kind:'every',every_ms}"
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-readable name for the job"
|
||||
},
|
||||
// NOTE: oneOf is correct for OpenAI-compatible APIs (including OpenRouter).
|
||||
// Gemini does not support oneOf in tool schemas; if Gemini native tool calling
|
||||
// is ever wired up, SchemaCleanr::clean_for_gemini must be applied before
|
||||
// tool specs are sent. See src/tools/schema.rs.
|
||||
"schedule": {
|
||||
"description": "When to run the job. Exactly one of three forms must be used.",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Cron expression schedule (repeating). Example: {\"kind\":\"cron\",\"expr\":\"0 9 * * 1-5\",\"tz\":\"America/New_York\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["cron"] },
|
||||
"expr": { "type": "string", "description": "Standard 5-field cron expression, e.g. '*/5 * * * *'" },
|
||||
"tz": { "type": "string", "description": "Optional IANA timezone name, e.g. 'America/New_York'. Defaults to UTC." }
|
||||
},
|
||||
"required": ["kind", "expr"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "One-shot schedule at a specific UTC datetime. Example: {\"kind\":\"at\",\"at\":\"2025-12-31T23:59:00Z\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["at"] },
|
||||
"at": { "type": "string", "description": "ISO 8601 UTC datetime string, e.g. '2025-12-31T23:59:00Z'" }
|
||||
},
|
||||
"required": ["kind", "at"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Repeating interval schedule in milliseconds. Example: {\"kind\":\"every\",\"every_ms\":3600000} runs every hour.",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["every"] },
|
||||
"every_ms": { "type": "integer", "description": "Interval in milliseconds, e.g. 3600000 for every hour" }
|
||||
},
|
||||
"required": ["kind", "every_ms"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"job_type": {
|
||||
"type": "string",
|
||||
"enum": ["shell", "agent"],
|
||||
"description": "Type of job: 'shell' runs a command, 'agent' runs the AI agent with a prompt"
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command to run (required when job_type is 'shell')"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Agent prompt to run on schedule (required when job_type is 'agent')"
|
||||
},
|
||||
"session_target": {
|
||||
"type": "string",
|
||||
"enum": ["isolated", "main"],
|
||||
"description": "Agent session context: 'isolated' starts a fresh session each run, 'main' reuses the primary session"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Optional model override for agent jobs, e.g. 'x-ai/grok-4-1-fast'"
|
||||
},
|
||||
"job_type": { "type": "string", "enum": ["shell", "agent"] },
|
||||
"command": { "type": "string" },
|
||||
"prompt": { "type": "string" },
|
||||
"session_target": { "type": "string", "enum": ["isolated", "main"] },
|
||||
"model": { "type": "string" },
|
||||
"delivery": {
|
||||
"type": "object",
|
||||
"description": "Delivery config to send job output to a channel. Example: {\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id>\"}",
|
||||
"description": "Optional delivery config to send job output to a channel after each run. When provided, all three of mode, channel, and to are expected.",
|
||||
"properties": {
|
||||
"mode": { "type": "string", "enum": ["none", "announce"], "description": "Set to 'announce' to deliver output to a channel" },
|
||||
"channel": { "type": "string", "enum": ["telegram", "discord", "slack", "mattermost", "matrix"], "description": "Channel type to deliver to" },
|
||||
"to": { "type": "string", "description": "Target: Discord channel ID, Telegram chat ID, Slack channel, etc." },
|
||||
"best_effort": { "type": "boolean", "description": "If true, delivery failure does not fail the job" }
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["none", "announce"],
|
||||
"description": "'announce' sends output to the specified channel; 'none' disables delivery"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"enum": ["telegram", "discord", "slack", "mattermost", "matrix"],
|
||||
"description": "Channel type to deliver output to"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "Destination ID: Discord channel ID, Telegram chat ID, Slack channel name, etc."
|
||||
},
|
||||
"best_effort": {
|
||||
"type": "boolean",
|
||||
"description": "If true, a delivery failure does not fail the job itself. Defaults to true."
|
||||
}
|
||||
}
|
||||
},
|
||||
"delete_after_run": { "type": "boolean" },
|
||||
"delete_after_run": {
|
||||
"type": "boolean",
|
||||
"description": "If true, the job is automatically deleted after its first successful run. Defaults to true for 'at' schedules."
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean",
|
||||
"description": "Set true to explicitly approve medium/high-risk shell commands in supervised mode",
|
||||
@@ -497,4 +567,71 @@ mod tests {
|
||||
|
||||
assert!(values.iter().any(|value| value == "matrix"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schedule_schema_is_oneof_with_cron_at_every_variants() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let cfg = Arc::new(Config {
|
||||
workspace_dir: tmp.path().join("workspace"),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
});
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&cfg.autonomy,
|
||||
&cfg.workspace_dir,
|
||||
));
|
||||
let tool = CronAddTool::new(cfg, security);
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
// Top-level: schedule is required
|
||||
let top_required = schema["required"].as_array().expect("top-level required");
|
||||
assert!(top_required.iter().any(|v| v == "schedule"));
|
||||
|
||||
// schedule is a oneOf with exactly 3 variants: cron, at, every
|
||||
let one_of = schema["properties"]["schedule"]["oneOf"]
|
||||
.as_array()
|
||||
.expect("schedule.oneOf must be an array");
|
||||
assert_eq!(one_of.len(), 3, "expected cron, at, and every variants");
|
||||
|
||||
let kinds: Vec<&str> = one_of
|
||||
.iter()
|
||||
.filter_map(|v| v["properties"]["kind"]["enum"][0].as_str())
|
||||
.collect();
|
||||
assert!(kinds.contains(&"cron"), "missing cron variant");
|
||||
assert!(kinds.contains(&"at"), "missing at variant");
|
||||
assert!(kinds.contains(&"every"), "missing every variant");
|
||||
|
||||
// Each variant declares its required fields and every_ms is typed integer
|
||||
for variant in one_of {
|
||||
let kind = variant["properties"]["kind"]["enum"][0]
|
||||
.as_str()
|
||||
.expect("variant kind");
|
||||
let req: Vec<&str> = variant["required"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("{kind} variant must have required"))
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect();
|
||||
assert!(
|
||||
req.contains(&"kind"),
|
||||
"{kind} variant missing 'kind' in required"
|
||||
);
|
||||
match kind {
|
||||
"cron" => assert!(req.contains(&"expr"), "cron variant missing 'expr'"),
|
||||
"at" => assert!(req.contains(&"at"), "at variant missing 'at'"),
|
||||
"every" => {
|
||||
assert!(
|
||||
req.contains(&"every_ms"),
|
||||
"every variant missing 'every_ms'"
|
||||
);
|
||||
assert_eq!(
|
||||
variant["properties"]["every_ms"]["type"].as_str(),
|
||||
Some("integer"),
|
||||
"every_ms must be typed as integer"
|
||||
);
|
||||
}
|
||||
_ => panic!("unexpected kind: {kind}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+200
-2
@@ -61,8 +61,106 @@ impl Tool for CronUpdateTool {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": { "type": "string" },
|
||||
"patch": { "type": "object" },
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the cron job to update, as returned by cron_add or cron_list"
|
||||
},
|
||||
"patch": {
|
||||
"type": "object",
|
||||
"description": "Fields to update. Only include fields you want to change; omitted fields are left as-is.",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "New human-readable name for the job"
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"description": "Enable or disable the job without deleting it"
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "New shell command (for shell jobs)"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "New agent prompt (for agent jobs)"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Model override for agent jobs, e.g. 'x-ai/grok-4-1-fast'"
|
||||
},
|
||||
"session_target": {
|
||||
"type": "string",
|
||||
"enum": ["isolated", "main"],
|
||||
"description": "Agent session context: 'isolated' starts fresh each run, 'main' reuses the primary session"
|
||||
},
|
||||
"delete_after_run": {
|
||||
"type": "boolean",
|
||||
"description": "If true, delete the job automatically after its first successful run"
|
||||
},
|
||||
// NOTE: oneOf is correct for OpenAI-compatible APIs (including OpenRouter).
|
||||
// Gemini does not support oneOf in tool schemas; if Gemini native tool calling
|
||||
// is ever wired up, SchemaCleanr::clean_for_gemini must be applied before
|
||||
// tool specs are sent. See src/tools/schema.rs.
|
||||
"schedule": {
|
||||
"description": "New schedule for the job. Exactly one of three forms must be used.",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Cron expression schedule (repeating). Example: {\"kind\":\"cron\",\"expr\":\"0 9 * * 1-5\",\"tz\":\"America/New_York\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["cron"] },
|
||||
"expr": { "type": "string", "description": "Standard 5-field cron expression, e.g. '*/5 * * * *'" },
|
||||
"tz": { "type": "string", "description": "Optional IANA timezone name, e.g. 'America/New_York'. Defaults to UTC." }
|
||||
},
|
||||
"required": ["kind", "expr"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "One-shot schedule at a specific UTC datetime. Example: {\"kind\":\"at\",\"at\":\"2025-12-31T23:59:00Z\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["at"] },
|
||||
"at": { "type": "string", "description": "ISO 8601 UTC datetime string, e.g. '2025-12-31T23:59:00Z'" }
|
||||
},
|
||||
"required": ["kind", "at"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Repeating interval schedule in milliseconds. Example: {\"kind\":\"every\",\"every_ms\":3600000} runs every hour.",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["every"] },
|
||||
"every_ms": { "type": "integer", "description": "Interval in milliseconds, e.g. 3600000 for every hour" }
|
||||
},
|
||||
"required": ["kind", "every_ms"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"delivery": {
|
||||
"type": "object",
|
||||
"description": "Delivery config to send job output to a channel after each run. When provided, mode, channel, and to are all expected.",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["none", "announce"],
|
||||
"description": "'announce' sends output to the specified channel; 'none' disables delivery"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"enum": ["telegram", "discord", "slack", "mattermost", "matrix"],
|
||||
"description": "Channel type to deliver output to"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "Destination ID: Discord channel ID, Telegram chat ID, Slack channel name, etc."
|
||||
},
|
||||
"best_effort": {
|
||||
"type": "boolean",
|
||||
"description": "If true, a delivery failure does not fail the job itself. Defaults to true."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean",
|
||||
"description": "Set true to explicitly approve medium/high-risk shell commands in supervised mode",
|
||||
@@ -274,6 +372,106 @@ mod tests {
|
||||
assert!(approved.success, "{:?}", approved.error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_schema_covers_all_cronjobpatch_fields_and_schedule_is_oneof() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = Arc::new(Config {
|
||||
workspace_dir: tmp.path().join("workspace"),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
});
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&cfg.autonomy,
|
||||
&cfg.workspace_dir,
|
||||
));
|
||||
let tool = CronUpdateTool::new(cfg, security);
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
// Top-level: job_id and patch are required
|
||||
let top_required = schema["required"].as_array().expect("top-level required");
|
||||
let top_req_strs: Vec<&str> = top_required.iter().filter_map(|v| v.as_str()).collect();
|
||||
assert!(top_req_strs.contains(&"job_id"));
|
||||
assert!(top_req_strs.contains(&"patch"));
|
||||
|
||||
// patch exposes all CronJobPatch fields
|
||||
let patch_props = schema["properties"]["patch"]["properties"]
|
||||
.as_object()
|
||||
.expect("patch must have a properties object");
|
||||
for field in &[
|
||||
"name",
|
||||
"enabled",
|
||||
"command",
|
||||
"prompt",
|
||||
"model",
|
||||
"session_target",
|
||||
"delete_after_run",
|
||||
"schedule",
|
||||
"delivery",
|
||||
] {
|
||||
assert!(
|
||||
patch_props.contains_key(*field),
|
||||
"patch schema missing field: {field}"
|
||||
);
|
||||
}
|
||||
|
||||
// patch.schedule is a oneOf with exactly 3 variants: cron, at, every
|
||||
let one_of = schema["properties"]["patch"]["properties"]["schedule"]["oneOf"]
|
||||
.as_array()
|
||||
.expect("patch.schedule.oneOf must be an array");
|
||||
assert_eq!(one_of.len(), 3, "expected cron, at, and every variants");
|
||||
|
||||
let kinds: Vec<&str> = one_of
|
||||
.iter()
|
||||
.filter_map(|v| v["properties"]["kind"]["enum"][0].as_str())
|
||||
.collect();
|
||||
assert!(kinds.contains(&"cron"), "missing cron variant");
|
||||
assert!(kinds.contains(&"at"), "missing at variant");
|
||||
assert!(kinds.contains(&"every"), "missing every variant");
|
||||
|
||||
// Each variant declares its required fields and every_ms is typed integer
|
||||
for variant in one_of {
|
||||
let kind = variant["properties"]["kind"]["enum"][0]
|
||||
.as_str()
|
||||
.expect("variant kind");
|
||||
let req: Vec<&str> = variant["required"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("{kind} variant must have required"))
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect();
|
||||
assert!(
|
||||
req.contains(&"kind"),
|
||||
"{kind} variant missing 'kind' in required"
|
||||
);
|
||||
match kind {
|
||||
"cron" => assert!(req.contains(&"expr"), "cron variant missing 'expr'"),
|
||||
"at" => assert!(req.contains(&"at"), "at variant missing 'at'"),
|
||||
"every" => {
|
||||
assert!(
|
||||
req.contains(&"every_ms"),
|
||||
"every variant missing 'every_ms'"
|
||||
);
|
||||
assert_eq!(
|
||||
variant["properties"]["every_ms"]["type"].as_str(),
|
||||
Some("integer"),
|
||||
"every_ms must be typed as integer"
|
||||
);
|
||||
}
|
||||
_ => panic!("unexpected schedule kind: {kind}"),
|
||||
}
|
||||
}
|
||||
|
||||
// patch.delivery.channel enum covers all supported channels
|
||||
let channel_enum = schema["properties"]["patch"]["properties"]["delivery"]["properties"]
|
||||
["channel"]["enum"]
|
||||
.as_array()
|
||||
.expect("patch.delivery.channel must have an enum");
|
||||
let channel_strs: Vec<&str> = channel_enum.iter().filter_map(|v| v.as_str()).collect();
|
||||
for ch in &["telegram", "discord", "slack", "mattermost", "matrix"] {
|
||||
assert!(channel_strs.contains(ch), "delivery.channel missing: {ch}");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blocks_update_when_rate_limited() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@@ -421,6 +421,7 @@ impl DelegateTool {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -0,0 +1,716 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default `gws` command execution time before kill (overridden by config).
|
||||
const DEFAULT_GWS_TIMEOUT_SECS: u64 = 30;
|
||||
/// Maximum output size in bytes (1MB).
|
||||
const MAX_OUTPUT_BYTES: usize = 1_048_576;
|
||||
|
||||
/// Allowed Google Workspace services that gws can target.
|
||||
const DEFAULT_ALLOWED_SERVICES: &[&str] = &[
|
||||
"drive",
|
||||
"sheets",
|
||||
"gmail",
|
||||
"calendar",
|
||||
"docs",
|
||||
"slides",
|
||||
"tasks",
|
||||
"people",
|
||||
"chat",
|
||||
"classroom",
|
||||
"forms",
|
||||
"keep",
|
||||
"meet",
|
||||
"events",
|
||||
];
|
||||
|
||||
/// Google Workspace CLI (`gws`) integration tool.
|
||||
///
|
||||
/// Wraps the `gws` CLI binary to give the agent structured access to
|
||||
/// Google Workspace services (Drive, Gmail, Calendar, Sheets, etc.).
|
||||
/// Requires `gws` to be installed and authenticated (`gws auth login`).
|
||||
pub struct GoogleWorkspaceTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
allowed_services: Vec<String>,
|
||||
credentials_path: Option<String>,
|
||||
default_account: Option<String>,
|
||||
rate_limit_per_minute: u32,
|
||||
timeout_secs: u64,
|
||||
audit_log: bool,
|
||||
}
|
||||
|
||||
impl GoogleWorkspaceTool {
|
||||
/// Create a new `GoogleWorkspaceTool`.
|
||||
///
|
||||
/// If `allowed_services` is empty, the default service set is used.
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
allowed_services: Vec<String>,
|
||||
credentials_path: Option<String>,
|
||||
default_account: Option<String>,
|
||||
rate_limit_per_minute: u32,
|
||||
timeout_secs: u64,
|
||||
audit_log: bool,
|
||||
) -> Self {
|
||||
let services = if allowed_services.is_empty() {
|
||||
DEFAULT_ALLOWED_SERVICES
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect()
|
||||
} else {
|
||||
allowed_services
|
||||
};
|
||||
Self {
|
||||
security,
|
||||
allowed_services: services,
|
||||
credentials_path,
|
||||
default_account,
|
||||
rate_limit_per_minute,
|
||||
timeout_secs,
|
||||
audit_log,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for GoogleWorkspaceTool {
|
||||
fn name(&self) -> &str {
|
||||
"google_workspace"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Interact with Google Workspace services (Drive, Gmail, Calendar, Sheets, Docs, etc.) \
|
||||
via the gws CLI. Requires gws to be installed and authenticated."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"service": {
|
||||
"type": "string",
|
||||
"description": "Google Workspace service (e.g. drive, gmail, calendar, sheets, docs, slides, tasks, people, chat, forms, keep, meet)"
|
||||
},
|
||||
"resource": {
|
||||
"type": "string",
|
||||
"description": "Service resource (e.g. files, messages, events, spreadsheets)"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "Method to call on the resource (e.g. list, get, create, update, delete)"
|
||||
},
|
||||
"sub_resource": {
|
||||
"type": "string",
|
||||
"description": "Optional sub-resource for nested operations"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "URL/query parameters as key-value pairs (passed as --params JSON)"
|
||||
},
|
||||
"body": {
|
||||
"type": "object",
|
||||
"description": "Request body for POST/PATCH/PUT operations (passed as --json JSON)"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["json", "table", "yaml", "csv"],
|
||||
"description": "Output format (default: json)"
|
||||
},
|
||||
"page_all": {
|
||||
"type": "boolean",
|
||||
"description": "Auto-paginate through all results"
|
||||
},
|
||||
"page_limit": {
|
||||
"type": "integer",
|
||||
"description": "Max pages to fetch when using page_all (default: 10)"
|
||||
}
|
||||
},
|
||||
"required": ["service", "resource", "method"]
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a Google Workspace CLI command with input validation and security enforcement.
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let service = args
|
||||
.get("service")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'service' parameter"))?;
|
||||
let resource = args
|
||||
.get("resource")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'resource' parameter"))?;
|
||||
let method = args
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'method' parameter"))?;
|
||||
|
||||
// Security checks
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate service is in the allowlist
|
||||
if !self.allowed_services.iter().any(|s| s == service) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Service '{service}' is not in the allowed services list. \
|
||||
Allowed: {}",
|
||||
self.allowed_services.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate inputs contain no shell metacharacters
|
||||
for (label, value) in [
|
||||
("service", service),
|
||||
("resource", resource),
|
||||
("method", method),
|
||||
] {
|
||||
if !value
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid characters in '{label}': only alphanumeric, underscore, and hyphen are allowed"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Build the gws command — validate all optional fields before consuming budget
|
||||
let mut cmd_args: Vec<String> = vec![service.to_string(), resource.to_string()];
|
||||
|
||||
if let Some(sub_resource_value) = args.get("sub_resource") {
|
||||
let sub_resource = match sub_resource_value.as_str() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'sub_resource' must be a string".into()),
|
||||
})
|
||||
}
|
||||
};
|
||||
if !sub_resource
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Invalid characters in 'sub_resource': only alphanumeric, underscore, and hyphen are allowed"
|
||||
.into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
cmd_args.push(sub_resource.to_string());
|
||||
}
|
||||
|
||||
cmd_args.push(method.to_string());
|
||||
|
||||
if let Some(params) = args.get("params") {
|
||||
if !params.is_object() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'params' must be an object".into()),
|
||||
});
|
||||
}
|
||||
cmd_args.push("--params".into());
|
||||
cmd_args.push(params.to_string());
|
||||
}
|
||||
|
||||
if let Some(body) = args.get("body") {
|
||||
if !body.is_object() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'body' must be an object".into()),
|
||||
});
|
||||
}
|
||||
cmd_args.push("--json".into());
|
||||
cmd_args.push(body.to_string());
|
||||
}
|
||||
|
||||
if let Some(format_value) = args.get("format") {
|
||||
let format = match format_value.as_str() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'format' must be a string".into()),
|
||||
})
|
||||
}
|
||||
};
|
||||
match format {
|
||||
"json" | "table" | "yaml" | "csv" => {
|
||||
cmd_args.push("--format".into());
|
||||
cmd_args.push(format.to_string());
|
||||
}
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid format '{format}': must be json, table, yaml, or csv"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let page_all = match args.get("page_all") {
|
||||
Some(v) => match v.as_bool() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'page_all' must be a boolean".into()),
|
||||
})
|
||||
}
|
||||
},
|
||||
None => false,
|
||||
};
|
||||
if page_all {
|
||||
cmd_args.push("--page-all".into());
|
||||
}
|
||||
|
||||
let page_limit = match args.get("page_limit") {
|
||||
Some(v) => match v.as_u64() {
|
||||
Some(n) => Some(n),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'page_limit' must be a non-negative integer".into()),
|
||||
})
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
if page_all || page_limit.is_some() {
|
||||
cmd_args.push("--page-limit".into());
|
||||
cmd_args.push(page_limit.unwrap_or(10).to_string());
|
||||
}
|
||||
|
||||
// Charge action budget only after all validation passes
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut cmd = tokio::process::Command::new("gws");
|
||||
cmd.args(&cmd_args);
|
||||
cmd.env_clear();
|
||||
// gws needs PATH to find itself and HOME/APPDATA for credential storage
|
||||
for key in &["PATH", "HOME", "APPDATA", "USERPROFILE", "LANG", "TERM"] {
|
||||
if let Ok(val) = std::env::var(key) {
|
||||
cmd.env(key, val);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply credential path if configured
|
||||
if let Some(ref creds) = self.credentials_path {
|
||||
cmd.env("GOOGLE_APPLICATION_CREDENTIALS", creds);
|
||||
}
|
||||
|
||||
// Apply default account if configured
|
||||
if let Some(ref account) = self.default_account {
|
||||
cmd.args(["--account", account]);
|
||||
}
|
||||
|
||||
if self.audit_log {
|
||||
tracing::info!(
|
||||
tool = "google_workspace",
|
||||
service = service,
|
||||
resource = resource,
|
||||
method = method,
|
||||
"gws audit: executing API call"
|
||||
);
|
||||
}
|
||||
|
||||
// Apply credential path if configured
|
||||
if let Some(ref creds) = self.credentials_path {
|
||||
cmd.env("GOOGLE_APPLICATION_CREDENTIALS", creds);
|
||||
}
|
||||
|
||||
// Apply default account if configured
|
||||
if let Some(ref account) = self.default_account {
|
||||
cmd.args(["--account", account]);
|
||||
}
|
||||
|
||||
if self.audit_log {
|
||||
tracing::info!(
|
||||
tool = "google_workspace",
|
||||
service = service,
|
||||
resource = resource,
|
||||
method = method,
|
||||
"gws audit: executing API call"
|
||||
);
|
||||
}
|
||||
|
||||
let result =
|
||||
tokio::time::timeout(Duration::from_secs(self.timeout_secs), cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let mut stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
if stdout.len() > MAX_OUTPUT_BYTES {
|
||||
// Find a valid char boundary at or before MAX_OUTPUT_BYTES
|
||||
let mut boundary = MAX_OUTPUT_BYTES;
|
||||
while boundary > 0 && !stdout.is_char_boundary(boundary) {
|
||||
boundary -= 1;
|
||||
}
|
||||
stdout.truncate(boundary);
|
||||
stdout.push_str("\n... [output truncated at 1MB]");
|
||||
}
|
||||
if stderr.len() > MAX_OUTPUT_BYTES {
|
||||
let mut boundary = MAX_OUTPUT_BYTES;
|
||||
while boundary > 0 && !stderr.is_char_boundary(boundary) {
|
||||
boundary -= 1;
|
||||
}
|
||||
stderr.truncate(boundary);
|
||||
stderr.push_str("\n... [stderr truncated at 1MB]");
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to execute gws: {e}. Is gws installed? Run: npm install -g @googleworkspace/cli"
|
||||
)),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"gws command timed out after {}s and was killed", self.timeout_secs
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
assert_eq!(tool.name(), "google_workspace");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_description_non_empty() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_required_fields() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["service"].is_object());
|
||||
assert!(schema["properties"]["resource"].is_object());
|
||||
assert!(schema["properties"]["method"].is_object());
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required should be an array");
|
||||
assert!(required.contains(&json!("service")));
|
||||
assert!(required.contains(&json!("resource")));
|
||||
assert!(required.contains(&json!("method")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_allowed_services_populated() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
assert!(!tool.allowed_services.is_empty());
|
||||
assert!(tool.allowed_services.contains(&"drive".to_string()));
|
||||
assert!(tool.allowed_services.contains(&"gmail".to_string()));
|
||||
assert!(tool.allowed_services.contains(&"calendar".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_allowed_services_override_defaults() {
|
||||
let tool = GoogleWorkspaceTool::new(
|
||||
test_security(),
|
||||
vec!["drive".into(), "sheets".into()],
|
||||
None,
|
||||
None,
|
||||
60,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
assert_eq!(tool.allowed_services.len(), 2);
|
||||
assert!(tool.allowed_services.contains(&"drive".to_string()));
|
||||
assert!(tool.allowed_services.contains(&"sheets".to_string()));
|
||||
assert!(!tool.allowed_services.contains(&"gmail".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_disallowed_service() {
|
||||
let tool = GoogleWorkspaceTool::new(
|
||||
test_security(),
|
||||
vec!["drive".into()],
|
||||
None,
|
||||
None,
|
||||
60,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "gmail",
|
||||
"resource": "users",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("disallowed service should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("not in the allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_shell_injection_in_service() {
|
||||
let tool = GoogleWorkspaceTool::new(
|
||||
test_security(),
|
||||
vec!["drive; rm -rf /".into()],
|
||||
None,
|
||||
None,
|
||||
60,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive; rm -rf /",
|
||||
"resource": "files",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("shell injection should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Invalid characters"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_shell_injection_in_resource() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files$(whoami)",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("shell injection should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Invalid characters"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_invalid_format() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"format": "xml"
|
||||
}))
|
||||
.await
|
||||
.expect("invalid format should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Invalid format"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_params() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"params": "not_an_object"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type params should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'params' must be an object"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_body() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "create",
|
||||
"body": "not_an_object"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type body should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'body' must be an object"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_page_all() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"page_all": "yes"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type page_all should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'page_all' must be a boolean"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_page_limit() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"page_limit": "ten"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type page_limit should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'page_limit' must be a non-negative integer"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_sub_resource() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"sub_resource": 123
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type sub_resource should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'sub_resource' must be a string"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_required_param_returns_error() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool.execute(json!({"service": "drive"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limited_returns_error() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
max_actions_per_hour: 0,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = GoogleWorkspaceTool::new(security, vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("rate-limited should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gws_timeout_is_reasonable() {
|
||||
assert_eq!(DEFAULT_GWS_TIMEOUT_SECS, 30);
|
||||
}
|
||||
}
|
||||
@@ -161,8 +161,7 @@ impl DeferredMcpToolSet {
|
||||
/// The agent loop consults this each iteration to decide which tool_specs
|
||||
/// to include in the LLM request.
|
||||
pub struct ActivatedToolSet {
|
||||
/// name -> activated Tool
|
||||
tools: HashMap<String, Box<dyn Tool>>,
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
}
|
||||
|
||||
impl ActivatedToolSet {
|
||||
@@ -172,27 +171,23 @@ impl ActivatedToolSet {
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a tool as activated, storing its live wrapper.
|
||||
pub fn activate(&mut self, name: String, tool: Box<dyn Tool>) {
|
||||
pub fn activate(&mut self, name: String, tool: Arc<dyn Tool>) {
|
||||
self.tools.insert(name, tool);
|
||||
}
|
||||
|
||||
/// Whether a tool has been activated.
|
||||
pub fn is_activated(&self, name: &str) -> bool {
|
||||
self.tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get an activated tool for execution.
|
||||
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
|
||||
self.tools.get(name).map(|t| t.as_ref())
|
||||
/// Clone the Arc so the caller can drop the mutex guard before awaiting.
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||
self.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
/// All currently activated tool specs (to include in LLM requests).
|
||||
pub fn tool_specs(&self) -> Vec<ToolSpec> {
|
||||
self.tools.values().map(|t| t.spec()).collect()
|
||||
}
|
||||
|
||||
/// All activated tools for execution dispatch.
|
||||
pub fn tool_names(&self) -> Vec<&str> {
|
||||
self.tools.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
@@ -280,7 +275,7 @@ mod tests {
|
||||
|
||||
let mut set = ActivatedToolSet::new();
|
||||
assert!(!set.is_activated("fake"));
|
||||
set.activate("fake".into(), Box::new(FakeTool));
|
||||
set.activate("fake".into(), Arc::new(FakeTool));
|
||||
assert!(set.is_activated("fake"));
|
||||
assert!(set.get("fake").is_some());
|
||||
assert_eq!(set.tool_specs().len(), 1);
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
pub mod backup_tool;
|
||||
pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
pub mod browser_open;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
@@ -36,6 +37,7 @@ pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod git_operations;
|
||||
pub mod glob_search;
|
||||
pub mod google_workspace;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub mod hardware_board_info;
|
||||
#[cfg(feature = "hardware")]
|
||||
@@ -75,6 +77,8 @@ pub mod workspace_tool;
|
||||
|
||||
pub use backup_tool::BackupTool;
|
||||
pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
#[allow(unused_imports)]
|
||||
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
@@ -93,6 +97,7 @@ pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use git_operations::GitOperationsTool;
|
||||
pub use glob_search::GlobSearchTool;
|
||||
pub use google_workspace::GoogleWorkspaceTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub use hardware_board_info::HardwareBoardInfoTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
@@ -270,6 +275,7 @@ pub fn all_tools_with_runtime(
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
|
||||
let has_shell_access = runtime.has_shell_access();
|
||||
let mut tool_arcs: Vec<Arc<dyn Tool>> = vec![
|
||||
Arc::new(ShellTool::new(security.clone(), runtime)),
|
||||
Arc::new(FileReadTool::new(security.clone())),
|
||||
@@ -329,6 +335,20 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Browser delegation tool (conditionally registered; requires shell access)
|
||||
if root_config.browser_delegate.enabled {
|
||||
if has_shell_access {
|
||||
tool_arcs.push(Arc::new(BrowserDelegateTool::new(
|
||||
security.clone(),
|
||||
root_config.browser_delegate.clone(),
|
||||
)));
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"browser_delegate: skipped registration because the current runtime does not allow shell access"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if http_config.enabled {
|
||||
tool_arcs.push(Arc::new(HttpRequestTool::new(
|
||||
security.clone(),
|
||||
@@ -361,6 +381,23 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Google Workspace CLI (gws) integration — requires shell access
|
||||
if root_config.google_workspace.enabled && has_shell_access {
|
||||
tool_arcs.push(Arc::new(GoogleWorkspaceTool::new(
|
||||
security.clone(),
|
||||
root_config.google_workspace.allowed_services.clone(),
|
||||
root_config.google_workspace.credentials_path.clone(),
|
||||
root_config.google_workspace.default_account.clone(),
|
||||
root_config.google_workspace.rate_limit_per_minute,
|
||||
root_config.google_workspace.timeout_secs,
|
||||
root_config.google_workspace.audit_log,
|
||||
)));
|
||||
} else if root_config.google_workspace.enabled {
|
||||
tracing::warn!(
|
||||
"google_workspace: skipped registration because shell access is unavailable"
|
||||
);
|
||||
}
|
||||
|
||||
// Notion API tool (conditionally registered)
|
||||
if root_config.notion.enabled {
|
||||
let notion_api_key = if root_config.notion.api_key.trim().is_empty() {
|
||||
@@ -413,6 +450,7 @@ pub fn all_tools_with_runtime(
|
||||
if root_config.cloud_ops.enabled {
|
||||
tool_arcs.push(Arc::new(CloudOpsTool::new(root_config.cloud_ops.clone())));
|
||||
tool_arcs.push(Arc::new(CloudPatternsTool::new()));
|
||||
|
||||
}
|
||||
|
||||
// PDF extraction (feature-gated at compile time via rag-pdf)
|
||||
|
||||
@@ -389,6 +389,11 @@ impl ModelRoutingConfigTool {
|
||||
|
||||
let mut cfg = self.load_config_without_env()?;
|
||||
|
||||
// Capture previous values for rollback on probe failure.
|
||||
let previous_provider = cfg.default_provider.clone();
|
||||
let previous_model = cfg.default_model.clone();
|
||||
let previous_temperature = cfg.default_temperature;
|
||||
|
||||
match provider_update {
|
||||
MaybeSet::Set(provider) => cfg.default_provider = Some(provider),
|
||||
MaybeSet::Null => cfg.default_provider = None,
|
||||
@@ -416,6 +421,38 @@ impl ModelRoutingConfigTool {
|
||||
|
||||
cfg.save().await?;
|
||||
|
||||
// Probe the new model with a minimal API call to catch invalid model IDs
|
||||
// before the channel hot-reload picks up the change.
|
||||
if let (Some(provider_name), Some(model_name)) =
|
||||
(cfg.default_provider.clone(), cfg.default_model.clone())
|
||||
{
|
||||
if let Err(probe_err) = self.probe_model(&provider_name, &model_name).await {
|
||||
if crate::providers::reliable::is_non_retryable(&probe_err) {
|
||||
let reverted_model = previous_model.as_deref().unwrap_or("(none)").to_string();
|
||||
|
||||
// Rollback to previous config.
|
||||
cfg.default_provider = previous_provider;
|
||||
cfg.default_model = previous_model;
|
||||
cfg.default_temperature = previous_temperature;
|
||||
cfg.save().await?;
|
||||
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: format!(
|
||||
"Model '{model_name}' is not available: {probe_err}. Reverted to '{reverted_model}'.",
|
||||
),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
// Retryable errors (e.g. transient network issues) — keep the
|
||||
// new config and let the resilient wrapper handle retries.
|
||||
tracing::warn!(
|
||||
model = %model_name,
|
||||
"Model probe returned retryable error (keeping new config): {probe_err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
@@ -426,6 +463,36 @@ impl ModelRoutingConfigTool {
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a minimal 1-token chat request to verify the model is accessible.
|
||||
/// Returns `Ok(())` if the probe succeeds **or** if no API key is available
|
||||
/// (the probe would fail with an auth error unrelated to model validity).
|
||||
/// Provider construction failures are also treated as non-fatal.
|
||||
async fn probe_model(&self, provider_name: &str, model: &str) -> anyhow::Result<()> {
|
||||
use crate::providers;
|
||||
|
||||
// Use the runtime config's API key (which includes env-sourced keys),
|
||||
// not the on-disk config (which may have no key at all).
|
||||
let api_key = self.config.api_key.as_deref();
|
||||
if api_key.is_none_or(|k| k.trim().is_empty()) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = match providers::create_provider_with_url(
|
||||
provider_name,
|
||||
api_key,
|
||||
self.config.api_url.as_deref(),
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Ok(()),
|
||||
};
|
||||
|
||||
provider
|
||||
.chat_with_system(Some("Respond with OK."), "ping", model, 0.0)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_upsert_scenario(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||
let hint = Self::parse_non_empty_string(args, "hint")?;
|
||||
let provider = Self::parse_non_empty_string(args, "provider")?;
|
||||
@@ -1082,4 +1149,52 @@ mod tests {
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap_or_default().contains("read-only"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn set_default_skips_probe_without_api_key() {
|
||||
// When no API key is configured (test_config has none), the probe is
|
||||
// skipped and any model string is accepted. This verifies the probe-
|
||||
// skip path doesn't accidentally reject valid config changes.
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ModelRoutingConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "set_default",
|
||||
"provider": "anthropic",
|
||||
"model": "totally-fake-model-12345"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
let output: Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(
|
||||
output["config"]["default"]["model"].as_str(),
|
||||
Some("totally-fake-model-12345")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn set_default_temperature_only_skips_probe() {
|
||||
// Temperature-only changes don't set a new model, so the probe should
|
||||
// not fire at all (no provider/model to probe).
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ModelRoutingConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "set_default",
|
||||
"temperature": 1.5
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
let output: Value = serde_json::from_str(&result.output).unwrap();
|
||||
assert_eq!(
|
||||
output["config"]["default"]["temperature"].as_f64(),
|
||||
Some(1.5)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,9 +88,6 @@ impl Tool for ScheduleTool {
|
||||
self.handle_get(id)
|
||||
}
|
||||
"create" | "add" | "once" => {
|
||||
if let Some(blocked) = self.enforce_mutation_allowed(action) {
|
||||
return Ok(blocked);
|
||||
}
|
||||
let approved = args
|
||||
.get("approved")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
@@ -301,6 +298,12 @@ impl ScheduleTool {
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce rate-limiting AFTER command/args validation so that invalid
|
||||
// requests do not consume the action budget. (Fixes #3699)
|
||||
if let Some(blocked) = self.enforce_mutation_allowed(action) {
|
||||
return Ok(blocked);
|
||||
}
|
||||
|
||||
// All job creation routes through validated cron helpers, which enforce
|
||||
// the full security policy (allowlist + risk gate) before persistence.
|
||||
if let Some(value) = expression {
|
||||
|
||||
@@ -107,7 +107,7 @@ impl Tool for ToolSearchTool {
|
||||
if let Some(spec) = self.deferred.tool_spec(&stub.prefixed_name) {
|
||||
if !guard.is_activated(&stub.prefixed_name) {
|
||||
if let Some(tool) = self.deferred.activate(&stub.prefixed_name) {
|
||||
guard.activate(stub.prefixed_name.clone(), tool);
|
||||
guard.activate(stub.prefixed_name.clone(), Arc::from(tool));
|
||||
activated_count += 1;
|
||||
}
|
||||
}
|
||||
@@ -152,7 +152,7 @@ impl ToolSearchTool {
|
||||
Some(spec) => {
|
||||
if !guard.is_activated(name) {
|
||||
if let Some(tool) = self.deferred.activate(name) {
|
||||
guard.activate(name.to_string(), tool);
|
||||
guard.activate(name.to_string(), Arc::from(tool));
|
||||
activated_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,9 +288,6 @@ fn config_multiple_channels_coexist() {
|
||||
let toml_str = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[channels_config]
|
||||
cli = true
|
||||
|
||||
[channels_config.telegram]
|
||||
bot_token = "test_token"
|
||||
allowed_users = ["zeroclaw_user"]
|
||||
@@ -345,3 +342,23 @@ fn config_memory_defaults_when_section_absent() {
|
||||
"vector + keyword weights should sum to ~1.0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_channels_without_cli_field() {
|
||||
let toml_str = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[channels_config.matrix]
|
||||
homeserver = "https://matrix.example.com"
|
||||
access_token = "syt_test_token"
|
||||
room_id = "!abc123:example.com"
|
||||
allowed_users = ["@user:example.com"]
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(toml_str)
|
||||
.expect("channels_config with only a Matrix section (no explicit cli field) should parse");
|
||||
assert!(
|
||||
parsed.channels_config.cli,
|
||||
"cli should default to true when omitted"
|
||||
);
|
||||
assert!(parsed.channels_config.matrix.is_some());
|
||||
}
|
||||
|
||||
@@ -6,3 +6,4 @@ mod hooks;
|
||||
mod memory_comparison;
|
||||
mod memory_restart;
|
||||
mod telegram_attachment_fallback;
|
||||
mod telegram_finalize_draft;
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::{body_partial_json, method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
use zeroclaw::channels::telegram::TelegramChannel;
|
||||
use zeroclaw::channels::traits::Channel;
|
||||
|
||||
fn test_channel(mock_url: &str) -> TelegramChannel {
|
||||
TelegramChannel::new("TEST_TOKEN".into(), vec!["*".into()], false)
|
||||
.with_api_base(mock_url.to_string())
|
||||
}
|
||||
|
||||
fn telegram_ok_response(message_id: i64) -> serde_json::Value {
|
||||
json!({
|
||||
"ok": true,
|
||||
"result": {
|
||||
"message_id": message_id,
|
||||
"chat": {"id": 123},
|
||||
"text": "ok"
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn telegram_error_response(description: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"ok": false,
|
||||
"error_code": 400,
|
||||
"description": description,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finalize_draft_treats_not_modified_as_success() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/editMessageText"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(400).set_body_json(telegram_error_response(
|
||||
"Bad Request: message is not modified",
|
||||
)),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let channel = test_channel(&server.uri());
|
||||
let result = channel.finalize_draft("123", "42", "final text").await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"not modified should be treated as success, got: {result:?}"
|
||||
);
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("requests should be captured");
|
||||
assert_eq!(requests.len(), 1, "should stop after first edit response");
|
||||
assert_eq!(requests[0].url.path(), "/botTEST_TOKEN/editMessageText");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finalize_draft_plain_retry_treats_not_modified_as_success() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/editMessageText"))
|
||||
.and(body_partial_json(json!({
|
||||
"chat_id": "123",
|
||||
"message_id": 42,
|
||||
"parse_mode": "HTML",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(400)
|
||||
.set_body_json(telegram_error_response("Bad Request: can't parse entities")),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/editMessageText"))
|
||||
.and(body_partial_json(json!({
|
||||
"chat_id": "123",
|
||||
"message_id": 42,
|
||||
"text": "Use **bold**",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(400).set_body_json(telegram_error_response(
|
||||
"Bad Request: message is not modified",
|
||||
)),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let channel = test_channel(&server.uri());
|
||||
let result = channel.finalize_draft("123", "42", "Use **bold**").await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"plain retry should accept not modified, got: {result:?}"
|
||||
);
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("requests should be captured");
|
||||
assert_eq!(requests.len(), 2, "should only attempt the two edit calls");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finalize_draft_skips_send_message_when_delete_fails() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/editMessageText"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(400).set_body_json(telegram_error_response(
|
||||
"Bad Request: message cannot be edited",
|
||||
)),
|
||||
)
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/deleteMessage"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(400).set_body_json(telegram_error_response(
|
||||
"Bad Request: message to delete not found",
|
||||
)),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let channel = test_channel(&server.uri());
|
||||
let result = channel.finalize_draft("123", "42", "final text").await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"delete failure should skip sendMessage instead of erroring, got: {result:?}"
|
||||
);
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("requests should be captured");
|
||||
assert_eq!(
|
||||
requests
|
||||
.iter()
|
||||
.filter(|req| req.url.path() == "/botTEST_TOKEN/sendMessage")
|
||||
.count(),
|
||||
0,
|
||||
"sendMessage should be skipped when deleteMessage fails"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finalize_draft_sends_fresh_message_after_successful_delete() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/editMessageText"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(400).set_body_json(telegram_error_response(
|
||||
"Bad Request: message cannot be edited",
|
||||
)),
|
||||
)
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/deleteMessage"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(telegram_ok_response(42)))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/botTEST_TOKEN/sendMessage"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(telegram_ok_response(43)))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let channel = test_channel(&server.uri());
|
||||
let result = channel.finalize_draft("123", "42", "final text").await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"successful delete should allow safe sendMessage fallback, got: {result:?}"
|
||||
);
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("requests should be captured");
|
||||
assert_eq!(
|
||||
requests
|
||||
.iter()
|
||||
.filter(|req| req.url.path() == "/botTEST_TOKEN/sendMessage")
|
||||
.count(),
|
||||
1,
|
||||
"sendMessage should be attempted exactly once after delete succeeds"
|
||||
);
|
||||
}
|
||||
@@ -131,7 +131,12 @@ impl StaticMemoryLoader {
|
||||
|
||||
#[async_trait]
|
||||
impl MemoryLoader for StaticMemoryLoader {
|
||||
async fn load_context(&self, _memory: &dyn Memory, _user_message: &str) -> Result<String> {
|
||||
async fn load_context(
|
||||
&self,
|
||||
_memory: &dyn Memory,
|
||||
_user_message: &str,
|
||||
_session_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
Ok(self.context.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ impl Provider for TraceLlmProvider {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(input_tokens),
|
||||
output_tokens: Some(output_tokens),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}),
|
||||
@@ -188,6 +189,7 @@ impl Provider for TraceLlmProvider {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(input_tokens),
|
||||
output_tokens: Some(output_tokens),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
})
|
||||
|
||||
+2
-1
@@ -1,2 +1,3 @@
|
||||
node_modules/
|
||||
dist/
|
||||
dist/*
|
||||
!dist/.gitkeep
|
||||
|
||||
Vendored
+3
@@ -0,0 +1,3 @@
|
||||
|
||||
|
||||
""
|
||||
+57
-2
@@ -1,5 +1,6 @@
|
||||
import { Routes, Route, Navigate } from 'react-router-dom';
|
||||
import { useState, useEffect, createContext, useContext } from 'react';
|
||||
import { useState, useEffect, createContext, useContext, Component } from 'react';
|
||||
import type { ReactNode, ErrorInfo } from 'react';
|
||||
import Layout from './components/layout/Layout';
|
||||
import Dashboard from './pages/Dashboard';
|
||||
import AgentChat from './pages/AgentChat';
|
||||
@@ -28,6 +29,60 @@ export const LocaleContext = createContext<LocaleContextType>({
|
||||
|
||||
export const useLocaleContext = () => useContext(LocaleContext);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error boundary — catches render crashes and shows a recoverable message
|
||||
// instead of a black screen
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface ErrorBoundaryState {
|
||||
error: Error | null;
|
||||
}
|
||||
|
||||
export class ErrorBoundary extends Component<
|
||||
{ children: ReactNode },
|
||||
ErrorBoundaryState
|
||||
> {
|
||||
constructor(props: { children: ReactNode }) {
|
||||
super(props);
|
||||
this.state = { error: null };
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
|
||||
return { error };
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, info: ErrorInfo) {
|
||||
console.error('[ZeroClaw] Render error:', error, info.componentStack);
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.state.error) {
|
||||
return (
|
||||
<div className="p-6">
|
||||
<div className="bg-gray-900 border border-red-700 rounded-xl p-6 w-full max-w-lg">
|
||||
<h2 className="text-lg font-semibold text-red-400 mb-2">
|
||||
Something went wrong
|
||||
</h2>
|
||||
<p className="text-gray-400 text-sm mb-4">
|
||||
A render error occurred. Check the browser console for details.
|
||||
</p>
|
||||
<pre className="text-xs text-red-300 bg-gray-800 rounded p-3 overflow-x-auto whitespace-pre-wrap break-all">
|
||||
{this.state.error.message}
|
||||
</pre>
|
||||
<button
|
||||
onClick={() => this.setState({ error: null })}
|
||||
className="mt-6 px-4 py-2 bg-blue-600 hover:bg-blue-700 text-white text-sm font-medium rounded-lg transition-colors"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return this.props.children;
|
||||
}
|
||||
}
|
||||
|
||||
// Pairing dialog component
|
||||
function PairingDialog({ onPair }: { onPair: (code: string) => Promise<void> }) {
|
||||
const [code, setCode] = useState('');
|
||||
@@ -77,7 +132,7 @@ function PairingDialog({ onPair }: { onPair: (code: string) => Promise<void> })
|
||||
autoFocus
|
||||
/>
|
||||
{error && (
|
||||
<p className="text-[#ff4466] text-sm mb-4 text-center animate-fade-in">{error}</p>
|
||||
<p className="text-[#ff4466] text-sm mb-4 text-center animate-fade-in" aria-live="polite">{error}</p>
|
||||
)}
|
||||
<button
|
||||
type="submit"
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { Outlet } from 'react-router-dom';
|
||||
import { Outlet, useLocation } from 'react-router-dom';
|
||||
import Sidebar from '@/components/layout/Sidebar';
|
||||
import Header from '@/components/layout/Header';
|
||||
import { ErrorBoundary } from '@/App';
|
||||
|
||||
export default function Layout() {
|
||||
const { pathname } = useLocation();
|
||||
|
||||
return (
|
||||
<div className="min-h-screen text-white" style={{ background: 'linear-gradient(135deg, #050510 0%, #080818 50%, #050510 100%)' }}>
|
||||
{/* Fixed sidebar */}
|
||||
@@ -12,9 +15,12 @@ export default function Layout() {
|
||||
<div className="ml-60 flex flex-col min-h-screen">
|
||||
<Header />
|
||||
|
||||
{/* Page content */}
|
||||
{/* Page content — ErrorBoundary keyed by pathname so the nav shell
|
||||
survives a page crash and the boundary resets on route change */}
|
||||
<main className="flex-1 overflow-y-auto">
|
||||
<Outlet />
|
||||
<ErrorBoundary key={pathname}>
|
||||
<Outlet />
|
||||
</ErrorBoundary>
|
||||
</main>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
+3
-1
@@ -71,7 +71,9 @@ export class WebSocketClient {
|
||||
params.set('session_id', sessionId);
|
||||
const url = `${this.baseUrl}/ws/chat?${params.toString()}`;
|
||||
|
||||
this.ws = new WebSocket(url, ['zeroclaw.v1']);
|
||||
const protocols: string[] = ['zeroclaw.v1'];
|
||||
if (token) protocols.push(`bearer.${token}`);
|
||||
this.ws = new WebSocket(url, protocols);
|
||||
|
||||
this.ws.onopen = () => {
|
||||
this.currentDelay = this.reconnectDelay;
|
||||
|
||||
Reference in New Issue
Block a user