Compare commits
89 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 41dd23175f | |||
| 864d754b56 | |||
| ccd52f3394 | |||
| eb01aa451d | |||
| c785b45f2d | |||
| ffb8b81f90 | |||
| 65f856d710 | |||
| 1682620377 | |||
| aa455ae89b | |||
| a9ffd38912 | |||
| 86a0584513 | |||
| abef4c5719 | |||
| 483b2336c4 | |||
| 14cda3bc9a | |||
| 6e8f0fa43c | |||
| a965b129f8 | |||
| c135de41b7 | |||
| 2d2c2ac9e6 | |||
| 5e774bbd70 | |||
| 33015067eb | |||
| 6b10c0b891 | |||
| bf817e30d2 | |||
| 0051a0c296 | |||
| d753de91f1 | |||
| f6b2f61a01 | |||
| 70e7910cb9 | |||
| a8868768e8 | |||
| 67293c50df | |||
| 1646079d25 | |||
| 25b639435f | |||
| 77779844e5 | |||
| f658d5806a | |||
| 7134fe0824 | |||
| 263802b3df | |||
| 3c25fddb2a | |||
| a6a46bdd25 | |||
| 235d4d2f1c | |||
| bd1e8c8e1a | |||
| f81807bff6 | |||
| bb7006313c | |||
| 9a49626376 | |||
| 8b978a721f | |||
| 75b4c1d4a4 | |||
| b2087e6065 | |||
| ad8f81ad76 | |||
| c58e1c1fb3 | |||
| cb0779d761 | |||
| daca2d9354 | |||
| 3c1e710c38 | |||
| 0aefde95f2 | |||
| a84aa60554 | |||
| edd4b37325 | |||
| c5f0155061 | |||
| 9ee06ed6fc | |||
| ac6b43e9f4 | |||
| 6c5573ad96 | |||
| 1d57a0d1e5 | |||
| 9780c7d797 | |||
| 35a5451a17 | |||
| 8e81d44d54 | |||
| 86ad0c6a2b | |||
| 6ecf89d6a9 | |||
| 691efa4d8c | |||
| d1e3f435b4 | |||
| 44c3e264ad | |||
| f2b6013329 | |||
| 05d3c51a30 | |||
| 2ceda31ce2 | |||
| 9069bc3c1f | |||
| 9319fe18da | |||
| cc454a86c8 | |||
| 256e8ccebf | |||
| 72c9e6b6ca | |||
| 755a129ca2 | |||
| 8b0d3684c5 | |||
| cdb5ac1471 | |||
| 67acb1a0bb | |||
| 9eac6bafef | |||
| a12f2ff439 | |||
| a38a4d132e | |||
| 48aba73d3a | |||
| a1ab1e1a11 | |||
| f394abf35c | |||
| 52e0271bd5 | |||
| 6c0a48efff | |||
| 87b5bca449 | |||
| be40c0c5a5 | |||
| 6527871928 | |||
| 0bda80de9c |
@@ -1,97 +0,0 @@
|
||||
# Mem0 Integration: Dual-Scope Recall + Per-Turn Memory
|
||||
|
||||
## Context
|
||||
|
||||
Mem0 auto-save works but the integration is missing key features from mem0 best practices: per-turn recall, multi-level scoping, and proper context injection. This causes the bot to "forget" on follow-up turns and not differentiate users.
|
||||
|
||||
## What's Missing (vs mem0 docs)
|
||||
|
||||
1. **Per-turn recall** — only first turn gets memory context, follow-ups get nothing
|
||||
2. **Dual-scope** — no sender vs group distinction. All memories use single hardcoded `user_id`
|
||||
3. **System prompt injection** — memory prepended to user message (pollutes session history)
|
||||
4. **`agent_id` scoping** — mem0 supports agent-level patterns, not used
|
||||
|
||||
## Changes
|
||||
|
||||
### 1. `src/memory/mem0.rs` — Use session_id for multi-level scoping
|
||||
|
||||
Map zeroclaw's `session_id` param to mem0's `user_id`. This enables per-user and per-group memory namespaces without changing the `Memory` trait.
|
||||
|
||||
```rust
|
||||
// Add helper:
|
||||
fn effective_user_id(&self, session_id: Option<&str>) -> &str {
|
||||
session_id.filter(|s| !s.is_empty()).unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
// In store(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In recall(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In list(): use effective_user_id(session_id) as mem0 user_id
|
||||
```
|
||||
|
||||
### 2. `src/channels/mod.rs` ~line 2229 — Per-turn dual-scope recall
|
||||
|
||||
Remove `if !had_prior_history` gate. Always recall from both sender scope and group scope (for group chats).
|
||||
|
||||
```rust
|
||||
// Detect group chat
|
||||
let is_group = msg.reply_target.contains("@g.us")
|
||||
|| msg.reply_target.starts_with("group:");
|
||||
|
||||
// Sender-scope recall (always)
|
||||
let sender_context = build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&msg.sender),
|
||||
).await;
|
||||
|
||||
// Group-scope recall (groups only)
|
||||
let group_context = if is_group {
|
||||
build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&history_key),
|
||||
).await
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Merge (deduplicate by checking substring overlap)
|
||||
let memory_context = merge_memory_contexts(&sender_context, &group_context);
|
||||
```
|
||||
|
||||
### 3. `src/channels/mod.rs` ~line 2244 — Inject into system prompt
|
||||
|
||||
Move memory context from user message to system prompt. Re-fetched each turn, doesn't pollute session.
|
||||
|
||||
```rust
|
||||
let mut system_prompt = build_channel_system_prompt(...);
|
||||
if !memory_context.is_empty() {
|
||||
system_prompt.push_str(&format!("\n\n{memory_context}"));
|
||||
}
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
```
|
||||
|
||||
### 4. `src/channels/mod.rs` — Dual-scope auto-save
|
||||
|
||||
Find existing auto-save call. For group messages, store twice:
|
||||
- `store(key, content, category, Some(&msg.sender))` — personal facts
|
||||
- `store(key, content, category, Some(&history_key))` — group context
|
||||
|
||||
Both async, non-blocking. DMs only store to sender scope.
|
||||
|
||||
### 5. `src/memory/mem0.rs` — Add `agent_id` support (optional)
|
||||
|
||||
Pass `self.app_name` as `agent_id` param to mem0 API for agent behavior tracking.
|
||||
|
||||
## Files to Modify
|
||||
|
||||
1. `src/memory/mem0.rs` — session_id → user_id mapping
|
||||
2. `src/channels/mod.rs` — per-turn recall, dual-scope, system prompt injection, dual-scope save
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo check --features whatsapp-web,memory-mem0`
|
||||
2. `cargo test --features whatsapp-web,memory-mem0`
|
||||
3. Deploy to Synology
|
||||
4. Test DM: "我鍾意食壽司" → next turn "我鍾意食咩" → should recall
|
||||
5. Test group: Joe says "我鍾意食壽司" → someone else asks "Joe 鍾意食咩" → should recall from group scope
|
||||
6. Check mem0 server logs: GET with `user_id=sender` AND `user_id=group_key`
|
||||
7. Check mem0 server logs: POST with both user_ids for group messages
|
||||
@@ -118,3 +118,7 @@ PROVIDER=openrouter
|
||||
# Optional: Brave Search (requires API key from https://brave.com/search/api)
|
||||
# WEB_SEARCH_PROVIDER=brave
|
||||
# BRAVE_API_KEY=your-brave-search-api-key
|
||||
#
|
||||
# Optional: SearXNG (self-hosted, requires instance URL)
|
||||
# WEB_SEARCH_PROVIDER=searxng
|
||||
# SEARXNG_INSTANCE_URL=https://searx.example.com
|
||||
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
run: mkdir -p web/dist && touch web/dist/.gitkeep
|
||||
|
||||
- name: Check all features
|
||||
run: cargo check --all-features --locked
|
||||
run: cargo check --features ci-all --locked
|
||||
|
||||
docs-quality:
|
||||
name: Docs Quality
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
name: Pub Homebrew Core
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag to publish (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Patch formula only (no push/PR)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
HOMEBREW_UPSTREAM_PR_TOKEN:
|
||||
required: false
|
||||
HOMEBREW_CORE_BOT_TOKEN:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
|
||||
@@ -19,6 +19,7 @@ env:
|
||||
jobs:
|
||||
detect-version-change:
|
||||
name: Detect Version Bump
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
@@ -40,6 +41,14 @@ jobs:
|
||||
echo "Current version: ${current}"
|
||||
echo "Previous version: ${previous}"
|
||||
|
||||
# Skip if stable release workflow will handle this version
|
||||
# (indicated by an existing or imminent stable tag)
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/v${current}" >/dev/null 2>&1; then
|
||||
echo "Stable tag v${current} exists — stable release workflow handles crates.io"
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$current" != "$previous" && -n "$current" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${current}" >> "$GITHUB_OUTPUT"
|
||||
@@ -102,6 +111,22 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish to crates.io
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
@@ -67,6 +67,24 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
if: "!inputs.dry_run"
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
if: "!inputs.dry_run"
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish (dry run)
|
||||
if: inputs.dry_run
|
||||
run: cargo publish --dry-run --locked --allow-dirty --no-verify
|
||||
|
||||
@@ -21,25 +21,48 @@ env:
|
||||
jobs:
|
||||
version:
|
||||
name: Resolve Version
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
version: ${{ steps.ver.outputs.version }}
|
||||
tag: ${{ steps.ver.outputs.tag }}
|
||||
skip: ${{ steps.ver.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Compute beta version
|
||||
id: ver
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
base_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
|
||||
# Skip beta if this is a version bump commit (stable release handles it)
|
||||
commit_msg=$(git log -1 --pretty=format:"%s")
|
||||
if [[ "$commit_msg" =~ ^chore:\ bump\ version ]]; then
|
||||
echo "Version bump commit detected — skipping beta release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Skip beta if a stable tag already exists for this version
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/v${base_version}" >/dev/null 2>&1; then
|
||||
echo "Stable tag v${base_version} exists — skipping beta release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
beta_tag="v${base_version}-beta.${GITHUB_RUN_NUMBER}"
|
||||
echo "version=${base_version}" >> "$GITHUB_OUTPUT"
|
||||
echo "tag=${beta_tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Beta release: ${beta_tag}"
|
||||
|
||||
release-notes:
|
||||
name: Generate Release Notes
|
||||
needs: [version]
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
notes: ${{ steps.notes.outputs.body }}
|
||||
@@ -130,6 +153,8 @@ jobs:
|
||||
|
||||
web:
|
||||
name: Build Web Dashboard
|
||||
needs: [version]
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
name: Release Stable
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v[0-9]+.[0-9]+.[0-9]+" # stable tags only (no -beta suffix)
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
@@ -33,11 +36,22 @@ jobs:
|
||||
- name: Validate semver and Cargo.toml match
|
||||
id: check
|
||||
shell: bash
|
||||
env:
|
||||
INPUT_VERSION: ${{ inputs.version || '' }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
input_version="${{ inputs.version }}"
|
||||
cargo_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
|
||||
# Resolve version from tag push or manual input
|
||||
if [[ "$EVENT_NAME" == "push" ]]; then
|
||||
# Tag push: extract version from tag name (v0.5.9 -> 0.5.9)
|
||||
input_version="${REF_NAME#v}"
|
||||
else
|
||||
input_version="$INPUT_VERSION"
|
||||
fi
|
||||
|
||||
if [[ ! "$input_version" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "::error::Version must be semver (X.Y.Z). Got: ${input_version}"
|
||||
exit 1
|
||||
@@ -49,9 +63,13 @@ jobs:
|
||||
fi
|
||||
|
||||
tag="v${input_version}"
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${tag} already exists."
|
||||
exit 1
|
||||
|
||||
# Only check tag existence for manual dispatch (tag push means it already exists)
|
||||
if [[ "$EVENT_NAME" != "push" ]]; then
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${tag} already exists."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "tag=${tag}" >> "$GITHUB_OUTPUT"
|
||||
@@ -286,6 +304,14 @@ jobs:
|
||||
NOTES: ${{ needs.release-notes.outputs.notes }}
|
||||
run: printf '%s\n' "$NOTES" > release-notes.md
|
||||
|
||||
- name: Create tag if manual dispatch
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
TAG: ${{ needs.validate.outputs.tag }}
|
||||
run: |
|
||||
git tag -a "$TAG" -m "zeroclaw $TAG"
|
||||
git push origin "$TAG"
|
||||
|
||||
- name: Create GitHub Release
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
@@ -323,6 +349,21 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish to crates.io
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
@@ -446,6 +487,16 @@ jobs:
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
homebrew:
|
||||
name: Update Homebrew Core
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-homebrew-core.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
# ── Post-publish: tweet after release + website are live ──────────────
|
||||
# Docker push can be slow; don't let it block the tweet.
|
||||
tweet:
|
||||
|
||||
Generated
+382
-26
@@ -117,6 +117,28 @@ version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||
|
||||
[[package]]
|
||||
name = "alsa"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
|
||||
dependencies = [
|
||||
"alsa-sys",
|
||||
"bitflags 2.11.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "alsa-sys"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db8fee663d06c4e303404ef5f40488a53e062f89ba8bfed81f42325aafad1527"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ambient-authority"
|
||||
version = "0.0.2"
|
||||
@@ -583,6 +605,24 @@ dependencies = [
|
||||
"virtue",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.72.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bip39"
|
||||
version = "2.2.2"
|
||||
@@ -878,6 +918,21 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cesu8"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||
dependencies = [
|
||||
"nom 7.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cff-parser"
|
||||
version = "0.1.0"
|
||||
@@ -1003,6 +1058,17 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"libc",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.6.0"
|
||||
@@ -1086,6 +1152,16 @@ version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compression-codecs"
|
||||
version = "0.4.37"
|
||||
@@ -1209,6 +1285,49 @@ dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coreaudio-rs"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation-sys",
|
||||
"coreaudio-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coreaudio-sys"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ceec7a6067e62d6f931a2baf6f3a751f4a892595bcec1461a3c94ef9949864b6"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpal"
|
||||
version = "0.15.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779"
|
||||
dependencies = [
|
||||
"alsa",
|
||||
"core-foundation-sys",
|
||||
"coreaudio-rs",
|
||||
"dasp_sample",
|
||||
"jni",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"mach2 0.4.3",
|
||||
"ndk",
|
||||
"ndk-context",
|
||||
"oboe",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
"windows",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpp_demangle"
|
||||
version = "0.4.5"
|
||||
@@ -1588,6 +1707,12 @@ dependencies = [
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dasp_sample"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f"
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.10.0"
|
||||
@@ -2944,7 +3069,7 @@ dependencies = [
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
"windows-core 0.62.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3319,9 +3444,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2"
|
||||
|
||||
[[package]]
|
||||
name = "iri-string"
|
||||
version = "0.7.10"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
|
||||
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
@@ -3362,9 +3487,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
||||
checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682"
|
||||
|
||||
[[package]]
|
||||
name = "ittapi"
|
||||
@@ -3395,6 +3520,50 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"cfg-if",
|
||||
"combine",
|
||||
"jni-sys 0.3.1",
|
||||
"log",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
|
||||
dependencies = [
|
||||
"jni-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
|
||||
dependencies = [
|
||||
"jni-sys-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys-macros"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.34"
|
||||
@@ -4202,9 +4371,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "moka"
|
||||
version = "0.12.14"
|
||||
version = "0.12.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
|
||||
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
|
||||
dependencies = [
|
||||
"async-lock",
|
||||
"crossbeam-channel",
|
||||
@@ -4242,6 +4411,35 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11ec1bc47d34ae756616f387c11fd0595f86f2cc7e6473bde9e3ded30cb902a1"
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"jni-sys 0.3.1",
|
||||
"log",
|
||||
"ndk-sys",
|
||||
"num_enum",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk-context"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b"
|
||||
|
||||
[[package]]
|
||||
name = "ndk-sys"
|
||||
version = "0.5.0+25.2.9519653"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691"
|
||||
dependencies = [
|
||||
"jni-sys 0.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "negentropy"
|
||||
version = "0.5.0"
|
||||
@@ -4421,6 +4619,17 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -4440,6 +4649,28 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26"
|
||||
dependencies = [
|
||||
"num_enum_derive",
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum_derive"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8"
|
||||
dependencies = [
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nusb"
|
||||
version = "0.2.3"
|
||||
@@ -4519,6 +4750,29 @@ dependencies = [
|
||||
"ruzstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oboe"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb"
|
||||
dependencies = [
|
||||
"jni",
|
||||
"ndk",
|
||||
"ndk-context",
|
||||
"num-derive",
|
||||
"num-traits",
|
||||
"oboe-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oboe-sys"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.4"
|
||||
@@ -5281,9 +5535,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pulldown-cmark"
|
||||
version = "0.13.1"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6"
|
||||
checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"memchr",
|
||||
@@ -5401,9 +5655,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "quoted_printable"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "640c9bd8497b02465aeef5375144c26062e0dcd5939dfcbb0f5db76cb8c17c73"
|
||||
checksum = "478e0585659a122aa407eb7e3c0e1fa51b1d8a870038bd29f0cf4a8551eea972"
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
@@ -7577,9 +7831,9 @@ checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "3.2.0"
|
||||
version = "3.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc"
|
||||
checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"cookie_store",
|
||||
@@ -7591,15 +7845,15 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"ureq-proto",
|
||||
"utf-8",
|
||||
"utf8-zero",
|
||||
"webpki-roots 1.0.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ureq-proto"
|
||||
version = "0.5.3"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
|
||||
checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"http 1.4.0",
|
||||
@@ -7632,6 +7886,12 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8-zero"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
@@ -8691,6 +8951,26 @@ dependencies = [
|
||||
"wasmtime-internal-math",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows"
|
||||
version = "0.54.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49"
|
||||
dependencies = [
|
||||
"windows-core 0.54.0",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.54.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12661b9c89351d684a50a8a643ce5f608e20243b9fb84687800163429f161d65"
|
||||
dependencies = [
|
||||
"windows-result 0.1.2",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.2"
|
||||
@@ -8700,7 +8980,7 @@ dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-result 0.4.1",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
@@ -8732,6 +9012,15 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8"
|
||||
dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.4.1"
|
||||
@@ -8750,6 +9039,15 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
|
||||
dependencies = [
|
||||
"windows-targets 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
@@ -8786,6 +9084,21 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
@@ -8819,6 +9132,12 @@ dependencies = [
|
||||
"windows_x86_64_msvc 0.53.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
@@ -8831,6 +9150,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -8843,6 +9168,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -8867,6 +9198,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -8879,6 +9216,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -8891,6 +9234,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
@@ -8903,6 +9252,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -9203,7 +9558,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.5"
|
||||
version = "0.5.9"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
@@ -9217,6 +9572,7 @@ dependencies = [
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"console",
|
||||
"cpal",
|
||||
"criterion",
|
||||
"cron",
|
||||
"dialoguer",
|
||||
@@ -9298,18 +9654,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.42"
|
||||
version = "0.8.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3"
|
||||
checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.42"
|
||||
version = "0.8.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f"
|
||||
checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -9414,9 +9770,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "8.3.0"
|
||||
version = "8.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a243cfad17427fc077f529da5a95abe4e94fd2bfdb601611870a6557cc67657"
|
||||
checksum = "5c546feb4481b0fbafb4ef0d79b6204fc41c6f9884b1b73b1d73f82442fc0845"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"flate2",
|
||||
@@ -9486,9 +9842,9 @@ checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
|
||||
|
||||
[[package]]
|
||||
name = "zune-jpeg"
|
||||
version = "0.5.13"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec5f41c76397b7da451efd19915684f727d7e1d516384ca6bd0ec43ec94de23c"
|
||||
checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6"
|
||||
dependencies = [
|
||||
"zune-core",
|
||||
]
|
||||
|
||||
+27
-4
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.5"
|
||||
version = "0.5.9"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -97,7 +97,7 @@ anyhow = "1.0"
|
||||
thiserror = "2.0"
|
||||
|
||||
# Aardvark I2C/SPI/GPIO USB adapter (Total Phase) — stub when SDK absent
|
||||
aardvark-sys = { path = "crates/aardvark-sys" }
|
||||
aardvark-sys = { path = "crates/aardvark-sys", version = "0.1.0" }
|
||||
|
||||
# UUID generation
|
||||
uuid = { version = "1.22", default-features = false, features = ["v4", "std"] }
|
||||
@@ -199,6 +199,9 @@ pdf-extract = { version = "0.10", optional = true }
|
||||
# WASM plugin runtime (extism)
|
||||
extism = { version = "1.20", optional = true }
|
||||
|
||||
# Cross-platform audio capture for voice wake word detection (optional, enable with --features voice-wake)
|
||||
cpal = { version = "0.15", optional = true }
|
||||
|
||||
# Terminal QR rendering for WhatsApp Web pairing flow.
|
||||
qrcode = { version = "0.14", optional = true }
|
||||
|
||||
@@ -228,8 +231,6 @@ channel-matrix = ["dep:matrix-sdk"]
|
||||
channel-lark = ["dep:prost"]
|
||||
channel-feishu = ["channel-lark"] # Alias for Feishu users (Lark and Feishu are the same platform)
|
||||
memory-postgres = ["dep:postgres"]
|
||||
# memory-mem0 = Mem0 (OpenMemory) memory backend via REST API
|
||||
memory-mem0 = []
|
||||
observability-prometheus = ["dep:prometheus"]
|
||||
observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"]
|
||||
peripheral-rpi = ["rppal"]
|
||||
@@ -252,8 +253,30 @@ rag-pdf = ["dep:pdf-extract"]
|
||||
skill-creation = []
|
||||
# whatsapp-web = Native WhatsApp Web client with custom rusqlite storage backend
|
||||
whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-proto", "dep:wa-rs-ureq-http", "dep:wa-rs-tokio-transport", "dep:serde-big-array", "dep:prost", "dep:qrcode"]
|
||||
# voice-wake = Voice wake word detection via microphone (cpal)
|
||||
voice-wake = ["dep:cpal"]
|
||||
# WASM plugin system (extism-based)
|
||||
plugins-wasm = ["dep:extism"]
|
||||
# Meta-feature for CI: all features except those requiring system C libraries
|
||||
# not available on standard CI runners (e.g., voice-wake needs libasound2-dev).
|
||||
ci-all = [
|
||||
"channel-nostr",
|
||||
"hardware",
|
||||
"channel-matrix",
|
||||
"channel-lark",
|
||||
"memory-postgres",
|
||||
"observability-prometheus",
|
||||
"observability-otel",
|
||||
"peripheral-rpi",
|
||||
"browser-native",
|
||||
"sandbox-landlock",
|
||||
"sandbox-bubblewrap",
|
||||
"probe",
|
||||
"rag-pdf",
|
||||
"skill-creation",
|
||||
"whatsapp-web",
|
||||
"plugins-wasm",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
|
||||
@@ -27,6 +27,7 @@ COPY Cargo.toml Cargo.lock ./
|
||||
# Previously we used sed to drop `crates/robot-kit`, which made the manifest disagree
|
||||
# with the lockfile and caused `cargo --locked` to fail (Cargo refused to rewrite the lock).
|
||||
COPY crates/robot-kit/ crates/robot-kit/
|
||||
COPY crates/aardvark-sys/ crates/aardvark-sys/
|
||||
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
|
||||
RUN mkdir -p src benches \
|
||||
&& echo "fn main() {}" > src/main.rs \
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Start mem0 + reranker GPU container for ZeroClaw memory backend.
|
||||
#
|
||||
# Required env vars:
|
||||
# MEM0_LLM_API_KEY or ZAI_API_KEY — API key for the LLM used in fact extraction
|
||||
#
|
||||
# Optional env vars (with defaults):
|
||||
# MEM0_LLM_PROVIDER — mem0 LLM provider (default: "openai" i.e. OpenAI-compatible)
|
||||
# MEM0_LLM_MODEL — LLM model for fact extraction (default: "glm-5-turbo")
|
||||
# MEM0_LLM_BASE_URL — LLM API base URL (default: "https://api.z.ai/api/coding/paas/v4")
|
||||
# MEM0_EMBEDDER_MODEL — embedding model (default: "BAAI/bge-m3")
|
||||
# MEM0_EMBEDDER_DIMS — embedding dimensions (default: "1024")
|
||||
# MEM0_EMBEDDER_DEVICE — "cuda", "cpu", or "auto" (default: "cuda")
|
||||
# MEM0_VECTOR_COLLECTION — Qdrant collection name (default: "zeroclaw_mem0")
|
||||
# RERANKER_MODEL — reranker model (default: "BAAI/bge-reranker-v2-m3")
|
||||
# RERANKER_DEVICE — "cuda" or "cpu" (default: "cuda")
|
||||
# MEM0_PORT — mem0 server port (default: 8765)
|
||||
# RERANKER_PORT — reranker server port (default: 8678)
|
||||
# CONTAINER_IMAGE — base container image (default: docker.io/kyuz0/amd-strix-halo-comfyui:latest)
|
||||
# CONTAINER_NAME — container name (default: mem0-gpu)
|
||||
# DATA_DIR — host path for Qdrant data (default: ~/mem0-data)
|
||||
# SCRIPT_DIR — host path for server scripts (default: directory of this script)
|
||||
set -e
|
||||
|
||||
# Resolve script directory for mounting server scripts
|
||||
SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "$0")" && pwd)}"
|
||||
|
||||
# API key — accept either name
|
||||
export MEM0_LLM_API_KEY="${MEM0_LLM_API_KEY:-${ZAI_API_KEY:?MEM0_LLM_API_KEY or ZAI_API_KEY must be set}}"
|
||||
|
||||
# Defaults
|
||||
MEM0_LLM_MODEL="${MEM0_LLM_MODEL:-glm-5-turbo}"
|
||||
MEM0_LLM_BASE_URL="${MEM0_LLM_BASE_URL:-https://api.z.ai/api/coding/paas/v4}"
|
||||
MEM0_PORT="${MEM0_PORT:-8765}"
|
||||
RERANKER_PORT="${RERANKER_PORT:-8678}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE:-docker.io/kyuz0/amd-strix-halo-comfyui:latest}"
|
||||
CONTAINER_NAME="${CONTAINER_NAME:-mem0-gpu}"
|
||||
DATA_DIR="${DATA_DIR:-$HOME/mem0-data}"
|
||||
|
||||
# Stop existing CPU services (if any)
|
||||
kill -9 $(pgrep -f "mem0-server.py") 2>/dev/null || true
|
||||
kill -9 $(pgrep -f "reranker-server.py") 2>/dev/null || true
|
||||
|
||||
# Stop existing container
|
||||
podman stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||
podman rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||
|
||||
podman run -d --name "$CONTAINER_NAME" \
|
||||
--device /dev/dri --device /dev/kfd \
|
||||
--group-add video --group-add render \
|
||||
--restart unless-stopped \
|
||||
-p "$MEM0_PORT:$MEM0_PORT" -p "$RERANKER_PORT:$RERANKER_PORT" \
|
||||
-v "$DATA_DIR":/root/mem0-data:Z \
|
||||
-v "$SCRIPT_DIR/mem0-server.py":/app/mem0-server.py:ro,Z \
|
||||
-v "$SCRIPT_DIR/reranker-server.py":/app/reranker-server.py:ro,Z \
|
||||
-v "$HOME/.cache/huggingface":/root/.cache/huggingface:Z \
|
||||
-e MEM0_LLM_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e ZAI_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e MEM0_LLM_MODEL="$MEM0_LLM_MODEL" \
|
||||
-e MEM0_LLM_BASE_URL="$MEM0_LLM_BASE_URL" \
|
||||
${MEM0_LLM_PROVIDER:+-e MEM0_LLM_PROVIDER="$MEM0_LLM_PROVIDER"} \
|
||||
${MEM0_EMBEDDER_MODEL:+-e MEM0_EMBEDDER_MODEL="$MEM0_EMBEDDER_MODEL"} \
|
||||
${MEM0_EMBEDDER_DIMS:+-e MEM0_EMBEDDER_DIMS="$MEM0_EMBEDDER_DIMS"} \
|
||||
${MEM0_EMBEDDER_DEVICE:+-e MEM0_EMBEDDER_DEVICE="$MEM0_EMBEDDER_DEVICE"} \
|
||||
${MEM0_VECTOR_COLLECTION:+-e MEM0_VECTOR_COLLECTION="$MEM0_VECTOR_COLLECTION"} \
|
||||
${RERANKER_MODEL:+-e RERANKER_MODEL="$RERANKER_MODEL"} \
|
||||
${RERANKER_DEVICE:+-e RERANKER_DEVICE="$RERANKER_DEVICE"} \
|
||||
-e RERANKER_PORT="$RERANKER_PORT" \
|
||||
-e RERANKER_URL="http://127.0.0.1:$RERANKER_PORT/rerank" \
|
||||
-e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
|
||||
-e HOME=/root \
|
||||
"$CONTAINER_IMAGE" \
|
||||
bash -c "pip install -q FlagEmbedding mem0ai flask httpx qdrant-client 2>&1 | tail -3; echo '=== Starting reranker (GPU) on :$RERANKER_PORT ==='; python3 /app/reranker-server.py & sleep 3; echo '=== Starting mem0 (GPU) on :$MEM0_PORT ==='; exec python3 /app/mem0-server.py"
|
||||
|
||||
echo "Container started, waiting for init..."
|
||||
sleep 15
|
||||
echo "=== Container logs ==="
|
||||
podman logs "$CONTAINER_NAME" 2>&1 | tail -25
|
||||
echo "=== Port check ==="
|
||||
ss -tlnp | grep "$MEM0_PORT\|$RERANKER_PORT" || echo "Ports not yet ready, check: podman logs $CONTAINER_NAME"
|
||||
@@ -1,288 +0,0 @@
|
||||
"""Minimal OpenMemory-compatible REST server wrapping mem0 Python SDK."""
|
||||
import asyncio
|
||||
import json, os, uuid, httpx
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import FastAPI, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from mem0 import Memory
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
RERANKER_URL = os.environ.get("RERANKER_URL", "http://127.0.0.1:8678/rerank")
|
||||
|
||||
CUSTOM_EXTRACTION_PROMPT = """You are a memory extraction specialist for a Cantonese/Chinese chat assistant.
|
||||
|
||||
Extract ONLY important, persistent facts from the conversation. Rules:
|
||||
1. Extract personal preferences, habits, relationships, names, locations
|
||||
2. Extract decisions, plans, and commitments people make
|
||||
3. SKIP small talk, greetings, reactions ("ok", "哈哈", "係呀")
|
||||
4. SKIP temporary states ("我依家食緊飯") unless they reveal a habit
|
||||
5. Keep facts in the ORIGINAL language (Cantonese/Chinese/English)
|
||||
6. For each fact, note WHO it's about (use their name or identifier if known)
|
||||
7. Merge/update existing facts rather than creating duplicates
|
||||
|
||||
Return a list of facts in JSON format: {"facts": ["fact1", "fact2", ...]}
|
||||
"""
|
||||
|
||||
PROCEDURAL_EXTRACTION_PROMPT = """You are a procedural memory specialist for an AI assistant.
|
||||
|
||||
Extract HOW-TO patterns and reusable procedures from the conversation trace. Rules:
|
||||
1. Identify step-by-step procedures the assistant followed to accomplish a task
|
||||
2. Extract tool usage patterns: which tools were called, in what order, with what arguments
|
||||
3. Capture decision points: why the assistant chose one approach over another
|
||||
4. Note error-recovery patterns: what failed, how it was fixed
|
||||
5. Keep the procedure generic enough to apply to similar future tasks
|
||||
6. Preserve technical details (commands, file paths, API calls) that are reusable
|
||||
7. SKIP greetings, small talk, and conversational filler
|
||||
8. Format each procedure as: "To [goal]: [step1] -> [step2] -> ... -> [result]"
|
||||
|
||||
Return a list of procedures in JSON format: {"facts": ["procedure1", "procedure2", ...]}
|
||||
"""
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
# LLM (for fact extraction when infer=true)
|
||||
MEM0_LLM_PROVIDER = os.environ.get("MEM0_LLM_PROVIDER", "openai") # "openai" (compatible), "anthropic", etc.
|
||||
MEM0_LLM_MODEL = os.environ.get("MEM0_LLM_MODEL", "glm-5-turbo")
|
||||
MEM0_LLM_API_KEY = os.environ.get("MEM0_LLM_API_KEY") or os.environ.get("ZAI_API_KEY", "")
|
||||
MEM0_LLM_BASE_URL = os.environ.get("MEM0_LLM_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
|
||||
|
||||
# Embedder
|
||||
MEM0_EMBEDDER_PROVIDER = os.environ.get("MEM0_EMBEDDER_PROVIDER", "huggingface") # "huggingface", "openai", etc.
|
||||
MEM0_EMBEDDER_MODEL = os.environ.get("MEM0_EMBEDDER_MODEL", "BAAI/bge-m3")
|
||||
MEM0_EMBEDDER_DIMS = int(os.environ.get("MEM0_EMBEDDER_DIMS", "1024"))
|
||||
MEM0_EMBEDDER_DEVICE = os.environ.get("MEM0_EMBEDDER_DEVICE", "cuda") # "cuda", "cpu", "auto"
|
||||
|
||||
# Vector store
|
||||
MEM0_VECTOR_PROVIDER = os.environ.get("MEM0_VECTOR_PROVIDER", "qdrant") # "qdrant", "chroma", etc.
|
||||
MEM0_VECTOR_COLLECTION = os.environ.get("MEM0_VECTOR_COLLECTION", "zeroclaw_mem0")
|
||||
MEM0_VECTOR_PATH = os.environ.get("MEM0_VECTOR_PATH", os.path.expanduser("~/mem0-data/qdrant"))
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": MEM0_LLM_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_LLM_MODEL,
|
||||
"api_key": MEM0_LLM_API_KEY,
|
||||
"openai_base_url": MEM0_LLM_BASE_URL,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": MEM0_EMBEDDER_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_EMBEDDER_MODEL,
|
||||
"embedding_dims": MEM0_EMBEDDER_DIMS,
|
||||
"model_kwargs": {"device": MEM0_EMBEDDER_DEVICE},
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": MEM0_VECTOR_PROVIDER,
|
||||
"config": {
|
||||
"collection_name": MEM0_VECTOR_COLLECTION,
|
||||
"embedding_model_dims": MEM0_EMBEDDER_DIMS,
|
||||
"path": MEM0_VECTOR_PATH,
|
||||
},
|
||||
},
|
||||
"custom_fact_extraction_prompt": CUSTOM_EXTRACTION_PROMPT,
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
|
||||
|
||||
def rerank_results(query: str, items: list, top_k: int = 10) -> list:
|
||||
"""Rerank search results using bge-reranker-v2-m3."""
|
||||
if not items:
|
||||
return items
|
||||
documents = [item.get("memory", "") for item in items]
|
||||
try:
|
||||
resp = httpx.post(
|
||||
RERANKER_URL,
|
||||
json={"query": query, "documents": documents, "top_k": top_k},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ranked = resp.json().get("results", [])
|
||||
return [items[r["index"]] for r in ranked]
|
||||
except Exception as e:
|
||||
print(f"Reranker failed, using original order: {e}")
|
||||
return items
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
infer: bool = True
|
||||
app: Optional[str] = None
|
||||
custom_instructions: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/")
|
||||
async def add_memory(req: AddMemoryRequest):
|
||||
# Use client-supplied prompt, fall back to server default, then mem0 SDK default
|
||||
prompt = req.custom_instructions or CUSTOM_EXTRACTION_PROMPT
|
||||
result = await asyncio.to_thread(m.add, req.text, user_id=req.user_id, metadata=req.metadata or {}, prompt=prompt)
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
class ProceduralMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
messages: list[dict]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/procedural")
|
||||
async def add_procedural_memory(req: ProceduralMemoryRequest):
|
||||
"""Store a conversation trace as procedural memory.
|
||||
|
||||
Accepts a list of messages (role/content dicts) representing a full
|
||||
conversation turn including tool calls, then uses mem0's native
|
||||
procedural memory extraction to learn reusable "how to" patterns.
|
||||
"""
|
||||
# Build metadata with procedural type marker
|
||||
meta = {"type": "procedural"}
|
||||
if req.metadata:
|
||||
meta.update(req.metadata)
|
||||
|
||||
# Use mem0's native message list support + procedural prompt
|
||||
result = await asyncio.to_thread(m.add,
|
||||
req.messages,
|
||||
user_id=req.user_id,
|
||||
metadata=meta,
|
||||
prompt=PROCEDURAL_EXTRACTION_PROMPT,
|
||||
)
|
||||
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
def _parse_mem0_results(raw_results) -> list:
|
||||
raw = raw_results.get("results", raw_results) if isinstance(raw_results, dict) else raw_results
|
||||
items = []
|
||||
for r in raw:
|
||||
item = r if isinstance(r, dict) else {"memory": str(r)}
|
||||
items.append({
|
||||
"id": item.get("id", str(uuid.uuid4())),
|
||||
"memory": item.get("memory", item.get("text", "")),
|
||||
"created_at": item.get("created_at", datetime.now(timezone.utc).isoformat()),
|
||||
"metadata_": item.get("metadata", {}),
|
||||
})
|
||||
return items
|
||||
|
||||
|
||||
def _parse_iso_timestamp(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 timestamp string, returning None on failure."""
|
||||
try:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _item_created_at(item: dict) -> Optional[datetime]:
|
||||
"""Extract created_at from an item as a timezone-aware datetime."""
|
||||
raw = item.get("created_at")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, datetime):
|
||||
if raw.tzinfo is None:
|
||||
raw = raw.replace(tzinfo=timezone.utc)
|
||||
return raw
|
||||
return _parse_iso_timestamp(str(raw))
|
||||
|
||||
|
||||
def _apply_post_filters(
|
||||
items: list,
|
||||
created_after: Optional[str],
|
||||
created_before: Optional[str],
|
||||
) -> list:
|
||||
"""Filter items by created_after / created_before timestamps (post-query)."""
|
||||
after_dt = _parse_iso_timestamp(created_after) if created_after else None
|
||||
before_dt = _parse_iso_timestamp(created_before) if created_before else None
|
||||
if after_dt is None and before_dt is None:
|
||||
return items
|
||||
filtered = []
|
||||
for item in items:
|
||||
ts = _item_created_at(item)
|
||||
if ts is None:
|
||||
# Keep items without a parseable timestamp
|
||||
filtered.append(item)
|
||||
continue
|
||||
if after_dt and ts < after_dt:
|
||||
continue
|
||||
if before_dt and ts > before_dt:
|
||||
continue
|
||||
filtered.append(item)
|
||||
return filtered
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/")
|
||||
async def list_or_search_memories(
|
||||
user_id: str = Query(...),
|
||||
search_query: Optional[str] = Query(None),
|
||||
size: int = Query(10),
|
||||
rerank: bool = Query(True),
|
||||
created_after: Optional[str] = Query(None),
|
||||
created_before: Optional[str] = Query(None),
|
||||
metadata_filter: Optional[str] = Query(None),
|
||||
):
|
||||
# Build mem0 SDK filters dict from metadata_filter JSON param
|
||||
sdk_filters = None
|
||||
if metadata_filter:
|
||||
try:
|
||||
sdk_filters = json.loads(metadata_filter)
|
||||
except json.JSONDecodeError:
|
||||
sdk_filters = None
|
||||
|
||||
if search_query:
|
||||
# Fetch more results than needed so reranker has candidates to work with
|
||||
fetch_size = min(size * 3, 50)
|
||||
results = await asyncio.to_thread(m.search,
|
||||
search_query,
|
||||
user_id=user_id,
|
||||
limit=fetch_size,
|
||||
filters=sdk_filters,
|
||||
)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
if rerank and items:
|
||||
items = rerank_results(search_query, items, top_k=size)
|
||||
else:
|
||||
items = items[:size]
|
||||
return {"items": items, "total": len(items)}
|
||||
else:
|
||||
results = await asyncio.to_thread(m.get_all,user_id=user_id, filters=sdk_filters)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
return {"items": items, "total": len(items)}
|
||||
|
||||
|
||||
@app.delete("/api/v1/memories/{memory_id}")
|
||||
async def delete_memory(memory_id: str):
|
||||
try:
|
||||
await asyncio.to_thread(m.delete, memory_id)
|
||||
except Exception:
|
||||
pass
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/{memory_id}/history")
|
||||
async def get_memory_history(memory_id: str):
|
||||
"""Return the edit history of a specific memory."""
|
||||
try:
|
||||
history = await asyncio.to_thread(m.history, memory_id)
|
||||
# Normalize to list of dicts
|
||||
entries = []
|
||||
raw = history if isinstance(history, list) else history.get("results", history) if isinstance(history, dict) else [history]
|
||||
for h in raw:
|
||||
entry = h if isinstance(h, dict) else {"event": str(h)}
|
||||
entries.append(entry)
|
||||
return {"memory_id": memory_id, "history": entries}
|
||||
except Exception as e:
|
||||
return {"memory_id": memory_id, "history": [], "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8765)
|
||||
@@ -1,50 +0,0 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from FlagEmbedding import FlagReranker
|
||||
import os, torch
|
||||
|
||||
app = Flask(__name__)
|
||||
reranker = None
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
RERANKER_DEVICE = os.environ.get("RERANKER_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
RERANKER_PORT = int(os.environ.get("RERANKER_PORT", "8678"))
|
||||
|
||||
def get_reranker():
|
||||
global reranker
|
||||
if reranker is None:
|
||||
reranker = FlagReranker(RERANKER_MODEL, use_fp16=True, device=RERANKER_DEVICE)
|
||||
return reranker
|
||||
|
||||
@app.route('/rerank', methods=['POST'])
|
||||
def rerank():
|
||||
data = request.json
|
||||
query = data.get('query', '')
|
||||
documents = data.get('documents', [])
|
||||
top_k = data.get('top_k', len(documents))
|
||||
|
||||
if not query or not documents:
|
||||
return jsonify({'error': 'query and documents required'}), 400
|
||||
|
||||
pairs = [[query, doc] for doc in documents]
|
||||
scores = get_reranker().compute_score(pairs)
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
|
||||
results = sorted(
|
||||
[{'index': i, 'document': doc, 'score': score}
|
||||
for i, (doc, score) in enumerate(zip(documents, scores))],
|
||||
key=lambda x: x['score'], reverse=True
|
||||
)[:top_k]
|
||||
|
||||
return jsonify({'results': results})
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health():
|
||||
return jsonify({'status': 'ok', 'model': RERANKER_MODEL, 'device': RERANKER_DEVICE})
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f'Loading reranker model ({RERANKER_MODEL}) on {RERANKER_DEVICE}...')
|
||||
get_reranker()
|
||||
print(f'Reranker server ready on :{RERANKER_PORT}')
|
||||
app.run(host='0.0.0.0', port=RERANKER_PORT)
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.5
|
||||
pkgver = 0.5.9
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.5.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.5.tar.gz
|
||||
source = zeroclaw-0.5.9.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.9.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.5
|
||||
pkgver=0.5.9
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.5",
|
||||
"version": "0.5.9",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.5/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.9/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# ADR-004: Tool Shared State Ownership Contract
|
||||
|
||||
**Status:** Accepted
|
||||
|
||||
**Date:** 2026-03-22
|
||||
|
||||
**Issue:** [#4057](https://github.com/zeroclaw/zeroclaw/issues/4057)
|
||||
|
||||
## Context
|
||||
|
||||
ZeroClaw tools execute in a multi-client environment where a single daemon
|
||||
process serves requests from multiple connected clients simultaneously. Several
|
||||
tools already maintain long-lived shared state:
|
||||
|
||||
- **`DelegateParentToolsHandle`** (`src/tools/mod.rs`):
|
||||
`Arc<RwLock<Vec<Arc<dyn Tool>>>>` — holds parent tools for delegate agents
|
||||
with no per-client isolation.
|
||||
- **`ChannelMapHandle`** (`src/tools/reaction.rs`):
|
||||
`Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>` — global channel map shared
|
||||
across all clients.
|
||||
- **`CanvasStore`** (`src/tools/canvas.rs`):
|
||||
`Arc<RwLock<HashMap<String, CanvasEntry>>>` — canvas IDs are plain strings
|
||||
with no client namespace.
|
||||
|
||||
These patterns emerged organically. As the tool surface grows and more clients
|
||||
connect concurrently, we need a clear contract governing ownership, identity,
|
||||
isolation, lifecycle, and reload behavior for tool-held shared state. Without
|
||||
this contract, new tools risk introducing data leaks between clients, stale
|
||||
state after config reloads, or inconsistent initialization timing.
|
||||
|
||||
Additional context:
|
||||
|
||||
- The tool registry is immutable after startup, built once in
|
||||
`all_tools_with_runtime()`.
|
||||
- Client identity is currently derived from IP address only
|
||||
(`src/gateway/mod.rs`), which is insufficient for reliable namespacing.
|
||||
- `SecurityPolicy` is scoped per agent, not per client.
|
||||
- `WorkspaceManager` provides some isolation but workspace switching is global.
|
||||
|
||||
## Decision
|
||||
|
||||
### 1. Ownership: May tools own long-lived shared state?
|
||||
|
||||
**Yes.** Tools MAY own long-lived shared state, provided they follow the
|
||||
established **handle pattern**: wrap the state in `Arc<RwLock<T>>` (or
|
||||
`Arc<parking_lot::RwLock<T>>`) and expose a cloneable handle type.
|
||||
|
||||
This pattern is already proven by three independent implementations:
|
||||
|
||||
| Handle | Location | Inner type |
|
||||
|--------|----------|-----------|
|
||||
| `DelegateParentToolsHandle` | `src/tools/mod.rs` | `Vec<Arc<dyn Tool>>` |
|
||||
| `ChannelMapHandle` | `src/tools/reaction.rs` | `HashMap<String, Arc<dyn Channel>>` |
|
||||
| `CanvasStore` | `src/tools/canvas.rs` | `HashMap<String, CanvasEntry>` |
|
||||
|
||||
Tools that need shared state MUST:
|
||||
|
||||
- Define a named handle type alias (e.g., `pub type FooHandle = Arc<RwLock<T>>`).
|
||||
- Accept the handle at construction time rather than creating global state.
|
||||
- Document the concurrency contract in the handle type's doc comment.
|
||||
|
||||
Tools MUST NOT use static mutable state (`lazy_static!`, `OnceCell` with
|
||||
interior mutability) for per-request or per-client data.
|
||||
|
||||
### 2. Identity assignment: Who constructs identity keys?
|
||||
|
||||
**The daemon SHOULD provide identity.** Tools MUST NOT construct their own
|
||||
client identity keys.
|
||||
|
||||
A new `ClientId` type should be introduced (opaque, `Clone + Eq + Hash + Send + Sync`)
|
||||
that the daemon assigns at connection time. This replaces the current approach
|
||||
of using raw IP addresses (`src/gateway/mod.rs:259-306`), which breaks when
|
||||
multiple clients share a NAT address or when proxied connections arrive.
|
||||
|
||||
`ClientId` is passed to tools that require per-client state namespacing as part
|
||||
of the tool execution context. Tools that do not need per-client isolation
|
||||
(e.g., the immutable tool registry) may ignore it.
|
||||
|
||||
The `ClientId` contract:
|
||||
|
||||
- Generated by the gateway layer at connection establishment.
|
||||
- Opaque to tools — tools must not parse or derive meaning from the value.
|
||||
- Stable for the lifetime of a single client session.
|
||||
- Passed through the execution context, not stored globally.
|
||||
|
||||
### 3. Lifecycle: When may tools run startup-style validation?
|
||||
|
||||
**Validation runs once at first registration, and again when config changes
|
||||
are detected.**
|
||||
|
||||
The lifecycle phases are:
|
||||
|
||||
1. **Construction** — tool is instantiated with handles and config. No I/O or
|
||||
validation occurs here.
|
||||
2. **Registration** — tool is registered in the tool registry via
|
||||
`all_tools_with_runtime()`. At this point the tool MAY perform one-time
|
||||
startup validation (e.g., checking that required credentials exist, verifying
|
||||
external service connectivity).
|
||||
3. **Execution** — tool handles individual requests. No re-validation unless
|
||||
the config-change signal fires (see Reload Semantics below).
|
||||
4. **Shutdown** — daemon is stopping. Tools with open resources SHOULD clean up
|
||||
gracefully via `Drop` or an explicit shutdown method.
|
||||
|
||||
Tools MUST NOT perform blocking validation during execution-phase calls.
|
||||
Validation results SHOULD be cached in the tool's handle state and checked
|
||||
via a fast path during execution.
|
||||
|
||||
### 4. Isolation: What must be isolated per client?
|
||||
|
||||
State falls into two categories with different isolation requirements:
|
||||
|
||||
**MUST be isolated per client:**
|
||||
|
||||
- Security-sensitive state: credentials, API keys, quotas, rate-limit counters,
|
||||
per-client authorization decisions.
|
||||
- User-specific session data: conversation context, user preferences,
|
||||
workspace-scoped file paths.
|
||||
|
||||
Isolation mechanism: tools holding per-client state MUST key their internal
|
||||
maps by `ClientId`. The handle pattern naturally supports this by using
|
||||
`HashMap<ClientId, T>` inside the `RwLock`.
|
||||
|
||||
**MAY be shared across clients (with namespace prefixing):**
|
||||
|
||||
- Broadcast/display state: canvas frames (`CanvasStore`), notification channels
|
||||
(`ChannelMapHandle`).
|
||||
- Read-only reference data: tool registry, static configuration, model
|
||||
metadata.
|
||||
|
||||
When shared state uses string keys (e.g., canvas IDs, channel names), tools
|
||||
SHOULD support optional namespace prefixing (e.g., `{client_id}:{canvas_name}`)
|
||||
to allow per-client isolation when needed without mandating it for broadcast
|
||||
use cases.
|
||||
|
||||
Tools MUST NOT store per-client secrets in shared (non-isolated) state
|
||||
structures.
|
||||
|
||||
### 5. Reload semantics: What invalidates prior shared state on config change?
|
||||
|
||||
**Config changes detected via hash comparison MUST invalidate cached
|
||||
validation state.**
|
||||
|
||||
The reload contract:
|
||||
|
||||
- The daemon computes a hash of the tool-relevant config section at startup and
|
||||
after each config reload event.
|
||||
- When the hash changes, the daemon signals affected tools to re-run their
|
||||
registration-phase validation.
|
||||
- Tools MUST treat their cached validation result as stale when signaled and
|
||||
re-validate before the next execution.
|
||||
|
||||
Specific invalidation rules:
|
||||
|
||||
| Config change | Invalidation scope |
|
||||
|--------------|-------------------|
|
||||
| Credential/secret rotation | Per-tool validation cache; per-client credential state |
|
||||
| Tool enable/disable | Full tool registry rebuild via `all_tools_with_runtime()` |
|
||||
| Security policy change | `SecurityPolicy` re-derivation; per-agent policy state |
|
||||
| Workspace directory change | `WorkspaceManager` state; file-path-dependent tool state |
|
||||
| Provider config change | Provider-dependent tools re-validate connectivity |
|
||||
|
||||
Tools MAY retain non-security shared state (e.g., canvas content, channel
|
||||
subscriptions) across config reloads unless the reload explicitly affects that
|
||||
state's validity.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
- **Consistency:** All new tools follow the same handle pattern, making shared
|
||||
state discoverable and auditable.
|
||||
- **Safety:** Per-client isolation of security-sensitive state prevents data
|
||||
leaks in multi-tenant scenarios.
|
||||
- **Clarity:** Explicit lifecycle phases eliminate ambiguity about when
|
||||
validation runs.
|
||||
- **Evolvability:** The `ClientId` abstraction decouples tools from transport
|
||||
details, supporting future identity mechanisms (tokens, certificates).
|
||||
|
||||
### Negative
|
||||
|
||||
- **Migration cost:** Existing tools (`CanvasStore`, `ReactionTool`) may need
|
||||
refactoring to accept `ClientId` and namespace their state.
|
||||
- **Complexity:** Tools that were simple singletons now need to consider
|
||||
multi-client semantics even if they currently have one client.
|
||||
- **Performance:** Per-client keying adds a hash lookup on each access, though
|
||||
this is negligible compared to I/O costs.
|
||||
|
||||
### Neutral
|
||||
|
||||
- The tool registry remains immutable after startup; this ADR does not change
|
||||
that invariant.
|
||||
- `SecurityPolicy` remains per-agent; this ADR documents that client isolation
|
||||
is orthogonal to agent-level policy.
|
||||
|
||||
## References
|
||||
|
||||
- `src/tools/mod.rs` — `DelegateParentToolsHandle`, `all_tools_with_runtime()`
|
||||
- `src/tools/reaction.rs` — `ChannelMapHandle`, `ReactionTool`
|
||||
- `src/tools/canvas.rs` — `CanvasStore`, `CanvasEntry`
|
||||
- `src/tools/traits.rs` — `Tool` trait
|
||||
- `src/gateway/mod.rs` — client IP extraction (`forwarded_client_ip`, `resolve_client_ip`)
|
||||
- `src/security/` — `SecurityPolicy`
|
||||
@@ -0,0 +1,215 @@
|
||||
# Browser Automation Setup Guide
|
||||
|
||||
This guide covers setting up browser automation capabilities in ZeroClaw, including both headless automation and GUI access via VNC.
|
||||
|
||||
## Overview
|
||||
|
||||
ZeroClaw supports multiple browser access methods:
|
||||
|
||||
| Method | Use Case | Requirements |
|
||||
|--------|----------|--------------|
|
||||
| **agent-browser CLI** | Headless automation, AI agents | npm, Chrome |
|
||||
| **VNC + noVNC** | GUI access, debugging | Xvfb, x11vnc, noVNC |
|
||||
| **Chrome Remote Desktop** | Remote GUI via Google | XFCE, Google account |
|
||||
|
||||
## Quick Start: Headless Automation
|
||||
|
||||
### 1. Install agent-browser
|
||||
|
||||
```bash
|
||||
# Install CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux (includes system deps)
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
### 2. Verify ZeroClaw Config
|
||||
|
||||
The browser tool is enabled by default. To verify or customize, edit
|
||||
`~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
### 3. Test
|
||||
|
||||
```bash
|
||||
echo "Open https://example.com and tell me what it says" | zeroclaw agent
|
||||
```
|
||||
|
||||
## VNC Setup (GUI Access)
|
||||
|
||||
For debugging or when you need visual browser access:
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
apt-get install -y xvfb x11vnc fluxbox novnc websockify
|
||||
|
||||
# Optional: Desktop environment for Chrome Remote Desktop
|
||||
apt-get install -y xfce4 xfce4-goodies
|
||||
```
|
||||
|
||||
### Start VNC Server
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access
|
||||
|
||||
DISPLAY_NUM=99
|
||||
VNC_PORT=5900
|
||||
NOVNC_PORT=6080
|
||||
RESOLUTION=1920x1080x24
|
||||
|
||||
# Start Xvfb
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
sleep 1
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
|
||||
echo "VNC available at:"
|
||||
echo " VNC Client: localhost:$VNC_PORT"
|
||||
echo " Web Browser: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
```
|
||||
|
||||
### VNC Access
|
||||
|
||||
- **VNC Client**: Connect to `localhost:5900`
|
||||
- **Web Browser**: Open `http://localhost:6080/vnc.html`
|
||||
|
||||
### Start Browser on VNC Display
|
||||
|
||||
```bash
|
||||
DISPLAY=:99 google-chrome --no-sandbox https://example.com &
|
||||
```
|
||||
|
||||
## Chrome Remote Desktop
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# Download and install
|
||||
wget https://dl.google.com/linux/direct/chrome-remote-desktop_current_amd64.deb
|
||||
apt-get install -y ./chrome-remote-desktop_current_amd64.deb
|
||||
|
||||
# Configure session
|
||||
echo "xfce4-session" > ~/.chrome-remote-desktop-session
|
||||
chmod +x ~/.chrome-remote-desktop-session
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
1. Visit <https://remotedesktop.google.com/headless>
|
||||
2. Copy the "Debian Linux" setup command
|
||||
3. Run it on your server
|
||||
4. Start the service: `systemctl --user start chrome-remote-desktop`
|
||||
|
||||
### Remote Access
|
||||
|
||||
Go to <https://remotedesktop.google.com/access> from any device.
|
||||
|
||||
## Testing
|
||||
|
||||
### CLI Tests
|
||||
|
||||
```bash
|
||||
# Basic open and close
|
||||
agent-browser open https://example.com
|
||||
agent-browser get title
|
||||
agent-browser close
|
||||
|
||||
# Snapshot with refs
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser close
|
||||
|
||||
# Screenshot
|
||||
agent-browser open https://example.com
|
||||
agent-browser screenshot /tmp/test.png
|
||||
agent-browser close
|
||||
```
|
||||
|
||||
### ZeroClaw Integration Tests
|
||||
|
||||
```bash
|
||||
# Content extraction
|
||||
echo "Open https://example.com and summarize it" | zeroclaw agent
|
||||
|
||||
# Navigation
|
||||
echo "Go to https://github.com/trending and list the top 3 repos" | zeroclaw agent
|
||||
|
||||
# Form interaction
|
||||
echo "Go to Wikipedia, search for 'Rust programming language', and summarize" | zeroclaw agent
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Element not found"
|
||||
|
||||
The page may not be fully loaded. Add a wait:
|
||||
|
||||
```bash
|
||||
agent-browser open https://slow-site.com
|
||||
agent-browser wait --load networkidle
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Cookie dialogs blocking access
|
||||
|
||||
Handle cookie consent first:
|
||||
|
||||
```bash
|
||||
agent-browser open https://site-with-cookies.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser click @accept_cookies # Click the accept button
|
||||
agent-browser snapshot -i # Now get the actual content
|
||||
```
|
||||
|
||||
### Docker sandbox network restrictions
|
||||
|
||||
If `web_fetch` fails inside Docker sandbox, use agent-browser instead:
|
||||
|
||||
```bash
|
||||
# Instead of web_fetch, use:
|
||||
agent-browser open https://example.com
|
||||
agent-browser get text body
|
||||
```
|
||||
|
||||
## Security Notes
|
||||
|
||||
- `agent-browser` runs Chrome in headless mode with sandboxing
|
||||
- For sensitive sites, use `--session-name` to persist auth state
|
||||
- The `--allowed-domains` config restricts navigation to specific domains
|
||||
- VNC ports (5900, 6080) should be behind a firewall or Tailscale
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser Documentation](https://github.com/vercel-labs/agent-browser)
|
||||
- [ZeroClaw Configuration Reference](./config-reference.md)
|
||||
- [Skills Documentation](../skills/)
|
||||
@@ -45,6 +45,15 @@ For complete code examples of each extension trait, see [extension-examples.md](
|
||||
- Keep multilingual entry-point parity for all supported locales (`en`, `zh-CN`, `ja`, `ru`, `fr`, `vi`) when nav or key wording changes.
|
||||
- When shared docs wording changes, sync corresponding localized docs in the same PR (or explicitly document deferral and follow-up PR).
|
||||
|
||||
## Tool Shared State
|
||||
|
||||
- Follow the `Arc<RwLock<T>>` handle pattern for any tool that owns long-lived shared state.
|
||||
- Accept handles at construction; do not create global/static mutable state.
|
||||
- Use `ClientId` (provided by the daemon) to namespace per-client state — never construct identity keys inside the tool.
|
||||
- Isolate security-sensitive state (credentials, quotas) per client; broadcast/display state may be shared with optional namespace prefixing.
|
||||
- Cached validation is invalidated on config change — tools must re-validate before the next execution when signaled.
|
||||
- See [ADR-004: Tool Shared State Ownership](../architecture/adr-004-tool-shared-state-ownership.md) for the full contract.
|
||||
|
||||
## Architecture Boundary Rules
|
||||
|
||||
- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features.
|
||||
|
||||
@@ -411,30 +411,6 @@ allowed_roots = [\"~/Desktop/projects\", \"/opt/shared-repo\"]
|
||||
|
||||
- 内存上下文注入忽略旧的 `assistant_resp*` 自动保存键,以防止旧模型生成的摘要被视为事实。
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) 后端 — 连接自托管 mem0 服务器,提供基于向量的记忆存储和 LLM 事实提取。构建时需要 `memory-mem0` feature flag,配置需设置 `backend = "mem0"`。
|
||||
|
||||
| 键 | 默认值 | 环境变量 | 用途 |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory 服务器地址 |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | 记忆作用域的用户 ID |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | 在 mem0 中注册的应用名称 |
|
||||
| `infer` | `true` | — | 使用 LLM 从存储文本中提取事实 (`true`) 或原样存储 (`false`) |
|
||||
| `extraction_prompt` | 未设置 | `MEM0_EXTRACTION_PROMPT` | 自定义 LLM 事实提取提示词(如适用于非英文内容) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "用原始语言提取事实..."
|
||||
```
|
||||
|
||||
服务器部署脚本位于 `deploy/mem0/`。
|
||||
|
||||
## `[[model_routes]]` 和 `[[embedding_routes]]`
|
||||
|
||||
使用路由提示,以便集成可以在模型 ID 演变时保持稳定的名称。
|
||||
|
||||
@@ -122,6 +122,34 @@ tools = ["mcp_browser_*"]
|
||||
keywords = ["browse", "navigate", "open url", "screenshot"]
|
||||
```
|
||||
|
||||
## `[pacing]`
|
||||
|
||||
Pacing controls for slow/local LLM workloads (Ollama, llama.cpp, vLLM). All keys are optional; when absent, existing behavior is preserved.
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `step_timeout_secs` | _none_ | Per-step timeout: maximum seconds for a single LLM inference turn. Catches a truly hung model without terminating the overall task loop |
|
||||
| `loop_detection_min_elapsed_secs` | _none_ | Minimum elapsed seconds before loop detection activates. Tasks completing under this threshold get aggressive loop protection; longer-running tasks receive a grace period |
|
||||
| `loop_ignore_tools` | `[]` | Tool names excluded from identical-output loop detection. Useful for browser workflows where `browser_screenshot` structurally resembles a loop |
|
||||
| `message_timeout_scale_max` | `4` | Override for the hardcoded timeout scaling cap. The channel message timeout budget is `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)` |
|
||||
|
||||
Notes:
|
||||
|
||||
- These settings are intended for local/slow LLM deployments. Cloud-provider users typically do not need them.
|
||||
- `step_timeout_secs` operates independently of the total channel message timeout budget. A step timeout abort does not consume the overall budget; the loop simply stops.
|
||||
- `loop_detection_min_elapsed_secs` delays loop-detection counting, not the task itself. Loop protection remains fully active for short tasks (the default).
|
||||
- `loop_ignore_tools` only suppresses tool-output-based loop detection for the listed tools. Other safety features (max iterations, overall timeout) remain active.
|
||||
- `message_timeout_scale_max` must be >= 1. Setting it higher than `max_tool_iterations` has no additional effect (the formula uses `min()`).
|
||||
- Example configuration for a slow local Ollama deployment:
|
||||
|
||||
```toml
|
||||
[pacing]
|
||||
step_timeout_secs = 120
|
||||
loop_detection_min_elapsed_secs = 60
|
||||
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
|
||||
message_timeout_scale_max = 8
|
||||
```
|
||||
|
||||
## `[security.otp]`
|
||||
|
||||
| Key | Default | Purpose |
|
||||
@@ -425,6 +453,12 @@ Notes:
|
||||
| `port` | `42617` | gateway listen port |
|
||||
| `require_pairing` | `true` | require pairing before bearer auth |
|
||||
| `allow_public_bind` | `false` | block accidental public exposure |
|
||||
| `path_prefix` | _(none)_ | URL path prefix for reverse-proxy deployments (e.g. `"/zeroclaw"`) |
|
||||
|
||||
When deploying behind a reverse proxy that maps ZeroClaw to a sub-path,
|
||||
set `path_prefix` to that sub-path (e.g. `"/zeroclaw"`). All gateway
|
||||
routes will be served under this prefix. The value must start with `/`
|
||||
and must not end with `/`.
|
||||
|
||||
## `[autonomy]`
|
||||
|
||||
@@ -474,30 +508,6 @@ Notes:
|
||||
|
||||
- Memory context injection ignores legacy `assistant_resp*` auto-save keys to prevent old model-authored summaries from being treated as facts.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) backend — connects to a self-hosted mem0 server for vector-based memory with LLM-powered fact extraction. Requires feature flag `memory-mem0` at build time and `backend = "mem0"` in config.
|
||||
|
||||
| Key | Default | Env var | Purpose |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory server URL |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID for scoping memories |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Application name registered in mem0 |
|
||||
| `infer` | `true` | — | Use LLM to extract facts from stored text (`true`) or store raw (`false`) |
|
||||
| `extraction_prompt` | unset | `MEM0_EXTRACTION_PROMPT` | Custom prompt for LLM fact extraction (e.g. for non-English content) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Extract facts in the original language..."
|
||||
```
|
||||
|
||||
Server deployment scripts are in `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` and `[[embedding_routes]]`
|
||||
|
||||
Use route hints so integrations can keep stable names while model IDs evolve.
|
||||
@@ -597,7 +607,7 @@ Top-level channel options are configured under `channels_config`.
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x) |
|
||||
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x, overridable via `[pacing].message_timeout_scale_max`) |
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -612,7 +622,7 @@ Examples:
|
||||
Notes:
|
||||
|
||||
- Default `300s` is optimized for on-device LLMs (Ollama) which are slower than cloud APIs.
|
||||
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, 4)` and a minimum of `1`.
|
||||
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, cap)` and a minimum of `1`. The default cap is `4`; override with `[pacing].message_timeout_scale_max`.
|
||||
- This scaling avoids false timeouts when the first LLM turn is slow/retried but later tool-loop turns still need to complete.
|
||||
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
|
||||
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
|
||||
|
||||
@@ -337,30 +337,6 @@ Lưu ý:
|
||||
|
||||
- Chèn ngữ cảnh memory bỏ qua khóa auto-save `assistant_resp*` kiểu cũ để tránh tóm tắt do model tạo bị coi là sự thật.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Backend Mem0 (OpenMemory) — kết nối đến server mem0 tự host, cung cấp bộ nhớ vector với trích xuất sự kiện bằng LLM. Cần feature flag `memory-mem0` khi build và `backend = "mem0"` trong config.
|
||||
|
||||
| Khóa | Mặc định | Biến môi trường | Mục đích |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | URL server OpenMemory |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID để phân vùng memory |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Tên ứng dụng đăng ký trong mem0 |
|
||||
| `infer` | `true` | — | Dùng LLM trích xuất sự kiện từ text (`true`) hoặc lưu nguyên (`false`) |
|
||||
| `extraction_prompt` | chưa đặt | `MEM0_EXTRACTION_PROMPT` | Prompt tùy chỉnh cho trích xuất sự kiện LLM (vd: cho nội dung không phải tiếng Anh) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Trích xuất sự kiện bằng ngôn ngữ gốc..."
|
||||
```
|
||||
|
||||
Script triển khai server nằm trong `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` và `[[embedding_routes]]`
|
||||
|
||||
Route hint giúp tên tích hợp ổn định khi model ID thay đổi.
|
||||
|
||||
@@ -38,3 +38,46 @@ allowed_tools = ["read", "edit", "exec"]
|
||||
max_iterations = 15
|
||||
# Optional: use longer timeout for complex coding tasks
|
||||
agentic_timeout_secs = 600
|
||||
|
||||
# ── Cron Configuration ────────────────────────────────────────
|
||||
[cron]
|
||||
# Enable the cron subsystem. Default: true
|
||||
enabled = true
|
||||
# Run all overdue jobs at scheduler startup. Default: true
|
||||
catch_up_on_startup = true
|
||||
# Maximum number of historical cron run records to retain. Default: 50
|
||||
max_run_history = 50
|
||||
|
||||
# ── Declarative Cron Jobs ─────────────────────────────────────
|
||||
# Define cron jobs directly in config. These are synced to the database
|
||||
# at scheduler startup. Each job needs a stable `id` for merge semantics.
|
||||
|
||||
# Shell job: runs a shell command on a cron schedule
|
||||
[[cron.jobs]]
|
||||
id = "daily-backup"
|
||||
name = "Daily Backup"
|
||||
job_type = "shell"
|
||||
command = "tar czf /tmp/backup.tar.gz /data"
|
||||
schedule = { kind = "cron", expr = "0 2 * * *" }
|
||||
|
||||
# Agent job: runs an agent prompt on an interval
|
||||
[[cron.jobs]]
|
||||
id = "health-check"
|
||||
name = "Health Check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health: disk space, memory, CPU load"
|
||||
model = "anthropic/claude-sonnet-4"
|
||||
allowed_tools = ["shell", "file_read"]
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
|
||||
# Cron job with timezone and delivery
|
||||
# [[cron.jobs]]
|
||||
# id = "morning-report"
|
||||
# name = "Morning Report"
|
||||
# job_type = "agent"
|
||||
# prompt = "Generate a daily summary of system metrics"
|
||||
# schedule = { kind = "cron", expr = "0 9 * * 1-5", tz = "America/New_York" }
|
||||
# [cron.jobs.delivery]
|
||||
# mode = "announce"
|
||||
# channel = "telegram"
|
||||
# to = "123456789"
|
||||
|
||||
+79
-25
@@ -569,11 +569,29 @@ MSG
|
||||
exit 0
|
||||
fi
|
||||
# Detect un-accepted Xcode/CLT license (causes `cc` to exit 69).
|
||||
if ! /usr/bin/xcrun --show-sdk-path >/dev/null 2>&1; then
|
||||
warn "Xcode license has not been accepted. Run:"
|
||||
warn " sudo xcodebuild -license accept"
|
||||
warn "then re-run this installer."
|
||||
exit 1
|
||||
# xcrun --show-sdk-path can succeed even without an accepted license,
|
||||
# so we test-compile a trivial C file which reliably triggers the error.
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
|
||||
_xcode_accept_ok=false
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
xcodebuild -license accept && _xcode_accept_ok=true
|
||||
elif [[ -c /dev/tty ]] && have_cmd sudo; then
|
||||
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
|
||||
fi
|
||||
if [[ "$_xcode_accept_ok" == true ]]; then
|
||||
step_ok "Xcode license accepted"
|
||||
else
|
||||
error "Could not accept Xcode license. Run manually:"
|
||||
error " sudo xcodebuild -license accept"
|
||||
error "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
rm -f "$_xcode_test_file"
|
||||
fi
|
||||
if ! have_cmd git; then
|
||||
warn "git is not available. Install git (e.g., Homebrew) and re-run bootstrap."
|
||||
@@ -1175,6 +1193,43 @@ else
|
||||
install_system_deps
|
||||
fi
|
||||
|
||||
# Always check Xcode/CLT license on macOS, regardless of --install-system-deps.
|
||||
# An un-accepted license causes `cc` to exit 69, breaking all Rust builds.
|
||||
if [[ "$OS_NAME" == "Darwin" ]]; then
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
|
||||
# Use /dev/tty so sudo can prompt for a password even in a curl|bash pipe.
|
||||
_xcode_accept_ok=false
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
xcodebuild -license accept && _xcode_accept_ok=true
|
||||
elif [[ -c /dev/tty ]] && have_cmd sudo; then
|
||||
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
|
||||
fi
|
||||
if [[ "$_xcode_accept_ok" == true ]]; then
|
||||
step_ok "Xcode license accepted"
|
||||
# Re-test compilation to confirm it's fixed.
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
error "C compiler still failing after license accept. Check your Xcode/CLT installation."
|
||||
exit 1
|
||||
fi
|
||||
rm -f "$_xcode_test_file"
|
||||
else
|
||||
error "Could not accept Xcode license. Run manually:"
|
||||
error " sudo xcodebuild -license accept"
|
||||
error "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
rm -f "$_xcode_test_file"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$INSTALL_RUST" == true ]]; then
|
||||
install_rust_toolchain
|
||||
fi
|
||||
@@ -1393,6 +1448,25 @@ else
|
||||
step_dot "Skipping install"
|
||||
fi
|
||||
|
||||
# --- Build web dashboard ---
|
||||
if [[ "$SKIP_BUILD" == false && -d "$WORK_DIR/web" ]]; then
|
||||
if have_cmd node && have_cmd npm; then
|
||||
step_dot "Building web dashboard"
|
||||
if (cd "$WORK_DIR/web" && npm ci --ignore-scripts 2>/dev/null && npm run build 2>/dev/null); then
|
||||
step_ok "Web dashboard built"
|
||||
else
|
||||
warn "Web dashboard build failed — dashboard will not be available"
|
||||
fi
|
||||
else
|
||||
warn "node/npm not found — skipping web dashboard build"
|
||||
warn "Install Node.js (>=18) and re-run, or build manually: cd web && npm ci && npm run build"
|
||||
fi
|
||||
else
|
||||
if [[ "$SKIP_BUILD" == true ]]; then
|
||||
step_dot "Skipping web dashboard build"
|
||||
fi
|
||||
fi
|
||||
|
||||
ZEROCLAW_BIN=""
|
||||
if [[ -x "$HOME/.cargo/bin/zeroclaw" ]]; then
|
||||
ZEROCLAW_BIN="$HOME/.cargo/bin/zeroclaw"
|
||||
@@ -1467,25 +1541,6 @@ if [[ -n "$ZEROCLAW_BIN" ]]; then
|
||||
if "$ZEROCLAW_BIN" service restart 2>/dev/null; then
|
||||
step_ok "Gateway service restarted"
|
||||
|
||||
# Fetch and display pairing code from running gateway
|
||||
PAIR_CODE=""
|
||||
for i in 1 2 3 4 5; do
|
||||
sleep 2
|
||||
if PAIR_CODE=$("$ZEROCLAW_BIN" gateway get-paircode 2>/dev/null | grep -oE '[0-9]{6}'); then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [[ -n "$PAIR_CODE" ]]; then
|
||||
echo
|
||||
echo -e " ${BOLD_BLUE}🔐 Gateway Pairing Code${RESET}"
|
||||
echo
|
||||
echo -e " ${BOLD_BLUE}┌──────────────┐${RESET}"
|
||||
echo -e " ${BOLD_BLUE}│${RESET} ${BOLD}${PAIR_CODE}${RESET} ${BOLD_BLUE}│${RESET}"
|
||||
echo -e " ${BOLD_BLUE}└──────────────┘${RESET}"
|
||||
echo
|
||||
echo -e " ${DIM}Enter this code in the dashboard to pair your device.${RESET}"
|
||||
echo -e " ${DIM}Run 'zeroclaw gateway get-paircode --new' anytime to generate a fresh code.${RESET}"
|
||||
fi
|
||||
else
|
||||
step_fail "Gateway service restart failed — re-run with zeroclaw service start"
|
||||
fi
|
||||
@@ -1532,7 +1587,6 @@ GATEWAY_PORT=42617
|
||||
DASHBOARD_URL="http://127.0.0.1:${GATEWAY_PORT}"
|
||||
echo
|
||||
echo -e "${BOLD}Dashboard URL:${RESET} ${BLUE}${DASHBOARD_URL}${RESET}"
|
||||
echo -e "${DIM} Run 'zeroclaw gateway get-paircode' to get your pairing code.${RESET}"
|
||||
|
||||
# --- Copy to clipboard ---
|
||||
COPIED_TO_CLIPBOARD=false
|
||||
|
||||
Executable
+21
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Start a browser on a virtual display
|
||||
# Usage: ./start-browser.sh [display_num] [url]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
URL=${2:-"https://google.com"}
|
||||
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Check if display is running
|
||||
if ! xdpyinfo -display :$DISPLAY_NUM &>/dev/null; then
|
||||
echo "Error: Display :$DISPLAY_NUM not running."
|
||||
echo "Start VNC first: ./start-vnc.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
google-chrome --no-sandbox --disable-gpu --disable-setuid-sandbox "$URL" &
|
||||
echo "Chrome started on display :$DISPLAY_NUM"
|
||||
echo "View via VNC or noVNC"
|
||||
Executable
+52
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access for browser GUI
|
||||
# Usage: ./start-vnc.sh [display_num] [vnc_port] [novnc_port] [resolution]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
VNC_PORT=${2:-5900}
|
||||
NOVNC_PORT=${3:-6080}
|
||||
RESOLUTION=${4:-1920x1080x24}
|
||||
|
||||
echo "Starting virtual display :$DISPLAY_NUM at $RESOLUTION"
|
||||
|
||||
# Kill any existing sessions
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*$NOVNC_PORT" 2>/dev/null || true
|
||||
sleep 1
|
||||
|
||||
# Start Xvfb (virtual framebuffer)
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
XVFB_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Set DISPLAY
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM 2>/dev/null &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg 2>/dev/null
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC client)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
NOVNC_PID=$!
|
||||
|
||||
echo ""
|
||||
echo "==================================="
|
||||
echo "VNC Server started!"
|
||||
echo "==================================="
|
||||
echo "VNC Direct: localhost:$VNC_PORT"
|
||||
echo "noVNC Web: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
echo "Display: :$DISPLAY_NUM"
|
||||
echo "==================================="
|
||||
echo ""
|
||||
echo "To start a browser, run:"
|
||||
echo " DISPLAY=:$DISPLAY_NUM google-chrome &"
|
||||
echo ""
|
||||
echo "To stop, run: pkill -f 'Xvfb :$DISPLAY_NUM'"
|
||||
Executable
+11
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
# Stop virtual display and VNC server
|
||||
# Usage: ./stop-vnc.sh [display_num]
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*6080" 2>/dev/null || true
|
||||
|
||||
echo "VNC server stopped"
|
||||
@@ -77,7 +77,9 @@ echo "Created annotated tag: $TAG"
|
||||
if [[ "$PUSH_TAG" == "true" ]]; then
|
||||
git push origin "$TAG"
|
||||
echo "Pushed tag to origin: $TAG"
|
||||
echo "GitHub release pipeline will run via .github/workflows/pub-release.yml"
|
||||
echo "Release Stable workflow will auto-trigger via tag push."
|
||||
echo "Monitor: gh workflow view 'Release Stable' --web"
|
||||
else
|
||||
echo "Next step: git push origin $TAG"
|
||||
echo "This will auto-trigger the Release Stable workflow (builds, Docker, crates.io, website, Scoop, AUR, Homebrew, tweet)."
|
||||
fi
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
---
|
||||
name: browser
|
||||
description: Headless browser automation using agent-browser CLI
|
||||
metadata: {"zeroclaw":{"emoji":"🌐","requires":{"bins":["agent-browser"]}}}
|
||||
---
|
||||
|
||||
# Browser Skill
|
||||
|
||||
Control a headless browser for web automation, scraping, and testing.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `agent-browser` CLI installed globally (`npm install -g agent-browser`)
|
||||
- Chrome downloaded (`agent-browser install`)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install agent-browser CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Navigate and snapshot
|
||||
|
||||
```bash
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Interact with elements
|
||||
|
||||
```bash
|
||||
agent-browser click @e1 # Click by ref
|
||||
agent-browser fill @e2 "text" # Fill input
|
||||
agent-browser press Enter # Press key
|
||||
```
|
||||
|
||||
### Extract data
|
||||
|
||||
```bash
|
||||
agent-browser get text @e1 # Get text content
|
||||
agent-browser get url # Get current URL
|
||||
agent-browser screenshot page.png # Take screenshot
|
||||
```
|
||||
|
||||
### Session management
|
||||
|
||||
```bash
|
||||
agent-browser close # Close browser
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Login flow
|
||||
|
||||
```bash
|
||||
agent-browser open https://site.com/login
|
||||
agent-browser snapshot -i
|
||||
agent-browser fill @email "user@example.com"
|
||||
agent-browser fill @password "secretpass"
|
||||
agent-browser click @submit
|
||||
agent-browser wait --text "Welcome"
|
||||
```
|
||||
|
||||
### Scrape page content
|
||||
|
||||
```bash
|
||||
agent-browser open https://news.ycombinator.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser get text @e1
|
||||
```
|
||||
|
||||
### Take screenshots
|
||||
|
||||
```bash
|
||||
agent-browser open https://google.com
|
||||
agent-browser screenshot --full page.png
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
- `--json` - JSON output for parsing
|
||||
- `--headed` - Show browser window (for debugging)
|
||||
- `--session-name <name>` - Persist session cookies
|
||||
- `--profile <path>` - Use persistent browser profile
|
||||
|
||||
## Configuration
|
||||
|
||||
The browser tool is enabled by default with `allowed_domains = ["*"]` and
|
||||
`backend = "agent_browser"`. To customize, edit `~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
## Full Command Reference
|
||||
|
||||
Run `agent-browser --help` for all available commands.
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser GitHub](https://github.com/vercel-labs/agent-browser)
|
||||
- [VNC Setup Guide](../docs/browser-setup.md)
|
||||
+2
-1
@@ -359,7 +359,7 @@ impl Agent {
|
||||
None
|
||||
};
|
||||
|
||||
let (mut tools, delegate_handle) = tools::all_tools_with_runtime(
|
||||
let (mut tools, delegate_handle, _reaction_handle) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
@@ -373,6 +373,7 @@ impl Agent {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
// ── Wire MCP tools (non-fatal) ─────────────────────────────
|
||||
|
||||
+657
-82
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory};
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -43,13 +43,16 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory
|
||||
let mut entries = memory
|
||||
.recall(user_message, self.limit, session_id, None, None)
|
||||
.await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
@@ -118,6 +121,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -226,6 +232,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.95),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
@@ -235,6 +244,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
@@ -5,6 +5,7 @@ pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
pub mod thinking;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
+7
-6
@@ -473,8 +473,9 @@ mod tests {
|
||||
assert!(output.contains("<available_skills>"));
|
||||
assert!(output.contains("<name>deploy</name>"));
|
||||
assert!(output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -516,10 +517,10 @@ mod tests {
|
||||
assert!(output.contains("<location>skills/deploy/SKILL.md</location>"));
|
||||
assert!(output.contains("read_skill(name)"));
|
||||
assert!(!output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(output.contains("<tools>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
//! Thinking/Reasoning Level Control
|
||||
//!
|
||||
//! Allows users to control how deeply the model reasons per message,
|
||||
//! trading speed for depth. Levels range from `Off` (fastest, most concise)
|
||||
//! to `Max` (deepest reasoning, slowest).
|
||||
//!
|
||||
//! Users can set the level via:
|
||||
//! - Inline directive: `/think:high` at the start of a message
|
||||
//! - Agent config: `[agent.thinking]` section with `default_level`
|
||||
//!
|
||||
//! Resolution hierarchy (highest priority first):
|
||||
//! 1. Inline directive (`/think:<level>`)
|
||||
//! 2. Session override (reserved for future use)
|
||||
//! 3. Agent config (`agent.thinking.default_level`)
|
||||
//! 4. Global default (`Medium`)
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// How deeply the model should reason for a given message.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ThinkingLevel {
|
||||
/// No chain-of-thought. Fastest, most concise responses.
|
||||
Off,
|
||||
/// Minimal reasoning. Brief, direct answers.
|
||||
Minimal,
|
||||
/// Light reasoning. Short explanations when needed.
|
||||
Low,
|
||||
/// Balanced reasoning (default). Moderate depth.
|
||||
#[default]
|
||||
Medium,
|
||||
/// Deep reasoning. Thorough analysis and step-by-step thinking.
|
||||
High,
|
||||
/// Maximum reasoning depth. Exhaustive analysis.
|
||||
Max,
|
||||
}
|
||||
|
||||
impl ThinkingLevel {
|
||||
/// Parse a thinking level from a string (case-insensitive).
|
||||
pub fn from_str_insensitive(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"off" | "none" => Some(Self::Off),
|
||||
"minimal" | "min" => Some(Self::Minimal),
|
||||
"low" => Some(Self::Low),
|
||||
"medium" | "med" | "default" => Some(Self::Medium),
|
||||
"high" => Some(Self::High),
|
||||
"max" | "maximum" => Some(Self::Max),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for thinking/reasoning level control.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ThinkingConfig {
|
||||
/// Default thinking level when no directive is present.
|
||||
#[serde(default)]
|
||||
pub default_level: ThinkingLevel,
|
||||
}
|
||||
|
||||
impl Default for ThinkingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_level: ThinkingLevel::Medium,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters derived from a thinking level, applied to the LLM request.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ThinkingParams {
|
||||
/// Temperature adjustment (added to the base temperature, clamped to 0.0..=2.0).
|
||||
pub temperature_adjustment: f64,
|
||||
/// Maximum tokens adjustment (added to any existing max_tokens setting).
|
||||
pub max_tokens_adjustment: i64,
|
||||
/// Optional system prompt prefix injected before the existing system prompt.
|
||||
pub system_prompt_prefix: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse a `/think:<level>` directive from the start of a message.
|
||||
///
|
||||
/// Returns `Some((level, remaining_message))` if a directive is found,
|
||||
/// or `None` if no directive is present. The remaining message has
|
||||
/// leading whitespace after the directive trimmed.
|
||||
pub fn parse_thinking_directive(message: &str) -> Option<(ThinkingLevel, String)> {
|
||||
let trimmed = message.trim_start();
|
||||
if !trimmed.starts_with("/think:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract the level token (everything between `/think:` and the next whitespace or end).
|
||||
let after_prefix = &trimmed["/think:".len()..];
|
||||
let level_end = after_prefix
|
||||
.find(|c: char| c.is_whitespace())
|
||||
.unwrap_or(after_prefix.len());
|
||||
let level_str = &after_prefix[..level_end];
|
||||
|
||||
let level = ThinkingLevel::from_str_insensitive(level_str)?;
|
||||
|
||||
let remaining = after_prefix[level_end..].trim_start().to_string();
|
||||
Some((level, remaining))
|
||||
}
|
||||
|
||||
/// Convert a `ThinkingLevel` into concrete parameters for the LLM request.
|
||||
pub fn apply_thinking_level(level: ThinkingLevel) -> ThinkingParams {
|
||||
match level {
|
||||
ThinkingLevel::Off => ThinkingParams {
|
||||
temperature_adjustment: -0.2,
|
||||
max_tokens_adjustment: -1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Be extremely concise. Give direct answers without explanation \
|
||||
unless explicitly asked. No preamble."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Minimal => ThinkingParams {
|
||||
temperature_adjustment: -0.1,
|
||||
max_tokens_adjustment: -500,
|
||||
system_prompt_prefix: Some(
|
||||
"Be concise and fast. Keep explanations brief. \
|
||||
Prioritize speed over thoroughness."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Low => ThinkingParams {
|
||||
temperature_adjustment: -0.05,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: Some("Keep reasoning light. Explain only when helpful.".into()),
|
||||
},
|
||||
ThinkingLevel::Medium => ThinkingParams {
|
||||
temperature_adjustment: 0.0,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: None,
|
||||
},
|
||||
ThinkingLevel::High => ThinkingParams {
|
||||
temperature_adjustment: 0.05,
|
||||
max_tokens_adjustment: 1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think step by step. Provide thorough analysis and \
|
||||
consider edge cases before answering."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Max => ThinkingParams {
|
||||
temperature_adjustment: 0.1,
|
||||
max_tokens_adjustment: 2000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think very carefully and exhaustively. Break down the problem \
|
||||
into sub-problems, consider all angles, verify your reasoning, \
|
||||
and provide the most thorough analysis possible."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the effective thinking level using the priority hierarchy:
|
||||
/// 1. Inline directive (if present)
|
||||
/// 2. Session override (reserved, currently always `None`)
|
||||
/// 3. Agent config default
|
||||
/// 4. Global default (`Medium`)
|
||||
pub fn resolve_thinking_level(
|
||||
inline_directive: Option<ThinkingLevel>,
|
||||
session_override: Option<ThinkingLevel>,
|
||||
config: &ThinkingConfig,
|
||||
) -> ThinkingLevel {
|
||||
inline_directive
|
||||
.or(session_override)
|
||||
.unwrap_or(config.default_level)
|
||||
}
|
||||
|
||||
/// Clamp a temperature value to the valid range `[0.0, 2.0]`.
|
||||
pub fn clamp_temperature(temp: f64) -> f64 {
|
||||
temp.clamp(0.0, 2.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── ThinkingLevel parsing ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_canonical_names() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("off"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("minimal"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("low"),
|
||||
Some(ThinkingLevel::Low)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("medium"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("high"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_aliases() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("none"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("min"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("med"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("default"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("maximum"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_case_insensitive() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("HIGH"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("Max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("OFF"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_invalid_returns_none() {
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("turbo"), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive(""), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("super-high"), None);
|
||||
}
|
||||
|
||||
// ── Directive parsing ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_directive_extracts_level_and_remaining_message() {
|
||||
let result = parse_thinking_directive("/think:high What is Rust?");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::High);
|
||||
assert_eq!(remaining, "What is Rust?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_handles_directive_only() {
|
||||
let result = parse_thinking_directive("/think:off");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Off);
|
||||
assert_eq!(remaining, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_strips_leading_whitespace() {
|
||||
let result = parse_thinking_directive(" /think:low Tell me about Rust");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Low);
|
||||
assert_eq!(remaining, "Tell me about Rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_no_directive() {
|
||||
assert!(parse_thinking_directive("Hello world").is_none());
|
||||
assert!(parse_thinking_directive("").is_none());
|
||||
assert!(parse_thinking_directive("/think").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_invalid_level() {
|
||||
assert!(parse_thinking_directive("/think:turbo What?").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_not_triggered_mid_message() {
|
||||
assert!(parse_thinking_directive("Hello /think:high world").is_none());
|
||||
}
|
||||
|
||||
// ── Level application ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_off_is_concise() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Off);
|
||||
assert!(params.temperature_adjustment < 0.0);
|
||||
assert!(params.max_tokens_adjustment < 0);
|
||||
assert!(params.system_prompt_prefix.is_some());
|
||||
assert!(params
|
||||
.system_prompt_prefix
|
||||
.unwrap()
|
||||
.to_lowercase()
|
||||
.contains("concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_medium_is_neutral() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Medium);
|
||||
assert!((params.temperature_adjustment - 0.0).abs() < f64::EPSILON);
|
||||
assert_eq!(params.max_tokens_adjustment, 0);
|
||||
assert!(params.system_prompt_prefix.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_high_adds_step_by_step() {
|
||||
let params = apply_thinking_level(ThinkingLevel::High);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("step by step"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_max_is_most_thorough() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Max);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("exhaustively"));
|
||||
}
|
||||
|
||||
// ── Resolution hierarchy ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resolve_inline_directive_takes_priority() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result =
|
||||
resolve_thinking_level(Some(ThinkingLevel::Max), Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::Max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_session_override_takes_priority_over_config() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result = resolve_thinking_level(None, Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_falls_back_to_config_default() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Minimal,
|
||||
};
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Minimal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_default_config_uses_medium() {
|
||||
let config = ThinkingConfig::default();
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
// ── Temperature clamping ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_within_range() {
|
||||
assert!((clamp_temperature(0.7) - 0.7).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(0.0) - 0.0).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(2.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_below_minimum() {
|
||||
assert!((clamp_temperature(-0.5) - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_above_maximum() {
|
||||
assert!((clamp_temperature(3.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Serde round-trip ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_config_deserializes_from_toml() {
|
||||
let toml_str = r#"default_level = "high""#;
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_config_default_level_deserializes() {
|
||||
let toml_str = "";
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_serializes_lowercase() {
|
||||
let level = ThinkingLevel::High;
|
||||
let json = serde_json::to_string(&level).unwrap();
|
||||
assert_eq!(json, "\"high\"");
|
||||
}
|
||||
}
|
||||
+48
-2
@@ -122,7 +122,7 @@ impl ApprovalManager {
|
||||
}
|
||||
|
||||
// always_ask overrides everything.
|
||||
if self.always_ask.contains(tool_name) {
|
||||
if self.always_ask.contains("*") || self.always_ask.contains(tool_name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ impl ApprovalManager {
|
||||
}
|
||||
|
||||
// auto_approve skips the prompt.
|
||||
if self.auto_approve.contains(tool_name) {
|
||||
if self.auto_approve.contains("*") || self.auto_approve.contains(tool_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -562,4 +562,50 @@ mod tests {
|
||||
let parsed: ApprovalRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.tool_name, "shell");
|
||||
}
|
||||
|
||||
// ── Regression: #4247 default approved tools in channels ──
|
||||
|
||||
#[test]
|
||||
fn non_interactive_allows_default_auto_approve_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
|
||||
for tool in &config.auto_approve {
|
||||
assert!(
|
||||
!mgr.needs_approval(tool),
|
||||
"default auto_approve tool '{tool}' should not need approval in non-interactive mode"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_denies_unknown_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("some_unknown_tool"),
|
||||
"unknown tool should need approval"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_weather_is_auto_approved() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
!mgr.needs_approval("weather"),
|
||||
"weather tool must not need approval — it is in the default auto_approve list"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn always_ask_overrides_auto_approve() {
|
||||
let mut config = AutonomyConfig::default();
|
||||
config.always_ask = vec!["weather".into()];
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("weather"),
|
||||
"always_ask must override auto_approve"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+116
-1
@@ -20,6 +20,9 @@ pub struct DiscordChannel {
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@@ -38,6 +41,7 @@ impl DiscordChannel {
|
||||
mention_only,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +51,14 @@ impl DiscordChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
|
||||
}
|
||||
@@ -113,6 +125,88 @@ async fn process_attachments(
|
||||
parts.join("\n---\n")
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const DISCORD_AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check if a content type or filename indicates an audio file.
|
||||
fn is_discord_audio_attachment(content_type: &str, filename: &str) -> bool {
|
||||
if content_type.starts_with("audio/") {
|
||||
return true;
|
||||
}
|
||||
if let Some(ext) = filename.rsplit('.').next() {
|
||||
return DISCORD_AUDIO_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str());
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Download and transcribe audio attachments from a Discord message.
|
||||
///
|
||||
/// Returns transcribed text blocks for any audio attachments found.
|
||||
/// Non-audio attachments and failures are silently skipped.
|
||||
async fn transcribe_discord_audio_attachments(
|
||||
attachments: &[serde_json::Value],
|
||||
client: &reqwest::Client,
|
||||
config: &crate::config::TranscriptionConfig,
|
||||
) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
for att in attachments {
|
||||
let ct = att
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let name = att
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("file");
|
||||
|
||||
if !is_discord_audio_attachment(ct, name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(url) = att.get("url").and_then(|v| v.as_str()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let audio_data = match client.get(url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: failed to read audio attachment bytes");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Ok(resp) => {
|
||||
tracing::warn!(name, status = %resp.status(), "discord: audio attachment download failed");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: audio attachment fetch error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, name, config).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
tracing::info!(
|
||||
"Discord: transcribed audio attachment {} ({} chars)",
|
||||
name,
|
||||
trimmed.len()
|
||||
);
|
||||
parts.push(format!("[Voice] {trimmed}"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: voice transcription failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum DiscordAttachmentKind {
|
||||
Image,
|
||||
@@ -737,7 +831,28 @@ impl Channel for DiscordChannel {
|
||||
.and_then(|a| a.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
process_attachments(&atts, &self.http_client()).await
|
||||
let client = self.http_client();
|
||||
let mut text_parts = process_attachments(&atts, &client).await;
|
||||
|
||||
// Transcribe audio attachments when transcription is configured
|
||||
if let Some(ref transcription_config) = self.transcription {
|
||||
let voice_text = transcribe_discord_audio_attachments(
|
||||
&atts,
|
||||
&client,
|
||||
transcription_config,
|
||||
)
|
||||
.await;
|
||||
if !voice_text.is_empty() {
|
||||
if text_parts.is_empty() {
|
||||
text_parts = voice_text;
|
||||
} else {
|
||||
text_parts = format!("{text_parts}
|
||||
{voice_text}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_parts
|
||||
};
|
||||
let final_content = if attachment_text.is_empty() {
|
||||
clean_content
|
||||
|
||||
@@ -0,0 +1,549 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::memory::{Memory, MemoryCategory};
|
||||
|
||||
/// Discord History channel — connects via Gateway WebSocket, stores ALL non-bot messages
|
||||
/// to a dedicated discord.db, and forwards @mention messages to the agent.
|
||||
pub struct DiscordHistoryChannel {
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
/// Channel IDs to watch. Empty = watch all channels.
|
||||
channel_ids: Vec<String>,
|
||||
/// Dedicated discord.db memory backend.
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
proxy_url: Option<String>,
|
||||
/// When false, DM messages are not stored in discord.db.
|
||||
store_dms: bool,
|
||||
/// When false, @mentions in DMs are not forwarded to the agent.
|
||||
respond_to_dms: bool,
|
||||
}
|
||||
|
||||
impl DiscordHistoryChannel {
|
||||
pub fn new(
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
channel_ids: Vec<String>,
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
store_dms: bool,
|
||||
respond_to_dms: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
allowed_users,
|
||||
channel_ids,
|
||||
discord_memory,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
store_dms,
|
||||
respond_to_dms,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client(
|
||||
"channel.discord_history",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
if self.allowed_users.is_empty() {
|
||||
return true; // default open for logging channel
|
||||
}
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
fn is_channel_watched(&self, channel_id: &str) -> bool {
|
||||
self.channel_ids.is_empty() || self.channel_ids.iter().any(|c| c == channel_id)
|
||||
}
|
||||
|
||||
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
||||
let part = token.split('.').next()?;
|
||||
base64_decode(part)
|
||||
}
|
||||
|
||||
async fn resolve_channel_name(&self, channel_id: &str) -> String {
|
||||
// 1. Check persistent database (via discord_memory)
|
||||
let cache_key = format!("cache:channel_name:{}", channel_id);
|
||||
|
||||
if let Ok(Some(cached_mem)) = self.discord_memory.get(&cache_key).await {
|
||||
// Check if it's still fresh (e.g., less than 24 hours old)
|
||||
// Note: cached_mem.timestamp is an RFC3339 string
|
||||
let is_fresh =
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&cached_mem.timestamp) {
|
||||
chrono::Utc::now().signed_duration_since(ts.with_timezone(&chrono::Utc))
|
||||
< chrono::Duration::hours(24)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if is_fresh {
|
||||
return cached_mem.content.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Fetch from API (either not in DB or stale)
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}");
|
||||
let resp = self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let name = if let Ok(r) = resp {
|
||||
if let Ok(json) = r.json::<serde_json::Value>().await {
|
||||
json.get("name")
|
||||
.and_then(|n| n.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
// For DMs, there might not be a 'name', use the recipient's username if available
|
||||
json.get("recipients")
|
||||
.and_then(|r| r.as_array())
|
||||
.and_then(|a| a.first())
|
||||
.and_then(|u| u.get("username"))
|
||||
.and_then(|un| un.as_str())
|
||||
.map(|s| format!("dm-{}", s))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let resolved = name.unwrap_or_else(|| channel_id.to_string());
|
||||
|
||||
// 3. Store in persistent database
|
||||
let _ = self
|
||||
.discord_memory
|
||||
.store(
|
||||
&cache_key,
|
||||
&resolved,
|
||||
crate::memory::MemoryCategory::Custom("channel_cache".to_string()),
|
||||
Some(channel_id),
|
||||
)
|
||||
.await;
|
||||
|
||||
resolved
|
||||
}
|
||||
}
|
||||
|
||||
const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
fn base64_decode(input: &str) -> Option<String> {
|
||||
let padded = match input.len() % 4 {
|
||||
2 => format!("{input}=="),
|
||||
3 => format!("{input}="),
|
||||
_ => input.to_string(),
|
||||
};
|
||||
let mut bytes = Vec::new();
|
||||
let chars: Vec<u8> = padded.bytes().collect();
|
||||
for chunk in chars.chunks(4) {
|
||||
if chunk.len() < 4 {
|
||||
break;
|
||||
}
|
||||
let mut v = [0usize; 4];
|
||||
for (i, &b) in chunk.iter().enumerate() {
|
||||
if b == b'=' {
|
||||
v[i] = 0;
|
||||
} else {
|
||||
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
|
||||
}
|
||||
}
|
||||
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
|
||||
if chunk[2] != b'=' {
|
||||
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
|
||||
}
|
||||
if chunk[3] != b'=' {
|
||||
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
|
||||
}
|
||||
}
|
||||
String::from_utf8(bytes).ok()
|
||||
}
|
||||
|
||||
fn contains_bot_mention(content: &str, bot_user_id: &str) -> bool {
|
||||
if bot_user_id.is_empty() {
|
||||
return false;
|
||||
}
|
||||
content.contains(&format!("<@{bot_user_id}>"))
|
||||
|| content.contains(&format!("<@!{bot_user_id}>"))
|
||||
}
|
||||
|
||||
fn strip_bot_mention(content: &str, bot_user_id: &str) -> String {
|
||||
let mut result = content.to_string();
|
||||
for tag in [format!("<@{bot_user_id}>"), format!("<@!{bot_user_id}>")] {
|
||||
result = result.replace(&tag, " ");
|
||||
}
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for DiscordHistoryChannel {
|
||||
fn name(&self) -> &str {
|
||||
"discord_history"
|
||||
}
|
||||
|
||||
/// Send a reply back to Discord (used when agent responds to @mention).
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
let content = super::strip_tool_call_tags(&message.content);
|
||||
let url = format!(
|
||||
"https://discord.com/api/v10/channels/{}/messages",
|
||||
message.recipient
|
||||
);
|
||||
self.http_client()
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.json(&json!({"content": content}))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
|
||||
|
||||
// Get Gateway URL
|
||||
let gw_resp: serde_json::Value = self
|
||||
.http_client()
|
||||
.get("https://discord.com/api/v10/gateway/bot")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let gw_url = gw_resp
|
||||
.get("url")
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("wss://gateway.discord.gg");
|
||||
|
||||
let ws_url = format!("{gw_url}/?v=10&encoding=json");
|
||||
tracing::info!("DiscordHistory: connecting to gateway...");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
|
||||
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
|
||||
let heartbeat_interval = hello_data
|
||||
.get("d")
|
||||
.and_then(|d| d.get("heartbeat_interval"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(41250);
|
||||
|
||||
// Identify with intents for guild + DM messages + message content
|
||||
let identify = json!({
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.bot_token,
|
||||
"intents": 37377,
|
||||
"properties": {
|
||||
"os": "linux",
|
||||
"browser": "zeroclaw",
|
||||
"device": "zeroclaw"
|
||||
}
|
||||
}
|
||||
});
|
||||
write
|
||||
.send(Message::Text(identify.to_string().into()))
|
||||
.await?;
|
||||
|
||||
tracing::info!("DiscordHistory: connected and identified");
|
||||
|
||||
let mut sequence: i64 = -1;
|
||||
|
||||
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
|
||||
tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(std::time::Duration::from_millis(heartbeat_interval));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if hb_tx.send(()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let guild_filter = self.guild_id.clone();
|
||||
let discord_memory = Arc::clone(&self.discord_memory);
|
||||
let store_dms = self.store_dms;
|
||||
let respond_to_dms = self.respond_to_dms;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = hb_rx.recv() => {
|
||||
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
||||
let hb = json!({"op": 1, "d": d});
|
||||
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = read.next() => {
|
||||
let msg = match msg {
|
||||
Some(Ok(Message::Text(t))) => t,
|
||||
Some(Ok(Message::Ping(payload))) => {
|
||||
if write.send(Message::Pong(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Err(e)) => {
|
||||
tracing::warn!("DiscordHistory: websocket error: {e}");
|
||||
break;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) {
|
||||
sequence = s;
|
||||
}
|
||||
|
||||
let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0);
|
||||
match op {
|
||||
1 => {
|
||||
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
||||
let hb = json!({"op": 1, "d": d});
|
||||
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
7 => { tracing::warn!("DiscordHistory: Reconnect (op 7)"); break; }
|
||||
9 => { tracing::warn!("DiscordHistory: Invalid Session (op 9)"); break; }
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
|
||||
if event_type != "MESSAGE_CREATE" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(d) = event.get("d") else { continue };
|
||||
|
||||
// Skip messages from the bot itself
|
||||
let author_id = d
|
||||
.get("author")
|
||||
.and_then(|a| a.get("id"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("");
|
||||
let username = d
|
||||
.get("author")
|
||||
.and_then(|a| a.get("username"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or(author_id);
|
||||
|
||||
if author_id == bot_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip other bots
|
||||
if d.get("author")
|
||||
.and_then(|a| a.get("bot"))
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_id = d
|
||||
.get("channel_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// DM detection: DMs have no guild_id
|
||||
let is_dm_event = d.get("guild_id").and_then(serde_json::Value::as_str).is_none();
|
||||
|
||||
// Resolve channel name (with cache)
|
||||
let channel_display = if is_dm_event {
|
||||
"dm".to_string()
|
||||
} else {
|
||||
self.resolve_channel_name(&channel_id).await
|
||||
};
|
||||
|
||||
if is_dm_event && !store_dms && !respond_to_dms {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Guild filter
|
||||
if let Some(ref gid) = guild_filter {
|
||||
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str);
|
||||
if let Some(g) = msg_guild {
|
||||
if g != gid {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Channel filter
|
||||
if !self.is_channel_watched(&channel_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !self.is_user_allowed(author_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let is_mention = contains_bot_mention(content, &bot_user_id);
|
||||
|
||||
// Collect attachment URLs
|
||||
let attachments: Vec<String> = d
|
||||
.get("attachments")
|
||||
.and_then(|a| a.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|a| a.get("url").and_then(|u| u.as_str()))
|
||||
.map(|u| u.to_string())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Store messages to discord.db (skip DMs if store_dms=false)
|
||||
if (!is_dm_event || store_dms) && (!content.is_empty() || !attachments.is_empty()) {
|
||||
let ts = chrono::Utc::now().to_rfc3339();
|
||||
let mut mem_content = format!(
|
||||
"@{username} in #{channel_display} at {ts}: {content}"
|
||||
);
|
||||
if !attachments.is_empty() {
|
||||
mem_content.push_str(" [attachments: ");
|
||||
mem_content.push_str(&attachments.join(", "));
|
||||
mem_content.push(']');
|
||||
}
|
||||
let mem_key = format!(
|
||||
"discord_{}",
|
||||
if message_id.is_empty() {
|
||||
Uuid::new_v4().to_string()
|
||||
} else {
|
||||
message_id.to_string()
|
||||
}
|
||||
);
|
||||
let channel_id_for_session = if channel_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(channel_id.as_str())
|
||||
};
|
||||
if let Err(err) = discord_memory
|
||||
.store(
|
||||
&mem_key,
|
||||
&mem_content,
|
||||
MemoryCategory::Custom("discord".to_string()),
|
||||
channel_id_for_session,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("discord_history: failed to store message: {err}");
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"discord_history: stored message from @{username} in #{channel_display}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Forward @mention to agent (skip DMs if respond_to_dms=false)
|
||||
if is_mention && (!is_dm_event || respond_to_dms) {
|
||||
let clean_content = strip_bot_mention(content, &bot_user_id);
|
||||
if clean_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let channel_msg = ChannelMessage {
|
||||
id: if message_id.is_empty() {
|
||||
Uuid::new_v4().to_string()
|
||||
} else {
|
||||
format!("discord_{message_id}")
|
||||
},
|
||||
sender: author_id.to_string(),
|
||||
reply_target: if channel_id.is_empty() {
|
||||
author_id.to_string()
|
||||
} else {
|
||||
channel_id.clone()
|
||||
},
|
||||
content: clean_content,
|
||||
channel: "discord_history".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
};
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.http_client()
|
||||
.get("https://discord.com/api/v10/users/@me")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> {
|
||||
let mut guard = self.typing_handles.lock();
|
||||
if let Some(h) = guard.remove(recipient) {
|
||||
h.abort();
|
||||
}
|
||||
let client = self.http_client();
|
||||
let token = self.bot_token.clone();
|
||||
let channel_id = recipient.to_string();
|
||||
let handle = tokio::spawn(async move {
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing");
|
||||
loop {
|
||||
let _ = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {token}"))
|
||||
.send()
|
||||
.await;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(8)).await;
|
||||
}
|
||||
});
|
||||
guard.insert(recipient.to_string(), handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> {
|
||||
let mut guard = self.typing_handles.lock();
|
||||
if let Some(handle) = guard.remove(recipient) {
|
||||
handle.abort();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+535
-47
@@ -1,5 +1,6 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine as _;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use std::collections::HashMap;
|
||||
@@ -221,6 +222,21 @@ const LARK_INVALID_ACCESS_TOKEN_CODE: i64 = 99_991_663;
|
||||
/// Lark card payloads have a ~30 KB limit; leave margin for JSON envelope.
|
||||
const LARK_CARD_MARKDOWN_MAX_BYTES: usize = 28_000;
|
||||
|
||||
/// Maximum image size we will download and inline (5 MiB).
|
||||
const LARK_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
/// Maximum file size we will download and present as text (512 KiB).
|
||||
const LARK_FILE_MAX_BYTES: usize = 512 * 1024;
|
||||
|
||||
/// Image MIME types we support for inline base64 encoding.
|
||||
const LARK_SUPPORTED_IMAGE_MIMES: &[&str] = &[
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/bmp",
|
||||
];
|
||||
|
||||
/// Returns true when the WebSocket frame indicates live traffic that should
|
||||
/// refresh the heartbeat watchdog.
|
||||
fn should_refresh_last_recv(msg: &WsMsg) -> bool {
|
||||
@@ -520,6 +536,17 @@ impl LarkChannel {
|
||||
format!("{}/im/v1/messages/{message_id}/reactions", self.api_base())
|
||||
}
|
||||
|
||||
fn image_download_url(&self, image_key: &str) -> String {
|
||||
format!("{}/im/v1/images/{image_key}", self.api_base())
|
||||
}
|
||||
|
||||
fn file_download_url(&self, message_id: &str, file_key: &str) -> String {
|
||||
format!(
|
||||
"{}/im/v1/messages/{message_id}/resources/{file_key}?type=file",
|
||||
self.api_base()
|
||||
)
|
||||
}
|
||||
|
||||
fn resolved_bot_open_id(&self) -> Option<String> {
|
||||
self.resolved_bot_open_id
|
||||
.read()
|
||||
@@ -866,6 +893,44 @@ impl LarkChannel {
|
||||
Some(details) => (details.text, details.mentioned_open_ids),
|
||||
None => continue,
|
||||
},
|
||||
"image" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let image_key = match v.get("image_key").and_then(|k| k.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => { tracing::debug!("Lark WS: image message missing image_key"); continue; }
|
||||
};
|
||||
match self.download_image_as_marker(&image_key).await {
|
||||
Some(marker) => (marker, Vec::new()),
|
||||
None => {
|
||||
tracing::warn!("Lark WS: failed to download image {image_key}");
|
||||
(format!("[IMAGE:{image_key} | download failed]"), Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
"file" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let file_key = match v.get("file_key").and_then(|k| k.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => { tracing::debug!("Lark WS: file message missing file_key"); continue; }
|
||||
};
|
||||
let file_name = v.get("file_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or("unknown_file")
|
||||
.to_string();
|
||||
match self.download_file_as_content(&lark_msg.message_id, &file_key, &file_name).await {
|
||||
Some(content) => (content, Vec::new()),
|
||||
None => {
|
||||
tracing::warn!("Lark WS: failed to download file {file_key}");
|
||||
(format!("[ATTACHMENT:{file_name} | download failed]"), Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
|
||||
};
|
||||
|
||||
@@ -986,6 +1051,183 @@ impl LarkChannel {
|
||||
*cached = None;
|
||||
}
|
||||
|
||||
/// Download an image from the Lark API and return an `[IMAGE:data:...]` marker string.
|
||||
async fn download_image_as_marker(&self, image_key: &str) -> Option<String> {
|
||||
let token = match self.get_tenant_access_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: failed to get token for image download: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.image_download_url(image_key);
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: image download request failed for {image_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Lark: image download failed for {image_key}: status={}",
|
||||
resp.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(cl) = resp.content_length() {
|
||||
if cl > LARK_IMAGE_MAX_BYTES as u64 {
|
||||
tracing::warn!("Lark: image too large for {image_key}: {cl} bytes exceeds limit");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string);
|
||||
|
||||
let bytes = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: image body read failed for {image_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if bytes.is_empty() || bytes.len() > LARK_IMAGE_MAX_BYTES {
|
||||
tracing::warn!(
|
||||
"Lark: image body empty or too large for {image_key}: {} bytes",
|
||||
bytes.len()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let mime = lark_detect_image_mime(content_type.as_deref(), &bytes)?;
|
||||
if !LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
|
||||
tracing::warn!("Lark: unsupported image MIME for {image_key}: {mime}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
Some(format!("[IMAGE:data:{mime};base64,{encoded}]"))
|
||||
}
|
||||
|
||||
/// Download a file from the Lark API and return a text content marker.
|
||||
/// For text-like files, the content is inlined. For binary files, a summary is returned.
|
||||
async fn download_file_as_content(
|
||||
&self,
|
||||
message_id: &str,
|
||||
file_key: &str,
|
||||
file_name: &str,
|
||||
) -> Option<String> {
|
||||
let token = match self.get_tenant_access_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: failed to get token for file download: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.file_download_url(message_id, file_key);
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: file download request failed for {file_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Lark: file download failed for {file_key}: status={}",
|
||||
resp.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(cl) = resp.content_length() {
|
||||
if cl > LARK_FILE_MAX_BYTES as u64 {
|
||||
tracing::warn!("Lark: file too large for {file_key}: {cl} bytes exceeds limit");
|
||||
return Some(format!(
|
||||
"[ATTACHMENT:{file_name} | size={cl} bytes | too large to inline]"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let bytes = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: file body read failed for {file_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if bytes.is_empty() {
|
||||
tracing::warn!("Lark: file body is empty for {file_key}");
|
||||
return None;
|
||||
}
|
||||
|
||||
// If the content is image-like, return as image marker
|
||||
if content_type.starts_with("image/") && bytes.len() <= LARK_IMAGE_MAX_BYTES {
|
||||
if let Some(mime) = lark_detect_image_mime(Some(&content_type), &bytes) {
|
||||
if LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
return Some(format!("[IMAGE:data:{mime};base64,{encoded}]"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the file looks like text, inline it
|
||||
if bytes.len() <= LARK_FILE_MAX_BYTES
|
||||
&& !bytes.contains(&0)
|
||||
&& (content_type.starts_with("text/")
|
||||
|| content_type.contains("json")
|
||||
|| content_type.contains("xml")
|
||||
|| content_type.contains("yaml")
|
||||
|| content_type.contains("javascript")
|
||||
|| content_type.contains("csv")
|
||||
|| lark_is_text_filename(file_name))
|
||||
{
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
let truncated = if text.len() > 50_000 {
|
||||
format!("{}...\n[truncated]", &text[..50_000])
|
||||
} else {
|
||||
text.into_owned()
|
||||
};
|
||||
let ext = file_name.rsplit('.').next().unwrap_or("text");
|
||||
return Some(format!("[FILE:{file_name}]\n```{ext}\n{truncated}\n```"));
|
||||
}
|
||||
|
||||
Some(format!(
|
||||
"[ATTACHMENT:{file_name} | mime={content_type} | size={} bytes]",
|
||||
bytes.len()
|
||||
))
|
||||
}
|
||||
|
||||
async fn fetch_bot_open_id_with_token(
|
||||
&self,
|
||||
token: &str,
|
||||
@@ -1085,8 +1327,9 @@ impl LarkChannel {
|
||||
Ok((status, parsed))
|
||||
}
|
||||
|
||||
/// Parse an event callback payload and extract text messages
|
||||
pub fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
/// Parse an event callback payload and extract messages.
|
||||
/// Supports text, post, image, and file message types.
|
||||
pub async fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
// Lark event v2 structure:
|
||||
@@ -1143,6 +1386,11 @@ impl LarkChannel {
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let evt_message_id = event
|
||||
.pointer("/message/message_id")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let (text, post_mentioned_open_ids): (String, Vec<String>) = match msg_type {
|
||||
"text" => {
|
||||
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
@@ -1162,6 +1410,62 @@ impl LarkChannel {
|
||||
Some(details) => (details.text, details.mentioned_open_ids),
|
||||
None => return messages,
|
||||
},
|
||||
"image" => {
|
||||
let image_key = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("image_key")
|
||||
.and_then(|k| k.as_str())
|
||||
.map(String::from)
|
||||
});
|
||||
match image_key {
|
||||
Some(key) => {
|
||||
let marker = match self.download_image_as_marker(&key).await {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!("Lark: failed to download image {key}");
|
||||
format!("[IMAGE:{key} | download failed]")
|
||||
}
|
||||
};
|
||||
(marker, Vec::new())
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("Lark: image message missing image_key");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
"file" => {
|
||||
let parsed = serde_json::from_str::<serde_json::Value>(content_str).ok();
|
||||
let file_key = parsed
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("file_key").and_then(|k| k.as_str()))
|
||||
.map(String::from);
|
||||
let file_name = parsed
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("file_name").and_then(|n| n.as_str()))
|
||||
.unwrap_or("unknown_file")
|
||||
.to_string();
|
||||
match file_key {
|
||||
Some(key) => {
|
||||
let content = match self
|
||||
.download_file_as_content(evt_message_id, &key, &file_name)
|
||||
.await
|
||||
{
|
||||
Some(c) => c,
|
||||
None => {
|
||||
tracing::warn!("Lark: failed to download file {key}");
|
||||
format!("[ATTACHMENT:{file_name} | download failed]")
|
||||
}
|
||||
};
|
||||
(content, Vec::new())
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("Lark: file message missing file_key");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
|
||||
return messages;
|
||||
@@ -1305,7 +1609,7 @@ impl LarkChannel {
|
||||
}
|
||||
|
||||
// Parse event messages
|
||||
let messages = state.channel.parse_event_payload(&payload);
|
||||
let messages = state.channel.parse_event_payload(&payload).await;
|
||||
if !messages.is_empty() {
|
||||
if let Some(message_id) = payload
|
||||
.pointer("/event/message/message_id")
|
||||
@@ -1556,6 +1860,72 @@ fn detect_lark_ack_locale(
|
||||
detect_locale_from_text(fallback_text).unwrap_or(LarkAckLocale::En)
|
||||
}
|
||||
|
||||
/// Detect image MIME type from magic bytes, falling back to Content-Type header.
|
||||
fn lark_detect_image_mime(content_type: Option<&str>, bytes: &[u8]) -> Option<String> {
|
||||
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
|
||||
return Some("image/png".to_string());
|
||||
}
|
||||
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
|
||||
return Some("image/jpeg".to_string());
|
||||
}
|
||||
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
|
||||
return Some("image/gif".to_string());
|
||||
}
|
||||
if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
|
||||
return Some("image/webp".to_string());
|
||||
}
|
||||
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
|
||||
return Some("image/bmp".to_string());
|
||||
}
|
||||
content_type
|
||||
.and_then(|ct| ct.split(';').next())
|
||||
.map(|ct| ct.trim().to_lowercase())
|
||||
.filter(|ct| ct.starts_with("image/"))
|
||||
}
|
||||
|
||||
/// Check if a filename looks like a text file based on extension.
|
||||
fn lark_is_text_filename(name: &str) -> bool {
|
||||
let ext = name.rsplit('.').next().unwrap_or("").to_ascii_lowercase();
|
||||
matches!(
|
||||
ext.as_str(),
|
||||
"txt"
|
||||
| "md"
|
||||
| "rs"
|
||||
| "py"
|
||||
| "js"
|
||||
| "ts"
|
||||
| "tsx"
|
||||
| "jsx"
|
||||
| "java"
|
||||
| "c"
|
||||
| "h"
|
||||
| "cpp"
|
||||
| "hpp"
|
||||
| "go"
|
||||
| "rb"
|
||||
| "sh"
|
||||
| "bash"
|
||||
| "zsh"
|
||||
| "toml"
|
||||
| "yaml"
|
||||
| "yml"
|
||||
| "json"
|
||||
| "xml"
|
||||
| "html"
|
||||
| "css"
|
||||
| "sql"
|
||||
| "csv"
|
||||
| "tsv"
|
||||
| "log"
|
||||
| "cfg"
|
||||
| "ini"
|
||||
| "conf"
|
||||
| "env"
|
||||
| "dockerfile"
|
||||
| "makefile"
|
||||
)
|
||||
}
|
||||
|
||||
fn random_lark_ack_reaction(
|
||||
payload: Option<&serde_json::Value>,
|
||||
fallback_text: &str,
|
||||
@@ -1892,8 +2262,8 @@ mod tests {
|
||||
assert!(!ch.is_user_allowed("ou_anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_challenge() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_challenge() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"challenge": "abc123",
|
||||
@@ -1901,12 +2271,12 @@ mod tests {
|
||||
"type": "url_verification"
|
||||
});
|
||||
// Challenge payloads should not produce messages
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_valid_text_message() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_valid_text_message() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": {
|
||||
@@ -1927,7 +2297,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
|
||||
assert_eq!(msgs[0].sender, "oc_chat123");
|
||||
@@ -1935,8 +2305,8 @@ mod tests {
|
||||
assert_eq!(msgs[0].timestamp, 1_699_999_999);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_unauthorized_user() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unauthorized_user() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -1951,12 +2321,38 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_non_text_message_skipped() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unsupported_message_type_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_type": "sticker",
|
||||
"content": "{}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_image_missing_key_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -1977,12 +2373,38 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_empty_text_skipped() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_file_missing_key_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_type": "file",
|
||||
"content": "{}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_empty_text_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2003,24 +2425,24 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_wrong_event_type() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_wrong_event_type() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.chat.disbanded_v1" },
|
||||
"event": {}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_missing_sender() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_missing_sender() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2040,12 +2462,12 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_unicode_message() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unicode_message() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2067,24 +2489,24 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "Hello world 🌍");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_missing_event() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_missing_event() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" }
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_invalid_content_json() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_invalid_content_json() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2105,7 +2527,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
@@ -2237,8 +2659,8 @@ mod tests {
|
||||
assert_eq!(ch.name(), "feishu");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_fallback_sender_to_open_id() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_fallback_sender_to_open_id() {
|
||||
// When chat_id is missing, sender should fall back to open_id
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
@@ -2260,13 +2682,13 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "ou_user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_message_requires_bot_mention_when_enabled() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_message_requires_bot_mention_when_enabled() {
|
||||
let ch = with_bot_open_id(
|
||||
LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
@@ -2292,7 +2714,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert!(ch.parse_event_payload(&no_mention_payload).is_empty());
|
||||
assert!(ch.parse_event_payload(&no_mention_payload).await.is_empty());
|
||||
|
||||
let wrong_mention_payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -2307,7 +2729,10 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert!(ch.parse_event_payload(&wrong_mention_payload).is_empty());
|
||||
assert!(ch
|
||||
.parse_event_payload(&wrong_mention_payload)
|
||||
.await
|
||||
.is_empty());
|
||||
|
||||
let bot_mention_payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -2322,11 +2747,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert_eq!(ch.parse_event_payload(&bot_mention_payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&bot_mention_payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
|
||||
let ch = with_bot_open_id(
|
||||
LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
@@ -2353,11 +2778,11 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_message_allows_without_mention_when_disabled() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_message_allows_without_mention_when_disabled() {
|
||||
let ch = LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
"secret".into(),
|
||||
@@ -2381,7 +2806,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2409,6 +2834,69 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_image_download_url_matches_region() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(
|
||||
ch.image_download_url("img_abc123"),
|
||||
"https://open.larksuite.com/open-apis/im/v1/images/img_abc123"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_file_download_url_matches_region() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(
|
||||
ch.file_download_url("om_msg123", "file_abc"),
|
||||
"https://open.larksuite.com/open-apis/im/v1/messages/om_msg123/resources/file_abc?type=file"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_detect_image_mime_from_magic_bytes() {
|
||||
let png = [0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, &png).as_deref(),
|
||||
Some("image/png")
|
||||
);
|
||||
|
||||
let jpeg = [0xff, 0xd8, 0xff, 0xe0];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, &jpeg).as_deref(),
|
||||
Some("image/jpeg")
|
||||
);
|
||||
|
||||
let gif = b"GIF89a...";
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, gif).as_deref(),
|
||||
Some("image/gif")
|
||||
);
|
||||
|
||||
// Unknown bytes should fall back to content-type header
|
||||
let unknown = [0x00, 0x01, 0x02];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(Some("image/webp"), &unknown).as_deref(),
|
||||
Some("image/webp")
|
||||
);
|
||||
|
||||
// Non-image content-type should be rejected
|
||||
assert_eq!(lark_detect_image_mime(Some("text/html"), &unknown), None);
|
||||
|
||||
// No info at all should return None
|
||||
assert_eq!(lark_detect_image_mime(None, &unknown), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_is_text_filename_recognizes_common_extensions() {
|
||||
assert!(lark_is_text_filename("script.py"));
|
||||
assert!(lark_is_text_filename("config.toml"));
|
||||
assert!(lark_is_text_filename("data.csv"));
|
||||
assert!(lark_is_text_filename("README.md"));
|
||||
assert!(!lark_is_text_filename("image.png"));
|
||||
assert!(!lark_is_text_filename("archive.zip"));
|
||||
assert!(!lark_is_text_filename("binary.exe"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_reaction_locale_explicit_language_tags() {
|
||||
assert_eq!(map_locale_tag("zh-CN"), Some(LarkAckLocale::ZhCn));
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
//! Link enricher: auto-detects URLs in inbound messages, fetches their content,
|
||||
//! and prepends summaries so the agent has link context without explicit tool calls.
|
||||
|
||||
use regex::Regex;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for the link enricher pipeline stage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinkEnricherConfig {
|
||||
pub enabled: bool,
|
||||
pub max_links: usize,
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for LinkEnricherConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// URL regex: matches http:// and https:// URLs, stopping at whitespace, angle
|
||||
/// brackets, or double-quotes.
|
||||
static URL_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r#"https?://[^\s<>"']+"#).expect("URL regex must compile"));
|
||||
|
||||
/// Extract URLs from message text, returning up to `max` unique URLs.
|
||||
pub fn extract_urls(text: &str, max: usize) -> Vec<String> {
|
||||
let mut seen = Vec::new();
|
||||
for m in URL_RE.find_iter(text) {
|
||||
let url = m.as_str().to_string();
|
||||
if !seen.contains(&url) {
|
||||
seen.push(url);
|
||||
if seen.len() >= max {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
/// Returns `true` if the URL points to a private/local address that should be
|
||||
/// blocked for SSRF protection.
|
||||
pub fn is_ssrf_target(url: &str) -> bool {
|
||||
let host = match extract_host(url) {
|
||||
Some(h) => h,
|
||||
None => return true, // unparseable URLs are rejected
|
||||
};
|
||||
|
||||
// Check hostname-based locals
|
||||
if host == "localhost"
|
||||
|| host.ends_with(".localhost")
|
||||
|| host.ends_with(".local")
|
||||
|| host == "local"
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check IP-based private ranges
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return is_private_ip(ip);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Extract the host portion from a URL string.
|
||||
fn extract_host(url: &str) -> Option<String> {
|
||||
let rest = url
|
||||
.strip_prefix("https://")
|
||||
.or_else(|| url.strip_prefix("http://"))?;
|
||||
let authority = rest.split(['/', '?', '#']).next()?;
|
||||
if authority.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Strip port
|
||||
let host = if authority.starts_with('[') {
|
||||
// IPv6 in brackets — reject for simplicity
|
||||
return None;
|
||||
} else {
|
||||
authority.split(':').next().unwrap_or(authority)
|
||||
};
|
||||
Some(host.to_lowercase())
|
||||
}
|
||||
|
||||
/// Check if an IP address falls within private/reserved ranges.
|
||||
fn is_private_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
v4.is_loopback() // 127.0.0.0/8
|
||||
|| v4.is_private() // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
|
||||
|| v4.is_link_local() // 169.254.0.0/16
|
||||
|| v4.is_unspecified() // 0.0.0.0
|
||||
|| v4.is_broadcast() // 255.255.255.255
|
||||
|| v4.is_multicast() // 224.0.0.0/4
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
v6.is_loopback() // ::1
|
||||
|| v6.is_unspecified() // ::
|
||||
|| v6.is_multicast()
|
||||
// Check for IPv4-mapped IPv6 addresses
|
||||
|| v6.to_ipv4_mapped().is_some_and(|v4| {
|
||||
v4.is_loopback()
|
||||
|| v4.is_private()
|
||||
|| v4.is_link_local()
|
||||
|| v4.is_unspecified()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the `<title>` tag content from HTML.
|
||||
pub fn extract_title(html: &str) -> Option<String> {
|
||||
// Case-insensitive search for <title>...</title>
|
||||
let lower = html.to_lowercase();
|
||||
let start = lower.find("<title")? + "<title".len();
|
||||
// Skip attributes if any (e.g. <title lang="en">)
|
||||
let start = lower[start..].find('>')? + start + 1;
|
||||
let end = lower[start..].find("</title")? + start;
|
||||
let title = lower[start..end].trim().to_string();
|
||||
if title.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(html_entity_decode_basic(&title))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the first `max_chars` of visible body text from HTML.
|
||||
pub fn extract_body_text(html: &str, max_chars: usize) -> String {
|
||||
let text = nanohtml2text::html2text(html);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.len() <= max_chars {
|
||||
trimmed.to_string()
|
||||
} else {
|
||||
let mut result: String = trimmed.chars().take(max_chars).collect();
|
||||
result.push_str("...");
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Basic HTML entity decoding for title content.
|
||||
fn html_entity_decode_basic(s: &str) -> String {
|
||||
s.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
}
|
||||
|
||||
/// Summary of a fetched link.
|
||||
struct LinkSummary {
|
||||
title: String,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
/// Fetch a single URL and extract a summary. Returns `None` on any failure.
|
||||
async fn fetch_link_summary(url: &str, timeout_secs: u64) -> Option<LinkSummary> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.redirect(reqwest::redirect::Policy::limited(5))
|
||||
.user_agent("ZeroClaw/0.1 (link-enricher)")
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let response = client.get(url).send().await.ok()?;
|
||||
if !response.status().is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Only process text/html responses
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
if !content_type.contains("text/html") && !content_type.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read up to 256KB to extract title and snippet
|
||||
let max_bytes: usize = 256 * 1024;
|
||||
let bytes = response.bytes().await.ok()?;
|
||||
let body = if bytes.len() > max_bytes {
|
||||
String::from_utf8_lossy(&bytes[..max_bytes]).into_owned()
|
||||
} else {
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
};
|
||||
|
||||
let title = extract_title(&body).unwrap_or_else(|| "Untitled".to_string());
|
||||
let snippet = extract_body_text(&body, 200);
|
||||
|
||||
Some(LinkSummary { title, snippet })
|
||||
}
|
||||
|
||||
/// Enrich a message by prepending link summaries for any URLs found in the text.
|
||||
///
|
||||
/// This is the main entry point called from the channel message processing pipeline.
|
||||
/// If the enricher is disabled or no URLs are found, the original message is returned
|
||||
/// unchanged.
|
||||
pub async fn enrich_message(content: &str, config: &LinkEnricherConfig) -> String {
|
||||
if !config.enabled || config.max_links == 0 {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let urls = extract_urls(content, config.max_links);
|
||||
if urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
// Filter out SSRF targets
|
||||
let safe_urls: Vec<&str> = urls
|
||||
.iter()
|
||||
.filter(|u| !is_ssrf_target(u))
|
||||
.map(|u| u.as_str())
|
||||
.collect();
|
||||
if safe_urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let mut enrichments = Vec::new();
|
||||
for url in safe_urls {
|
||||
match fetch_link_summary(url, config.timeout_secs).await {
|
||||
Some(summary) => {
|
||||
enrichments.push(format!("[Link: {} — {}]", summary.title, summary.snippet));
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(url, "Link enricher: failed to fetch or extract summary");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if enrichments.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let prefix = enrichments.join("\n");
|
||||
format!("{prefix}\n{content}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── URL extraction ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_urls_finds_http_and_https() {
|
||||
let text = "Check https://example.com and http://test.org/page for info";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com", "http://test.org/page",]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_respects_max() {
|
||||
let text = "https://a.com https://b.com https://c.com https://d.com";
|
||||
let urls = extract_urls(text, 2);
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert_eq!(urls[0], "https://a.com");
|
||||
assert_eq!(urls[1], "https://b.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_deduplicates() {
|
||||
let text = "Visit https://example.com and https://example.com again";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_handles_no_urls() {
|
||||
let text = "Just a normal message without links";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert!(urls.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_angle_brackets() {
|
||||
let text = "Link: <https://example.com/path> done";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/path"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_quotes() {
|
||||
let text = r#"href="https://example.com/page" end"#;
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/page"]);
|
||||
}
|
||||
|
||||
// ── SSRF protection ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_localhost() {
|
||||
assert!(is_ssrf_target("http://localhost/admin"));
|
||||
assert!(is_ssrf_target("https://localhost:8080/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_loopback_ip() {
|
||||
assert!(is_ssrf_target("http://127.0.0.1/secret"));
|
||||
assert!(is_ssrf_target("http://127.0.0.2:9090"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_10_network() {
|
||||
assert!(is_ssrf_target("http://10.0.0.1/internal"));
|
||||
assert!(is_ssrf_target("http://10.255.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_172_network() {
|
||||
assert!(is_ssrf_target("http://172.16.0.1/admin"));
|
||||
assert!(is_ssrf_target("http://172.31.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_192_168_network() {
|
||||
assert!(is_ssrf_target("http://192.168.1.1/router"));
|
||||
assert!(is_ssrf_target("http://192.168.0.100:3000"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_link_local() {
|
||||
assert!(is_ssrf_target("http://169.254.0.1/metadata"));
|
||||
assert!(is_ssrf_target("http://169.254.169.254/latest"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_ipv6_loopback() {
|
||||
// IPv6 in brackets is rejected by extract_host
|
||||
assert!(is_ssrf_target("http://[::1]/admin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_dot_local() {
|
||||
assert!(is_ssrf_target("http://myhost.local/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_allows_public_urls() {
|
||||
assert!(!is_ssrf_target("https://example.com/page"));
|
||||
assert!(!is_ssrf_target("https://www.google.com"));
|
||||
assert!(!is_ssrf_target("http://93.184.216.34/resource"));
|
||||
}
|
||||
|
||||
// ── Title extraction ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_title_basic() {
|
||||
let html = "<html><head><title>My Page Title</title></head><body>Hello</body></html>";
|
||||
assert_eq!(extract_title(html), Some("my page title".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_with_entities() {
|
||||
let html = "<title>Tom & Jerry's Page</title>";
|
||||
assert_eq!(extract_title(html), Some("tom & jerry's page".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_case_insensitive() {
|
||||
let html = "<HTML><HEAD><TITLE>Upper Case</TITLE></HEAD></HTML>";
|
||||
assert_eq!(extract_title(html), Some("upper case".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_multibyte_chars_no_panic() {
|
||||
// İ (U+0130) lowercases to 2 chars, changing byte length.
|
||||
// This must not panic or produce wrong offsets.
|
||||
let html = "<title>İstanbul Guide</title>";
|
||||
let result = extract_title(html);
|
||||
assert!(result.is_some());
|
||||
let title = result.unwrap();
|
||||
assert!(title.contains("stanbul"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_missing() {
|
||||
let html = "<html><body>No title here</body></html>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_empty() {
|
||||
let html = "<title> </title>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
// ── Body text extraction ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_strips_html() {
|
||||
let html = "<html><body><h1>Header</h1><p>Some content here</p></body></html>";
|
||||
let text = extract_body_text(html, 200);
|
||||
assert!(text.contains("Header"));
|
||||
assert!(text.contains("Some content"));
|
||||
assert!(!text.contains("<h1>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_truncates() {
|
||||
let html = "<p>A very long paragraph that should be truncated to fit within the limit.</p>";
|
||||
let text = extract_body_text(html, 20);
|
||||
assert!(text.len() <= 25); // 20 chars + "..."
|
||||
assert!(text.ends_with("..."));
|
||||
}
|
||||
|
||||
// ── Config toggle ───────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_disabled_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Check https://example.com for details";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_no_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "No links in this message";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_ssrf_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Try http://127.0.0.1/admin and http://192.168.1.1/router";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_is_disabled() {
|
||||
let config = LinkEnricherConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_links, 3);
|
||||
assert_eq!(config.timeout_secs, 10);
|
||||
}
|
||||
}
|
||||
+210
-7
@@ -8,6 +8,7 @@ use matrix_sdk::{
|
||||
events::reaction::ReactionEventContent,
|
||||
events::receipt::ReceiptThread,
|
||||
events::relation::{Annotation, Thread},
|
||||
events::room::member::StrippedRoomMemberEvent,
|
||||
events::room::message::{
|
||||
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||
},
|
||||
@@ -32,6 +33,7 @@ pub struct MatrixChannel {
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
session_owner_hint: Option<String>,
|
||||
session_device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
@@ -48,6 +50,7 @@ impl std::fmt::Debug for MatrixChannel {
|
||||
.field("homeserver", &self.homeserver)
|
||||
.field("room_id", &self.room_id)
|
||||
.field("allowed_users", &self.allowed_users)
|
||||
.field("allowed_rooms", &self.allowed_rooms)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -121,7 +124,16 @@ impl MatrixChannel {
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint(homeserver, access_token, room_id, allowed_users, None, None)
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_with_session_hint(
|
||||
@@ -132,11 +144,12 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint_and_zeroclaw_dir(
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
None,
|
||||
@@ -151,6 +164,28 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
zeroclaw_dir,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_full(
|
||||
homeserver: String,
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
let homeserver = homeserver.trim_end_matches('/').to_string();
|
||||
let access_token = access_token.trim().to_string();
|
||||
@@ -160,12 +195,18 @@ impl MatrixChannel {
|
||||
.map(|user| user.trim().to_string())
|
||||
.filter(|user| !user.is_empty())
|
||||
.collect();
|
||||
let allowed_rooms = allowed_rooms
|
||||
.into_iter()
|
||||
.map(|room| room.trim().to_string())
|
||||
.filter(|room| !room.is_empty())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms,
|
||||
session_owner_hint: Self::normalize_optional_field(owner_hint),
|
||||
session_device_id_hint: Self::normalize_optional_field(device_id_hint),
|
||||
zeroclaw_dir,
|
||||
@@ -220,6 +261,21 @@ impl MatrixChannel {
|
||||
allowed_users.iter().any(|u| u.eq_ignore_ascii_case(sender))
|
||||
}
|
||||
|
||||
/// Check whether a room (by its canonical ID) is in the allowed_rooms list.
|
||||
/// If allowed_rooms is empty, all rooms are allowed.
|
||||
fn is_room_allowed_static(allowed_rooms: &[String], room_id: &str) -> bool {
|
||||
if allowed_rooms.is_empty() {
|
||||
return true;
|
||||
}
|
||||
allowed_rooms
|
||||
.iter()
|
||||
.any(|r| r.eq_ignore_ascii_case(room_id))
|
||||
}
|
||||
|
||||
fn is_room_allowed(&self, room_id: &str) -> bool {
|
||||
Self::is_room_allowed_static(&self.allowed_rooms, room_id)
|
||||
}
|
||||
|
||||
fn is_supported_message_type(msgtype: &str) -> bool {
|
||||
matches!(msgtype, "m.text" | "m.notice")
|
||||
}
|
||||
@@ -228,6 +284,10 @@ impl MatrixChannel {
|
||||
!body.trim().is_empty()
|
||||
}
|
||||
|
||||
fn room_matches_target(target_room_id: &str, incoming_room_id: &str) -> bool {
|
||||
target_room_id == incoming_room_id
|
||||
}
|
||||
|
||||
fn cache_event_id(
|
||||
event_id: &str,
|
||||
recent_order: &mut std::collections::VecDeque<String>,
|
||||
@@ -526,8 +586,9 @@ impl MatrixChannel {
|
||||
if client.encryption().backups().are_enabled().await {
|
||||
tracing::info!("Matrix room-key backup is enabled for this device.");
|
||||
} else {
|
||||
client.encryption().backups().disable().await;
|
||||
tracing::warn!(
|
||||
"Matrix room-key backup is not enabled for this device; `matrix_sdk_crypto::backups` warnings about missing backup keys may appear until recovery is configured."
|
||||
"Matrix room-key backup is not enabled for this device; automatic backup attempts have been disabled to suppress recurring warnings. To enable backups, configure server-side key backup and recovery for this device."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -697,6 +758,7 @@ impl Channel for MatrixChannel {
|
||||
let target_room_for_handler = target_room.clone();
|
||||
let my_user_id_for_handler = my_user_id.clone();
|
||||
let allowed_users_for_handler = self.allowed_users.clone();
|
||||
let allowed_rooms_for_handler = self.allowed_rooms.clone();
|
||||
let dedupe_for_handler = Arc::clone(&recent_event_cache);
|
||||
let homeserver_for_handler = self.homeserver.clone();
|
||||
let access_token_for_handler = self.access_token.clone();
|
||||
@@ -704,18 +766,29 @@ impl Channel for MatrixChannel {
|
||||
|
||||
client.add_event_handler(move |event: OriginalSyncRoomMessageEvent, room: Room| {
|
||||
let tx = tx_handler.clone();
|
||||
let _target_room = target_room_for_handler.clone();
|
||||
let target_room = target_room_for_handler.clone();
|
||||
let my_user_id = my_user_id_for_handler.clone();
|
||||
let allowed_users = allowed_users_for_handler.clone();
|
||||
let allowed_rooms = allowed_rooms_for_handler.clone();
|
||||
let dedupe = Arc::clone(&dedupe_for_handler);
|
||||
let homeserver = homeserver_for_handler.clone();
|
||||
let access_token = access_token_for_handler.clone();
|
||||
let voice_mode = Arc::clone(&voice_mode_for_handler);
|
||||
|
||||
async move {
|
||||
if false
|
||||
/* multi-room: room_id filter disabled */
|
||||
{
|
||||
if !MatrixChannel::room_matches_target(
|
||||
target_room.as_str(),
|
||||
room.room_id().as_str(),
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Room allowlist: skip messages from rooms not in the configured list
|
||||
if !MatrixChannel::is_room_allowed_static(&allowed_rooms, room.room_id().as_ref()) {
|
||||
tracing::debug!(
|
||||
"Matrix: ignoring message from room {} (not in allowed_rooms)",
|
||||
room.room_id()
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -907,6 +980,45 @@ impl Channel for MatrixChannel {
|
||||
}
|
||||
});
|
||||
|
||||
// Invite handler: auto-accept invites for allowed rooms, auto-reject others
|
||||
let allowed_rooms_for_invite = self.allowed_rooms.clone();
|
||||
client.add_event_handler(move |event: StrippedRoomMemberEvent, room: Room| {
|
||||
let allowed_rooms = allowed_rooms_for_invite.clone();
|
||||
async move {
|
||||
// Only process invite events targeting us
|
||||
if event.content.membership
|
||||
!= matrix_sdk::ruma::events::room::member::MembershipState::Invite
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let room_id_str = room.room_id().to_string();
|
||||
|
||||
if MatrixChannel::is_room_allowed_static(&allowed_rooms, &room_id_str) {
|
||||
// Room is allowed (or no allowlist configured): auto-accept
|
||||
tracing::info!(
|
||||
"Matrix: auto-accepting invite for allowed room {}",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.join().await {
|
||||
tracing::warn!("Matrix: failed to auto-join room {}: {error}", room_id_str);
|
||||
}
|
||||
} else {
|
||||
// Room is NOT in allowlist: auto-reject
|
||||
tracing::info!(
|
||||
"Matrix: auto-rejecting invite for room {} (not in allowed_rooms)",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.leave().await {
|
||||
tracing::warn!(
|
||||
"Matrix: failed to reject invite for room {}: {error}",
|
||||
room_id_str
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sync_settings = SyncSettings::new().timeout(std::time::Duration::from_secs(30));
|
||||
client
|
||||
.sync_with_result_callback(sync_settings, |sync_result| {
|
||||
@@ -1294,6 +1406,22 @@ mod tests {
|
||||
assert_eq!(value["room"]["timeline"]["limit"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_matches_configured_room() {
|
||||
assert!(MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!ops:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_rejects_other_rooms() {
|
||||
assert!(!MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!other:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_id_cache_deduplicates_and_evicts_old_entries() {
|
||||
let mut recent_order = std::collections::VecDeque::new();
|
||||
@@ -1549,4 +1677,79 @@ mod tests {
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.rooms.join.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowed_rooms_permits_all() {
|
||||
let ch = make_channel();
|
||||
assert!(ch.is_room_allowed("!any:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!other:evil.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_filters_by_id() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec!["!allowed:matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!allowed:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!forbidden:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_supports_aliases() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec![
|
||||
"#ops:matrix.org".to_string(),
|
||||
"!direct:matrix.org".to_string(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!direct:matrix.org"));
|
||||
assert!(ch.is_room_allowed("#ops:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!other:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_case_insensitive() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec!["!Room:Matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!ROOM:MATRIX.ORG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_trims_whitespace() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec![" !room:matrix.org ".to_string(), " ".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(ch.allowed_rooms.len(), 1);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
}
|
||||
}
|
||||
|
||||
+272
-50
@@ -19,11 +19,14 @@ pub mod clawdtalk;
|
||||
pub mod cli;
|
||||
pub mod dingtalk;
|
||||
pub mod discord;
|
||||
pub mod discord_history;
|
||||
pub mod email_channel;
|
||||
pub mod gmail_push;
|
||||
pub mod imessage;
|
||||
pub mod irc;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
pub mod lark;
|
||||
pub mod link_enricher;
|
||||
pub mod linq;
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
pub mod matrix;
|
||||
@@ -45,6 +48,8 @@ pub mod traits;
|
||||
pub mod transcription;
|
||||
pub mod tts;
|
||||
pub mod twitter;
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub mod voice_wake;
|
||||
pub mod wati;
|
||||
pub mod webhook;
|
||||
pub mod wecom;
|
||||
@@ -59,7 +64,9 @@ pub use clawdtalk::{ClawdTalkChannel, ClawdTalkConfig};
|
||||
pub use cli::CliChannel;
|
||||
pub use dingtalk::DingTalkChannel;
|
||||
pub use discord::DiscordChannel;
|
||||
pub use discord_history::DiscordHistoryChannel;
|
||||
pub use email_channel::EmailChannel;
|
||||
pub use gmail_push::GmailPushChannel;
|
||||
pub use imessage::IMessageChannel;
|
||||
pub use irc::IrcChannel;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
@@ -82,6 +89,8 @@ pub use traits::{Channel, SendMessage};
|
||||
#[allow(unused_imports)]
|
||||
pub use tts::{TtsManager, TtsProvider};
|
||||
pub use twitter::TwitterChannel;
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub use voice_wake::VoiceWakeChannel;
|
||||
pub use wati::WatiChannel;
|
||||
pub use webhook::WebhookChannel;
|
||||
pub use wecom::WeComChannel;
|
||||
@@ -222,9 +231,21 @@ fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
|
||||
fn channel_message_timeout_budget_secs(
|
||||
message_timeout_secs: u64,
|
||||
max_tool_iterations: usize,
|
||||
) -> u64 {
|
||||
channel_message_timeout_budget_secs_with_cap(
|
||||
message_timeout_secs,
|
||||
max_tool_iterations,
|
||||
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP,
|
||||
)
|
||||
}
|
||||
|
||||
fn channel_message_timeout_budget_secs_with_cap(
|
||||
message_timeout_secs: u64,
|
||||
max_tool_iterations: usize,
|
||||
scale_cap: u64,
|
||||
) -> u64 {
|
||||
let iterations = max_tool_iterations.max(1) as u64;
|
||||
let scale = iterations.min(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
|
||||
let scale = iterations.min(scale_cap);
|
||||
message_timeout_secs.saturating_mul(scale)
|
||||
}
|
||||
|
||||
@@ -362,6 +383,7 @@ struct ChannelRuntimeContext {
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
cost_tracking: Option<ChannelCostTrackingState>,
|
||||
pacing: crate::config::PacingConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -2045,6 +2067,25 @@ async fn process_channel_message(
|
||||
msg
|
||||
};
|
||||
|
||||
// ── Link enricher: prepend URL summaries before agent sees the message ──
|
||||
let le_config = &ctx.prompt_config.link_enricher;
|
||||
if le_config.enabled {
|
||||
let enricher_cfg = link_enricher::LinkEnricherConfig {
|
||||
enabled: le_config.enabled,
|
||||
max_links: le_config.max_links,
|
||||
timeout_secs: le_config.timeout_secs,
|
||||
};
|
||||
let enriched = link_enricher::enrich_message(&msg.content, &enricher_cfg).await;
|
||||
if enriched != msg.content {
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Link enricher: prepended URL summaries to message"
|
||||
);
|
||||
msg.content = enriched;
|
||||
}
|
||||
}
|
||||
|
||||
let target_channel = ctx
|
||||
.channels_by_name
|
||||
.get(&msg.channel)
|
||||
@@ -2402,8 +2443,15 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let model_switch_callback = get_model_switch_state();
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
let scale_cap = ctx
|
||||
.pacing
|
||||
.message_timeout_scale_max
|
||||
.unwrap_or(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
|
||||
let timeout_budget_secs = channel_message_timeout_budget_secs_with_cap(
|
||||
ctx.message_timeout_secs,
|
||||
ctx.max_tool_iterations,
|
||||
scale_cap,
|
||||
);
|
||||
let cost_tracking_context = ctx.cost_tracking.clone().map(|state| {
|
||||
crate::agent::loop_::ToolLoopCostTrackingContext::new(state.tracker, state.prices)
|
||||
});
|
||||
@@ -2445,6 +2493,7 @@ async fn process_channel_message(
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&ctx.pacing,
|
||||
),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
@@ -3107,9 +3156,12 @@ pub fn build_system_prompt_with_mode(
|
||||
Some(&autonomy_cfg),
|
||||
native_tools,
|
||||
skills_prompt_mode,
|
||||
false,
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
workspace_dir: &std::path::Path,
|
||||
model_name: &str,
|
||||
@@ -3120,6 +3172,8 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
autonomy_config: Option<&crate::config::AutonomyConfig>,
|
||||
native_tools: bool,
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode,
|
||||
compact_context: bool,
|
||||
max_system_prompt_chars: usize,
|
||||
) -> String {
|
||||
use std::fmt::Write;
|
||||
let mut prompt = String::with_capacity(8192);
|
||||
@@ -3146,11 +3200,19 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
// ── 1. Tooling ──────────────────────────────────────────────
|
||||
if !tools.is_empty() {
|
||||
prompt.push_str("## Tools\n\n");
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
if compact_context {
|
||||
// Compact mode: tool names only, no descriptions/schemas
|
||||
prompt.push_str("Available tools: ");
|
||||
let names: Vec<&str> = tools.iter().map(|(name, _)| *name).collect();
|
||||
prompt.push_str(&names.join(", "));
|
||||
prompt.push_str("\n\n");
|
||||
} else {
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
// ── 1b. Hardware (when gpio/arduino tools present) ───────────
|
||||
@@ -3294,11 +3356,13 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
std::env::consts::OS,
|
||||
);
|
||||
|
||||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
|
||||
// ── 8. Channel Capabilities (skipped in compact_context mode) ──
|
||||
if !compact_context {
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||||
prompt
|
||||
.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
|
||||
Some(crate::security::AutonomyLevel::Full) => {
|
||||
"- If the runtime policy already allows a tool, use it directly; do not ask the user for extra approval.\n\
|
||||
- Never pretend you are waiting for a human approval click or confirmation when the runtime policy already permits the action.\n\
|
||||
@@ -3312,10 +3376,23 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
- If there is no approval path for this channel or the runtime blocks an action, explain that restriction directly instead of simulating an approval flow.\n"
|
||||
}
|
||||
});
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
|
||||
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
|
||||
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
|
||||
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
|
||||
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
|
||||
} // end if !compact_context (Channel Capabilities)
|
||||
|
||||
// ── 9. Truncation (max_system_prompt_chars budget) ──────────
|
||||
if max_system_prompt_chars > 0 && prompt.len() > max_system_prompt_chars {
|
||||
// Truncate on a char boundary, keeping the top portion (identity + safety).
|
||||
let mut end = max_system_prompt_chars;
|
||||
// Ensure we don't split a multi-byte UTF-8 character.
|
||||
while !prompt.is_char_boundary(end) && end > 0 {
|
||||
end -= 1;
|
||||
}
|
||||
prompt.truncate(end);
|
||||
prompt.push_str("\n\n[System prompt truncated to fit context budget]\n");
|
||||
}
|
||||
|
||||
if prompt.is_empty() {
|
||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct."
|
||||
@@ -3613,13 +3690,16 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
.discord
|
||||
.as_ref()
|
||||
.context("Discord channel is not configured")?;
|
||||
Ok(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)))
|
||||
Ok(Arc::new(
|
||||
DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
"slack" => {
|
||||
let sl = config
|
||||
@@ -3635,7 +3715,8 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
other => anyhow::bail!("Unknown channel '{other}'. Supported: telegram, discord, slack"),
|
||||
@@ -3721,11 +3802,37 @@ fn collect_configured_channels(
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_proxy_url(dc.proxy_url.clone()),
|
||||
.with_proxy_url(dc.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref dh) = config.channels_config.discord_history {
|
||||
match crate::memory::SqliteMemory::new_named(&config.workspace_dir, "discord") {
|
||||
Ok(discord_mem) => {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Discord History",
|
||||
channel: Arc::new(
|
||||
DiscordHistoryChannel::new(
|
||||
dh.bot_token.clone(),
|
||||
dh.guild_id.clone(),
|
||||
dh.allowed_users.clone(),
|
||||
dh.channel_ids.clone(),
|
||||
Arc::new(discord_mem),
|
||||
dh.store_dms,
|
||||
dh.respond_to_dms,
|
||||
)
|
||||
.with_proxy_url(dh.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("discord_history: failed to open discord.db: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sl) = config.channels_config.slack {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Slack",
|
||||
@@ -3740,7 +3847,8 @@ fn collect_configured_channels(
|
||||
.with_thread_replies(sl.thread_replies.unwrap_or(true))
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(sl.proxy_url.clone()),
|
||||
.with_proxy_url(sl.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3773,11 +3881,12 @@ fn collect_configured_channels(
|
||||
if let Some(ref mx) = config.channels_config.matrix {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Matrix",
|
||||
channel: Arc::new(MatrixChannel::new_with_session_hint_and_zeroclaw_dir(
|
||||
channel: Arc::new(MatrixChannel::new_full(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
mx.allowed_rooms.clone(),
|
||||
mx.user_id.clone(),
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
@@ -3917,6 +4026,15 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref gp_cfg) = config.channels_config.gmail_push {
|
||||
if gp_cfg.enabled {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Gmail Push",
|
||||
channel: Arc::new(GmailPushChannel::new(gp_cfg.clone())),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "IRC",
|
||||
@@ -4093,6 +4211,17 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
if let Some(ref vw) = config.channels_config.voice_wake {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "VoiceWake",
|
||||
channel: Arc::new(VoiceWakeChannel::new(
|
||||
vw.clone(),
|
||||
config.transcription.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref wh) = config.channels_config.webhook {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Webhook",
|
||||
@@ -4243,22 +4372,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let (mut built_tools, delegate_handle_ch): (Vec<Box<dyn Tool>>, _) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
);
|
||||
let (mut built_tools, delegate_handle_ch, reaction_handle_ch) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
|
||||
@@ -4431,6 +4560,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
if !native_tools {
|
||||
system_prompt.push_str(&build_tool_instructions(
|
||||
@@ -4530,6 +4661,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
);
|
||||
|
||||
// Populate the reaction tool's channel map now that channels are initialized.
|
||||
if let Some(ref handle) = reaction_handle_ch {
|
||||
let mut map = handle.write();
|
||||
for (name, ch) in channels_by_name.as_ref() {
|
||||
map.insert(name.clone(), Arc::clone(ch));
|
||||
}
|
||||
}
|
||||
|
||||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||||
|
||||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||||
@@ -4641,6 +4781,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
tracker,
|
||||
prices: Arc::new(config.cost.prices.clone()),
|
||||
}),
|
||||
pacing: config.pacing.clone(),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
@@ -4737,6 +4878,49 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_message_timeout_budget_with_custom_scale_cap() {
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 8, 8),
|
||||
300 * 8
|
||||
);
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 20, 8),
|
||||
300 * 8
|
||||
);
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 10, 1),
|
||||
300
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pacing_config_defaults_preserve_existing_behavior() {
|
||||
let pacing = crate::config::PacingConfig::default();
|
||||
assert!(pacing.step_timeout_secs.is_none());
|
||||
assert!(pacing.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(pacing.loop_ignore_tools.is_empty());
|
||||
assert!(pacing.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pacing_message_timeout_scale_max_overrides_default_cap() {
|
||||
// Custom cap of 8 scales budget proportionally
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 10, 8),
|
||||
300 * 8
|
||||
);
|
||||
// Default cap produces the standard behavior
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(
|
||||
300,
|
||||
10,
|
||||
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
|
||||
),
|
||||
300 * CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_window_overflow_error_detector_matches_known_messages() {
|
||||
let overflow_err = anyhow::anyhow!(
|
||||
@@ -4941,6 +5125,7 @@ mod tests {
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -5057,6 +5242,7 @@ mod tests {
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -5129,6 +5315,7 @@ mod tests {
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -5220,6 +5407,7 @@ mod tests {
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(
|
||||
@@ -5761,6 +5949,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5842,6 +6031,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5937,6 +6127,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6017,6 +6208,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6107,6 +6299,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6218,6 +6411,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6310,6 +6504,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6417,6 +6612,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6509,6 +6705,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6591,6 +6788,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6704,6 +6902,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: "2026-02-20T00:00:00Z".to_string(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -6788,6 +6989,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -6890,6 +7092,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7007,6 +7210,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7121,6 +7325,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7217,6 +7422,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7297,6 +7503,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7519,9 +7726,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(prompt.contains("<instructions>"));
|
||||
assert!(prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
assert!(!prompt.contains("loaded on demand"));
|
||||
}
|
||||
|
||||
@@ -7564,10 +7771,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -7689,6 +7896,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
Some(&config),
|
||||
false,
|
||||
crate::config::SkillsPromptInjectionMode::Full,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
|
||||
assert!(
|
||||
@@ -7718,6 +7927,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
Some(&config),
|
||||
false,
|
||||
crate::config::SkillsPromptInjectionMode::Full,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
|
||||
assert!(
|
||||
@@ -8063,6 +8274,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8194,6 +8406,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8365,6 +8578,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8473,6 +8687,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9045,6 +9260,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -9132,6 +9348,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9294,6 +9511,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9405,6 +9623,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9508,6 +9727,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9631,6 +9851,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9892,6 +10113,7 @@ This is an example JSON object for profile settings."#;
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
|
||||
@@ -12,6 +12,8 @@ use chrono::{DateTime, Utc};
|
||||
pub struct SessionMetadata {
|
||||
/// Session key (e.g. `telegram_user123`).
|
||||
pub key: String,
|
||||
/// Optional human-readable name (e.g. `eyrie-commander-briefing`).
|
||||
pub name: Option<String>,
|
||||
/// When the session was first created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the last message was appended.
|
||||
@@ -54,6 +56,7 @@ pub trait SessionBackend: Send + Sync {
|
||||
let messages = self.load(&key);
|
||||
SessionMetadata {
|
||||
key,
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: messages.len(),
|
||||
@@ -81,6 +84,16 @@ pub trait SessionBackend: Send + Sync {
|
||||
fn delete_session(&self, _session_key: &str) -> std::io::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Set or update the human-readable name for a session.
|
||||
fn set_session_name(&self, _session_key: &str, _name: &str) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the human-readable name for a session (if set).
|
||||
fn get_session_name(&self, _session_key: &str) -> std::io::Result<Option<String>> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -91,6 +104,7 @@ mod tests {
|
||||
fn session_metadata_is_constructible() {
|
||||
let meta = SessionMetadata {
|
||||
key: "test".into(),
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: 5,
|
||||
|
||||
@@ -51,7 +51,8 @@ impl SqliteSessionBackend {
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
last_activity TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
message_count INTEGER NOT NULL DEFAULT 0,
|
||||
name TEXT
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
|
||||
@@ -69,6 +70,18 @@ impl SqliteSessionBackend {
|
||||
)
|
||||
.context("Failed to initialize session schema")?;
|
||||
|
||||
// Migration: add name column to existing databases
|
||||
let has_name: bool = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) > 0 FROM pragma_table_info('session_metadata') WHERE name = 'name'",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap_or(false);
|
||||
if !has_name {
|
||||
let _ = conn.execute("ALTER TABLE session_metadata ADD COLUMN name TEXT", []);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
@@ -226,7 +239,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_key, created_at, last_activity, message_count
|
||||
"SELECT session_key, created_at, last_activity, message_count, name
|
||||
FROM session_metadata ORDER BY last_activity DESC",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
@@ -238,6 +251,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
let created_str: String = row.get(1)?;
|
||||
let activity_str: String = row.get(2)?;
|
||||
let count: i64 = row.get(3)?;
|
||||
let name: Option<String> = row.get(4)?;
|
||||
|
||||
let created = DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
@@ -249,6 +263,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(SessionMetadata {
|
||||
key,
|
||||
name,
|
||||
created_at: created,
|
||||
last_activity: activity,
|
||||
message_count: count as usize,
|
||||
@@ -321,6 +336,27 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn set_session_name(&self, session_key: &str, name: &str) -> std::io::Result<()> {
|
||||
let conn = self.conn.lock();
|
||||
let name_val = if name.is_empty() { None } else { Some(name) };
|
||||
conn.execute(
|
||||
"UPDATE session_metadata SET name = ?1 WHERE session_key = ?2",
|
||||
params![name_val, session_key],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_session_name(&self, session_key: &str) -> std::io::Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
conn.query_row(
|
||||
"SELECT name FROM session_metadata WHERE session_key = ?1",
|
||||
params![session_key],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
|
||||
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
let Some(keyword) = &query.keyword else {
|
||||
return self.list_sessions_with_metadata();
|
||||
@@ -357,14 +393,16 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
keys.iter()
|
||||
.filter_map(|key| {
|
||||
conn.query_row(
|
||||
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
|
||||
"SELECT created_at, last_activity, message_count, name FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
|row| {
|
||||
let created_str: String = row.get(0)?;
|
||||
let activity_str: String = row.get(1)?;
|
||||
let count: i64 = row.get(2)?;
|
||||
let name: Option<String> = row.get(3)?;
|
||||
Ok(SessionMetadata {
|
||||
key: key.clone(),
|
||||
name,
|
||||
created_at: DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
@@ -555,4 +593,55 @@ mod tests {
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_persists() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "My Session").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert_eq!(meta[0].name.as_deref(), Some("My Session"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_updates_existing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "First").unwrap();
|
||||
backend.set_session_name("s1", "Second").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta[0].name.as_deref(), Some("Second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sessions_without_name_return_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_name_clears_to_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "Named").unwrap();
|
||||
backend.set_session_name("s1", "").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ pub struct SlackChannel {
|
||||
active_assistant_thread: Mutex<HashMap<String, String>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio file attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@@ -125,6 +128,7 @@ impl SlackChannel {
|
||||
workspace_dir: None,
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +162,14 @@ impl SlackChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio file attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client_with_timeouts(
|
||||
"channel.slack",
|
||||
@@ -558,6 +570,13 @@ impl SlackChannel {
|
||||
.await
|
||||
.unwrap_or_else(|| raw_file.clone());
|
||||
|
||||
// Voice / audio transcription: if transcription is configured and the
|
||||
// file looks like an audio attachment, download and transcribe it.
|
||||
if Self::is_audio_file(&file) {
|
||||
if let Some(transcribed) = self.try_transcribe_audio_file(&file).await {
|
||||
return Some(transcribed);
|
||||
}
|
||||
}
|
||||
if Self::is_image_file(&file) {
|
||||
if let Some(marker) = self.fetch_image_marker(&file).await {
|
||||
return Some(marker);
|
||||
@@ -1449,6 +1468,106 @@ impl SlackChannel {
|
||||
.is_some_and(|ext| Self::mime_from_extension(ext).is_some())
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check whether a Slack file object looks like an audio attachment
|
||||
/// (voice memo, audio message, or uploaded audio file).
|
||||
fn is_audio_file(file: &serde_json::Value) -> bool {
|
||||
// Slack voice messages use subtype "slack_audio"
|
||||
if let Some(subtype) = file.get("subtype").and_then(|v| v.as_str()) {
|
||||
if subtype == "slack_audio" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if Self::slack_file_mime(file)
|
||||
.as_deref()
|
||||
.is_some_and(|mime| mime.starts_with("audio/"))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(ft) = file
|
||||
.get("filetype")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|v| v.to_ascii_lowercase())
|
||||
{
|
||||
if Self::AUDIO_EXTENSIONS.contains(&ft.as_str()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Self::file_extension(&Self::slack_file_name(file))
|
||||
.as_deref()
|
||||
.is_some_and(|ext| Self::AUDIO_EXTENSIONS.contains(&ext))
|
||||
}
|
||||
|
||||
/// Download an audio file attachment and transcribe it using the configured
|
||||
/// transcription provider. Returns `None` if transcription is not configured
|
||||
/// or if the download/transcription fails.
|
||||
async fn try_transcribe_audio_file(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let file_name = Self::slack_file_name(file);
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
let resp = self.fetch_slack_private_file(url).await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
tracing::warn!(
|
||||
"Slack voice file download failed for {} ({status})",
|
||||
redacted_url
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let audio_data = match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice file read failed for {}: {e}", redacted_url);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Determine a filename with extension for the transcription API.
|
||||
let transcription_filename = if Self::file_extension(&file_name).is_some() {
|
||||
file_name.clone()
|
||||
} else {
|
||||
// Fall back to extension from mimetype or default to .ogg
|
||||
let mime_ext = Self::slack_file_mime(file)
|
||||
.and_then(|mime| mime.rsplit('/').next().map(|s| s.to_string()))
|
||||
.unwrap_or_else(|| "ogg".to_string());
|
||||
format!("voice.{mime_ext}")
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, &transcription_filename, config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::info!("Slack voice transcription returned empty text, skipping");
|
||||
None
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Slack: transcribed voice file {} ({} chars)",
|
||||
file_name,
|
||||
trimmed.len()
|
||||
);
|
||||
Some(format!("[Voice] {trimmed}"))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice transcription failed for {}: {e}", file_name);
|
||||
Some(Self::format_attachment_summary(file))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_text_snippet(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
@@ -1140,6 +1140,11 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content = format!("{quote}\n\n{content}");
|
||||
}
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
content = format!("{attr}{content}");
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1263,6 +1268,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
format!("[Voice] {text}")
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1299,6 +1311,41 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
(username, sender_id, sender_identity)
|
||||
}
|
||||
|
||||
/// Build a forwarding attribution prefix from Telegram forward fields.
|
||||
///
|
||||
/// Returns `Some("[Forwarded from ...] ")` when the message is forwarded,
|
||||
/// `None` otherwise.
|
||||
fn format_forward_attribution(message: &serde_json::Value) -> Option<String> {
|
||||
if let Some(from_chat) = message.get("forward_from_chat") {
|
||||
// Forwarded from a channel or group
|
||||
let title = from_chat
|
||||
.get("title")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown channel");
|
||||
Some(format!("[Forwarded from channel: {title}] "))
|
||||
} else if let Some(from_user) = message.get("forward_from") {
|
||||
// Forwarded from a user (privacy allows identity)
|
||||
let label = from_user
|
||||
.get("username")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|u| format!("@{u}"))
|
||||
.or_else(|| {
|
||||
from_user
|
||||
.get("first_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(String::from)
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
Some(format!("[Forwarded from {label}] "))
|
||||
} else {
|
||||
// Forwarded from a user who hides their identity
|
||||
message
|
||||
.get("forward_sender_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|name| format!("[Forwarded from {name}] "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract reply context from a Telegram `reply_to_message`, if present.
|
||||
fn extract_reply_context(&self, message: &serde_json::Value) -> Option<String> {
|
||||
let reply = message.get("reply_to_message")?;
|
||||
@@ -1420,6 +1467,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
// Exit voice-chat mode when user switches back to typing
|
||||
if let Ok(mut vc) = self.voice_chats.lock() {
|
||||
vc.remove(&reply_target);
|
||||
@@ -4871,4 +4925,153 @@ mod tests {
|
||||
TelegramChannel::new("token".into(), vec!["*".into()], false).with_ack_reactions(true);
|
||||
assert!(ch.ack_reactions);
|
||||
}
|
||||
|
||||
// ── Forwarded message tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_with_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 100,
|
||||
"message": {
|
||||
"message_id": 50,
|
||||
"text": "Check this out",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"first_name": "Bob",
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from @bob] Check this out");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_channel() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 101,
|
||||
"message": {
|
||||
"message_id": 51,
|
||||
"text": "Breaking news",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from_chat": {
|
||||
"id": -1_001_234_567_890_i64,
|
||||
"title": "Daily News",
|
||||
"username": "dailynews",
|
||||
"type": "channel"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("channel-forwarded message should parse");
|
||||
assert_eq!(
|
||||
msg.content,
|
||||
"[Forwarded from channel: Daily News] Breaking news"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_hidden_sender() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 102,
|
||||
"message": {
|
||||
"message_id": 52,
|
||||
"text": "Secret tip",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_sender_name": "Hidden User",
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("hidden-sender forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Hidden User] Secret tip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_non_forwarded_unaffected() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 103,
|
||||
"message": {
|
||||
"message_id": 53,
|
||||
"text": "Normal message",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 }
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("non-forwarded message should parse");
|
||||
assert_eq!(msg.content, "Normal message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_no_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 104,
|
||||
"message": {
|
||||
"message_id": 54,
|
||||
"text": "Hello there",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 77,
|
||||
"first_name": "Charlie"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message without username should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Charlie] Hello there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarded_photo_attachment_has_attribution() {
|
||||
// Verify that format_forward_attribution produces correct prefix
|
||||
// for a photo message (the actual download is async, so we test the
|
||||
// helper directly with a photo-bearing message structure).
|
||||
let message = serde_json::json!({
|
||||
"message_id": 60,
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"photo": [
|
||||
{ "file_id": "abc123", "file_unique_id": "u1", "width": 320, "height": 240 }
|
||||
],
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
});
|
||||
|
||||
let attr =
|
||||
TelegramChannel::format_forward_attribution(&message).expect("should detect forward");
|
||||
assert_eq!(attr, "[Forwarded from @bob] ");
|
||||
|
||||
// Simulate what try_parse_attachment_message does after building content
|
||||
let photo_content = "[IMAGE:/tmp/photo.jpg]".to_string();
|
||||
let content = format!("{attr}{photo_content}");
|
||||
assert_eq!(content, "[Forwarded from @bob] [IMAGE:/tmp/photo.jpg]");
|
||||
}
|
||||
}
|
||||
|
||||
+130
-1
@@ -1,6 +1,7 @@
|
||||
//! Multi-provider Text-to-Speech (TTS) subsystem.
|
||||
//!
|
||||
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, and Edge TTS (free, subprocess-based).
|
||||
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, Edge TTS (free, subprocess-based),
|
||||
//! and Piper TTS (local GPU-accelerated, OpenAI-compatible endpoint).
|
||||
//! Provider selection is driven by [`TtsConfig`] in `config.toml`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
@@ -451,6 +452,80 @@ impl TtsProvider for EdgeTtsProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Piper TTS (local, OpenAI-compatible) ─────────────────────────
|
||||
|
||||
/// Piper TTS provider — local GPU-accelerated server with an OpenAI-compatible endpoint.
|
||||
pub struct PiperTtsProvider {
|
||||
client: reqwest::Client,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl PiperTtsProvider {
|
||||
/// Create a new Piper TTS provider pointing at the given API URL.
|
||||
pub fn new(api_url: &str) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.timeout(TTS_HTTP_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to build HTTP client for Piper TTS"),
|
||||
api_url: api_url.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TtsProvider for PiperTtsProvider {
|
||||
fn name(&self) -> &str {
|
||||
"piper"
|
||||
}
|
||||
|
||||
async fn synthesize(&self, text: &str, voice: &str) -> Result<Vec<u8>> {
|
||||
let body = serde_json::json!({
|
||||
"model": "tts-1",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&self.api_url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send Piper TTS request")?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let error_body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.unwrap_or_else(|_| serde_json::json!({"error": "unknown"}));
|
||||
let msg = error_body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown error");
|
||||
bail!("Piper TTS API error ({}): {}", status, msg);
|
||||
}
|
||||
|
||||
let bytes = resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("Failed to read Piper TTS response body")?;
|
||||
Ok(bytes.to_vec())
|
||||
}
|
||||
|
||||
fn supported_voices(&self) -> Vec<String> {
|
||||
// Piper voices depend on installed models; return empty (dynamic).
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn supported_formats(&self) -> Vec<String> {
|
||||
["mp3", "wav", "opus"]
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── TtsManager ───────────────────────────────────────────────────
|
||||
|
||||
/// Central manager for multi-provider TTS synthesis.
|
||||
@@ -510,6 +585,11 @@ impl TtsManager {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref piper_cfg) = config.piper {
|
||||
let provider = PiperTtsProvider::new(&piper_cfg.api_url);
|
||||
providers.insert("piper".to_string(), Box::new(provider));
|
||||
}
|
||||
|
||||
let max_text_length = if config.max_text_length == 0 {
|
||||
DEFAULT_MAX_TEXT_LENGTH
|
||||
} else {
|
||||
@@ -652,6 +732,54 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn piper_provider_creation() {
|
||||
let provider = PiperTtsProvider::new("http://127.0.0.1:5000/v1/audio/speech");
|
||||
assert_eq!(provider.name(), "piper");
|
||||
assert_eq!(provider.api_url, "http://127.0.0.1:5000/v1/audio/speech");
|
||||
assert_eq!(provider.supported_formats(), vec!["mp3", "wav", "opus"]);
|
||||
// Piper voices depend on installed models; list is empty.
|
||||
assert!(provider.supported_voices().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tts_manager_with_piper_provider() {
|
||||
let mut config = default_tts_config();
|
||||
config.default_provider = "piper".to_string();
|
||||
config.piper = Some(crate::config::PiperTtsConfig {
|
||||
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
|
||||
});
|
||||
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
assert_eq!(manager.available_providers(), vec!["piper"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tts_rejects_empty_text_for_piper() {
|
||||
let mut config = default_tts_config();
|
||||
config.default_provider = "piper".to_string();
|
||||
config.piper = Some(crate::config::PiperTtsConfig {
|
||||
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
|
||||
});
|
||||
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
let err = manager
|
||||
.synthesize_with_provider("", "piper", "default")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not be empty"),
|
||||
"expected empty-text error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn piper_not_registered_when_config_is_none() {
|
||||
let config = default_tts_config();
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
assert!(!manager.available_providers().contains(&"piper".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tts_config_defaults() {
|
||||
let config = TtsConfig::default();
|
||||
@@ -664,6 +792,7 @@ mod tests {
|
||||
assert!(config.elevenlabs.is_none());
|
||||
assert!(config.google.is_none());
|
||||
assert!(config.edge.is_none());
|
||||
assert!(config.piper.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,531 @@
|
||||
//! Voice Wake Word detection channel.
|
||||
//!
|
||||
//! Listens on the default microphone via `cpal`, detects a configurable wake
|
||||
//! word using energy-based VAD followed by transcription-based keyword matching,
|
||||
//! then captures the subsequent utterance and dispatches it as a channel message.
|
||||
//!
|
||||
//! Gated behind the `voice-wake` Cargo feature.
|
||||
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::channels::transcription::transcribe_audio;
|
||||
use crate::config::schema::VoiceWakeConfig;
|
||||
use crate::config::TranscriptionConfig;
|
||||
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
|
||||
// ── State machine ──────────────────────────────────────────────
|
||||
|
||||
/// Internal states for the wake-word detector.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WakeState {
|
||||
/// Passively monitoring microphone energy levels.
|
||||
Listening,
|
||||
/// Energy spike detected — capturing a short window to check for wake word.
|
||||
Triggered,
|
||||
/// Wake word confirmed — capturing the full utterance that follows.
|
||||
Capturing,
|
||||
/// Captured audio is being transcribed.
|
||||
Processing,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WakeState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Listening => write!(f, "Listening"),
|
||||
Self::Triggered => write!(f, "Triggered"),
|
||||
Self::Capturing => write!(f, "Capturing"),
|
||||
Self::Processing => write!(f, "Processing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Channel implementation ─────────────────────────────────────
|
||||
|
||||
/// Voice wake-word channel that activates on a spoken keyword.
|
||||
pub struct VoiceWakeChannel {
|
||||
config: VoiceWakeConfig,
|
||||
transcription_config: TranscriptionConfig,
|
||||
}
|
||||
|
||||
impl VoiceWakeChannel {
|
||||
/// Create a new `VoiceWakeChannel` from its config sections.
|
||||
pub fn new(config: VoiceWakeConfig, transcription_config: TranscriptionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
transcription_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for VoiceWakeChannel {
|
||||
fn name(&self) -> &str {
|
||||
"voice_wake"
|
||||
}
|
||||
|
||||
async fn send(&self, _message: &SendMessage) -> Result<()> {
|
||||
// Voice wake is input-only; outbound messages are not supported.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
let config = self.config.clone();
|
||||
let transcription_config = self.transcription_config.clone();
|
||||
|
||||
// Run the blocking audio capture loop on a dedicated thread.
|
||||
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(4);
|
||||
|
||||
let energy_threshold = config.energy_threshold;
|
||||
let silence_timeout = Duration::from_millis(u64::from(config.silence_timeout_ms));
|
||||
let max_capture = Duration::from_secs(u64::from(config.max_capture_secs));
|
||||
let sample_rate: u32;
|
||||
let channels_count: u16;
|
||||
|
||||
// ── Initialise cpal stream ────────────────────────────
|
||||
{
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.ok_or_else(|| anyhow::anyhow!("No default audio input device available"))?;
|
||||
|
||||
let supported = device.default_input_config()?;
|
||||
sample_rate = supported.sample_rate().0;
|
||||
channels_count = supported.channels();
|
||||
|
||||
info!(
|
||||
device = ?device.name().unwrap_or_default(),
|
||||
sample_rate,
|
||||
channels = channels_count,
|
||||
"VoiceWake: opening audio input"
|
||||
);
|
||||
|
||||
let stream_config: cpal::StreamConfig = supported.into();
|
||||
let audio_tx_clone = audio_tx.clone();
|
||||
|
||||
let stream = device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
// Non-blocking: try_send and drop if full.
|
||||
let _ = audio_tx_clone.try_send(data.to_vec());
|
||||
},
|
||||
move |err| {
|
||||
warn!("VoiceWake: audio stream error: {err}");
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
|
||||
stream.play()?;
|
||||
|
||||
// Keep the stream alive for the lifetime of the channel.
|
||||
// We leak it intentionally — the channel runs until the daemon shuts down.
|
||||
std::mem::forget(stream);
|
||||
}
|
||||
|
||||
// Drop the extra sender so the channel closes when the stream sender drops.
|
||||
drop(audio_tx);
|
||||
|
||||
// ── Main detection loop ───────────────────────────────
|
||||
let wake_word = config.wake_word.to_lowercase();
|
||||
let mut state = WakeState::Listening;
|
||||
let mut capture_buf: Vec<f32> = Vec::new();
|
||||
let mut last_voice_at = Instant::now();
|
||||
let mut capture_start = Instant::now();
|
||||
let mut msg_counter: u64 = 0;
|
||||
|
||||
info!(wake_word = %wake_word, "VoiceWake: entering listen loop");
|
||||
|
||||
while let Some(chunk) = audio_rx.recv().await {
|
||||
let energy = compute_rms_energy(&chunk);
|
||||
|
||||
match state {
|
||||
WakeState::Listening => {
|
||||
if energy >= energy_threshold {
|
||||
debug!(
|
||||
energy,
|
||||
"VoiceWake: energy spike — transitioning to Triggered"
|
||||
);
|
||||
state = WakeState::Triggered;
|
||||
capture_buf.clear();
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
last_voice_at = Instant::now();
|
||||
capture_start = Instant::now();
|
||||
}
|
||||
}
|
||||
WakeState::Triggered => {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
}
|
||||
|
||||
let since_voice = last_voice_at.elapsed();
|
||||
let since_start = capture_start.elapsed();
|
||||
|
||||
// After enough silence or max time, transcribe to check for wake word.
|
||||
if since_voice >= silence_timeout || since_start >= max_capture {
|
||||
debug!("VoiceWake: Triggered window closed — transcribing for wake word");
|
||||
|
||||
let wav_bytes =
|
||||
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
|
||||
|
||||
match transcribe_audio(wav_bytes, "wake_check.wav", &transcription_config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let lower = text.to_lowercase();
|
||||
if lower.contains(&wake_word) {
|
||||
info!(text = %text, "VoiceWake: wake word detected — capturing utterance");
|
||||
state = WakeState::Capturing;
|
||||
capture_buf.clear();
|
||||
last_voice_at = Instant::now();
|
||||
capture_start = Instant::now();
|
||||
} else {
|
||||
debug!(text = %text, "VoiceWake: no wake word — back to Listening");
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VoiceWake: transcription error during wake check: {e}");
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WakeState::Capturing => {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
}
|
||||
|
||||
let since_voice = last_voice_at.elapsed();
|
||||
let since_start = capture_start.elapsed();
|
||||
|
||||
if since_voice >= silence_timeout || since_start >= max_capture {
|
||||
debug!("VoiceWake: utterance capture complete — transcribing");
|
||||
|
||||
let wav_bytes =
|
||||
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
|
||||
|
||||
match transcribe_audio(wav_bytes, "utterance.wav", &transcription_config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
msg_counter += 1;
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: format!("voice_wake_{msg_counter}"),
|
||||
sender: "voice_user".into(),
|
||||
reply_target: "voice_user".into(),
|
||||
content: trimmed,
|
||||
channel: "voice_wake".into(),
|
||||
timestamp: ts,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
};
|
||||
|
||||
if let Err(e) = tx.send(msg).await {
|
||||
warn!("VoiceWake: failed to dispatch message: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VoiceWake: transcription error for utterance: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
WakeState::Processing => {
|
||||
// Should not receive chunks while processing, but just buffer them.
|
||||
// State transitions happen above synchronously after transcription.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bail!("VoiceWake: audio stream ended unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Audio utilities ────────────────────────────────────────────
|
||||
|
||||
/// Compute RMS (root-mean-square) energy of an audio chunk.
|
||||
pub fn compute_rms_energy(samples: &[f32]) -> f32 {
|
||||
if samples.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
|
||||
(sum_sq / samples.len() as f32).sqrt()
|
||||
}
|
||||
|
||||
/// Encode raw f32 PCM samples as a WAV byte buffer (16-bit PCM).
|
||||
///
|
||||
/// This produces a minimal valid WAV file that Whisper-compatible APIs accept.
|
||||
pub fn encode_wav_from_f32(samples: &[f32], sample_rate: u32, channels: u16) -> Vec<u8> {
|
||||
let bits_per_sample: u16 = 16;
|
||||
let byte_rate = u32::from(channels) * sample_rate * u32::from(bits_per_sample) / 8;
|
||||
let block_align = channels * bits_per_sample / 8;
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let data_len = (samples.len() * 2) as u32; // 16-bit = 2 bytes per sample; max ~25 MB
|
||||
let file_len = 36 + data_len;
|
||||
|
||||
let mut buf = Vec::with_capacity(file_len as usize + 8);
|
||||
|
||||
// RIFF header
|
||||
buf.extend_from_slice(b"RIFF");
|
||||
buf.extend_from_slice(&file_len.to_le_bytes());
|
||||
buf.extend_from_slice(b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
buf.extend_from_slice(b"fmt ");
|
||||
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
|
||||
buf.extend_from_slice(&channels.to_le_bytes());
|
||||
buf.extend_from_slice(&sample_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&byte_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&block_align.to_le_bytes());
|
||||
buf.extend_from_slice(&bits_per_sample.to_le_bytes());
|
||||
|
||||
// data chunk
|
||||
buf.extend_from_slice(b"data");
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
|
||||
for &sample in samples {
|
||||
let clamped = sample.clamp(-1.0, 1.0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let pcm16 = (clamped * 32767.0) as i16; // clamped to [-1,1] so fits i16
|
||||
buf.extend_from_slice(&pcm16.to_le_bytes());
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::traits::ChannelConfig;
|
||||
|
||||
// ── State machine tests ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wake_state_display() {
|
||||
assert_eq!(WakeState::Listening.to_string(), "Listening");
|
||||
assert_eq!(WakeState::Triggered.to_string(), "Triggered");
|
||||
assert_eq!(WakeState::Capturing.to_string(), "Capturing");
|
||||
assert_eq!(WakeState::Processing.to_string(), "Processing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wake_state_equality() {
|
||||
assert_eq!(WakeState::Listening, WakeState::Listening);
|
||||
assert_ne!(WakeState::Listening, WakeState::Triggered);
|
||||
}
|
||||
|
||||
// ── Energy computation tests ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_silence_is_zero() {
|
||||
let silence = vec![0.0f32; 1024];
|
||||
assert_eq!(compute_rms_energy(&silence), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_empty_is_zero() {
|
||||
assert_eq!(compute_rms_energy(&[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_constant_signal() {
|
||||
// Constant signal at 0.5 → RMS should be 0.5
|
||||
let signal = vec![0.5f32; 100];
|
||||
let energy = compute_rms_energy(&signal);
|
||||
assert!((energy - 0.5).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_above_threshold() {
|
||||
let loud = vec![0.8f32; 256];
|
||||
let energy = compute_rms_energy(&loud);
|
||||
assert!(energy > 0.01, "Loud signal should exceed default threshold");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_below_threshold_for_quiet() {
|
||||
let quiet = vec![0.001f32; 256];
|
||||
let energy = compute_rms_energy(&quiet);
|
||||
assert!(
|
||||
energy < 0.01,
|
||||
"Very quiet signal should be below default threshold"
|
||||
);
|
||||
}
|
||||
|
||||
// ── WAV encoding tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wav_header_is_valid() {
|
||||
let samples = vec![0.0f32; 100];
|
||||
let wav = encode_wav_from_f32(&samples, 16000, 1);
|
||||
|
||||
// RIFF header
|
||||
assert_eq!(&wav[0..4], b"RIFF");
|
||||
assert_eq!(&wav[8..12], b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
assert_eq!(&wav[12..16], b"fmt ");
|
||||
let fmt_size = u32::from_le_bytes(wav[16..20].try_into().unwrap());
|
||||
assert_eq!(fmt_size, 16);
|
||||
|
||||
// PCM format
|
||||
let format = u16::from_le_bytes(wav[20..22].try_into().unwrap());
|
||||
assert_eq!(format, 1);
|
||||
|
||||
// Channels
|
||||
let channels = u16::from_le_bytes(wav[22..24].try_into().unwrap());
|
||||
assert_eq!(channels, 1);
|
||||
|
||||
// Sample rate
|
||||
let sr = u32::from_le_bytes(wav[24..28].try_into().unwrap());
|
||||
assert_eq!(sr, 16000);
|
||||
|
||||
// data chunk
|
||||
assert_eq!(&wav[36..40], b"data");
|
||||
let data_size = u32::from_le_bytes(wav[40..44].try_into().unwrap());
|
||||
assert_eq!(data_size, 200); // 100 samples * 2 bytes each
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wav_total_size_correct() {
|
||||
let samples = vec![0.0f32; 50];
|
||||
let wav = encode_wav_from_f32(&samples, 44100, 2);
|
||||
// header (44 bytes) + data (50 * 2 = 100 bytes)
|
||||
assert_eq!(wav.len(), 144);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wav_encodes_clipped_samples() {
|
||||
// Samples outside [-1, 1] should be clamped
|
||||
let samples = vec![-2.0f32, 2.0, 0.0];
|
||||
let wav = encode_wav_from_f32(&samples, 16000, 1);
|
||||
|
||||
let s0 = i16::from_le_bytes(wav[44..46].try_into().unwrap());
|
||||
let s1 = i16::from_le_bytes(wav[46..48].try_into().unwrap());
|
||||
let s2 = i16::from_le_bytes(wav[48..50].try_into().unwrap());
|
||||
|
||||
assert_eq!(s0, -32767); // clamped to -1.0
|
||||
assert_eq!(s1, 32767); // clamped to 1.0
|
||||
assert_eq!(s2, 0);
|
||||
}
|
||||
|
||||
// ── Config parsing tests ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_defaults() {
|
||||
let config = VoiceWakeConfig::default();
|
||||
assert_eq!(config.wake_word, "hey zeroclaw");
|
||||
assert_eq!(config.silence_timeout_ms, 2000);
|
||||
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
|
||||
assert_eq!(config.max_capture_secs, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_deserialize_partial() {
|
||||
let toml_str = r#"
|
||||
wake_word = "okay agent"
|
||||
max_capture_secs = 60
|
||||
"#;
|
||||
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.wake_word, "okay agent");
|
||||
assert_eq!(config.max_capture_secs, 60);
|
||||
// Defaults preserved for unset fields
|
||||
assert_eq!(config.silence_timeout_ms, 2000);
|
||||
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_deserialize_all_fields() {
|
||||
let toml_str = r#"
|
||||
wake_word = "hello bot"
|
||||
silence_timeout_ms = 3000
|
||||
energy_threshold = 0.05
|
||||
max_capture_secs = 15
|
||||
"#;
|
||||
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.wake_word, "hello bot");
|
||||
assert_eq!(config.silence_timeout_ms, 3000);
|
||||
assert!((config.energy_threshold - 0.05).abs() < f32::EPSILON);
|
||||
assert_eq!(config.max_capture_secs, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_channel_config_trait() {
|
||||
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
|
||||
assert_eq!(VoiceWakeConfig::desc(), "voice wake word detection");
|
||||
}
|
||||
|
||||
// ── State transition logic tests ───────────────────────
|
||||
|
||||
#[test]
|
||||
fn energy_threshold_determines_trigger() {
|
||||
let threshold = 0.01f32;
|
||||
let quiet_energy = compute_rms_energy(&vec![0.005f32; 256]);
|
||||
let loud_energy = compute_rms_energy(&vec![0.5f32; 256]);
|
||||
|
||||
assert!(quiet_energy < threshold, "Quiet should not trigger");
|
||||
assert!(loud_energy >= threshold, "Loud should trigger");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_transitions_are_deterministic() {
|
||||
// Verify that the state enum values are distinct and copyable
|
||||
let states = [
|
||||
WakeState::Listening,
|
||||
WakeState::Triggered,
|
||||
WakeState::Capturing,
|
||||
WakeState::Processing,
|
||||
];
|
||||
for (i, a) in states.iter().enumerate() {
|
||||
for (j, b) in states.iter().enumerate() {
|
||||
if i == j {
|
||||
assert_eq!(a, b);
|
||||
} else {
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_config_impl() {
|
||||
// VoiceWakeConfig implements ChannelConfig
|
||||
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
|
||||
assert!(!VoiceWakeConfig::desc().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_channel_name() {
|
||||
let config = VoiceWakeConfig::default();
|
||||
let transcription_config = TranscriptionConfig::default();
|
||||
let channel = VoiceWakeChannel::new(config, transcription_config);
|
||||
assert_eq!(channel.name(), "voice_wake");
|
||||
}
|
||||
}
|
||||
+22
-20
@@ -10,26 +10,28 @@ pub use schema::{
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
|
||||
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
|
||||
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
|
||||
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
|
||||
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
|
||||
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
|
||||
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, CronJobDecl, CronScheduleDecl,
|
||||
DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation,
|
||||
GoogleWorkspaceConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkEnricherConfig, LinkedInConfig, LinkedInContentConfig,
|
||||
LinkedInImageConfig, LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig,
|
||||
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
|
||||
ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig,
|
||||
OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PiperTtsConfig,
|
||||
PluginsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
|
||||
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
|
||||
DEFAULT_GWS_SERVICES,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+781
-112
File diff suppressed because it is too large
Load Diff
+1
-1
@@ -15,7 +15,7 @@ pub use schedule::{
|
||||
#[allow(unused_imports)]
|
||||
pub use store::{
|
||||
add_agent_job, all_overdue_jobs, due_jobs, get_job, list_jobs, list_runs, record_last_run,
|
||||
record_run, remove_job, reschedule_after_run, update_job,
|
||||
record_run, remove_job, reschedule_after_run, sync_declarative_jobs, update_job,
|
||||
};
|
||||
pub use types::{
|
||||
deserialize_maybe_stringified, CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType,
|
||||
|
||||
+48
-2
@@ -1,5 +1,7 @@
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
use crate::channels::MatrixChannel;
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
use crate::channels::WhatsAppWebChannel;
|
||||
use crate::channels::{
|
||||
Channel, DiscordChannel, MattermostChannel, QQChannel, SendMessage, SignalChannel,
|
||||
SlackChannel, TelegramChannel,
|
||||
@@ -7,8 +9,8 @@ use crate::channels::{
|
||||
use crate::config::Config;
|
||||
use crate::cron::{
|
||||
all_overdue_jobs, due_jobs, next_run_for_schedule, record_last_run, record_run, remove_job,
|
||||
reschedule_after_run, update_job, CronJob, CronJobPatch, DeliveryConfig, JobType, Schedule,
|
||||
SessionTarget,
|
||||
reschedule_after_run, sync_declarative_jobs, update_job, CronJob, CronJobPatch, DeliveryConfig,
|
||||
JobType, Schedule, SessionTarget,
|
||||
};
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Result;
|
||||
@@ -34,6 +36,19 @@ pub async fn run(config: Config) -> Result<()> {
|
||||
|
||||
crate::health::mark_component_ok(SCHEDULER_COMPONENT);
|
||||
|
||||
// ── Declarative job sync: reconcile config-defined jobs with the DB.
|
||||
match sync_declarative_jobs(&config, &config.cron.jobs) {
|
||||
Ok(()) => {
|
||||
if !config.cron.jobs.is_empty() {
|
||||
tracing::info!(
|
||||
count = config.cron.jobs.len(),
|
||||
"Synced declarative cron jobs from config"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to sync declarative cron jobs: {e}"),
|
||||
}
|
||||
|
||||
// ── Startup catch-up: run ALL overdue jobs before entering the
|
||||
// normal polling loop. The regular loop is capped by `max_tasks`,
|
||||
// which could leave some overdue jobs waiting across many cycles
|
||||
@@ -483,6 +498,36 @@ pub(crate) async fn deliver_announcement(
|
||||
anyhow::bail!("matrix delivery channel requires `channel-matrix` feature");
|
||||
}
|
||||
}
|
||||
"whatsapp" | "whatsapp-web" | "whatsapp_web" => {
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
{
|
||||
let wa = config
|
||||
.channels_config
|
||||
.whatsapp
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("whatsapp channel not configured"))?;
|
||||
if !wa.is_web_config() {
|
||||
anyhow::bail!(
|
||||
"whatsapp cron delivery requires Web mode (session_path must be set)"
|
||||
);
|
||||
}
|
||||
let channel = WhatsAppWebChannel::new(
|
||||
wa.session_path.clone().unwrap_or_default(),
|
||||
wa.pair_phone.clone(),
|
||||
wa.pair_code.clone(),
|
||||
wa.allowed_numbers.clone(),
|
||||
wa.mode.clone(),
|
||||
wa.dm_policy.clone(),
|
||||
wa.group_policy.clone(),
|
||||
wa.self_chat_mode,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
}
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
anyhow::bail!("whatsapp delivery channel requires `whatsapp-web` feature");
|
||||
}
|
||||
}
|
||||
"qq" => {
|
||||
let qq = config
|
||||
.channels_config
|
||||
@@ -657,6 +702,7 @@ mod tests {
|
||||
delivery: DeliveryConfig::default(),
|
||||
delete_after_run: false,
|
||||
allowed_tools: None,
|
||||
source: "imperative".into(),
|
||||
created_at: Utc::now(),
|
||||
next_run: Utc::now(),
|
||||
last_run: None,
|
||||
|
||||
+521
-4
@@ -124,7 +124,7 @@ pub fn list_jobs(config: &Config) -> Result<Vec<CronJob>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs ORDER BY next_run ASC",
|
||||
)?;
|
||||
|
||||
@@ -143,7 +143,7 @@ pub fn get_job(config: &Config, job_id: &str) -> Result<CronJob> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs WHERE id = ?1",
|
||||
)?;
|
||||
|
||||
@@ -177,7 +177,7 @@ pub fn due_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJob>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs
|
||||
WHERE enabled = 1 AND next_run <= ?1
|
||||
ORDER BY next_run ASC
|
||||
@@ -206,7 +206,8 @@ pub fn all_overdue_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJ
|
||||
with_connection(config, |conn| {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output, allowed_tools
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools, source
|
||||
FROM cron_jobs
|
||||
WHERE enabled = 1 AND next_run <= ?1
|
||||
ORDER BY next_run ASC",
|
||||
@@ -488,6 +489,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
let last_run_raw: Option<String> = row.get(14)?;
|
||||
let created_at_raw: String = row.get(12)?;
|
||||
let allowed_tools_raw: Option<String> = row.get(17)?;
|
||||
let source: Option<String> = row.get(18)?;
|
||||
|
||||
Ok(CronJob {
|
||||
id: row.get(0)?,
|
||||
@@ -502,6 +504,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
enabled: row.get::<_, i64>(9)? != 0,
|
||||
delivery,
|
||||
delete_after_run: row.get::<_, i64>(11)? != 0,
|
||||
source: source.unwrap_or_else(|| "imperative".to_string()),
|
||||
created_at: parse_rfc3339(&created_at_raw).map_err(sql_conversion_error)?,
|
||||
next_run: parse_rfc3339(&next_run_raw).map_err(sql_conversion_error)?,
|
||||
last_run: match last_run_raw {
|
||||
@@ -564,6 +567,277 @@ fn decode_allowed_tools(raw: Option<&str>) -> Result<Option<Vec<String>>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Synchronize declarative cron job definitions from config into the database.
|
||||
///
|
||||
/// For each declarative job (identified by `id`):
|
||||
/// - If the job exists in DB: update it to match the config definition.
|
||||
/// - If the job does not exist: insert it.
|
||||
///
|
||||
/// Jobs created imperatively (via CLI/API) are never modified or deleted.
|
||||
/// Declarative jobs that are no longer present in config are removed.
|
||||
pub fn sync_declarative_jobs(
|
||||
config: &Config,
|
||||
decls: &[crate::config::schema::CronJobDecl],
|
||||
) -> Result<()> {
|
||||
use crate::config::schema::CronScheduleDecl;
|
||||
|
||||
if decls.is_empty() {
|
||||
// If no declarative jobs are defined, clean up any previously
|
||||
// synced declarative jobs that are no longer in config.
|
||||
with_connection(config, |conn| {
|
||||
let deleted = conn
|
||||
.execute("DELETE FROM cron_jobs WHERE source = 'declarative'", [])
|
||||
.context("Failed to remove stale declarative cron jobs")?;
|
||||
if deleted > 0 {
|
||||
tracing::info!(
|
||||
count = deleted,
|
||||
"Removed declarative cron jobs no longer in config"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Validate declarations before touching the DB.
|
||||
for decl in decls {
|
||||
validate_decl(decl)?;
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
with_connection(config, |conn| {
|
||||
// Collect IDs of all declarative jobs currently defined in config.
|
||||
let config_ids: std::collections::HashSet<&str> =
|
||||
decls.iter().map(|d| d.id.as_str()).collect();
|
||||
|
||||
// Remove declarative jobs no longer in config.
|
||||
{
|
||||
let mut stmt = conn.prepare("SELECT id FROM cron_jobs WHERE source = 'declarative'")?;
|
||||
let db_ids: Vec<String> = stmt
|
||||
.query_map([], |row| row.get(0))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
|
||||
for db_id in &db_ids {
|
||||
if !config_ids.contains(db_id.as_str()) {
|
||||
conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![db_id])
|
||||
.with_context(|| {
|
||||
format!("Failed to remove stale declarative cron job '{db_id}'")
|
||||
})?;
|
||||
tracing::info!(
|
||||
job_id = %db_id,
|
||||
"Removed declarative cron job no longer in config"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for decl in decls {
|
||||
let schedule = convert_schedule_decl(&decl.schedule)?;
|
||||
let expression = schedule_cron_expression(&schedule).unwrap_or_default();
|
||||
let schedule_json = serde_json::to_string(&schedule)?;
|
||||
let job_type = &decl.job_type;
|
||||
let session_target = decl.session_target.as_deref().unwrap_or("isolated");
|
||||
let delivery = match &decl.delivery {
|
||||
Some(d) => convert_delivery_decl(d),
|
||||
None => DeliveryConfig::default(),
|
||||
};
|
||||
let delivery_json = serde_json::to_string(&delivery)?;
|
||||
let allowed_tools_json = encode_allowed_tools(decl.allowed_tools.as_ref())?;
|
||||
let command = decl.command.as_deref().unwrap_or("");
|
||||
let delete_after_run = matches!(decl.schedule, CronScheduleDecl::At { .. });
|
||||
|
||||
// Check if job already exists.
|
||||
let exists: bool = conn
|
||||
.prepare("SELECT COUNT(*) FROM cron_jobs WHERE id = ?1")?
|
||||
.query_row(params![decl.id], |row| row.get::<_, i64>(0))
|
||||
.map(|c| c > 0)
|
||||
.unwrap_or(false);
|
||||
|
||||
if exists {
|
||||
// Update existing declarative job — preserve runtime state
|
||||
// (next_run, last_run, last_status, last_output, created_at).
|
||||
// Only update the schedule's next_run if the schedule itself changed.
|
||||
let current_schedule_raw: Option<String> = conn
|
||||
.prepare("SELECT schedule FROM cron_jobs WHERE id = ?1")?
|
||||
.query_row(params![decl.id], |row| row.get(0))
|
||||
.ok();
|
||||
|
||||
let schedule_changed = current_schedule_raw.as_deref() != Some(&schedule_json);
|
||||
|
||||
if schedule_changed {
|
||||
let next_run = next_run_for_schedule(&schedule, now)?;
|
||||
conn.execute(
|
||||
"UPDATE cron_jobs
|
||||
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
|
||||
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
|
||||
enabled = ?9, delivery = ?10, delete_after_run = ?11,
|
||||
allowed_tools = ?12, source = 'declarative', next_run = ?13
|
||||
WHERE id = ?14",
|
||||
params![
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
next_run.to_rfc3339(),
|
||||
decl.id,
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("Failed to update declarative cron job '{}'", decl.id)
|
||||
})?;
|
||||
} else {
|
||||
conn.execute(
|
||||
"UPDATE cron_jobs
|
||||
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
|
||||
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
|
||||
enabled = ?9, delivery = ?10, delete_after_run = ?11,
|
||||
allowed_tools = ?12, source = 'declarative'
|
||||
WHERE id = ?13",
|
||||
params![
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
decl.id,
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("Failed to update declarative cron job '{}'", decl.id)
|
||||
})?;
|
||||
}
|
||||
|
||||
tracing::debug!(job_id = %decl.id, "Updated declarative cron job");
|
||||
} else {
|
||||
// Insert new declarative job.
|
||||
let next_run = next_run_for_schedule(&schedule, now)?;
|
||||
conn.execute(
|
||||
"INSERT INTO cron_jobs (
|
||||
id, expression, command, schedule, job_type, prompt, name,
|
||||
session_target, model, enabled, delivery, delete_after_run,
|
||||
allowed_tools, source, created_at, next_run
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, 'declarative', ?14, ?15)",
|
||||
params![
|
||||
decl.id,
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
now.to_rfc3339(),
|
||||
next_run.to_rfc3339(),
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to insert declarative cron job '{}'",
|
||||
decl.id
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(job_id = %decl.id, "Inserted declarative cron job from config");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a declarative cron job definition.
|
||||
fn validate_decl(decl: &crate::config::schema::CronJobDecl) -> Result<()> {
|
||||
if decl.id.trim().is_empty() {
|
||||
anyhow::bail!("Declarative cron job has empty id");
|
||||
}
|
||||
|
||||
match decl.job_type.to_lowercase().as_str() {
|
||||
"shell" => {
|
||||
if decl
|
||||
.command
|
||||
.as_deref()
|
||||
.map_or(true, |c| c.trim().is_empty())
|
||||
{
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': shell job requires a non-empty 'command'",
|
||||
decl.id
|
||||
);
|
||||
}
|
||||
}
|
||||
"agent" => {
|
||||
if decl.prompt.as_deref().map_or(true, |p| p.trim().is_empty()) {
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': agent job requires a non-empty 'prompt'",
|
||||
decl.id
|
||||
);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': invalid job_type '{}', expected 'shell' or 'agent'",
|
||||
decl.id,
|
||||
other
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert a `CronScheduleDecl` to the runtime `Schedule` type.
|
||||
fn convert_schedule_decl(decl: &crate::config::schema::CronScheduleDecl) -> Result<Schedule> {
|
||||
use crate::config::schema::CronScheduleDecl;
|
||||
match decl {
|
||||
CronScheduleDecl::Cron { expr, tz } => Ok(Schedule::Cron {
|
||||
expr: expr.clone(),
|
||||
tz: tz.clone(),
|
||||
}),
|
||||
CronScheduleDecl::Every { every_ms } => Ok(Schedule::Every {
|
||||
every_ms: *every_ms,
|
||||
}),
|
||||
CronScheduleDecl::At { at } => {
|
||||
let parsed = DateTime::parse_from_rfc3339(at)
|
||||
.with_context(|| {
|
||||
format!("Invalid RFC3339 timestamp in declarative cron 'at': {at}")
|
||||
})?
|
||||
.with_timezone(&Utc);
|
||||
Ok(Schedule::At { at: parsed })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a `DeliveryConfigDecl` to the runtime `DeliveryConfig`.
|
||||
fn convert_delivery_decl(decl: &crate::config::schema::DeliveryConfigDecl) -> DeliveryConfig {
|
||||
DeliveryConfig {
|
||||
mode: decl.mode.clone(),
|
||||
channel: decl.channel.clone(),
|
||||
to: decl.to.clone(),
|
||||
best_effort: decl.best_effort,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_column_if_missing(conn: &Connection, name: &str, sql_type: &str) -> Result<()> {
|
||||
let mut stmt = conn.prepare("PRAGMA table_info(cron_jobs)")?;
|
||||
let mut rows = stmt.query([])?;
|
||||
@@ -654,6 +928,7 @@ fn with_connection<T>(config: &Config, f: impl FnOnce(&Connection) -> Result<T>)
|
||||
add_column_if_missing(&conn, "delivery", "TEXT")?;
|
||||
add_column_if_missing(&conn, "delete_after_run", "INTEGER NOT NULL DEFAULT 0")?;
|
||||
add_column_if_missing(&conn, "allowed_tools", "TEXT")?;
|
||||
add_column_if_missing(&conn, "source", "TEXT DEFAULT 'imperative'")?;
|
||||
|
||||
f(&conn)
|
||||
}
|
||||
@@ -1170,4 +1445,246 @@ mod tests {
|
||||
assert!(last_output.ends_with(TRUNCATED_OUTPUT_MARKER));
|
||||
assert!(last_output.len() <= MAX_CRON_OUTPUT_BYTES);
|
||||
}
|
||||
|
||||
// ── Declarative cron job sync tests ──────────────────────────
|
||||
|
||||
fn make_shell_decl(id: &str, expr: &str, cmd: &str) -> crate::config::schema::CronJobDecl {
|
||||
crate::config::schema::CronJobDecl {
|
||||
id: id.to_string(),
|
||||
name: Some(format!("decl-{id}")),
|
||||
job_type: "shell".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Cron {
|
||||
expr: expr.to_string(),
|
||||
tz: None,
|
||||
},
|
||||
command: Some(cmd.to_string()),
|
||||
prompt: None,
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_agent_decl(id: &str, expr: &str, prompt: &str) -> crate::config::schema::CronJobDecl {
|
||||
crate::config::schema::CronJobDecl {
|
||||
id: id.to_string(),
|
||||
name: Some(format!("decl-{id}")),
|
||||
job_type: "agent".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Cron {
|
||||
expr: expr.to_string(),
|
||||
tz: None,
|
||||
},
|
||||
command: None,
|
||||
prompt: Some(prompt.to_string()),
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_inserts_new_declarative_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("daily-backup", "0 2 * * *", "echo backup")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job = get_job(&config, "daily-backup").unwrap();
|
||||
assert_eq!(job.command, "echo backup");
|
||||
assert_eq!(job.source, "declarative");
|
||||
assert_eq!(job.name.as_deref(), Some("decl-daily-backup"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_updates_existing_declarative_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("updatable", "0 2 * * *", "echo v1")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job_v1 = get_job(&config, "updatable").unwrap();
|
||||
assert_eq!(job_v1.command, "echo v1");
|
||||
|
||||
let decls_v2 = vec![make_shell_decl("updatable", "0 3 * * *", "echo v2")];
|
||||
sync_declarative_jobs(&config, &decls_v2).unwrap();
|
||||
|
||||
let job_v2 = get_job(&config, "updatable").unwrap();
|
||||
assert_eq!(job_v2.command, "echo v2");
|
||||
assert_eq!(job_v2.expression, "0 3 * * *");
|
||||
assert_eq!(job_v2.source, "declarative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_does_not_delete_imperative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
// Create an imperative job via the normal API.
|
||||
let imperative = add_job(&config, "*/10 * * * *", "echo imperative").unwrap();
|
||||
|
||||
// Sync declarative jobs (none of which match the imperative job).
|
||||
let decls = vec![make_shell_decl("my-decl", "0 2 * * *", "echo decl")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
// Imperative job should still exist.
|
||||
let still_there = get_job(&config, &imperative.id).unwrap();
|
||||
assert_eq!(still_there.command, "echo imperative");
|
||||
assert_eq!(still_there.source, "imperative");
|
||||
|
||||
// Declarative job should also exist.
|
||||
let decl_job = get_job(&config, "my-decl").unwrap();
|
||||
assert_eq!(decl_job.command, "echo decl");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_removes_stale_declarative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
// Insert two declarative jobs.
|
||||
let decls = vec![
|
||||
make_shell_decl("keeper", "0 2 * * *", "echo keep"),
|
||||
make_shell_decl("stale", "0 3 * * *", "echo stale"),
|
||||
];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
// Now sync with only "keeper" — "stale" should be removed.
|
||||
let decls_v2 = vec![make_shell_decl("keeper", "0 2 * * *", "echo keep")];
|
||||
sync_declarative_jobs(&config, &decls_v2).unwrap();
|
||||
|
||||
assert!(get_job(&config, "stale").is_err());
|
||||
assert!(get_job(&config, "keeper").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_empty_removes_all_declarative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("to-remove", "0 2 * * *", "echo bye")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
assert!(get_job(&config, "to-remove").is_ok());
|
||||
|
||||
// Sync with empty list.
|
||||
sync_declarative_jobs(&config, &[]).unwrap();
|
||||
assert!(get_job(&config, "to-remove").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_validates_shell_job_requires_command() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let mut decl = make_shell_decl("bad", "0 2 * * *", "echo ok");
|
||||
decl.command = None;
|
||||
|
||||
let result = sync_declarative_jobs(&config, &[decl]);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("command"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_validates_agent_job_requires_prompt() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let mut decl = make_agent_decl("bad-agent", "0 2 * * *", "do stuff");
|
||||
decl.prompt = None;
|
||||
|
||||
let result = sync_declarative_jobs(&config, &[decl]);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_agent_job_inserts_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_agent_decl(
|
||||
"agent-check",
|
||||
"*/15 * * * *",
|
||||
"check health",
|
||||
)];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job = get_job(&config, "agent-check").unwrap();
|
||||
assert_eq!(job.job_type, JobType::Agent);
|
||||
assert_eq!(job.prompt.as_deref(), Some("check health"));
|
||||
assert_eq!(job.source, "declarative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_every_schedule_works() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decl = crate::config::schema::CronJobDecl {
|
||||
id: "interval-job".to_string(),
|
||||
name: None,
|
||||
job_type: "shell".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Every { every_ms: 60000 },
|
||||
command: Some("echo interval".to_string()),
|
||||
prompt: None,
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
};
|
||||
|
||||
sync_declarative_jobs(&config, &[decl]).unwrap();
|
||||
|
||||
let job = get_job(&config, "interval-job").unwrap();
|
||||
assert!(matches!(job.schedule, Schedule::Every { every_ms: 60000 }));
|
||||
assert_eq!(job.command, "echo interval");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn declarative_config_parses_from_toml() {
|
||||
let toml_str = r#"
|
||||
enabled = true
|
||||
|
||||
[[jobs]]
|
||||
id = "daily-report"
|
||||
name = "Daily Report"
|
||||
job_type = "shell"
|
||||
command = "echo report"
|
||||
schedule = { kind = "cron", expr = "0 9 * * *" }
|
||||
|
||||
[[jobs]]
|
||||
id = "health-check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health"
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
"#;
|
||||
|
||||
let parsed: crate::config::schema::CronConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(parsed.enabled);
|
||||
assert_eq!(parsed.jobs.len(), 2);
|
||||
|
||||
assert_eq!(parsed.jobs[0].id, "daily-report");
|
||||
assert_eq!(parsed.jobs[0].command.as_deref(), Some("echo report"));
|
||||
assert!(matches!(
|
||||
parsed.jobs[0].schedule,
|
||||
crate::config::schema::CronScheduleDecl::Cron { ref expr, .. } if expr == "0 9 * * *"
|
||||
));
|
||||
|
||||
assert_eq!(parsed.jobs[1].id, "health-check");
|
||||
assert_eq!(parsed.jobs[1].job_type, "agent");
|
||||
assert_eq!(
|
||||
parsed.jobs[1].prompt.as_deref(),
|
||||
Some("Check server health")
|
||||
);
|
||||
assert!(matches!(
|
||||
parsed.jobs[1].schedule,
|
||||
crate::config::schema::CronScheduleDecl::Every { every_ms: 300_000 }
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,6 +127,10 @@ fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_source() -> String {
|
||||
"imperative".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronJob {
|
||||
pub id: String,
|
||||
@@ -146,6 +150,9 @@ pub struct CronJob {
|
||||
/// When `None`, all tools are available (backward compatible default).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// How the job was created: `"imperative"` (CLI/API) or `"declarative"` (config).
|
||||
#[serde(default = "default_source")]
|
||||
pub source: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub next_run: DateTime<Utc>,
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
|
||||
+193
-1
@@ -362,10 +362,22 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
// Re-read session context on every tick so we pick up messages
|
||||
// that arrived since the daemon started.
|
||||
let session_context = if config.heartbeat.load_session_context {
|
||||
load_heartbeat_session_context(&config)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut tick_had_error = false;
|
||||
for task in &tasks_to_run {
|
||||
let task_start = std::time::Instant::now();
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let task_prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let prompt = match &session_context {
|
||||
Some(ctx) => format!("{ctx}\n\n{task_prompt}"),
|
||||
None => task_prompt,
|
||||
};
|
||||
let temp = config.default_temperature;
|
||||
match Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
@@ -497,6 +509,186 @@ fn resolve_heartbeat_delivery(config: &Config) -> Result<Option<(String, String)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load recent conversation history for the heartbeat's delivery target and
|
||||
/// format it as a text preamble to inject into the task prompt.
|
||||
///
|
||||
/// Scans `{workspace}/sessions/` for JSONL files whose name starts with
|
||||
/// `{channel}_` and ends with `_{to}.jsonl` (or exactly `{channel}_{to}.jsonl`),
|
||||
/// then picks the most recently modified match. This handles session key
|
||||
/// formats such as `telegram_diskiller.jsonl` and
|
||||
/// `telegram_5673725398_diskiller.jsonl`.
|
||||
/// Returns `None` when `target`/`to` are not configured or no session exists.
|
||||
const HEARTBEAT_SESSION_CONTEXT_MESSAGES: usize = 20;
|
||||
|
||||
fn load_heartbeat_session_context(config: &Config) -> Option<String> {
|
||||
use crate::providers::traits::ChatMessage;
|
||||
|
||||
let channel = config
|
||||
.heartbeat
|
||||
.target
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())?;
|
||||
let to = config
|
||||
.heartbeat
|
||||
.to
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())?;
|
||||
|
||||
if channel.contains('/') || channel.contains('\\') || to.contains('/') || to.contains('\\') {
|
||||
tracing::warn!("heartbeat session context: channel/to contains path separators, skipping");
|
||||
return None;
|
||||
}
|
||||
|
||||
let sessions_dir = config.workspace_dir.join("sessions");
|
||||
|
||||
// Find the most recently modified JSONL file that belongs to this target.
|
||||
// Matches both `{channel}_{to}.jsonl` and `{channel}_{anything}_{to}.jsonl`.
|
||||
let prefix = format!("{channel}_");
|
||||
let suffix = format!("_{to}.jsonl");
|
||||
let exact = format!("{channel}_{to}.jsonl");
|
||||
let mid_prefix = format!("{channel}_{to}_");
|
||||
|
||||
let path = std::fs::read_dir(&sessions_dir)
|
||||
.ok()?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
let name = e.file_name();
|
||||
let name = name.to_string_lossy();
|
||||
name.ends_with(".jsonl")
|
||||
&& (name == exact
|
||||
|| (name.starts_with(&prefix) && name.ends_with(&suffix))
|
||||
|| name.starts_with(&mid_prefix))
|
||||
})
|
||||
.max_by_key(|e| {
|
||||
e.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
|
||||
})
|
||||
.map(|e| e.path())?;
|
||||
|
||||
if !path.exists() {
|
||||
tracing::debug!("💓 Heartbeat session context: no session file found for {channel}/{to}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let messages = load_jsonl_messages(&path);
|
||||
if messages.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let recent: Vec<&ChatMessage> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" || m.role == "assistant")
|
||||
.rev()
|
||||
.take(HEARTBEAT_SESSION_CONTEXT_MESSAGES)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
// Only inject context if there is at least one real user message in the
|
||||
// window. If the JSONL contains only assistant messages (e.g. previous
|
||||
// heartbeat outputs with no reply yet), skip context to avoid feeding
|
||||
// Monika's own messages back to her in a loop.
|
||||
let has_user_message = recent.iter().any(|m| m.role == "user");
|
||||
if !has_user_message {
|
||||
tracing::debug!(
|
||||
"💓 Heartbeat session context: no user messages in recent history — skipping"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
// Use the session file's mtime as a proxy for when the last message arrived.
|
||||
let last_message_age = std::fs::metadata(&path)
|
||||
.ok()
|
||||
.and_then(|m| m.modified().ok())
|
||||
.and_then(|mtime| mtime.elapsed().ok());
|
||||
|
||||
let silence_note = match last_message_age {
|
||||
Some(age) => {
|
||||
let mins = age.as_secs() / 60;
|
||||
if mins < 60 {
|
||||
format!("(last message ~{mins} minutes ago)\n")
|
||||
} else {
|
||||
let hours = mins / 60;
|
||||
let rem = mins % 60;
|
||||
if rem == 0 {
|
||||
format!("(last message ~{hours}h ago)\n")
|
||||
} else {
|
||||
format!("(last message ~{hours}h {rem}m ago)\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
"💓 Heartbeat session context: {} messages from {}, silence: {}",
|
||||
recent.len(),
|
||||
path.display(),
|
||||
silence_note.trim(),
|
||||
);
|
||||
|
||||
let mut ctx = format!(
|
||||
"[Recent conversation history — use this for context when composing your message] {silence_note}",
|
||||
);
|
||||
for msg in &recent {
|
||||
let label = if msg.role == "user" { "User" } else { "You" };
|
||||
// Truncate very long messages to avoid bloating the prompt.
|
||||
// Use char_indices to avoid panicking on multi-byte UTF-8 characters.
|
||||
let content = if msg.content.len() > 500 {
|
||||
let truncate_at = msg
|
||||
.content
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 500)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
format!("{}…", &msg.content[..truncate_at])
|
||||
} else {
|
||||
msg.content.clone()
|
||||
};
|
||||
ctx.push_str(label);
|
||||
ctx.push_str(": ");
|
||||
ctx.push_str(&content);
|
||||
ctx.push('\n');
|
||||
}
|
||||
|
||||
Some(ctx)
|
||||
}
|
||||
|
||||
/// Read the last `HEARTBEAT_SESSION_CONTEXT_MESSAGES` `ChatMessage` lines from
|
||||
/// a JSONL session file using a bounded rolling window so we never hold the
|
||||
/// entire file in memory.
|
||||
fn load_jsonl_messages(path: &std::path::Path) -> Vec<crate::providers::traits::ChatMessage> {
|
||||
use std::collections::VecDeque;
|
||||
use std::io::BufRead;
|
||||
|
||||
let file = match std::fs::File::open(path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut window: VecDeque<crate::providers::traits::ChatMessage> =
|
||||
VecDeque::with_capacity(HEARTBEAT_SESSION_CONTEXT_MESSAGES + 1);
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<crate::providers::traits::ChatMessage>(trimmed) {
|
||||
window.push_back(msg);
|
||||
if window.len() > HEARTBEAT_SESSION_CONTEXT_MESSAGES {
|
||||
window.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
window.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Auto-detect the best channel for heartbeat delivery by checking which
|
||||
/// channels are configured. Returns the first match in priority order.
|
||||
fn auto_detect_heartbeat_channel(config: &Config) -> Option<(String, String)> {
|
||||
|
||||
+60
-3
@@ -23,7 +23,7 @@ fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
|
||||
}
|
||||
|
||||
/// Verify bearer token against PairingGuard. Returns error response if unauthorized.
|
||||
fn require_auth(
|
||||
pub(super) fn require_auth(
|
||||
state: &AppState,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
|
||||
@@ -1280,12 +1280,16 @@ pub async fn handle_api_sessions_list(
|
||||
.into_iter()
|
||||
.filter_map(|meta| {
|
||||
let session_id = meta.key.strip_prefix("gw_")?;
|
||||
Some(serde_json::json!({
|
||||
let mut entry = serde_json::json!({
|
||||
"session_id": session_id,
|
||||
"created_at": meta.created_at.to_rfc3339(),
|
||||
"last_activity": meta.last_activity.to_rfc3339(),
|
||||
"message_count": meta.message_count,
|
||||
}))
|
||||
});
|
||||
if let Some(name) = meta.name {
|
||||
entry["name"] = serde_json::Value::String(name);
|
||||
}
|
||||
Some(entry)
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1326,6 +1330,56 @@ pub async fn handle_api_session_delete(
|
||||
}
|
||||
}
|
||||
|
||||
/// PUT /api/sessions/{id} — rename a gateway session
|
||||
pub async fn handle_api_session_rename(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let Some(ref backend) = state.session_backend else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Session persistence is disabled"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let name = body["name"].as_str().unwrap_or("").trim();
|
||||
if name.is_empty() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "name is required"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let session_key = format!("gw_{id}");
|
||||
|
||||
// Verify the session exists before renaming
|
||||
let sessions = backend.list_sessions();
|
||||
if !sessions.contains(&session_key) {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Session not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match backend.set_session_name(&session_key, name) {
|
||||
Ok(()) => Json(serde_json::json!({"session_id": id, "name": name})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": format!("Failed to rename session: {e}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -1429,6 +1483,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -1438,6 +1493,8 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
path_prefix: String::new(),
|
||||
canvas_store: crate::tools::canvas::CanvasStore::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
//! Live Canvas gateway routes — REST + WebSocket for real-time canvas updates.
|
||||
//!
|
||||
//! - `GET /api/canvas/:id` — get current canvas content (JSON)
|
||||
//! - `POST /api/canvas/:id` — push content programmatically
|
||||
//! - `GET /api/canvas` — list all active canvases
|
||||
//! - `WS /ws/canvas/:id` — real-time canvas updates via WebSocket
|
||||
|
||||
use super::api::require_auth;
|
||||
use super::AppState;
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
Path, State, WebSocketUpgrade,
|
||||
},
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// POST /api/canvas/:id request body.
|
||||
#[derive(Deserialize)]
|
||||
pub struct CanvasPostBody {
|
||||
pub content_type: Option<String>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// GET /api/canvas — list all active canvases.
|
||||
pub async fn handle_canvas_list(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let ids = state.canvas_store.list();
|
||||
Json(serde_json::json!({ "canvases": ids })).into_response()
|
||||
}
|
||||
|
||||
/// GET /api/canvas/:id — get current canvas content.
|
||||
pub async fn handle_canvas_get(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
match state.canvas_store.snapshot(&id) {
|
||||
Some(frame) => Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frame": frame,
|
||||
}))
|
||||
.into_response(),
|
||||
None => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({ "error": format!("Canvas '{}' not found", id) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/canvas/:id/history — get canvas frame history.
|
||||
pub async fn handle_canvas_history(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let history = state.canvas_store.history(&id);
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frames": history,
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// POST /api/canvas/:id — push content to a canvas.
|
||||
pub async fn handle_canvas_post(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<CanvasPostBody>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let content_type = body.content_type.as_deref().unwrap_or("html");
|
||||
|
||||
// Validate content_type against allowed set (prevent injecting "eval" frames via REST).
|
||||
if !crate::tools::canvas::ALLOWED_CONTENT_TYPES.contains(&content_type) {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": format!(
|
||||
"Invalid content_type '{}'. Allowed: {:?}",
|
||||
content_type,
|
||||
crate::tools::canvas::ALLOWED_CONTENT_TYPES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Enforce content size limit (same as tool-side validation).
|
||||
if body.content.len() > crate::tools::canvas::MAX_CONTENT_SIZE {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Json(serde_json::json!({
|
||||
"error": format!(
|
||||
"Content exceeds maximum size of {} bytes",
|
||||
crate::tools::canvas::MAX_CONTENT_SIZE
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match state.canvas_store.render(&id, content_type, &body.content) {
|
||||
Some(frame) => (
|
||||
StatusCode::CREATED,
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frame": frame,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
None => (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(serde_json::json!({
|
||||
"error": "Maximum canvas count reached. Clear unused canvases first."
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// DELETE /api/canvas/:id — clear a canvas.
|
||||
pub async fn handle_canvas_clear(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
state.canvas_store.clear(&id);
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"status": "cleared",
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// WS /ws/canvas/:id — real-time canvas updates.
|
||||
pub async fn handle_ws_canvas(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> impl IntoResponse {
|
||||
// Auth check (same pattern as ws::handle_ws_chat)
|
||||
if state.pairing.require_pairing() {
|
||||
let token = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.or_else(|| {
|
||||
// Fallback: check query params in the upgrade request URI
|
||||
headers
|
||||
.get("sec-websocket-protocol")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|protos| {
|
||||
protos
|
||||
.split(',')
|
||||
.map(|p| p.trim())
|
||||
.find_map(|p| p.strip_prefix("bearer."))
|
||||
})
|
||||
})
|
||||
.unwrap_or("");
|
||||
|
||||
if !state.pairing.is_authenticated(token) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Unauthorized — provide Authorization header or Sec-WebSocket-Protocol bearer",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
ws.on_upgrade(move |socket| handle_canvas_socket(socket, state, id))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn handle_canvas_socket(socket: WebSocket, state: AppState, canvas_id: String) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Subscribe to canvas updates
|
||||
let mut rx = match state.canvas_store.subscribe(&canvas_id) {
|
||||
Some(rx) => rx,
|
||||
None => {
|
||||
let msg = serde_json::json!({
|
||||
"type": "error",
|
||||
"error": "Maximum canvas count reached",
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Send current state immediately if available
|
||||
if let Some(frame) = state.canvas_store.snapshot(&canvas_id) {
|
||||
let msg = serde_json::json!({
|
||||
"type": "frame",
|
||||
"canvas_id": canvas_id,
|
||||
"frame": frame,
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
}
|
||||
|
||||
// Send a connected acknowledgement
|
||||
let ack = serde_json::json!({
|
||||
"type": "connected",
|
||||
"canvas_id": canvas_id,
|
||||
});
|
||||
let _ = sender.send(Message::Text(ack.to_string().into())).await;
|
||||
|
||||
// Spawn a task that forwards broadcast updates to the WebSocket
|
||||
let canvas_id_clone = canvas_id.clone();
|
||||
let send_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(frame) => {
|
||||
let msg = serde_json::json!({
|
||||
"type": "frame",
|
||||
"canvas_id": canvas_id_clone,
|
||||
"frame": frame,
|
||||
});
|
||||
if sender
|
||||
.send(Message::Text(msg.to_string().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
// Client fell behind — notify and continue rather than disconnecting.
|
||||
let msg = serde_json::json!({
|
||||
"type": "lagged",
|
||||
"canvas_id": canvas_id_clone,
|
||||
"missed_frames": n,
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Read loop: we mostly ignore incoming messages but handle close/ping
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Close(_)) | Err(_) => break,
|
||||
_ => {} // Ignore all other messages (pings are handled by axum)
|
||||
}
|
||||
}
|
||||
|
||||
// Abort the send task when the connection is closed
|
||||
send_task.abort();
|
||||
}
|
||||
+194
-40
@@ -11,14 +11,15 @@ pub mod api;
|
||||
pub mod api_pairing;
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
pub mod api_plugins;
|
||||
pub mod canvas;
|
||||
pub mod nodes;
|
||||
pub mod sse;
|
||||
pub mod static_files;
|
||||
pub mod ws;
|
||||
|
||||
use crate::channels::{
|
||||
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel, LinqChannel,
|
||||
NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
|
||||
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel,
|
||||
GmailPushChannel, LinqChannel, NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::cost::CostTracker;
|
||||
@@ -28,6 +29,7 @@ use crate::runtime;
|
||||
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools;
|
||||
use crate::tools::canvas::CanvasStore;
|
||||
use crate::tools::traits::ToolSpec;
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::{Context, Result};
|
||||
@@ -336,6 +338,8 @@ pub struct AppState {
|
||||
/// Nextcloud Talk webhook secret for signature verification
|
||||
pub nextcloud_talk_webhook_secret: Option<Arc<str>>,
|
||||
pub wati: Option<Arc<WatiChannel>>,
|
||||
/// Gmail Pub/Sub push notification channel
|
||||
pub gmail_push: Option<Arc<GmailPushChannel>>,
|
||||
/// Observability backend for metrics scraping
|
||||
pub observer: Arc<dyn crate::observability::Observer>,
|
||||
/// Registered tool specs (for web dashboard tools page)
|
||||
@@ -348,12 +352,16 @@ pub struct AppState {
|
||||
pub shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||
/// Registry of dynamically connected nodes
|
||||
pub node_registry: Arc<nodes::NodeRegistry>,
|
||||
/// Path prefix for reverse-proxy deployments (empty string = no prefix)
|
||||
pub path_prefix: String,
|
||||
/// Session backend for persisting gateway WS chat sessions
|
||||
pub session_backend: Option<Arc<dyn SessionBackend>>,
|
||||
/// Device registry for paired device management
|
||||
pub device_registry: Option<Arc<api_pairing::DeviceRegistry>>,
|
||||
/// Pending pairing request store
|
||||
pub pending_pairings: Option<Arc<api_pairing::PairingStore>>,
|
||||
/// Shared canvas store for Live Canvas (A2UI) system
|
||||
pub canvas_store: CanvasStore,
|
||||
}
|
||||
|
||||
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
|
||||
@@ -430,21 +438,25 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let (mut tools_registry_raw, delegate_handle_gw) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
);
|
||||
let canvas_store = tools::CanvasStore::new();
|
||||
|
||||
let (mut tools_registry_raw, delegate_handle_gw, _reaction_handle_gw) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
Some(canvas_store.clone()),
|
||||
);
|
||||
|
||||
// ── Wire MCP tools into the gateway tool registry (non-fatal) ───
|
||||
// Without this, the `/api/tools` endpoint misses MCP tools.
|
||||
@@ -627,6 +639,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
})
|
||||
.map(Arc::from);
|
||||
|
||||
// Gmail Push channel (if configured and enabled)
|
||||
let gmail_push_channel: Option<Arc<GmailPushChannel>> = config
|
||||
.channels_config
|
||||
.gmail_push
|
||||
.as_ref()
|
||||
.filter(|gp| gp.enabled)
|
||||
.map(|gp| Arc::new(GmailPushChannel::new(gp.clone())));
|
||||
|
||||
// ── Session persistence for WS chat ─────────────────────
|
||||
let session_backend: Option<Arc<dyn SessionBackend>> = if config.gateway.session_persistence {
|
||||
match SqliteSessionBackend::new(&config.workspace_dir) {
|
||||
@@ -673,6 +693,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
idempotency_max_keys,
|
||||
));
|
||||
|
||||
// Resolve optional path prefix for reverse-proxy deployments.
|
||||
let path_prefix: Option<&str> = config
|
||||
.gateway
|
||||
.path_prefix
|
||||
.as_deref()
|
||||
.filter(|p| !p.is_empty());
|
||||
|
||||
// ── Tunnel ────────────────────────────────────────────────
|
||||
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
||||
let mut tunnel_url: Option<String> = None;
|
||||
@@ -691,18 +718,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
|
||||
let pfx = path_prefix.unwrap_or("");
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}{pfx}");
|
||||
if let Some(ref url) = tunnel_url {
|
||||
println!(" 🌐 Public URL: {url}");
|
||||
}
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}/");
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}{pfx}/");
|
||||
if let Some(code) = pairing.pairing_code() {
|
||||
println!();
|
||||
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
|
||||
println!(" ┌──────────────┐");
|
||||
println!(" │ {code} │");
|
||||
println!(" └──────────────┘");
|
||||
println!();
|
||||
println!(" Send: POST {pfx}/pair with header X-Pairing-Code: {code}");
|
||||
} else if pairing.require_pairing() {
|
||||
println!(" 🔒 Pairing: ACTIVE (bearer token required)");
|
||||
println!(" To pair a new device: zeroclaw gateway get-paircode --new");
|
||||
@@ -711,29 +739,29 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||
println!();
|
||||
}
|
||||
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||
println!(" POST {pfx}/pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST {pfx}/webhook — {{\"message\": \"your prompt\"}}");
|
||||
if whatsapp_channel.is_some() {
|
||||
println!(" GET /whatsapp — Meta webhook verification");
|
||||
println!(" POST /whatsapp — WhatsApp message webhook");
|
||||
println!(" GET {pfx}/whatsapp — Meta webhook verification");
|
||||
println!(" POST {pfx}/whatsapp — WhatsApp message webhook");
|
||||
}
|
||||
if linq_channel.is_some() {
|
||||
println!(" POST /linq — Linq message webhook (iMessage/RCS/SMS)");
|
||||
println!(" POST {pfx}/linq — Linq message webhook (iMessage/RCS/SMS)");
|
||||
}
|
||||
if wati_channel.is_some() {
|
||||
println!(" GET /wati — WATI webhook verification");
|
||||
println!(" POST /wati — WATI message webhook");
|
||||
println!(" GET {pfx}/wati — WATI webhook verification");
|
||||
println!(" POST {pfx}/wati — WATI message webhook");
|
||||
}
|
||||
if nextcloud_talk_channel.is_some() {
|
||||
println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook");
|
||||
println!(" POST {pfx}/nextcloud-talk — Nextcloud Talk bot webhook");
|
||||
}
|
||||
println!(" GET /api/* — REST API (bearer token required)");
|
||||
println!(" GET /ws/chat — WebSocket agent chat");
|
||||
println!(" GET {pfx}/api/* — REST API (bearer token required)");
|
||||
println!(" GET {pfx}/ws/chat — WebSocket agent chat");
|
||||
if config.nodes.enabled {
|
||||
println!(" GET /ws/nodes — WebSocket node discovery");
|
||||
println!(" GET {pfx}/ws/nodes — WebSocket node discovery");
|
||||
}
|
||||
println!(" GET /health — health check");
|
||||
println!(" GET /metrics — Prometheus metrics");
|
||||
println!(" GET {pfx}/health — health check");
|
||||
println!(" GET {pfx}/metrics — Prometheus metrics");
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
crate::health::mark_component_ok("gateway");
|
||||
@@ -790,6 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
nextcloud_talk: nextcloud_talk_channel,
|
||||
nextcloud_talk_webhook_secret,
|
||||
wati: wati_channel,
|
||||
gmail_push: gmail_push_channel,
|
||||
observer: broadcast_observer,
|
||||
tools_registry,
|
||||
cost_tracker,
|
||||
@@ -799,6 +828,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
session_backend,
|
||||
device_registry,
|
||||
pending_pairings,
|
||||
path_prefix: path_prefix.unwrap_or("").to_string(),
|
||||
canvas_store,
|
||||
};
|
||||
|
||||
// Config PUT needs larger body limit (1MB)
|
||||
@@ -807,7 +838,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.layer(RequestBodyLimitLayer::new(1_048_576));
|
||||
|
||||
// Build router with middleware
|
||||
let app = Router::new()
|
||||
let inner = Router::new()
|
||||
// ── Admin routes (for CLI management) ──
|
||||
.route("/admin/shutdown", post(handle_admin_shutdown))
|
||||
.route("/admin/paircode", get(handle_admin_paircode))
|
||||
@@ -823,6 +854,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/wati", get(handle_wati_verify))
|
||||
.route("/wati", post(handle_wati_webhook))
|
||||
.route("/nextcloud-talk", post(handle_nextcloud_talk_webhook))
|
||||
.route("/webhook/gmail", post(handle_gmail_push_webhook))
|
||||
// ── Web Dashboard API routes ──
|
||||
.route("/api/status", get(api::handle_api_status))
|
||||
.route("/api/config", get(api::handle_api_config_get))
|
||||
@@ -854,7 +886,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/api/cli-tools", get(api::handle_api_cli_tools))
|
||||
.route("/api/health", get(api::handle_api_health))
|
||||
.route("/api/sessions", get(api::handle_api_sessions_list))
|
||||
.route("/api/sessions/{id}", delete(api::handle_api_session_delete))
|
||||
.route("/api/sessions/{id}", delete(api::handle_api_session_delete).put(api::handle_api_session_rename))
|
||||
// ── Pairing + Device management API ──
|
||||
.route("/api/pairing/initiate", post(api_pairing::initiate_pairing))
|
||||
.route("/api/pair", post(api_pairing::submit_pairing_enhanced))
|
||||
@@ -863,34 +895,61 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route(
|
||||
"/api/devices/{id}/token/rotate",
|
||||
post(api_pairing::rotate_token),
|
||||
)
|
||||
// ── Live Canvas (A2UI) routes ──
|
||||
.route("/api/canvas", get(canvas::handle_canvas_list))
|
||||
.route(
|
||||
"/api/canvas/{id}",
|
||||
get(canvas::handle_canvas_get)
|
||||
.post(canvas::handle_canvas_post)
|
||||
.delete(canvas::handle_canvas_clear),
|
||||
)
|
||||
.route(
|
||||
"/api/canvas/{id}/history",
|
||||
get(canvas::handle_canvas_history),
|
||||
);
|
||||
|
||||
// ── Plugin management API (requires plugins-wasm feature) ──
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
let app = app.route(
|
||||
let inner = inner.route(
|
||||
"/api/plugins",
|
||||
get(api_plugins::plugin_routes::list_plugins),
|
||||
);
|
||||
|
||||
let app = app
|
||||
let inner = inner
|
||||
// ── SSE event stream ──
|
||||
.route("/api/events", get(sse::handle_sse_events))
|
||||
// ── WebSocket agent chat ──
|
||||
.route("/ws/chat", get(ws::handle_ws_chat))
|
||||
// ── WebSocket canvas updates ──
|
||||
.route("/ws/canvas/{id}", get(canvas::handle_ws_canvas))
|
||||
// ── WebSocket node discovery ──
|
||||
.route("/ws/nodes", get(nodes::handle_ws_nodes))
|
||||
// ── Static assets (web dashboard) ──
|
||||
.route("/_app/{*path}", get(static_files::handle_static))
|
||||
// ── Config PUT with larger body limit ──
|
||||
.merge(config_put_router)
|
||||
// ── SPA fallback: non-API GET requests serve index.html ──
|
||||
.fallback(get(static_files::handle_spa_fallback))
|
||||
.with_state(state)
|
||||
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
|
||||
.layer(TimeoutLayer::with_status_code(
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
Duration::from_secs(gateway_request_timeout_secs()),
|
||||
))
|
||||
// ── SPA fallback: non-API GET requests serve index.html ──
|
||||
.fallback(get(static_files::handle_spa_fallback));
|
||||
));
|
||||
|
||||
// Nest under path prefix when configured (axum strips prefix before routing).
|
||||
// nest() at "/prefix" handles both "/prefix" and "/prefix/*" but not "/prefix/"
|
||||
// with a trailing slash, so we add a fallback redirect for that case.
|
||||
let app = if let Some(prefix) = path_prefix {
|
||||
let redirect_target = prefix.to_string();
|
||||
Router::new().nest(prefix, inner).route(
|
||||
&format!("{prefix}/"),
|
||||
get(|| async move { axum::response::Redirect::permanent(&redirect_target) }),
|
||||
)
|
||||
} else {
|
||||
inner
|
||||
};
|
||||
|
||||
// Run the server with graceful shutdown
|
||||
axum::serve(
|
||||
@@ -1788,6 +1847,74 @@ async fn handle_nextcloud_talk_webhook(
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
/// Maximum request body size for the Gmail webhook endpoint (1 MB).
|
||||
/// Google Pub/Sub messages are typically under 10 KB.
|
||||
const GMAIL_WEBHOOK_MAX_BODY: usize = 1024 * 1024;
|
||||
|
||||
/// POST /webhook/gmail — incoming Gmail Pub/Sub push notification
|
||||
async fn handle_gmail_push_webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> impl IntoResponse {
|
||||
let Some(ref gmail_push) = state.gmail_push else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Gmail push not configured"})),
|
||||
);
|
||||
};
|
||||
|
||||
// Enforce body size limit.
|
||||
if body.len() > GMAIL_WEBHOOK_MAX_BODY {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Json(serde_json::json!({"error": "Request body too large"})),
|
||||
);
|
||||
}
|
||||
|
||||
// Authenticate the webhook request using a shared secret.
|
||||
let secret = gmail_push.resolve_webhook_secret();
|
||||
if !secret.is_empty() {
|
||||
let provided = headers
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.unwrap_or("");
|
||||
|
||||
if provided != secret {
|
||||
tracing::warn!("Gmail push webhook: unauthorized request");
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"error": "Unauthorized"})),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let body_str = String::from_utf8_lossy(&body);
|
||||
let envelope: crate::channels::gmail_push::PubSubEnvelope =
|
||||
match serde_json::from_str(&body_str) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!("Gmail push webhook: invalid payload: {e}");
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Invalid Pub/Sub envelope"})),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Process the notification asynchronously (non-blocking for the webhook response)
|
||||
let channel = Arc::clone(gmail_push);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = channel.handle_notification(&envelope).await {
|
||||
tracing::error!("Gmail push notification processing failed: {e:#}");
|
||||
}
|
||||
});
|
||||
|
||||
// Acknowledge immediately — Google Pub/Sub requires a 2xx within ~10s
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// ADMIN HANDLERS (for CLI management)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -1976,15 +2103,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_metrics(State(state)).await.into_response();
|
||||
@@ -2031,15 +2161,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer,
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_metrics(State(state)).await.into_response();
|
||||
@@ -2415,15 +2548,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2484,15 +2620,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
@@ -2565,15 +2704,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_webhook(
|
||||
@@ -2618,15 +2760,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2676,15 +2821,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2739,15 +2887,18 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = Box::pin(handle_nextcloud_talk_webhook(
|
||||
@@ -2798,15 +2949,18 @@ mod tests {
|
||||
nextcloud_talk: Some(channel),
|
||||
nextcloud_talk_webhook_secret: Some(Arc::from(secret)),
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
//! Uses `rust-embed` to bundle the `web/dist/` directory into the binary at compile time.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{header, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use rust_embed::Embed;
|
||||
|
||||
use super::AppState;
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "web/dist/"]
|
||||
struct WebAssets;
|
||||
@@ -23,16 +26,41 @@ pub async fn handle_static(uri: Uri) -> Response {
|
||||
serve_embedded_file(path)
|
||||
}
|
||||
|
||||
/// SPA fallback: serve index.html for any non-API, non-static GET request
|
||||
pub async fn handle_spa_fallback() -> Response {
|
||||
if WebAssets::get("index.html").is_none() {
|
||||
/// SPA fallback: serve index.html for any non-API, non-static GET request.
|
||||
/// Injects `window.__ZEROCLAW_BASE__` so the frontend knows the path prefix.
|
||||
pub async fn handle_spa_fallback(State(state): State<AppState>) -> Response {
|
||||
let Some(content) = WebAssets::get("index.html") else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"Web dashboard not available. Build it with: cd web && npm ci && npm run build",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
serve_embedded_file("index.html")
|
||||
};
|
||||
|
||||
let html = String::from_utf8_lossy(&content.data);
|
||||
|
||||
// Inject path prefix for the SPA and rewrite asset paths in the HTML
|
||||
let html = if state.path_prefix.is_empty() {
|
||||
html.into_owned()
|
||||
} else {
|
||||
let pfx = &state.path_prefix;
|
||||
// JSON-encode the prefix to safely embed in a <script> block
|
||||
let json_pfx = serde_json::to_string(pfx).unwrap_or_else(|_| "\"\"".to_string());
|
||||
let script = format!("<script>window.__ZEROCLAW_BASE__={json_pfx};</script>");
|
||||
// Rewrite absolute /_app/ references so the browser requests {prefix}/_app/...
|
||||
html.replace("/_app/", &format!("{pfx}/_app/"))
|
||||
.replace("<head>", &format!("<head>{script}"))
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, "text/html; charset=utf-8".to_string()),
|
||||
(header::CACHE_CONTROL, "no-cache".to_string()),
|
||||
],
|
||||
html,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn serve_embedded_file(path: &str) -> Response {
|
||||
|
||||
+34
-3
@@ -1,13 +1,21 @@
|
||||
//! WebSocket agent chat handler.
|
||||
//!
|
||||
//! Connect: `ws://host:port/ws/chat?session_id=ID&name=My+Session`
|
||||
//!
|
||||
//! Protocol:
|
||||
//! ```text
|
||||
//! Server -> Client: {"type":"session_start","session_id":"...","name":"...","resumed":true,"message_count":42}
|
||||
//! Client -> Server: {"type":"message","content":"Hello"}
|
||||
//! Server -> Client: {"type":"chunk","content":"Hi! "}
|
||||
//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
|
||||
//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
|
||||
//! Server -> Client: {"type":"done","full_response":"..."}
|
||||
//! ```
|
||||
//!
|
||||
//! Query params:
|
||||
//! - `session_id` — resume or create a session (default: new UUID)
|
||||
//! - `name` — optional human-readable label for the session
|
||||
//! - `token` — bearer auth token (alternative to Authorization header)
|
||||
|
||||
use super::AppState;
|
||||
use axum::{
|
||||
@@ -53,6 +61,8 @@ const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
|
||||
pub struct WsQuery {
|
||||
pub token: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
/// Optional human-readable name for the session.
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a bearer token from WebSocket-compatible sources.
|
||||
@@ -134,14 +144,20 @@ pub async fn handle_ws_chat(
|
||||
};
|
||||
|
||||
let session_id = params.session_id;
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
|
||||
let session_name = params.name;
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Gateway session key prefix to avoid collisions with channel sessions.
|
||||
const GW_SESSION_PREFIX: &str = "gw_";
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<String>) {
|
||||
async fn handle_socket(
|
||||
socket: WebSocket,
|
||||
state: AppState,
|
||||
session_id: Option<String>,
|
||||
session_name: Option<String>,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Resolve session ID: use provided or generate a new UUID
|
||||
@@ -163,6 +179,7 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
|
||||
// Hydrate agent from persisted session (if available)
|
||||
let mut resumed = false;
|
||||
let mut message_count: usize = 0;
|
||||
let mut effective_name: Option<String> = None;
|
||||
if let Some(ref backend) = state.session_backend {
|
||||
let messages = backend.load(&session_key);
|
||||
if !messages.is_empty() {
|
||||
@@ -170,15 +187,29 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
|
||||
agent.seed_history(&messages);
|
||||
resumed = true;
|
||||
}
|
||||
// Set session name if provided (non-empty) on connect
|
||||
if let Some(ref name) = session_name {
|
||||
if !name.is_empty() {
|
||||
let _ = backend.set_session_name(&session_key, name);
|
||||
effective_name = Some(name.clone());
|
||||
}
|
||||
}
|
||||
// If no name was provided via query param, load the stored name
|
||||
if effective_name.is_none() {
|
||||
effective_name = backend.get_session_name(&session_key).unwrap_or(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Send session_start message to client
|
||||
let session_start = serde_json::json!({
|
||||
let mut session_start = serde_json::json!({
|
||||
"type": "session_start",
|
||||
"session_id": session_id,
|
||||
"resumed": resumed,
|
||||
"message_count": message_count,
|
||||
});
|
||||
if let Some(ref name) = effective_name {
|
||||
session_start["name"] = serde_json::Value::String(name.clone());
|
||||
}
|
||||
let _ = sender
|
||||
.send(Message::Text(session_start.to_string().into()))
|
||||
.await;
|
||||
|
||||
@@ -407,7 +407,11 @@ mod tests {
|
||||
// Simpler: write a temp script.
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let script_path = dir.path().join("tool.sh");
|
||||
std::fs::write(&script_path, format!("#!/bin/sh\necho '{}'\n", result_json)).unwrap();
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
format!("#!/bin/sh\ncat > /dev/null\necho '{}'\n", result_json),
|
||||
)
|
||||
.unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
@@ -891,6 +891,7 @@ mod tests {
|
||||
device_id: None,
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec![],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
});
|
||||
let entries = all_integrations();
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
//! Audit trail for memory operations.
|
||||
//!
|
||||
//! Provides a decorator `AuditedMemory<M>` that wraps any `Memory` backend
|
||||
//! and logs all operations to a `memory_audit` table. Opt-in via
|
||||
//! `[memory] audit_enabled = true`.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Audit log entry operations.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum AuditOp {
|
||||
Store,
|
||||
Recall,
|
||||
Get,
|
||||
List,
|
||||
Forget,
|
||||
StoreProcedural,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AuditOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Store => write!(f, "store"),
|
||||
Self::Recall => write!(f, "recall"),
|
||||
Self::Get => write!(f, "get"),
|
||||
Self::List => write!(f, "list"),
|
||||
Self::Forget => write!(f, "forget"),
|
||||
Self::StoreProcedural => write!(f, "store_procedural"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decorator that wraps a `Memory` backend with audit logging.
|
||||
pub struct AuditedMemory<M: Memory> {
|
||||
inner: M,
|
||||
audit_conn: Arc<Mutex<Connection>>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl<M: Memory> AuditedMemory<M> {
|
||||
pub fn new(inner: M, workspace_dir: &Path) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join("audit.db");
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&db_path)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
CREATE TABLE IF NOT EXISTS memory_audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
operation TEXT NOT NULL,
|
||||
key TEXT,
|
||||
namespace TEXT,
|
||||
session_id TEXT,
|
||||
timestamp TEXT NOT NULL,
|
||||
metadata TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON memory_audit(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_operation ON memory_audit(operation);",
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
audit_conn: Arc::new(Mutex::new(conn)),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn log_audit(
|
||||
&self,
|
||||
op: AuditOp,
|
||||
key: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
session_id: Option<&str>,
|
||||
metadata: Option<&str>,
|
||||
) {
|
||||
let conn = self.audit_conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let op_str = op.to_string();
|
||||
let _ = conn.execute(
|
||||
"INSERT INTO memory_audit (operation, key, namespace, session_id, timestamp, metadata)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![op_str, key, namespace, session_id, now, metadata],
|
||||
);
|
||||
}
|
||||
|
||||
/// Prune audit entries older than the given number of days.
|
||||
pub fn prune_older_than(&self, retention_days: u32) -> anyhow::Result<u64> {
|
||||
let conn = self.audit_conn.lock();
|
||||
let cutoff =
|
||||
(Local::now() - chrono::Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memory_audit WHERE timestamp < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
Ok(u64::try_from(affected).unwrap_or(0))
|
||||
}
|
||||
|
||||
/// Count total audit entries.
|
||||
pub fn audit_count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self.audit_conn.lock();
|
||||
let count: i64 =
|
||||
conn.query_row("SELECT COUNT(*) FROM memory_audit", [], |row| row.get(0))?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(count as usize)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<M: Memory> Memory for AuditedMemory<M> {
|
||||
fn name(&self) -> &str {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(AuditOp::Store, Some(key), None, session_id, None);
|
||||
self.inner.store(key, content, category, session_id).await
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(
|
||||
AuditOp::Recall,
|
||||
None,
|
||||
None,
|
||||
session_id,
|
||||
Some(&format!("query={query}")),
|
||||
);
|
||||
self.inner
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
self.log_audit(AuditOp::Get, Some(key), None, None, None);
|
||||
self.inner.get(key).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(AuditOp::List, None, None, session_id, None);
|
||||
self.inner.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
self.log_audit(AuditOp::Forget, Some(key), None, None, None);
|
||||
self.inner.forget(key).await
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
self.inner.count().await
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.inner.health_check().await
|
||||
}
|
||||
|
||||
async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(
|
||||
AuditOp::StoreProcedural,
|
||||
None,
|
||||
None,
|
||||
session_id,
|
||||
Some(&format!("messages={}", messages.len())),
|
||||
);
|
||||
self.inner.store_procedural(messages, session_id).await
|
||||
}
|
||||
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(
|
||||
AuditOp::Recall,
|
||||
None,
|
||||
Some(namespace),
|
||||
session_id,
|
||||
Some(&format!("query={query}")),
|
||||
);
|
||||
self.inner
|
||||
.recall_namespaced(namespace, query, limit, session_id, since, until)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(AuditOp::Store, Some(key), namespace, session_id, None);
|
||||
self.inner
|
||||
.store_with_metadata(key, content, category, session_id, namespace, importance)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::NoneMemory;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_logs_store_operation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
audited
|
||||
.store("test_key", "test_value", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(audited.audit_count().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_logs_recall_operation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
let _ = audited.recall("query", 10, None, None, None).await;
|
||||
|
||||
assert_eq!(audited.audit_count().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_prune_works() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
audited
|
||||
.store("k1", "v1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Pruning with 0 days should remove entries
|
||||
let pruned = audited.prune_older_than(0).unwrap();
|
||||
// Entry was just created, so 0-day retention should remove it
|
||||
// Pruning should succeed (pruned is usize, always >= 0)
|
||||
let _ = pruned;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_delegates_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
assert_eq!(audited.name(), "none");
|
||||
assert!(audited.health_check().await);
|
||||
assert_eq!(audited.count().await.unwrap(), 0);
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ pub enum MemoryBackendKind {
|
||||
Lucid,
|
||||
Postgres,
|
||||
Qdrant,
|
||||
Mem0,
|
||||
Markdown,
|
||||
None,
|
||||
Unknown,
|
||||
@@ -66,15 +65,6 @@ const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const MEM0_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "mem0",
|
||||
label: "Mem0 (OpenMemory) — semantic memory with LLM fact extraction via [memory.mem0]",
|
||||
auto_save_default: true,
|
||||
uses_sqlite_hygiene: false,
|
||||
sqlite_based: false,
|
||||
optional_dependency: true,
|
||||
};
|
||||
|
||||
const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "none",
|
||||
label: "None — disable persistent memory",
|
||||
@@ -114,7 +104,6 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
|
||||
"lucid" => MemoryBackendKind::Lucid,
|
||||
"postgres" => MemoryBackendKind::Postgres,
|
||||
"qdrant" => MemoryBackendKind::Qdrant,
|
||||
"mem0" | "openmemory" => MemoryBackendKind::Mem0,
|
||||
"markdown" => MemoryBackendKind::Markdown,
|
||||
"none" => MemoryBackendKind::None,
|
||||
_ => MemoryBackendKind::Unknown,
|
||||
@@ -127,7 +116,6 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
|
||||
MemoryBackendKind::Lucid => LUCID_PROFILE,
|
||||
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
|
||||
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
|
||||
MemoryBackendKind::Mem0 => MEM0_PROFILE,
|
||||
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
|
||||
MemoryBackendKind::None => NONE_PROFILE,
|
||||
MemoryBackendKind::Unknown => CUSTOM_PROFILE,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,173 @@
|
||||
//! Conflict resolution for memory entries.
|
||||
//!
|
||||
//! Before storing Core memories, performs a semantic similarity check against
|
||||
//! existing entries. If cosine similarity exceeds a threshold but content
|
||||
//! differs, the old entry is marked as superseded.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
|
||||
/// Check for conflicting memories and mark old ones as superseded.
|
||||
///
|
||||
/// Returns the list of entry IDs that were superseded.
|
||||
pub async fn check_and_resolve_conflicts(
|
||||
memory: &dyn Memory,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: &MemoryCategory,
|
||||
threshold: f64,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
// Only check conflicts for Core memories
|
||||
if !matches!(category, MemoryCategory::Core) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Search for similar existing entries
|
||||
let candidates = memory.recall(content, 10, None, None, None).await?;
|
||||
|
||||
let mut superseded = Vec::new();
|
||||
for candidate in &candidates {
|
||||
if candidate.key == key {
|
||||
continue; // Same key = update, not conflict
|
||||
}
|
||||
if !matches!(candidate.category, MemoryCategory::Core) {
|
||||
continue;
|
||||
}
|
||||
if let Some(score) = candidate.score {
|
||||
if score > threshold && candidate.content != content {
|
||||
superseded.push(candidate.id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(superseded)
|
||||
}
|
||||
|
||||
/// Mark entries as superseded in SQLite by setting their `superseded_by` column.
|
||||
pub fn mark_superseded(
|
||||
conn: &rusqlite::Connection,
|
||||
superseded_ids: &[String],
|
||||
new_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if superseded_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for id in superseded_ids {
|
||||
conn.execute(
|
||||
"UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
|
||||
rusqlite::params![new_id, id],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Simple text-based conflict detection without embeddings.
|
||||
///
|
||||
/// Uses token overlap (Jaccard similarity) as a fast approximation
|
||||
/// when vector embeddings are unavailable.
|
||||
pub fn jaccard_similarity(a: &str, b: &str) -> f64 {
|
||||
let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
|
||||
let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
|
||||
|
||||
if words_a.is_empty() && words_b.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
if words_a.is_empty() || words_b.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let intersection = words_a.intersection(&words_b).count();
|
||||
let union = words_a.union(&words_b).count();
|
||||
|
||||
if union == 0 {
|
||||
0.0
|
||||
} else {
|
||||
intersection as f64 / union as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Find potentially conflicting entries using text similarity when embeddings
|
||||
/// are not available. Returns entries above the threshold.
|
||||
pub fn find_text_conflicts(
|
||||
entries: &[MemoryEntry],
|
||||
new_content: &str,
|
||||
threshold: f64,
|
||||
) -> Vec<String> {
|
||||
entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
matches!(e.category, MemoryCategory::Core)
|
||||
&& e.superseded_by.is_none()
|
||||
&& jaccard_similarity(&e.content, new_content) > threshold
|
||||
&& e.content != new_content
|
||||
})
|
||||
.map(|e| e.id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn jaccard_identical_strings() {
|
||||
let sim = jaccard_similarity("hello world", "hello world");
|
||||
assert!((sim - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_disjoint_strings() {
|
||||
let sim = jaccard_similarity("hello world", "foo bar");
|
||||
assert!(sim.abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_partial_overlap() {
|
||||
let sim = jaccard_similarity("the quick brown fox", "the slow brown dog");
|
||||
// overlap: "the", "brown" = 2; union: "the", "quick", "brown", "fox", "slow", "dog" = 6
|
||||
assert!((sim - 2.0 / 6.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_empty_strings() {
|
||||
assert!((jaccard_similarity("", "") - 1.0).abs() < f64::EPSILON);
|
||||
assert!(jaccard_similarity("hello", "").abs() < f64::EPSILON);
|
||||
assert!(jaccard_similarity("", "hello").abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_text_conflicts_filters_correctly() {
|
||||
let entries = vec![
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "pref".into(),
|
||||
content: "User prefers Rust for systems work".into(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.7),
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
key: "daily1".into(),
|
||||
content: "User prefers Rust for systems work".into(),
|
||||
category: MemoryCategory::Daily,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.3),
|
||||
superseded_by: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Only Core entries should be flagged
|
||||
let conflicts = find_text_conflicts(&entries, "User now prefers Go for systems work", 0.3);
|
||||
assert_eq!(conflicts.len(), 1);
|
||||
assert_eq!(conflicts[0], "1");
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@
|
||||
//! This two-phase approach replaces the naive raw-message auto-save with
|
||||
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
|
||||
|
||||
use crate::memory::conflict;
|
||||
use crate::memory::importance;
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
@@ -78,8 +80,33 @@ pub async fn consolidate_turn(
|
||||
if let Some(ref update) = result.memory_update {
|
||||
if !update.trim().is_empty() {
|
||||
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
|
||||
|
||||
// Compute importance score heuristically.
|
||||
let imp = importance::compute_importance(update, &MemoryCategory::Core);
|
||||
|
||||
// Check for conflicts with existing Core memories.
|
||||
if let Err(e) = conflict::check_and_resolve_conflicts(
|
||||
memory,
|
||||
&mem_key,
|
||||
update,
|
||||
&MemoryCategory::Core,
|
||||
0.85,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::debug!("conflict check skipped: {e}");
|
||||
}
|
||||
|
||||
// Store with importance metadata.
|
||||
memory
|
||||
.store(&mem_key, update, MemoryCategory::Core, None)
|
||||
.store_with_metadata(
|
||||
&mem_key,
|
||||
update,
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
None,
|
||||
Some(imp),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
use super::traits::{MemoryCategory, MemoryEntry};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Default half-life in days for time-decay scoring.
|
||||
/// After this many days, a non-Core memory's score drops to 50%.
|
||||
pub const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Apply exponential time decay to memory entry scores.
|
||||
///
|
||||
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
|
||||
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
|
||||
/// - Entries without a score (`None`) are left unchanged.
|
||||
///
|
||||
/// Decay formula: `score * 2^(-age_days / half_life_days)`
|
||||
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
|
||||
let half_life = if half_life_days <= 0.0 {
|
||||
DEFAULT_HALF_LIFE_DAYS
|
||||
} else {
|
||||
half_life_days
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
// Core memories are evergreen — never decay
|
||||
if entry.category == MemoryCategory::Core {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score = match entry.score {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
Ok(dt) => dt.with_timezone(&Utc),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let age_days = now.signed_duration_since(ts).num_seconds().max(0) as f64 / 86_400.0;
|
||||
|
||||
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
|
||||
entry.score = Some(score * decay_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "test".into(),
|
||||
content: "value".into(),
|
||||
category,
|
||||
timestamp: timestamp.into(),
|
||||
session_id: None,
|
||||
score,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn recent_rfc3339() -> String {
|
||||
Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
fn days_ago_rfc3339(days: i64) -> String {
|
||||
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn core_memories_are_never_decayed() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Core,
|
||||
Some(0.9),
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recent_entry_score_barely_changes() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.8),
|
||||
&recent_rfc3339(),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.8).abs() < 0.01,
|
||||
"recent entry should barely decay, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_half_life_halves_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(7),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.5).abs() < 0.05,
|
||||
"score after one half-life should be ~0.5, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_half_lives_quarters_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(14),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.25).abs() < 0.05,
|
||||
"score after two half-lives should be ~0.25, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_score_entry_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_timestamp_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.9),
|
||||
"not-a-date",
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
}
|
||||
+42
-4
@@ -1,4 +1,5 @@
|
||||
use crate::config::MemoryConfig;
|
||||
use crate::memory::policy::PolicyEnforcer;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Local, NaiveDate, Utc};
|
||||
use rusqlite::{params, Connection};
|
||||
@@ -47,6 +48,13 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Use policy engine for per-category retention overrides.
|
||||
let enforcer = PolicyEnforcer::new(&config.policy);
|
||||
let conversation_retention = enforcer.retention_days_for_category(
|
||||
&crate::memory::traits::MemoryCategory::Conversation,
|
||||
config.conversation_retention_days,
|
||||
);
|
||||
|
||||
let report = HygieneReport {
|
||||
archived_memory_files: archive_daily_memory_files(
|
||||
workspace_dir,
|
||||
@@ -55,12 +63,16 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||
archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?,
|
||||
purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?,
|
||||
purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?,
|
||||
pruned_conversation_rows: prune_conversation_rows(
|
||||
workspace_dir,
|
||||
config.conversation_retention_days,
|
||||
)?,
|
||||
pruned_conversation_rows: prune_conversation_rows(workspace_dir, conversation_retention)?,
|
||||
};
|
||||
|
||||
// Prune audit entries if audit is enabled.
|
||||
if config.audit_enabled {
|
||||
if let Err(e) = prune_audit_entries(workspace_dir, config.audit_retention_days) {
|
||||
tracing::debug!("audit pruning skipped: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
write_state(workspace_dir, &report)?;
|
||||
|
||||
if report.total_actions() > 0 {
|
||||
@@ -318,6 +330,32 @@ fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result<
|
||||
Ok(u64::try_from(affected).unwrap_or(0))
|
||||
}
|
||||
|
||||
fn prune_audit_entries(workspace_dir: &Path, retention_days: u32) -> Result<()> {
|
||||
if retention_days == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let db_path = workspace_dir.join("memory").join("audit.db");
|
||||
if !db_path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)?;
|
||||
conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?;
|
||||
let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memory_audit WHERE timestamp < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
|
||||
if affected > 0 {
|
||||
tracing::debug!("pruned {affected} audit entries older than {retention_days} days");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn memory_date_from_filename(filename: &str) -> Option<NaiveDate> {
|
||||
let stem = filename.strip_suffix(".md")?;
|
||||
let date_part = stem.split('_').next().unwrap_or(stem);
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
//! Heuristic importance scorer for non-LLM paths.
|
||||
//!
|
||||
//! Assigns importance scores (0.0–1.0) based on memory category and keyword
|
||||
//! signals. Used when LLM-based consolidation is unavailable or as a fast
|
||||
//! first-pass scorer.
|
||||
|
||||
use super::traits::MemoryCategory;
|
||||
|
||||
/// Base importance by category.
|
||||
fn category_base_score(category: &MemoryCategory) -> f64 {
|
||||
match category {
|
||||
MemoryCategory::Core => 0.7,
|
||||
MemoryCategory::Daily => 0.3,
|
||||
MemoryCategory::Conversation => 0.2,
|
||||
MemoryCategory::Custom(_) => 0.4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyword boost: if the content contains high-signal keywords, bump importance.
|
||||
fn keyword_boost(content: &str) -> f64 {
|
||||
const HIGH_SIGNAL_KEYWORDS: &[&str] = &[
|
||||
"decision",
|
||||
"always",
|
||||
"never",
|
||||
"important",
|
||||
"critical",
|
||||
"must",
|
||||
"requirement",
|
||||
"policy",
|
||||
"rule",
|
||||
"principle",
|
||||
];
|
||||
|
||||
let lowered = content.to_ascii_lowercase();
|
||||
let matches = HIGH_SIGNAL_KEYWORDS
|
||||
.iter()
|
||||
.filter(|kw| lowered.contains(**kw))
|
||||
.count();
|
||||
|
||||
// Cap at +0.2
|
||||
(matches as f64 * 0.1).min(0.2)
|
||||
}
|
||||
|
||||
/// Compute heuristic importance score for a memory entry.
|
||||
pub fn compute_importance(content: &str, category: &MemoryCategory) -> f64 {
|
||||
let base = category_base_score(category);
|
||||
let boost = keyword_boost(content);
|
||||
(base + boost).min(1.0)
|
||||
}
|
||||
|
||||
/// Compute final retrieval score incorporating importance and recency.
|
||||
///
|
||||
/// `hybrid_score`: raw retrieval score from FTS/vector (0.0–1.0)
|
||||
/// `importance`: importance score (0.0–1.0)
|
||||
/// `recency_decay`: recency factor (0.0–1.0, 1.0 = very recent)
|
||||
pub fn weighted_final_score(hybrid_score: f64, importance: f64, recency_decay: f64) -> f64 {
|
||||
hybrid_score * 0.7 + importance * 0.2 + recency_decay * 0.1
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn core_category_has_high_base_score() {
|
||||
let score = compute_importance("some fact", &MemoryCategory::Core);
|
||||
assert!((score - 0.7).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversation_category_has_low_base_score() {
|
||||
let score = compute_importance("chat message", &MemoryCategory::Conversation);
|
||||
assert!((score - 0.2).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keywords_boost_importance() {
|
||||
let score = compute_importance(
|
||||
"This is a critical decision that must always be followed",
|
||||
&MemoryCategory::Core,
|
||||
);
|
||||
// base 0.7 + boost for "critical", "decision", "must", "always" = 0.7 + 0.2 (capped) = 0.9
|
||||
assert!(score > 0.85);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boost_capped_at_point_two() {
|
||||
let score = compute_importance(
|
||||
"important critical decision rule policy must always never requirement principle",
|
||||
&MemoryCategory::Conversation,
|
||||
);
|
||||
// base 0.2 + max boost 0.2 = 0.4
|
||||
assert!((score - 0.4).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weighted_final_score_formula() {
|
||||
let score = weighted_final_score(1.0, 1.0, 1.0);
|
||||
assert!((score - 1.0).abs() < f64::EPSILON);
|
||||
|
||||
let score = weighted_final_score(0.0, 0.0, 0.0);
|
||||
assert!(score.abs() < f64::EPSILON);
|
||||
|
||||
let score = weighted_final_score(0.5, 0.5, 0.5);
|
||||
assert!((score - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -226,6 +226,9 @@ impl LucidMemory {
|
||||
timestamp: now.clone(),
|
||||
session_id: None,
|
||||
score: Some((1.0 - rank as f64 * 0.05).max(0.1)),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -91,6 +91,9 @@ impl MarkdownMemory {
|
||||
timestamp: filename.to_string(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
|
||||
@@ -1,635 +0,0 @@
|
||||
//! Mem0 (OpenMemory) memory backend.
|
||||
//!
|
||||
//! Connects to a self-hosted OpenMemory server via its REST API
|
||||
//! and implements the [`Memory`] trait for seamless integration with
|
||||
//! ZeroClaw's auto-save, auto-recall, and hygiene lifecycle.
|
||||
//!
|
||||
//! Deploy OpenMemory: `docker compose up` from the mem0 repo.
|
||||
//! Default endpoint: `http://localhost:8765`.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
|
||||
use crate::config::schema::Mem0Config;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Memory backend backed by a mem0 (OpenMemory) REST API.
|
||||
pub struct Mem0Memory {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
user_id: String,
|
||||
app_name: String,
|
||||
infer: bool,
|
||||
extraction_prompt: Option<String>,
|
||||
}
|
||||
|
||||
// ── mem0 API request/response types ────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AddMemoryRequest<'a> {
|
||||
user_id: &'a str,
|
||||
text: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
metadata: Option<Mem0Metadata<'a>>,
|
||||
infer: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
app: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
custom_instructions: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Mem0Metadata<'a> {
|
||||
key: &'a str,
|
||||
category: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AddProceduralRequest<'a> {
|
||||
user_id: &'a str,
|
||||
messages: &'a [ProceduralMessage],
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DeleteMemoriesRequest<'a> {
|
||||
memory_ids: Vec<&'a str>,
|
||||
user_id: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0MemoryItem {
|
||||
id: String,
|
||||
#[serde(alias = "content", alias = "text", default)]
|
||||
memory: String,
|
||||
#[serde(default)]
|
||||
created_at: Option<serde_json::Value>,
|
||||
#[serde(default, rename = "metadata_")]
|
||||
metadata: Option<Mem0ResponseMetadata>,
|
||||
#[serde(alias = "relevance_score", default)]
|
||||
score: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
struct Mem0ResponseMetadata {
|
||||
#[serde(default)]
|
||||
key: Option<String>,
|
||||
#[serde(default)]
|
||||
category: Option<String>,
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0ListResponse {
|
||||
#[serde(default)]
|
||||
items: Vec<Mem0MemoryItem>,
|
||||
#[serde(default)]
|
||||
total: usize,
|
||||
}
|
||||
|
||||
// ── Implementation ─────────────────────────────────────────────────
|
||||
|
||||
impl Mem0Memory {
|
||||
/// Create a new mem0 memory backend from config.
|
||||
pub fn new(config: &Mem0Config) -> anyhow::Result<Self> {
|
||||
let base_url = config.url.trim_end_matches('/').to_string();
|
||||
if base_url.is_empty() {
|
||||
anyhow::bail!("mem0 URL is empty; set [memory.mem0] url or MEM0_URL env var");
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
user_id: config.user_id.clone(),
|
||||
app_name: config.app_name.clone(),
|
||||
infer: config.infer,
|
||||
extraction_prompt: config.extraction_prompt.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn api_url(&self, path: &str) -> String {
|
||||
format!("{}/api/v1{}", self.base_url, path)
|
||||
}
|
||||
|
||||
/// Use `session_id` as the effective mem0 `user_id` when provided,
|
||||
/// falling back to the configured default. This enables per-user
|
||||
/// and per-group memory scoping via the existing `Memory` trait.
|
||||
fn effective_user_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
|
||||
session_id
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
/// Recall memories with optional search filters.
|
||||
///
|
||||
/// - `created_after` / `created_before`: ISO 8601 timestamps for time-range filtering.
|
||||
/// - `metadata_filter`: arbitrary JSON object passed to the mem0 SDK `filters` param.
|
||||
pub async fn recall_filtered(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
created_after: Option<&str>,
|
||||
created_before: Option<&str>,
|
||||
metadata_filter: Option<&serde_json::Value>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let limit_str = limit.to_string();
|
||||
let mut params: Vec<(&str, &str)> = vec![
|
||||
("user_id", effective_user),
|
||||
("search_query", query),
|
||||
("size", &limit_str),
|
||||
];
|
||||
if let Some(after) = created_after {
|
||||
params.push(("created_after", after));
|
||||
}
|
||||
if let Some(before) = created_before {
|
||||
params.push(("created_before", before));
|
||||
}
|
||||
let meta_json;
|
||||
if let Some(mf) = metadata_filter {
|
||||
meta_json = serde_json::to_string(mf)?;
|
||||
params.push(("metadata_filter", &meta_json));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 recall failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
Ok(list.items.into_iter().map(|i| self.to_entry(i)).collect())
|
||||
}
|
||||
|
||||
fn to_entry(&self, item: Mem0MemoryItem) -> MemoryEntry {
|
||||
let meta = item.metadata.unwrap_or_default();
|
||||
let timestamp = match item.created_at {
|
||||
Some(serde_json::Value::Number(n)) => {
|
||||
// Unix timestamp → ISO 8601
|
||||
if let Some(ts) = n.as_i64() {
|
||||
chrono::DateTime::from_timestamp(ts, 0)
|
||||
.map(|dt| dt.to_rfc3339())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
Some(serde_json::Value::String(s)) => s,
|
||||
_ => String::new(),
|
||||
};
|
||||
|
||||
let category = match meta.category.as_deref() {
|
||||
Some("daily") => MemoryCategory::Daily,
|
||||
Some("conversation") => MemoryCategory::Conversation,
|
||||
Some(other) if other != "core" => MemoryCategory::Custom(other.to_string()),
|
||||
// "core" or None → default
|
||||
_ => MemoryCategory::Core,
|
||||
};
|
||||
|
||||
MemoryEntry {
|
||||
id: item.id,
|
||||
key: meta.key.unwrap_or_default(),
|
||||
content: item.memory,
|
||||
category,
|
||||
timestamp,
|
||||
session_id: meta.session_id,
|
||||
score: item.score,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a conversation trace as procedural memory.
|
||||
///
|
||||
/// Sends the message history (user input, tool calls, assistant response)
|
||||
/// to the mem0 procedural endpoint so that "how to" patterns can be
|
||||
/// extracted and stored for future recall.
|
||||
pub async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
if messages.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let body = AddProceduralRequest {
|
||||
user_id: effective_user,
|
||||
messages,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url("/memories/procedural"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 store_procedural failed ({status}): {text}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── History API types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0HistoryResponse {
|
||||
#[serde(default)]
|
||||
history: Vec<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl Mem0Memory {
|
||||
/// Retrieve the edit history (audit trail) for a specific memory by ID.
|
||||
pub async fn history(&self, memory_id: &str) -> anyhow::Result<String> {
|
||||
let url = self.api_url(&format!("/memories/{memory_id}/history"));
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 history failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let body: Mem0HistoryResponse = resp.json().await?;
|
||||
|
||||
if let Some(err) = body.error {
|
||||
anyhow::bail!("mem0 history error: {err}");
|
||||
}
|
||||
|
||||
if body.history.is_empty() {
|
||||
return Ok(format!("No history found for memory {memory_id}."));
|
||||
}
|
||||
|
||||
let mut lines = Vec::with_capacity(body.history.len() + 1);
|
||||
lines.push(format!("History for memory {memory_id}:"));
|
||||
|
||||
for (i, entry) in body.history.iter().enumerate() {
|
||||
let event = entry
|
||||
.get("event")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let old_memory = entry
|
||||
.get("old_memory")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-");
|
||||
let new_memory = entry
|
||||
.get("new_memory")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-");
|
||||
let timestamp = entry
|
||||
.get("created_at")
|
||||
.or_else(|| entry.get("timestamp"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
lines.push(format!(
|
||||
" {idx}. [{event}] at {timestamp}\n old: {old_memory}\n new: {new_memory}",
|
||||
idx = i + 1,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for Mem0Memory {
|
||||
fn name(&self) -> &str {
|
||||
"mem0"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cat_str = category.to_string();
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let body = AddMemoryRequest {
|
||||
user_id: effective_user,
|
||||
text: content,
|
||||
metadata: Some(Mem0Metadata {
|
||||
key,
|
||||
category: &cat_str,
|
||||
session_id,
|
||||
}),
|
||||
infer: self.infer,
|
||||
app: Some(&self.app_name),
|
||||
custom_instructions: self.extraction_prompt.as_deref(),
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url("/memories/"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 store failed ({status}): {text}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// mem0 handles filtering server-side; since/until are not yet
|
||||
// supported by the mem0 API, so we pass them through as no-ops.
|
||||
self.recall_filtered(query, limit, session_id, None, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
// mem0 doesn't have a get-by-key API, so we search by key in metadata
|
||||
let results = self.recall(key, 1, None, None, None).await?;
|
||||
Ok(results.into_iter().find(|e| e.key == key))
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[("user_id", effective_user), ("size", "100")])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 list failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
let entries: Vec<MemoryEntry> = list.items.into_iter().map(|i| self.to_entry(i)).collect();
|
||||
|
||||
// Client-side category filter (mem0 API doesn't filter by metadata)
|
||||
match category {
|
||||
Some(cat) => Ok(entries.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
None => Ok(entries),
|
||||
}
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
// Find the memory ID by key first
|
||||
let entry = self.get(key).await?;
|
||||
let entry = match entry {
|
||||
Some(e) => e,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let body = DeleteMemoriesRequest {
|
||||
memory_ids: vec![&entry.id],
|
||||
user_id: &self.user_id,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.delete(self.api_url("/memories/"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[
|
||||
("user_id", self.user_id.as_str()),
|
||||
("size", "1"),
|
||||
("page", "1"),
|
||||
])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 count failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
Ok(list.total)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[
|
||||
("user_id", self.user_id.as_str()),
|
||||
("size", "1"),
|
||||
("page", "1"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.is_ok_and(|r| r.status().is_success())
|
||||
}
|
||||
|
||||
async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Mem0Memory::store_procedural(self, messages, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> Mem0Config {
|
||||
Mem0Config {
|
||||
url: "http://localhost:8765".into(),
|
||||
user_id: "test-user".into(),
|
||||
app_name: "test-app".into(),
|
||||
infer: true,
|
||||
extraction_prompt: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_rejects_empty_url() {
|
||||
let config = Mem0Config {
|
||||
url: String::new(),
|
||||
..test_config()
|
||||
};
|
||||
assert!(Mem0Memory::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_trims_trailing_slash() {
|
||||
let config = Mem0Config {
|
||||
url: "http://localhost:8765/".into(),
|
||||
..test_config()
|
||||
};
|
||||
let mem = Mem0Memory::new(&config).unwrap();
|
||||
assert_eq!(mem.base_url, "http://localhost:8765");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_url_builds_correct_path() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
assert_eq!(
|
||||
mem.api_url("/memories/"),
|
||||
"http://localhost:8765/api/v1/memories/"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_maps_unix_timestamp() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-1".into(),
|
||||
memory: "hello".into(),
|
||||
created_at: Some(serde_json::json!(1_700_000_000)),
|
||||
metadata: Some(Mem0ResponseMetadata {
|
||||
key: Some("k1".into()),
|
||||
category: Some("core".into()),
|
||||
session_id: None,
|
||||
}),
|
||||
score: Some(0.95),
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.id, "id-1");
|
||||
assert_eq!(entry.key, "k1");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
assert!(!entry.timestamp.is_empty());
|
||||
assert_eq!(entry.score, Some(0.95));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_maps_string_timestamp() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-2".into(),
|
||||
memory: "world".into(),
|
||||
created_at: Some(serde_json::json!("2024-01-01T00:00:00Z")),
|
||||
metadata: None,
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.timestamp, "2024-01-01T00:00:00Z");
|
||||
assert_eq!(entry.category, MemoryCategory::Core); // default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_handles_missing_metadata() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-3".into(),
|
||||
memory: "bare".into(),
|
||||
created_at: None,
|
||||
metadata: None,
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.key, "");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
assert!(entry.timestamp.is_empty());
|
||||
assert_eq!(entry.score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_custom_category() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-4".into(),
|
||||
memory: "custom".into(),
|
||||
created_at: None,
|
||||
metadata: Some(Mem0ResponseMetadata {
|
||||
key: Some("k".into()),
|
||||
category: Some("project_notes".into()),
|
||||
session_id: Some("s1".into()),
|
||||
}),
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(
|
||||
entry.category,
|
||||
MemoryCategory::Custom("project_notes".into())
|
||||
);
|
||||
assert_eq!(entry.session_id.as_deref(), Some("s1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_returns_mem0() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
assert_eq!(mem.name(), "mem0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn procedural_request_serializes_messages() {
|
||||
let messages = vec![
|
||||
ProceduralMessage {
|
||||
role: "user".into(),
|
||||
content: "How do I deploy?".into(),
|
||||
name: None,
|
||||
},
|
||||
ProceduralMessage {
|
||||
role: "tool".into(),
|
||||
content: "deployment started".into(),
|
||||
name: Some("shell".into()),
|
||||
},
|
||||
ProceduralMessage {
|
||||
role: "assistant".into(),
|
||||
content: "Deployment complete.".into(),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
let req = AddProceduralRequest {
|
||||
user_id: "test-user",
|
||||
messages: &messages,
|
||||
metadata: None,
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(json["user_id"], "test-user");
|
||||
let msgs = json["messages"].as_array().unwrap();
|
||||
assert_eq!(msgs.len(), 3);
|
||||
assert_eq!(msgs[0]["role"], "user");
|
||||
assert_eq!(msgs[1]["name"], "shell");
|
||||
// metadata should be absent when None
|
||||
assert!(json.get("metadata").is_none());
|
||||
}
|
||||
}
|
||||
+16
-27
@@ -1,24 +1,33 @@
|
||||
pub mod audit;
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod conflict;
|
||||
pub mod consolidation;
|
||||
pub mod decay;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod importance;
|
||||
pub mod knowledge_graph;
|
||||
pub mod lucid;
|
||||
pub mod markdown;
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
pub mod mem0;
|
||||
pub mod none;
|
||||
pub mod policy;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub mod postgres;
|
||||
pub mod qdrant;
|
||||
pub mod response_cache;
|
||||
pub mod retrieval;
|
||||
pub mod snapshot;
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
pub mod vector;
|
||||
|
||||
#[cfg(test)]
|
||||
mod battle_tests;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use audit::AuditedMemory;
|
||||
#[allow(unused_imports)]
|
||||
pub use backend::{
|
||||
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
|
||||
@@ -26,13 +35,15 @@ pub use backend::{
|
||||
};
|
||||
pub use lucid::LucidMemory;
|
||||
pub use markdown::MarkdownMemory;
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
pub use mem0::Mem0Memory;
|
||||
pub use none::NoneMemory;
|
||||
#[allow(unused_imports)]
|
||||
pub use policy::PolicyEnforcer;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub use postgres::PostgresMemory;
|
||||
pub use qdrant::QdrantMemory;
|
||||
pub use response_cache::ResponseCache;
|
||||
#[allow(unused_imports)]
|
||||
pub use retrieval::{RetrievalConfig, RetrievalPipeline};
|
||||
pub use sqlite::SqliteMemory;
|
||||
pub use traits::Memory;
|
||||
#[allow(unused_imports)]
|
||||
@@ -61,7 +72,7 @@ where
|
||||
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
|
||||
}
|
||||
MemoryBackendKind::Postgres => postgres_builder(),
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Mem0 | MemoryBackendKind::Markdown => {
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
|
||||
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
|
||||
}
|
||||
MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())),
|
||||
@@ -340,28 +351,6 @@ pub fn create_memory_with_storage_and_routes(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
fn build_mem0_memory(config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
|
||||
let mem = Mem0Memory::new(&config.mem0)?;
|
||||
tracing::info!(
|
||||
"📦 Mem0 memory backend configured (url: {}, user: {})",
|
||||
config.mem0.url,
|
||||
config.mem0.user_id
|
||||
);
|
||||
Ok(Box::new(mem))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "memory-mem0"))]
|
||||
fn build_mem0_memory(_config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
|
||||
anyhow::bail!(
|
||||
"memory backend 'mem0' requested but this build was compiled without `memory-mem0`; rebuild with `--features memory-mem0`"
|
||||
);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Mem0) {
|
||||
return build_mem0_memory(config);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
|
||||
let url = config
|
||||
.qdrant
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
//! Policy engine for memory operations.
|
||||
//!
|
||||
//! Validates operations against configurable rules before they reach the
|
||||
//! backend. Enforces namespace quotas, category limits, read-only namespaces,
|
||||
//! and per-category retention rules.
|
||||
|
||||
use super::traits::MemoryCategory;
|
||||
use crate::config::MemoryPolicyConfig;
|
||||
|
||||
/// Policy enforcer that validates memory operations.
|
||||
pub struct PolicyEnforcer {
|
||||
config: MemoryPolicyConfig,
|
||||
}
|
||||
|
||||
impl PolicyEnforcer {
|
||||
pub fn new(config: &MemoryPolicyConfig) -> Self {
|
||||
Self {
|
||||
config: config.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a namespace is read-only.
|
||||
pub fn is_read_only(&self, namespace: &str) -> bool {
|
||||
self.config
|
||||
.read_only_namespaces
|
||||
.iter()
|
||||
.any(|ns| ns == namespace)
|
||||
}
|
||||
|
||||
/// Validate a store operation against policy rules.
|
||||
pub fn validate_store(
|
||||
&self,
|
||||
namespace: &str,
|
||||
_category: &MemoryCategory,
|
||||
) -> Result<(), PolicyViolation> {
|
||||
if self.is_read_only(namespace) {
|
||||
return Err(PolicyViolation::ReadOnlyNamespace(namespace.to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if adding an entry would exceed namespace limits.
|
||||
pub fn check_namespace_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
|
||||
if self.config.max_entries_per_namespace > 0
|
||||
&& current_count >= self.config.max_entries_per_namespace
|
||||
{
|
||||
return Err(PolicyViolation::NamespaceQuotaExceeded {
|
||||
max: self.config.max_entries_per_namespace,
|
||||
current: current_count,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if adding an entry would exceed category limits.
|
||||
pub fn check_category_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
|
||||
if self.config.max_entries_per_category > 0
|
||||
&& current_count >= self.config.max_entries_per_category
|
||||
{
|
||||
return Err(PolicyViolation::CategoryQuotaExceeded {
|
||||
max: self.config.max_entries_per_category,
|
||||
current: current_count,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the retention days for a specific category, falling back to the
|
||||
/// provided default if no per-category override exists.
|
||||
pub fn retention_days_for_category(&self, category: &MemoryCategory, default_days: u32) -> u32 {
|
||||
let key = category.to_string();
|
||||
self.config
|
||||
.retention_days_by_category
|
||||
.get(&key)
|
||||
.copied()
|
||||
.unwrap_or(default_days)
|
||||
}
|
||||
}
|
||||
|
||||
/// Policy violation errors.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PolicyViolation {
|
||||
ReadOnlyNamespace(String),
|
||||
NamespaceQuotaExceeded { max: usize, current: usize },
|
||||
CategoryQuotaExceeded { max: usize, current: usize },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PolicyViolation {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ReadOnlyNamespace(ns) => write!(f, "namespace '{ns}' is read-only"),
|
||||
Self::NamespaceQuotaExceeded { max, current } => {
|
||||
write!(f, "namespace quota exceeded: {current}/{max} entries")
|
||||
}
|
||||
Self::CategoryQuotaExceeded { max, current } => {
|
||||
write!(f, "category quota exceeded: {current}/{max} entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PolicyViolation {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn empty_policy() -> MemoryPolicyConfig {
|
||||
MemoryPolicyConfig::default()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_policy_allows_everything() {
|
||||
let enforcer = PolicyEnforcer::new(&empty_policy());
|
||||
assert!(!enforcer.is_read_only("default"));
|
||||
assert!(enforcer
|
||||
.validate_store("default", &MemoryCategory::Core)
|
||||
.is_ok());
|
||||
assert!(enforcer.check_namespace_limit(100).is_ok());
|
||||
assert!(enforcer.check_category_limit(100).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_namespace_blocks_writes() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
read_only_namespaces: vec!["archive".into()],
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.is_read_only("archive"));
|
||||
assert!(!enforcer.is_read_only("default"));
|
||||
assert!(enforcer
|
||||
.validate_store("archive", &MemoryCategory::Core)
|
||||
.is_err());
|
||||
assert!(enforcer
|
||||
.validate_store("default", &MemoryCategory::Core)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn namespace_quota_enforced() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
max_entries_per_namespace: 10,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.check_namespace_limit(5).is_ok());
|
||||
assert!(enforcer.check_namespace_limit(10).is_err());
|
||||
assert!(enforcer.check_namespace_limit(15).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn category_quota_enforced() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
max_entries_per_category: 50,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.check_category_limit(25).is_ok());
|
||||
assert!(enforcer.check_category_limit(50).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_category_retention_overrides_default() {
|
||||
let mut retention = HashMap::new();
|
||||
retention.insert("core".into(), 365);
|
||||
retention.insert("conversation".into(), 7);
|
||||
|
||||
let policy = MemoryPolicyConfig {
|
||||
retention_days_by_category: retention,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Core, 30),
|
||||
365
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Conversation, 30),
|
||||
7
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Daily, 30),
|
||||
30
|
||||
);
|
||||
}
|
||||
}
|
||||
+12
-3
@@ -100,6 +100,8 @@ impl PostgresMemory {
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_category ON {qualified_table}(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session_id ON {qualified_table}(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_updated_at ON {qualified_table}(updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_content_fts ON {qualified_table} USING gin(to_tsvector('simple', content));
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_key_fts ON {qualified_table} USING gin(to_tsvector('simple', key));
|
||||
"
|
||||
))?;
|
||||
|
||||
@@ -135,6 +137,9 @@ impl PostgresMemory {
|
||||
timestamp: timestamp.to_rfc3339(),
|
||||
session_id: row.get(5),
|
||||
score: row.try_get(6).ok(),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -267,12 +272,16 @@ impl Memory for PostgresMemory {
|
||||
"
|
||||
SELECT id, key, content, category, created_at, session_id,
|
||||
(
|
||||
CASE WHEN key ILIKE '%' || $1 || '%' THEN 2.0 ELSE 0.0 END +
|
||||
CASE WHEN content ILIKE '%' || $1 || '%' THEN 1.0 ELSE 0.0 END
|
||||
CASE WHEN to_tsvector('simple', key) @@ plainto_tsquery('simple', $1)
|
||||
THEN ts_rank_cd(to_tsvector('simple', key), plainto_tsquery('simple', $1)) * 2.0
|
||||
ELSE 0.0 END +
|
||||
CASE WHEN to_tsvector('simple', content) @@ plainto_tsquery('simple', $1)
|
||||
THEN ts_rank_cd(to_tsvector('simple', content), plainto_tsquery('simple', $1))
|
||||
ELSE 0.0 END
|
||||
) AS score
|
||||
FROM {qualified_table}
|
||||
WHERE ($2::TEXT IS NULL OR session_id = $2)
|
||||
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
|
||||
AND ($1 = '' OR to_tsvector('simple', key || ' ' || content) @@ plainto_tsquery('simple', $1))
|
||||
{time_filter}
|
||||
ORDER BY score DESC, updated_at DESC
|
||||
LIMIT $3
|
||||
|
||||
@@ -373,6 +373,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: Some(point.score),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -437,6 +440,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
});
|
||||
|
||||
@@ -514,6 +520,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
//! Multi-stage retrieval pipeline.
|
||||
//!
|
||||
//! Wraps a `Memory` trait object with staged retrieval:
|
||||
//! - **Stage 1 (Hot cache):** In-memory LRU of recent recall results.
|
||||
//! - **Stage 2 (FTS):** FTS5 keyword search with optional early-return.
|
||||
//! - **Stage 3 (Vector):** Vector similarity search + hybrid merge.
|
||||
//!
|
||||
//! Configurable via `[memory]` settings: `retrieval_stages`, `fts_early_return_score`.
|
||||
|
||||
use super::traits::{Memory, MemoryEntry};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// A cached recall result.
|
||||
struct CachedResult {
|
||||
entries: Vec<MemoryEntry>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
/// Multi-stage retrieval pipeline configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
/// Ordered list of stages: "cache", "fts", "vector".
|
||||
pub stages: Vec<String>,
|
||||
/// FTS score above which to early-return without vector stage.
|
||||
pub fts_early_return_score: f64,
|
||||
/// Max entries in the hot cache.
|
||||
pub cache_max_entries: usize,
|
||||
/// TTL for cached results.
|
||||
pub cache_ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stages: vec!["cache".into(), "fts".into(), "vector".into()],
|
||||
fts_early_return_score: 0.85,
|
||||
cache_max_entries: 256,
|
||||
cache_ttl: Duration::from_secs(300),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-stage retrieval pipeline wrapping a `Memory` backend.
|
||||
pub struct RetrievalPipeline {
|
||||
memory: Arc<dyn Memory>,
|
||||
config: RetrievalConfig,
|
||||
hot_cache: Mutex<HashMap<String, CachedResult>>,
|
||||
}
|
||||
|
||||
impl RetrievalPipeline {
|
||||
pub fn new(memory: Arc<dyn Memory>, config: RetrievalConfig) -> Self {
|
||||
Self {
|
||||
memory,
|
||||
config,
|
||||
hot_cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a cache key from query parameters.
|
||||
fn cache_key(
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
) -> String {
|
||||
format!(
|
||||
"{}:{}:{}:{}",
|
||||
query,
|
||||
limit,
|
||||
session_id.unwrap_or(""),
|
||||
namespace.unwrap_or("")
|
||||
)
|
||||
}
|
||||
|
||||
/// Check the hot cache for a previous result.
|
||||
fn check_cache(&self, key: &str) -> Option<Vec<MemoryEntry>> {
|
||||
let cache = self.hot_cache.lock();
|
||||
if let Some(cached) = cache.get(key) {
|
||||
if cached.created_at.elapsed() < self.config.cache_ttl {
|
||||
return Some(cached.entries.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Store a result in the hot cache with LRU eviction.
|
||||
fn store_in_cache(&self, key: String, entries: Vec<MemoryEntry>) {
|
||||
let mut cache = self.hot_cache.lock();
|
||||
|
||||
// LRU eviction: remove oldest entries if at capacity
|
||||
if cache.len() >= self.config.cache_max_entries {
|
||||
let oldest_key = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.created_at)
|
||||
.map(|(k, _)| k.clone());
|
||||
if let Some(k) = oldest_key {
|
||||
cache.remove(&k);
|
||||
}
|
||||
}
|
||||
|
||||
cache.insert(
|
||||
key,
|
||||
CachedResult {
|
||||
entries,
|
||||
created_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Execute the multi-stage retrieval pipeline.
|
||||
pub async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let ck = Self::cache_key(query, limit, session_id, namespace);
|
||||
|
||||
for stage in &self.config.stages {
|
||||
match stage.as_str() {
|
||||
"cache" => {
|
||||
if let Some(cached) = self.check_cache(&ck) {
|
||||
tracing::debug!("retrieval pipeline: cache hit for '{query}'");
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
"fts" | "vector" => {
|
||||
// Both FTS and vector are handled by the backend's recall method
|
||||
// which already does hybrid merge. We delegate to it.
|
||||
let results = if let Some(ns) = namespace {
|
||||
self.memory
|
||||
.recall_namespaced(ns, query, limit, session_id, since, until)
|
||||
.await?
|
||||
} else {
|
||||
self.memory
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await?
|
||||
};
|
||||
|
||||
if !results.is_empty() {
|
||||
// Check for FTS early-return: if top score exceeds threshold
|
||||
// and we're in the FTS stage, we can skip further stages
|
||||
if stage == "fts" {
|
||||
if let Some(top_score) = results.first().and_then(|e| e.score) {
|
||||
if top_score >= self.config.fts_early_return_score {
|
||||
tracing::debug!(
|
||||
"retrieval pipeline: FTS early return (score={top_score:.3})"
|
||||
);
|
||||
self.store_in_cache(ck, results.clone());
|
||||
return Ok(results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.store_in_cache(ck, results.clone());
|
||||
return Ok(results);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("retrieval pipeline: unknown stage '{other}', skipping");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No results from any stage
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Invalidate the hot cache (e.g. after a store operation).
|
||||
pub fn invalidate_cache(&self) {
|
||||
self.hot_cache.lock().clear();
|
||||
}
|
||||
|
||||
/// Get the number of entries in the hot cache.
|
||||
pub fn cache_size(&self) -> usize {
|
||||
self.hot_cache.lock().len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::NoneMemory;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_returns_empty_from_none_backend() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
|
||||
|
||||
let results = pipeline
|
||||
.recall("test", 10, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_cache_invalidation() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
|
||||
|
||||
// Force a cache entry
|
||||
let ck = RetrievalPipeline::cache_key("test", 10, None, None);
|
||||
pipeline.store_in_cache(ck, vec![]);
|
||||
|
||||
assert_eq!(pipeline.cache_size(), 1);
|
||||
pipeline.invalidate_cache();
|
||||
assert_eq!(pipeline.cache_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cache_key_includes_all_params() {
|
||||
let k1 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns1"));
|
||||
let k2 = RetrievalPipeline::cache_key("hello", 10, Some("sess-b"), Some("ns1"));
|
||||
let k3 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns2"));
|
||||
|
||||
assert_ne!(k1, k2);
|
||||
assert_ne!(k1, k3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_caches_results() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let config = RetrievalConfig {
|
||||
stages: vec!["cache".into()],
|
||||
..Default::default()
|
||||
};
|
||||
let pipeline = RetrievalPipeline::new(memory, config);
|
||||
|
||||
// First call: cache miss, no results
|
||||
let results = pipeline
|
||||
.recall("test", 10, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Manually insert a cache entry
|
||||
let ck = RetrievalPipeline::cache_key("cached_query", 5, None, None);
|
||||
let fake_entry = MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "k".into(),
|
||||
content: "cached content".into(),
|
||||
category: crate::memory::MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
};
|
||||
pipeline.store_in_cache(ck, vec![fake_entry]);
|
||||
|
||||
// Cache hit
|
||||
let results = pipeline
|
||||
.recall("cached_query", 5, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].content, "cached content");
|
||||
}
|
||||
}
|
||||
+154
-23
@@ -46,6 +46,31 @@ impl SqliteMemory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Like `new`, but stores data in `{db_name}.db` instead of `brain.db`.
|
||||
pub fn new_named(workspace_dir: &Path, db_name: &str) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join(format!("{db_name}.db"));
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let conn = Self::open_connection(&db_path, None)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA mmap_size = 8388608;
|
||||
PRAGMA cache_size = -2000;
|
||||
PRAGMA temp_store = MEMORY;",
|
||||
)?;
|
||||
Self::init_schema(&conn)?;
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
db_path,
|
||||
embedder: Arc::new(super::embeddings::NoopEmbedding),
|
||||
vector_weight: 0.7,
|
||||
keyword_weight: 0.3,
|
||||
cache_max: 10_000,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build SQLite memory with optional open timeout.
|
||||
///
|
||||
/// If `open_timeout_secs` is `Some(n)`, opening the database is limited to `n` seconds
|
||||
@@ -172,17 +197,35 @@ impl SqliteMemory {
|
||||
)?;
|
||||
|
||||
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||
let has_session_id: bool = conn
|
||||
let schema_sql: String = conn
|
||||
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||
.query_row([], |row| row.get::<_, String>(0))?
|
||||
.contains("session_id");
|
||||
if !has_session_id {
|
||||
.query_row([], |row| row.get::<_, String>(0))?;
|
||||
|
||||
if !schema_sql.contains("session_id") {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Migration: add namespace column
|
||||
if !schema_sql.contains("namespace") {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN namespace TEXT DEFAULT 'default';
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Migration: add importance column
|
||||
if !schema_sql.contains("importance") {
|
||||
conn.execute_batch("ALTER TABLE memories ADD COLUMN importance REAL DEFAULT 0.5;")?;
|
||||
}
|
||||
|
||||
// Migration: add superseded_by column
|
||||
if !schema_sql.contains("superseded_by") {
|
||||
conn.execute_batch("ALTER TABLE memories ADD COLUMN superseded_by TEXT;")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -221,8 +264,13 @@ impl SqliteMemory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Provide access to the connection for advanced queries (e.g. retrieval pipeline).
|
||||
pub fn connection(&self) -> &Arc<Mutex<Connection>> {
|
||||
&self.conn
|
||||
}
|
||||
|
||||
/// Get embedding from cache, or compute + cache it
|
||||
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||
pub async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||
if self.embedder.dimensions() == 0 {
|
||||
return Ok(None); // Noop embedder
|
||||
}
|
||||
@@ -285,7 +333,7 @@ impl SqliteMemory {
|
||||
}
|
||||
|
||||
/// FTS5 BM25 keyword search
|
||||
fn fts5_search(
|
||||
pub fn fts5_search(
|
||||
conn: &Connection,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
@@ -331,7 +379,7 @@ impl SqliteMemory {
|
||||
///
|
||||
/// Optional `category` and `session_id` filters reduce full-table scans
|
||||
/// when the caller already knows the scope of relevant memories.
|
||||
fn vector_search(
|
||||
pub fn vector_search(
|
||||
conn: &Connection,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
@@ -448,8 +496,8 @@ impl SqliteMemory {
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let mut sql =
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories \
|
||||
WHERE 1=1"
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories \
|
||||
WHERE superseded_by IS NULL AND 1=1"
|
||||
.to_string();
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
let mut idx = 1;
|
||||
@@ -485,6 +533,9 @@ impl SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -529,8 +580,8 @@ impl Memory for SqliteMemory {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 'default', 0.5)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
@@ -616,8 +667,8 @@ impl Memory for SqliteMemory {
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id \
|
||||
FROM memories WHERE id IN ({placeholders})"
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by \
|
||||
FROM memories WHERE superseded_by IS NULL AND id IN ({placeholders})"
|
||||
);
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
|
||||
@@ -634,17 +685,20 @@ impl Memory for SqliteMemory {
|
||||
row.get::<_, String>(3)?,
|
||||
row.get::<_, String>(4)?,
|
||||
row.get::<_, Option<String>>(5)?,
|
||||
row.get::<_, Option<String>>(6)?,
|
||||
row.get::<_, Option<f64>>(7)?,
|
||||
row.get::<_, Option<String>>(8)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut entry_map = std::collections::HashMap::new();
|
||||
for row in rows {
|
||||
let (id, key, content, cat, ts, sid) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid));
|
||||
let (id, key, content, cat, ts, sid, ns, imp, sup) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid, ns, imp, sup));
|
||||
}
|
||||
|
||||
for scored in &merged {
|
||||
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||
if let Some((key, content, cat, ts, sid, ns, imp, sup)) = entry_map.remove(&scored.id) {
|
||||
if let Some(s) = since_ref {
|
||||
if ts.as_str() < s {
|
||||
continue;
|
||||
@@ -663,6 +717,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: ts,
|
||||
session_id: sid,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
namespace: ns.unwrap_or_else(|| "default".into()),
|
||||
importance: imp,
|
||||
superseded_by: sup,
|
||||
};
|
||||
if let Some(filter_sid) = session_ref {
|
||||
if entry.session_id.as_deref() != Some(filter_sid) {
|
||||
@@ -702,8 +759,8 @@ impl Memory for SqliteMemory {
|
||||
param_idx += 1;
|
||||
}
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}{time_conditions}
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL AND ({where_clause}){time_conditions}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{param_idx}"
|
||||
);
|
||||
@@ -732,6 +789,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(1.0),
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
for row in rows {
|
||||
@@ -759,7 +819,7 @@ impl Memory for SqliteMemory {
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
@@ -771,6 +831,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -807,14 +870,17 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(ref cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL AND category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
@@ -828,8 +894,8 @@ impl Memory for SqliteMemory {
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC LIMIT ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL ORDER BY updated_at DESC LIMIT ?1",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
@@ -879,6 +945,71 @@ impl Memory for SqliteMemory {
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let entries = self
|
||||
.recall(query, limit * 2, session_id, since, until)
|
||||
.await?;
|
||||
let filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.namespace == namespace)
|
||||
.take(limit)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
let embedding_bytes = self
|
||||
.get_or_compute_embedding(content)
|
||||
.await?
|
||||
.map(|emb| vector::vec_to_bytes(&emb));
|
||||
|
||||
let conn = self.conn.clone();
|
||||
let key = key.to_string();
|
||||
let content = content.to_string();
|
||||
let sid = session_id.map(String::from);
|
||||
let ns = namespace.unwrap_or("default").to_string();
|
||||
let imp = importance.unwrap_or(0.5);
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let conn = conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let cat = Self::category_to_str(&category);
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
embedding = excluded.embedding,
|
||||
updated_at = excluded.updated_at,
|
||||
session_id = excluded.session_id,
|
||||
namespace = excluded.namespace,
|
||||
importance = excluded.importance",
|
||||
params![id, key, content, cat, embedding_bytes, now, now, sid, ns, imp],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+63
-2
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
/// A single message in a conversation trace for procedural memory.
|
||||
///
|
||||
/// Used to capture "how to" patterns from tool-calling turns so that
|
||||
/// backends that support procedural storage (e.g. mem0) can learn from them.
|
||||
/// backends that support procedural storage can learn from them.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ProceduralMessage {
|
||||
pub role: String,
|
||||
@@ -23,6 +23,19 @@ pub struct MemoryEntry {
|
||||
pub timestamp: String,
|
||||
pub session_id: Option<String>,
|
||||
pub score: Option<f64>,
|
||||
/// Namespace for isolation between agents/contexts.
|
||||
#[serde(default = "default_namespace")]
|
||||
pub namespace: String,
|
||||
/// Importance score (0.0–1.0) for prioritized retrieval.
|
||||
#[serde(default)]
|
||||
pub importance: Option<f64>,
|
||||
/// If this entry was superseded by a newer conflicting entry.
|
||||
#[serde(default)]
|
||||
pub superseded_by: Option<String>,
|
||||
}
|
||||
|
||||
fn default_namespace() -> String {
|
||||
"default".into()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MemoryEntry {
|
||||
@@ -34,6 +47,8 @@ impl std::fmt::Debug for MemoryEntry {
|
||||
.field("category", &self.category)
|
||||
.field("timestamp", &self.timestamp)
|
||||
.field("score", &self.score)
|
||||
.field("namespace", &self.namespace)
|
||||
.field("importance", &self.importance)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -128,7 +143,7 @@ pub trait Memory: Send + Sync {
|
||||
|
||||
/// Store a conversation trace as procedural memory.
|
||||
///
|
||||
/// Backends that support procedural storage (e.g. mem0) override this
|
||||
/// Backends that support procedural storage override this
|
||||
/// to extract "how to" patterns from tool-calling turns. The default
|
||||
/// implementation is a no-op.
|
||||
async fn store_procedural(
|
||||
@@ -138,6 +153,46 @@ pub trait Memory: Send + Sync {
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recall memories scoped to a specific namespace.
|
||||
///
|
||||
/// Default implementation delegates to `recall()` and filters by namespace.
|
||||
/// Backends with native namespace support should override for efficiency.
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let entries = self
|
||||
.recall(query, limit * 2, session_id, since, until)
|
||||
.await?;
|
||||
let filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.namespace == namespace)
|
||||
.take(limit)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Store a memory entry with namespace and importance.
|
||||
///
|
||||
/// Default implementation delegates to `store()`. Backends with native
|
||||
/// namespace/importance support should override.
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
_namespace: Option<&str>,
|
||||
_importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.store(key, content, category, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -185,6 +240,9 @@ mod tests {
|
||||
timestamp: "2026-02-16T00:00:00Z".into(),
|
||||
session_id: Some("session-abc".into()),
|
||||
score: Some(0.98),
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.7),
|
||||
superseded_by: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&entry).unwrap();
|
||||
@@ -196,5 +254,8 @@ mod tests {
|
||||
assert_eq!(parsed.category, MemoryCategory::Core);
|
||||
assert_eq!(parsed.session_id.as_deref(), Some("session-abc"));
|
||||
assert_eq!(parsed.score, Some(0.98));
|
||||
assert_eq!(parsed.namespace, "default");
|
||||
assert_eq!(parsed.importance, Some(0.7));
|
||||
assert!(parsed.superseded_by.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,6 +126,7 @@ pub fn hybrid_merge(
|
||||
b.final_score
|
||||
.partial_cmp(&a.final_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| a.id.cmp(&b.id))
|
||||
});
|
||||
results.truncate(limit);
|
||||
results
|
||||
|
||||
@@ -507,6 +507,7 @@ mod tests {
|
||||
max_images: 1,
|
||||
max_image_size_mb: 5,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
@@ -549,6 +550,7 @@ mod tests {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
|
||||
+16
-1
@@ -154,6 +154,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
agent: crate::config::schema::AgentConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -172,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
link_enricher: crate::config::LinkEnricherConfig::default(),
|
||||
text_browser: crate::config::TextBrowserConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
@@ -196,6 +198,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
knowledge: crate::config::KnowledgeConfig::default(),
|
||||
linkedin: crate::config::LinkedInConfig::default(),
|
||||
image_gen: crate::config::ImageGenConfig::default(),
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
@@ -418,9 +421,17 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
retrieval_stages: vec!["cache".into(), "fts".into(), "vector".into()],
|
||||
rerank_enabled: false,
|
||||
rerank_threshold: 5,
|
||||
fts_early_return_score: 0.85,
|
||||
default_namespace: "default".into(),
|
||||
conflict_threshold: 0.85,
|
||||
audit_enabled: false,
|
||||
audit_retention_days: 30,
|
||||
policy: crate::config::MemoryPolicyConfig::default(),
|
||||
sqlite_open_timeout_secs: None,
|
||||
qdrant: crate::config::QdrantConfig::default(),
|
||||
mem0: crate::config::schema::Mem0Config::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -576,6 +587,7 @@ async fn run_quick_setup_with_home(
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
agent: crate::config::schema::AgentConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -594,6 +606,7 @@ async fn run_quick_setup_with_home(
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
link_enricher: crate::config::LinkEnricherConfig::default(),
|
||||
text_browser: crate::config::TextBrowserConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
@@ -618,6 +631,7 @@ async fn run_quick_setup_with_home(
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
knowledge: crate::config::KnowledgeConfig::default(),
|
||||
linkedin: crate::config::LinkedInConfig::default(),
|
||||
image_gen: crate::config::ImageGenConfig::default(),
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
@@ -4181,6 +4195,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
device_id: detected_device_id,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -412,6 +412,7 @@ mod tests {
|
||||
));
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
writeln!(f, "#!/bin/sh\ncat /dev/stdin").unwrap();
|
||||
f.sync_all().unwrap();
|
||||
drop(f);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
|
||||
@@ -767,6 +767,12 @@ impl Provider for OpenAiCodexProvider {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Mutex that serializes all tests which mutate process-global env vars
|
||||
/// (`std::env::set_var` / `remove_var`). Each such test must hold this
|
||||
/// lock for its entire duration so that parallel test threads don't race.
|
||||
static ENV_MUTEX: Mutex<()> = Mutex::new(());
|
||||
|
||||
struct EnvGuard {
|
||||
key: &'static str,
|
||||
@@ -841,6 +847,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_responses_url_prefers_explicit_endpoint_env() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _endpoint_guard = EnvGuard::set(
|
||||
CODEX_RESPONSES_URL_ENV,
|
||||
Some("https://env.example.com/v1/responses"),
|
||||
@@ -856,6 +863,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_responses_url_uses_provider_api_url_override() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None);
|
||||
let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None);
|
||||
|
||||
@@ -959,6 +967,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_reasoning_effort_prefers_configured_override() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("low"));
|
||||
assert_eq!(
|
||||
resolve_reasoning_effort("gpt-5-codex", Some("high")),
|
||||
@@ -968,6 +977,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_reasoning_effort_uses_legacy_env_when_unconfigured() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("minimal"));
|
||||
assert_eq!(
|
||||
resolve_reasoning_effort("gpt-5-codex", None),
|
||||
|
||||
@@ -108,6 +108,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
|
||||
"token limit exceeded",
|
||||
"prompt is too long",
|
||||
"input is too long",
|
||||
"prompt exceeds max length",
|
||||
];
|
||||
|
||||
hints.iter().any(|hint| lower.contains(hint))
|
||||
|
||||
+17
-2
@@ -97,7 +97,8 @@ pub struct SecurityPolicy {
|
||||
/// Default allowed commands for Unix platforms.
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn default_allowed_commands() -> Vec<String> {
|
||||
vec![
|
||||
#[allow(unused_mut)]
|
||||
let mut cmds = vec![
|
||||
"git".into(),
|
||||
"npm".into(),
|
||||
"cargo".into(),
|
||||
@@ -111,7 +112,16 @@ fn default_allowed_commands() -> Vec<String> {
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
"date".into(),
|
||||
]
|
||||
"df".into(),
|
||||
"du".into(),
|
||||
"uname".into(),
|
||||
"uptime".into(),
|
||||
"hostname".into(),
|
||||
];
|
||||
// `free` is Linux-only; it does not exist on macOS or other BSDs.
|
||||
#[cfg(target_os = "linux")]
|
||||
cmds.push("free".into());
|
||||
cmds
|
||||
}
|
||||
|
||||
/// Default allowed commands for Windows platforms.
|
||||
@@ -142,6 +152,11 @@ fn default_allowed_commands() -> Vec<String> {
|
||||
"wc".into(),
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
"df".into(),
|
||||
"du".into(),
|
||||
"uname".into(),
|
||||
"uptime".into(),
|
||||
"hostname".into(),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
+3
-3
@@ -1236,7 +1236,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_capture_reads_stdout() {
|
||||
let out = run_capture(Command::new("sh").args(["-lc", "echo hello"]))
|
||||
let out = run_capture(Command::new("sh").args(["-c", "echo hello"]))
|
||||
.expect("stdout capture should succeed");
|
||||
assert_eq!(out.trim(), "hello");
|
||||
}
|
||||
@@ -1244,7 +1244,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_capture_falls_back_to_stderr() {
|
||||
let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"]))
|
||||
let out = run_capture(Command::new("sh").args(["-c", "echo warn 1>&2"]))
|
||||
.expect("stderr capture should succeed");
|
||||
assert_eq!(out.trim(), "warn");
|
||||
}
|
||||
@@ -1252,7 +1252,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_checked_errors_on_non_zero_status() {
|
||||
let err = run_checked(Command::new("sh").args(["-lc", "exit 17"]))
|
||||
let err = run_checked(Command::new("sh").args(["-c", "exit 17"]))
|
||||
.expect_err("non-zero exit should error");
|
||||
assert!(err.to_string().contains("Command failed"));
|
||||
}
|
||||
|
||||
@@ -106,8 +106,14 @@ impl SkillCreator {
|
||||
// Trim leading/trailing hyphens, then truncate.
|
||||
let trimmed = collapsed.trim_matches('-');
|
||||
if trimmed.len() > 64 {
|
||||
// Truncate at a hyphen boundary if possible.
|
||||
let truncated = &trimmed[..64];
|
||||
// Find the nearest valid character boundary at or before 64 bytes.
|
||||
let safe_index = trimmed
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 64)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
let truncated = &trimmed[..safe_index];
|
||||
truncated.trim_end_matches('-').to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
|
||||
+89
-14
@@ -738,15 +738,47 @@ pub fn skills_to_prompt_with_mode(
|
||||
}
|
||||
|
||||
if !skill.tools.is_empty() {
|
||||
let _ = writeln!(prompt, " <tools>");
|
||||
for tool in &skill.tools {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
// Tools with known kinds (shell, script, http) are registered as
|
||||
// callable tool specs and can be invoked directly via function calling.
|
||||
// We note them here for context but mark them as callable.
|
||||
let registered: Vec<_> = skill
|
||||
.tools
|
||||
.iter()
|
||||
.filter(|t| matches!(t.kind.as_str(), "shell" | "script" | "http"))
|
||||
.collect();
|
||||
let unregistered: Vec<_> = skill
|
||||
.tools
|
||||
.iter()
|
||||
.filter(|t| !matches!(t.kind.as_str(), "shell" | "script" | "http"))
|
||||
.collect();
|
||||
|
||||
if !registered.is_empty() {
|
||||
let _ = writeln!(prompt, " <callable_tools hint=\"These are registered as callable tool specs. Invoke them directly by name ({{}}.{{}}) instead of using shell.\">");
|
||||
for tool in ®istered {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(
|
||||
&mut prompt,
|
||||
8,
|
||||
"name",
|
||||
&format!("{}.{}", skill.name, tool.name),
|
||||
);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </callable_tools>");
|
||||
}
|
||||
|
||||
if !unregistered.is_empty() {
|
||||
let _ = writeln!(prompt, " <tools>");
|
||||
for tool in &unregistered {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </tools>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </tools>");
|
||||
}
|
||||
|
||||
let _ = writeln!(prompt, " </skill>");
|
||||
@@ -756,6 +788,47 @@ pub fn skills_to_prompt_with_mode(
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Convert skill tools into callable `Tool` trait objects.
|
||||
///
|
||||
/// Each skill's `[[tools]]` entries are converted to either `SkillShellTool`
|
||||
/// (for `shell`/`script` kinds) or `SkillHttpTool` (for `http` kind),
|
||||
/// enabling them to appear as first-class callable tool specs rather than
|
||||
/// only as XML in the system prompt.
|
||||
pub fn skills_to_tools(
|
||||
skills: &[Skill],
|
||||
security: std::sync::Arc<crate::security::SecurityPolicy>,
|
||||
) -> Vec<Box<dyn crate::tools::traits::Tool>> {
|
||||
let mut tools: Vec<Box<dyn crate::tools::traits::Tool>> = Vec::new();
|
||||
for skill in skills {
|
||||
for tool in &skill.tools {
|
||||
match tool.kind.as_str() {
|
||||
"shell" | "script" => {
|
||||
tools.push(Box::new(crate::tools::skill_tool::SkillShellTool::new(
|
||||
&skill.name,
|
||||
tool,
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
"http" => {
|
||||
tools.push(Box::new(crate::tools::skill_http::SkillHttpTool::new(
|
||||
&skill.name,
|
||||
tool,
|
||||
)));
|
||||
}
|
||||
other => {
|
||||
tracing::warn!(
|
||||
"Unknown skill tool kind '{}' for {}.{}, skipping",
|
||||
other,
|
||||
skill.name,
|
||||
tool.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
/// Get the skills directory path
|
||||
pub fn skills_dir(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("skills")
|
||||
@@ -1517,10 +1590,10 @@ command = "echo hello"
|
||||
assert!(prompt.contains("read_skill(name)"));
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt.contains("<instruction>Do the thing.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>run</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell/script/http) appear under <callable_tools>.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>test.run</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1710,9 +1783,11 @@ description = "Bare minimum"
|
||||
}];
|
||||
let prompt = skills_to_prompt(&skills, Path::new("/tmp"));
|
||||
assert!(prompt.contains("weather"));
|
||||
assert!(prompt.contains("<name>get_weather</name>"));
|
||||
// Registered tools (shell kind) now appear under <callable_tools> with
|
||||
// prefixed names (skill_name.tool_name).
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>weather.get_weather</name>"));
|
||||
assert!(prompt.contains("<description>Fetch forecast</description>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,636 @@
|
||||
//! Live Canvas (A2UI) tool — push rendered content to a web canvas in real time.
|
||||
//!
|
||||
//! The agent can render HTML/SVG/Markdown to a named canvas, snapshot its
|
||||
//! current state, clear it, or evaluate a JavaScript expression in the canvas
|
||||
//! context. Content is stored in a shared [`CanvasStore`] and broadcast to
|
||||
//! connected WebSocket clients via per-canvas channels.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Maximum content size per canvas frame (256 KB).
|
||||
pub const MAX_CONTENT_SIZE: usize = 256 * 1024;
|
||||
|
||||
/// Maximum number of history frames kept per canvas.
|
||||
const MAX_HISTORY_FRAMES: usize = 50;
|
||||
|
||||
/// Broadcast channel capacity per canvas.
|
||||
const BROADCAST_CAPACITY: usize = 64;
|
||||
|
||||
/// Maximum number of concurrent canvases to prevent memory exhaustion.
|
||||
const MAX_CANVAS_COUNT: usize = 100;
|
||||
|
||||
/// Allowed content types for canvas frames via the REST API.
|
||||
pub const ALLOWED_CONTENT_TYPES: &[&str] = &["html", "svg", "markdown", "text"];
|
||||
|
||||
/// A single canvas frame (one render).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CanvasFrame {
|
||||
/// Unique frame identifier.
|
||||
pub frame_id: String,
|
||||
/// Content type: `html`, `svg`, `markdown`, or `text`.
|
||||
pub content_type: String,
|
||||
/// The rendered content.
|
||||
pub content: String,
|
||||
/// ISO-8601 timestamp of when the frame was created.
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Per-canvas state: current content + history + broadcast sender.
|
||||
struct CanvasEntry {
|
||||
current: Option<CanvasFrame>,
|
||||
history: Vec<CanvasFrame>,
|
||||
tx: broadcast::Sender<CanvasFrame>,
|
||||
}
|
||||
|
||||
/// Shared canvas store — holds all active canvases.
|
||||
///
|
||||
/// Thread-safe and cheaply cloneable (wraps `Arc`).
|
||||
#[derive(Clone)]
|
||||
pub struct CanvasStore {
|
||||
inner: Arc<RwLock<HashMap<String, CanvasEntry>>>,
|
||||
}
|
||||
|
||||
impl Default for CanvasStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CanvasStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new frame to a canvas. Creates the canvas if it does not exist.
|
||||
/// Returns `None` if the maximum canvas count has been reached and this is a new canvas.
|
||||
pub fn render(
|
||||
&self,
|
||||
canvas_id: &str,
|
||||
content_type: &str,
|
||||
content: &str,
|
||||
) -> Option<CanvasFrame> {
|
||||
let frame = CanvasFrame {
|
||||
frame_id: uuid::Uuid::new_v4().to_string(),
|
||||
content_type: content_type.to_string(),
|
||||
content: content.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
let mut store = self.inner.write();
|
||||
|
||||
// Enforce canvas count limit for new canvases.
|
||||
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = store
|
||||
.entry(canvas_id.to_string())
|
||||
.or_insert_with(|| CanvasEntry {
|
||||
current: None,
|
||||
history: Vec::new(),
|
||||
tx: broadcast::channel(BROADCAST_CAPACITY).0,
|
||||
});
|
||||
|
||||
entry.current = Some(frame.clone());
|
||||
entry.history.push(frame.clone());
|
||||
if entry.history.len() > MAX_HISTORY_FRAMES {
|
||||
let excess = entry.history.len() - MAX_HISTORY_FRAMES;
|
||||
entry.history.drain(..excess);
|
||||
}
|
||||
|
||||
// Best-effort broadcast — ignore errors (no receivers is fine).
|
||||
let _ = entry.tx.send(frame.clone());
|
||||
|
||||
Some(frame)
|
||||
}
|
||||
|
||||
/// Get the current (most recent) frame for a canvas.
|
||||
pub fn snapshot(&self, canvas_id: &str) -> Option<CanvasFrame> {
|
||||
let store = self.inner.read();
|
||||
store.get(canvas_id).and_then(|entry| entry.current.clone())
|
||||
}
|
||||
|
||||
/// Get the frame history for a canvas.
|
||||
pub fn history(&self, canvas_id: &str) -> Vec<CanvasFrame> {
|
||||
let store = self.inner.read();
|
||||
store
|
||||
.get(canvas_id)
|
||||
.map(|entry| entry.history.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Clear a canvas (removes current content and history).
|
||||
pub fn clear(&self, canvas_id: &str) -> bool {
|
||||
let mut store = self.inner.write();
|
||||
if let Some(entry) = store.get_mut(canvas_id) {
|
||||
entry.current = None;
|
||||
entry.history.clear();
|
||||
// Send an empty frame to signal clear to subscribers.
|
||||
let clear_frame = CanvasFrame {
|
||||
frame_id: uuid::Uuid::new_v4().to_string(),
|
||||
content_type: "clear".to_string(),
|
||||
content: String::new(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
let _ = entry.tx.send(clear_frame);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to real-time updates for a canvas.
|
||||
/// Creates the canvas entry if it does not exist (subject to canvas count limit).
|
||||
/// Returns `None` if the canvas does not exist and the limit has been reached.
|
||||
pub fn subscribe(&self, canvas_id: &str) -> Option<broadcast::Receiver<CanvasFrame>> {
|
||||
let mut store = self.inner.write();
|
||||
|
||||
// Enforce canvas count limit for new entries.
|
||||
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = store
|
||||
.entry(canvas_id.to_string())
|
||||
.or_insert_with(|| CanvasEntry {
|
||||
current: None,
|
||||
history: Vec::new(),
|
||||
tx: broadcast::channel(BROADCAST_CAPACITY).0,
|
||||
});
|
||||
Some(entry.tx.subscribe())
|
||||
}
|
||||
|
||||
/// List all canvas IDs that currently have content.
|
||||
pub fn list(&self) -> Vec<String> {
|
||||
let store = self.inner.read();
|
||||
store.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// `CanvasTool` — agent-callable tool for the Live Canvas (A2UI) system.
|
||||
pub struct CanvasTool {
|
||||
store: CanvasStore,
|
||||
}
|
||||
|
||||
impl CanvasTool {
|
||||
pub fn new(store: CanvasStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CanvasTool {
|
||||
fn name(&self) -> &str {
|
||||
"canvas"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Push rendered content (HTML, SVG, Markdown) to a live web canvas that users can see \
|
||||
in real-time. Actions: render (push content), snapshot (get current content), \
|
||||
clear (reset canvas), eval (evaluate JS expression in canvas context). \
|
||||
Each canvas is identified by a canvas_id string."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform on the canvas.",
|
||||
"enum": ["render", "snapshot", "clear", "eval"]
|
||||
},
|
||||
"canvas_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the canvas. Defaults to 'default'."
|
||||
},
|
||||
"content_type": {
|
||||
"type": "string",
|
||||
"description": "Content type for render action: html, svg, markdown, or text.",
|
||||
"enum": ["html", "svg", "markdown", "text"]
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to render (for render action)."
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "JavaScript expression to evaluate (for eval action). \
|
||||
The result is returned as text. Evaluated client-side in the canvas iframe."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let canvas_id = args
|
||||
.get("canvas_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
|
||||
match action {
|
||||
"render" => {
|
||||
let content_type = args
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("html");
|
||||
|
||||
let content = match args.get("content").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing required parameter: content (for render action)"
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if content.len() > MAX_CONTENT_SIZE {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Content exceeds maximum size of {} bytes",
|
||||
MAX_CONTENT_SIZE
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
match self.store.render(canvas_id, content_type, content) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Rendered {} content to canvas '{}' (frame: {})",
|
||||
content_type, canvas_id, frame.frame_id
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Maximum canvas count ({}) reached. Clear unused canvases first.",
|
||||
MAX_CANVAS_COUNT
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
"snapshot" => match self.store.snapshot(canvas_id) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&frame)
|
||||
.unwrap_or_else(|_| frame.content.clone()),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Canvas '{}' is empty", canvas_id),
|
||||
error: None,
|
||||
}),
|
||||
},
|
||||
|
||||
"clear" => {
|
||||
let existed = self.store.clear(canvas_id);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: if existed {
|
||||
format!("Canvas '{}' cleared", canvas_id)
|
||||
} else {
|
||||
format!("Canvas '{}' was already empty", canvas_id)
|
||||
},
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
"eval" => {
|
||||
// Eval is handled client-side. We store an eval request as a special frame
|
||||
// that the web viewer interprets.
|
||||
let expression = match args.get("expression").and_then(|v| v.as_str()) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing required parameter: expression (for eval action)"
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Push a special eval frame so connected clients know to evaluate it.
|
||||
match self.store.render(canvas_id, "eval", expression) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Eval request sent to canvas '{}' (frame: {}). \
|
||||
Result will be available to connected viewers.",
|
||||
canvas_id, frame.frame_id
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Maximum canvas count ({}) reached. Clear unused canvases first.",
|
||||
MAX_CANVAS_COUNT
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: '{}'. Valid actions: render, snapshot, clear, eval",
|
||||
other
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn canvas_store_render_and_snapshot() {
|
||||
let store = CanvasStore::new();
|
||||
let frame = store.render("test", "html", "<h1>Hello</h1>").unwrap();
|
||||
assert_eq!(frame.content_type, "html");
|
||||
assert_eq!(frame.content, "<h1>Hello</h1>");
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.frame_id, frame.frame_id);
|
||||
assert_eq!(snapshot.content, "<h1>Hello</h1>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_snapshot_empty_returns_none() {
|
||||
let store = CanvasStore::new();
|
||||
assert!(store.snapshot("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_clear_removes_content() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>content</p>");
|
||||
assert!(store.snapshot("test").is_some());
|
||||
|
||||
let cleared = store.clear("test");
|
||||
assert!(cleared);
|
||||
assert!(store.snapshot("test").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_clear_nonexistent_returns_false() {
|
||||
let store = CanvasStore::new();
|
||||
assert!(!store.clear("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_history_tracks_frames() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "frame1");
|
||||
store.render("test", "html", "frame2");
|
||||
store.render("test", "html", "frame3");
|
||||
|
||||
let history = store.history("test");
|
||||
assert_eq!(history.len(), 3);
|
||||
assert_eq!(history[0].content, "frame1");
|
||||
assert_eq!(history[2].content, "frame3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_history_limit_enforced() {
|
||||
let store = CanvasStore::new();
|
||||
for i in 0..60 {
|
||||
store.render("test", "html", &format!("frame{i}"));
|
||||
}
|
||||
|
||||
let history = store.history("test");
|
||||
assert_eq!(history.len(), MAX_HISTORY_FRAMES);
|
||||
// Oldest frames should have been dropped
|
||||
assert_eq!(history[0].content, "frame10");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_list_returns_canvas_ids() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("alpha", "html", "a");
|
||||
store.render("beta", "svg", "b");
|
||||
|
||||
let mut ids = store.list();
|
||||
ids.sort();
|
||||
assert_eq!(ids, vec!["alpha", "beta"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_subscribe_receives_updates() {
|
||||
let store = CanvasStore::new();
|
||||
let mut rx = store.subscribe("test").unwrap();
|
||||
store.render("test", "html", "<p>live</p>");
|
||||
|
||||
let frame = rx.try_recv().unwrap();
|
||||
assert_eq!(frame.content, "<p>live</p>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"canvas_id": "test",
|
||||
"content_type": "html",
|
||||
"content": "<h1>Hello World</h1>"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Rendered html content"));
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.content, "<h1>Hello World</h1>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_snapshot_action() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>snap</p>");
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "snapshot", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("<p>snap</p>"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_snapshot_empty() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "snapshot", "canvas_id": "empty"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_clear_action() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>clear me</p>");
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({"action": "clear", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("cleared"));
|
||||
assert!(store.snapshot("test").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_eval_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "eval",
|
||||
"canvas_id": "test",
|
||||
"expression": "document.title"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Eval request sent"));
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.content_type, "eval");
|
||||
assert_eq!(snapshot.content, "document.title");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_unknown_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_missing_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_missing_content() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "render", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("content"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_content_too_large() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let big_content = "x".repeat(MAX_CONTENT_SIZE + 1);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"canvas_id": "test",
|
||||
"content": big_content
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("maximum size"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_default_canvas_id() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"content_type": "html",
|
||||
"content": "<p>default</p>"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(store.snapshot("default").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_enforces_max_canvas_count() {
|
||||
let store = CanvasStore::new();
|
||||
// Create MAX_CANVAS_COUNT canvases
|
||||
for i in 0..MAX_CANVAS_COUNT {
|
||||
assert!(store
|
||||
.render(&format!("canvas_{i}"), "html", "content")
|
||||
.is_some());
|
||||
}
|
||||
// The next new canvas should be rejected
|
||||
assert!(store.render("one_too_many", "html", "content").is_none());
|
||||
// But rendering to an existing canvas should still work
|
||||
assert!(store.render("canvas_0", "html", "updated").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_eval_missing_expression() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "eval", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("expression"));
|
||||
}
|
||||
}
|
||||
@@ -530,6 +530,7 @@ impl DelegateTool {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Search Discord message history stored in discord.db.
|
||||
pub struct DiscordSearchTool {
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl DiscordSearchTool {
|
||||
pub fn new(discord_memory: Arc<dyn Memory>) -> Self {
|
||||
Self { discord_memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DiscordSearchTool {
|
||||
fn name(&self) -> &str {
|
||||
"discord_search"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search Discord message history. Returns messages matching a keyword query, optionally filtered by channel_id, author_id, or time range."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Keywords or phrase to search for in Discord messages (optional if since/until provided)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 10)"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Filter results to a specific Discord channel ID"
|
||||
},
|
||||
"since": {
|
||||
"type": "string",
|
||||
"description": "Filter messages at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
|
||||
},
|
||||
"until": {
|
||||
"type": "string",
|
||||
"description": "Filter messages at or before this time (RFC 3339)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let channel_id = args.get("channel_id").and_then(|v| v.as_str());
|
||||
let since = args.get("since").and_then(|v| v.as_str());
|
||||
let until = args.get("until").and_then(|v| v.as_str());
|
||||
|
||||
if query.trim().is_empty() && since.is_none() && until.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(s) = since {
|
||||
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'since' date: {s}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(u) = until {
|
||||
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'until' date: {u}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let (Some(s), Some(u)) = (since, until) {
|
||||
if let (Ok(s_dt), Ok(u_dt)) = (
|
||||
chrono::DateTime::parse_from_rfc3339(s),
|
||||
chrono::DateTime::parse_from_rfc3339(u),
|
||||
) {
|
||||
if s_dt >= u_dt {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'since' must be before 'until'".into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(10, |v| v as usize);
|
||||
|
||||
match self
|
||||
.discord_memory
|
||||
.recall(query, limit, channel_id, since, until)
|
||||
.await
|
||||
{
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No Discord messages found.".into(),
|
||||
error: None,
|
||||
}),
|
||||
Ok(entries) => {
|
||||
let mut output = format!("Found {} Discord messages:\n", entries.len());
|
||||
for entry in &entries {
|
||||
let score = entry
|
||||
.score
|
||||
.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
|
||||
let _ = writeln!(output, "- {}{score}", entry.content);
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Discord search failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn seeded_discord_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new_named(tmp.path(), "discord").unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_empty() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No Discord messages found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_finds_match() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
mem.store(
|
||||
"discord_001",
|
||||
"@user1 in #general at 2025-01-01T00:00:00Z: hello world",
|
||||
MemoryCategory::Custom("discord".to_string()),
|
||||
Some("general"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("hello"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_requires_query_or_time() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("at least"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
assert_eq!(tool.name(), "discord_search");
|
||||
assert!(tool.parameters_schema()["properties"]["query"].is_object());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,494 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
|
||||
///
|
||||
/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
|
||||
/// calls the fal.ai synchronous endpoint, downloads the resulting image,
|
||||
/// and saves it to `{workspace}/images/{filename}.png`.
|
||||
pub struct ImageGenTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
default_model: String,
|
||||
api_key_env: String,
|
||||
}
|
||||
|
||||
impl ImageGenTool {
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
default_model: String,
|
||||
api_key_env: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
workspace_dir,
|
||||
default_model,
|
||||
api_key_env,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a reusable HTTP client with reasonable timeouts.
|
||||
fn http_client() -> reqwest::Client {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.build()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Read an API key from the environment.
|
||||
fn read_api_key(env_var: &str) -> Result<String, String> {
|
||||
std::env::var(env_var)
|
||||
.map(|v| v.trim().to_string())
|
||||
.ok()
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
|
||||
}
|
||||
|
||||
/// Core generation logic: call fal.ai, download image, save to disk.
|
||||
async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// ── Parse parameters ───────────────────────────────────────
|
||||
let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
|
||||
Some(p) if !p.trim().is_empty() => p.trim().to_string(),
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: 'prompt'".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let filename = args
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or("generated_image");
|
||||
|
||||
// Sanitize filename — strip path components to prevent traversal.
|
||||
let safe_name = PathBuf::from(filename).file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
|
||||
let size = args
|
||||
.get("size")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("square_hd");
|
||||
|
||||
// Validate size enum.
|
||||
const VALID_SIZES: &[&str] = &[
|
||||
"square_hd",
|
||||
"landscape_4_3",
|
||||
"portrait_4_3",
|
||||
"landscape_16_9",
|
||||
"portrait_16_9",
|
||||
];
|
||||
if !VALID_SIZES.contains(&size) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid size '{size}'. Valid values: {}",
|
||||
VALID_SIZES.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let model = args
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or(&self.default_model);
|
||||
|
||||
// Validate model identifier: must look like a fal.ai model path
|
||||
// (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
|
||||
// strings, or fragments that could redirect the HTTP request.
|
||||
if model.contains("..")
|
||||
|| model.contains('?')
|
||||
|| model.contains('#')
|
||||
|| model.contains('\\')
|
||||
|| model.starts_with('/')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid model identifier '{model}'. \
|
||||
Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// ── Read API key ───────────────────────────────────────────
|
||||
let api_key = match Self::read_api_key(&self.api_key_env) {
|
||||
Ok(k) => k,
|
||||
Err(msg) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(msg),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// ── Call fal.ai ────────────────────────────────────────────
|
||||
let client = Self::http_client();
|
||||
let url = format!("https://fal.run/{model}");
|
||||
|
||||
let body = json!({
|
||||
"prompt": prompt,
|
||||
"image_size": size,
|
||||
"num_images": 1
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Key {api_key}"))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("fal.ai request failed")?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("fal.ai API error ({status}): {body_text}")),
|
||||
});
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse fal.ai response as JSON")?;
|
||||
|
||||
let image_url = resp_json
|
||||
.pointer("/images/0/url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("No image URL in fal.ai response"))?;
|
||||
|
||||
// ── Download image ─────────────────────────────────────────
|
||||
let img_resp = client
|
||||
.get(image_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to download generated image")?;
|
||||
|
||||
if !img_resp.status().is_success() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to download image from {image_url} ({})",
|
||||
img_resp.status()
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let bytes = img_resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("Failed to read image bytes")?;
|
||||
|
||||
// ── Save to disk ───────────────────────────────────────────
|
||||
let images_dir = self.workspace_dir.join("images");
|
||||
tokio::fs::create_dir_all(&images_dir)
|
||||
.await
|
||||
.context("Failed to create images directory")?;
|
||||
|
||||
let output_path = images_dir.join(format!("{safe_name}.png"));
|
||||
tokio::fs::write(&output_path, &bytes)
|
||||
.await
|
||||
.context("Failed to write image file")?;
|
||||
|
||||
let size_kb = bytes.len() / 1024;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Image generated successfully.\n\
|
||||
File: {}\n\
|
||||
Size: {} KB\n\
|
||||
Model: {}\n\
|
||||
Prompt: {}",
|
||||
output_path.display(),
|
||||
size_kb,
|
||||
model,
|
||||
prompt,
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ImageGenTool {
|
||||
fn name(&self) -> &str {
|
||||
"image_gen"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Generate an image from a text prompt using fal.ai (Flux models). \
|
||||
Saves the result to the workspace images directory and returns the file path."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["prompt"],
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text prompt describing the image to generate."
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
|
||||
},
|
||||
"size": {
|
||||
"type": "string",
|
||||
"enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
|
||||
"description": "Image aspect ratio / size preset (default: 'square_hd')."
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security: image generation is a side-effecting action (HTTP + file write).
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "image_gen")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
self.generate(args).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn test_tool() -> ImageGenTool {
|
||||
ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY".into(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "image_gen");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_description_is_nonempty() {
|
||||
let tool = test_tool();
|
||||
assert!(!tool.description().is_empty());
|
||||
assert!(tool.description().contains("image"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_required_prompt() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["required"], json!(["prompt"]));
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_optional_params() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["filename"].is_object());
|
||||
assert!(schema["properties"]["size"].is_object());
|
||||
assert!(schema["properties"]["model"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_spec_roundtrip() {
|
||||
let tool = test_tool();
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "image_gen");
|
||||
assert!(spec.parameters.is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_prompt_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"prompt": " "})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_api_key_returns_error() {
|
||||
// Temporarily ensure the env var is unset.
|
||||
let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
|
||||
std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_IMAGE_GEN".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "a sunset over the ocean"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("FAL_API_KEY_TEST_IMAGE_GEN"));
|
||||
|
||||
// Restore if it was set.
|
||||
if let Some(val) = original {
|
||||
std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_size_returns_error() {
|
||||
// Set a dummy key so we get past the key check.
|
||||
std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_SIZE".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "test", "size": "invalid_size"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Invalid size"));
|
||||
|
||||
std::env::remove_var("FAL_API_KEY_TEST_SIZE");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_only_autonomy_blocks_execution() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ImageGenTool::new(
|
||||
security,
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY".into(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
let err = result.error.as_deref().unwrap();
|
||||
assert!(
|
||||
err.contains("read-only") || err.contains("image_gen"),
|
||||
"expected read-only or image_gen in error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_model_with_traversal_returns_error() {
|
||||
std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_MODEL".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("Invalid model identifier"));
|
||||
|
||||
std::env::remove_var("FAL_API_KEY_TEST_MODEL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_missing() {
|
||||
let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.contains("DEFINITELY_NOT_SET_ZC_TEST_12345"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filename_traversal_is_sanitized() {
|
||||
// Verify that path traversal in filenames is stripped to just the final component.
|
||||
let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
assert_eq!(sanitized, "passwd");
|
||||
|
||||
// ".." alone has no file_name, falls back to default.
|
||||
let sanitized = PathBuf::from("..").file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
assert_eq!(sanitized, "generated_image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_present() {
|
||||
std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123");
|
||||
let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "test_value_123");
|
||||
std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY");
|
||||
}
|
||||
}
|
||||
+119
-10
@@ -20,6 +20,7 @@ pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
pub mod browser_open;
|
||||
pub mod calculator;
|
||||
pub mod canvas;
|
||||
pub mod claude_code;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
@@ -34,6 +35,7 @@ pub mod cron_runs;
|
||||
pub mod cron_update;
|
||||
pub mod data_management;
|
||||
pub mod delegate;
|
||||
pub mod discord_search;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
@@ -47,6 +49,7 @@ pub mod hardware_memory_map;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub mod hardware_memory_read;
|
||||
pub mod http_request;
|
||||
pub mod image_gen;
|
||||
pub mod image_info;
|
||||
pub mod jira_tool;
|
||||
pub mod knowledge_tool;
|
||||
@@ -69,13 +72,17 @@ pub mod pdf_read;
|
||||
pub mod project_intel;
|
||||
pub mod proxy_config;
|
||||
pub mod pushover;
|
||||
pub mod reaction;
|
||||
pub mod read_skill;
|
||||
pub mod report_templates;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod sessions;
|
||||
pub mod shell;
|
||||
pub mod skill_http;
|
||||
pub mod skill_tool;
|
||||
pub mod swarm;
|
||||
pub mod text_browser;
|
||||
pub mod tool_search;
|
||||
@@ -93,6 +100,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use canvas::{CanvasStore, CanvasTool};
|
||||
pub use claude_code::ClaudeCodeTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
@@ -106,6 +114,7 @@ pub use cron_runs::CronRunsTool;
|
||||
pub use cron_update::CronUpdateTool;
|
||||
pub use data_management::DataManagementTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use discord_search::DiscordSearchTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
@@ -119,6 +128,7 @@ pub use hardware_memory_map::HardwareMemoryMapTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub use hardware_memory_read::HardwareMemoryReadTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use image_gen::ImageGenTool;
|
||||
pub use image_info::ImageInfoTool;
|
||||
pub use jira_tool::JiraTool;
|
||||
pub use knowledge_tool::KnowledgeTool;
|
||||
@@ -139,13 +149,19 @@ pub use pdf_read::PdfReadTool;
|
||||
pub use project_intel::ProjectIntelTool;
|
||||
pub use proxy_config::ProxyConfigTool;
|
||||
pub use pushover::PushoverTool;
|
||||
pub use reaction::{ChannelMapHandle, ReactionTool};
|
||||
pub use read_skill::ReadSkillTool;
|
||||
pub use schedule::ScheduleTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use sessions::{SessionsHistoryTool, SessionsListTool, SessionsSendTool};
|
||||
pub use shell::ShellTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use skill_http::SkillHttpTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use skill_tool::SkillShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use text_browser::TextBrowserTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
@@ -247,6 +263,33 @@ pub fn default_tools_with_runtime(
|
||||
]
|
||||
}
|
||||
|
||||
/// Register skill-defined tools into an existing tool registry.
|
||||
///
|
||||
/// Converts each skill's `[[tools]]` entries into callable `Tool` implementations
|
||||
/// and appends them to the registry. Skill tools that would shadow a built-in tool
|
||||
/// name are skipped with a warning.
|
||||
pub fn register_skill_tools(
|
||||
tools_registry: &mut Vec<Box<dyn Tool>>,
|
||||
skills: &[crate::skills::Skill],
|
||||
security: Arc<SecurityPolicy>,
|
||||
) {
|
||||
let skill_tools = crate::skills::skills_to_tools(skills, security);
|
||||
let existing_names: std::collections::HashSet<String> = tools_registry
|
||||
.iter()
|
||||
.map(|t| t.name().to_string())
|
||||
.collect();
|
||||
for tool in skill_tools {
|
||||
if existing_names.contains(tool.name()) {
|
||||
tracing::warn!(
|
||||
"Skill tool '{}' shadows built-in tool, skipping",
|
||||
tool.name()
|
||||
);
|
||||
} else {
|
||||
tools_registry.push(tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create full tool registry including memory tools and optional Composio
|
||||
#[allow(clippy::implicit_hasher, clippy::too_many_arguments)]
|
||||
pub fn all_tools(
|
||||
@@ -262,7 +305,12 @@ pub fn all_tools(
|
||||
agents: &HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
|
||||
canvas_store: Option<CanvasStore>,
|
||||
) -> (
|
||||
Vec<Box<dyn Tool>>,
|
||||
Option<DelegateParentToolsHandle>,
|
||||
Option<ChannelMapHandle>,
|
||||
) {
|
||||
all_tools_with_runtime(
|
||||
config,
|
||||
security,
|
||||
@@ -277,6 +325,7 @@ pub fn all_tools(
|
||||
agents,
|
||||
fallback_api_key,
|
||||
root_config,
|
||||
canvas_store,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -296,7 +345,12 @@ pub fn all_tools_with_runtime(
|
||||
agents: &HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
|
||||
canvas_store: Option<CanvasStore>,
|
||||
) -> (
|
||||
Vec<Box<dyn Tool>>,
|
||||
Option<DelegateParentToolsHandle>,
|
||||
Option<ChannelMapHandle>,
|
||||
) {
|
||||
let has_shell_access = runtime.has_shell_access();
|
||||
let sandbox = create_sandbox(&root_config.security);
|
||||
let mut tool_arcs: Vec<Arc<dyn Tool>> = vec![
|
||||
@@ -336,8 +390,21 @@ pub fn all_tools_with_runtime(
|
||||
)),
|
||||
Arc::new(CalculatorTool::new()),
|
||||
Arc::new(WeatherTool::new()),
|
||||
Arc::new(CanvasTool::new(canvas_store.unwrap_or_default())),
|
||||
];
|
||||
|
||||
// Register discord_search if discord_history channel is configured
|
||||
if root_config.channels_config.discord_history.is_some() {
|
||||
match crate::memory::SqliteMemory::new_named(workspace_dir, "discord") {
|
||||
Ok(discord_mem) => {
|
||||
tool_arcs.push(Arc::new(DiscordSearchTool::new(Arc::new(discord_mem))));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("discord_search: failed to open discord.db: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(
|
||||
root_config.skills.prompt_injection_mode,
|
||||
crate::config::SkillsPromptInjectionMode::Compact
|
||||
@@ -424,6 +491,7 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(WebSearchTool::new_with_config(
|
||||
root_config.web_search.provider.clone(),
|
||||
root_config.web_search.brave_api_key.clone(),
|
||||
root_config.web_search.searxng_instance_url.clone(),
|
||||
root_config.web_search.max_results,
|
||||
root_config.web_search.timeout_secs,
|
||||
root_config.config_path.clone(),
|
||||
@@ -545,6 +613,18 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
|
||||
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));
|
||||
|
||||
// Session-to-session messaging tools (always available when sessions dir exists)
|
||||
if let Ok(session_store) = crate::channels::session_store::SessionStore::new(workspace_dir) {
|
||||
let backend: Arc<dyn crate::channels::session_backend::SessionBackend> =
|
||||
Arc::new(session_store);
|
||||
tool_arcs.push(Arc::new(SessionsListTool::new(backend.clone())));
|
||||
tool_arcs.push(Arc::new(SessionsHistoryTool::new(
|
||||
backend.clone(),
|
||||
security.clone(),
|
||||
)));
|
||||
tool_arcs.push(Arc::new(SessionsSendTool::new(backend, security.clone())));
|
||||
}
|
||||
|
||||
// LinkedIn integration (config-gated)
|
||||
if root_config.linkedin.enabled {
|
||||
tool_arcs.push(Arc::new(LinkedInTool::new(
|
||||
@@ -556,6 +636,16 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Standalone image generation tool (config-gated)
|
||||
if root_config.image_gen.enabled {
|
||||
tool_arcs.push(Arc::new(ImageGenTool::new(
|
||||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.image_gen.default_model.clone(),
|
||||
root_config.image_gen.api_key_env.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(key) = composio_key {
|
||||
if !key.is_empty() {
|
||||
tool_arcs.push(Arc::new(ComposioTool::new(
|
||||
@@ -566,6 +656,11 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
}
|
||||
|
||||
// Emoji reaction tool — always registered; channel map populated later by start_channels.
|
||||
let reaction_tool = ReactionTool::new(security.clone());
|
||||
let reaction_handle = reaction_tool.channel_map_handle();
|
||||
tool_arcs.push(Arc::new(reaction_tool));
|
||||
|
||||
// Microsoft 365 Graph API integration
|
||||
if root_config.microsoft365.enabled {
|
||||
let ms_cfg = &root_config.microsoft365;
|
||||
@@ -592,7 +687,11 @@ pub fn all_tools_with_runtime(
|
||||
tracing::error!(
|
||||
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
|
||||
);
|
||||
return (boxed_registry_from_arcs(tool_arcs), None);
|
||||
return (
|
||||
boxed_registry_from_arcs(tool_arcs),
|
||||
None,
|
||||
Some(reaction_handle),
|
||||
);
|
||||
}
|
||||
|
||||
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
|
||||
@@ -776,7 +875,11 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
}
|
||||
|
||||
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
|
||||
(
|
||||
boxed_registry_from_arcs(tool_arcs),
|
||||
delegate_handle,
|
||||
Some(reaction_handle),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -820,7 +923,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -833,6 +936,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"browser_open"));
|
||||
@@ -862,7 +966,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -875,6 +979,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"browser_open"));
|
||||
@@ -1015,7 +1120,7 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1028,6 +1133,7 @@ mod tests {
|
||||
&agents,
|
||||
Some("delegate-test-credential"),
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"delegate"));
|
||||
@@ -1048,7 +1154,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1061,6 +1167,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"delegate"));
|
||||
@@ -1082,7 +1189,7 @@ mod tests {
|
||||
let mut cfg = test_config(&tmp);
|
||||
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Compact;
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(cfg.clone()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1095,6 +1202,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"read_skill"));
|
||||
@@ -1116,7 +1224,7 @@ mod tests {
|
||||
let mut cfg = test_config(&tmp);
|
||||
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Full;
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(cfg.clone()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1129,6 +1237,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"read_skill"));
|
||||
|
||||
@@ -0,0 +1,546 @@
|
||||
//! Emoji reaction tool for cross-channel message reactions.
|
||||
//!
|
||||
//! Exposes `add_reaction` and `remove_reaction` from the [`Channel`] trait as an
|
||||
//! agent-callable tool. The tool holds a late-binding channel map handle that is
|
||||
//! populated once channels are initialized (after tool construction). This mirrors
|
||||
//! the pattern used by [`DelegateTool`] for its parent-tools handle.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::channels::traits::Channel;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Shared handle to the channel map. Starts empty; populated once channels boot.
|
||||
pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
|
||||
|
||||
/// Agent-callable tool for adding or removing emoji reactions on messages.
|
||||
pub struct ReactionTool {
|
||||
channels: ChannelMapHandle,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl ReactionTool {
|
||||
/// Create a new reaction tool with an empty channel map.
|
||||
/// Call [`populate`] or write to the returned [`ChannelMapHandle`] once channels
|
||||
/// are available.
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
security,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the shared handle so callers can populate it after channel init.
|
||||
pub fn channel_map_handle(&self) -> ChannelMapHandle {
|
||||
Arc::clone(&self.channels)
|
||||
}
|
||||
|
||||
/// Convenience: populate the channel map from a pre-built map.
|
||||
pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
|
||||
*self.channels.write() = map;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ReactionTool {
|
||||
fn name(&self) -> &str {
|
||||
"reaction"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Add or remove an emoji reaction on a message in any active channel. \
|
||||
Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
|
||||
the platform message ID, and the emoji (Unicode character or platform shortcode)."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"description": "Platform-scoped message identifier to react to"
|
||||
},
|
||||
"emoji": {
|
||||
"type": "string",
|
||||
"description": "Emoji to react with (Unicode character or platform shortcode)"
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "remove"],
|
||||
"description": "Whether to add or remove the reaction (default: 'add')"
|
||||
}
|
||||
},
|
||||
"required": ["channel", "channel_id", "message_id", "emoji"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security gate
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "reaction")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let channel_name = args
|
||||
.get("channel")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'channel' parameter"))?;
|
||||
|
||||
let channel_id = args
|
||||
.get("channel_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'channel_id' parameter"))?;
|
||||
|
||||
let message_id = args
|
||||
.get("message_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message_id' parameter"))?;
|
||||
|
||||
let emoji = args
|
||||
.get("emoji")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'emoji' parameter"))?;
|
||||
|
||||
let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
|
||||
|
||||
if action != "add" && action != "remove" {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid action '{action}': must be 'add' or 'remove'"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Read-lock the channel map to find the target channel.
|
||||
let channel = {
|
||||
let map = self.channels.read();
|
||||
if map.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("No channels available yet (channels not initialized)".to_string()),
|
||||
});
|
||||
}
|
||||
match map.get(channel_name) {
|
||||
Some(ch) => Arc::clone(ch),
|
||||
None => {
|
||||
let available: Vec<String> = map.keys().cloned().collect();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Channel '{channel_name}' not found. Available channels: {}",
|
||||
available.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let result = if action == "add" {
|
||||
channel.add_reaction(channel_id, message_id, emoji).await
|
||||
} else {
|
||||
channel.remove_reaction(channel_id, message_id, emoji).await
|
||||
};
|
||||
|
||||
let past_tense = if action == "remove" {
|
||||
"removed"
|
||||
} else {
|
||||
"added"
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to {action} reaction: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::channels::traits::{ChannelMessage, SendMessage};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
struct MockChannel {
|
||||
reaction_added: AtomicBool,
|
||||
reaction_removed: AtomicBool,
|
||||
last_channel_id: parking_lot::Mutex<Option<String>>,
|
||||
fail_on_add: bool,
|
||||
}
|
||||
|
||||
impl MockChannel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
reaction_added: AtomicBool::new(false),
|
||||
reaction_removed: AtomicBool::new(false),
|
||||
last_channel_id: parking_lot::Mutex::new(None),
|
||||
fail_on_add: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn failing() -> Self {
|
||||
Self {
|
||||
reaction_added: AtomicBool::new(false),
|
||||
reaction_removed: AtomicBool::new(false),
|
||||
last_channel_id: parking_lot::Mutex::new(None),
|
||||
fail_on_add: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for MockChannel {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_reaction(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
_message_id: &str,
|
||||
_emoji: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if self.fail_on_add {
|
||||
return Err(anyhow::anyhow!("API error: rate limited"));
|
||||
}
|
||||
*self.last_channel_id.lock() = Some(channel_id.to_string());
|
||||
self.reaction_added.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_reaction(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
_message_id: &str,
|
||||
_emoji: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
*self.last_channel_id.lock() = Some(channel_id.to_string());
|
||||
self.reaction_removed.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let map: HashMap<String, Arc<dyn Channel>> = channels
|
||||
.into_iter()
|
||||
.map(|(name, ch)| (name.to_string(), ch))
|
||||
.collect();
|
||||
tool.populate(map);
|
||||
tool
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_metadata() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
assert_eq!(tool.name(), "reaction");
|
||||
assert!(!tool.description().is_empty());
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["channel"].is_object());
|
||||
assert!(schema["properties"]["channel_id"].is_object());
|
||||
assert!(schema["properties"]["message_id"].is_object());
|
||||
assert!(schema["properties"]["emoji"].is_object());
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.iter().any(|v| v == "channel"));
|
||||
assert!(required.iter().any(|v| v == "channel_id"));
|
||||
assert!(required.iter().any(|v| v == "message_id"));
|
||||
assert!(required.iter().any(|v| v == "emoji"));
|
||||
// action is optional (defaults to "add")
|
||||
assert!(!required.iter().any(|v| v == "action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_reaction_success() {
|
||||
let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
|
||||
let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_123",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("added"));
|
||||
assert!(result.error.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_reaction_success() {
|
||||
let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
|
||||
let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "slack",
|
||||
"channel_id": "C0123SLACK",
|
||||
"message_id": "msg_456",
|
||||
"emoji": "\u{1F440}",
|
||||
"action": "remove"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("removed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_channel_returns_error() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"discord",
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "nonexistent",
|
||||
"channel_id": "ch_x",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
let err = result.error.as_deref().unwrap();
|
||||
assert!(err.contains("not found"));
|
||||
assert!(err.contains("discord"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_action_returns_error() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"discord",
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}",
|
||||
"action": "toggle"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("toggle"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_error_propagated() {
|
||||
let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
|
||||
let tool = make_tool_with_channels(vec![("discord", mock)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("rate limited"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_required_params() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"test",
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
)]);
|
||||
|
||||
// Missing channel
|
||||
let result = tool
|
||||
.execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Missing channel_id
|
||||
let result = tool
|
||||
.execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Missing message_id
|
||||
let result = tool
|
||||
.execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Missing emoji
|
||||
let result = tool
|
||||
.execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_channels_returns_not_initialized() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
// No channels populated
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("not initialized"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_action_is_add() {
|
||||
let mock = Arc::new(MockChannel::new());
|
||||
let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
|
||||
let tool = make_tool_with_channels(vec![("test", mock_ch)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "test",
|
||||
"channel_id": "ch_test",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(mock.reaction_added.load(Ordering::SeqCst));
|
||||
assert!(!mock.reaction_removed.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_id_passed_to_trait_not_channel_name() {
|
||||
let mock = Arc::new(MockChannel::new());
|
||||
let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
|
||||
let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "123456789",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
// The trait must receive the platform channel_id, not the channel name
|
||||
assert_eq!(
|
||||
mock.last_channel_id.lock().as_deref(),
|
||||
Some("123456789"),
|
||||
"add_reaction must receive channel_id, not channel name"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_map_handle_allows_late_binding() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let handle = tool.channel_map_handle();
|
||||
|
||||
// Initially empty — tool reports not initialized
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "slack",
|
||||
"channel_id": "C0123",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
|
||||
// Populate via the handle
|
||||
{
|
||||
let mut map = handle.write();
|
||||
map.insert(
|
||||
"slack".to_string(),
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
);
|
||||
}
|
||||
|
||||
// Now the tool can route to the channel
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "slack",
|
||||
"channel_id": "C0123",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_matches_metadata() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "reaction");
|
||||
assert_eq!(spec.description, tool.description());
|
||||
assert!(spec.parameters["required"].is_array());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
//! Session-to-session messaging tools for inter-agent communication.
|
||||
//!
|
||||
//! Provides three tools:
|
||||
//! - `sessions_list` — list active sessions with metadata
|
||||
//! - `sessions_history` — read message history from a specific session
|
||||
//! - `sessions_send` — send a message to a specific session
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::channels::session_backend::SessionBackend;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Validate that a session ID is non-empty and contains at least one
|
||||
/// alphanumeric character (prevents blank keys after sanitization).
|
||||
fn validate_session_id(session_id: &str) -> Result<(), ToolResult> {
|
||||
let trimmed = session_id.trim();
|
||||
if trimmed.is_empty() || !trimmed.chars().any(|c| c.is_alphanumeric()) {
|
||||
return Err(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Invalid 'session_id': must be non-empty and contain at least one alphanumeric character.".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── SessionsListTool ────────────────────────────────────────────────
|
||||
|
||||
/// Lists active sessions with their channel, last activity time, and message count.
|
||||
pub struct SessionsListTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
}
|
||||
|
||||
impl SessionsListTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>) -> Self {
|
||||
Self { backend }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsListTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_list"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"List all active conversation sessions with their channel, last activity time, and message count."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max sessions to return (default: 50)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(50, |v| v as usize);
|
||||
|
||||
let metadata = self.backend.list_sessions_with_metadata();
|
||||
|
||||
if metadata.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No active sessions found.".into(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let capped: Vec<_> = metadata.into_iter().take(limit).collect();
|
||||
let mut output = format!("Found {} session(s):\n", capped.len());
|
||||
for meta in &capped {
|
||||
// Extract channel from key (convention: channel__identifier)
|
||||
let channel = meta.key.split("__").next().unwrap_or(&meta.key);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"- {}: channel={}, messages={}, last_activity={}",
|
||||
meta.key, channel, meta.message_count, meta.last_activity
|
||||
);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionsHistoryTool ─────────────────────────────────────────────
|
||||
|
||||
/// Reads the message history of a specific session by ID.
|
||||
pub struct SessionsHistoryTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SessionsHistoryTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { backend, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsHistoryTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_history"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the message history of a specific session by its session ID. Returns the last N messages."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The session ID to read history from (e.g. telegram__user123)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max messages to return, from most recent (default: 20)"
|
||||
}
|
||||
},
|
||||
"required": ["session_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "sessions_history")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let session_id = args
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
|
||||
|
||||
if let Err(result) = validate_session_id(session_id) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(20, |v| v as usize);
|
||||
|
||||
let messages = self.backend.load(session_id);
|
||||
|
||||
if messages.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("No messages found for session '{session_id}'."),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Take the last `limit` messages
|
||||
let start = messages.len().saturating_sub(limit);
|
||||
let tail = &messages[start..];
|
||||
|
||||
let mut output = format!(
|
||||
"Session '{}': showing {}/{} messages\n",
|
||||
session_id,
|
||||
tail.len(),
|
||||
messages.len()
|
||||
);
|
||||
for msg in tail {
|
||||
let _ = writeln!(output, "[{}] {}", msg.role, msg.content);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionsSendTool ────────────────────────────────────────────────
|
||||
|
||||
/// Sends a message to a specific session, enabling inter-agent communication.
|
||||
pub struct SessionsSendTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SessionsSendTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { backend, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsSendTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_send"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a message to a specific session by its session ID. The message is appended to the session's conversation history as a 'user' message, enabling inter-agent communication."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The target session ID (e.g. telegram__user123)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
}
|
||||
},
|
||||
"required": ["session_id", "message"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "sessions_send")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let session_id = args
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
|
||||
|
||||
if let Err(result) = validate_session_id(session_id) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let message = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?;
|
||||
|
||||
if message.trim().is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Message content must not be empty.".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let chat_msg = crate::providers::traits::ChatMessage::user(message);
|
||||
|
||||
match self.backend.append(session_id, &chat_msg) {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Message sent to session '{session_id}'."),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to send message: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::channels::session_store::SessionStore;
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_backend() -> (TempDir, Arc<dyn SessionBackend>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(store))
|
||||
}
|
||||
|
||||
fn seeded_backend() -> (TempDir, Arc<dyn SessionBackend>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
store
|
||||
.append("telegram__alice", &ChatMessage::user("Hello from Alice"))
|
||||
.unwrap();
|
||||
store
|
||||
.append(
|
||||
"telegram__alice",
|
||||
&ChatMessage::assistant("Hi Alice, how can I help?"),
|
||||
)
|
||||
.unwrap();
|
||||
store
|
||||
.append("discord__bob", &ChatMessage::user("Hey from Bob"))
|
||||
.unwrap();
|
||||
(tmp, Arc::new(store))
|
||||
}
|
||||
|
||||
// ── SessionsListTool tests ──────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_empty_sessions() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No active sessions"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_shows_all() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("2 session(s)"));
|
||||
assert!(result.output.contains("telegram__alice"));
|
||||
assert!(result.output.contains("discord__bob"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_respects_limit() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({"limit": 1})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("1 session(s)"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_extracts_channel() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.output.contains("channel=telegram"));
|
||||
assert!(result.output.contains("channel=discord"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
assert_eq!(tool.name(), "sessions_list");
|
||||
assert!(tool.parameters_schema()["properties"]["limit"].is_object());
|
||||
}
|
||||
|
||||
// ── SessionsHistoryTool tests ───────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_empty_session() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "nonexistent"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No messages found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_returns_messages() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "telegram__alice"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("showing 2/2 messages"));
|
||||
assert!(result.output.contains("[user] Hello from Alice"));
|
||||
assert!(result.output.contains("[assistant] Hi Alice"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_respects_limit() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "telegram__alice", "limit": 1}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("showing 1/2 messages"));
|
||||
// Should show only the last message
|
||||
assert!(result.output.contains("[assistant]"));
|
||||
assert!(!result.output.contains("[user] Hello from Alice"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_missing_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("session_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_rejects_empty_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"session_id": " "})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
assert_eq!(tool.name(), "sessions_history");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["session_id"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("session_id")));
|
||||
}
|
||||
|
||||
// ── SessionsSendTool tests ──────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_appends_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": "Hello from another agent"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Message sent"));
|
||||
|
||||
// Verify message was appended
|
||||
let messages = backend.load("telegram__alice");
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "Hello from another agent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_to_existing_session() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsSendTool::new(backend.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": "Inter-agent message"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let messages = backend.load("telegram__alice");
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(messages[2].content, "Inter-agent message");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_empty_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": " "
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_empty_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "",
|
||||
"message": "hello"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_non_alphanumeric_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "///",
|
||||
"message": "hello"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_missing_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"message": "hi"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("session_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_missing_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"session_id": "telegram__alice"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("message"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
assert_eq!(tool.name(), "sessions_send");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("session_id")));
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("message")));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
//! HTTP-based tool derived from a skill's `[[tools]]` section.
|
||||
//!
|
||||
//! Each `SkillTool` with `kind = "http"` is converted into a `SkillHttpTool`
|
||||
//! that implements the `Tool` trait. The command field is used as the URL
|
||||
//! template and args are substituted as query parameters or path segments.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Maximum response body size (1 MB).
|
||||
const MAX_RESPONSE_BYTES: usize = 1_048_576;
|
||||
/// HTTP request timeout (seconds).
|
||||
const HTTP_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// A tool derived from a skill's `[[tools]]` section that makes HTTP requests.
|
||||
pub struct SkillHttpTool {
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
url_template: String,
|
||||
args: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl SkillHttpTool {
|
||||
/// Create a new skill HTTP tool.
|
||||
///
|
||||
/// The tool name is prefixed with the skill name (`skill_name.tool_name`)
|
||||
/// to prevent collisions with built-in tools.
|
||||
pub fn new(skill_name: &str, tool: &crate::skills::SkillTool) -> Self {
|
||||
Self {
|
||||
tool_name: format!("{}.{}", skill_name, tool.name),
|
||||
tool_description: tool.description.clone(),
|
||||
url_template: tool.command.clone(),
|
||||
args: tool.args.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_parameters_schema(&self) -> serde_json::Value {
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut required = Vec::new();
|
||||
|
||||
for (name, description) in &self.args {
|
||||
properties.insert(
|
||||
name.clone(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": description
|
||||
}),
|
||||
);
|
||||
required.push(serde_json::Value::String(name.clone()));
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
})
|
||||
}
|
||||
|
||||
/// Substitute `{{arg_name}}` placeholders in the URL template with
|
||||
/// the provided argument values.
|
||||
fn substitute_args(&self, args: &serde_json::Value) -> String {
|
||||
let mut url = self.url_template.clone();
|
||||
if let Some(obj) = args.as_object() {
|
||||
for (key, value) in obj {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
let replacement = value.as_str().unwrap_or_default();
|
||||
url = url.replace(&placeholder, replacement);
|
||||
}
|
||||
}
|
||||
url
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillHttpTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.tool_description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.build_parameters_schema()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let url = self.substitute_args(&args);
|
||||
|
||||
// Validate URL scheme
|
||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Only http:// and https:// URLs are allowed, got: {url}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
|
||||
.build()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build HTTP client: {e}"))?;
|
||||
|
||||
let response = match client.get(&url).send().await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("HTTP request failed: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let body = match response.bytes().await {
|
||||
Ok(bytes) => {
|
||||
let mut text = String::from_utf8_lossy(&bytes).to_string();
|
||||
if text.len() > MAX_RESPONSE_BYTES {
|
||||
let mut b = MAX_RESPONSE_BYTES.min(text.len());
|
||||
while b > 0 && !text.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
text.truncate(b);
|
||||
text.push_str("\n... [response truncated at 1MB]");
|
||||
}
|
||||
text
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to read response body: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
success: status.is_success(),
|
||||
output: body,
|
||||
error: if status.is_success() {
|
||||
None
|
||||
} else {
|
||||
Some(format!("HTTP {}", status))
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::skills::SkillTool;
|
||||
|
||||
fn sample_http_tool() -> SkillTool {
|
||||
let mut args = HashMap::new();
|
||||
args.insert("city".to_string(), "City name to look up".to_string());
|
||||
|
||||
SkillTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Fetch weather for a city".to_string(),
|
||||
kind: "http".to_string(),
|
||||
command: "https://api.example.com/weather?city={{city}}".to_string(),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_name_is_prefixed() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
assert_eq!(tool.name(), "weather_skill.get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_description() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
assert_eq!(tool.description(), "Fetch weather for a city");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_parameters_schema() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["city"].is_object());
|
||||
assert_eq!(schema["properties"]["city"]["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_substitute_args() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let result = tool.substitute_args(&serde_json::json!({"city": "London"}));
|
||||
assert_eq!(result, "https://api.example.com/weather?city=London");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_spec_roundtrip() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "weather_skill.get_weather");
|
||||
assert_eq!(spec.description, "Fetch weather for a city");
|
||||
assert_eq!(spec.parameters["type"], "object");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_empty_args() {
|
||||
let st = SkillTool {
|
||||
name: "ping".to_string(),
|
||||
description: "Ping endpoint".to_string(),
|
||||
kind: "http".to_string(),
|
||||
command: "https://api.example.com/ping".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillHttpTool::new("s", &st);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].as_object().unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user