Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 41dd23175f | |||
| 864d754b56 | |||
| ccd52f3394 | |||
| eb01aa451d | |||
| c785b45f2d | |||
| ffb8b81f90 | |||
| 65f856d710 | |||
| 1682620377 | |||
| aa455ae89b | |||
| a9ffd38912 | |||
| 86a0584513 | |||
| abef4c5719 | |||
| 483b2336c4 | |||
| 14cda3bc9a | |||
| 6e8f0fa43c | |||
| a965b129f8 | |||
| c135de41b7 | |||
| 2d2c2ac9e6 | |||
| 5e774bbd70 | |||
| 33015067eb | |||
| 6b10c0b891 | |||
| bf817e30d2 | |||
| 0051a0c296 | |||
| d753de91f1 | |||
| f6b2f61a01 | |||
| 70e7910cb9 | |||
| a8868768e8 | |||
| 67293c50df | |||
| 1646079d25 | |||
| 25b639435f | |||
| 77779844e5 | |||
| f658d5806a | |||
| 7134fe0824 | |||
| 263802b3df | |||
| 3c25fddb2a | |||
| a6a46bdd25 | |||
| 235d4d2f1c | |||
| bd1e8c8e1a | |||
| f81807bff6 | |||
| bb7006313c | |||
| 9a49626376 | |||
| 8b978a721f | |||
| 75b4c1d4a4 | |||
| cb0779d761 |
@@ -1,97 +0,0 @@
|
||||
# Mem0 Integration: Dual-Scope Recall + Per-Turn Memory
|
||||
|
||||
## Context
|
||||
|
||||
Mem0 auto-save works but the integration is missing key features from mem0 best practices: per-turn recall, multi-level scoping, and proper context injection. This causes the bot to "forget" on follow-up turns and not differentiate users.
|
||||
|
||||
## What's Missing (vs mem0 docs)
|
||||
|
||||
1. **Per-turn recall** — only first turn gets memory context, follow-ups get nothing
|
||||
2. **Dual-scope** — no sender vs group distinction. All memories use single hardcoded `user_id`
|
||||
3. **System prompt injection** — memory prepended to user message (pollutes session history)
|
||||
4. **`agent_id` scoping** — mem0 supports agent-level patterns, not used
|
||||
|
||||
## Changes
|
||||
|
||||
### 1. `src/memory/mem0.rs` — Use session_id for multi-level scoping
|
||||
|
||||
Map zeroclaw's `session_id` param to mem0's `user_id`. This enables per-user and per-group memory namespaces without changing the `Memory` trait.
|
||||
|
||||
```rust
|
||||
// Add helper:
|
||||
fn effective_user_id(&self, session_id: Option<&str>) -> &str {
|
||||
session_id.filter(|s| !s.is_empty()).unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
// In store(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In recall(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In list(): use effective_user_id(session_id) as mem0 user_id
|
||||
```
|
||||
|
||||
### 2. `src/channels/mod.rs` ~line 2229 — Per-turn dual-scope recall
|
||||
|
||||
Remove `if !had_prior_history` gate. Always recall from both sender scope and group scope (for group chats).
|
||||
|
||||
```rust
|
||||
// Detect group chat
|
||||
let is_group = msg.reply_target.contains("@g.us")
|
||||
|| msg.reply_target.starts_with("group:");
|
||||
|
||||
// Sender-scope recall (always)
|
||||
let sender_context = build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&msg.sender),
|
||||
).await;
|
||||
|
||||
// Group-scope recall (groups only)
|
||||
let group_context = if is_group {
|
||||
build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&history_key),
|
||||
).await
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Merge (deduplicate by checking substring overlap)
|
||||
let memory_context = merge_memory_contexts(&sender_context, &group_context);
|
||||
```
|
||||
|
||||
### 3. `src/channels/mod.rs` ~line 2244 — Inject into system prompt
|
||||
|
||||
Move memory context from user message to system prompt. Re-fetched each turn, doesn't pollute session.
|
||||
|
||||
```rust
|
||||
let mut system_prompt = build_channel_system_prompt(...);
|
||||
if !memory_context.is_empty() {
|
||||
system_prompt.push_str(&format!("\n\n{memory_context}"));
|
||||
}
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
```
|
||||
|
||||
### 4. `src/channels/mod.rs` — Dual-scope auto-save
|
||||
|
||||
Find existing auto-save call. For group messages, store twice:
|
||||
- `store(key, content, category, Some(&msg.sender))` — personal facts
|
||||
- `store(key, content, category, Some(&history_key))` — group context
|
||||
|
||||
Both async, non-blocking. DMs only store to sender scope.
|
||||
|
||||
### 5. `src/memory/mem0.rs` — Add `agent_id` support (optional)
|
||||
|
||||
Pass `self.app_name` as `agent_id` param to mem0 API for agent behavior tracking.
|
||||
|
||||
## Files to Modify
|
||||
|
||||
1. `src/memory/mem0.rs` — session_id → user_id mapping
|
||||
2. `src/channels/mod.rs` — per-turn recall, dual-scope, system prompt injection, dual-scope save
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo check --features whatsapp-web,memory-mem0`
|
||||
2. `cargo test --features whatsapp-web,memory-mem0`
|
||||
3. Deploy to Synology
|
||||
4. Test DM: "我鍾意食壽司" → next turn "我鍾意食咩" → should recall
|
||||
5. Test group: Joe says "我鍾意食壽司" → someone else asks "Joe 鍾意食咩" → should recall from group scope
|
||||
6. Check mem0 server logs: GET with `user_id=sender` AND `user_id=group_key`
|
||||
7. Check mem0 server logs: POST with both user_ids for group messages
|
||||
@@ -118,3 +118,7 @@ PROVIDER=openrouter
|
||||
# Optional: Brave Search (requires API key from https://brave.com/search/api)
|
||||
# WEB_SEARCH_PROVIDER=brave
|
||||
# BRAVE_API_KEY=your-brave-search-api-key
|
||||
#
|
||||
# Optional: SearXNG (self-hosted, requires instance URL)
|
||||
# WEB_SEARCH_PROVIDER=searxng
|
||||
# SEARXNG_INSTANCE_URL=https://searx.example.com
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
name: Pub Homebrew Core
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag to publish (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Patch formula only (no push/PR)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
HOMEBREW_UPSTREAM_PR_TOKEN:
|
||||
required: false
|
||||
HOMEBREW_CORE_BOT_TOKEN:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
|
||||
@@ -41,6 +41,14 @@ jobs:
|
||||
echo "Current version: ${current}"
|
||||
echo "Previous version: ${previous}"
|
||||
|
||||
# Skip if stable release workflow will handle this version
|
||||
# (indicated by an existing or imminent stable tag)
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/v${current}" >/dev/null 2>&1; then
|
||||
echo "Stable tag v${current} exists — stable release workflow handles crates.io"
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$current" != "$previous" && -n "$current" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${current}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
@@ -26,22 +26,43 @@ jobs:
|
||||
outputs:
|
||||
version: ${{ steps.ver.outputs.version }}
|
||||
tag: ${{ steps.ver.outputs.tag }}
|
||||
skip: ${{ steps.ver.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Compute beta version
|
||||
id: ver
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
base_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
|
||||
# Skip beta if this is a version bump commit (stable release handles it)
|
||||
commit_msg=$(git log -1 --pretty=format:"%s")
|
||||
if [[ "$commit_msg" =~ ^chore:\ bump\ version ]]; then
|
||||
echo "Version bump commit detected — skipping beta release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Skip beta if a stable tag already exists for this version
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/v${base_version}" >/dev/null 2>&1; then
|
||||
echo "Stable tag v${base_version} exists — skipping beta release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
beta_tag="v${base_version}-beta.${GITHUB_RUN_NUMBER}"
|
||||
echo "version=${base_version}" >> "$GITHUB_OUTPUT"
|
||||
echo "tag=${beta_tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Beta release: ${beta_tag}"
|
||||
|
||||
release-notes:
|
||||
name: Generate Release Notes
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
needs: [version]
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
notes: ${{ steps.notes.outputs.body }}
|
||||
@@ -132,7 +153,8 @@ jobs:
|
||||
|
||||
web:
|
||||
name: Build Web Dashboard
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
needs: [version]
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
name: Release Stable
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v[0-9]+.[0-9]+.[0-9]+" # stable tags only (no -beta suffix)
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
@@ -33,11 +36,22 @@ jobs:
|
||||
- name: Validate semver and Cargo.toml match
|
||||
id: check
|
||||
shell: bash
|
||||
env:
|
||||
INPUT_VERSION: ${{ inputs.version || '' }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
input_version="${{ inputs.version }}"
|
||||
cargo_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
|
||||
# Resolve version from tag push or manual input
|
||||
if [[ "$EVENT_NAME" == "push" ]]; then
|
||||
# Tag push: extract version from tag name (v0.5.9 -> 0.5.9)
|
||||
input_version="${REF_NAME#v}"
|
||||
else
|
||||
input_version="$INPUT_VERSION"
|
||||
fi
|
||||
|
||||
if [[ ! "$input_version" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "::error::Version must be semver (X.Y.Z). Got: ${input_version}"
|
||||
exit 1
|
||||
@@ -49,9 +63,13 @@ jobs:
|
||||
fi
|
||||
|
||||
tag="v${input_version}"
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${tag} already exists."
|
||||
exit 1
|
||||
|
||||
# Only check tag existence for manual dispatch (tag push means it already exists)
|
||||
if [[ "$EVENT_NAME" != "push" ]]; then
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${tag} already exists."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "tag=${tag}" >> "$GITHUB_OUTPUT"
|
||||
@@ -286,6 +304,14 @@ jobs:
|
||||
NOTES: ${{ needs.release-notes.outputs.notes }}
|
||||
run: printf '%s\n' "$NOTES" > release-notes.md
|
||||
|
||||
- name: Create tag if manual dispatch
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
TAG: ${{ needs.validate.outputs.tag }}
|
||||
run: |
|
||||
git tag -a "$TAG" -m "zeroclaw $TAG"
|
||||
git push origin "$TAG"
|
||||
|
||||
- name: Create GitHub Release
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
@@ -461,6 +487,16 @@ jobs:
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
homebrew:
|
||||
name: Update Homebrew Core
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-homebrew-core.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
# ── Post-publish: tweet after release + website are live ──────────────
|
||||
# Docker push can be slow; don't let it block the tweet.
|
||||
tweet:
|
||||
|
||||
Generated
+57
-29
@@ -3444,9 +3444,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2"
|
||||
|
||||
[[package]]
|
||||
name = "iri-string"
|
||||
version = "0.7.10"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
|
||||
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
@@ -3487,9 +3487,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
||||
checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682"
|
||||
|
||||
[[package]]
|
||||
name = "ittapi"
|
||||
@@ -3529,7 +3529,7 @@ dependencies = [
|
||||
"cesu8",
|
||||
"cfg-if",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"jni-sys 0.3.1",
|
||||
"log",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
@@ -3538,9 +3538,31 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
|
||||
dependencies = [
|
||||
"jni-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
|
||||
dependencies = [
|
||||
"jni-sys-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys-macros"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
@@ -4349,9 +4371,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "moka"
|
||||
version = "0.12.14"
|
||||
version = "0.12.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
|
||||
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
|
||||
dependencies = [
|
||||
"async-lock",
|
||||
"crossbeam-channel",
|
||||
@@ -4396,7 +4418,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"jni-sys",
|
||||
"jni-sys 0.3.1",
|
||||
"log",
|
||||
"ndk-sys",
|
||||
"num_enum",
|
||||
@@ -4415,7 +4437,7 @@ version = "0.5.0+25.2.9519653"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691"
|
||||
dependencies = [
|
||||
"jni-sys",
|
||||
"jni-sys 0.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5513,9 +5535,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pulldown-cmark"
|
||||
version = "0.13.1"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6"
|
||||
checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"memchr",
|
||||
@@ -5633,9 +5655,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "quoted_printable"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "640c9bd8497b02465aeef5375144c26062e0dcd5939dfcbb0f5db76cb8c17c73"
|
||||
checksum = "478e0585659a122aa407eb7e3c0e1fa51b1d8a870038bd29f0cf4a8551eea972"
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
@@ -7809,9 +7831,9 @@ checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "3.2.0"
|
||||
version = "3.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc"
|
||||
checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"cookie_store",
|
||||
@@ -7823,15 +7845,15 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"ureq-proto",
|
||||
"utf-8",
|
||||
"utf8-zero",
|
||||
"webpki-roots 1.0.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ureq-proto"
|
||||
version = "0.5.3"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
|
||||
checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"http 1.4.0",
|
||||
@@ -7864,6 +7886,12 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8-zero"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
@@ -9530,7 +9558,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.6"
|
||||
version = "0.5.9"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
@@ -9626,18 +9654,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.42"
|
||||
version = "0.8.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3"
|
||||
checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.42"
|
||||
version = "0.8.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f"
|
||||
checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -9742,9 +9770,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "8.3.0"
|
||||
version = "8.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a243cfad17427fc077f529da5a95abe4e94fd2bfdb601611870a6557cc67657"
|
||||
checksum = "5c546feb4481b0fbafb4ef0d79b6204fc41c6f9884b1b73b1d73f82442fc0845"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"flate2",
|
||||
@@ -9814,9 +9842,9 @@ checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
|
||||
|
||||
[[package]]
|
||||
name = "zune-jpeg"
|
||||
version = "0.5.13"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec5f41c76397b7da451efd19915684f727d7e1d516384ca6bd0ec43ec94de23c"
|
||||
checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6"
|
||||
dependencies = [
|
||||
"zune-core",
|
||||
]
|
||||
|
||||
+1
-4
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.6"
|
||||
version = "0.5.9"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -231,8 +231,6 @@ channel-matrix = ["dep:matrix-sdk"]
|
||||
channel-lark = ["dep:prost"]
|
||||
channel-feishu = ["channel-lark"] # Alias for Feishu users (Lark and Feishu are the same platform)
|
||||
memory-postgres = ["dep:postgres"]
|
||||
# memory-mem0 = Mem0 (OpenMemory) memory backend via REST API
|
||||
memory-mem0 = []
|
||||
observability-prometheus = ["dep:prometheus"]
|
||||
observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"]
|
||||
peripheral-rpi = ["rppal"]
|
||||
@@ -267,7 +265,6 @@ ci-all = [
|
||||
"channel-matrix",
|
||||
"channel-lark",
|
||||
"memory-postgres",
|
||||
"memory-mem0",
|
||||
"observability-prometheus",
|
||||
"observability-otel",
|
||||
"peripheral-rpi",
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Start mem0 + reranker GPU container for ZeroClaw memory backend.
|
||||
#
|
||||
# Required env vars:
|
||||
# MEM0_LLM_API_KEY or ZAI_API_KEY — API key for the LLM used in fact extraction
|
||||
#
|
||||
# Optional env vars (with defaults):
|
||||
# MEM0_LLM_PROVIDER — mem0 LLM provider (default: "openai" i.e. OpenAI-compatible)
|
||||
# MEM0_LLM_MODEL — LLM model for fact extraction (default: "glm-5-turbo")
|
||||
# MEM0_LLM_BASE_URL — LLM API base URL (default: "https://api.z.ai/api/coding/paas/v4")
|
||||
# MEM0_EMBEDDER_MODEL — embedding model (default: "BAAI/bge-m3")
|
||||
# MEM0_EMBEDDER_DIMS — embedding dimensions (default: "1024")
|
||||
# MEM0_EMBEDDER_DEVICE — "cuda", "cpu", or "auto" (default: "cuda")
|
||||
# MEM0_VECTOR_COLLECTION — Qdrant collection name (default: "zeroclaw_mem0")
|
||||
# RERANKER_MODEL — reranker model (default: "BAAI/bge-reranker-v2-m3")
|
||||
# RERANKER_DEVICE — "cuda" or "cpu" (default: "cuda")
|
||||
# MEM0_PORT — mem0 server port (default: 8765)
|
||||
# RERANKER_PORT — reranker server port (default: 8678)
|
||||
# CONTAINER_IMAGE — base container image (default: docker.io/kyuz0/amd-strix-halo-comfyui:latest)
|
||||
# CONTAINER_NAME — container name (default: mem0-gpu)
|
||||
# DATA_DIR — host path for Qdrant data (default: ~/mem0-data)
|
||||
# SCRIPT_DIR — host path for server scripts (default: directory of this script)
|
||||
set -e
|
||||
|
||||
# Resolve script directory for mounting server scripts
|
||||
SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "$0")" && pwd)}"
|
||||
|
||||
# API key — accept either name
|
||||
export MEM0_LLM_API_KEY="${MEM0_LLM_API_KEY:-${ZAI_API_KEY:?MEM0_LLM_API_KEY or ZAI_API_KEY must be set}}"
|
||||
|
||||
# Defaults
|
||||
MEM0_LLM_MODEL="${MEM0_LLM_MODEL:-glm-5-turbo}"
|
||||
MEM0_LLM_BASE_URL="${MEM0_LLM_BASE_URL:-https://api.z.ai/api/coding/paas/v4}"
|
||||
MEM0_PORT="${MEM0_PORT:-8765}"
|
||||
RERANKER_PORT="${RERANKER_PORT:-8678}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE:-docker.io/kyuz0/amd-strix-halo-comfyui:latest}"
|
||||
CONTAINER_NAME="${CONTAINER_NAME:-mem0-gpu}"
|
||||
DATA_DIR="${DATA_DIR:-$HOME/mem0-data}"
|
||||
|
||||
# Stop existing CPU services (if any)
|
||||
kill -9 $(pgrep -f "mem0-server.py") 2>/dev/null || true
|
||||
kill -9 $(pgrep -f "reranker-server.py") 2>/dev/null || true
|
||||
|
||||
# Stop existing container
|
||||
podman stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||
podman rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||
|
||||
podman run -d --name "$CONTAINER_NAME" \
|
||||
--device /dev/dri --device /dev/kfd \
|
||||
--group-add video --group-add render \
|
||||
--restart unless-stopped \
|
||||
-p "$MEM0_PORT:$MEM0_PORT" -p "$RERANKER_PORT:$RERANKER_PORT" \
|
||||
-v "$DATA_DIR":/root/mem0-data:Z \
|
||||
-v "$SCRIPT_DIR/mem0-server.py":/app/mem0-server.py:ro,Z \
|
||||
-v "$SCRIPT_DIR/reranker-server.py":/app/reranker-server.py:ro,Z \
|
||||
-v "$HOME/.cache/huggingface":/root/.cache/huggingface:Z \
|
||||
-e MEM0_LLM_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e ZAI_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e MEM0_LLM_MODEL="$MEM0_LLM_MODEL" \
|
||||
-e MEM0_LLM_BASE_URL="$MEM0_LLM_BASE_URL" \
|
||||
${MEM0_LLM_PROVIDER:+-e MEM0_LLM_PROVIDER="$MEM0_LLM_PROVIDER"} \
|
||||
${MEM0_EMBEDDER_MODEL:+-e MEM0_EMBEDDER_MODEL="$MEM0_EMBEDDER_MODEL"} \
|
||||
${MEM0_EMBEDDER_DIMS:+-e MEM0_EMBEDDER_DIMS="$MEM0_EMBEDDER_DIMS"} \
|
||||
${MEM0_EMBEDDER_DEVICE:+-e MEM0_EMBEDDER_DEVICE="$MEM0_EMBEDDER_DEVICE"} \
|
||||
${MEM0_VECTOR_COLLECTION:+-e MEM0_VECTOR_COLLECTION="$MEM0_VECTOR_COLLECTION"} \
|
||||
${RERANKER_MODEL:+-e RERANKER_MODEL="$RERANKER_MODEL"} \
|
||||
${RERANKER_DEVICE:+-e RERANKER_DEVICE="$RERANKER_DEVICE"} \
|
||||
-e RERANKER_PORT="$RERANKER_PORT" \
|
||||
-e RERANKER_URL="http://127.0.0.1:$RERANKER_PORT/rerank" \
|
||||
-e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
|
||||
-e HOME=/root \
|
||||
"$CONTAINER_IMAGE" \
|
||||
bash -c "pip install -q FlagEmbedding mem0ai flask httpx qdrant-client 2>&1 | tail -3; echo '=== Starting reranker (GPU) on :$RERANKER_PORT ==='; python3 /app/reranker-server.py & sleep 3; echo '=== Starting mem0 (GPU) on :$MEM0_PORT ==='; exec python3 /app/mem0-server.py"
|
||||
|
||||
echo "Container started, waiting for init..."
|
||||
sleep 15
|
||||
echo "=== Container logs ==="
|
||||
podman logs "$CONTAINER_NAME" 2>&1 | tail -25
|
||||
echo "=== Port check ==="
|
||||
ss -tlnp | grep "$MEM0_PORT\|$RERANKER_PORT" || echo "Ports not yet ready, check: podman logs $CONTAINER_NAME"
|
||||
@@ -1,288 +0,0 @@
|
||||
"""Minimal OpenMemory-compatible REST server wrapping mem0 Python SDK."""
|
||||
import asyncio
|
||||
import json, os, uuid, httpx
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import FastAPI, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from mem0 import Memory
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
RERANKER_URL = os.environ.get("RERANKER_URL", "http://127.0.0.1:8678/rerank")
|
||||
|
||||
CUSTOM_EXTRACTION_PROMPT = """You are a memory extraction specialist for a Cantonese/Chinese chat assistant.
|
||||
|
||||
Extract ONLY important, persistent facts from the conversation. Rules:
|
||||
1. Extract personal preferences, habits, relationships, names, locations
|
||||
2. Extract decisions, plans, and commitments people make
|
||||
3. SKIP small talk, greetings, reactions ("ok", "哈哈", "係呀")
|
||||
4. SKIP temporary states ("我依家食緊飯") unless they reveal a habit
|
||||
5. Keep facts in the ORIGINAL language (Cantonese/Chinese/English)
|
||||
6. For each fact, note WHO it's about (use their name or identifier if known)
|
||||
7. Merge/update existing facts rather than creating duplicates
|
||||
|
||||
Return a list of facts in JSON format: {"facts": ["fact1", "fact2", ...]}
|
||||
"""
|
||||
|
||||
PROCEDURAL_EXTRACTION_PROMPT = """You are a procedural memory specialist for an AI assistant.
|
||||
|
||||
Extract HOW-TO patterns and reusable procedures from the conversation trace. Rules:
|
||||
1. Identify step-by-step procedures the assistant followed to accomplish a task
|
||||
2. Extract tool usage patterns: which tools were called, in what order, with what arguments
|
||||
3. Capture decision points: why the assistant chose one approach over another
|
||||
4. Note error-recovery patterns: what failed, how it was fixed
|
||||
5. Keep the procedure generic enough to apply to similar future tasks
|
||||
6. Preserve technical details (commands, file paths, API calls) that are reusable
|
||||
7. SKIP greetings, small talk, and conversational filler
|
||||
8. Format each procedure as: "To [goal]: [step1] -> [step2] -> ... -> [result]"
|
||||
|
||||
Return a list of procedures in JSON format: {"facts": ["procedure1", "procedure2", ...]}
|
||||
"""
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
# LLM (for fact extraction when infer=true)
|
||||
MEM0_LLM_PROVIDER = os.environ.get("MEM0_LLM_PROVIDER", "openai") # "openai" (compatible), "anthropic", etc.
|
||||
MEM0_LLM_MODEL = os.environ.get("MEM0_LLM_MODEL", "glm-5-turbo")
|
||||
MEM0_LLM_API_KEY = os.environ.get("MEM0_LLM_API_KEY") or os.environ.get("ZAI_API_KEY", "")
|
||||
MEM0_LLM_BASE_URL = os.environ.get("MEM0_LLM_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
|
||||
|
||||
# Embedder
|
||||
MEM0_EMBEDDER_PROVIDER = os.environ.get("MEM0_EMBEDDER_PROVIDER", "huggingface") # "huggingface", "openai", etc.
|
||||
MEM0_EMBEDDER_MODEL = os.environ.get("MEM0_EMBEDDER_MODEL", "BAAI/bge-m3")
|
||||
MEM0_EMBEDDER_DIMS = int(os.environ.get("MEM0_EMBEDDER_DIMS", "1024"))
|
||||
MEM0_EMBEDDER_DEVICE = os.environ.get("MEM0_EMBEDDER_DEVICE", "cuda") # "cuda", "cpu", "auto"
|
||||
|
||||
# Vector store
|
||||
MEM0_VECTOR_PROVIDER = os.environ.get("MEM0_VECTOR_PROVIDER", "qdrant") # "qdrant", "chroma", etc.
|
||||
MEM0_VECTOR_COLLECTION = os.environ.get("MEM0_VECTOR_COLLECTION", "zeroclaw_mem0")
|
||||
MEM0_VECTOR_PATH = os.environ.get("MEM0_VECTOR_PATH", os.path.expanduser("~/mem0-data/qdrant"))
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": MEM0_LLM_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_LLM_MODEL,
|
||||
"api_key": MEM0_LLM_API_KEY,
|
||||
"openai_base_url": MEM0_LLM_BASE_URL,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": MEM0_EMBEDDER_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_EMBEDDER_MODEL,
|
||||
"embedding_dims": MEM0_EMBEDDER_DIMS,
|
||||
"model_kwargs": {"device": MEM0_EMBEDDER_DEVICE},
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": MEM0_VECTOR_PROVIDER,
|
||||
"config": {
|
||||
"collection_name": MEM0_VECTOR_COLLECTION,
|
||||
"embedding_model_dims": MEM0_EMBEDDER_DIMS,
|
||||
"path": MEM0_VECTOR_PATH,
|
||||
},
|
||||
},
|
||||
"custom_fact_extraction_prompt": CUSTOM_EXTRACTION_PROMPT,
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
|
||||
|
||||
def rerank_results(query: str, items: list, top_k: int = 10) -> list:
|
||||
"""Rerank search results using bge-reranker-v2-m3."""
|
||||
if not items:
|
||||
return items
|
||||
documents = [item.get("memory", "") for item in items]
|
||||
try:
|
||||
resp = httpx.post(
|
||||
RERANKER_URL,
|
||||
json={"query": query, "documents": documents, "top_k": top_k},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ranked = resp.json().get("results", [])
|
||||
return [items[r["index"]] for r in ranked]
|
||||
except Exception as e:
|
||||
print(f"Reranker failed, using original order: {e}")
|
||||
return items
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
infer: bool = True
|
||||
app: Optional[str] = None
|
||||
custom_instructions: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/")
|
||||
async def add_memory(req: AddMemoryRequest):
|
||||
# Use client-supplied prompt, fall back to server default, then mem0 SDK default
|
||||
prompt = req.custom_instructions or CUSTOM_EXTRACTION_PROMPT
|
||||
result = await asyncio.to_thread(m.add, req.text, user_id=req.user_id, metadata=req.metadata or {}, prompt=prompt)
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
class ProceduralMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
messages: list[dict]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/procedural")
|
||||
async def add_procedural_memory(req: ProceduralMemoryRequest):
|
||||
"""Store a conversation trace as procedural memory.
|
||||
|
||||
Accepts a list of messages (role/content dicts) representing a full
|
||||
conversation turn including tool calls, then uses mem0's native
|
||||
procedural memory extraction to learn reusable "how to" patterns.
|
||||
"""
|
||||
# Build metadata with procedural type marker
|
||||
meta = {"type": "procedural"}
|
||||
if req.metadata:
|
||||
meta.update(req.metadata)
|
||||
|
||||
# Use mem0's native message list support + procedural prompt
|
||||
result = await asyncio.to_thread(m.add,
|
||||
req.messages,
|
||||
user_id=req.user_id,
|
||||
metadata=meta,
|
||||
prompt=PROCEDURAL_EXTRACTION_PROMPT,
|
||||
)
|
||||
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
def _parse_mem0_results(raw_results) -> list:
|
||||
raw = raw_results.get("results", raw_results) if isinstance(raw_results, dict) else raw_results
|
||||
items = []
|
||||
for r in raw:
|
||||
item = r if isinstance(r, dict) else {"memory": str(r)}
|
||||
items.append({
|
||||
"id": item.get("id", str(uuid.uuid4())),
|
||||
"memory": item.get("memory", item.get("text", "")),
|
||||
"created_at": item.get("created_at", datetime.now(timezone.utc).isoformat()),
|
||||
"metadata_": item.get("metadata", {}),
|
||||
})
|
||||
return items
|
||||
|
||||
|
||||
def _parse_iso_timestamp(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 timestamp string, returning None on failure."""
|
||||
try:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _item_created_at(item: dict) -> Optional[datetime]:
|
||||
"""Extract created_at from an item as a timezone-aware datetime."""
|
||||
raw = item.get("created_at")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, datetime):
|
||||
if raw.tzinfo is None:
|
||||
raw = raw.replace(tzinfo=timezone.utc)
|
||||
return raw
|
||||
return _parse_iso_timestamp(str(raw))
|
||||
|
||||
|
||||
def _apply_post_filters(
|
||||
items: list,
|
||||
created_after: Optional[str],
|
||||
created_before: Optional[str],
|
||||
) -> list:
|
||||
"""Filter items by created_after / created_before timestamps (post-query)."""
|
||||
after_dt = _parse_iso_timestamp(created_after) if created_after else None
|
||||
before_dt = _parse_iso_timestamp(created_before) if created_before else None
|
||||
if after_dt is None and before_dt is None:
|
||||
return items
|
||||
filtered = []
|
||||
for item in items:
|
||||
ts = _item_created_at(item)
|
||||
if ts is None:
|
||||
# Keep items without a parseable timestamp
|
||||
filtered.append(item)
|
||||
continue
|
||||
if after_dt and ts < after_dt:
|
||||
continue
|
||||
if before_dt and ts > before_dt:
|
||||
continue
|
||||
filtered.append(item)
|
||||
return filtered
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/")
|
||||
async def list_or_search_memories(
|
||||
user_id: str = Query(...),
|
||||
search_query: Optional[str] = Query(None),
|
||||
size: int = Query(10),
|
||||
rerank: bool = Query(True),
|
||||
created_after: Optional[str] = Query(None),
|
||||
created_before: Optional[str] = Query(None),
|
||||
metadata_filter: Optional[str] = Query(None),
|
||||
):
|
||||
# Build mem0 SDK filters dict from metadata_filter JSON param
|
||||
sdk_filters = None
|
||||
if metadata_filter:
|
||||
try:
|
||||
sdk_filters = json.loads(metadata_filter)
|
||||
except json.JSONDecodeError:
|
||||
sdk_filters = None
|
||||
|
||||
if search_query:
|
||||
# Fetch more results than needed so reranker has candidates to work with
|
||||
fetch_size = min(size * 3, 50)
|
||||
results = await asyncio.to_thread(m.search,
|
||||
search_query,
|
||||
user_id=user_id,
|
||||
limit=fetch_size,
|
||||
filters=sdk_filters,
|
||||
)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
if rerank and items:
|
||||
items = rerank_results(search_query, items, top_k=size)
|
||||
else:
|
||||
items = items[:size]
|
||||
return {"items": items, "total": len(items)}
|
||||
else:
|
||||
results = await asyncio.to_thread(m.get_all,user_id=user_id, filters=sdk_filters)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
return {"items": items, "total": len(items)}
|
||||
|
||||
|
||||
@app.delete("/api/v1/memories/{memory_id}")
|
||||
async def delete_memory(memory_id: str):
|
||||
try:
|
||||
await asyncio.to_thread(m.delete, memory_id)
|
||||
except Exception:
|
||||
pass
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/{memory_id}/history")
|
||||
async def get_memory_history(memory_id: str):
|
||||
"""Return the edit history of a specific memory."""
|
||||
try:
|
||||
history = await asyncio.to_thread(m.history, memory_id)
|
||||
# Normalize to list of dicts
|
||||
entries = []
|
||||
raw = history if isinstance(history, list) else history.get("results", history) if isinstance(history, dict) else [history]
|
||||
for h in raw:
|
||||
entry = h if isinstance(h, dict) else {"event": str(h)}
|
||||
entries.append(entry)
|
||||
return {"memory_id": memory_id, "history": entries}
|
||||
except Exception as e:
|
||||
return {"memory_id": memory_id, "history": [], "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8765)
|
||||
@@ -1,50 +0,0 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from FlagEmbedding import FlagReranker
|
||||
import os, torch
|
||||
|
||||
app = Flask(__name__)
|
||||
reranker = None
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
RERANKER_DEVICE = os.environ.get("RERANKER_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
RERANKER_PORT = int(os.environ.get("RERANKER_PORT", "8678"))
|
||||
|
||||
def get_reranker():
|
||||
global reranker
|
||||
if reranker is None:
|
||||
reranker = FlagReranker(RERANKER_MODEL, use_fp16=True, device=RERANKER_DEVICE)
|
||||
return reranker
|
||||
|
||||
@app.route('/rerank', methods=['POST'])
|
||||
def rerank():
|
||||
data = request.json
|
||||
query = data.get('query', '')
|
||||
documents = data.get('documents', [])
|
||||
top_k = data.get('top_k', len(documents))
|
||||
|
||||
if not query or not documents:
|
||||
return jsonify({'error': 'query and documents required'}), 400
|
||||
|
||||
pairs = [[query, doc] for doc in documents]
|
||||
scores = get_reranker().compute_score(pairs)
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
|
||||
results = sorted(
|
||||
[{'index': i, 'document': doc, 'score': score}
|
||||
for i, (doc, score) in enumerate(zip(documents, scores))],
|
||||
key=lambda x: x['score'], reverse=True
|
||||
)[:top_k]
|
||||
|
||||
return jsonify({'results': results})
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health():
|
||||
return jsonify({'status': 'ok', 'model': RERANKER_MODEL, 'device': RERANKER_DEVICE})
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f'Loading reranker model ({RERANKER_MODEL}) on {RERANKER_DEVICE}...')
|
||||
get_reranker()
|
||||
print(f'Reranker server ready on :{RERANKER_PORT}')
|
||||
app.run(host='0.0.0.0', port=RERANKER_PORT)
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.6
|
||||
pkgver = 0.5.9
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.6.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.6.tar.gz
|
||||
source = zeroclaw-0.5.9.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.9.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.6
|
||||
pkgver=0.5.9
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.6",
|
||||
"version": "0.5.9",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.6/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.9/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# ADR-004: Tool Shared State Ownership Contract
|
||||
|
||||
**Status:** Accepted
|
||||
|
||||
**Date:** 2026-03-22
|
||||
|
||||
**Issue:** [#4057](https://github.com/zeroclaw/zeroclaw/issues/4057)
|
||||
|
||||
## Context
|
||||
|
||||
ZeroClaw tools execute in a multi-client environment where a single daemon
|
||||
process serves requests from multiple connected clients simultaneously. Several
|
||||
tools already maintain long-lived shared state:
|
||||
|
||||
- **`DelegateParentToolsHandle`** (`src/tools/mod.rs`):
|
||||
`Arc<RwLock<Vec<Arc<dyn Tool>>>>` — holds parent tools for delegate agents
|
||||
with no per-client isolation.
|
||||
- **`ChannelMapHandle`** (`src/tools/reaction.rs`):
|
||||
`Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>` — global channel map shared
|
||||
across all clients.
|
||||
- **`CanvasStore`** (`src/tools/canvas.rs`):
|
||||
`Arc<RwLock<HashMap<String, CanvasEntry>>>` — canvas IDs are plain strings
|
||||
with no client namespace.
|
||||
|
||||
These patterns emerged organically. As the tool surface grows and more clients
|
||||
connect concurrently, we need a clear contract governing ownership, identity,
|
||||
isolation, lifecycle, and reload behavior for tool-held shared state. Without
|
||||
this contract, new tools risk introducing data leaks between clients, stale
|
||||
state after config reloads, or inconsistent initialization timing.
|
||||
|
||||
Additional context:
|
||||
|
||||
- The tool registry is immutable after startup, built once in
|
||||
`all_tools_with_runtime()`.
|
||||
- Client identity is currently derived from IP address only
|
||||
(`src/gateway/mod.rs`), which is insufficient for reliable namespacing.
|
||||
- `SecurityPolicy` is scoped per agent, not per client.
|
||||
- `WorkspaceManager` provides some isolation but workspace switching is global.
|
||||
|
||||
## Decision
|
||||
|
||||
### 1. Ownership: May tools own long-lived shared state?
|
||||
|
||||
**Yes.** Tools MAY own long-lived shared state, provided they follow the
|
||||
established **handle pattern**: wrap the state in `Arc<RwLock<T>>` (or
|
||||
`Arc<parking_lot::RwLock<T>>`) and expose a cloneable handle type.
|
||||
|
||||
This pattern is already proven by three independent implementations:
|
||||
|
||||
| Handle | Location | Inner type |
|
||||
|--------|----------|-----------|
|
||||
| `DelegateParentToolsHandle` | `src/tools/mod.rs` | `Vec<Arc<dyn Tool>>` |
|
||||
| `ChannelMapHandle` | `src/tools/reaction.rs` | `HashMap<String, Arc<dyn Channel>>` |
|
||||
| `CanvasStore` | `src/tools/canvas.rs` | `HashMap<String, CanvasEntry>` |
|
||||
|
||||
Tools that need shared state MUST:
|
||||
|
||||
- Define a named handle type alias (e.g., `pub type FooHandle = Arc<RwLock<T>>`).
|
||||
- Accept the handle at construction time rather than creating global state.
|
||||
- Document the concurrency contract in the handle type's doc comment.
|
||||
|
||||
Tools MUST NOT use static mutable state (`lazy_static!`, `OnceCell` with
|
||||
interior mutability) for per-request or per-client data.
|
||||
|
||||
### 2. Identity assignment: Who constructs identity keys?
|
||||
|
||||
**The daemon SHOULD provide identity.** Tools MUST NOT construct their own
|
||||
client identity keys.
|
||||
|
||||
A new `ClientId` type should be introduced (opaque, `Clone + Eq + Hash + Send + Sync`)
|
||||
that the daemon assigns at connection time. This replaces the current approach
|
||||
of using raw IP addresses (`src/gateway/mod.rs:259-306`), which breaks when
|
||||
multiple clients share a NAT address or when proxied connections arrive.
|
||||
|
||||
`ClientId` is passed to tools that require per-client state namespacing as part
|
||||
of the tool execution context. Tools that do not need per-client isolation
|
||||
(e.g., the immutable tool registry) may ignore it.
|
||||
|
||||
The `ClientId` contract:
|
||||
|
||||
- Generated by the gateway layer at connection establishment.
|
||||
- Opaque to tools — tools must not parse or derive meaning from the value.
|
||||
- Stable for the lifetime of a single client session.
|
||||
- Passed through the execution context, not stored globally.
|
||||
|
||||
### 3. Lifecycle: When may tools run startup-style validation?
|
||||
|
||||
**Validation runs once at first registration, and again when config changes
|
||||
are detected.**
|
||||
|
||||
The lifecycle phases are:
|
||||
|
||||
1. **Construction** — tool is instantiated with handles and config. No I/O or
|
||||
validation occurs here.
|
||||
2. **Registration** — tool is registered in the tool registry via
|
||||
`all_tools_with_runtime()`. At this point the tool MAY perform one-time
|
||||
startup validation (e.g., checking that required credentials exist, verifying
|
||||
external service connectivity).
|
||||
3. **Execution** — tool handles individual requests. No re-validation unless
|
||||
the config-change signal fires (see Reload Semantics below).
|
||||
4. **Shutdown** — daemon is stopping. Tools with open resources SHOULD clean up
|
||||
gracefully via `Drop` or an explicit shutdown method.
|
||||
|
||||
Tools MUST NOT perform blocking validation during execution-phase calls.
|
||||
Validation results SHOULD be cached in the tool's handle state and checked
|
||||
via a fast path during execution.
|
||||
|
||||
### 4. Isolation: What must be isolated per client?
|
||||
|
||||
State falls into two categories with different isolation requirements:
|
||||
|
||||
**MUST be isolated per client:**
|
||||
|
||||
- Security-sensitive state: credentials, API keys, quotas, rate-limit counters,
|
||||
per-client authorization decisions.
|
||||
- User-specific session data: conversation context, user preferences,
|
||||
workspace-scoped file paths.
|
||||
|
||||
Isolation mechanism: tools holding per-client state MUST key their internal
|
||||
maps by `ClientId`. The handle pattern naturally supports this by using
|
||||
`HashMap<ClientId, T>` inside the `RwLock`.
|
||||
|
||||
**MAY be shared across clients (with namespace prefixing):**
|
||||
|
||||
- Broadcast/display state: canvas frames (`CanvasStore`), notification channels
|
||||
(`ChannelMapHandle`).
|
||||
- Read-only reference data: tool registry, static configuration, model
|
||||
metadata.
|
||||
|
||||
When shared state uses string keys (e.g., canvas IDs, channel names), tools
|
||||
SHOULD support optional namespace prefixing (e.g., `{client_id}:{canvas_name}`)
|
||||
to allow per-client isolation when needed without mandating it for broadcast
|
||||
use cases.
|
||||
|
||||
Tools MUST NOT store per-client secrets in shared (non-isolated) state
|
||||
structures.
|
||||
|
||||
### 5. Reload semantics: What invalidates prior shared state on config change?
|
||||
|
||||
**Config changes detected via hash comparison MUST invalidate cached
|
||||
validation state.**
|
||||
|
||||
The reload contract:
|
||||
|
||||
- The daemon computes a hash of the tool-relevant config section at startup and
|
||||
after each config reload event.
|
||||
- When the hash changes, the daemon signals affected tools to re-run their
|
||||
registration-phase validation.
|
||||
- Tools MUST treat their cached validation result as stale when signaled and
|
||||
re-validate before the next execution.
|
||||
|
||||
Specific invalidation rules:
|
||||
|
||||
| Config change | Invalidation scope |
|
||||
|--------------|-------------------|
|
||||
| Credential/secret rotation | Per-tool validation cache; per-client credential state |
|
||||
| Tool enable/disable | Full tool registry rebuild via `all_tools_with_runtime()` |
|
||||
| Security policy change | `SecurityPolicy` re-derivation; per-agent policy state |
|
||||
| Workspace directory change | `WorkspaceManager` state; file-path-dependent tool state |
|
||||
| Provider config change | Provider-dependent tools re-validate connectivity |
|
||||
|
||||
Tools MAY retain non-security shared state (e.g., canvas content, channel
|
||||
subscriptions) across config reloads unless the reload explicitly affects that
|
||||
state's validity.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
- **Consistency:** All new tools follow the same handle pattern, making shared
|
||||
state discoverable and auditable.
|
||||
- **Safety:** Per-client isolation of security-sensitive state prevents data
|
||||
leaks in multi-tenant scenarios.
|
||||
- **Clarity:** Explicit lifecycle phases eliminate ambiguity about when
|
||||
validation runs.
|
||||
- **Evolvability:** The `ClientId` abstraction decouples tools from transport
|
||||
details, supporting future identity mechanisms (tokens, certificates).
|
||||
|
||||
### Negative
|
||||
|
||||
- **Migration cost:** Existing tools (`CanvasStore`, `ReactionTool`) may need
|
||||
refactoring to accept `ClientId` and namespace their state.
|
||||
- **Complexity:** Tools that were simple singletons now need to consider
|
||||
multi-client semantics even if they currently have one client.
|
||||
- **Performance:** Per-client keying adds a hash lookup on each access, though
|
||||
this is negligible compared to I/O costs.
|
||||
|
||||
### Neutral
|
||||
|
||||
- The tool registry remains immutable after startup; this ADR does not change
|
||||
that invariant.
|
||||
- `SecurityPolicy` remains per-agent; this ADR documents that client isolation
|
||||
is orthogonal to agent-level policy.
|
||||
|
||||
## References
|
||||
|
||||
- `src/tools/mod.rs` — `DelegateParentToolsHandle`, `all_tools_with_runtime()`
|
||||
- `src/tools/reaction.rs` — `ChannelMapHandle`, `ReactionTool`
|
||||
- `src/tools/canvas.rs` — `CanvasStore`, `CanvasEntry`
|
||||
- `src/tools/traits.rs` — `Tool` trait
|
||||
- `src/gateway/mod.rs` — client IP extraction (`forwarded_client_ip`, `resolve_client_ip`)
|
||||
- `src/security/` — `SecurityPolicy`
|
||||
@@ -0,0 +1,215 @@
|
||||
# Browser Automation Setup Guide
|
||||
|
||||
This guide covers setting up browser automation capabilities in ZeroClaw, including both headless automation and GUI access via VNC.
|
||||
|
||||
## Overview
|
||||
|
||||
ZeroClaw supports multiple browser access methods:
|
||||
|
||||
| Method | Use Case | Requirements |
|
||||
|--------|----------|--------------|
|
||||
| **agent-browser CLI** | Headless automation, AI agents | npm, Chrome |
|
||||
| **VNC + noVNC** | GUI access, debugging | Xvfb, x11vnc, noVNC |
|
||||
| **Chrome Remote Desktop** | Remote GUI via Google | XFCE, Google account |
|
||||
|
||||
## Quick Start: Headless Automation
|
||||
|
||||
### 1. Install agent-browser
|
||||
|
||||
```bash
|
||||
# Install CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux (includes system deps)
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
### 2. Verify ZeroClaw Config
|
||||
|
||||
The browser tool is enabled by default. To verify or customize, edit
|
||||
`~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
### 3. Test
|
||||
|
||||
```bash
|
||||
echo "Open https://example.com and tell me what it says" | zeroclaw agent
|
||||
```
|
||||
|
||||
## VNC Setup (GUI Access)
|
||||
|
||||
For debugging or when you need visual browser access:
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
apt-get install -y xvfb x11vnc fluxbox novnc websockify
|
||||
|
||||
# Optional: Desktop environment for Chrome Remote Desktop
|
||||
apt-get install -y xfce4 xfce4-goodies
|
||||
```
|
||||
|
||||
### Start VNC Server
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access
|
||||
|
||||
DISPLAY_NUM=99
|
||||
VNC_PORT=5900
|
||||
NOVNC_PORT=6080
|
||||
RESOLUTION=1920x1080x24
|
||||
|
||||
# Start Xvfb
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
sleep 1
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
|
||||
echo "VNC available at:"
|
||||
echo " VNC Client: localhost:$VNC_PORT"
|
||||
echo " Web Browser: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
```
|
||||
|
||||
### VNC Access
|
||||
|
||||
- **VNC Client**: Connect to `localhost:5900`
|
||||
- **Web Browser**: Open `http://localhost:6080/vnc.html`
|
||||
|
||||
### Start Browser on VNC Display
|
||||
|
||||
```bash
|
||||
DISPLAY=:99 google-chrome --no-sandbox https://example.com &
|
||||
```
|
||||
|
||||
## Chrome Remote Desktop
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# Download and install
|
||||
wget https://dl.google.com/linux/direct/chrome-remote-desktop_current_amd64.deb
|
||||
apt-get install -y ./chrome-remote-desktop_current_amd64.deb
|
||||
|
||||
# Configure session
|
||||
echo "xfce4-session" > ~/.chrome-remote-desktop-session
|
||||
chmod +x ~/.chrome-remote-desktop-session
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
1. Visit <https://remotedesktop.google.com/headless>
|
||||
2. Copy the "Debian Linux" setup command
|
||||
3. Run it on your server
|
||||
4. Start the service: `systemctl --user start chrome-remote-desktop`
|
||||
|
||||
### Remote Access
|
||||
|
||||
Go to <https://remotedesktop.google.com/access> from any device.
|
||||
|
||||
## Testing
|
||||
|
||||
### CLI Tests
|
||||
|
||||
```bash
|
||||
# Basic open and close
|
||||
agent-browser open https://example.com
|
||||
agent-browser get title
|
||||
agent-browser close
|
||||
|
||||
# Snapshot with refs
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser close
|
||||
|
||||
# Screenshot
|
||||
agent-browser open https://example.com
|
||||
agent-browser screenshot /tmp/test.png
|
||||
agent-browser close
|
||||
```
|
||||
|
||||
### ZeroClaw Integration Tests
|
||||
|
||||
```bash
|
||||
# Content extraction
|
||||
echo "Open https://example.com and summarize it" | zeroclaw agent
|
||||
|
||||
# Navigation
|
||||
echo "Go to https://github.com/trending and list the top 3 repos" | zeroclaw agent
|
||||
|
||||
# Form interaction
|
||||
echo "Go to Wikipedia, search for 'Rust programming language', and summarize" | zeroclaw agent
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Element not found"
|
||||
|
||||
The page may not be fully loaded. Add a wait:
|
||||
|
||||
```bash
|
||||
agent-browser open https://slow-site.com
|
||||
agent-browser wait --load networkidle
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Cookie dialogs blocking access
|
||||
|
||||
Handle cookie consent first:
|
||||
|
||||
```bash
|
||||
agent-browser open https://site-with-cookies.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser click @accept_cookies # Click the accept button
|
||||
agent-browser snapshot -i # Now get the actual content
|
||||
```
|
||||
|
||||
### Docker sandbox network restrictions
|
||||
|
||||
If `web_fetch` fails inside Docker sandbox, use agent-browser instead:
|
||||
|
||||
```bash
|
||||
# Instead of web_fetch, use:
|
||||
agent-browser open https://example.com
|
||||
agent-browser get text body
|
||||
```
|
||||
|
||||
## Security Notes
|
||||
|
||||
- `agent-browser` runs Chrome in headless mode with sandboxing
|
||||
- For sensitive sites, use `--session-name` to persist auth state
|
||||
- The `--allowed-domains` config restricts navigation to specific domains
|
||||
- VNC ports (5900, 6080) should be behind a firewall or Tailscale
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser Documentation](https://github.com/vercel-labs/agent-browser)
|
||||
- [ZeroClaw Configuration Reference](./config-reference.md)
|
||||
- [Skills Documentation](../skills/)
|
||||
@@ -45,6 +45,15 @@ For complete code examples of each extension trait, see [extension-examples.md](
|
||||
- Keep multilingual entry-point parity for all supported locales (`en`, `zh-CN`, `ja`, `ru`, `fr`, `vi`) when nav or key wording changes.
|
||||
- When shared docs wording changes, sync corresponding localized docs in the same PR (or explicitly document deferral and follow-up PR).
|
||||
|
||||
## Tool Shared State
|
||||
|
||||
- Follow the `Arc<RwLock<T>>` handle pattern for any tool that owns long-lived shared state.
|
||||
- Accept handles at construction; do not create global/static mutable state.
|
||||
- Use `ClientId` (provided by the daemon) to namespace per-client state — never construct identity keys inside the tool.
|
||||
- Isolate security-sensitive state (credentials, quotas) per client; broadcast/display state may be shared with optional namespace prefixing.
|
||||
- Cached validation is invalidated on config change — tools must re-validate before the next execution when signaled.
|
||||
- See [ADR-004: Tool Shared State Ownership](../architecture/adr-004-tool-shared-state-ownership.md) for the full contract.
|
||||
|
||||
## Architecture Boundary Rules
|
||||
|
||||
- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features.
|
||||
|
||||
@@ -411,30 +411,6 @@ allowed_roots = [\"~/Desktop/projects\", \"/opt/shared-repo\"]
|
||||
|
||||
- 内存上下文注入忽略旧的 `assistant_resp*` 自动保存键,以防止旧模型生成的摘要被视为事实。
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) 后端 — 连接自托管 mem0 服务器,提供基于向量的记忆存储和 LLM 事实提取。构建时需要 `memory-mem0` feature flag,配置需设置 `backend = "mem0"`。
|
||||
|
||||
| 键 | 默认值 | 环境变量 | 用途 |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory 服务器地址 |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | 记忆作用域的用户 ID |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | 在 mem0 中注册的应用名称 |
|
||||
| `infer` | `true` | — | 使用 LLM 从存储文本中提取事实 (`true`) 或原样存储 (`false`) |
|
||||
| `extraction_prompt` | 未设置 | `MEM0_EXTRACTION_PROMPT` | 自定义 LLM 事实提取提示词(如适用于非英文内容) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "用原始语言提取事实..."
|
||||
```
|
||||
|
||||
服务器部署脚本位于 `deploy/mem0/`。
|
||||
|
||||
## `[[model_routes]]` 和 `[[embedding_routes]]`
|
||||
|
||||
使用路由提示,以便集成可以在模型 ID 演变时保持稳定的名称。
|
||||
|
||||
@@ -508,30 +508,6 @@ Notes:
|
||||
|
||||
- Memory context injection ignores legacy `assistant_resp*` auto-save keys to prevent old model-authored summaries from being treated as facts.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) backend — connects to a self-hosted mem0 server for vector-based memory with LLM-powered fact extraction. Requires feature flag `memory-mem0` at build time and `backend = "mem0"` in config.
|
||||
|
||||
| Key | Default | Env var | Purpose |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory server URL |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID for scoping memories |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Application name registered in mem0 |
|
||||
| `infer` | `true` | — | Use LLM to extract facts from stored text (`true`) or store raw (`false`) |
|
||||
| `extraction_prompt` | unset | `MEM0_EXTRACTION_PROMPT` | Custom prompt for LLM fact extraction (e.g. for non-English content) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Extract facts in the original language..."
|
||||
```
|
||||
|
||||
Server deployment scripts are in `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` and `[[embedding_routes]]`
|
||||
|
||||
Use route hints so integrations can keep stable names while model IDs evolve.
|
||||
|
||||
@@ -337,30 +337,6 @@ Lưu ý:
|
||||
|
||||
- Chèn ngữ cảnh memory bỏ qua khóa auto-save `assistant_resp*` kiểu cũ để tránh tóm tắt do model tạo bị coi là sự thật.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Backend Mem0 (OpenMemory) — kết nối đến server mem0 tự host, cung cấp bộ nhớ vector với trích xuất sự kiện bằng LLM. Cần feature flag `memory-mem0` khi build và `backend = "mem0"` trong config.
|
||||
|
||||
| Khóa | Mặc định | Biến môi trường | Mục đích |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | URL server OpenMemory |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID để phân vùng memory |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Tên ứng dụng đăng ký trong mem0 |
|
||||
| `infer` | `true` | — | Dùng LLM trích xuất sự kiện từ text (`true`) hoặc lưu nguyên (`false`) |
|
||||
| `extraction_prompt` | chưa đặt | `MEM0_EXTRACTION_PROMPT` | Prompt tùy chỉnh cho trích xuất sự kiện LLM (vd: cho nội dung không phải tiếng Anh) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Trích xuất sự kiện bằng ngôn ngữ gốc..."
|
||||
```
|
||||
|
||||
Script triển khai server nằm trong `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` và `[[embedding_routes]]`
|
||||
|
||||
Route hint giúp tên tích hợp ổn định khi model ID thay đổi.
|
||||
|
||||
@@ -38,3 +38,46 @@ allowed_tools = ["read", "edit", "exec"]
|
||||
max_iterations = 15
|
||||
# Optional: use longer timeout for complex coding tasks
|
||||
agentic_timeout_secs = 600
|
||||
|
||||
# ── Cron Configuration ────────────────────────────────────────
|
||||
[cron]
|
||||
# Enable the cron subsystem. Default: true
|
||||
enabled = true
|
||||
# Run all overdue jobs at scheduler startup. Default: true
|
||||
catch_up_on_startup = true
|
||||
# Maximum number of historical cron run records to retain. Default: 50
|
||||
max_run_history = 50
|
||||
|
||||
# ── Declarative Cron Jobs ─────────────────────────────────────
|
||||
# Define cron jobs directly in config. These are synced to the database
|
||||
# at scheduler startup. Each job needs a stable `id` for merge semantics.
|
||||
|
||||
# Shell job: runs a shell command on a cron schedule
|
||||
[[cron.jobs]]
|
||||
id = "daily-backup"
|
||||
name = "Daily Backup"
|
||||
job_type = "shell"
|
||||
command = "tar czf /tmp/backup.tar.gz /data"
|
||||
schedule = { kind = "cron", expr = "0 2 * * *" }
|
||||
|
||||
# Agent job: runs an agent prompt on an interval
|
||||
[[cron.jobs]]
|
||||
id = "health-check"
|
||||
name = "Health Check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health: disk space, memory, CPU load"
|
||||
model = "anthropic/claude-sonnet-4"
|
||||
allowed_tools = ["shell", "file_read"]
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
|
||||
# Cron job with timezone and delivery
|
||||
# [[cron.jobs]]
|
||||
# id = "morning-report"
|
||||
# name = "Morning Report"
|
||||
# job_type = "agent"
|
||||
# prompt = "Generate a daily summary of system metrics"
|
||||
# schedule = { kind = "cron", expr = "0 9 * * 1-5", tz = "America/New_York" }
|
||||
# [cron.jobs.delivery]
|
||||
# mode = "announce"
|
||||
# channel = "telegram"
|
||||
# to = "123456789"
|
||||
|
||||
+19
@@ -1448,6 +1448,25 @@ else
|
||||
step_dot "Skipping install"
|
||||
fi
|
||||
|
||||
# --- Build web dashboard ---
|
||||
if [[ "$SKIP_BUILD" == false && -d "$WORK_DIR/web" ]]; then
|
||||
if have_cmd node && have_cmd npm; then
|
||||
step_dot "Building web dashboard"
|
||||
if (cd "$WORK_DIR/web" && npm ci --ignore-scripts 2>/dev/null && npm run build 2>/dev/null); then
|
||||
step_ok "Web dashboard built"
|
||||
else
|
||||
warn "Web dashboard build failed — dashboard will not be available"
|
||||
fi
|
||||
else
|
||||
warn "node/npm not found — skipping web dashboard build"
|
||||
warn "Install Node.js (>=18) and re-run, or build manually: cd web && npm ci && npm run build"
|
||||
fi
|
||||
else
|
||||
if [[ "$SKIP_BUILD" == true ]]; then
|
||||
step_dot "Skipping web dashboard build"
|
||||
fi
|
||||
fi
|
||||
|
||||
ZEROCLAW_BIN=""
|
||||
if [[ -x "$HOME/.cargo/bin/zeroclaw" ]]; then
|
||||
ZEROCLAW_BIN="$HOME/.cargo/bin/zeroclaw"
|
||||
|
||||
Executable
+21
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Start a browser on a virtual display
|
||||
# Usage: ./start-browser.sh [display_num] [url]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
URL=${2:-"https://google.com"}
|
||||
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Check if display is running
|
||||
if ! xdpyinfo -display :$DISPLAY_NUM &>/dev/null; then
|
||||
echo "Error: Display :$DISPLAY_NUM not running."
|
||||
echo "Start VNC first: ./start-vnc.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
google-chrome --no-sandbox --disable-gpu --disable-setuid-sandbox "$URL" &
|
||||
echo "Chrome started on display :$DISPLAY_NUM"
|
||||
echo "View via VNC or noVNC"
|
||||
Executable
+52
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access for browser GUI
|
||||
# Usage: ./start-vnc.sh [display_num] [vnc_port] [novnc_port] [resolution]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
VNC_PORT=${2:-5900}
|
||||
NOVNC_PORT=${3:-6080}
|
||||
RESOLUTION=${4:-1920x1080x24}
|
||||
|
||||
echo "Starting virtual display :$DISPLAY_NUM at $RESOLUTION"
|
||||
|
||||
# Kill any existing sessions
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*$NOVNC_PORT" 2>/dev/null || true
|
||||
sleep 1
|
||||
|
||||
# Start Xvfb (virtual framebuffer)
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
XVFB_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Set DISPLAY
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM 2>/dev/null &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg 2>/dev/null
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC client)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
NOVNC_PID=$!
|
||||
|
||||
echo ""
|
||||
echo "==================================="
|
||||
echo "VNC Server started!"
|
||||
echo "==================================="
|
||||
echo "VNC Direct: localhost:$VNC_PORT"
|
||||
echo "noVNC Web: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
echo "Display: :$DISPLAY_NUM"
|
||||
echo "==================================="
|
||||
echo ""
|
||||
echo "To start a browser, run:"
|
||||
echo " DISPLAY=:$DISPLAY_NUM google-chrome &"
|
||||
echo ""
|
||||
echo "To stop, run: pkill -f 'Xvfb :$DISPLAY_NUM'"
|
||||
Executable
+11
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
# Stop virtual display and VNC server
|
||||
# Usage: ./stop-vnc.sh [display_num]
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*6080" 2>/dev/null || true
|
||||
|
||||
echo "VNC server stopped"
|
||||
@@ -77,7 +77,9 @@ echo "Created annotated tag: $TAG"
|
||||
if [[ "$PUSH_TAG" == "true" ]]; then
|
||||
git push origin "$TAG"
|
||||
echo "Pushed tag to origin: $TAG"
|
||||
echo "GitHub release pipeline will run via .github/workflows/pub-release.yml"
|
||||
echo "Release Stable workflow will auto-trigger via tag push."
|
||||
echo "Monitor: gh workflow view 'Release Stable' --web"
|
||||
else
|
||||
echo "Next step: git push origin $TAG"
|
||||
echo "This will auto-trigger the Release Stable workflow (builds, Docker, crates.io, website, Scoop, AUR, Homebrew, tweet)."
|
||||
fi
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
---
|
||||
name: browser
|
||||
description: Headless browser automation using agent-browser CLI
|
||||
metadata: {"zeroclaw":{"emoji":"🌐","requires":{"bins":["agent-browser"]}}}
|
||||
---
|
||||
|
||||
# Browser Skill
|
||||
|
||||
Control a headless browser for web automation, scraping, and testing.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `agent-browser` CLI installed globally (`npm install -g agent-browser`)
|
||||
- Chrome downloaded (`agent-browser install`)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install agent-browser CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Navigate and snapshot
|
||||
|
||||
```bash
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Interact with elements
|
||||
|
||||
```bash
|
||||
agent-browser click @e1 # Click by ref
|
||||
agent-browser fill @e2 "text" # Fill input
|
||||
agent-browser press Enter # Press key
|
||||
```
|
||||
|
||||
### Extract data
|
||||
|
||||
```bash
|
||||
agent-browser get text @e1 # Get text content
|
||||
agent-browser get url # Get current URL
|
||||
agent-browser screenshot page.png # Take screenshot
|
||||
```
|
||||
|
||||
### Session management
|
||||
|
||||
```bash
|
||||
agent-browser close # Close browser
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Login flow
|
||||
|
||||
```bash
|
||||
agent-browser open https://site.com/login
|
||||
agent-browser snapshot -i
|
||||
agent-browser fill @email "user@example.com"
|
||||
agent-browser fill @password "secretpass"
|
||||
agent-browser click @submit
|
||||
agent-browser wait --text "Welcome"
|
||||
```
|
||||
|
||||
### Scrape page content
|
||||
|
||||
```bash
|
||||
agent-browser open https://news.ycombinator.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser get text @e1
|
||||
```
|
||||
|
||||
### Take screenshots
|
||||
|
||||
```bash
|
||||
agent-browser open https://google.com
|
||||
agent-browser screenshot --full page.png
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
- `--json` - JSON output for parsing
|
||||
- `--headed` - Show browser window (for debugging)
|
||||
- `--session-name <name>` - Persist session cookies
|
||||
- `--profile <path>` - Use persistent browser profile
|
||||
|
||||
## Configuration
|
||||
|
||||
The browser tool is enabled by default with `allowed_domains = ["*"]` and
|
||||
`backend = "agent_browser"`. To customize, edit `~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
## Full Command Reference
|
||||
|
||||
Run `agent-browser --help` for all available commands.
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser GitHub](https://github.com/vercel-labs/agent-browser)
|
||||
- [VNC Setup Guide](../docs/browser-setup.md)
|
||||
+517
-60
@@ -4,7 +4,7 @@ use crate::config::Config;
|
||||
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
|
||||
use crate::cost::CostTracker;
|
||||
use crate::i18n::ToolDescriptions;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::memory::{self, decay, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
|
||||
use crate::providers::{
|
||||
@@ -561,6 +561,7 @@ fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Res
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
/// Core memories are exempt from time decay (evergreen).
|
||||
async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
@@ -570,7 +571,10 @@ async fn build_context(
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
if let Ok(mut entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
@@ -2707,16 +2711,53 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
|
||||
// ── Vision provider routing ──────────────────────────
|
||||
// When the default provider lacks vision support but a dedicated
|
||||
// vision_provider is configured, create it on demand and use it
|
||||
// for this iteration. Otherwise, preserve the original error.
|
||||
let vision_provider_box: Option<Box<dyn Provider>> = if image_marker_count > 0
|
||||
&& !provider.supports_vision()
|
||||
{
|
||||
if let Some(ref vp) = multimodal_config.vision_provider {
|
||||
let vp_instance = providers::create_provider(vp, None)
|
||||
.map_err(|e| anyhow::anyhow!("failed to create vision provider '{vp}': {e}"))?;
|
||||
if !vp_instance.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: vp.clone(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"configured vision_provider '{vp}' does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
Some(vp_instance)
|
||||
} else {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
.into());
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (active_provider, active_provider_name, active_model): (&dyn Provider, &str, &str) =
|
||||
if let Some(ref vp_box) = vision_provider_box {
|
||||
let vp_name = multimodal_config
|
||||
.vision_provider
|
||||
.as_deref()
|
||||
.unwrap_or(provider_name);
|
||||
let vm = multimodal_config.vision_model.as_deref().unwrap_or(model);
|
||||
(vp_box.as_ref(), vp_name, vm)
|
||||
} else {
|
||||
(provider, provider_name, model)
|
||||
};
|
||||
|
||||
let prepared_messages =
|
||||
multimodal::prepare_messages_for_provider(history, multimodal_config).await?;
|
||||
@@ -2732,15 +2773,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmRequest {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
messages_count: history.len(),
|
||||
});
|
||||
runtime_trace::record_event(
|
||||
"llm_request",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
None,
|
||||
None,
|
||||
@@ -2778,12 +2819,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
None
|
||||
};
|
||||
|
||||
let chat_future = provider.chat(
|
||||
let chat_future = active_provider.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
active_model,
|
||||
temperature,
|
||||
);
|
||||
|
||||
@@ -2836,8 +2877,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
.unwrap_or((None, None));
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
@@ -2846,10 +2887,9 @@ pub(crate) async fn run_tool_call_loop(
|
||||
});
|
||||
|
||||
// Record cost via task-local tracker (no-op when not scoped)
|
||||
let _ = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage));
|
||||
let _ = resp.usage.as_ref().and_then(|usage| {
|
||||
record_tool_loop_cost_usage(active_provider_name, active_model, usage)
|
||||
});
|
||||
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
@@ -2872,8 +2912,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"tool_call_parse_issue",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&parse_issue),
|
||||
@@ -2890,8 +2930,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
@@ -2940,8 +2980,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
Err(e) => {
|
||||
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(safe_error.clone()),
|
||||
@@ -2951,8 +2991,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&safe_error),
|
||||
@@ -3701,6 +3741,11 @@ pub async fn run(
|
||||
|
||||
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
||||
let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config);
|
||||
|
||||
// Register skill-defined tools as callable tool specs in the tool registry
|
||||
// so the LLM can invoke them via native function calling, not just XML prompts.
|
||||
tools::register_skill_tools(&mut tools_registry, &skills, security.clone());
|
||||
|
||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||
(
|
||||
"shell",
|
||||
@@ -3865,17 +3910,45 @@ pub async fn run(
|
||||
|
||||
let mut final_output = String::new();
|
||||
|
||||
// Save the base system prompt before any thinking modifications so
|
||||
// the interactive loop can restore it between turns.
|
||||
let base_system_prompt = system_prompt.clone();
|
||||
|
||||
if let Some(msg) = message {
|
||||
// ── Parse thinking directive from user message ─────────
|
||||
let (thinking_directive, effective_msg) =
|
||||
match crate::agent::thinking::parse_thinking_directive(&msg) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed from message");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, msg.clone()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let effective_temperature = crate::agent::thinking::clamp_temperature(
|
||||
temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// Prepend thinking system prompt prefix when present.
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
}
|
||||
|
||||
// Auto-save user message to memory (skip short/trivial messages)
|
||||
if config.memory.auto_save
|
||||
&& msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&msg)
|
||||
&& effective_msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&effective_msg)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(
|
||||
&user_key,
|
||||
&msg,
|
||||
&effective_msg,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -3885,7 +3958,7 @@ pub async fn run(
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&msg,
|
||||
&effective_msg,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -3893,14 +3966,14 @@ pub async fn run(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, &msg, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, &effective_msg, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {msg}")
|
||||
format!("[{now}] {effective_msg}")
|
||||
} else {
|
||||
format!("{context}[{now}] {msg}")
|
||||
format!("{context}[{now}] {effective_msg}")
|
||||
};
|
||||
|
||||
let mut history = vec![
|
||||
@@ -3909,8 +3982,11 @@ pub async fn run(
|
||||
];
|
||||
|
||||
// Compute per-turn excluded MCP tools from tool_filter_groups.
|
||||
let excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, &msg);
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
&effective_msg,
|
||||
);
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
let mut response = String::new();
|
||||
@@ -3922,7 +3998,7 @@ pub async fn run(
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
effective_temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
@@ -4042,9 +4118,10 @@ pub async fn run(
|
||||
"/quit" | "/exit" => break,
|
||||
"/help" => {
|
||||
println!("Available commands:");
|
||||
println!(" /help Show this help message");
|
||||
println!(" /clear /new Clear conversation history");
|
||||
println!(" /quit /exit Exit interactive mode\n");
|
||||
println!(" /help Show this help message");
|
||||
println!(" /clear /new Clear conversation history");
|
||||
println!(" /quit /exit Exit interactive mode");
|
||||
println!(" /think:<level> Set reasoning depth (off|minimal|low|medium|high|max)\n");
|
||||
continue;
|
||||
}
|
||||
"/clear" | "/new" => {
|
||||
@@ -4096,16 +4173,47 @@ pub async fn run(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// ── Parse thinking directive from interactive input ───
|
||||
let (thinking_directive, effective_input) =
|
||||
match crate::agent::thinking::parse_thinking_directive(&user_input) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, user_input.clone()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let turn_temperature = crate::agent::thinking::clamp_temperature(
|
||||
temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// For non-Medium levels, temporarily patch the system prompt with prefix.
|
||||
let turn_system_prompt;
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
turn_system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
// Update the system message in history for this turn.
|
||||
if let Some(sys_msg) = history.first_mut() {
|
||||
if sys_msg.role == "system" {
|
||||
sys_msg.content = turn_system_prompt.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-save conversation turns (skip short/trivial messages)
|
||||
if config.memory.auto_save
|
||||
&& user_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&user_input)
|
||||
&& effective_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&effective_input)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(
|
||||
&user_key,
|
||||
&user_input,
|
||||
&effective_input,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -4115,7 +4223,7 @@ pub async fn run(
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&user_input,
|
||||
&effective_input,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -4123,14 +4231,14 @@ pub async fn run(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, &user_input, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, &effective_input, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {user_input}")
|
||||
format!("[{now}] {effective_input}")
|
||||
} else {
|
||||
format!("{context}[{now}] {user_input}")
|
||||
format!("{context}[{now}] {effective_input}")
|
||||
};
|
||||
|
||||
history.push(ChatMessage::user(&enriched));
|
||||
@@ -4139,7 +4247,7 @@ pub async fn run(
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
&user_input,
|
||||
&effective_input,
|
||||
);
|
||||
|
||||
let response = loop {
|
||||
@@ -4150,7 +4258,7 @@ pub async fn run(
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
turn_temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
@@ -4235,6 +4343,15 @@ pub async fn run(
|
||||
// Hard cap as a safety net.
|
||||
trim_history(&mut history, config.agent.max_history_messages);
|
||||
|
||||
// Restore base system prompt (remove per-turn thinking prefix).
|
||||
if thinking_params.system_prompt_prefix.is_some() {
|
||||
if let Some(sys_msg) = history.first_mut() {
|
||||
if sys_msg.role == "system" {
|
||||
sys_msg.content.clone_from(&base_system_prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = session_state_file.as_deref() {
|
||||
save_interactive_session_history(path, &history)?;
|
||||
}
|
||||
@@ -4415,6 +4532,10 @@ pub async fn process_message(
|
||||
let i18n_descs = crate::i18n::ToolDescriptions::load(&i18n_locale, &i18n_search_dirs);
|
||||
|
||||
let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config);
|
||||
|
||||
// Register skill-defined tools as callable tool specs (process_message path).
|
||||
tools::register_skill_tools(&mut tools_registry, &skills, security.clone());
|
||||
|
||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||
("shell", "Execute terminal commands."),
|
||||
("file_read", "Read file contents."),
|
||||
@@ -4508,9 +4629,34 @@ pub async fn process_message(
|
||||
system_prompt.push_str(&deferred_section);
|
||||
}
|
||||
|
||||
// ── Parse thinking directive from user message ─────────────
|
||||
let (thinking_directive, effective_message) =
|
||||
match crate::agent::thinking::parse_thinking_directive(message) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed from message");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, message.to_string()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let effective_temperature = crate::agent::thinking::clamp_temperature(
|
||||
config.default_temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// Prepend thinking system prompt prefix when present.
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
}
|
||||
|
||||
let effective_msg_ref = effective_message.as_str();
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
message,
|
||||
effective_msg_ref,
|
||||
config.memory.min_relevance_score,
|
||||
session_id,
|
||||
)
|
||||
@@ -4518,22 +4664,25 @@ pub async fn process_message(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, message, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, effective_msg_ref, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {message}")
|
||||
format!("[{now}] {effective_message}")
|
||||
} else {
|
||||
format!("{context}[{now}] {message}")
|
||||
format!("{context}[{now}] {effective_message}")
|
||||
};
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system(&system_prompt),
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
let mut excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, message);
|
||||
let mut excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
effective_msg_ref,
|
||||
);
|
||||
if config.autonomy.level != AutonomyLevel::Full {
|
||||
excluded_tools.extend(config.autonomy.non_cli_excluded_tools.iter().cloned());
|
||||
}
|
||||
@@ -4545,7 +4694,7 @@ pub async fn process_message(
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
config.default_temperature,
|
||||
effective_temperature,
|
||||
true,
|
||||
"daemon",
|
||||
None,
|
||||
@@ -5094,6 +5243,7 @@ mod tests {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
@@ -5171,6 +5321,313 @@ mod tests {
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
/// When `vision_provider` is not set and the default provider lacks vision
|
||||
/// support, the original `ProviderCapabilityError` should be returned.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_no_vision_provider_config_preserves_error() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"check [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail without vision_provider config");
|
||||
|
||||
assert!(err.to_string().contains("capability=vision"));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
/// When `vision_provider` is set but the provider factory cannot resolve
|
||||
/// the name, a descriptive error should be returned (not the generic
|
||||
/// capability error).
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_vision_provider_creation_failure() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("some-model".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail when vision provider cannot be created");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure error, got: {}",
|
||||
err
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
/// Messages without image markers should use the default provider even
|
||||
/// when `vision_provider` is configured.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_no_images_uses_default_provider() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["hello world"]);
|
||||
|
||||
let mut history = vec![ChatMessage::user("just text, no images".to_string())];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("some-model".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even though vision_provider points to a nonexistent provider, this
|
||||
// should succeed because there are no image markers to trigger routing.
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"scripted",
|
||||
"scripted-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("text-only messages should succeed with default provider");
|
||||
|
||||
assert_eq!(result, "hello world");
|
||||
}
|
||||
|
||||
/// When `vision_provider` is set but `vision_model` is not, the default
|
||||
/// model should be used as fallback for the vision provider.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_vision_provider_without_model_falls_back() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"look [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
// vision_provider set but vision_model is None — the code should
|
||||
// fall back to the default model. Since the provider name is invalid,
|
||||
// we just verify the error path references the correct provider.
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail due to nonexistent vision provider");
|
||||
|
||||
// Verify the routing was attempted (not the generic capability error).
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure, got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
/// Empty `[IMAGE:]` markers (which are preserved as literal text by the
|
||||
/// parser) should not trigger vision provider routing.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_empty_image_markers_use_default_provider() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["handled"]);
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"empty marker [IMAGE:] should be ignored".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"scripted",
|
||||
"scripted-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("empty image markers should not trigger vision routing");
|
||||
|
||||
assert_eq!(result, "handled");
|
||||
}
|
||||
|
||||
/// Multiple image markers should still trigger vision routing when
|
||||
/// vision_provider is configured.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_multiple_images_trigger_vision_routing() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"two images [IMAGE:data:image/png;base64,aQ==] and [IMAGE:data:image/png;base64,bQ==]"
|
||||
.to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("llava:7b".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should attempt vision provider creation for multiple images");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure for multiple images, got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_execute_tools_in_parallel_returns_false_for_single_call() {
|
||||
let calls = vec![ParsedToolCall {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory};
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -43,13 +43,16 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory
|
||||
let mut entries = memory
|
||||
.recall(user_message, self.limit, session_id, None, None)
|
||||
.await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
@@ -118,6 +121,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -226,6 +232,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.95),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
@@ -235,6 +244,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
@@ -5,6 +5,7 @@ pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
pub mod thinking;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
+7
-6
@@ -473,8 +473,9 @@ mod tests {
|
||||
assert!(output.contains("<available_skills>"));
|
||||
assert!(output.contains("<name>deploy</name>"));
|
||||
assert!(output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -516,10 +517,10 @@ mod tests {
|
||||
assert!(output.contains("<location>skills/deploy/SKILL.md</location>"));
|
||||
assert!(output.contains("read_skill(name)"));
|
||||
assert!(!output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(output.contains("<tools>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
//! Thinking/Reasoning Level Control
|
||||
//!
|
||||
//! Allows users to control how deeply the model reasons per message,
|
||||
//! trading speed for depth. Levels range from `Off` (fastest, most concise)
|
||||
//! to `Max` (deepest reasoning, slowest).
|
||||
//!
|
||||
//! Users can set the level via:
|
||||
//! - Inline directive: `/think:high` at the start of a message
|
||||
//! - Agent config: `[agent.thinking]` section with `default_level`
|
||||
//!
|
||||
//! Resolution hierarchy (highest priority first):
|
||||
//! 1. Inline directive (`/think:<level>`)
|
||||
//! 2. Session override (reserved for future use)
|
||||
//! 3. Agent config (`agent.thinking.default_level`)
|
||||
//! 4. Global default (`Medium`)
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// How deeply the model should reason for a given message.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ThinkingLevel {
|
||||
/// No chain-of-thought. Fastest, most concise responses.
|
||||
Off,
|
||||
/// Minimal reasoning. Brief, direct answers.
|
||||
Minimal,
|
||||
/// Light reasoning. Short explanations when needed.
|
||||
Low,
|
||||
/// Balanced reasoning (default). Moderate depth.
|
||||
#[default]
|
||||
Medium,
|
||||
/// Deep reasoning. Thorough analysis and step-by-step thinking.
|
||||
High,
|
||||
/// Maximum reasoning depth. Exhaustive analysis.
|
||||
Max,
|
||||
}
|
||||
|
||||
impl ThinkingLevel {
|
||||
/// Parse a thinking level from a string (case-insensitive).
|
||||
pub fn from_str_insensitive(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"off" | "none" => Some(Self::Off),
|
||||
"minimal" | "min" => Some(Self::Minimal),
|
||||
"low" => Some(Self::Low),
|
||||
"medium" | "med" | "default" => Some(Self::Medium),
|
||||
"high" => Some(Self::High),
|
||||
"max" | "maximum" => Some(Self::Max),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for thinking/reasoning level control.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ThinkingConfig {
|
||||
/// Default thinking level when no directive is present.
|
||||
#[serde(default)]
|
||||
pub default_level: ThinkingLevel,
|
||||
}
|
||||
|
||||
impl Default for ThinkingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_level: ThinkingLevel::Medium,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters derived from a thinking level, applied to the LLM request.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ThinkingParams {
|
||||
/// Temperature adjustment (added to the base temperature, clamped to 0.0..=2.0).
|
||||
pub temperature_adjustment: f64,
|
||||
/// Maximum tokens adjustment (added to any existing max_tokens setting).
|
||||
pub max_tokens_adjustment: i64,
|
||||
/// Optional system prompt prefix injected before the existing system prompt.
|
||||
pub system_prompt_prefix: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse a `/think:<level>` directive from the start of a message.
|
||||
///
|
||||
/// Returns `Some((level, remaining_message))` if a directive is found,
|
||||
/// or `None` if no directive is present. The remaining message has
|
||||
/// leading whitespace after the directive trimmed.
|
||||
pub fn parse_thinking_directive(message: &str) -> Option<(ThinkingLevel, String)> {
|
||||
let trimmed = message.trim_start();
|
||||
if !trimmed.starts_with("/think:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract the level token (everything between `/think:` and the next whitespace or end).
|
||||
let after_prefix = &trimmed["/think:".len()..];
|
||||
let level_end = after_prefix
|
||||
.find(|c: char| c.is_whitespace())
|
||||
.unwrap_or(after_prefix.len());
|
||||
let level_str = &after_prefix[..level_end];
|
||||
|
||||
let level = ThinkingLevel::from_str_insensitive(level_str)?;
|
||||
|
||||
let remaining = after_prefix[level_end..].trim_start().to_string();
|
||||
Some((level, remaining))
|
||||
}
|
||||
|
||||
/// Convert a `ThinkingLevel` into concrete parameters for the LLM request.
|
||||
pub fn apply_thinking_level(level: ThinkingLevel) -> ThinkingParams {
|
||||
match level {
|
||||
ThinkingLevel::Off => ThinkingParams {
|
||||
temperature_adjustment: -0.2,
|
||||
max_tokens_adjustment: -1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Be extremely concise. Give direct answers without explanation \
|
||||
unless explicitly asked. No preamble."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Minimal => ThinkingParams {
|
||||
temperature_adjustment: -0.1,
|
||||
max_tokens_adjustment: -500,
|
||||
system_prompt_prefix: Some(
|
||||
"Be concise and fast. Keep explanations brief. \
|
||||
Prioritize speed over thoroughness."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Low => ThinkingParams {
|
||||
temperature_adjustment: -0.05,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: Some("Keep reasoning light. Explain only when helpful.".into()),
|
||||
},
|
||||
ThinkingLevel::Medium => ThinkingParams {
|
||||
temperature_adjustment: 0.0,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: None,
|
||||
},
|
||||
ThinkingLevel::High => ThinkingParams {
|
||||
temperature_adjustment: 0.05,
|
||||
max_tokens_adjustment: 1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think step by step. Provide thorough analysis and \
|
||||
consider edge cases before answering."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Max => ThinkingParams {
|
||||
temperature_adjustment: 0.1,
|
||||
max_tokens_adjustment: 2000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think very carefully and exhaustively. Break down the problem \
|
||||
into sub-problems, consider all angles, verify your reasoning, \
|
||||
and provide the most thorough analysis possible."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the effective thinking level using the priority hierarchy:
|
||||
/// 1. Inline directive (if present)
|
||||
/// 2. Session override (reserved, currently always `None`)
|
||||
/// 3. Agent config default
|
||||
/// 4. Global default (`Medium`)
|
||||
pub fn resolve_thinking_level(
|
||||
inline_directive: Option<ThinkingLevel>,
|
||||
session_override: Option<ThinkingLevel>,
|
||||
config: &ThinkingConfig,
|
||||
) -> ThinkingLevel {
|
||||
inline_directive
|
||||
.or(session_override)
|
||||
.unwrap_or(config.default_level)
|
||||
}
|
||||
|
||||
/// Clamp a temperature value to the valid range `[0.0, 2.0]`.
|
||||
pub fn clamp_temperature(temp: f64) -> f64 {
|
||||
temp.clamp(0.0, 2.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── ThinkingLevel parsing ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_canonical_names() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("off"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("minimal"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("low"),
|
||||
Some(ThinkingLevel::Low)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("medium"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("high"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_aliases() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("none"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("min"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("med"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("default"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("maximum"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_case_insensitive() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("HIGH"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("Max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("OFF"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_invalid_returns_none() {
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("turbo"), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive(""), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("super-high"), None);
|
||||
}
|
||||
|
||||
// ── Directive parsing ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_directive_extracts_level_and_remaining_message() {
|
||||
let result = parse_thinking_directive("/think:high What is Rust?");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::High);
|
||||
assert_eq!(remaining, "What is Rust?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_handles_directive_only() {
|
||||
let result = parse_thinking_directive("/think:off");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Off);
|
||||
assert_eq!(remaining, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_strips_leading_whitespace() {
|
||||
let result = parse_thinking_directive(" /think:low Tell me about Rust");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Low);
|
||||
assert_eq!(remaining, "Tell me about Rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_no_directive() {
|
||||
assert!(parse_thinking_directive("Hello world").is_none());
|
||||
assert!(parse_thinking_directive("").is_none());
|
||||
assert!(parse_thinking_directive("/think").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_invalid_level() {
|
||||
assert!(parse_thinking_directive("/think:turbo What?").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_not_triggered_mid_message() {
|
||||
assert!(parse_thinking_directive("Hello /think:high world").is_none());
|
||||
}
|
||||
|
||||
// ── Level application ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_off_is_concise() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Off);
|
||||
assert!(params.temperature_adjustment < 0.0);
|
||||
assert!(params.max_tokens_adjustment < 0);
|
||||
assert!(params.system_prompt_prefix.is_some());
|
||||
assert!(params
|
||||
.system_prompt_prefix
|
||||
.unwrap()
|
||||
.to_lowercase()
|
||||
.contains("concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_medium_is_neutral() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Medium);
|
||||
assert!((params.temperature_adjustment - 0.0).abs() < f64::EPSILON);
|
||||
assert_eq!(params.max_tokens_adjustment, 0);
|
||||
assert!(params.system_prompt_prefix.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_high_adds_step_by_step() {
|
||||
let params = apply_thinking_level(ThinkingLevel::High);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("step by step"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_max_is_most_thorough() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Max);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("exhaustively"));
|
||||
}
|
||||
|
||||
// ── Resolution hierarchy ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resolve_inline_directive_takes_priority() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result =
|
||||
resolve_thinking_level(Some(ThinkingLevel::Max), Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::Max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_session_override_takes_priority_over_config() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result = resolve_thinking_level(None, Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_falls_back_to_config_default() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Minimal,
|
||||
};
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Minimal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_default_config_uses_medium() {
|
||||
let config = ThinkingConfig::default();
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
// ── Temperature clamping ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_within_range() {
|
||||
assert!((clamp_temperature(0.7) - 0.7).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(0.0) - 0.0).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(2.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_below_minimum() {
|
||||
assert!((clamp_temperature(-0.5) - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_above_maximum() {
|
||||
assert!((clamp_temperature(3.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Serde round-trip ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_config_deserializes_from_toml() {
|
||||
let toml_str = r#"default_level = "high""#;
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_config_default_level_deserializes() {
|
||||
let toml_str = "";
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_serializes_lowercase() {
|
||||
let level = ThinkingLevel::High;
|
||||
let json = serde_json::to_string(&level).unwrap();
|
||||
assert_eq!(json, "\"high\"");
|
||||
}
|
||||
}
|
||||
@@ -562,4 +562,50 @@ mod tests {
|
||||
let parsed: ApprovalRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.tool_name, "shell");
|
||||
}
|
||||
|
||||
// ── Regression: #4247 default approved tools in channels ──
|
||||
|
||||
#[test]
|
||||
fn non_interactive_allows_default_auto_approve_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
|
||||
for tool in &config.auto_approve {
|
||||
assert!(
|
||||
!mgr.needs_approval(tool),
|
||||
"default auto_approve tool '{tool}' should not need approval in non-interactive mode"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_denies_unknown_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("some_unknown_tool"),
|
||||
"unknown tool should need approval"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_weather_is_auto_approved() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
!mgr.needs_approval("weather"),
|
||||
"weather tool must not need approval — it is in the default auto_approve list"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn always_ask_overrides_auto_approve() {
|
||||
let mut config = AutonomyConfig::default();
|
||||
config.always_ask = vec!["weather".into()];
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("weather"),
|
||||
"always_ask must override auto_approve"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+116
-1
@@ -20,6 +20,9 @@ pub struct DiscordChannel {
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@@ -38,6 +41,7 @@ impl DiscordChannel {
|
||||
mention_only,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +51,14 @@ impl DiscordChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
|
||||
}
|
||||
@@ -113,6 +125,88 @@ async fn process_attachments(
|
||||
parts.join("\n---\n")
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const DISCORD_AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check if a content type or filename indicates an audio file.
|
||||
fn is_discord_audio_attachment(content_type: &str, filename: &str) -> bool {
|
||||
if content_type.starts_with("audio/") {
|
||||
return true;
|
||||
}
|
||||
if let Some(ext) = filename.rsplit('.').next() {
|
||||
return DISCORD_AUDIO_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str());
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Download and transcribe audio attachments from a Discord message.
|
||||
///
|
||||
/// Returns transcribed text blocks for any audio attachments found.
|
||||
/// Non-audio attachments and failures are silently skipped.
|
||||
async fn transcribe_discord_audio_attachments(
|
||||
attachments: &[serde_json::Value],
|
||||
client: &reqwest::Client,
|
||||
config: &crate::config::TranscriptionConfig,
|
||||
) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
for att in attachments {
|
||||
let ct = att
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let name = att
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("file");
|
||||
|
||||
if !is_discord_audio_attachment(ct, name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(url) = att.get("url").and_then(|v| v.as_str()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let audio_data = match client.get(url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: failed to read audio attachment bytes");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Ok(resp) => {
|
||||
tracing::warn!(name, status = %resp.status(), "discord: audio attachment download failed");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: audio attachment fetch error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, name, config).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
tracing::info!(
|
||||
"Discord: transcribed audio attachment {} ({} chars)",
|
||||
name,
|
||||
trimmed.len()
|
||||
);
|
||||
parts.push(format!("[Voice] {trimmed}"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: voice transcription failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum DiscordAttachmentKind {
|
||||
Image,
|
||||
@@ -737,7 +831,28 @@ impl Channel for DiscordChannel {
|
||||
.and_then(|a| a.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
process_attachments(&atts, &self.http_client()).await
|
||||
let client = self.http_client();
|
||||
let mut text_parts = process_attachments(&atts, &client).await;
|
||||
|
||||
// Transcribe audio attachments when transcription is configured
|
||||
if let Some(ref transcription_config) = self.transcription {
|
||||
let voice_text = transcribe_discord_audio_attachments(
|
||||
&atts,
|
||||
&client,
|
||||
transcription_config,
|
||||
)
|
||||
.await;
|
||||
if !voice_text.is_empty() {
|
||||
if text_parts.is_empty() {
|
||||
text_parts = voice_text;
|
||||
} else {
|
||||
text_parts = format!("{text_parts}
|
||||
{voice_text}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_parts
|
||||
};
|
||||
let final_content = if attachment_text.is_empty() {
|
||||
clean_content
|
||||
|
||||
+535
-47
@@ -1,5 +1,6 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine as _;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use std::collections::HashMap;
|
||||
@@ -221,6 +222,21 @@ const LARK_INVALID_ACCESS_TOKEN_CODE: i64 = 99_991_663;
|
||||
/// Lark card payloads have a ~30 KB limit; leave margin for JSON envelope.
|
||||
const LARK_CARD_MARKDOWN_MAX_BYTES: usize = 28_000;
|
||||
|
||||
/// Maximum image size we will download and inline (5 MiB).
|
||||
const LARK_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
/// Maximum file size we will download and present as text (512 KiB).
|
||||
const LARK_FILE_MAX_BYTES: usize = 512 * 1024;
|
||||
|
||||
/// Image MIME types we support for inline base64 encoding.
|
||||
const LARK_SUPPORTED_IMAGE_MIMES: &[&str] = &[
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/bmp",
|
||||
];
|
||||
|
||||
/// Returns true when the WebSocket frame indicates live traffic that should
|
||||
/// refresh the heartbeat watchdog.
|
||||
fn should_refresh_last_recv(msg: &WsMsg) -> bool {
|
||||
@@ -520,6 +536,17 @@ impl LarkChannel {
|
||||
format!("{}/im/v1/messages/{message_id}/reactions", self.api_base())
|
||||
}
|
||||
|
||||
fn image_download_url(&self, image_key: &str) -> String {
|
||||
format!("{}/im/v1/images/{image_key}", self.api_base())
|
||||
}
|
||||
|
||||
fn file_download_url(&self, message_id: &str, file_key: &str) -> String {
|
||||
format!(
|
||||
"{}/im/v1/messages/{message_id}/resources/{file_key}?type=file",
|
||||
self.api_base()
|
||||
)
|
||||
}
|
||||
|
||||
fn resolved_bot_open_id(&self) -> Option<String> {
|
||||
self.resolved_bot_open_id
|
||||
.read()
|
||||
@@ -866,6 +893,44 @@ impl LarkChannel {
|
||||
Some(details) => (details.text, details.mentioned_open_ids),
|
||||
None => continue,
|
||||
},
|
||||
"image" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let image_key = match v.get("image_key").and_then(|k| k.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => { tracing::debug!("Lark WS: image message missing image_key"); continue; }
|
||||
};
|
||||
match self.download_image_as_marker(&image_key).await {
|
||||
Some(marker) => (marker, Vec::new()),
|
||||
None => {
|
||||
tracing::warn!("Lark WS: failed to download image {image_key}");
|
||||
(format!("[IMAGE:{image_key} | download failed]"), Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
"file" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let file_key = match v.get("file_key").and_then(|k| k.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => { tracing::debug!("Lark WS: file message missing file_key"); continue; }
|
||||
};
|
||||
let file_name = v.get("file_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or("unknown_file")
|
||||
.to_string();
|
||||
match self.download_file_as_content(&lark_msg.message_id, &file_key, &file_name).await {
|
||||
Some(content) => (content, Vec::new()),
|
||||
None => {
|
||||
tracing::warn!("Lark WS: failed to download file {file_key}");
|
||||
(format!("[ATTACHMENT:{file_name} | download failed]"), Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
|
||||
};
|
||||
|
||||
@@ -986,6 +1051,183 @@ impl LarkChannel {
|
||||
*cached = None;
|
||||
}
|
||||
|
||||
/// Download an image from the Lark API and return an `[IMAGE:data:...]` marker string.
|
||||
async fn download_image_as_marker(&self, image_key: &str) -> Option<String> {
|
||||
let token = match self.get_tenant_access_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: failed to get token for image download: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.image_download_url(image_key);
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: image download request failed for {image_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Lark: image download failed for {image_key}: status={}",
|
||||
resp.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(cl) = resp.content_length() {
|
||||
if cl > LARK_IMAGE_MAX_BYTES as u64 {
|
||||
tracing::warn!("Lark: image too large for {image_key}: {cl} bytes exceeds limit");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string);
|
||||
|
||||
let bytes = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: image body read failed for {image_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if bytes.is_empty() || bytes.len() > LARK_IMAGE_MAX_BYTES {
|
||||
tracing::warn!(
|
||||
"Lark: image body empty or too large for {image_key}: {} bytes",
|
||||
bytes.len()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let mime = lark_detect_image_mime(content_type.as_deref(), &bytes)?;
|
||||
if !LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
|
||||
tracing::warn!("Lark: unsupported image MIME for {image_key}: {mime}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
Some(format!("[IMAGE:data:{mime};base64,{encoded}]"))
|
||||
}
|
||||
|
||||
/// Download a file from the Lark API and return a text content marker.
|
||||
/// For text-like files, the content is inlined. For binary files, a summary is returned.
|
||||
async fn download_file_as_content(
|
||||
&self,
|
||||
message_id: &str,
|
||||
file_key: &str,
|
||||
file_name: &str,
|
||||
) -> Option<String> {
|
||||
let token = match self.get_tenant_access_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: failed to get token for file download: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.file_download_url(message_id, file_key);
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: file download request failed for {file_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Lark: file download failed for {file_key}: status={}",
|
||||
resp.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(cl) = resp.content_length() {
|
||||
if cl > LARK_FILE_MAX_BYTES as u64 {
|
||||
tracing::warn!("Lark: file too large for {file_key}: {cl} bytes exceeds limit");
|
||||
return Some(format!(
|
||||
"[ATTACHMENT:{file_name} | size={cl} bytes | too large to inline]"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let bytes = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: file body read failed for {file_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if bytes.is_empty() {
|
||||
tracing::warn!("Lark: file body is empty for {file_key}");
|
||||
return None;
|
||||
}
|
||||
|
||||
// If the content is image-like, return as image marker
|
||||
if content_type.starts_with("image/") && bytes.len() <= LARK_IMAGE_MAX_BYTES {
|
||||
if let Some(mime) = lark_detect_image_mime(Some(&content_type), &bytes) {
|
||||
if LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
return Some(format!("[IMAGE:data:{mime};base64,{encoded}]"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the file looks like text, inline it
|
||||
if bytes.len() <= LARK_FILE_MAX_BYTES
|
||||
&& !bytes.contains(&0)
|
||||
&& (content_type.starts_with("text/")
|
||||
|| content_type.contains("json")
|
||||
|| content_type.contains("xml")
|
||||
|| content_type.contains("yaml")
|
||||
|| content_type.contains("javascript")
|
||||
|| content_type.contains("csv")
|
||||
|| lark_is_text_filename(file_name))
|
||||
{
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
let truncated = if text.len() > 50_000 {
|
||||
format!("{}...\n[truncated]", &text[..50_000])
|
||||
} else {
|
||||
text.into_owned()
|
||||
};
|
||||
let ext = file_name.rsplit('.').next().unwrap_or("text");
|
||||
return Some(format!("[FILE:{file_name}]\n```{ext}\n{truncated}\n```"));
|
||||
}
|
||||
|
||||
Some(format!(
|
||||
"[ATTACHMENT:{file_name} | mime={content_type} | size={} bytes]",
|
||||
bytes.len()
|
||||
))
|
||||
}
|
||||
|
||||
async fn fetch_bot_open_id_with_token(
|
||||
&self,
|
||||
token: &str,
|
||||
@@ -1085,8 +1327,9 @@ impl LarkChannel {
|
||||
Ok((status, parsed))
|
||||
}
|
||||
|
||||
/// Parse an event callback payload and extract text messages
|
||||
pub fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
/// Parse an event callback payload and extract messages.
|
||||
/// Supports text, post, image, and file message types.
|
||||
pub async fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
// Lark event v2 structure:
|
||||
@@ -1143,6 +1386,11 @@ impl LarkChannel {
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let evt_message_id = event
|
||||
.pointer("/message/message_id")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let (text, post_mentioned_open_ids): (String, Vec<String>) = match msg_type {
|
||||
"text" => {
|
||||
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
@@ -1162,6 +1410,62 @@ impl LarkChannel {
|
||||
Some(details) => (details.text, details.mentioned_open_ids),
|
||||
None => return messages,
|
||||
},
|
||||
"image" => {
|
||||
let image_key = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("image_key")
|
||||
.and_then(|k| k.as_str())
|
||||
.map(String::from)
|
||||
});
|
||||
match image_key {
|
||||
Some(key) => {
|
||||
let marker = match self.download_image_as_marker(&key).await {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!("Lark: failed to download image {key}");
|
||||
format!("[IMAGE:{key} | download failed]")
|
||||
}
|
||||
};
|
||||
(marker, Vec::new())
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("Lark: image message missing image_key");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
"file" => {
|
||||
let parsed = serde_json::from_str::<serde_json::Value>(content_str).ok();
|
||||
let file_key = parsed
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("file_key").and_then(|k| k.as_str()))
|
||||
.map(String::from);
|
||||
let file_name = parsed
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("file_name").and_then(|n| n.as_str()))
|
||||
.unwrap_or("unknown_file")
|
||||
.to_string();
|
||||
match file_key {
|
||||
Some(key) => {
|
||||
let content = match self
|
||||
.download_file_as_content(evt_message_id, &key, &file_name)
|
||||
.await
|
||||
{
|
||||
Some(c) => c,
|
||||
None => {
|
||||
tracing::warn!("Lark: failed to download file {key}");
|
||||
format!("[ATTACHMENT:{file_name} | download failed]")
|
||||
}
|
||||
};
|
||||
(content, Vec::new())
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("Lark: file message missing file_key");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
|
||||
return messages;
|
||||
@@ -1305,7 +1609,7 @@ impl LarkChannel {
|
||||
}
|
||||
|
||||
// Parse event messages
|
||||
let messages = state.channel.parse_event_payload(&payload);
|
||||
let messages = state.channel.parse_event_payload(&payload).await;
|
||||
if !messages.is_empty() {
|
||||
if let Some(message_id) = payload
|
||||
.pointer("/event/message/message_id")
|
||||
@@ -1556,6 +1860,72 @@ fn detect_lark_ack_locale(
|
||||
detect_locale_from_text(fallback_text).unwrap_or(LarkAckLocale::En)
|
||||
}
|
||||
|
||||
/// Detect image MIME type from magic bytes, falling back to Content-Type header.
|
||||
fn lark_detect_image_mime(content_type: Option<&str>, bytes: &[u8]) -> Option<String> {
|
||||
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
|
||||
return Some("image/png".to_string());
|
||||
}
|
||||
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
|
||||
return Some("image/jpeg".to_string());
|
||||
}
|
||||
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
|
||||
return Some("image/gif".to_string());
|
||||
}
|
||||
if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
|
||||
return Some("image/webp".to_string());
|
||||
}
|
||||
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
|
||||
return Some("image/bmp".to_string());
|
||||
}
|
||||
content_type
|
||||
.and_then(|ct| ct.split(';').next())
|
||||
.map(|ct| ct.trim().to_lowercase())
|
||||
.filter(|ct| ct.starts_with("image/"))
|
||||
}
|
||||
|
||||
/// Check if a filename looks like a text file based on extension.
|
||||
fn lark_is_text_filename(name: &str) -> bool {
|
||||
let ext = name.rsplit('.').next().unwrap_or("").to_ascii_lowercase();
|
||||
matches!(
|
||||
ext.as_str(),
|
||||
"txt"
|
||||
| "md"
|
||||
| "rs"
|
||||
| "py"
|
||||
| "js"
|
||||
| "ts"
|
||||
| "tsx"
|
||||
| "jsx"
|
||||
| "java"
|
||||
| "c"
|
||||
| "h"
|
||||
| "cpp"
|
||||
| "hpp"
|
||||
| "go"
|
||||
| "rb"
|
||||
| "sh"
|
||||
| "bash"
|
||||
| "zsh"
|
||||
| "toml"
|
||||
| "yaml"
|
||||
| "yml"
|
||||
| "json"
|
||||
| "xml"
|
||||
| "html"
|
||||
| "css"
|
||||
| "sql"
|
||||
| "csv"
|
||||
| "tsv"
|
||||
| "log"
|
||||
| "cfg"
|
||||
| "ini"
|
||||
| "conf"
|
||||
| "env"
|
||||
| "dockerfile"
|
||||
| "makefile"
|
||||
)
|
||||
}
|
||||
|
||||
fn random_lark_ack_reaction(
|
||||
payload: Option<&serde_json::Value>,
|
||||
fallback_text: &str,
|
||||
@@ -1892,8 +2262,8 @@ mod tests {
|
||||
assert!(!ch.is_user_allowed("ou_anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_challenge() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_challenge() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"challenge": "abc123",
|
||||
@@ -1901,12 +2271,12 @@ mod tests {
|
||||
"type": "url_verification"
|
||||
});
|
||||
// Challenge payloads should not produce messages
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_valid_text_message() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_valid_text_message() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": {
|
||||
@@ -1927,7 +2297,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
|
||||
assert_eq!(msgs[0].sender, "oc_chat123");
|
||||
@@ -1935,8 +2305,8 @@ mod tests {
|
||||
assert_eq!(msgs[0].timestamp, 1_699_999_999);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_unauthorized_user() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unauthorized_user() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -1951,12 +2321,38 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_non_text_message_skipped() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unsupported_message_type_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_type": "sticker",
|
||||
"content": "{}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_image_missing_key_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -1977,12 +2373,38 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_empty_text_skipped() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_file_missing_key_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_type": "file",
|
||||
"content": "{}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_empty_text_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2003,24 +2425,24 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_wrong_event_type() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_wrong_event_type() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.chat.disbanded_v1" },
|
||||
"event": {}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_missing_sender() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_missing_sender() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2040,12 +2462,12 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_unicode_message() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unicode_message() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2067,24 +2489,24 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "Hello world 🌍");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_missing_event() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_missing_event() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" }
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_invalid_content_json() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_invalid_content_json() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2105,7 +2527,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
@@ -2237,8 +2659,8 @@ mod tests {
|
||||
assert_eq!(ch.name(), "feishu");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_fallback_sender_to_open_id() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_fallback_sender_to_open_id() {
|
||||
// When chat_id is missing, sender should fall back to open_id
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
@@ -2260,13 +2682,13 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "ou_user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_message_requires_bot_mention_when_enabled() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_message_requires_bot_mention_when_enabled() {
|
||||
let ch = with_bot_open_id(
|
||||
LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
@@ -2292,7 +2714,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert!(ch.parse_event_payload(&no_mention_payload).is_empty());
|
||||
assert!(ch.parse_event_payload(&no_mention_payload).await.is_empty());
|
||||
|
||||
let wrong_mention_payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -2307,7 +2729,10 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert!(ch.parse_event_payload(&wrong_mention_payload).is_empty());
|
||||
assert!(ch
|
||||
.parse_event_payload(&wrong_mention_payload)
|
||||
.await
|
||||
.is_empty());
|
||||
|
||||
let bot_mention_payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -2322,11 +2747,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert_eq!(ch.parse_event_payload(&bot_mention_payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&bot_mention_payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
|
||||
let ch = with_bot_open_id(
|
||||
LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
@@ -2353,11 +2778,11 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_message_allows_without_mention_when_disabled() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_message_allows_without_mention_when_disabled() {
|
||||
let ch = LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
"secret".into(),
|
||||
@@ -2381,7 +2806,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2409,6 +2834,69 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_image_download_url_matches_region() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(
|
||||
ch.image_download_url("img_abc123"),
|
||||
"https://open.larksuite.com/open-apis/im/v1/images/img_abc123"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_file_download_url_matches_region() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(
|
||||
ch.file_download_url("om_msg123", "file_abc"),
|
||||
"https://open.larksuite.com/open-apis/im/v1/messages/om_msg123/resources/file_abc?type=file"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_detect_image_mime_from_magic_bytes() {
|
||||
let png = [0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, &png).as_deref(),
|
||||
Some("image/png")
|
||||
);
|
||||
|
||||
let jpeg = [0xff, 0xd8, 0xff, 0xe0];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, &jpeg).as_deref(),
|
||||
Some("image/jpeg")
|
||||
);
|
||||
|
||||
let gif = b"GIF89a...";
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, gif).as_deref(),
|
||||
Some("image/gif")
|
||||
);
|
||||
|
||||
// Unknown bytes should fall back to content-type header
|
||||
let unknown = [0x00, 0x01, 0x02];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(Some("image/webp"), &unknown).as_deref(),
|
||||
Some("image/webp")
|
||||
);
|
||||
|
||||
// Non-image content-type should be rejected
|
||||
assert_eq!(lark_detect_image_mime(Some("text/html"), &unknown), None);
|
||||
|
||||
// No info at all should return None
|
||||
assert_eq!(lark_detect_image_mime(None, &unknown), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_is_text_filename_recognizes_common_extensions() {
|
||||
assert!(lark_is_text_filename("script.py"));
|
||||
assert!(lark_is_text_filename("config.toml"));
|
||||
assert!(lark_is_text_filename("data.csv"));
|
||||
assert!(lark_is_text_filename("README.md"));
|
||||
assert!(!lark_is_text_filename("image.png"));
|
||||
assert!(!lark_is_text_filename("archive.zip"));
|
||||
assert!(!lark_is_text_filename("binary.exe"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_reaction_locale_explicit_language_tags() {
|
||||
assert_eq!(map_locale_tag("zh-CN"), Some(LarkAckLocale::ZhCn));
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
//! Link enricher: auto-detects URLs in inbound messages, fetches their content,
|
||||
//! and prepends summaries so the agent has link context without explicit tool calls.
|
||||
|
||||
use regex::Regex;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for the link enricher pipeline stage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinkEnricherConfig {
|
||||
pub enabled: bool,
|
||||
pub max_links: usize,
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for LinkEnricherConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// URL regex: matches http:// and https:// URLs, stopping at whitespace, angle
|
||||
/// brackets, or double-quotes.
|
||||
static URL_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r#"https?://[^\s<>"']+"#).expect("URL regex must compile"));
|
||||
|
||||
/// Extract URLs from message text, returning up to `max` unique URLs.
|
||||
pub fn extract_urls(text: &str, max: usize) -> Vec<String> {
|
||||
let mut seen = Vec::new();
|
||||
for m in URL_RE.find_iter(text) {
|
||||
let url = m.as_str().to_string();
|
||||
if !seen.contains(&url) {
|
||||
seen.push(url);
|
||||
if seen.len() >= max {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
/// Returns `true` if the URL points to a private/local address that should be
|
||||
/// blocked for SSRF protection.
|
||||
pub fn is_ssrf_target(url: &str) -> bool {
|
||||
let host = match extract_host(url) {
|
||||
Some(h) => h,
|
||||
None => return true, // unparseable URLs are rejected
|
||||
};
|
||||
|
||||
// Check hostname-based locals
|
||||
if host == "localhost"
|
||||
|| host.ends_with(".localhost")
|
||||
|| host.ends_with(".local")
|
||||
|| host == "local"
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check IP-based private ranges
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return is_private_ip(ip);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Extract the host portion from a URL string.
|
||||
fn extract_host(url: &str) -> Option<String> {
|
||||
let rest = url
|
||||
.strip_prefix("https://")
|
||||
.or_else(|| url.strip_prefix("http://"))?;
|
||||
let authority = rest.split(['/', '?', '#']).next()?;
|
||||
if authority.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Strip port
|
||||
let host = if authority.starts_with('[') {
|
||||
// IPv6 in brackets — reject for simplicity
|
||||
return None;
|
||||
} else {
|
||||
authority.split(':').next().unwrap_or(authority)
|
||||
};
|
||||
Some(host.to_lowercase())
|
||||
}
|
||||
|
||||
/// Check if an IP address falls within private/reserved ranges.
|
||||
fn is_private_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
v4.is_loopback() // 127.0.0.0/8
|
||||
|| v4.is_private() // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
|
||||
|| v4.is_link_local() // 169.254.0.0/16
|
||||
|| v4.is_unspecified() // 0.0.0.0
|
||||
|| v4.is_broadcast() // 255.255.255.255
|
||||
|| v4.is_multicast() // 224.0.0.0/4
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
v6.is_loopback() // ::1
|
||||
|| v6.is_unspecified() // ::
|
||||
|| v6.is_multicast()
|
||||
// Check for IPv4-mapped IPv6 addresses
|
||||
|| v6.to_ipv4_mapped().is_some_and(|v4| {
|
||||
v4.is_loopback()
|
||||
|| v4.is_private()
|
||||
|| v4.is_link_local()
|
||||
|| v4.is_unspecified()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the `<title>` tag content from HTML.
|
||||
pub fn extract_title(html: &str) -> Option<String> {
|
||||
// Case-insensitive search for <title>...</title>
|
||||
let lower = html.to_lowercase();
|
||||
let start = lower.find("<title")? + "<title".len();
|
||||
// Skip attributes if any (e.g. <title lang="en">)
|
||||
let start = lower[start..].find('>')? + start + 1;
|
||||
let end = lower[start..].find("</title")? + start;
|
||||
let title = lower[start..end].trim().to_string();
|
||||
if title.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(html_entity_decode_basic(&title))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the first `max_chars` of visible body text from HTML.
|
||||
pub fn extract_body_text(html: &str, max_chars: usize) -> String {
|
||||
let text = nanohtml2text::html2text(html);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.len() <= max_chars {
|
||||
trimmed.to_string()
|
||||
} else {
|
||||
let mut result: String = trimmed.chars().take(max_chars).collect();
|
||||
result.push_str("...");
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Basic HTML entity decoding for title content.
|
||||
fn html_entity_decode_basic(s: &str) -> String {
|
||||
s.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
}
|
||||
|
||||
/// Summary of a fetched link.
|
||||
struct LinkSummary {
|
||||
title: String,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
/// Fetch a single URL and extract a summary. Returns `None` on any failure.
|
||||
async fn fetch_link_summary(url: &str, timeout_secs: u64) -> Option<LinkSummary> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.redirect(reqwest::redirect::Policy::limited(5))
|
||||
.user_agent("ZeroClaw/0.1 (link-enricher)")
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let response = client.get(url).send().await.ok()?;
|
||||
if !response.status().is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Only process text/html responses
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
if !content_type.contains("text/html") && !content_type.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read up to 256KB to extract title and snippet
|
||||
let max_bytes: usize = 256 * 1024;
|
||||
let bytes = response.bytes().await.ok()?;
|
||||
let body = if bytes.len() > max_bytes {
|
||||
String::from_utf8_lossy(&bytes[..max_bytes]).into_owned()
|
||||
} else {
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
};
|
||||
|
||||
let title = extract_title(&body).unwrap_or_else(|| "Untitled".to_string());
|
||||
let snippet = extract_body_text(&body, 200);
|
||||
|
||||
Some(LinkSummary { title, snippet })
|
||||
}
|
||||
|
||||
/// Enrich a message by prepending link summaries for any URLs found in the text.
|
||||
///
|
||||
/// This is the main entry point called from the channel message processing pipeline.
|
||||
/// If the enricher is disabled or no URLs are found, the original message is returned
|
||||
/// unchanged.
|
||||
pub async fn enrich_message(content: &str, config: &LinkEnricherConfig) -> String {
|
||||
if !config.enabled || config.max_links == 0 {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let urls = extract_urls(content, config.max_links);
|
||||
if urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
// Filter out SSRF targets
|
||||
let safe_urls: Vec<&str> = urls
|
||||
.iter()
|
||||
.filter(|u| !is_ssrf_target(u))
|
||||
.map(|u| u.as_str())
|
||||
.collect();
|
||||
if safe_urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let mut enrichments = Vec::new();
|
||||
for url in safe_urls {
|
||||
match fetch_link_summary(url, config.timeout_secs).await {
|
||||
Some(summary) => {
|
||||
enrichments.push(format!("[Link: {} — {}]", summary.title, summary.snippet));
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(url, "Link enricher: failed to fetch or extract summary");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if enrichments.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let prefix = enrichments.join("\n");
|
||||
format!("{prefix}\n{content}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── URL extraction ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_urls_finds_http_and_https() {
|
||||
let text = "Check https://example.com and http://test.org/page for info";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com", "http://test.org/page",]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_respects_max() {
|
||||
let text = "https://a.com https://b.com https://c.com https://d.com";
|
||||
let urls = extract_urls(text, 2);
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert_eq!(urls[0], "https://a.com");
|
||||
assert_eq!(urls[1], "https://b.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_deduplicates() {
|
||||
let text = "Visit https://example.com and https://example.com again";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_handles_no_urls() {
|
||||
let text = "Just a normal message without links";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert!(urls.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_angle_brackets() {
|
||||
let text = "Link: <https://example.com/path> done";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/path"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_quotes() {
|
||||
let text = r#"href="https://example.com/page" end"#;
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/page"]);
|
||||
}
|
||||
|
||||
// ── SSRF protection ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_localhost() {
|
||||
assert!(is_ssrf_target("http://localhost/admin"));
|
||||
assert!(is_ssrf_target("https://localhost:8080/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_loopback_ip() {
|
||||
assert!(is_ssrf_target("http://127.0.0.1/secret"));
|
||||
assert!(is_ssrf_target("http://127.0.0.2:9090"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_10_network() {
|
||||
assert!(is_ssrf_target("http://10.0.0.1/internal"));
|
||||
assert!(is_ssrf_target("http://10.255.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_172_network() {
|
||||
assert!(is_ssrf_target("http://172.16.0.1/admin"));
|
||||
assert!(is_ssrf_target("http://172.31.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_192_168_network() {
|
||||
assert!(is_ssrf_target("http://192.168.1.1/router"));
|
||||
assert!(is_ssrf_target("http://192.168.0.100:3000"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_link_local() {
|
||||
assert!(is_ssrf_target("http://169.254.0.1/metadata"));
|
||||
assert!(is_ssrf_target("http://169.254.169.254/latest"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_ipv6_loopback() {
|
||||
// IPv6 in brackets is rejected by extract_host
|
||||
assert!(is_ssrf_target("http://[::1]/admin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_dot_local() {
|
||||
assert!(is_ssrf_target("http://myhost.local/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_allows_public_urls() {
|
||||
assert!(!is_ssrf_target("https://example.com/page"));
|
||||
assert!(!is_ssrf_target("https://www.google.com"));
|
||||
assert!(!is_ssrf_target("http://93.184.216.34/resource"));
|
||||
}
|
||||
|
||||
// ── Title extraction ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_title_basic() {
|
||||
let html = "<html><head><title>My Page Title</title></head><body>Hello</body></html>";
|
||||
assert_eq!(extract_title(html), Some("my page title".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_with_entities() {
|
||||
let html = "<title>Tom & Jerry's Page</title>";
|
||||
assert_eq!(extract_title(html), Some("tom & jerry's page".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_case_insensitive() {
|
||||
let html = "<HTML><HEAD><TITLE>Upper Case</TITLE></HEAD></HTML>";
|
||||
assert_eq!(extract_title(html), Some("upper case".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_multibyte_chars_no_panic() {
|
||||
// İ (U+0130) lowercases to 2 chars, changing byte length.
|
||||
// This must not panic or produce wrong offsets.
|
||||
let html = "<title>İstanbul Guide</title>";
|
||||
let result = extract_title(html);
|
||||
assert!(result.is_some());
|
||||
let title = result.unwrap();
|
||||
assert!(title.contains("stanbul"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_missing() {
|
||||
let html = "<html><body>No title here</body></html>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_empty() {
|
||||
let html = "<title> </title>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
// ── Body text extraction ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_strips_html() {
|
||||
let html = "<html><body><h1>Header</h1><p>Some content here</p></body></html>";
|
||||
let text = extract_body_text(html, 200);
|
||||
assert!(text.contains("Header"));
|
||||
assert!(text.contains("Some content"));
|
||||
assert!(!text.contains("<h1>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_truncates() {
|
||||
let html = "<p>A very long paragraph that should be truncated to fit within the limit.</p>";
|
||||
let text = extract_body_text(html, 20);
|
||||
assert!(text.len() <= 25); // 20 chars + "..."
|
||||
assert!(text.ends_with("..."));
|
||||
}
|
||||
|
||||
// ── Config toggle ───────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_disabled_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Check https://example.com for details";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_no_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "No links in this message";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_ssrf_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Try http://127.0.0.1/admin and http://192.168.1.1/router";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_is_disabled() {
|
||||
let config = LinkEnricherConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_links, 3);
|
||||
assert_eq!(config.timeout_secs, 10);
|
||||
}
|
||||
}
|
||||
+210
-7
@@ -8,6 +8,7 @@ use matrix_sdk::{
|
||||
events::reaction::ReactionEventContent,
|
||||
events::receipt::ReceiptThread,
|
||||
events::relation::{Annotation, Thread},
|
||||
events::room::member::StrippedRoomMemberEvent,
|
||||
events::room::message::{
|
||||
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||
},
|
||||
@@ -32,6 +33,7 @@ pub struct MatrixChannel {
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
session_owner_hint: Option<String>,
|
||||
session_device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
@@ -48,6 +50,7 @@ impl std::fmt::Debug for MatrixChannel {
|
||||
.field("homeserver", &self.homeserver)
|
||||
.field("room_id", &self.room_id)
|
||||
.field("allowed_users", &self.allowed_users)
|
||||
.field("allowed_rooms", &self.allowed_rooms)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -121,7 +124,16 @@ impl MatrixChannel {
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint(homeserver, access_token, room_id, allowed_users, None, None)
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_with_session_hint(
|
||||
@@ -132,11 +144,12 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint_and_zeroclaw_dir(
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
None,
|
||||
@@ -151,6 +164,28 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
zeroclaw_dir,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_full(
|
||||
homeserver: String,
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
let homeserver = homeserver.trim_end_matches('/').to_string();
|
||||
let access_token = access_token.trim().to_string();
|
||||
@@ -160,12 +195,18 @@ impl MatrixChannel {
|
||||
.map(|user| user.trim().to_string())
|
||||
.filter(|user| !user.is_empty())
|
||||
.collect();
|
||||
let allowed_rooms = allowed_rooms
|
||||
.into_iter()
|
||||
.map(|room| room.trim().to_string())
|
||||
.filter(|room| !room.is_empty())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms,
|
||||
session_owner_hint: Self::normalize_optional_field(owner_hint),
|
||||
session_device_id_hint: Self::normalize_optional_field(device_id_hint),
|
||||
zeroclaw_dir,
|
||||
@@ -220,6 +261,21 @@ impl MatrixChannel {
|
||||
allowed_users.iter().any(|u| u.eq_ignore_ascii_case(sender))
|
||||
}
|
||||
|
||||
/// Check whether a room (by its canonical ID) is in the allowed_rooms list.
|
||||
/// If allowed_rooms is empty, all rooms are allowed.
|
||||
fn is_room_allowed_static(allowed_rooms: &[String], room_id: &str) -> bool {
|
||||
if allowed_rooms.is_empty() {
|
||||
return true;
|
||||
}
|
||||
allowed_rooms
|
||||
.iter()
|
||||
.any(|r| r.eq_ignore_ascii_case(room_id))
|
||||
}
|
||||
|
||||
fn is_room_allowed(&self, room_id: &str) -> bool {
|
||||
Self::is_room_allowed_static(&self.allowed_rooms, room_id)
|
||||
}
|
||||
|
||||
fn is_supported_message_type(msgtype: &str) -> bool {
|
||||
matches!(msgtype, "m.text" | "m.notice")
|
||||
}
|
||||
@@ -228,6 +284,10 @@ impl MatrixChannel {
|
||||
!body.trim().is_empty()
|
||||
}
|
||||
|
||||
fn room_matches_target(target_room_id: &str, incoming_room_id: &str) -> bool {
|
||||
target_room_id == incoming_room_id
|
||||
}
|
||||
|
||||
fn cache_event_id(
|
||||
event_id: &str,
|
||||
recent_order: &mut std::collections::VecDeque<String>,
|
||||
@@ -526,8 +586,9 @@ impl MatrixChannel {
|
||||
if client.encryption().backups().are_enabled().await {
|
||||
tracing::info!("Matrix room-key backup is enabled for this device.");
|
||||
} else {
|
||||
client.encryption().backups().disable().await;
|
||||
tracing::warn!(
|
||||
"Matrix room-key backup is not enabled for this device; `matrix_sdk_crypto::backups` warnings about missing backup keys may appear until recovery is configured."
|
||||
"Matrix room-key backup is not enabled for this device; automatic backup attempts have been disabled to suppress recurring warnings. To enable backups, configure server-side key backup and recovery for this device."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -697,6 +758,7 @@ impl Channel for MatrixChannel {
|
||||
let target_room_for_handler = target_room.clone();
|
||||
let my_user_id_for_handler = my_user_id.clone();
|
||||
let allowed_users_for_handler = self.allowed_users.clone();
|
||||
let allowed_rooms_for_handler = self.allowed_rooms.clone();
|
||||
let dedupe_for_handler = Arc::clone(&recent_event_cache);
|
||||
let homeserver_for_handler = self.homeserver.clone();
|
||||
let access_token_for_handler = self.access_token.clone();
|
||||
@@ -704,18 +766,29 @@ impl Channel for MatrixChannel {
|
||||
|
||||
client.add_event_handler(move |event: OriginalSyncRoomMessageEvent, room: Room| {
|
||||
let tx = tx_handler.clone();
|
||||
let _target_room = target_room_for_handler.clone();
|
||||
let target_room = target_room_for_handler.clone();
|
||||
let my_user_id = my_user_id_for_handler.clone();
|
||||
let allowed_users = allowed_users_for_handler.clone();
|
||||
let allowed_rooms = allowed_rooms_for_handler.clone();
|
||||
let dedupe = Arc::clone(&dedupe_for_handler);
|
||||
let homeserver = homeserver_for_handler.clone();
|
||||
let access_token = access_token_for_handler.clone();
|
||||
let voice_mode = Arc::clone(&voice_mode_for_handler);
|
||||
|
||||
async move {
|
||||
if false
|
||||
/* multi-room: room_id filter disabled */
|
||||
{
|
||||
if !MatrixChannel::room_matches_target(
|
||||
target_room.as_str(),
|
||||
room.room_id().as_str(),
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Room allowlist: skip messages from rooms not in the configured list
|
||||
if !MatrixChannel::is_room_allowed_static(&allowed_rooms, room.room_id().as_ref()) {
|
||||
tracing::debug!(
|
||||
"Matrix: ignoring message from room {} (not in allowed_rooms)",
|
||||
room.room_id()
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -907,6 +980,45 @@ impl Channel for MatrixChannel {
|
||||
}
|
||||
});
|
||||
|
||||
// Invite handler: auto-accept invites for allowed rooms, auto-reject others
|
||||
let allowed_rooms_for_invite = self.allowed_rooms.clone();
|
||||
client.add_event_handler(move |event: StrippedRoomMemberEvent, room: Room| {
|
||||
let allowed_rooms = allowed_rooms_for_invite.clone();
|
||||
async move {
|
||||
// Only process invite events targeting us
|
||||
if event.content.membership
|
||||
!= matrix_sdk::ruma::events::room::member::MembershipState::Invite
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let room_id_str = room.room_id().to_string();
|
||||
|
||||
if MatrixChannel::is_room_allowed_static(&allowed_rooms, &room_id_str) {
|
||||
// Room is allowed (or no allowlist configured): auto-accept
|
||||
tracing::info!(
|
||||
"Matrix: auto-accepting invite for allowed room {}",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.join().await {
|
||||
tracing::warn!("Matrix: failed to auto-join room {}: {error}", room_id_str);
|
||||
}
|
||||
} else {
|
||||
// Room is NOT in allowlist: auto-reject
|
||||
tracing::info!(
|
||||
"Matrix: auto-rejecting invite for room {} (not in allowed_rooms)",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.leave().await {
|
||||
tracing::warn!(
|
||||
"Matrix: failed to reject invite for room {}: {error}",
|
||||
room_id_str
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sync_settings = SyncSettings::new().timeout(std::time::Duration::from_secs(30));
|
||||
client
|
||||
.sync_with_result_callback(sync_settings, |sync_result| {
|
||||
@@ -1294,6 +1406,22 @@ mod tests {
|
||||
assert_eq!(value["room"]["timeline"]["limit"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_matches_configured_room() {
|
||||
assert!(MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!ops:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_rejects_other_rooms() {
|
||||
assert!(!MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!other:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_id_cache_deduplicates_and_evicts_old_entries() {
|
||||
let mut recent_order = std::collections::VecDeque::new();
|
||||
@@ -1549,4 +1677,79 @@ mod tests {
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.rooms.join.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowed_rooms_permits_all() {
|
||||
let ch = make_channel();
|
||||
assert!(ch.is_room_allowed("!any:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!other:evil.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_filters_by_id() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec!["!allowed:matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!allowed:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!forbidden:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_supports_aliases() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec![
|
||||
"#ops:matrix.org".to_string(),
|
||||
"!direct:matrix.org".to_string(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!direct:matrix.org"));
|
||||
assert!(ch.is_room_allowed("#ops:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!other:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_case_insensitive() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec!["!Room:Matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!ROOM:MATRIX.ORG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_trims_whitespace() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec![" !room:matrix.org ".to_string(), " ".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(ch.allowed_rooms.len(), 1);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
}
|
||||
}
|
||||
|
||||
+48
-18
@@ -26,6 +26,7 @@ pub mod imessage;
|
||||
pub mod irc;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
pub mod lark;
|
||||
pub mod link_enricher;
|
||||
pub mod linq;
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
pub mod matrix;
|
||||
@@ -2066,6 +2067,25 @@ async fn process_channel_message(
|
||||
msg
|
||||
};
|
||||
|
||||
// ── Link enricher: prepend URL summaries before agent sees the message ──
|
||||
let le_config = &ctx.prompt_config.link_enricher;
|
||||
if le_config.enabled {
|
||||
let enricher_cfg = link_enricher::LinkEnricherConfig {
|
||||
enabled: le_config.enabled,
|
||||
max_links: le_config.max_links,
|
||||
timeout_secs: le_config.timeout_secs,
|
||||
};
|
||||
let enriched = link_enricher::enrich_message(&msg.content, &enricher_cfg).await;
|
||||
if enriched != msg.content {
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Link enricher: prepended URL summaries to message"
|
||||
);
|
||||
msg.content = enriched;
|
||||
}
|
||||
}
|
||||
|
||||
let target_channel = ctx
|
||||
.channels_by_name
|
||||
.get(&msg.channel)
|
||||
@@ -3670,13 +3690,16 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
.discord
|
||||
.as_ref()
|
||||
.context("Discord channel is not configured")?;
|
||||
Ok(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)))
|
||||
Ok(Arc::new(
|
||||
DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
"slack" => {
|
||||
let sl = config
|
||||
@@ -3692,7 +3715,8 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
other => anyhow::bail!("Unknown channel '{other}'. Supported: telegram, discord, slack"),
|
||||
@@ -3778,7 +3802,8 @@ fn collect_configured_channels(
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_proxy_url(dc.proxy_url.clone()),
|
||||
.with_proxy_url(dc.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3822,7 +3847,8 @@ fn collect_configured_channels(
|
||||
.with_thread_replies(sl.thread_replies.unwrap_or(true))
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(sl.proxy_url.clone()),
|
||||
.with_proxy_url(sl.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3855,11 +3881,12 @@ fn collect_configured_channels(
|
||||
if let Some(ref mx) = config.channels_config.matrix {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Matrix",
|
||||
channel: Arc::new(MatrixChannel::new_with_session_hint_and_zeroclaw_dir(
|
||||
channel: Arc::new(MatrixChannel::new_full(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
mx.allowed_rooms.clone(),
|
||||
mx.user_id.clone(),
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
@@ -6875,6 +6902,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: "2026-02-20T00:00:00Z".to_string(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -7696,9 +7726,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(prompt.contains("<instructions>"));
|
||||
assert!(prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
assert!(!prompt.contains("loaded on demand"));
|
||||
}
|
||||
|
||||
@@ -7741,10 +7771,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -12,6 +12,8 @@ use chrono::{DateTime, Utc};
|
||||
pub struct SessionMetadata {
|
||||
/// Session key (e.g. `telegram_user123`).
|
||||
pub key: String,
|
||||
/// Optional human-readable name (e.g. `eyrie-commander-briefing`).
|
||||
pub name: Option<String>,
|
||||
/// When the session was first created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the last message was appended.
|
||||
@@ -54,6 +56,7 @@ pub trait SessionBackend: Send + Sync {
|
||||
let messages = self.load(&key);
|
||||
SessionMetadata {
|
||||
key,
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: messages.len(),
|
||||
@@ -81,6 +84,16 @@ pub trait SessionBackend: Send + Sync {
|
||||
fn delete_session(&self, _session_key: &str) -> std::io::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Set or update the human-readable name for a session.
|
||||
fn set_session_name(&self, _session_key: &str, _name: &str) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the human-readable name for a session (if set).
|
||||
fn get_session_name(&self, _session_key: &str) -> std::io::Result<Option<String>> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -91,6 +104,7 @@ mod tests {
|
||||
fn session_metadata_is_constructible() {
|
||||
let meta = SessionMetadata {
|
||||
key: "test".into(),
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: 5,
|
||||
|
||||
@@ -51,7 +51,8 @@ impl SqliteSessionBackend {
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
last_activity TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
message_count INTEGER NOT NULL DEFAULT 0,
|
||||
name TEXT
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
|
||||
@@ -69,6 +70,18 @@ impl SqliteSessionBackend {
|
||||
)
|
||||
.context("Failed to initialize session schema")?;
|
||||
|
||||
// Migration: add name column to existing databases
|
||||
let has_name: bool = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) > 0 FROM pragma_table_info('session_metadata') WHERE name = 'name'",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap_or(false);
|
||||
if !has_name {
|
||||
let _ = conn.execute("ALTER TABLE session_metadata ADD COLUMN name TEXT", []);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
@@ -226,7 +239,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_key, created_at, last_activity, message_count
|
||||
"SELECT session_key, created_at, last_activity, message_count, name
|
||||
FROM session_metadata ORDER BY last_activity DESC",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
@@ -238,6 +251,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
let created_str: String = row.get(1)?;
|
||||
let activity_str: String = row.get(2)?;
|
||||
let count: i64 = row.get(3)?;
|
||||
let name: Option<String> = row.get(4)?;
|
||||
|
||||
let created = DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
@@ -249,6 +263,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(SessionMetadata {
|
||||
key,
|
||||
name,
|
||||
created_at: created,
|
||||
last_activity: activity,
|
||||
message_count: count as usize,
|
||||
@@ -321,6 +336,27 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn set_session_name(&self, session_key: &str, name: &str) -> std::io::Result<()> {
|
||||
let conn = self.conn.lock();
|
||||
let name_val = if name.is_empty() { None } else { Some(name) };
|
||||
conn.execute(
|
||||
"UPDATE session_metadata SET name = ?1 WHERE session_key = ?2",
|
||||
params![name_val, session_key],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_session_name(&self, session_key: &str) -> std::io::Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
conn.query_row(
|
||||
"SELECT name FROM session_metadata WHERE session_key = ?1",
|
||||
params![session_key],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
|
||||
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
let Some(keyword) = &query.keyword else {
|
||||
return self.list_sessions_with_metadata();
|
||||
@@ -357,14 +393,16 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
keys.iter()
|
||||
.filter_map(|key| {
|
||||
conn.query_row(
|
||||
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
|
||||
"SELECT created_at, last_activity, message_count, name FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
|row| {
|
||||
let created_str: String = row.get(0)?;
|
||||
let activity_str: String = row.get(1)?;
|
||||
let count: i64 = row.get(2)?;
|
||||
let name: Option<String> = row.get(3)?;
|
||||
Ok(SessionMetadata {
|
||||
key: key.clone(),
|
||||
name,
|
||||
created_at: DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
@@ -555,4 +593,55 @@ mod tests {
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_persists() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "My Session").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert_eq!(meta[0].name.as_deref(), Some("My Session"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_updates_existing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "First").unwrap();
|
||||
backend.set_session_name("s1", "Second").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta[0].name.as_deref(), Some("Second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sessions_without_name_return_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_name_clears_to_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "Named").unwrap();
|
||||
backend.set_session_name("s1", "").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ pub struct SlackChannel {
|
||||
active_assistant_thread: Mutex<HashMap<String, String>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio file attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@@ -125,6 +128,7 @@ impl SlackChannel {
|
||||
workspace_dir: None,
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +162,14 @@ impl SlackChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio file attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client_with_timeouts(
|
||||
"channel.slack",
|
||||
@@ -558,6 +570,13 @@ impl SlackChannel {
|
||||
.await
|
||||
.unwrap_or_else(|| raw_file.clone());
|
||||
|
||||
// Voice / audio transcription: if transcription is configured and the
|
||||
// file looks like an audio attachment, download and transcribe it.
|
||||
if Self::is_audio_file(&file) {
|
||||
if let Some(transcribed) = self.try_transcribe_audio_file(&file).await {
|
||||
return Some(transcribed);
|
||||
}
|
||||
}
|
||||
if Self::is_image_file(&file) {
|
||||
if let Some(marker) = self.fetch_image_marker(&file).await {
|
||||
return Some(marker);
|
||||
@@ -1449,6 +1468,106 @@ impl SlackChannel {
|
||||
.is_some_and(|ext| Self::mime_from_extension(ext).is_some())
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check whether a Slack file object looks like an audio attachment
|
||||
/// (voice memo, audio message, or uploaded audio file).
|
||||
fn is_audio_file(file: &serde_json::Value) -> bool {
|
||||
// Slack voice messages use subtype "slack_audio"
|
||||
if let Some(subtype) = file.get("subtype").and_then(|v| v.as_str()) {
|
||||
if subtype == "slack_audio" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if Self::slack_file_mime(file)
|
||||
.as_deref()
|
||||
.is_some_and(|mime| mime.starts_with("audio/"))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(ft) = file
|
||||
.get("filetype")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|v| v.to_ascii_lowercase())
|
||||
{
|
||||
if Self::AUDIO_EXTENSIONS.contains(&ft.as_str()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Self::file_extension(&Self::slack_file_name(file))
|
||||
.as_deref()
|
||||
.is_some_and(|ext| Self::AUDIO_EXTENSIONS.contains(&ext))
|
||||
}
|
||||
|
||||
/// Download an audio file attachment and transcribe it using the configured
|
||||
/// transcription provider. Returns `None` if transcription is not configured
|
||||
/// or if the download/transcription fails.
|
||||
async fn try_transcribe_audio_file(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let file_name = Self::slack_file_name(file);
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
let resp = self.fetch_slack_private_file(url).await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
tracing::warn!(
|
||||
"Slack voice file download failed for {} ({status})",
|
||||
redacted_url
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let audio_data = match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice file read failed for {}: {e}", redacted_url);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Determine a filename with extension for the transcription API.
|
||||
let transcription_filename = if Self::file_extension(&file_name).is_some() {
|
||||
file_name.clone()
|
||||
} else {
|
||||
// Fall back to extension from mimetype or default to .ogg
|
||||
let mime_ext = Self::slack_file_mime(file)
|
||||
.and_then(|mime| mime.rsplit('/').next().map(|s| s.to_string()))
|
||||
.unwrap_or_else(|| "ogg".to_string());
|
||||
format!("voice.{mime_ext}")
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, &transcription_filename, config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::info!("Slack voice transcription returned empty text, skipping");
|
||||
None
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Slack: transcribed voice file {} ({} chars)",
|
||||
file_name,
|
||||
trimmed.len()
|
||||
);
|
||||
Some(format!("[Voice] {trimmed}"))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice transcription failed for {}: {e}", file_name);
|
||||
Some(Self::format_attachment_summary(file))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_text_snippet(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
@@ -1140,6 +1140,11 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content = format!("{quote}\n\n{content}");
|
||||
}
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
content = format!("{attr}{content}");
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1263,6 +1268,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
format!("[Voice] {text}")
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1299,6 +1311,41 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
(username, sender_id, sender_identity)
|
||||
}
|
||||
|
||||
/// Build a forwarding attribution prefix from Telegram forward fields.
|
||||
///
|
||||
/// Returns `Some("[Forwarded from ...] ")` when the message is forwarded,
|
||||
/// `None` otherwise.
|
||||
fn format_forward_attribution(message: &serde_json::Value) -> Option<String> {
|
||||
if let Some(from_chat) = message.get("forward_from_chat") {
|
||||
// Forwarded from a channel or group
|
||||
let title = from_chat
|
||||
.get("title")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown channel");
|
||||
Some(format!("[Forwarded from channel: {title}] "))
|
||||
} else if let Some(from_user) = message.get("forward_from") {
|
||||
// Forwarded from a user (privacy allows identity)
|
||||
let label = from_user
|
||||
.get("username")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|u| format!("@{u}"))
|
||||
.or_else(|| {
|
||||
from_user
|
||||
.get("first_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(String::from)
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
Some(format!("[Forwarded from {label}] "))
|
||||
} else {
|
||||
// Forwarded from a user who hides their identity
|
||||
message
|
||||
.get("forward_sender_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|name| format!("[Forwarded from {name}] "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract reply context from a Telegram `reply_to_message`, if present.
|
||||
fn extract_reply_context(&self, message: &serde_json::Value) -> Option<String> {
|
||||
let reply = message.get("reply_to_message")?;
|
||||
@@ -1420,6 +1467,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
// Exit voice-chat mode when user switches back to typing
|
||||
if let Ok(mut vc) = self.voice_chats.lock() {
|
||||
vc.remove(&reply_target);
|
||||
@@ -4871,4 +4925,153 @@ mod tests {
|
||||
TelegramChannel::new("token".into(), vec!["*".into()], false).with_ack_reactions(true);
|
||||
assert!(ch.ack_reactions);
|
||||
}
|
||||
|
||||
// ── Forwarded message tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_with_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 100,
|
||||
"message": {
|
||||
"message_id": 50,
|
||||
"text": "Check this out",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"first_name": "Bob",
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from @bob] Check this out");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_channel() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 101,
|
||||
"message": {
|
||||
"message_id": 51,
|
||||
"text": "Breaking news",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from_chat": {
|
||||
"id": -1_001_234_567_890_i64,
|
||||
"title": "Daily News",
|
||||
"username": "dailynews",
|
||||
"type": "channel"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("channel-forwarded message should parse");
|
||||
assert_eq!(
|
||||
msg.content,
|
||||
"[Forwarded from channel: Daily News] Breaking news"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_hidden_sender() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 102,
|
||||
"message": {
|
||||
"message_id": 52,
|
||||
"text": "Secret tip",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_sender_name": "Hidden User",
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("hidden-sender forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Hidden User] Secret tip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_non_forwarded_unaffected() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 103,
|
||||
"message": {
|
||||
"message_id": 53,
|
||||
"text": "Normal message",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 }
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("non-forwarded message should parse");
|
||||
assert_eq!(msg.content, "Normal message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_no_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 104,
|
||||
"message": {
|
||||
"message_id": 54,
|
||||
"text": "Hello there",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 77,
|
||||
"first_name": "Charlie"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message without username should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Charlie] Hello there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarded_photo_attachment_has_attribution() {
|
||||
// Verify that format_forward_attribution produces correct prefix
|
||||
// for a photo message (the actual download is async, so we test the
|
||||
// helper directly with a photo-bearing message structure).
|
||||
let message = serde_json::json!({
|
||||
"message_id": 60,
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"photo": [
|
||||
{ "file_id": "abc123", "file_unique_id": "u1", "width": 320, "height": 240 }
|
||||
],
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
});
|
||||
|
||||
let attr =
|
||||
TelegramChannel::format_forward_attribution(&message).expect("should detect forward");
|
||||
assert_eq!(attr, "[Forwarded from @bob] ");
|
||||
|
||||
// Simulate what try_parse_attachment_message does after building content
|
||||
let photo_content = "[IMAGE:/tmp/photo.jpg]".to_string();
|
||||
let content = format!("{attr}{photo_content}");
|
||||
assert_eq!(content, "[Forwarded from @bob] [IMAGE:/tmp/photo.jpg]");
|
||||
}
|
||||
}
|
||||
|
||||
+130
-1
@@ -1,6 +1,7 @@
|
||||
//! Multi-provider Text-to-Speech (TTS) subsystem.
|
||||
//!
|
||||
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, and Edge TTS (free, subprocess-based).
|
||||
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, Edge TTS (free, subprocess-based),
|
||||
//! and Piper TTS (local GPU-accelerated, OpenAI-compatible endpoint).
|
||||
//! Provider selection is driven by [`TtsConfig`] in `config.toml`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
@@ -451,6 +452,80 @@ impl TtsProvider for EdgeTtsProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Piper TTS (local, OpenAI-compatible) ─────────────────────────
|
||||
|
||||
/// Piper TTS provider — local GPU-accelerated server with an OpenAI-compatible endpoint.
|
||||
pub struct PiperTtsProvider {
|
||||
client: reqwest::Client,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl PiperTtsProvider {
|
||||
/// Create a new Piper TTS provider pointing at the given API URL.
|
||||
pub fn new(api_url: &str) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.timeout(TTS_HTTP_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to build HTTP client for Piper TTS"),
|
||||
api_url: api_url.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TtsProvider for PiperTtsProvider {
|
||||
fn name(&self) -> &str {
|
||||
"piper"
|
||||
}
|
||||
|
||||
async fn synthesize(&self, text: &str, voice: &str) -> Result<Vec<u8>> {
|
||||
let body = serde_json::json!({
|
||||
"model": "tts-1",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&self.api_url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send Piper TTS request")?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let error_body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.unwrap_or_else(|_| serde_json::json!({"error": "unknown"}));
|
||||
let msg = error_body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown error");
|
||||
bail!("Piper TTS API error ({}): {}", status, msg);
|
||||
}
|
||||
|
||||
let bytes = resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("Failed to read Piper TTS response body")?;
|
||||
Ok(bytes.to_vec())
|
||||
}
|
||||
|
||||
fn supported_voices(&self) -> Vec<String> {
|
||||
// Piper voices depend on installed models; return empty (dynamic).
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn supported_formats(&self) -> Vec<String> {
|
||||
["mp3", "wav", "opus"]
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── TtsManager ───────────────────────────────────────────────────
|
||||
|
||||
/// Central manager for multi-provider TTS synthesis.
|
||||
@@ -510,6 +585,11 @@ impl TtsManager {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref piper_cfg) = config.piper {
|
||||
let provider = PiperTtsProvider::new(&piper_cfg.api_url);
|
||||
providers.insert("piper".to_string(), Box::new(provider));
|
||||
}
|
||||
|
||||
let max_text_length = if config.max_text_length == 0 {
|
||||
DEFAULT_MAX_TEXT_LENGTH
|
||||
} else {
|
||||
@@ -652,6 +732,54 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn piper_provider_creation() {
|
||||
let provider = PiperTtsProvider::new("http://127.0.0.1:5000/v1/audio/speech");
|
||||
assert_eq!(provider.name(), "piper");
|
||||
assert_eq!(provider.api_url, "http://127.0.0.1:5000/v1/audio/speech");
|
||||
assert_eq!(provider.supported_formats(), vec!["mp3", "wav", "opus"]);
|
||||
// Piper voices depend on installed models; list is empty.
|
||||
assert!(provider.supported_voices().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tts_manager_with_piper_provider() {
|
||||
let mut config = default_tts_config();
|
||||
config.default_provider = "piper".to_string();
|
||||
config.piper = Some(crate::config::PiperTtsConfig {
|
||||
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
|
||||
});
|
||||
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
assert_eq!(manager.available_providers(), vec!["piper"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tts_rejects_empty_text_for_piper() {
|
||||
let mut config = default_tts_config();
|
||||
config.default_provider = "piper".to_string();
|
||||
config.piper = Some(crate::config::PiperTtsConfig {
|
||||
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
|
||||
});
|
||||
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
let err = manager
|
||||
.synthesize_with_provider("", "piper", "default")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not be empty"),
|
||||
"expected empty-text error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn piper_not_registered_when_config_is_none() {
|
||||
let config = default_tts_config();
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
assert!(!manager.available_providers().contains(&"piper".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tts_config_defaults() {
|
||||
let config = TtsConfig::default();
|
||||
@@ -664,6 +792,7 @@ mod tests {
|
||||
assert!(config.elevenlabs.is_none());
|
||||
assert!(config.google.is_none());
|
||||
assert!(config.edge.is_none());
|
||||
assert!(config.piper.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
+49
-104
@@ -21,12 +21,6 @@ use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
|
||||
// ── State machine ──────────────────────────────────────────────
|
||||
|
||||
/// Maximum allowed capture duration (seconds) to prevent unbounded memory growth.
|
||||
const MAX_CAPTURE_SECS_LIMIT: u32 = 300;
|
||||
|
||||
/// Minimum silence timeout to prevent API hammering.
|
||||
const MIN_SILENCE_TIMEOUT_MS: u32 = 100;
|
||||
|
||||
/// Internal states for the wake-word detector.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WakeState {
|
||||
@@ -36,6 +30,8 @@ pub enum WakeState {
|
||||
Triggered,
|
||||
/// Wake word confirmed — capturing the full utterance that follows.
|
||||
Capturing,
|
||||
/// Captured audio is being transcribed.
|
||||
Processing,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WakeState {
|
||||
@@ -44,6 +40,7 @@ impl std::fmt::Display for WakeState {
|
||||
Self::Listening => write!(f, "Listening"),
|
||||
Self::Triggered => write!(f, "Triggered"),
|
||||
Self::Capturing => write!(f, "Capturing"),
|
||||
Self::Processing => write!(f, "Processing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,97 +78,55 @@ impl Channel for VoiceWakeChannel {
|
||||
let config = self.config.clone();
|
||||
let transcription_config = self.transcription_config.clone();
|
||||
|
||||
// ── Validate config ───────────────────────────────────
|
||||
let energy_threshold = config.energy_threshold;
|
||||
if !energy_threshold.is_finite() || energy_threshold <= 0.0 {
|
||||
bail!("VoiceWake: energy_threshold must be a positive finite number, got {energy_threshold}");
|
||||
}
|
||||
if config.silence_timeout_ms < MIN_SILENCE_TIMEOUT_MS {
|
||||
bail!(
|
||||
"VoiceWake: silence_timeout_ms must be >= {MIN_SILENCE_TIMEOUT_MS}, got {}",
|
||||
config.silence_timeout_ms
|
||||
);
|
||||
}
|
||||
let max_capture_secs = config.max_capture_secs.min(MAX_CAPTURE_SECS_LIMIT);
|
||||
if max_capture_secs != config.max_capture_secs {
|
||||
warn!(
|
||||
"VoiceWake: max_capture_secs clamped from {} to {MAX_CAPTURE_SECS_LIMIT}",
|
||||
config.max_capture_secs
|
||||
);
|
||||
}
|
||||
|
||||
// Run the blocking audio capture loop on a dedicated thread.
|
||||
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(64);
|
||||
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(4);
|
||||
|
||||
let energy_threshold = config.energy_threshold;
|
||||
let silence_timeout = Duration::from_millis(u64::from(config.silence_timeout_ms));
|
||||
let max_capture = Duration::from_secs(u64::from(max_capture_secs));
|
||||
let max_capture = Duration::from_secs(u64::from(config.max_capture_secs));
|
||||
let sample_rate: u32;
|
||||
let channels_count: u16;
|
||||
|
||||
// ── Initialise cpal stream ────────────────────────────
|
||||
// cpal::Stream is !Send, so we build and hold it on a dedicated thread.
|
||||
// When the listen function exits, the shutdown oneshot is dropped,
|
||||
// the thread exits, and the stream + microphone are released.
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let (init_tx, init_rx) = tokio::sync::oneshot::channel::<Result<(u32, u16)>>();
|
||||
{
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.ok_or_else(|| anyhow::anyhow!("No default audio input device available"))?;
|
||||
|
||||
let supported = device.default_input_config()?;
|
||||
sample_rate = supported.sample_rate().0;
|
||||
channels_count = supported.channels();
|
||||
|
||||
info!(
|
||||
device = ?device.name().unwrap_or_default(),
|
||||
sample_rate,
|
||||
channels = channels_count,
|
||||
"VoiceWake: opening audio input"
|
||||
);
|
||||
|
||||
let stream_config: cpal::StreamConfig = supported.into();
|
||||
let audio_tx_clone = audio_tx.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
let stream = device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
// Non-blocking: try_send and drop if full.
|
||||
let _ = audio_tx_clone.try_send(data.to_vec());
|
||||
},
|
||||
move |err| {
|
||||
warn!("VoiceWake: audio stream error: {err}");
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
|
||||
let result = (|| -> Result<(u32, u16, cpal::Stream)> {
|
||||
let host = cpal::default_host();
|
||||
let device = host.default_input_device().ok_or_else(|| {
|
||||
anyhow::anyhow!("No default audio input device available")
|
||||
})?;
|
||||
stream.play()?;
|
||||
|
||||
let supported = device.default_input_config()?;
|
||||
let sr = supported.sample_rate().0;
|
||||
let ch = supported.channels();
|
||||
|
||||
info!(
|
||||
device = ?device.name().unwrap_or_default(),
|
||||
sample_rate = sr,
|
||||
channels = ch,
|
||||
"VoiceWake: opening audio input"
|
||||
);
|
||||
|
||||
let stream_config: cpal::StreamConfig = supported.into();
|
||||
|
||||
let stream = device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let _ = audio_tx_clone.try_send(data.to_vec());
|
||||
},
|
||||
move |err| {
|
||||
warn!("VoiceWake: audio stream error: {err}");
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
|
||||
stream.play()?;
|
||||
Ok((sr, ch, stream))
|
||||
})();
|
||||
|
||||
match result {
|
||||
Ok((sr, ch, _stream)) => {
|
||||
let _ = init_tx.send(Ok((sr, ch)));
|
||||
// Hold the stream alive until shutdown is signalled.
|
||||
let _ = shutdown_rx.blocking_recv();
|
||||
debug!("VoiceWake: stream holder thread exiting");
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = init_tx.send(Err(e));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (sr, ch) = init_rx
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("VoiceWake: stream init thread panicked"))??;
|
||||
sample_rate = sr;
|
||||
channels_count = ch;
|
||||
// Keep the stream alive for the lifetime of the channel.
|
||||
// We leak it intentionally — the channel runs until the daemon shuts down.
|
||||
std::mem::forget(stream);
|
||||
}
|
||||
|
||||
// Drop the extra sender so the channel closes when the stream sender drops.
|
||||
@@ -185,10 +140,6 @@ impl Channel for VoiceWakeChannel {
|
||||
let mut capture_start = Instant::now();
|
||||
let mut msg_counter: u64 = 0;
|
||||
|
||||
// Hard cap on capture buffer: max_capture_secs * sample_rate * channels * 2 (safety margin).
|
||||
let max_buf_samples =
|
||||
max_capture_secs as usize * sample_rate as usize * channels_count as usize * 2;
|
||||
|
||||
info!(wake_word = %wake_word, "VoiceWake: entering listen loop");
|
||||
|
||||
while let Some(chunk) = audio_rx.recv().await {
|
||||
@@ -209,9 +160,7 @@ impl Channel for VoiceWakeChannel {
|
||||
}
|
||||
}
|
||||
WakeState::Triggered => {
|
||||
if capture_buf.len() + chunk.len() <= max_buf_samples {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
}
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
@@ -253,9 +202,7 @@ impl Channel for VoiceWakeChannel {
|
||||
}
|
||||
}
|
||||
WakeState::Capturing => {
|
||||
if capture_buf.len() + chunk.len() <= max_buf_samples {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
}
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
@@ -307,11 +254,13 @@ impl Channel for VoiceWakeChannel {
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
WakeState::Processing => {
|
||||
// Should not receive chunks while processing, but just buffer them.
|
||||
// State transitions happen above synchronously after transcription.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the stream holder thread to exit and release the microphone.
|
||||
drop(shutdown_tx);
|
||||
bail!("VoiceWake: audio stream ended unexpectedly");
|
||||
}
|
||||
}
|
||||
@@ -334,14 +283,8 @@ pub fn encode_wav_from_f32(samples: &[f32], sample_rate: u32, channels: u16) ->
|
||||
let bits_per_sample: u16 = 16;
|
||||
let byte_rate = u32::from(channels) * sample_rate * u32::from(bits_per_sample) / 8;
|
||||
let block_align = channels * bits_per_sample / 8;
|
||||
// Guard against u32 overflow — reject buffers that exceed WAV's 4 GB limit.
|
||||
let data_bytes = samples.len() * 2;
|
||||
assert!(
|
||||
u32::try_from(data_bytes).is_ok(),
|
||||
"audio buffer too large for WAV encoding ({data_bytes} bytes)"
|
||||
);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let data_len = data_bytes as u32;
|
||||
let data_len = (samples.len() * 2) as u32; // 16-bit = 2 bytes per sample; max ~25 MB
|
||||
let file_len = 36 + data_len;
|
||||
|
||||
let mut buf = Vec::with_capacity(file_len as usize + 8);
|
||||
@@ -389,6 +332,7 @@ mod tests {
|
||||
assert_eq!(WakeState::Listening.to_string(), "Listening");
|
||||
assert_eq!(WakeState::Triggered.to_string(), "Triggered");
|
||||
assert_eq!(WakeState::Capturing.to_string(), "Capturing");
|
||||
assert_eq!(WakeState::Processing.to_string(), "Processing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -557,6 +501,7 @@ mod tests {
|
||||
WakeState::Listening,
|
||||
WakeState::Triggered,
|
||||
WakeState::Capturing,
|
||||
WakeState::Processing,
|
||||
];
|
||||
for (i, a) in states.iter().enumerate() {
|
||||
for (j, b) in states.iter().enumerate() {
|
||||
|
||||
+15
-14
@@ -10,21 +10,22 @@ pub use schema::{
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
|
||||
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, CronJobDecl, CronScheduleDecl,
|
||||
DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation,
|
||||
GoogleWorkspaceConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig,
|
||||
LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig,
|
||||
Microsoft365Config, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
|
||||
NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig, OpenAiSttConfig,
|
||||
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PacingConfig,
|
||||
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkEnricherConfig, LinkedInConfig, LinkedInContentConfig,
|
||||
LinkedInImageConfig, LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig,
|
||||
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
|
||||
ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig,
|
||||
OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PiperTtsConfig,
|
||||
PluginsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
|
||||
+484
-104
@@ -265,6 +265,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub web_fetch: WebFetchConfig,
|
||||
|
||||
/// Link enricher configuration (`[link_enricher]`).
|
||||
#[serde(default)]
|
||||
pub link_enricher: LinkEnricherConfig,
|
||||
|
||||
/// Text browser tool configuration (`[text_browser]`).
|
||||
#[serde(default)]
|
||||
pub text_browser: TextBrowserConfig,
|
||||
@@ -1005,6 +1009,10 @@ fn default_edge_tts_binary_path() -> String {
|
||||
"edge-tts".into()
|
||||
}
|
||||
|
||||
fn default_piper_tts_api_url() -> String {
|
||||
"http://127.0.0.1:5000/v1/audio/speech".into()
|
||||
}
|
||||
|
||||
/// Text-to-Speech configuration (`[tts]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TtsConfig {
|
||||
@@ -1035,6 +1043,9 @@ pub struct TtsConfig {
|
||||
/// Edge TTS provider configuration (`[tts.edge]`).
|
||||
#[serde(default)]
|
||||
pub edge: Option<EdgeTtsConfig>,
|
||||
/// Piper TTS provider configuration (`[tts.piper]`).
|
||||
#[serde(default)]
|
||||
pub piper: Option<PiperTtsConfig>,
|
||||
}
|
||||
|
||||
impl Default for TtsConfig {
|
||||
@@ -1049,6 +1060,7 @@ impl Default for TtsConfig {
|
||||
elevenlabs: None,
|
||||
google: None,
|
||||
edge: None,
|
||||
piper: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1103,6 +1115,14 @@ pub struct EdgeTtsConfig {
|
||||
pub binary_path: String,
|
||||
}
|
||||
|
||||
/// Piper TTS provider configuration (local GPU-accelerated, OpenAI-compatible endpoint).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PiperTtsConfig {
|
||||
/// Base URL for the Piper TTS HTTP server (e.g. `"http://127.0.0.1:5000/v1/audio/speech"`).
|
||||
#[serde(default = "default_piper_tts_api_url")]
|
||||
pub api_url: String,
|
||||
}
|
||||
|
||||
/// Determines when a `ToolFilterGroup` is active.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -1258,6 +1278,10 @@ pub struct AgentConfig {
|
||||
/// Useful for small-context models (e.g. glm-4.5-air ~8K tokens → set to 8000).
|
||||
#[serde(default = "default_max_system_prompt_chars")]
|
||||
pub max_system_prompt_chars: usize,
|
||||
/// Thinking/reasoning level control. Configures how deeply the model reasons
|
||||
/// per message. Users can override per-message with `/think:<level>` directives.
|
||||
#[serde(default)]
|
||||
pub thinking: crate::agent::thinking::ThinkingConfig,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@@ -1292,6 +1316,7 @@ impl Default for AgentConfig {
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
tool_filter_groups: Vec::new(),
|
||||
max_system_prompt_chars: default_max_system_prompt_chars(),
|
||||
thinking: crate::agent::thinking::ThinkingConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1413,6 +1438,15 @@ pub struct MultimodalConfig {
|
||||
/// Allow fetching remote image URLs (http/https). Disabled by default.
|
||||
#[serde(default)]
|
||||
pub allow_remote_fetch: bool,
|
||||
/// Provider name to use for vision/image messages (e.g. `"ollama"`).
|
||||
/// When set, messages containing `[IMAGE:]` markers are routed to this
|
||||
/// provider instead of the default text provider.
|
||||
#[serde(default)]
|
||||
pub vision_provider: Option<String>,
|
||||
/// Model to use when routing to the vision provider (e.g. `"llava:7b"`).
|
||||
/// Only used when `vision_provider` is set.
|
||||
#[serde(default)]
|
||||
pub vision_model: Option<String>,
|
||||
}
|
||||
|
||||
fn default_multimodal_max_images() -> usize {
|
||||
@@ -1438,6 +1472,8 @@ impl Default for MultimodalConfig {
|
||||
max_images: default_multimodal_max_images(),
|
||||
max_image_size_mb: default_multimodal_max_image_size_mb(),
|
||||
allow_remote_fetch: false,
|
||||
vision_provider: None,
|
||||
vision_model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2120,8 +2156,8 @@ fn default_browser_webdriver_url() -> String {
|
||||
impl Default for BrowserConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_domains: Vec::new(),
|
||||
enabled: true,
|
||||
allowed_domains: vec!["*".into()],
|
||||
session_name: None,
|
||||
backend: default_browser_backend(),
|
||||
native_headless: default_true(),
|
||||
@@ -2136,7 +2172,9 @@ impl Default for BrowserConfig {
|
||||
|
||||
/// HTTP request tool configuration (`[http_request]` section).
|
||||
///
|
||||
/// Deny-by-default: if `allowed_domains` is empty, all HTTP requests are rejected.
|
||||
/// Domain filtering: `allowed_domains` controls which hosts are reachable (use `["*"]`
|
||||
/// for all public hosts, which is the default). If `allowed_domains` is empty, all
|
||||
/// requests are rejected.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct HttpRequestConfig {
|
||||
/// Enable `http_request` tool for API interactions
|
||||
@@ -2160,8 +2198,8 @@ pub struct HttpRequestConfig {
|
||||
impl Default for HttpRequestConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_domains: vec![],
|
||||
enabled: true,
|
||||
allowed_domains: vec!["*".into()],
|
||||
max_response_size: default_http_max_response_size(),
|
||||
timeout_secs: default_http_timeout_secs(),
|
||||
allow_private_hosts: false,
|
||||
@@ -2219,7 +2257,7 @@ fn default_web_fetch_allowed_domains() -> Vec<String> {
|
||||
impl Default for WebFetchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
enabled: true,
|
||||
allowed_domains: vec!["*".into()],
|
||||
blocked_domains: vec![],
|
||||
max_response_size: default_web_fetch_max_response_size(),
|
||||
@@ -2228,6 +2266,45 @@ impl Default for WebFetchConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Link enricher ─────────────────────────────────────────────────
|
||||
|
||||
/// Automatic link understanding for inbound channel messages (`[link_enricher]`).
|
||||
///
|
||||
/// When enabled, URLs in incoming messages are automatically fetched and
|
||||
/// summarised. The summary is prepended to the message before the agent
|
||||
/// processes it, giving the LLM context about linked pages without an
|
||||
/// explicit tool call.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct LinkEnricherConfig {
|
||||
/// Enable the link enricher pipeline stage (default: false)
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Maximum number of links to fetch per message (default: 3)
|
||||
#[serde(default = "default_link_enricher_max_links")]
|
||||
pub max_links: usize,
|
||||
/// Per-link fetch timeout in seconds (default: 10)
|
||||
#[serde(default = "default_link_enricher_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_link_enricher_max_links() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_link_enricher_timeout_secs() -> u64 {
|
||||
10
|
||||
}
|
||||
|
||||
impl Default for LinkEnricherConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_links: default_link_enricher_max_links(),
|
||||
timeout_secs: default_link_enricher_timeout_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Text browser ─────────────────────────────────────────────────
|
||||
|
||||
/// Text browser tool configuration (`[text_browser]` section).
|
||||
@@ -2269,12 +2346,15 @@ pub struct WebSearchConfig {
|
||||
/// Enable `web_search_tool` for web searches
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Search provider: "duckduckgo" (free, no API key) or "brave" (requires API key)
|
||||
/// Search provider: "duckduckgo" (free), "brave" (requires API key), or "searxng" (self-hosted)
|
||||
#[serde(default = "default_web_search_provider")]
|
||||
pub provider: String,
|
||||
/// Brave Search API key (required if provider is "brave")
|
||||
#[serde(default)]
|
||||
pub brave_api_key: Option<String>,
|
||||
/// SearXNG instance URL (required if provider is "searxng"), e.g. "https://searx.example.com"
|
||||
#[serde(default)]
|
||||
pub searxng_instance_url: Option<String>,
|
||||
/// Maximum results per search (1-10)
|
||||
#[serde(default = "default_web_search_max_results")]
|
||||
pub max_results: usize,
|
||||
@@ -2298,9 +2378,10 @@ fn default_web_search_timeout_secs() -> u64 {
|
||||
impl Default for WebSearchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
enabled: true,
|
||||
provider: default_web_search_provider(),
|
||||
brave_api_key: None,
|
||||
searxng_instance_url: None,
|
||||
max_results: default_web_search_max_results(),
|
||||
timeout_secs: default_web_search_timeout_secs(),
|
||||
}
|
||||
@@ -3809,77 +3890,6 @@ impl Default for QdrantConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the mem0 (OpenMemory) memory backend.
|
||||
///
|
||||
/// Connects to a self-hosted OpenMemory server via its REST API.
|
||||
/// Deploy OpenMemory with `docker compose up` from the mem0 repo,
|
||||
/// then point `url` at the API (default `http://localhost:8765`).
|
||||
///
|
||||
/// ```toml
|
||||
/// [memory]
|
||||
/// backend = "mem0"
|
||||
///
|
||||
/// [memory.mem0]
|
||||
/// url = "http://localhost:8765"
|
||||
/// user_id = "zeroclaw"
|
||||
/// app_name = "zeroclaw"
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct Mem0Config {
|
||||
/// OpenMemory server URL (e.g. `http://localhost:8765`).
|
||||
/// Falls back to `MEM0_URL` env var if not set.
|
||||
#[serde(default = "default_mem0_url")]
|
||||
pub url: String,
|
||||
/// User ID for scoping memories within mem0.
|
||||
/// Falls back to `MEM0_USER_ID` env var, or default `"zeroclaw"`.
|
||||
#[serde(default = "default_mem0_user_id")]
|
||||
pub user_id: String,
|
||||
/// Application name registered in mem0.
|
||||
/// Falls back to `MEM0_APP_NAME` env var, or default `"zeroclaw"`.
|
||||
#[serde(default = "default_mem0_app_name")]
|
||||
pub app_name: String,
|
||||
/// Whether mem0 should use its built-in LLM to extract facts from
|
||||
/// stored text (`infer = true`) or store raw text as-is (`false`).
|
||||
#[serde(default = "default_mem0_infer")]
|
||||
pub infer: bool,
|
||||
/// Custom prompt for guiding LLM-based fact extraction when `infer = true`.
|
||||
/// Useful for non-English content (e.g. Cantonese/Chinese).
|
||||
/// Falls back to `MEM0_EXTRACTION_PROMPT` env var.
|
||||
/// If unset, the mem0 server uses its built-in default prompt.
|
||||
#[serde(default = "default_mem0_extraction_prompt")]
|
||||
pub extraction_prompt: Option<String>,
|
||||
}
|
||||
|
||||
fn default_mem0_url() -> String {
|
||||
std::env::var("MEM0_URL").unwrap_or_else(|_| "http://localhost:8765".into())
|
||||
}
|
||||
fn default_mem0_user_id() -> String {
|
||||
std::env::var("MEM0_USER_ID").unwrap_or_else(|_| "zeroclaw".into())
|
||||
}
|
||||
fn default_mem0_app_name() -> String {
|
||||
std::env::var("MEM0_APP_NAME").unwrap_or_else(|_| "zeroclaw".into())
|
||||
}
|
||||
fn default_mem0_infer() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_mem0_extraction_prompt() -> Option<String> {
|
||||
std::env::var("MEM0_EXTRACTION_PROMPT")
|
||||
.ok()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
}
|
||||
|
||||
impl Default for Mem0Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
url: default_mem0_url(),
|
||||
user_id: default_mem0_user_id(),
|
||||
app_name: default_mem0_app_name(),
|
||||
infer: default_mem0_infer(),
|
||||
extraction_prompt: default_mem0_extraction_prompt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct MemoryConfig {
|
||||
@@ -3954,6 +3964,43 @@ pub struct MemoryConfig {
|
||||
#[serde(default = "default_true")]
|
||||
pub auto_hydrate: bool,
|
||||
|
||||
// ── Retrieval Pipeline ─────────────────────────────────────
|
||||
/// Retrieval stages to execute in order. Valid: "cache", "fts", "vector".
|
||||
#[serde(default = "default_retrieval_stages")]
|
||||
pub retrieval_stages: Vec<String>,
|
||||
/// Enable LLM reranking when candidate count exceeds threshold.
|
||||
#[serde(default)]
|
||||
pub rerank_enabled: bool,
|
||||
/// Minimum candidate count to trigger reranking.
|
||||
#[serde(default = "default_rerank_threshold")]
|
||||
pub rerank_threshold: usize,
|
||||
/// FTS score above which to early-return without vector search (0.0–1.0).
|
||||
#[serde(default = "default_fts_early_return_score")]
|
||||
pub fts_early_return_score: f64,
|
||||
|
||||
// ── Namespace Isolation ─────────────────────────────────────
|
||||
/// Default namespace for memory entries.
|
||||
#[serde(default = "default_namespace")]
|
||||
pub default_namespace: String,
|
||||
|
||||
// ── Conflict Resolution ─────────────────────────────────────
|
||||
/// Cosine similarity threshold for conflict detection (0.0–1.0).
|
||||
#[serde(default = "default_conflict_threshold")]
|
||||
pub conflict_threshold: f64,
|
||||
|
||||
// ── Audit Trail ─────────────────────────────────────────────
|
||||
/// Enable audit logging of memory operations.
|
||||
#[serde(default)]
|
||||
pub audit_enabled: bool,
|
||||
/// Retention period for audit entries in days (default: 30).
|
||||
#[serde(default = "default_audit_retention_days")]
|
||||
pub audit_retention_days: u32,
|
||||
|
||||
// ── Policy Engine ───────────────────────────────────────────
|
||||
/// Memory policy configuration.
|
||||
#[serde(default)]
|
||||
pub policy: MemoryPolicyConfig,
|
||||
|
||||
// ── SQLite backend options ─────────────────────────────────
|
||||
/// For sqlite backend: max seconds to wait when opening the DB (e.g. file locked).
|
||||
/// None = wait indefinitely (default). Recommended max: 300.
|
||||
@@ -3965,13 +4012,42 @@ pub struct MemoryConfig {
|
||||
/// Only used when `backend = "qdrant"`.
|
||||
#[serde(default)]
|
||||
pub qdrant: QdrantConfig,
|
||||
}
|
||||
|
||||
// ── Mem0 backend options ─────────────────────────────────
|
||||
/// Configuration for mem0 (OpenMemory) backend.
|
||||
/// Only used when `backend = "mem0"`.
|
||||
/// Requires `--features memory-mem0` at build time.
|
||||
/// Memory policy configuration (`[memory.policy]` section).
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct MemoryPolicyConfig {
|
||||
/// Maximum entries per namespace (0 = unlimited).
|
||||
#[serde(default)]
|
||||
pub mem0: Mem0Config,
|
||||
pub max_entries_per_namespace: usize,
|
||||
/// Maximum entries per category (0 = unlimited).
|
||||
#[serde(default)]
|
||||
pub max_entries_per_category: usize,
|
||||
/// Retention days by category (overrides global). Keys: "core", "daily", "conversation".
|
||||
#[serde(default)]
|
||||
pub retention_days_by_category: std::collections::HashMap<String, u32>,
|
||||
/// Namespaces that are read-only (writes are rejected).
|
||||
#[serde(default)]
|
||||
pub read_only_namespaces: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_retrieval_stages() -> Vec<String> {
|
||||
vec!["cache".into(), "fts".into(), "vector".into()]
|
||||
}
|
||||
fn default_rerank_threshold() -> usize {
|
||||
5
|
||||
}
|
||||
fn default_fts_early_return_score() -> f64 {
|
||||
0.85
|
||||
}
|
||||
fn default_namespace() -> String {
|
||||
"default".into()
|
||||
}
|
||||
fn default_conflict_threshold() -> f64 {
|
||||
0.85
|
||||
}
|
||||
fn default_audit_retention_days() -> u32 {
|
||||
30
|
||||
}
|
||||
|
||||
fn default_embedding_provider() -> String {
|
||||
@@ -4045,9 +4121,17 @@ impl Default for MemoryConfig {
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
retrieval_stages: default_retrieval_stages(),
|
||||
rerank_enabled: false,
|
||||
rerank_threshold: default_rerank_threshold(),
|
||||
fts_early_return_score: default_fts_early_return_score(),
|
||||
default_namespace: default_namespace(),
|
||||
conflict_threshold: default_conflict_threshold(),
|
||||
audit_enabled: false,
|
||||
audit_retention_days: default_audit_retention_days(),
|
||||
policy: MemoryPolicyConfig::default(),
|
||||
sqlite_open_timeout_secs: None,
|
||||
qdrant: QdrantConfig::default(),
|
||||
mem0: Mem0Config::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4256,6 +4340,7 @@ fn default_auto_approve() -> Vec<String> {
|
||||
"glob_search".into(),
|
||||
"content_search".into(),
|
||||
"image_info".into(),
|
||||
"weather".into(),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -4263,6 +4348,19 @@ fn default_always_ask() -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
impl AutonomyConfig {
|
||||
/// Merge the built-in default `auto_approve` entries into the current
|
||||
/// list, preserving any user-supplied additions.
|
||||
pub fn ensure_default_auto_approve(&mut self) {
|
||||
let defaults = default_auto_approve();
|
||||
for entry in defaults {
|
||||
if !self.auto_approve.iter().any(|existing| existing == &entry) {
|
||||
self.auto_approve.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_valid_env_var_name(name: &str) -> bool {
|
||||
let mut chars = name.chars();
|
||||
match chars.next() {
|
||||
@@ -4642,6 +4740,7 @@ pub struct ClassificationRule {
|
||||
|
||||
/// Heartbeat configuration for periodic health pings (`[heartbeat]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct HeartbeatConfig {
|
||||
/// Enable periodic heartbeat pings. Default: `false`.
|
||||
pub enabled: bool,
|
||||
@@ -4688,6 +4787,14 @@ pub struct HeartbeatConfig {
|
||||
/// Maximum number of heartbeat run history records to retain. Default: `100`.
|
||||
#[serde(default = "default_heartbeat_max_run_history")]
|
||||
pub max_run_history: u32,
|
||||
/// Load the channel session history before each heartbeat task execution so
|
||||
/// the LLM has conversational context. Default: `false`.
|
||||
///
|
||||
/// When `true`, the session file for the configured `target`/`to` is passed
|
||||
/// to the agent as `session_state_file`, giving it access to the recent
|
||||
/// conversation history — just as if the user had sent a message.
|
||||
#[serde(default)]
|
||||
pub load_session_context: bool,
|
||||
}
|
||||
|
||||
fn default_heartbeat_interval() -> u32 {
|
||||
@@ -4726,6 +4833,7 @@ impl Default for HeartbeatConfig {
|
||||
deadman_channel: None,
|
||||
deadman_to: None,
|
||||
max_run_history: default_heartbeat_max_run_history(),
|
||||
load_session_context: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4750,6 +4858,92 @@ pub struct CronConfig {
|
||||
/// Maximum number of historical cron run records to retain. Default: `50`.
|
||||
#[serde(default = "default_max_run_history")]
|
||||
pub max_run_history: u32,
|
||||
/// Declarative cron job definitions (`[[cron.jobs]]`).
|
||||
///
|
||||
/// Jobs declared here are synced into the database at scheduler startup.
|
||||
/// They use `source = "declarative"` to distinguish them from jobs
|
||||
/// created imperatively via CLI or API. Declarative config takes
|
||||
/// precedence on each sync: if the config changes, the DB is updated
|
||||
/// to match. Imperative jobs are never deleted by the sync process.
|
||||
#[serde(default)]
|
||||
pub jobs: Vec<CronJobDecl>,
|
||||
}
|
||||
|
||||
/// A declarative cron job definition for the `[[cron.jobs]]` config array.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CronJobDecl {
|
||||
/// Stable identifier used for merge semantics across syncs.
|
||||
pub id: String,
|
||||
/// Human-readable name.
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
/// Job type: `"shell"` (default) or `"agent"`.
|
||||
#[serde(default = "default_job_type_decl")]
|
||||
pub job_type: String,
|
||||
/// Schedule for the job.
|
||||
pub schedule: CronScheduleDecl,
|
||||
/// Shell command to run (required when `job_type = "shell"`).
|
||||
#[serde(default)]
|
||||
pub command: Option<String>,
|
||||
/// Agent prompt (required when `job_type = "agent"`).
|
||||
#[serde(default)]
|
||||
pub prompt: Option<String>,
|
||||
/// Whether the job is enabled. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
/// Model override for agent jobs.
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
/// Allowlist of tool names for agent jobs.
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// Session target: `"isolated"` (default) or `"main"`.
|
||||
#[serde(default)]
|
||||
pub session_target: Option<String>,
|
||||
/// Delivery configuration.
|
||||
#[serde(default)]
|
||||
pub delivery: Option<DeliveryConfigDecl>,
|
||||
}
|
||||
|
||||
/// Schedule variant for declarative cron jobs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "kind", rename_all = "lowercase")]
|
||||
pub enum CronScheduleDecl {
|
||||
/// Classic cron expression.
|
||||
Cron {
|
||||
expr: String,
|
||||
#[serde(default)]
|
||||
tz: Option<String>,
|
||||
},
|
||||
/// Interval in milliseconds.
|
||||
Every { every_ms: u64 },
|
||||
/// One-shot at an RFC 3339 timestamp.
|
||||
At { at: String },
|
||||
}
|
||||
|
||||
/// Delivery configuration for declarative cron jobs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DeliveryConfigDecl {
|
||||
/// Delivery mode: `"none"` or `"announce"`.
|
||||
#[serde(default = "default_delivery_mode")]
|
||||
pub mode: String,
|
||||
/// Channel name (e.g. `"telegram"`, `"discord"`).
|
||||
#[serde(default)]
|
||||
pub channel: Option<String>,
|
||||
/// Target/recipient identifier.
|
||||
#[serde(default)]
|
||||
pub to: Option<String>,
|
||||
/// Best-effort delivery. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub best_effort: bool,
|
||||
}
|
||||
|
||||
fn default_job_type_decl() -> String {
|
||||
"shell".to_string()
|
||||
}
|
||||
|
||||
fn default_delivery_mode() -> String {
|
||||
"none".to_string()
|
||||
}
|
||||
|
||||
fn default_max_run_history() -> u32 {
|
||||
@@ -4762,6 +4956,7 @@ impl Default for CronConfig {
|
||||
enabled: true,
|
||||
catch_up_on_startup: true,
|
||||
max_run_history: default_max_run_history(),
|
||||
jobs: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5443,6 +5638,10 @@ pub struct MatrixConfig {
|
||||
pub room_id: String,
|
||||
/// Allowed Matrix user IDs. Empty = deny all.
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Allowed Matrix room IDs or aliases. Empty = allow all rooms.
|
||||
/// Supports canonical room IDs (`!abc:server`) and aliases (`#room:server`).
|
||||
#[serde(default)]
|
||||
pub allowed_rooms: Vec<String>,
|
||||
/// Whether to interrupt an in-flight agent response when a new message arrives.
|
||||
#[serde(default)]
|
||||
pub interrupt_on_new_message: bool,
|
||||
@@ -6963,6 +7162,7 @@ impl Default for Config {
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
link_enricher: LinkEnricherConfig::default(),
|
||||
text_browser: TextBrowserConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
@@ -7516,22 +7716,50 @@ impl Config {
|
||||
.await
|
||||
.context("Failed to read config file")?;
|
||||
|
||||
// Track ignored/unknown config keys to warn users about silent misconfigurations
|
||||
// (e.g., using [providers.ollama] which doesn't exist instead of top-level api_url)
|
||||
// Deserialize the config with the standard TOML parser.
|
||||
//
|
||||
// Previously this used `serde_ignored::deserialize` for both
|
||||
// deserialization and unknown-key detection. However,
|
||||
// `serde_ignored` silently drops field values inside nested
|
||||
// structs that carry `#[serde(default)]` (e.g. the entire
|
||||
// `[autonomy]` table), causing user-supplied values to be
|
||||
// replaced by defaults. See #4171.
|
||||
//
|
||||
// We now deserialize with `toml::from_str` (which is correct)
|
||||
// and run `serde_ignored` separately just for diagnostics.
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).context("Failed to deserialize config file")?;
|
||||
|
||||
// Ensure the built-in default auto_approve entries are always
|
||||
// present. When a user specifies `auto_approve` in their TOML
|
||||
// (e.g. to add a custom tool), serde replaces the default list
|
||||
// instead of merging. This caused default-safe tools like
|
||||
// `weather` or `calculator` to lose their auto-approve status
|
||||
// and get silently denied in non-interactive channel runs.
|
||||
// See #4247.
|
||||
//
|
||||
// Users who want to require approval for a default tool can
|
||||
// add it to `always_ask`, which takes precedence over
|
||||
// `auto_approve` in the approval decision (see approval/mod.rs).
|
||||
config.autonomy.ensure_default_auto_approve();
|
||||
|
||||
// Detect unknown/ignored config keys for diagnostic warnings.
|
||||
// This second pass uses serde_ignored but discards the parsed
|
||||
// result — only the ignored-path list is kept.
|
||||
let mut ignored_paths: Vec<String> = Vec::new();
|
||||
let mut config: Config = serde_ignored::deserialize(
|
||||
toml::de::Deserializer::parse(&contents).context("Failed to parse config file")?,
|
||||
let _: Result<Config, _> = serde_ignored::deserialize(
|
||||
toml::de::Deserializer::parse(&contents)
|
||||
.unwrap_or_else(|_| unreachable!("already parsed above")),
|
||||
|path| {
|
||||
ignored_paths.push(path.to_string());
|
||||
},
|
||||
)
|
||||
.context("Failed to deserialize config file")?;
|
||||
);
|
||||
|
||||
// Warn about each unknown config key.
|
||||
// serde_ignored + #[serde(default)] on nested structs can produce
|
||||
// false positives: parent-level fields get re-reported under the
|
||||
// nested key (e.g. "memory.mem0.auto_hydrate" even though
|
||||
// auto_hydrate belongs to MemoryConfig, not Mem0Config). We
|
||||
// nested key (e.g. "memory.qdrant.auto_hydrate" even though
|
||||
// auto_hydrate belongs to MemoryConfig, not QdrantConfig). We
|
||||
// suppress these by checking whether the leaf key is a known field
|
||||
// on the parent struct.
|
||||
let known_memory_fields: &[&str] = &[
|
||||
@@ -7560,7 +7788,7 @@ impl Config {
|
||||
];
|
||||
for path in ignored_paths {
|
||||
// Skip false positives from nested memory sub-sections
|
||||
if path.starts_with("memory.mem0.") || path.starts_with("memory.qdrant.") {
|
||||
if path.starts_with("memory.qdrant.") {
|
||||
let leaf = path.rsplit('.').next().unwrap_or("");
|
||||
if known_memory_fields.contains(&leaf) {
|
||||
continue;
|
||||
@@ -8854,6 +9082,16 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// SearXNG instance URL: ZEROCLAW_SEARXNG_INSTANCE_URL or SEARXNG_INSTANCE_URL
|
||||
if let Ok(instance_url) = std::env::var("ZEROCLAW_SEARXNG_INSTANCE_URL")
|
||||
.or_else(|_| std::env::var("SEARXNG_INSTANCE_URL"))
|
||||
{
|
||||
let instance_url = instance_url.trim();
|
||||
if !instance_url.is_empty() {
|
||||
self.web_search.searxng_instance_url = Some(instance_url.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Web search max results: ZEROCLAW_WEB_SEARCH_MAX_RESULTS or WEB_SEARCH_MAX_RESULTS
|
||||
if let Ok(max_results) = std::env::var("ZEROCLAW_WEB_SEARCH_MAX_RESULTS")
|
||||
.or_else(|_| std::env::var("WEB_SEARCH_MAX_RESULTS"))
|
||||
@@ -9531,7 +9769,9 @@ mod tests {
|
||||
merged.push(']');
|
||||
}
|
||||
merged.push('\n');
|
||||
toml::from_str(&merged).unwrap()
|
||||
let mut config: Config = toml::from_str(&merged).unwrap();
|
||||
config.autonomy.ensure_default_auto_approve();
|
||||
config
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -9539,8 +9779,8 @@ mod tests {
|
||||
let cfg = HttpRequestConfig::default();
|
||||
assert_eq!(cfg.timeout_secs, 30);
|
||||
assert_eq!(cfg.max_response_size, 1_000_000);
|
||||
assert!(!cfg.enabled);
|
||||
assert!(cfg.allowed_domains.is_empty());
|
||||
assert!(cfg.enabled);
|
||||
assert_eq!(cfg.allowed_domains, vec!["*".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -9729,6 +9969,7 @@ recipient = "42"
|
||||
enabled: false,
|
||||
catch_up_on_startup: false,
|
||||
max_run_history: 100,
|
||||
jobs: Vec::new(),
|
||||
};
|
||||
let json = serde_json::to_string(&c).unwrap();
|
||||
let parsed: CronConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -9903,6 +10144,7 @@ default_temperature = 0.7
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
link_enricher: LinkEnricherConfig::default(),
|
||||
text_browser: TextBrowserConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
@@ -9986,6 +10228,140 @@ default_temperature = 0.7
|
||||
assert_eq!(parsed.provider_timeout_secs, 120);
|
||||
}
|
||||
|
||||
/// Regression test for #4171: the `[autonomy]` section must not be
|
||||
/// silently dropped when parsing config TOML.
|
||||
#[test]
|
||||
async fn autonomy_section_is_not_silently_ignored() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
level = "full"
|
||||
max_actions_per_hour = 99
|
||||
auto_approve = ["file_read", "memory_recall", "http_request"]
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
assert_eq!(
|
||||
parsed.autonomy.level,
|
||||
AutonomyLevel::Full,
|
||||
"autonomy.level must be parsed from config (was silently defaulting to Supervised)"
|
||||
);
|
||||
assert_eq!(
|
||||
parsed.autonomy.max_actions_per_hour, 99,
|
||||
"autonomy.max_actions_per_hour must be parsed from config"
|
||||
);
|
||||
assert!(
|
||||
parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.contains(&"http_request".to_string()),
|
||||
"autonomy.auto_approve must include http_request from config"
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test for #4247: when a user provides a custom auto_approve
|
||||
/// list, the built-in defaults must still be present.
|
||||
#[test]
|
||||
async fn auto_approve_merges_user_entries_with_defaults() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
auto_approve = ["my_custom_tool", "another_tool"]
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
// User entries are preserved
|
||||
assert!(
|
||||
parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.contains(&"my_custom_tool".to_string()),
|
||||
"user-supplied tool must remain in auto_approve"
|
||||
);
|
||||
assert!(
|
||||
parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.contains(&"another_tool".to_string()),
|
||||
"user-supplied tool must remain in auto_approve"
|
||||
);
|
||||
// Defaults are merged in
|
||||
for default_tool in &[
|
||||
"file_read",
|
||||
"memory_recall",
|
||||
"weather",
|
||||
"calculator",
|
||||
"web_fetch",
|
||||
] {
|
||||
assert!(
|
||||
parsed.autonomy.auto_approve.contains(&default_tool.to_string()),
|
||||
"default tool '{default_tool}' must be present in auto_approve even when user provides custom list"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Regression test: empty auto_approve still gets defaults merged.
|
||||
#[test]
|
||||
async fn auto_approve_empty_list_gets_defaults() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
auto_approve = []
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
let defaults = default_auto_approve();
|
||||
for tool in &defaults {
|
||||
assert!(
|
||||
parsed.autonomy.auto_approve.contains(tool),
|
||||
"default tool '{tool}' must be present even when user sets auto_approve = []"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// When no autonomy section is provided, defaults are applied normally.
|
||||
#[test]
|
||||
async fn auto_approve_defaults_when_no_autonomy_section() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
let defaults = default_auto_approve();
|
||||
for tool in &defaults {
|
||||
assert!(
|
||||
parsed.autonomy.auto_approve.contains(tool),
|
||||
"default tool '{tool}' must be present when no [autonomy] section"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Duplicates are not introduced when ensure_default_auto_approve runs
|
||||
/// on a list that already contains the defaults.
|
||||
#[test]
|
||||
async fn auto_approve_no_duplicates() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
auto_approve = ["weather", "file_read"]
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
let weather_count = parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.iter()
|
||||
.filter(|t| *t == "weather")
|
||||
.count();
|
||||
assert_eq!(weather_count, 1, "weather must not be duplicated");
|
||||
let file_read_count = parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.iter()
|
||||
.filter(|t| *t == "file_read")
|
||||
.count();
|
||||
assert_eq!(file_read_count, 1, "file_read must not be duplicated");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_timeout_secs_parses_from_toml() {
|
||||
let raw = r#"
|
||||
@@ -10285,6 +10661,7 @@ default_temperature = 0.7
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
link_enricher: LinkEnricherConfig::default(),
|
||||
text_browser: TextBrowserConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
@@ -10597,6 +10974,7 @@ default_temperature = 0.7
|
||||
device_id: Some("DEVICE123".into()),
|
||||
room_id: "!room123:matrix.org".into(),
|
||||
allowed_users: vec!["@user:matrix.org".into()],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
};
|
||||
let json = serde_json::to_string(&mc).unwrap();
|
||||
@@ -10618,6 +10996,7 @@ default_temperature = 0.7
|
||||
device_id: None,
|
||||
room_id: "!abc:synapse.local".into(),
|
||||
allowed_users: vec!["@admin:synapse.local".into(), "*".into()],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
};
|
||||
let toml_str = toml::to_string(&mc).unwrap();
|
||||
@@ -10711,6 +11090,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
device_id: None,
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec!["@u:m".into()],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
}),
|
||||
signal: None,
|
||||
@@ -11306,15 +11686,15 @@ default_temperature = 0.7
|
||||
assert!(!c.composio.enabled);
|
||||
assert!(c.composio.api_key.is_none());
|
||||
assert!(c.secrets.encrypt);
|
||||
assert!(!c.browser.enabled);
|
||||
assert!(c.browser.allowed_domains.is_empty());
|
||||
assert!(c.browser.enabled);
|
||||
assert_eq!(c.browser.allowed_domains, vec!["*".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn browser_config_default_disabled() {
|
||||
async fn browser_config_default_enabled() {
|
||||
let b = BrowserConfig::default();
|
||||
assert!(!b.enabled);
|
||||
assert!(b.allowed_domains.is_empty());
|
||||
assert!(b.enabled);
|
||||
assert_eq!(b.allowed_domains, vec!["*".to_string()]);
|
||||
assert_eq!(b.backend, "agent_browser");
|
||||
assert!(b.native_headless);
|
||||
assert_eq!(b.native_webdriver_url, "http://127.0.0.1:9515");
|
||||
@@ -11379,8 +11759,8 @@ config_path = "/tmp/config.toml"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed = parse_test_config(minimal);
|
||||
assert!(!parsed.browser.enabled);
|
||||
assert!(parsed.browser.allowed_domains.is_empty());
|
||||
assert!(parsed.browser.enabled);
|
||||
assert_eq!(parsed.browser.allowed_domains, vec!["*".to_string()]);
|
||||
}
|
||||
|
||||
// ── Environment variable overrides (Docker support) ─────────
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ pub use schedule::{
|
||||
#[allow(unused_imports)]
|
||||
pub use store::{
|
||||
add_agent_job, all_overdue_jobs, due_jobs, get_job, list_jobs, list_runs, record_last_run,
|
||||
record_run, remove_job, reschedule_after_run, update_job,
|
||||
record_run, remove_job, reschedule_after_run, sync_declarative_jobs, update_job,
|
||||
};
|
||||
pub use types::{
|
||||
deserialize_maybe_stringified, CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType,
|
||||
|
||||
+48
-2
@@ -1,5 +1,7 @@
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
use crate::channels::MatrixChannel;
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
use crate::channels::WhatsAppWebChannel;
|
||||
use crate::channels::{
|
||||
Channel, DiscordChannel, MattermostChannel, QQChannel, SendMessage, SignalChannel,
|
||||
SlackChannel, TelegramChannel,
|
||||
@@ -7,8 +9,8 @@ use crate::channels::{
|
||||
use crate::config::Config;
|
||||
use crate::cron::{
|
||||
all_overdue_jobs, due_jobs, next_run_for_schedule, record_last_run, record_run, remove_job,
|
||||
reschedule_after_run, update_job, CronJob, CronJobPatch, DeliveryConfig, JobType, Schedule,
|
||||
SessionTarget,
|
||||
reschedule_after_run, sync_declarative_jobs, update_job, CronJob, CronJobPatch, DeliveryConfig,
|
||||
JobType, Schedule, SessionTarget,
|
||||
};
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Result;
|
||||
@@ -34,6 +36,19 @@ pub async fn run(config: Config) -> Result<()> {
|
||||
|
||||
crate::health::mark_component_ok(SCHEDULER_COMPONENT);
|
||||
|
||||
// ── Declarative job sync: reconcile config-defined jobs with the DB.
|
||||
match sync_declarative_jobs(&config, &config.cron.jobs) {
|
||||
Ok(()) => {
|
||||
if !config.cron.jobs.is_empty() {
|
||||
tracing::info!(
|
||||
count = config.cron.jobs.len(),
|
||||
"Synced declarative cron jobs from config"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to sync declarative cron jobs: {e}"),
|
||||
}
|
||||
|
||||
// ── Startup catch-up: run ALL overdue jobs before entering the
|
||||
// normal polling loop. The regular loop is capped by `max_tasks`,
|
||||
// which could leave some overdue jobs waiting across many cycles
|
||||
@@ -483,6 +498,36 @@ pub(crate) async fn deliver_announcement(
|
||||
anyhow::bail!("matrix delivery channel requires `channel-matrix` feature");
|
||||
}
|
||||
}
|
||||
"whatsapp" | "whatsapp-web" | "whatsapp_web" => {
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
{
|
||||
let wa = config
|
||||
.channels_config
|
||||
.whatsapp
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("whatsapp channel not configured"))?;
|
||||
if !wa.is_web_config() {
|
||||
anyhow::bail!(
|
||||
"whatsapp cron delivery requires Web mode (session_path must be set)"
|
||||
);
|
||||
}
|
||||
let channel = WhatsAppWebChannel::new(
|
||||
wa.session_path.clone().unwrap_or_default(),
|
||||
wa.pair_phone.clone(),
|
||||
wa.pair_code.clone(),
|
||||
wa.allowed_numbers.clone(),
|
||||
wa.mode.clone(),
|
||||
wa.dm_policy.clone(),
|
||||
wa.group_policy.clone(),
|
||||
wa.self_chat_mode,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
}
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
anyhow::bail!("whatsapp delivery channel requires `whatsapp-web` feature");
|
||||
}
|
||||
}
|
||||
"qq" => {
|
||||
let qq = config
|
||||
.channels_config
|
||||
@@ -657,6 +702,7 @@ mod tests {
|
||||
delivery: DeliveryConfig::default(),
|
||||
delete_after_run: false,
|
||||
allowed_tools: None,
|
||||
source: "imperative".into(),
|
||||
created_at: Utc::now(),
|
||||
next_run: Utc::now(),
|
||||
last_run: None,
|
||||
|
||||
+521
-4
@@ -124,7 +124,7 @@ pub fn list_jobs(config: &Config) -> Result<Vec<CronJob>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs ORDER BY next_run ASC",
|
||||
)?;
|
||||
|
||||
@@ -143,7 +143,7 @@ pub fn get_job(config: &Config, job_id: &str) -> Result<CronJob> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs WHERE id = ?1",
|
||||
)?;
|
||||
|
||||
@@ -177,7 +177,7 @@ pub fn due_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJob>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs
|
||||
WHERE enabled = 1 AND next_run <= ?1
|
||||
ORDER BY next_run ASC
|
||||
@@ -206,7 +206,8 @@ pub fn all_overdue_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJ
|
||||
with_connection(config, |conn| {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output, allowed_tools
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools, source
|
||||
FROM cron_jobs
|
||||
WHERE enabled = 1 AND next_run <= ?1
|
||||
ORDER BY next_run ASC",
|
||||
@@ -488,6 +489,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
let last_run_raw: Option<String> = row.get(14)?;
|
||||
let created_at_raw: String = row.get(12)?;
|
||||
let allowed_tools_raw: Option<String> = row.get(17)?;
|
||||
let source: Option<String> = row.get(18)?;
|
||||
|
||||
Ok(CronJob {
|
||||
id: row.get(0)?,
|
||||
@@ -502,6 +504,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
enabled: row.get::<_, i64>(9)? != 0,
|
||||
delivery,
|
||||
delete_after_run: row.get::<_, i64>(11)? != 0,
|
||||
source: source.unwrap_or_else(|| "imperative".to_string()),
|
||||
created_at: parse_rfc3339(&created_at_raw).map_err(sql_conversion_error)?,
|
||||
next_run: parse_rfc3339(&next_run_raw).map_err(sql_conversion_error)?,
|
||||
last_run: match last_run_raw {
|
||||
@@ -564,6 +567,277 @@ fn decode_allowed_tools(raw: Option<&str>) -> Result<Option<Vec<String>>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Synchronize declarative cron job definitions from config into the database.
|
||||
///
|
||||
/// For each declarative job (identified by `id`):
|
||||
/// - If the job exists in DB: update it to match the config definition.
|
||||
/// - If the job does not exist: insert it.
|
||||
///
|
||||
/// Jobs created imperatively (via CLI/API) are never modified or deleted.
|
||||
/// Declarative jobs that are no longer present in config are removed.
|
||||
pub fn sync_declarative_jobs(
|
||||
config: &Config,
|
||||
decls: &[crate::config::schema::CronJobDecl],
|
||||
) -> Result<()> {
|
||||
use crate::config::schema::CronScheduleDecl;
|
||||
|
||||
if decls.is_empty() {
|
||||
// If no declarative jobs are defined, clean up any previously
|
||||
// synced declarative jobs that are no longer in config.
|
||||
with_connection(config, |conn| {
|
||||
let deleted = conn
|
||||
.execute("DELETE FROM cron_jobs WHERE source = 'declarative'", [])
|
||||
.context("Failed to remove stale declarative cron jobs")?;
|
||||
if deleted > 0 {
|
||||
tracing::info!(
|
||||
count = deleted,
|
||||
"Removed declarative cron jobs no longer in config"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Validate declarations before touching the DB.
|
||||
for decl in decls {
|
||||
validate_decl(decl)?;
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
with_connection(config, |conn| {
|
||||
// Collect IDs of all declarative jobs currently defined in config.
|
||||
let config_ids: std::collections::HashSet<&str> =
|
||||
decls.iter().map(|d| d.id.as_str()).collect();
|
||||
|
||||
// Remove declarative jobs no longer in config.
|
||||
{
|
||||
let mut stmt = conn.prepare("SELECT id FROM cron_jobs WHERE source = 'declarative'")?;
|
||||
let db_ids: Vec<String> = stmt
|
||||
.query_map([], |row| row.get(0))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
|
||||
for db_id in &db_ids {
|
||||
if !config_ids.contains(db_id.as_str()) {
|
||||
conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![db_id])
|
||||
.with_context(|| {
|
||||
format!("Failed to remove stale declarative cron job '{db_id}'")
|
||||
})?;
|
||||
tracing::info!(
|
||||
job_id = %db_id,
|
||||
"Removed declarative cron job no longer in config"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for decl in decls {
|
||||
let schedule = convert_schedule_decl(&decl.schedule)?;
|
||||
let expression = schedule_cron_expression(&schedule).unwrap_or_default();
|
||||
let schedule_json = serde_json::to_string(&schedule)?;
|
||||
let job_type = &decl.job_type;
|
||||
let session_target = decl.session_target.as_deref().unwrap_or("isolated");
|
||||
let delivery = match &decl.delivery {
|
||||
Some(d) => convert_delivery_decl(d),
|
||||
None => DeliveryConfig::default(),
|
||||
};
|
||||
let delivery_json = serde_json::to_string(&delivery)?;
|
||||
let allowed_tools_json = encode_allowed_tools(decl.allowed_tools.as_ref())?;
|
||||
let command = decl.command.as_deref().unwrap_or("");
|
||||
let delete_after_run = matches!(decl.schedule, CronScheduleDecl::At { .. });
|
||||
|
||||
// Check if job already exists.
|
||||
let exists: bool = conn
|
||||
.prepare("SELECT COUNT(*) FROM cron_jobs WHERE id = ?1")?
|
||||
.query_row(params![decl.id], |row| row.get::<_, i64>(0))
|
||||
.map(|c| c > 0)
|
||||
.unwrap_or(false);
|
||||
|
||||
if exists {
|
||||
// Update existing declarative job — preserve runtime state
|
||||
// (next_run, last_run, last_status, last_output, created_at).
|
||||
// Only update the schedule's next_run if the schedule itself changed.
|
||||
let current_schedule_raw: Option<String> = conn
|
||||
.prepare("SELECT schedule FROM cron_jobs WHERE id = ?1")?
|
||||
.query_row(params![decl.id], |row| row.get(0))
|
||||
.ok();
|
||||
|
||||
let schedule_changed = current_schedule_raw.as_deref() != Some(&schedule_json);
|
||||
|
||||
if schedule_changed {
|
||||
let next_run = next_run_for_schedule(&schedule, now)?;
|
||||
conn.execute(
|
||||
"UPDATE cron_jobs
|
||||
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
|
||||
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
|
||||
enabled = ?9, delivery = ?10, delete_after_run = ?11,
|
||||
allowed_tools = ?12, source = 'declarative', next_run = ?13
|
||||
WHERE id = ?14",
|
||||
params![
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
next_run.to_rfc3339(),
|
||||
decl.id,
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("Failed to update declarative cron job '{}'", decl.id)
|
||||
})?;
|
||||
} else {
|
||||
conn.execute(
|
||||
"UPDATE cron_jobs
|
||||
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
|
||||
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
|
||||
enabled = ?9, delivery = ?10, delete_after_run = ?11,
|
||||
allowed_tools = ?12, source = 'declarative'
|
||||
WHERE id = ?13",
|
||||
params![
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
decl.id,
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("Failed to update declarative cron job '{}'", decl.id)
|
||||
})?;
|
||||
}
|
||||
|
||||
tracing::debug!(job_id = %decl.id, "Updated declarative cron job");
|
||||
} else {
|
||||
// Insert new declarative job.
|
||||
let next_run = next_run_for_schedule(&schedule, now)?;
|
||||
conn.execute(
|
||||
"INSERT INTO cron_jobs (
|
||||
id, expression, command, schedule, job_type, prompt, name,
|
||||
session_target, model, enabled, delivery, delete_after_run,
|
||||
allowed_tools, source, created_at, next_run
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, 'declarative', ?14, ?15)",
|
||||
params![
|
||||
decl.id,
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
now.to_rfc3339(),
|
||||
next_run.to_rfc3339(),
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to insert declarative cron job '{}'",
|
||||
decl.id
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(job_id = %decl.id, "Inserted declarative cron job from config");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a declarative cron job definition.
|
||||
fn validate_decl(decl: &crate::config::schema::CronJobDecl) -> Result<()> {
|
||||
if decl.id.trim().is_empty() {
|
||||
anyhow::bail!("Declarative cron job has empty id");
|
||||
}
|
||||
|
||||
match decl.job_type.to_lowercase().as_str() {
|
||||
"shell" => {
|
||||
if decl
|
||||
.command
|
||||
.as_deref()
|
||||
.map_or(true, |c| c.trim().is_empty())
|
||||
{
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': shell job requires a non-empty 'command'",
|
||||
decl.id
|
||||
);
|
||||
}
|
||||
}
|
||||
"agent" => {
|
||||
if decl.prompt.as_deref().map_or(true, |p| p.trim().is_empty()) {
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': agent job requires a non-empty 'prompt'",
|
||||
decl.id
|
||||
);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': invalid job_type '{}', expected 'shell' or 'agent'",
|
||||
decl.id,
|
||||
other
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert a `CronScheduleDecl` to the runtime `Schedule` type.
|
||||
fn convert_schedule_decl(decl: &crate::config::schema::CronScheduleDecl) -> Result<Schedule> {
|
||||
use crate::config::schema::CronScheduleDecl;
|
||||
match decl {
|
||||
CronScheduleDecl::Cron { expr, tz } => Ok(Schedule::Cron {
|
||||
expr: expr.clone(),
|
||||
tz: tz.clone(),
|
||||
}),
|
||||
CronScheduleDecl::Every { every_ms } => Ok(Schedule::Every {
|
||||
every_ms: *every_ms,
|
||||
}),
|
||||
CronScheduleDecl::At { at } => {
|
||||
let parsed = DateTime::parse_from_rfc3339(at)
|
||||
.with_context(|| {
|
||||
format!("Invalid RFC3339 timestamp in declarative cron 'at': {at}")
|
||||
})?
|
||||
.with_timezone(&Utc);
|
||||
Ok(Schedule::At { at: parsed })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a `DeliveryConfigDecl` to the runtime `DeliveryConfig`.
|
||||
fn convert_delivery_decl(decl: &crate::config::schema::DeliveryConfigDecl) -> DeliveryConfig {
|
||||
DeliveryConfig {
|
||||
mode: decl.mode.clone(),
|
||||
channel: decl.channel.clone(),
|
||||
to: decl.to.clone(),
|
||||
best_effort: decl.best_effort,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_column_if_missing(conn: &Connection, name: &str, sql_type: &str) -> Result<()> {
|
||||
let mut stmt = conn.prepare("PRAGMA table_info(cron_jobs)")?;
|
||||
let mut rows = stmt.query([])?;
|
||||
@@ -654,6 +928,7 @@ fn with_connection<T>(config: &Config, f: impl FnOnce(&Connection) -> Result<T>)
|
||||
add_column_if_missing(&conn, "delivery", "TEXT")?;
|
||||
add_column_if_missing(&conn, "delete_after_run", "INTEGER NOT NULL DEFAULT 0")?;
|
||||
add_column_if_missing(&conn, "allowed_tools", "TEXT")?;
|
||||
add_column_if_missing(&conn, "source", "TEXT DEFAULT 'imperative'")?;
|
||||
|
||||
f(&conn)
|
||||
}
|
||||
@@ -1170,4 +1445,246 @@ mod tests {
|
||||
assert!(last_output.ends_with(TRUNCATED_OUTPUT_MARKER));
|
||||
assert!(last_output.len() <= MAX_CRON_OUTPUT_BYTES);
|
||||
}
|
||||
|
||||
// ── Declarative cron job sync tests ──────────────────────────
|
||||
|
||||
fn make_shell_decl(id: &str, expr: &str, cmd: &str) -> crate::config::schema::CronJobDecl {
|
||||
crate::config::schema::CronJobDecl {
|
||||
id: id.to_string(),
|
||||
name: Some(format!("decl-{id}")),
|
||||
job_type: "shell".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Cron {
|
||||
expr: expr.to_string(),
|
||||
tz: None,
|
||||
},
|
||||
command: Some(cmd.to_string()),
|
||||
prompt: None,
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_agent_decl(id: &str, expr: &str, prompt: &str) -> crate::config::schema::CronJobDecl {
|
||||
crate::config::schema::CronJobDecl {
|
||||
id: id.to_string(),
|
||||
name: Some(format!("decl-{id}")),
|
||||
job_type: "agent".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Cron {
|
||||
expr: expr.to_string(),
|
||||
tz: None,
|
||||
},
|
||||
command: None,
|
||||
prompt: Some(prompt.to_string()),
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_inserts_new_declarative_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("daily-backup", "0 2 * * *", "echo backup")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job = get_job(&config, "daily-backup").unwrap();
|
||||
assert_eq!(job.command, "echo backup");
|
||||
assert_eq!(job.source, "declarative");
|
||||
assert_eq!(job.name.as_deref(), Some("decl-daily-backup"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_updates_existing_declarative_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("updatable", "0 2 * * *", "echo v1")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job_v1 = get_job(&config, "updatable").unwrap();
|
||||
assert_eq!(job_v1.command, "echo v1");
|
||||
|
||||
let decls_v2 = vec![make_shell_decl("updatable", "0 3 * * *", "echo v2")];
|
||||
sync_declarative_jobs(&config, &decls_v2).unwrap();
|
||||
|
||||
let job_v2 = get_job(&config, "updatable").unwrap();
|
||||
assert_eq!(job_v2.command, "echo v2");
|
||||
assert_eq!(job_v2.expression, "0 3 * * *");
|
||||
assert_eq!(job_v2.source, "declarative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_does_not_delete_imperative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
// Create an imperative job via the normal API.
|
||||
let imperative = add_job(&config, "*/10 * * * *", "echo imperative").unwrap();
|
||||
|
||||
// Sync declarative jobs (none of which match the imperative job).
|
||||
let decls = vec![make_shell_decl("my-decl", "0 2 * * *", "echo decl")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
// Imperative job should still exist.
|
||||
let still_there = get_job(&config, &imperative.id).unwrap();
|
||||
assert_eq!(still_there.command, "echo imperative");
|
||||
assert_eq!(still_there.source, "imperative");
|
||||
|
||||
// Declarative job should also exist.
|
||||
let decl_job = get_job(&config, "my-decl").unwrap();
|
||||
assert_eq!(decl_job.command, "echo decl");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_removes_stale_declarative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
// Insert two declarative jobs.
|
||||
let decls = vec![
|
||||
make_shell_decl("keeper", "0 2 * * *", "echo keep"),
|
||||
make_shell_decl("stale", "0 3 * * *", "echo stale"),
|
||||
];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
// Now sync with only "keeper" — "stale" should be removed.
|
||||
let decls_v2 = vec![make_shell_decl("keeper", "0 2 * * *", "echo keep")];
|
||||
sync_declarative_jobs(&config, &decls_v2).unwrap();
|
||||
|
||||
assert!(get_job(&config, "stale").is_err());
|
||||
assert!(get_job(&config, "keeper").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_empty_removes_all_declarative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("to-remove", "0 2 * * *", "echo bye")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
assert!(get_job(&config, "to-remove").is_ok());
|
||||
|
||||
// Sync with empty list.
|
||||
sync_declarative_jobs(&config, &[]).unwrap();
|
||||
assert!(get_job(&config, "to-remove").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_validates_shell_job_requires_command() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let mut decl = make_shell_decl("bad", "0 2 * * *", "echo ok");
|
||||
decl.command = None;
|
||||
|
||||
let result = sync_declarative_jobs(&config, &[decl]);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("command"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_validates_agent_job_requires_prompt() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let mut decl = make_agent_decl("bad-agent", "0 2 * * *", "do stuff");
|
||||
decl.prompt = None;
|
||||
|
||||
let result = sync_declarative_jobs(&config, &[decl]);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_agent_job_inserts_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_agent_decl(
|
||||
"agent-check",
|
||||
"*/15 * * * *",
|
||||
"check health",
|
||||
)];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job = get_job(&config, "agent-check").unwrap();
|
||||
assert_eq!(job.job_type, JobType::Agent);
|
||||
assert_eq!(job.prompt.as_deref(), Some("check health"));
|
||||
assert_eq!(job.source, "declarative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_every_schedule_works() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decl = crate::config::schema::CronJobDecl {
|
||||
id: "interval-job".to_string(),
|
||||
name: None,
|
||||
job_type: "shell".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Every { every_ms: 60000 },
|
||||
command: Some("echo interval".to_string()),
|
||||
prompt: None,
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
};
|
||||
|
||||
sync_declarative_jobs(&config, &[decl]).unwrap();
|
||||
|
||||
let job = get_job(&config, "interval-job").unwrap();
|
||||
assert!(matches!(job.schedule, Schedule::Every { every_ms: 60000 }));
|
||||
assert_eq!(job.command, "echo interval");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn declarative_config_parses_from_toml() {
|
||||
let toml_str = r#"
|
||||
enabled = true
|
||||
|
||||
[[jobs]]
|
||||
id = "daily-report"
|
||||
name = "Daily Report"
|
||||
job_type = "shell"
|
||||
command = "echo report"
|
||||
schedule = { kind = "cron", expr = "0 9 * * *" }
|
||||
|
||||
[[jobs]]
|
||||
id = "health-check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health"
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
"#;
|
||||
|
||||
let parsed: crate::config::schema::CronConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(parsed.enabled);
|
||||
assert_eq!(parsed.jobs.len(), 2);
|
||||
|
||||
assert_eq!(parsed.jobs[0].id, "daily-report");
|
||||
assert_eq!(parsed.jobs[0].command.as_deref(), Some("echo report"));
|
||||
assert!(matches!(
|
||||
parsed.jobs[0].schedule,
|
||||
crate::config::schema::CronScheduleDecl::Cron { ref expr, .. } if expr == "0 9 * * *"
|
||||
));
|
||||
|
||||
assert_eq!(parsed.jobs[1].id, "health-check");
|
||||
assert_eq!(parsed.jobs[1].job_type, "agent");
|
||||
assert_eq!(
|
||||
parsed.jobs[1].prompt.as_deref(),
|
||||
Some("Check server health")
|
||||
);
|
||||
assert!(matches!(
|
||||
parsed.jobs[1].schedule,
|
||||
crate::config::schema::CronScheduleDecl::Every { every_ms: 300_000 }
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,6 +127,10 @@ fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_source() -> String {
|
||||
"imperative".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronJob {
|
||||
pub id: String,
|
||||
@@ -146,6 +150,9 @@ pub struct CronJob {
|
||||
/// When `None`, all tools are available (backward compatible default).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// How the job was created: `"imperative"` (CLI/API) or `"declarative"` (config).
|
||||
#[serde(default = "default_source")]
|
||||
pub source: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub next_run: DateTime<Utc>,
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
|
||||
+193
-1
@@ -362,10 +362,22 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
// Re-read session context on every tick so we pick up messages
|
||||
// that arrived since the daemon started.
|
||||
let session_context = if config.heartbeat.load_session_context {
|
||||
load_heartbeat_session_context(&config)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut tick_had_error = false;
|
||||
for task in &tasks_to_run {
|
||||
let task_start = std::time::Instant::now();
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let task_prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let prompt = match &session_context {
|
||||
Some(ctx) => format!("{ctx}\n\n{task_prompt}"),
|
||||
None => task_prompt,
|
||||
};
|
||||
let temp = config.default_temperature;
|
||||
match Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
@@ -497,6 +509,186 @@ fn resolve_heartbeat_delivery(config: &Config) -> Result<Option<(String, String)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load recent conversation history for the heartbeat's delivery target and
|
||||
/// format it as a text preamble to inject into the task prompt.
|
||||
///
|
||||
/// Scans `{workspace}/sessions/` for JSONL files whose name starts with
|
||||
/// `{channel}_` and ends with `_{to}.jsonl` (or exactly `{channel}_{to}.jsonl`),
|
||||
/// then picks the most recently modified match. This handles session key
|
||||
/// formats such as `telegram_diskiller.jsonl` and
|
||||
/// `telegram_5673725398_diskiller.jsonl`.
|
||||
/// Returns `None` when `target`/`to` are not configured or no session exists.
|
||||
const HEARTBEAT_SESSION_CONTEXT_MESSAGES: usize = 20;
|
||||
|
||||
fn load_heartbeat_session_context(config: &Config) -> Option<String> {
|
||||
use crate::providers::traits::ChatMessage;
|
||||
|
||||
let channel = config
|
||||
.heartbeat
|
||||
.target
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())?;
|
||||
let to = config
|
||||
.heartbeat
|
||||
.to
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())?;
|
||||
|
||||
if channel.contains('/') || channel.contains('\\') || to.contains('/') || to.contains('\\') {
|
||||
tracing::warn!("heartbeat session context: channel/to contains path separators, skipping");
|
||||
return None;
|
||||
}
|
||||
|
||||
let sessions_dir = config.workspace_dir.join("sessions");
|
||||
|
||||
// Find the most recently modified JSONL file that belongs to this target.
|
||||
// Matches both `{channel}_{to}.jsonl` and `{channel}_{anything}_{to}.jsonl`.
|
||||
let prefix = format!("{channel}_");
|
||||
let suffix = format!("_{to}.jsonl");
|
||||
let exact = format!("{channel}_{to}.jsonl");
|
||||
let mid_prefix = format!("{channel}_{to}_");
|
||||
|
||||
let path = std::fs::read_dir(&sessions_dir)
|
||||
.ok()?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
let name = e.file_name();
|
||||
let name = name.to_string_lossy();
|
||||
name.ends_with(".jsonl")
|
||||
&& (name == exact
|
||||
|| (name.starts_with(&prefix) && name.ends_with(&suffix))
|
||||
|| name.starts_with(&mid_prefix))
|
||||
})
|
||||
.max_by_key(|e| {
|
||||
e.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
|
||||
})
|
||||
.map(|e| e.path())?;
|
||||
|
||||
if !path.exists() {
|
||||
tracing::debug!("💓 Heartbeat session context: no session file found for {channel}/{to}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let messages = load_jsonl_messages(&path);
|
||||
if messages.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let recent: Vec<&ChatMessage> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" || m.role == "assistant")
|
||||
.rev()
|
||||
.take(HEARTBEAT_SESSION_CONTEXT_MESSAGES)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
// Only inject context if there is at least one real user message in the
|
||||
// window. If the JSONL contains only assistant messages (e.g. previous
|
||||
// heartbeat outputs with no reply yet), skip context to avoid feeding
|
||||
// Monika's own messages back to her in a loop.
|
||||
let has_user_message = recent.iter().any(|m| m.role == "user");
|
||||
if !has_user_message {
|
||||
tracing::debug!(
|
||||
"💓 Heartbeat session context: no user messages in recent history — skipping"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
// Use the session file's mtime as a proxy for when the last message arrived.
|
||||
let last_message_age = std::fs::metadata(&path)
|
||||
.ok()
|
||||
.and_then(|m| m.modified().ok())
|
||||
.and_then(|mtime| mtime.elapsed().ok());
|
||||
|
||||
let silence_note = match last_message_age {
|
||||
Some(age) => {
|
||||
let mins = age.as_secs() / 60;
|
||||
if mins < 60 {
|
||||
format!("(last message ~{mins} minutes ago)\n")
|
||||
} else {
|
||||
let hours = mins / 60;
|
||||
let rem = mins % 60;
|
||||
if rem == 0 {
|
||||
format!("(last message ~{hours}h ago)\n")
|
||||
} else {
|
||||
format!("(last message ~{hours}h {rem}m ago)\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
"💓 Heartbeat session context: {} messages from {}, silence: {}",
|
||||
recent.len(),
|
||||
path.display(),
|
||||
silence_note.trim(),
|
||||
);
|
||||
|
||||
let mut ctx = format!(
|
||||
"[Recent conversation history — use this for context when composing your message] {silence_note}",
|
||||
);
|
||||
for msg in &recent {
|
||||
let label = if msg.role == "user" { "User" } else { "You" };
|
||||
// Truncate very long messages to avoid bloating the prompt.
|
||||
// Use char_indices to avoid panicking on multi-byte UTF-8 characters.
|
||||
let content = if msg.content.len() > 500 {
|
||||
let truncate_at = msg
|
||||
.content
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 500)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
format!("{}…", &msg.content[..truncate_at])
|
||||
} else {
|
||||
msg.content.clone()
|
||||
};
|
||||
ctx.push_str(label);
|
||||
ctx.push_str(": ");
|
||||
ctx.push_str(&content);
|
||||
ctx.push('\n');
|
||||
}
|
||||
|
||||
Some(ctx)
|
||||
}
|
||||
|
||||
/// Read the last `HEARTBEAT_SESSION_CONTEXT_MESSAGES` `ChatMessage` lines from
|
||||
/// a JSONL session file using a bounded rolling window so we never hold the
|
||||
/// entire file in memory.
|
||||
fn load_jsonl_messages(path: &std::path::Path) -> Vec<crate::providers::traits::ChatMessage> {
|
||||
use std::collections::VecDeque;
|
||||
use std::io::BufRead;
|
||||
|
||||
let file = match std::fs::File::open(path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut window: VecDeque<crate::providers::traits::ChatMessage> =
|
||||
VecDeque::with_capacity(HEARTBEAT_SESSION_CONTEXT_MESSAGES + 1);
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<crate::providers::traits::ChatMessage>(trimmed) {
|
||||
window.push_back(msg);
|
||||
if window.len() > HEARTBEAT_SESSION_CONTEXT_MESSAGES {
|
||||
window.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
window.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Auto-detect the best channel for heartbeat delivery by checking which
|
||||
/// channels are configured. Returns the first match in priority order.
|
||||
fn auto_detect_heartbeat_channel(config: &Config) -> Option<(String, String)> {
|
||||
|
||||
+56
-2
@@ -1280,12 +1280,16 @@ pub async fn handle_api_sessions_list(
|
||||
.into_iter()
|
||||
.filter_map(|meta| {
|
||||
let session_id = meta.key.strip_prefix("gw_")?;
|
||||
Some(serde_json::json!({
|
||||
let mut entry = serde_json::json!({
|
||||
"session_id": session_id,
|
||||
"created_at": meta.created_at.to_rfc3339(),
|
||||
"last_activity": meta.last_activity.to_rfc3339(),
|
||||
"message_count": meta.message_count,
|
||||
}))
|
||||
});
|
||||
if let Some(name) = meta.name {
|
||||
entry["name"] = serde_json::Value::String(name);
|
||||
}
|
||||
Some(entry)
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1326,6 +1330,56 @@ pub async fn handle_api_session_delete(
|
||||
}
|
||||
}
|
||||
|
||||
/// PUT /api/sessions/{id} — rename a gateway session
|
||||
pub async fn handle_api_session_rename(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let Some(ref backend) = state.session_backend else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Session persistence is disabled"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let name = body["name"].as_str().unwrap_or("").trim();
|
||||
if name.is_empty() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "name is required"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let session_key = format!("gw_{id}");
|
||||
|
||||
// Verify the session exists before renaming
|
||||
let sessions = backend.list_sessions();
|
||||
if !sessions.contains(&session_key) {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Session not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match backend.set_session_name(&session_key, name) {
|
||||
Ok(()) => Json(serde_json::json!({"session_id": id, "name": name})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": format!("Failed to rename session: {e}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
+1
-1
@@ -886,7 +886,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/api/cli-tools", get(api::handle_api_cli_tools))
|
||||
.route("/api/health", get(api::handle_api_health))
|
||||
.route("/api/sessions", get(api::handle_api_sessions_list))
|
||||
.route("/api/sessions/{id}", delete(api::handle_api_session_delete))
|
||||
.route("/api/sessions/{id}", delete(api::handle_api_session_delete).put(api::handle_api_session_rename))
|
||||
// ── Pairing + Device management API ──
|
||||
.route("/api/pairing/initiate", post(api_pairing::initiate_pairing))
|
||||
.route("/api/pair", post(api_pairing::submit_pairing_enhanced))
|
||||
|
||||
+34
-3
@@ -1,13 +1,21 @@
|
||||
//! WebSocket agent chat handler.
|
||||
//!
|
||||
//! Connect: `ws://host:port/ws/chat?session_id=ID&name=My+Session`
|
||||
//!
|
||||
//! Protocol:
|
||||
//! ```text
|
||||
//! Server -> Client: {"type":"session_start","session_id":"...","name":"...","resumed":true,"message_count":42}
|
||||
//! Client -> Server: {"type":"message","content":"Hello"}
|
||||
//! Server -> Client: {"type":"chunk","content":"Hi! "}
|
||||
//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
|
||||
//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
|
||||
//! Server -> Client: {"type":"done","full_response":"..."}
|
||||
//! ```
|
||||
//!
|
||||
//! Query params:
|
||||
//! - `session_id` — resume or create a session (default: new UUID)
|
||||
//! - `name` — optional human-readable label for the session
|
||||
//! - `token` — bearer auth token (alternative to Authorization header)
|
||||
|
||||
use super::AppState;
|
||||
use axum::{
|
||||
@@ -53,6 +61,8 @@ const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
|
||||
pub struct WsQuery {
|
||||
pub token: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
/// Optional human-readable name for the session.
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a bearer token from WebSocket-compatible sources.
|
||||
@@ -134,14 +144,20 @@ pub async fn handle_ws_chat(
|
||||
};
|
||||
|
||||
let session_id = params.session_id;
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
|
||||
let session_name = params.name;
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Gateway session key prefix to avoid collisions with channel sessions.
|
||||
const GW_SESSION_PREFIX: &str = "gw_";
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<String>) {
|
||||
async fn handle_socket(
|
||||
socket: WebSocket,
|
||||
state: AppState,
|
||||
session_id: Option<String>,
|
||||
session_name: Option<String>,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Resolve session ID: use provided or generate a new UUID
|
||||
@@ -163,6 +179,7 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
|
||||
// Hydrate agent from persisted session (if available)
|
||||
let mut resumed = false;
|
||||
let mut message_count: usize = 0;
|
||||
let mut effective_name: Option<String> = None;
|
||||
if let Some(ref backend) = state.session_backend {
|
||||
let messages = backend.load(&session_key);
|
||||
if !messages.is_empty() {
|
||||
@@ -170,15 +187,29 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
|
||||
agent.seed_history(&messages);
|
||||
resumed = true;
|
||||
}
|
||||
// Set session name if provided (non-empty) on connect
|
||||
if let Some(ref name) = session_name {
|
||||
if !name.is_empty() {
|
||||
let _ = backend.set_session_name(&session_key, name);
|
||||
effective_name = Some(name.clone());
|
||||
}
|
||||
}
|
||||
// If no name was provided via query param, load the stored name
|
||||
if effective_name.is_none() {
|
||||
effective_name = backend.get_session_name(&session_key).unwrap_or(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Send session_start message to client
|
||||
let session_start = serde_json::json!({
|
||||
let mut session_start = serde_json::json!({
|
||||
"type": "session_start",
|
||||
"session_id": session_id,
|
||||
"resumed": resumed,
|
||||
"message_count": message_count,
|
||||
});
|
||||
if let Some(ref name) = effective_name {
|
||||
session_start["name"] = serde_json::Value::String(name.clone());
|
||||
}
|
||||
let _ = sender
|
||||
.send(Message::Text(session_start.to_string().into()))
|
||||
.await;
|
||||
|
||||
@@ -891,6 +891,7 @@ mod tests {
|
||||
device_id: None,
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec![],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
});
|
||||
let entries = all_integrations();
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
//! Audit trail for memory operations.
|
||||
//!
|
||||
//! Provides a decorator `AuditedMemory<M>` that wraps any `Memory` backend
|
||||
//! and logs all operations to a `memory_audit` table. Opt-in via
|
||||
//! `[memory] audit_enabled = true`.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Audit log entry operations.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum AuditOp {
|
||||
Store,
|
||||
Recall,
|
||||
Get,
|
||||
List,
|
||||
Forget,
|
||||
StoreProcedural,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AuditOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Store => write!(f, "store"),
|
||||
Self::Recall => write!(f, "recall"),
|
||||
Self::Get => write!(f, "get"),
|
||||
Self::List => write!(f, "list"),
|
||||
Self::Forget => write!(f, "forget"),
|
||||
Self::StoreProcedural => write!(f, "store_procedural"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decorator that wraps a `Memory` backend with audit logging.
|
||||
pub struct AuditedMemory<M: Memory> {
|
||||
inner: M,
|
||||
audit_conn: Arc<Mutex<Connection>>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl<M: Memory> AuditedMemory<M> {
|
||||
pub fn new(inner: M, workspace_dir: &Path) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join("audit.db");
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&db_path)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
CREATE TABLE IF NOT EXISTS memory_audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
operation TEXT NOT NULL,
|
||||
key TEXT,
|
||||
namespace TEXT,
|
||||
session_id TEXT,
|
||||
timestamp TEXT NOT NULL,
|
||||
metadata TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON memory_audit(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_operation ON memory_audit(operation);",
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
audit_conn: Arc::new(Mutex::new(conn)),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn log_audit(
|
||||
&self,
|
||||
op: AuditOp,
|
||||
key: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
session_id: Option<&str>,
|
||||
metadata: Option<&str>,
|
||||
) {
|
||||
let conn = self.audit_conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let op_str = op.to_string();
|
||||
let _ = conn.execute(
|
||||
"INSERT INTO memory_audit (operation, key, namespace, session_id, timestamp, metadata)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![op_str, key, namespace, session_id, now, metadata],
|
||||
);
|
||||
}
|
||||
|
||||
/// Prune audit entries older than the given number of days.
|
||||
pub fn prune_older_than(&self, retention_days: u32) -> anyhow::Result<u64> {
|
||||
let conn = self.audit_conn.lock();
|
||||
let cutoff =
|
||||
(Local::now() - chrono::Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memory_audit WHERE timestamp < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
Ok(u64::try_from(affected).unwrap_or(0))
|
||||
}
|
||||
|
||||
/// Count total audit entries.
|
||||
pub fn audit_count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self.audit_conn.lock();
|
||||
let count: i64 =
|
||||
conn.query_row("SELECT COUNT(*) FROM memory_audit", [], |row| row.get(0))?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(count as usize)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<M: Memory> Memory for AuditedMemory<M> {
|
||||
fn name(&self) -> &str {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(AuditOp::Store, Some(key), None, session_id, None);
|
||||
self.inner.store(key, content, category, session_id).await
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(
|
||||
AuditOp::Recall,
|
||||
None,
|
||||
None,
|
||||
session_id,
|
||||
Some(&format!("query={query}")),
|
||||
);
|
||||
self.inner
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
self.log_audit(AuditOp::Get, Some(key), None, None, None);
|
||||
self.inner.get(key).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(AuditOp::List, None, None, session_id, None);
|
||||
self.inner.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
self.log_audit(AuditOp::Forget, Some(key), None, None, None);
|
||||
self.inner.forget(key).await
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
self.inner.count().await
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.inner.health_check().await
|
||||
}
|
||||
|
||||
async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(
|
||||
AuditOp::StoreProcedural,
|
||||
None,
|
||||
None,
|
||||
session_id,
|
||||
Some(&format!("messages={}", messages.len())),
|
||||
);
|
||||
self.inner.store_procedural(messages, session_id).await
|
||||
}
|
||||
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(
|
||||
AuditOp::Recall,
|
||||
None,
|
||||
Some(namespace),
|
||||
session_id,
|
||||
Some(&format!("query={query}")),
|
||||
);
|
||||
self.inner
|
||||
.recall_namespaced(namespace, query, limit, session_id, since, until)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(AuditOp::Store, Some(key), namespace, session_id, None);
|
||||
self.inner
|
||||
.store_with_metadata(key, content, category, session_id, namespace, importance)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::NoneMemory;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_logs_store_operation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
audited
|
||||
.store("test_key", "test_value", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(audited.audit_count().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_logs_recall_operation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
let _ = audited.recall("query", 10, None, None, None).await;
|
||||
|
||||
assert_eq!(audited.audit_count().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_prune_works() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
audited
|
||||
.store("k1", "v1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Pruning with 0 days should remove entries
|
||||
let pruned = audited.prune_older_than(0).unwrap();
|
||||
// Entry was just created, so 0-day retention should remove it
|
||||
// Pruning should succeed (pruned is usize, always >= 0)
|
||||
let _ = pruned;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_delegates_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
assert_eq!(audited.name(), "none");
|
||||
assert!(audited.health_check().await);
|
||||
assert_eq!(audited.count().await.unwrap(), 0);
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ pub enum MemoryBackendKind {
|
||||
Lucid,
|
||||
Postgres,
|
||||
Qdrant,
|
||||
Mem0,
|
||||
Markdown,
|
||||
None,
|
||||
Unknown,
|
||||
@@ -66,15 +65,6 @@ const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const MEM0_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "mem0",
|
||||
label: "Mem0 (OpenMemory) — semantic memory with LLM fact extraction via [memory.mem0]",
|
||||
auto_save_default: true,
|
||||
uses_sqlite_hygiene: false,
|
||||
sqlite_based: false,
|
||||
optional_dependency: true,
|
||||
};
|
||||
|
||||
const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "none",
|
||||
label: "None — disable persistent memory",
|
||||
@@ -114,7 +104,6 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
|
||||
"lucid" => MemoryBackendKind::Lucid,
|
||||
"postgres" => MemoryBackendKind::Postgres,
|
||||
"qdrant" => MemoryBackendKind::Qdrant,
|
||||
"mem0" | "openmemory" => MemoryBackendKind::Mem0,
|
||||
"markdown" => MemoryBackendKind::Markdown,
|
||||
"none" => MemoryBackendKind::None,
|
||||
_ => MemoryBackendKind::Unknown,
|
||||
@@ -127,7 +116,6 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
|
||||
MemoryBackendKind::Lucid => LUCID_PROFILE,
|
||||
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
|
||||
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
|
||||
MemoryBackendKind::Mem0 => MEM0_PROFILE,
|
||||
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
|
||||
MemoryBackendKind::None => NONE_PROFILE,
|
||||
MemoryBackendKind::Unknown => CUSTOM_PROFILE,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,173 @@
|
||||
//! Conflict resolution for memory entries.
|
||||
//!
|
||||
//! Before storing Core memories, performs a semantic similarity check against
|
||||
//! existing entries. If cosine similarity exceeds a threshold but content
|
||||
//! differs, the old entry is marked as superseded.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
|
||||
/// Check for conflicting memories and mark old ones as superseded.
|
||||
///
|
||||
/// Returns the list of entry IDs that were superseded.
|
||||
pub async fn check_and_resolve_conflicts(
|
||||
memory: &dyn Memory,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: &MemoryCategory,
|
||||
threshold: f64,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
// Only check conflicts for Core memories
|
||||
if !matches!(category, MemoryCategory::Core) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Search for similar existing entries
|
||||
let candidates = memory.recall(content, 10, None, None, None).await?;
|
||||
|
||||
let mut superseded = Vec::new();
|
||||
for candidate in &candidates {
|
||||
if candidate.key == key {
|
||||
continue; // Same key = update, not conflict
|
||||
}
|
||||
if !matches!(candidate.category, MemoryCategory::Core) {
|
||||
continue;
|
||||
}
|
||||
if let Some(score) = candidate.score {
|
||||
if score > threshold && candidate.content != content {
|
||||
superseded.push(candidate.id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(superseded)
|
||||
}
|
||||
|
||||
/// Mark entries as superseded in SQLite by setting their `superseded_by` column.
|
||||
pub fn mark_superseded(
|
||||
conn: &rusqlite::Connection,
|
||||
superseded_ids: &[String],
|
||||
new_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if superseded_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for id in superseded_ids {
|
||||
conn.execute(
|
||||
"UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
|
||||
rusqlite::params![new_id, id],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Simple text-based conflict detection without embeddings.
|
||||
///
|
||||
/// Uses token overlap (Jaccard similarity) as a fast approximation
|
||||
/// when vector embeddings are unavailable.
|
||||
pub fn jaccard_similarity(a: &str, b: &str) -> f64 {
|
||||
let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
|
||||
let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
|
||||
|
||||
if words_a.is_empty() && words_b.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
if words_a.is_empty() || words_b.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let intersection = words_a.intersection(&words_b).count();
|
||||
let union = words_a.union(&words_b).count();
|
||||
|
||||
if union == 0 {
|
||||
0.0
|
||||
} else {
|
||||
intersection as f64 / union as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Find potentially conflicting entries using text similarity when embeddings
|
||||
/// are not available. Returns entries above the threshold.
|
||||
pub fn find_text_conflicts(
|
||||
entries: &[MemoryEntry],
|
||||
new_content: &str,
|
||||
threshold: f64,
|
||||
) -> Vec<String> {
|
||||
entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
matches!(e.category, MemoryCategory::Core)
|
||||
&& e.superseded_by.is_none()
|
||||
&& jaccard_similarity(&e.content, new_content) > threshold
|
||||
&& e.content != new_content
|
||||
})
|
||||
.map(|e| e.id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn jaccard_identical_strings() {
|
||||
let sim = jaccard_similarity("hello world", "hello world");
|
||||
assert!((sim - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_disjoint_strings() {
|
||||
let sim = jaccard_similarity("hello world", "foo bar");
|
||||
assert!(sim.abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_partial_overlap() {
|
||||
let sim = jaccard_similarity("the quick brown fox", "the slow brown dog");
|
||||
// overlap: "the", "brown" = 2; union: "the", "quick", "brown", "fox", "slow", "dog" = 6
|
||||
assert!((sim - 2.0 / 6.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_empty_strings() {
|
||||
assert!((jaccard_similarity("", "") - 1.0).abs() < f64::EPSILON);
|
||||
assert!(jaccard_similarity("hello", "").abs() < f64::EPSILON);
|
||||
assert!(jaccard_similarity("", "hello").abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_text_conflicts_filters_correctly() {
|
||||
let entries = vec![
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "pref".into(),
|
||||
content: "User prefers Rust for systems work".into(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.7),
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
key: "daily1".into(),
|
||||
content: "User prefers Rust for systems work".into(),
|
||||
category: MemoryCategory::Daily,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.3),
|
||||
superseded_by: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Only Core entries should be flagged
|
||||
let conflicts = find_text_conflicts(&entries, "User now prefers Go for systems work", 0.3);
|
||||
assert_eq!(conflicts.len(), 1);
|
||||
assert_eq!(conflicts[0], "1");
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@
|
||||
//! This two-phase approach replaces the naive raw-message auto-save with
|
||||
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
|
||||
|
||||
use crate::memory::conflict;
|
||||
use crate::memory::importance;
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
@@ -78,8 +80,33 @@ pub async fn consolidate_turn(
|
||||
if let Some(ref update) = result.memory_update {
|
||||
if !update.trim().is_empty() {
|
||||
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
|
||||
|
||||
// Compute importance score heuristically.
|
||||
let imp = importance::compute_importance(update, &MemoryCategory::Core);
|
||||
|
||||
// Check for conflicts with existing Core memories.
|
||||
if let Err(e) = conflict::check_and_resolve_conflicts(
|
||||
memory,
|
||||
&mem_key,
|
||||
update,
|
||||
&MemoryCategory::Core,
|
||||
0.85,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::debug!("conflict check skipped: {e}");
|
||||
}
|
||||
|
||||
// Store with importance metadata.
|
||||
memory
|
||||
.store(&mem_key, update, MemoryCategory::Core, None)
|
||||
.store_with_metadata(
|
||||
&mem_key,
|
||||
update,
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
None,
|
||||
Some(imp),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
use super::traits::{MemoryCategory, MemoryEntry};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Default half-life in days for time-decay scoring.
|
||||
/// After this many days, a non-Core memory's score drops to 50%.
|
||||
pub const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Apply exponential time decay to memory entry scores.
|
||||
///
|
||||
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
|
||||
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
|
||||
/// - Entries without a score (`None`) are left unchanged.
|
||||
///
|
||||
/// Decay formula: `score * 2^(-age_days / half_life_days)`
|
||||
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
|
||||
let half_life = if half_life_days <= 0.0 {
|
||||
DEFAULT_HALF_LIFE_DAYS
|
||||
} else {
|
||||
half_life_days
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
// Core memories are evergreen — never decay
|
||||
if entry.category == MemoryCategory::Core {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score = match entry.score {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
Ok(dt) => dt.with_timezone(&Utc),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let age_days = now.signed_duration_since(ts).num_seconds().max(0) as f64 / 86_400.0;
|
||||
|
||||
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
|
||||
entry.score = Some(score * decay_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "test".into(),
|
||||
content: "value".into(),
|
||||
category,
|
||||
timestamp: timestamp.into(),
|
||||
session_id: None,
|
||||
score,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn recent_rfc3339() -> String {
|
||||
Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
fn days_ago_rfc3339(days: i64) -> String {
|
||||
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn core_memories_are_never_decayed() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Core,
|
||||
Some(0.9),
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recent_entry_score_barely_changes() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.8),
|
||||
&recent_rfc3339(),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.8).abs() < 0.01,
|
||||
"recent entry should barely decay, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_half_life_halves_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(7),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.5).abs() < 0.05,
|
||||
"score after one half-life should be ~0.5, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_half_lives_quarters_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(14),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.25).abs() < 0.05,
|
||||
"score after two half-lives should be ~0.25, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_score_entry_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_timestamp_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.9),
|
||||
"not-a-date",
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
}
|
||||
+42
-4
@@ -1,4 +1,5 @@
|
||||
use crate::config::MemoryConfig;
|
||||
use crate::memory::policy::PolicyEnforcer;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Local, NaiveDate, Utc};
|
||||
use rusqlite::{params, Connection};
|
||||
@@ -47,6 +48,13 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Use policy engine for per-category retention overrides.
|
||||
let enforcer = PolicyEnforcer::new(&config.policy);
|
||||
let conversation_retention = enforcer.retention_days_for_category(
|
||||
&crate::memory::traits::MemoryCategory::Conversation,
|
||||
config.conversation_retention_days,
|
||||
);
|
||||
|
||||
let report = HygieneReport {
|
||||
archived_memory_files: archive_daily_memory_files(
|
||||
workspace_dir,
|
||||
@@ -55,12 +63,16 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||
archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?,
|
||||
purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?,
|
||||
purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?,
|
||||
pruned_conversation_rows: prune_conversation_rows(
|
||||
workspace_dir,
|
||||
config.conversation_retention_days,
|
||||
)?,
|
||||
pruned_conversation_rows: prune_conversation_rows(workspace_dir, conversation_retention)?,
|
||||
};
|
||||
|
||||
// Prune audit entries if audit is enabled.
|
||||
if config.audit_enabled {
|
||||
if let Err(e) = prune_audit_entries(workspace_dir, config.audit_retention_days) {
|
||||
tracing::debug!("audit pruning skipped: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
write_state(workspace_dir, &report)?;
|
||||
|
||||
if report.total_actions() > 0 {
|
||||
@@ -318,6 +330,32 @@ fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result<
|
||||
Ok(u64::try_from(affected).unwrap_or(0))
|
||||
}
|
||||
|
||||
fn prune_audit_entries(workspace_dir: &Path, retention_days: u32) -> Result<()> {
|
||||
if retention_days == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let db_path = workspace_dir.join("memory").join("audit.db");
|
||||
if !db_path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)?;
|
||||
conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?;
|
||||
let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memory_audit WHERE timestamp < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
|
||||
if affected > 0 {
|
||||
tracing::debug!("pruned {affected} audit entries older than {retention_days} days");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn memory_date_from_filename(filename: &str) -> Option<NaiveDate> {
|
||||
let stem = filename.strip_suffix(".md")?;
|
||||
let date_part = stem.split('_').next().unwrap_or(stem);
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
//! Heuristic importance scorer for non-LLM paths.
|
||||
//!
|
||||
//! Assigns importance scores (0.0–1.0) based on memory category and keyword
|
||||
//! signals. Used when LLM-based consolidation is unavailable or as a fast
|
||||
//! first-pass scorer.
|
||||
|
||||
use super::traits::MemoryCategory;
|
||||
|
||||
/// Base importance by category.
|
||||
fn category_base_score(category: &MemoryCategory) -> f64 {
|
||||
match category {
|
||||
MemoryCategory::Core => 0.7,
|
||||
MemoryCategory::Daily => 0.3,
|
||||
MemoryCategory::Conversation => 0.2,
|
||||
MemoryCategory::Custom(_) => 0.4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyword boost: if the content contains high-signal keywords, bump importance.
|
||||
fn keyword_boost(content: &str) -> f64 {
|
||||
const HIGH_SIGNAL_KEYWORDS: &[&str] = &[
|
||||
"decision",
|
||||
"always",
|
||||
"never",
|
||||
"important",
|
||||
"critical",
|
||||
"must",
|
||||
"requirement",
|
||||
"policy",
|
||||
"rule",
|
||||
"principle",
|
||||
];
|
||||
|
||||
let lowered = content.to_ascii_lowercase();
|
||||
let matches = HIGH_SIGNAL_KEYWORDS
|
||||
.iter()
|
||||
.filter(|kw| lowered.contains(**kw))
|
||||
.count();
|
||||
|
||||
// Cap at +0.2
|
||||
(matches as f64 * 0.1).min(0.2)
|
||||
}
|
||||
|
||||
/// Compute heuristic importance score for a memory entry.
|
||||
pub fn compute_importance(content: &str, category: &MemoryCategory) -> f64 {
|
||||
let base = category_base_score(category);
|
||||
let boost = keyword_boost(content);
|
||||
(base + boost).min(1.0)
|
||||
}
|
||||
|
||||
/// Compute final retrieval score incorporating importance and recency.
|
||||
///
|
||||
/// `hybrid_score`: raw retrieval score from FTS/vector (0.0–1.0)
|
||||
/// `importance`: importance score (0.0–1.0)
|
||||
/// `recency_decay`: recency factor (0.0–1.0, 1.0 = very recent)
|
||||
pub fn weighted_final_score(hybrid_score: f64, importance: f64, recency_decay: f64) -> f64 {
|
||||
hybrid_score * 0.7 + importance * 0.2 + recency_decay * 0.1
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn core_category_has_high_base_score() {
|
||||
let score = compute_importance("some fact", &MemoryCategory::Core);
|
||||
assert!((score - 0.7).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversation_category_has_low_base_score() {
|
||||
let score = compute_importance("chat message", &MemoryCategory::Conversation);
|
||||
assert!((score - 0.2).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keywords_boost_importance() {
|
||||
let score = compute_importance(
|
||||
"This is a critical decision that must always be followed",
|
||||
&MemoryCategory::Core,
|
||||
);
|
||||
// base 0.7 + boost for "critical", "decision", "must", "always" = 0.7 + 0.2 (capped) = 0.9
|
||||
assert!(score > 0.85);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boost_capped_at_point_two() {
|
||||
let score = compute_importance(
|
||||
"important critical decision rule policy must always never requirement principle",
|
||||
&MemoryCategory::Conversation,
|
||||
);
|
||||
// base 0.2 + max boost 0.2 = 0.4
|
||||
assert!((score - 0.4).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weighted_final_score_formula() {
|
||||
let score = weighted_final_score(1.0, 1.0, 1.0);
|
||||
assert!((score - 1.0).abs() < f64::EPSILON);
|
||||
|
||||
let score = weighted_final_score(0.0, 0.0, 0.0);
|
||||
assert!(score.abs() < f64::EPSILON);
|
||||
|
||||
let score = weighted_final_score(0.5, 0.5, 0.5);
|
||||
assert!((score - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -226,6 +226,9 @@ impl LucidMemory {
|
||||
timestamp: now.clone(),
|
||||
session_id: None,
|
||||
score: Some((1.0 - rank as f64 * 0.05).max(0.1)),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -91,6 +91,9 @@ impl MarkdownMemory {
|
||||
timestamp: filename.to_string(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
|
||||
@@ -1,635 +0,0 @@
|
||||
//! Mem0 (OpenMemory) memory backend.
|
||||
//!
|
||||
//! Connects to a self-hosted OpenMemory server via its REST API
|
||||
//! and implements the [`Memory`] trait for seamless integration with
|
||||
//! ZeroClaw's auto-save, auto-recall, and hygiene lifecycle.
|
||||
//!
|
||||
//! Deploy OpenMemory: `docker compose up` from the mem0 repo.
|
||||
//! Default endpoint: `http://localhost:8765`.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
|
||||
use crate::config::schema::Mem0Config;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Memory backend backed by a mem0 (OpenMemory) REST API.
|
||||
pub struct Mem0Memory {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
user_id: String,
|
||||
app_name: String,
|
||||
infer: bool,
|
||||
extraction_prompt: Option<String>,
|
||||
}
|
||||
|
||||
// ── mem0 API request/response types ────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AddMemoryRequest<'a> {
|
||||
user_id: &'a str,
|
||||
text: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
metadata: Option<Mem0Metadata<'a>>,
|
||||
infer: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
app: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
custom_instructions: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Mem0Metadata<'a> {
|
||||
key: &'a str,
|
||||
category: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AddProceduralRequest<'a> {
|
||||
user_id: &'a str,
|
||||
messages: &'a [ProceduralMessage],
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DeleteMemoriesRequest<'a> {
|
||||
memory_ids: Vec<&'a str>,
|
||||
user_id: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0MemoryItem {
|
||||
id: String,
|
||||
#[serde(alias = "content", alias = "text", default)]
|
||||
memory: String,
|
||||
#[serde(default)]
|
||||
created_at: Option<serde_json::Value>,
|
||||
#[serde(default, rename = "metadata_")]
|
||||
metadata: Option<Mem0ResponseMetadata>,
|
||||
#[serde(alias = "relevance_score", default)]
|
||||
score: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
struct Mem0ResponseMetadata {
|
||||
#[serde(default)]
|
||||
key: Option<String>,
|
||||
#[serde(default)]
|
||||
category: Option<String>,
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0ListResponse {
|
||||
#[serde(default)]
|
||||
items: Vec<Mem0MemoryItem>,
|
||||
#[serde(default)]
|
||||
total: usize,
|
||||
}
|
||||
|
||||
// ── Implementation ─────────────────────────────────────────────────
|
||||
|
||||
impl Mem0Memory {
|
||||
/// Create a new mem0 memory backend from config.
|
||||
pub fn new(config: &Mem0Config) -> anyhow::Result<Self> {
|
||||
let base_url = config.url.trim_end_matches('/').to_string();
|
||||
if base_url.is_empty() {
|
||||
anyhow::bail!("mem0 URL is empty; set [memory.mem0] url or MEM0_URL env var");
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
user_id: config.user_id.clone(),
|
||||
app_name: config.app_name.clone(),
|
||||
infer: config.infer,
|
||||
extraction_prompt: config.extraction_prompt.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn api_url(&self, path: &str) -> String {
|
||||
format!("{}/api/v1{}", self.base_url, path)
|
||||
}
|
||||
|
||||
/// Use `session_id` as the effective mem0 `user_id` when provided,
|
||||
/// falling back to the configured default. This enables per-user
|
||||
/// and per-group memory scoping via the existing `Memory` trait.
|
||||
fn effective_user_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
|
||||
session_id
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
/// Recall memories with optional search filters.
|
||||
///
|
||||
/// - `created_after` / `created_before`: ISO 8601 timestamps for time-range filtering.
|
||||
/// - `metadata_filter`: arbitrary JSON object passed to the mem0 SDK `filters` param.
|
||||
pub async fn recall_filtered(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
created_after: Option<&str>,
|
||||
created_before: Option<&str>,
|
||||
metadata_filter: Option<&serde_json::Value>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let limit_str = limit.to_string();
|
||||
let mut params: Vec<(&str, &str)> = vec![
|
||||
("user_id", effective_user),
|
||||
("search_query", query),
|
||||
("size", &limit_str),
|
||||
];
|
||||
if let Some(after) = created_after {
|
||||
params.push(("created_after", after));
|
||||
}
|
||||
if let Some(before) = created_before {
|
||||
params.push(("created_before", before));
|
||||
}
|
||||
let meta_json;
|
||||
if let Some(mf) = metadata_filter {
|
||||
meta_json = serde_json::to_string(mf)?;
|
||||
params.push(("metadata_filter", &meta_json));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 recall failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
Ok(list.items.into_iter().map(|i| self.to_entry(i)).collect())
|
||||
}
|
||||
|
||||
fn to_entry(&self, item: Mem0MemoryItem) -> MemoryEntry {
|
||||
let meta = item.metadata.unwrap_or_default();
|
||||
let timestamp = match item.created_at {
|
||||
Some(serde_json::Value::Number(n)) => {
|
||||
// Unix timestamp → ISO 8601
|
||||
if let Some(ts) = n.as_i64() {
|
||||
chrono::DateTime::from_timestamp(ts, 0)
|
||||
.map(|dt| dt.to_rfc3339())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
Some(serde_json::Value::String(s)) => s,
|
||||
_ => String::new(),
|
||||
};
|
||||
|
||||
let category = match meta.category.as_deref() {
|
||||
Some("daily") => MemoryCategory::Daily,
|
||||
Some("conversation") => MemoryCategory::Conversation,
|
||||
Some(other) if other != "core" => MemoryCategory::Custom(other.to_string()),
|
||||
// "core" or None → default
|
||||
_ => MemoryCategory::Core,
|
||||
};
|
||||
|
||||
MemoryEntry {
|
||||
id: item.id,
|
||||
key: meta.key.unwrap_or_default(),
|
||||
content: item.memory,
|
||||
category,
|
||||
timestamp,
|
||||
session_id: meta.session_id,
|
||||
score: item.score,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a conversation trace as procedural memory.
|
||||
///
|
||||
/// Sends the message history (user input, tool calls, assistant response)
|
||||
/// to the mem0 procedural endpoint so that "how to" patterns can be
|
||||
/// extracted and stored for future recall.
|
||||
pub async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
if messages.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let body = AddProceduralRequest {
|
||||
user_id: effective_user,
|
||||
messages,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url("/memories/procedural"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 store_procedural failed ({status}): {text}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── History API types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0HistoryResponse {
|
||||
#[serde(default)]
|
||||
history: Vec<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl Mem0Memory {
|
||||
/// Retrieve the edit history (audit trail) for a specific memory by ID.
|
||||
pub async fn history(&self, memory_id: &str) -> anyhow::Result<String> {
|
||||
let url = self.api_url(&format!("/memories/{memory_id}/history"));
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 history failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let body: Mem0HistoryResponse = resp.json().await?;
|
||||
|
||||
if let Some(err) = body.error {
|
||||
anyhow::bail!("mem0 history error: {err}");
|
||||
}
|
||||
|
||||
if body.history.is_empty() {
|
||||
return Ok(format!("No history found for memory {memory_id}."));
|
||||
}
|
||||
|
||||
let mut lines = Vec::with_capacity(body.history.len() + 1);
|
||||
lines.push(format!("History for memory {memory_id}:"));
|
||||
|
||||
for (i, entry) in body.history.iter().enumerate() {
|
||||
let event = entry
|
||||
.get("event")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let old_memory = entry
|
||||
.get("old_memory")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-");
|
||||
let new_memory = entry
|
||||
.get("new_memory")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-");
|
||||
let timestamp = entry
|
||||
.get("created_at")
|
||||
.or_else(|| entry.get("timestamp"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
lines.push(format!(
|
||||
" {idx}. [{event}] at {timestamp}\n old: {old_memory}\n new: {new_memory}",
|
||||
idx = i + 1,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for Mem0Memory {
|
||||
fn name(&self) -> &str {
|
||||
"mem0"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cat_str = category.to_string();
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let body = AddMemoryRequest {
|
||||
user_id: effective_user,
|
||||
text: content,
|
||||
metadata: Some(Mem0Metadata {
|
||||
key,
|
||||
category: &cat_str,
|
||||
session_id,
|
||||
}),
|
||||
infer: self.infer,
|
||||
app: Some(&self.app_name),
|
||||
custom_instructions: self.extraction_prompt.as_deref(),
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url("/memories/"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 store failed ({status}): {text}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// mem0 handles filtering server-side; since/until are not yet
|
||||
// supported by the mem0 API, so we pass them through as no-ops.
|
||||
self.recall_filtered(query, limit, session_id, None, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
// mem0 doesn't have a get-by-key API, so we search by key in metadata
|
||||
let results = self.recall(key, 1, None, None, None).await?;
|
||||
Ok(results.into_iter().find(|e| e.key == key))
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[("user_id", effective_user), ("size", "100")])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 list failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
let entries: Vec<MemoryEntry> = list.items.into_iter().map(|i| self.to_entry(i)).collect();
|
||||
|
||||
// Client-side category filter (mem0 API doesn't filter by metadata)
|
||||
match category {
|
||||
Some(cat) => Ok(entries.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
None => Ok(entries),
|
||||
}
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
// Find the memory ID by key first
|
||||
let entry = self.get(key).await?;
|
||||
let entry = match entry {
|
||||
Some(e) => e,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let body = DeleteMemoriesRequest {
|
||||
memory_ids: vec![&entry.id],
|
||||
user_id: &self.user_id,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.delete(self.api_url("/memories/"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[
|
||||
("user_id", self.user_id.as_str()),
|
||||
("size", "1"),
|
||||
("page", "1"),
|
||||
])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 count failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
Ok(list.total)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[
|
||||
("user_id", self.user_id.as_str()),
|
||||
("size", "1"),
|
||||
("page", "1"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.is_ok_and(|r| r.status().is_success())
|
||||
}
|
||||
|
||||
async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Mem0Memory::store_procedural(self, messages, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> Mem0Config {
|
||||
Mem0Config {
|
||||
url: "http://localhost:8765".into(),
|
||||
user_id: "test-user".into(),
|
||||
app_name: "test-app".into(),
|
||||
infer: true,
|
||||
extraction_prompt: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_rejects_empty_url() {
|
||||
let config = Mem0Config {
|
||||
url: String::new(),
|
||||
..test_config()
|
||||
};
|
||||
assert!(Mem0Memory::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_trims_trailing_slash() {
|
||||
let config = Mem0Config {
|
||||
url: "http://localhost:8765/".into(),
|
||||
..test_config()
|
||||
};
|
||||
let mem = Mem0Memory::new(&config).unwrap();
|
||||
assert_eq!(mem.base_url, "http://localhost:8765");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_url_builds_correct_path() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
assert_eq!(
|
||||
mem.api_url("/memories/"),
|
||||
"http://localhost:8765/api/v1/memories/"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_maps_unix_timestamp() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-1".into(),
|
||||
memory: "hello".into(),
|
||||
created_at: Some(serde_json::json!(1_700_000_000)),
|
||||
metadata: Some(Mem0ResponseMetadata {
|
||||
key: Some("k1".into()),
|
||||
category: Some("core".into()),
|
||||
session_id: None,
|
||||
}),
|
||||
score: Some(0.95),
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.id, "id-1");
|
||||
assert_eq!(entry.key, "k1");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
assert!(!entry.timestamp.is_empty());
|
||||
assert_eq!(entry.score, Some(0.95));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_maps_string_timestamp() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-2".into(),
|
||||
memory: "world".into(),
|
||||
created_at: Some(serde_json::json!("2024-01-01T00:00:00Z")),
|
||||
metadata: None,
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.timestamp, "2024-01-01T00:00:00Z");
|
||||
assert_eq!(entry.category, MemoryCategory::Core); // default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_handles_missing_metadata() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-3".into(),
|
||||
memory: "bare".into(),
|
||||
created_at: None,
|
||||
metadata: None,
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.key, "");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
assert!(entry.timestamp.is_empty());
|
||||
assert_eq!(entry.score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_custom_category() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-4".into(),
|
||||
memory: "custom".into(),
|
||||
created_at: None,
|
||||
metadata: Some(Mem0ResponseMetadata {
|
||||
key: Some("k".into()),
|
||||
category: Some("project_notes".into()),
|
||||
session_id: Some("s1".into()),
|
||||
}),
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(
|
||||
entry.category,
|
||||
MemoryCategory::Custom("project_notes".into())
|
||||
);
|
||||
assert_eq!(entry.session_id.as_deref(), Some("s1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_returns_mem0() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
assert_eq!(mem.name(), "mem0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn procedural_request_serializes_messages() {
|
||||
let messages = vec![
|
||||
ProceduralMessage {
|
||||
role: "user".into(),
|
||||
content: "How do I deploy?".into(),
|
||||
name: None,
|
||||
},
|
||||
ProceduralMessage {
|
||||
role: "tool".into(),
|
||||
content: "deployment started".into(),
|
||||
name: Some("shell".into()),
|
||||
},
|
||||
ProceduralMessage {
|
||||
role: "assistant".into(),
|
||||
content: "Deployment complete.".into(),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
let req = AddProceduralRequest {
|
||||
user_id: "test-user",
|
||||
messages: &messages,
|
||||
metadata: None,
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(json["user_id"], "test-user");
|
||||
let msgs = json["messages"].as_array().unwrap();
|
||||
assert_eq!(msgs.len(), 3);
|
||||
assert_eq!(msgs[0]["role"], "user");
|
||||
assert_eq!(msgs[1]["name"], "shell");
|
||||
// metadata should be absent when None
|
||||
assert!(json.get("metadata").is_none());
|
||||
}
|
||||
}
|
||||
+16
-27
@@ -1,24 +1,33 @@
|
||||
pub mod audit;
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod conflict;
|
||||
pub mod consolidation;
|
||||
pub mod decay;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod importance;
|
||||
pub mod knowledge_graph;
|
||||
pub mod lucid;
|
||||
pub mod markdown;
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
pub mod mem0;
|
||||
pub mod none;
|
||||
pub mod policy;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub mod postgres;
|
||||
pub mod qdrant;
|
||||
pub mod response_cache;
|
||||
pub mod retrieval;
|
||||
pub mod snapshot;
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
pub mod vector;
|
||||
|
||||
#[cfg(test)]
|
||||
mod battle_tests;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use audit::AuditedMemory;
|
||||
#[allow(unused_imports)]
|
||||
pub use backend::{
|
||||
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
|
||||
@@ -26,13 +35,15 @@ pub use backend::{
|
||||
};
|
||||
pub use lucid::LucidMemory;
|
||||
pub use markdown::MarkdownMemory;
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
pub use mem0::Mem0Memory;
|
||||
pub use none::NoneMemory;
|
||||
#[allow(unused_imports)]
|
||||
pub use policy::PolicyEnforcer;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub use postgres::PostgresMemory;
|
||||
pub use qdrant::QdrantMemory;
|
||||
pub use response_cache::ResponseCache;
|
||||
#[allow(unused_imports)]
|
||||
pub use retrieval::{RetrievalConfig, RetrievalPipeline};
|
||||
pub use sqlite::SqliteMemory;
|
||||
pub use traits::Memory;
|
||||
#[allow(unused_imports)]
|
||||
@@ -61,7 +72,7 @@ where
|
||||
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
|
||||
}
|
||||
MemoryBackendKind::Postgres => postgres_builder(),
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Mem0 | MemoryBackendKind::Markdown => {
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
|
||||
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
|
||||
}
|
||||
MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())),
|
||||
@@ -340,28 +351,6 @@ pub fn create_memory_with_storage_and_routes(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
fn build_mem0_memory(config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
|
||||
let mem = Mem0Memory::new(&config.mem0)?;
|
||||
tracing::info!(
|
||||
"📦 Mem0 memory backend configured (url: {}, user: {})",
|
||||
config.mem0.url,
|
||||
config.mem0.user_id
|
||||
);
|
||||
Ok(Box::new(mem))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "memory-mem0"))]
|
||||
fn build_mem0_memory(_config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
|
||||
anyhow::bail!(
|
||||
"memory backend 'mem0' requested but this build was compiled without `memory-mem0`; rebuild with `--features memory-mem0`"
|
||||
);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Mem0) {
|
||||
return build_mem0_memory(config);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
|
||||
let url = config
|
||||
.qdrant
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
//! Policy engine for memory operations.
|
||||
//!
|
||||
//! Validates operations against configurable rules before they reach the
|
||||
//! backend. Enforces namespace quotas, category limits, read-only namespaces,
|
||||
//! and per-category retention rules.
|
||||
|
||||
use super::traits::MemoryCategory;
|
||||
use crate::config::MemoryPolicyConfig;
|
||||
|
||||
/// Policy enforcer that validates memory operations.
|
||||
pub struct PolicyEnforcer {
|
||||
config: MemoryPolicyConfig,
|
||||
}
|
||||
|
||||
impl PolicyEnforcer {
|
||||
pub fn new(config: &MemoryPolicyConfig) -> Self {
|
||||
Self {
|
||||
config: config.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a namespace is read-only.
|
||||
pub fn is_read_only(&self, namespace: &str) -> bool {
|
||||
self.config
|
||||
.read_only_namespaces
|
||||
.iter()
|
||||
.any(|ns| ns == namespace)
|
||||
}
|
||||
|
||||
/// Validate a store operation against policy rules.
|
||||
pub fn validate_store(
|
||||
&self,
|
||||
namespace: &str,
|
||||
_category: &MemoryCategory,
|
||||
) -> Result<(), PolicyViolation> {
|
||||
if self.is_read_only(namespace) {
|
||||
return Err(PolicyViolation::ReadOnlyNamespace(namespace.to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if adding an entry would exceed namespace limits.
|
||||
pub fn check_namespace_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
|
||||
if self.config.max_entries_per_namespace > 0
|
||||
&& current_count >= self.config.max_entries_per_namespace
|
||||
{
|
||||
return Err(PolicyViolation::NamespaceQuotaExceeded {
|
||||
max: self.config.max_entries_per_namespace,
|
||||
current: current_count,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if adding an entry would exceed category limits.
|
||||
pub fn check_category_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
|
||||
if self.config.max_entries_per_category > 0
|
||||
&& current_count >= self.config.max_entries_per_category
|
||||
{
|
||||
return Err(PolicyViolation::CategoryQuotaExceeded {
|
||||
max: self.config.max_entries_per_category,
|
||||
current: current_count,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the retention days for a specific category, falling back to the
|
||||
/// provided default if no per-category override exists.
|
||||
pub fn retention_days_for_category(&self, category: &MemoryCategory, default_days: u32) -> u32 {
|
||||
let key = category.to_string();
|
||||
self.config
|
||||
.retention_days_by_category
|
||||
.get(&key)
|
||||
.copied()
|
||||
.unwrap_or(default_days)
|
||||
}
|
||||
}
|
||||
|
||||
/// Policy violation errors.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PolicyViolation {
|
||||
ReadOnlyNamespace(String),
|
||||
NamespaceQuotaExceeded { max: usize, current: usize },
|
||||
CategoryQuotaExceeded { max: usize, current: usize },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PolicyViolation {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ReadOnlyNamespace(ns) => write!(f, "namespace '{ns}' is read-only"),
|
||||
Self::NamespaceQuotaExceeded { max, current } => {
|
||||
write!(f, "namespace quota exceeded: {current}/{max} entries")
|
||||
}
|
||||
Self::CategoryQuotaExceeded { max, current } => {
|
||||
write!(f, "category quota exceeded: {current}/{max} entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PolicyViolation {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn empty_policy() -> MemoryPolicyConfig {
|
||||
MemoryPolicyConfig::default()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_policy_allows_everything() {
|
||||
let enforcer = PolicyEnforcer::new(&empty_policy());
|
||||
assert!(!enforcer.is_read_only("default"));
|
||||
assert!(enforcer
|
||||
.validate_store("default", &MemoryCategory::Core)
|
||||
.is_ok());
|
||||
assert!(enforcer.check_namespace_limit(100).is_ok());
|
||||
assert!(enforcer.check_category_limit(100).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_namespace_blocks_writes() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
read_only_namespaces: vec!["archive".into()],
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.is_read_only("archive"));
|
||||
assert!(!enforcer.is_read_only("default"));
|
||||
assert!(enforcer
|
||||
.validate_store("archive", &MemoryCategory::Core)
|
||||
.is_err());
|
||||
assert!(enforcer
|
||||
.validate_store("default", &MemoryCategory::Core)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn namespace_quota_enforced() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
max_entries_per_namespace: 10,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.check_namespace_limit(5).is_ok());
|
||||
assert!(enforcer.check_namespace_limit(10).is_err());
|
||||
assert!(enforcer.check_namespace_limit(15).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn category_quota_enforced() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
max_entries_per_category: 50,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.check_category_limit(25).is_ok());
|
||||
assert!(enforcer.check_category_limit(50).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_category_retention_overrides_default() {
|
||||
let mut retention = HashMap::new();
|
||||
retention.insert("core".into(), 365);
|
||||
retention.insert("conversation".into(), 7);
|
||||
|
||||
let policy = MemoryPolicyConfig {
|
||||
retention_days_by_category: retention,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Core, 30),
|
||||
365
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Conversation, 30),
|
||||
7
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Daily, 30),
|
||||
30
|
||||
);
|
||||
}
|
||||
}
|
||||
+12
-3
@@ -100,6 +100,8 @@ impl PostgresMemory {
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_category ON {qualified_table}(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session_id ON {qualified_table}(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_updated_at ON {qualified_table}(updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_content_fts ON {qualified_table} USING gin(to_tsvector('simple', content));
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_key_fts ON {qualified_table} USING gin(to_tsvector('simple', key));
|
||||
"
|
||||
))?;
|
||||
|
||||
@@ -135,6 +137,9 @@ impl PostgresMemory {
|
||||
timestamp: timestamp.to_rfc3339(),
|
||||
session_id: row.get(5),
|
||||
score: row.try_get(6).ok(),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -267,12 +272,16 @@ impl Memory for PostgresMemory {
|
||||
"
|
||||
SELECT id, key, content, category, created_at, session_id,
|
||||
(
|
||||
CASE WHEN key ILIKE '%' || $1 || '%' THEN 2.0 ELSE 0.0 END +
|
||||
CASE WHEN content ILIKE '%' || $1 || '%' THEN 1.0 ELSE 0.0 END
|
||||
CASE WHEN to_tsvector('simple', key) @@ plainto_tsquery('simple', $1)
|
||||
THEN ts_rank_cd(to_tsvector('simple', key), plainto_tsquery('simple', $1)) * 2.0
|
||||
ELSE 0.0 END +
|
||||
CASE WHEN to_tsvector('simple', content) @@ plainto_tsquery('simple', $1)
|
||||
THEN ts_rank_cd(to_tsvector('simple', content), plainto_tsquery('simple', $1))
|
||||
ELSE 0.0 END
|
||||
) AS score
|
||||
FROM {qualified_table}
|
||||
WHERE ($2::TEXT IS NULL OR session_id = $2)
|
||||
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
|
||||
AND ($1 = '' OR to_tsvector('simple', key || ' ' || content) @@ plainto_tsquery('simple', $1))
|
||||
{time_filter}
|
||||
ORDER BY score DESC, updated_at DESC
|
||||
LIMIT $3
|
||||
|
||||
@@ -373,6 +373,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: Some(point.score),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -437,6 +440,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
});
|
||||
|
||||
@@ -514,6 +520,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
//! Multi-stage retrieval pipeline.
|
||||
//!
|
||||
//! Wraps a `Memory` trait object with staged retrieval:
|
||||
//! - **Stage 1 (Hot cache):** In-memory LRU of recent recall results.
|
||||
//! - **Stage 2 (FTS):** FTS5 keyword search with optional early-return.
|
||||
//! - **Stage 3 (Vector):** Vector similarity search + hybrid merge.
|
||||
//!
|
||||
//! Configurable via `[memory]` settings: `retrieval_stages`, `fts_early_return_score`.
|
||||
|
||||
use super::traits::{Memory, MemoryEntry};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// A cached recall result.
|
||||
struct CachedResult {
|
||||
entries: Vec<MemoryEntry>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
/// Multi-stage retrieval pipeline configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
/// Ordered list of stages: "cache", "fts", "vector".
|
||||
pub stages: Vec<String>,
|
||||
/// FTS score above which to early-return without vector stage.
|
||||
pub fts_early_return_score: f64,
|
||||
/// Max entries in the hot cache.
|
||||
pub cache_max_entries: usize,
|
||||
/// TTL for cached results.
|
||||
pub cache_ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stages: vec!["cache".into(), "fts".into(), "vector".into()],
|
||||
fts_early_return_score: 0.85,
|
||||
cache_max_entries: 256,
|
||||
cache_ttl: Duration::from_secs(300),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-stage retrieval pipeline wrapping a `Memory` backend.
|
||||
pub struct RetrievalPipeline {
|
||||
memory: Arc<dyn Memory>,
|
||||
config: RetrievalConfig,
|
||||
hot_cache: Mutex<HashMap<String, CachedResult>>,
|
||||
}
|
||||
|
||||
impl RetrievalPipeline {
|
||||
pub fn new(memory: Arc<dyn Memory>, config: RetrievalConfig) -> Self {
|
||||
Self {
|
||||
memory,
|
||||
config,
|
||||
hot_cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a cache key from query parameters.
|
||||
fn cache_key(
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
) -> String {
|
||||
format!(
|
||||
"{}:{}:{}:{}",
|
||||
query,
|
||||
limit,
|
||||
session_id.unwrap_or(""),
|
||||
namespace.unwrap_or("")
|
||||
)
|
||||
}
|
||||
|
||||
/// Check the hot cache for a previous result.
|
||||
fn check_cache(&self, key: &str) -> Option<Vec<MemoryEntry>> {
|
||||
let cache = self.hot_cache.lock();
|
||||
if let Some(cached) = cache.get(key) {
|
||||
if cached.created_at.elapsed() < self.config.cache_ttl {
|
||||
return Some(cached.entries.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Store a result in the hot cache with LRU eviction.
|
||||
fn store_in_cache(&self, key: String, entries: Vec<MemoryEntry>) {
|
||||
let mut cache = self.hot_cache.lock();
|
||||
|
||||
// LRU eviction: remove oldest entries if at capacity
|
||||
if cache.len() >= self.config.cache_max_entries {
|
||||
let oldest_key = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.created_at)
|
||||
.map(|(k, _)| k.clone());
|
||||
if let Some(k) = oldest_key {
|
||||
cache.remove(&k);
|
||||
}
|
||||
}
|
||||
|
||||
cache.insert(
|
||||
key,
|
||||
CachedResult {
|
||||
entries,
|
||||
created_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Execute the multi-stage retrieval pipeline.
|
||||
pub async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let ck = Self::cache_key(query, limit, session_id, namespace);
|
||||
|
||||
for stage in &self.config.stages {
|
||||
match stage.as_str() {
|
||||
"cache" => {
|
||||
if let Some(cached) = self.check_cache(&ck) {
|
||||
tracing::debug!("retrieval pipeline: cache hit for '{query}'");
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
"fts" | "vector" => {
|
||||
// Both FTS and vector are handled by the backend's recall method
|
||||
// which already does hybrid merge. We delegate to it.
|
||||
let results = if let Some(ns) = namespace {
|
||||
self.memory
|
||||
.recall_namespaced(ns, query, limit, session_id, since, until)
|
||||
.await?
|
||||
} else {
|
||||
self.memory
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await?
|
||||
};
|
||||
|
||||
if !results.is_empty() {
|
||||
// Check for FTS early-return: if top score exceeds threshold
|
||||
// and we're in the FTS stage, we can skip further stages
|
||||
if stage == "fts" {
|
||||
if let Some(top_score) = results.first().and_then(|e| e.score) {
|
||||
if top_score >= self.config.fts_early_return_score {
|
||||
tracing::debug!(
|
||||
"retrieval pipeline: FTS early return (score={top_score:.3})"
|
||||
);
|
||||
self.store_in_cache(ck, results.clone());
|
||||
return Ok(results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.store_in_cache(ck, results.clone());
|
||||
return Ok(results);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("retrieval pipeline: unknown stage '{other}', skipping");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No results from any stage
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Invalidate the hot cache (e.g. after a store operation).
|
||||
pub fn invalidate_cache(&self) {
|
||||
self.hot_cache.lock().clear();
|
||||
}
|
||||
|
||||
/// Get the number of entries in the hot cache.
|
||||
pub fn cache_size(&self) -> usize {
|
||||
self.hot_cache.lock().len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::NoneMemory;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_returns_empty_from_none_backend() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
|
||||
|
||||
let results = pipeline
|
||||
.recall("test", 10, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_cache_invalidation() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
|
||||
|
||||
// Force a cache entry
|
||||
let ck = RetrievalPipeline::cache_key("test", 10, None, None);
|
||||
pipeline.store_in_cache(ck, vec![]);
|
||||
|
||||
assert_eq!(pipeline.cache_size(), 1);
|
||||
pipeline.invalidate_cache();
|
||||
assert_eq!(pipeline.cache_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cache_key_includes_all_params() {
|
||||
let k1 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns1"));
|
||||
let k2 = RetrievalPipeline::cache_key("hello", 10, Some("sess-b"), Some("ns1"));
|
||||
let k3 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns2"));
|
||||
|
||||
assert_ne!(k1, k2);
|
||||
assert_ne!(k1, k3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_caches_results() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let config = RetrievalConfig {
|
||||
stages: vec!["cache".into()],
|
||||
..Default::default()
|
||||
};
|
||||
let pipeline = RetrievalPipeline::new(memory, config);
|
||||
|
||||
// First call: cache miss, no results
|
||||
let results = pipeline
|
||||
.recall("test", 10, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Manually insert a cache entry
|
||||
let ck = RetrievalPipeline::cache_key("cached_query", 5, None, None);
|
||||
let fake_entry = MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "k".into(),
|
||||
content: "cached content".into(),
|
||||
category: crate::memory::MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
};
|
||||
pipeline.store_in_cache(ck, vec![fake_entry]);
|
||||
|
||||
// Cache hit
|
||||
let results = pipeline
|
||||
.recall("cached_query", 5, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].content, "cached content");
|
||||
}
|
||||
}
|
||||
+129
-23
@@ -197,17 +197,35 @@ impl SqliteMemory {
|
||||
)?;
|
||||
|
||||
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||
let has_session_id: bool = conn
|
||||
let schema_sql: String = conn
|
||||
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||
.query_row([], |row| row.get::<_, String>(0))?
|
||||
.contains("session_id");
|
||||
if !has_session_id {
|
||||
.query_row([], |row| row.get::<_, String>(0))?;
|
||||
|
||||
if !schema_sql.contains("session_id") {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Migration: add namespace column
|
||||
if !schema_sql.contains("namespace") {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN namespace TEXT DEFAULT 'default';
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Migration: add importance column
|
||||
if !schema_sql.contains("importance") {
|
||||
conn.execute_batch("ALTER TABLE memories ADD COLUMN importance REAL DEFAULT 0.5;")?;
|
||||
}
|
||||
|
||||
// Migration: add superseded_by column
|
||||
if !schema_sql.contains("superseded_by") {
|
||||
conn.execute_batch("ALTER TABLE memories ADD COLUMN superseded_by TEXT;")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -246,8 +264,13 @@ impl SqliteMemory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Provide access to the connection for advanced queries (e.g. retrieval pipeline).
|
||||
pub fn connection(&self) -> &Arc<Mutex<Connection>> {
|
||||
&self.conn
|
||||
}
|
||||
|
||||
/// Get embedding from cache, or compute + cache it
|
||||
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||
pub async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||
if self.embedder.dimensions() == 0 {
|
||||
return Ok(None); // Noop embedder
|
||||
}
|
||||
@@ -310,7 +333,7 @@ impl SqliteMemory {
|
||||
}
|
||||
|
||||
/// FTS5 BM25 keyword search
|
||||
fn fts5_search(
|
||||
pub fn fts5_search(
|
||||
conn: &Connection,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
@@ -356,7 +379,7 @@ impl SqliteMemory {
|
||||
///
|
||||
/// Optional `category` and `session_id` filters reduce full-table scans
|
||||
/// when the caller already knows the scope of relevant memories.
|
||||
fn vector_search(
|
||||
pub fn vector_search(
|
||||
conn: &Connection,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
@@ -473,8 +496,8 @@ impl SqliteMemory {
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let mut sql =
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories \
|
||||
WHERE 1=1"
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories \
|
||||
WHERE superseded_by IS NULL AND 1=1"
|
||||
.to_string();
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
let mut idx = 1;
|
||||
@@ -510,6 +533,9 @@ impl SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -554,8 +580,8 @@ impl Memory for SqliteMemory {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 'default', 0.5)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
@@ -641,8 +667,8 @@ impl Memory for SqliteMemory {
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id \
|
||||
FROM memories WHERE id IN ({placeholders})"
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by \
|
||||
FROM memories WHERE superseded_by IS NULL AND id IN ({placeholders})"
|
||||
);
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
|
||||
@@ -659,17 +685,20 @@ impl Memory for SqliteMemory {
|
||||
row.get::<_, String>(3)?,
|
||||
row.get::<_, String>(4)?,
|
||||
row.get::<_, Option<String>>(5)?,
|
||||
row.get::<_, Option<String>>(6)?,
|
||||
row.get::<_, Option<f64>>(7)?,
|
||||
row.get::<_, Option<String>>(8)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut entry_map = std::collections::HashMap::new();
|
||||
for row in rows {
|
||||
let (id, key, content, cat, ts, sid) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid));
|
||||
let (id, key, content, cat, ts, sid, ns, imp, sup) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid, ns, imp, sup));
|
||||
}
|
||||
|
||||
for scored in &merged {
|
||||
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||
if let Some((key, content, cat, ts, sid, ns, imp, sup)) = entry_map.remove(&scored.id) {
|
||||
if let Some(s) = since_ref {
|
||||
if ts.as_str() < s {
|
||||
continue;
|
||||
@@ -688,6 +717,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: ts,
|
||||
session_id: sid,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
namespace: ns.unwrap_or_else(|| "default".into()),
|
||||
importance: imp,
|
||||
superseded_by: sup,
|
||||
};
|
||||
if let Some(filter_sid) = session_ref {
|
||||
if entry.session_id.as_deref() != Some(filter_sid) {
|
||||
@@ -727,8 +759,8 @@ impl Memory for SqliteMemory {
|
||||
param_idx += 1;
|
||||
}
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}{time_conditions}
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL AND ({where_clause}){time_conditions}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{param_idx}"
|
||||
);
|
||||
@@ -757,6 +789,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(1.0),
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
for row in rows {
|
||||
@@ -784,7 +819,7 @@ impl Memory for SqliteMemory {
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
@@ -796,6 +831,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -832,14 +870,17 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(ref cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL AND category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
@@ -853,8 +894,8 @@ impl Memory for SqliteMemory {
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC LIMIT ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL ORDER BY updated_at DESC LIMIT ?1",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
@@ -904,6 +945,71 @@ impl Memory for SqliteMemory {
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let entries = self
|
||||
.recall(query, limit * 2, session_id, since, until)
|
||||
.await?;
|
||||
let filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.namespace == namespace)
|
||||
.take(limit)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
let embedding_bytes = self
|
||||
.get_or_compute_embedding(content)
|
||||
.await?
|
||||
.map(|emb| vector::vec_to_bytes(&emb));
|
||||
|
||||
let conn = self.conn.clone();
|
||||
let key = key.to_string();
|
||||
let content = content.to_string();
|
||||
let sid = session_id.map(String::from);
|
||||
let ns = namespace.unwrap_or("default").to_string();
|
||||
let imp = importance.unwrap_or(0.5);
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let conn = conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let cat = Self::category_to_str(&category);
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
embedding = excluded.embedding,
|
||||
updated_at = excluded.updated_at,
|
||||
session_id = excluded.session_id,
|
||||
namespace = excluded.namespace,
|
||||
importance = excluded.importance",
|
||||
params![id, key, content, cat, embedding_bytes, now, now, sid, ns, imp],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+63
-2
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
/// A single message in a conversation trace for procedural memory.
|
||||
///
|
||||
/// Used to capture "how to" patterns from tool-calling turns so that
|
||||
/// backends that support procedural storage (e.g. mem0) can learn from them.
|
||||
/// backends that support procedural storage can learn from them.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ProceduralMessage {
|
||||
pub role: String,
|
||||
@@ -23,6 +23,19 @@ pub struct MemoryEntry {
|
||||
pub timestamp: String,
|
||||
pub session_id: Option<String>,
|
||||
pub score: Option<f64>,
|
||||
/// Namespace for isolation between agents/contexts.
|
||||
#[serde(default = "default_namespace")]
|
||||
pub namespace: String,
|
||||
/// Importance score (0.0–1.0) for prioritized retrieval.
|
||||
#[serde(default)]
|
||||
pub importance: Option<f64>,
|
||||
/// If this entry was superseded by a newer conflicting entry.
|
||||
#[serde(default)]
|
||||
pub superseded_by: Option<String>,
|
||||
}
|
||||
|
||||
fn default_namespace() -> String {
|
||||
"default".into()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MemoryEntry {
|
||||
@@ -34,6 +47,8 @@ impl std::fmt::Debug for MemoryEntry {
|
||||
.field("category", &self.category)
|
||||
.field("timestamp", &self.timestamp)
|
||||
.field("score", &self.score)
|
||||
.field("namespace", &self.namespace)
|
||||
.field("importance", &self.importance)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -128,7 +143,7 @@ pub trait Memory: Send + Sync {
|
||||
|
||||
/// Store a conversation trace as procedural memory.
|
||||
///
|
||||
/// Backends that support procedural storage (e.g. mem0) override this
|
||||
/// Backends that support procedural storage override this
|
||||
/// to extract "how to" patterns from tool-calling turns. The default
|
||||
/// implementation is a no-op.
|
||||
async fn store_procedural(
|
||||
@@ -138,6 +153,46 @@ pub trait Memory: Send + Sync {
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recall memories scoped to a specific namespace.
|
||||
///
|
||||
/// Default implementation delegates to `recall()` and filters by namespace.
|
||||
/// Backends with native namespace support should override for efficiency.
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let entries = self
|
||||
.recall(query, limit * 2, session_id, since, until)
|
||||
.await?;
|
||||
let filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.namespace == namespace)
|
||||
.take(limit)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Store a memory entry with namespace and importance.
|
||||
///
|
||||
/// Default implementation delegates to `store()`. Backends with native
|
||||
/// namespace/importance support should override.
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
_namespace: Option<&str>,
|
||||
_importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.store(key, content, category, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -185,6 +240,9 @@ mod tests {
|
||||
timestamp: "2026-02-16T00:00:00Z".into(),
|
||||
session_id: Some("session-abc".into()),
|
||||
score: Some(0.98),
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.7),
|
||||
superseded_by: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&entry).unwrap();
|
||||
@@ -196,5 +254,8 @@ mod tests {
|
||||
assert_eq!(parsed.category, MemoryCategory::Core);
|
||||
assert_eq!(parsed.session_id.as_deref(), Some("session-abc"));
|
||||
assert_eq!(parsed.score, Some(0.98));
|
||||
assert_eq!(parsed.namespace, "default");
|
||||
assert_eq!(parsed.importance, Some(0.7));
|
||||
assert!(parsed.superseded_by.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,6 +126,7 @@ pub fn hybrid_merge(
|
||||
b.final_score
|
||||
.partial_cmp(&a.final_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| a.id.cmp(&b.id))
|
||||
});
|
||||
results.truncate(limit);
|
||||
results
|
||||
|
||||
@@ -507,6 +507,7 @@ mod tests {
|
||||
max_images: 1,
|
||||
max_image_size_mb: 5,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
@@ -549,6 +550,7 @@ mod tests {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
|
||||
+12
-1
@@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
link_enricher: crate::config::LinkEnricherConfig::default(),
|
||||
text_browser: crate::config::TextBrowserConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
@@ -420,9 +421,17 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
retrieval_stages: vec!["cache".into(), "fts".into(), "vector".into()],
|
||||
rerank_enabled: false,
|
||||
rerank_threshold: 5,
|
||||
fts_early_return_score: 0.85,
|
||||
default_namespace: "default".into(),
|
||||
conflict_threshold: 0.85,
|
||||
audit_enabled: false,
|
||||
audit_retention_days: 30,
|
||||
policy: crate::config::MemoryPolicyConfig::default(),
|
||||
sqlite_open_timeout_secs: None,
|
||||
qdrant: crate::config::QdrantConfig::default(),
|
||||
mem0: crate::config::schema::Mem0Config::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -597,6 +606,7 @@ async fn run_quick_setup_with_home(
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
link_enricher: crate::config::LinkEnricherConfig::default(),
|
||||
text_browser: crate::config::TextBrowserConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
@@ -4185,6 +4195,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
device_id: detected_device_id,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -767,6 +767,12 @@ impl Provider for OpenAiCodexProvider {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Mutex that serializes all tests which mutate process-global env vars
|
||||
/// (`std::env::set_var` / `remove_var`). Each such test must hold this
|
||||
/// lock for its entire duration so that parallel test threads don't race.
|
||||
static ENV_MUTEX: Mutex<()> = Mutex::new(());
|
||||
|
||||
struct EnvGuard {
|
||||
key: &'static str,
|
||||
@@ -841,6 +847,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_responses_url_prefers_explicit_endpoint_env() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _endpoint_guard = EnvGuard::set(
|
||||
CODEX_RESPONSES_URL_ENV,
|
||||
Some("https://env.example.com/v1/responses"),
|
||||
@@ -856,6 +863,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_responses_url_uses_provider_api_url_override() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None);
|
||||
let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None);
|
||||
|
||||
@@ -959,6 +967,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_reasoning_effort_prefers_configured_override() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("low"));
|
||||
assert_eq!(
|
||||
resolve_reasoning_effort("gpt-5-codex", Some("high")),
|
||||
@@ -968,6 +977,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_reasoning_effort_uses_legacy_env_when_unconfigured() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("minimal"));
|
||||
assert_eq!(
|
||||
resolve_reasoning_effort("gpt-5-codex", None),
|
||||
|
||||
+3
-3
@@ -1236,7 +1236,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_capture_reads_stdout() {
|
||||
let out = run_capture(Command::new("sh").args(["-lc", "echo hello"]))
|
||||
let out = run_capture(Command::new("sh").args(["-c", "echo hello"]))
|
||||
.expect("stdout capture should succeed");
|
||||
assert_eq!(out.trim(), "hello");
|
||||
}
|
||||
@@ -1244,7 +1244,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_capture_falls_back_to_stderr() {
|
||||
let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"]))
|
||||
let out = run_capture(Command::new("sh").args(["-c", "echo warn 1>&2"]))
|
||||
.expect("stderr capture should succeed");
|
||||
assert_eq!(out.trim(), "warn");
|
||||
}
|
||||
@@ -1252,7 +1252,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_checked_errors_on_non_zero_status() {
|
||||
let err = run_checked(Command::new("sh").args(["-lc", "exit 17"]))
|
||||
let err = run_checked(Command::new("sh").args(["-c", "exit 17"]))
|
||||
.expect_err("non-zero exit should error");
|
||||
assert!(err.to_string().contains("Command failed"));
|
||||
}
|
||||
|
||||
@@ -106,8 +106,14 @@ impl SkillCreator {
|
||||
// Trim leading/trailing hyphens, then truncate.
|
||||
let trimmed = collapsed.trim_matches('-');
|
||||
if trimmed.len() > 64 {
|
||||
// Truncate at a hyphen boundary if possible.
|
||||
let truncated = &trimmed[..64];
|
||||
// Find the nearest valid character boundary at or before 64 bytes.
|
||||
let safe_index = trimmed
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 64)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
let truncated = &trimmed[..safe_index];
|
||||
truncated.trim_end_matches('-').to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
|
||||
+89
-14
@@ -738,15 +738,47 @@ pub fn skills_to_prompt_with_mode(
|
||||
}
|
||||
|
||||
if !skill.tools.is_empty() {
|
||||
let _ = writeln!(prompt, " <tools>");
|
||||
for tool in &skill.tools {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
// Tools with known kinds (shell, script, http) are registered as
|
||||
// callable tool specs and can be invoked directly via function calling.
|
||||
// We note them here for context but mark them as callable.
|
||||
let registered: Vec<_> = skill
|
||||
.tools
|
||||
.iter()
|
||||
.filter(|t| matches!(t.kind.as_str(), "shell" | "script" | "http"))
|
||||
.collect();
|
||||
let unregistered: Vec<_> = skill
|
||||
.tools
|
||||
.iter()
|
||||
.filter(|t| !matches!(t.kind.as_str(), "shell" | "script" | "http"))
|
||||
.collect();
|
||||
|
||||
if !registered.is_empty() {
|
||||
let _ = writeln!(prompt, " <callable_tools hint=\"These are registered as callable tool specs. Invoke them directly by name ({{}}.{{}}) instead of using shell.\">");
|
||||
for tool in ®istered {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(
|
||||
&mut prompt,
|
||||
8,
|
||||
"name",
|
||||
&format!("{}.{}", skill.name, tool.name),
|
||||
);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </callable_tools>");
|
||||
}
|
||||
|
||||
if !unregistered.is_empty() {
|
||||
let _ = writeln!(prompt, " <tools>");
|
||||
for tool in &unregistered {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </tools>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </tools>");
|
||||
}
|
||||
|
||||
let _ = writeln!(prompt, " </skill>");
|
||||
@@ -756,6 +788,47 @@ pub fn skills_to_prompt_with_mode(
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Convert skill tools into callable `Tool` trait objects.
|
||||
///
|
||||
/// Each skill's `[[tools]]` entries are converted to either `SkillShellTool`
|
||||
/// (for `shell`/`script` kinds) or `SkillHttpTool` (for `http` kind),
|
||||
/// enabling them to appear as first-class callable tool specs rather than
|
||||
/// only as XML in the system prompt.
|
||||
pub fn skills_to_tools(
|
||||
skills: &[Skill],
|
||||
security: std::sync::Arc<crate::security::SecurityPolicy>,
|
||||
) -> Vec<Box<dyn crate::tools::traits::Tool>> {
|
||||
let mut tools: Vec<Box<dyn crate::tools::traits::Tool>> = Vec::new();
|
||||
for skill in skills {
|
||||
for tool in &skill.tools {
|
||||
match tool.kind.as_str() {
|
||||
"shell" | "script" => {
|
||||
tools.push(Box::new(crate::tools::skill_tool::SkillShellTool::new(
|
||||
&skill.name,
|
||||
tool,
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
"http" => {
|
||||
tools.push(Box::new(crate::tools::skill_http::SkillHttpTool::new(
|
||||
&skill.name,
|
||||
tool,
|
||||
)));
|
||||
}
|
||||
other => {
|
||||
tracing::warn!(
|
||||
"Unknown skill tool kind '{}' for {}.{}, skipping",
|
||||
other,
|
||||
skill.name,
|
||||
tool.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
/// Get the skills directory path
|
||||
pub fn skills_dir(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("skills")
|
||||
@@ -1517,10 +1590,10 @@ command = "echo hello"
|
||||
assert!(prompt.contains("read_skill(name)"));
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt.contains("<instruction>Do the thing.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>run</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell/script/http) appear under <callable_tools>.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>test.run</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1710,9 +1783,11 @@ description = "Bare minimum"
|
||||
}];
|
||||
let prompt = skills_to_prompt(&skills, Path::new("/tmp"));
|
||||
assert!(prompt.contains("weather"));
|
||||
assert!(prompt.contains("<name>get_weather</name>"));
|
||||
// Registered tools (shell kind) now appear under <callable_tools> with
|
||||
// prefixed names (skill_name.tool_name).
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>weather.get_weather</name>"));
|
||||
assert!(prompt.contains("<description>Fetch forecast</description>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -81,6 +81,8 @@ pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod sessions;
|
||||
pub mod shell;
|
||||
pub mod skill_http;
|
||||
pub mod skill_tool;
|
||||
pub mod swarm;
|
||||
pub mod text_browser;
|
||||
pub mod tool_search;
|
||||
@@ -156,6 +158,10 @@ pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use sessions::{SessionsHistoryTool, SessionsListTool, SessionsSendTool};
|
||||
pub use shell::ShellTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use skill_http::SkillHttpTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use skill_tool::SkillShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use text_browser::TextBrowserTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
@@ -257,6 +263,33 @@ pub fn default_tools_with_runtime(
|
||||
]
|
||||
}
|
||||
|
||||
/// Register skill-defined tools into an existing tool registry.
|
||||
///
|
||||
/// Converts each skill's `[[tools]]` entries into callable `Tool` implementations
|
||||
/// and appends them to the registry. Skill tools that would shadow a built-in tool
|
||||
/// name are skipped with a warning.
|
||||
pub fn register_skill_tools(
|
||||
tools_registry: &mut Vec<Box<dyn Tool>>,
|
||||
skills: &[crate::skills::Skill],
|
||||
security: Arc<SecurityPolicy>,
|
||||
) {
|
||||
let skill_tools = crate::skills::skills_to_tools(skills, security);
|
||||
let existing_names: std::collections::HashSet<String> = tools_registry
|
||||
.iter()
|
||||
.map(|t| t.name().to_string())
|
||||
.collect();
|
||||
for tool in skill_tools {
|
||||
if existing_names.contains(tool.name()) {
|
||||
tracing::warn!(
|
||||
"Skill tool '{}' shadows built-in tool, skipping",
|
||||
tool.name()
|
||||
);
|
||||
} else {
|
||||
tools_registry.push(tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create full tool registry including memory tools and optional Composio
|
||||
#[allow(clippy::implicit_hasher, clippy::too_many_arguments)]
|
||||
pub fn all_tools(
|
||||
@@ -458,6 +491,7 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(WebSearchTool::new_with_config(
|
||||
root_config.web_search.provider.clone(),
|
||||
root_config.web_search.brave_api_key.clone(),
|
||||
root_config.web_search.searxng_instance_url.clone(),
|
||||
root_config.web_search.max_results,
|
||||
root_config.web_search.timeout_secs,
|
||||
root_config.config_path.clone(),
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
//! HTTP-based tool derived from a skill's `[[tools]]` section.
|
||||
//!
|
||||
//! Each `SkillTool` with `kind = "http"` is converted into a `SkillHttpTool`
|
||||
//! that implements the `Tool` trait. The command field is used as the URL
|
||||
//! template and args are substituted as query parameters or path segments.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Maximum response body size (1 MB).
|
||||
const MAX_RESPONSE_BYTES: usize = 1_048_576;
|
||||
/// HTTP request timeout (seconds).
|
||||
const HTTP_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// A tool derived from a skill's `[[tools]]` section that makes HTTP requests.
|
||||
pub struct SkillHttpTool {
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
url_template: String,
|
||||
args: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl SkillHttpTool {
|
||||
/// Create a new skill HTTP tool.
|
||||
///
|
||||
/// The tool name is prefixed with the skill name (`skill_name.tool_name`)
|
||||
/// to prevent collisions with built-in tools.
|
||||
pub fn new(skill_name: &str, tool: &crate::skills::SkillTool) -> Self {
|
||||
Self {
|
||||
tool_name: format!("{}.{}", skill_name, tool.name),
|
||||
tool_description: tool.description.clone(),
|
||||
url_template: tool.command.clone(),
|
||||
args: tool.args.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_parameters_schema(&self) -> serde_json::Value {
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut required = Vec::new();
|
||||
|
||||
for (name, description) in &self.args {
|
||||
properties.insert(
|
||||
name.clone(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": description
|
||||
}),
|
||||
);
|
||||
required.push(serde_json::Value::String(name.clone()));
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
})
|
||||
}
|
||||
|
||||
/// Substitute `{{arg_name}}` placeholders in the URL template with
|
||||
/// the provided argument values.
|
||||
fn substitute_args(&self, args: &serde_json::Value) -> String {
|
||||
let mut url = self.url_template.clone();
|
||||
if let Some(obj) = args.as_object() {
|
||||
for (key, value) in obj {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
let replacement = value.as_str().unwrap_or_default();
|
||||
url = url.replace(&placeholder, replacement);
|
||||
}
|
||||
}
|
||||
url
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillHttpTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.tool_description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.build_parameters_schema()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let url = self.substitute_args(&args);
|
||||
|
||||
// Validate URL scheme
|
||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Only http:// and https:// URLs are allowed, got: {url}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
|
||||
.build()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build HTTP client: {e}"))?;
|
||||
|
||||
let response = match client.get(&url).send().await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("HTTP request failed: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let body = match response.bytes().await {
|
||||
Ok(bytes) => {
|
||||
let mut text = String::from_utf8_lossy(&bytes).to_string();
|
||||
if text.len() > MAX_RESPONSE_BYTES {
|
||||
let mut b = MAX_RESPONSE_BYTES.min(text.len());
|
||||
while b > 0 && !text.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
text.truncate(b);
|
||||
text.push_str("\n... [response truncated at 1MB]");
|
||||
}
|
||||
text
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to read response body: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
success: status.is_success(),
|
||||
output: body,
|
||||
error: if status.is_success() {
|
||||
None
|
||||
} else {
|
||||
Some(format!("HTTP {}", status))
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::skills::SkillTool;
|
||||
|
||||
fn sample_http_tool() -> SkillTool {
|
||||
let mut args = HashMap::new();
|
||||
args.insert("city".to_string(), "City name to look up".to_string());
|
||||
|
||||
SkillTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Fetch weather for a city".to_string(),
|
||||
kind: "http".to_string(),
|
||||
command: "https://api.example.com/weather?city={{city}}".to_string(),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_name_is_prefixed() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
assert_eq!(tool.name(), "weather_skill.get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_description() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
assert_eq!(tool.description(), "Fetch weather for a city");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_parameters_schema() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["city"].is_object());
|
||||
assert_eq!(schema["properties"]["city"]["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_substitute_args() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let result = tool.substitute_args(&serde_json::json!({"city": "London"}));
|
||||
assert_eq!(result, "https://api.example.com/weather?city=London");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_spec_roundtrip() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "weather_skill.get_weather");
|
||||
assert_eq!(spec.description, "Fetch weather for a city");
|
||||
assert_eq!(spec.parameters["type"], "object");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_empty_args() {
|
||||
let st = SkillTool {
|
||||
name: "ping".to_string(),
|
||||
description: "Ping endpoint".to_string(),
|
||||
kind: "http".to_string(),
|
||||
command: "https://api.example.com/ping".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillHttpTool::new("s", &st);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].as_object().unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
//! Shell-based tool derived from a skill's `[[tools]]` section.
|
||||
//!
|
||||
//! Each `SkillTool` with `kind = "shell"` or `kind = "script"` is converted
|
||||
//! into a `SkillShellTool` that implements the `Tool` trait. The tool name is
|
||||
//! prefixed with the skill name (e.g. `my_skill.run_lint`) to avoid collisions
|
||||
//! with built-in tools.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Maximum execution time for a skill shell command (seconds).
|
||||
const SKILL_SHELL_TIMEOUT_SECS: u64 = 60;
|
||||
/// Maximum output size in bytes (1 MB).
|
||||
const MAX_OUTPUT_BYTES: usize = 1_048_576;
|
||||
|
||||
/// A tool derived from a skill's `[[tools]]` section that executes shell commands.
|
||||
pub struct SkillShellTool {
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
command_template: String,
|
||||
args: HashMap<String, String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SkillShellTool {
|
||||
/// Create a new skill shell tool.
|
||||
///
|
||||
/// The tool name is prefixed with the skill name (`skill_name.tool_name`)
|
||||
/// to prevent collisions with built-in tools.
|
||||
pub fn new(
|
||||
skill_name: &str,
|
||||
tool: &crate::skills::SkillTool,
|
||||
security: Arc<SecurityPolicy>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_name: format!("{}.{}", skill_name, tool.name),
|
||||
tool_description: tool.description.clone(),
|
||||
command_template: tool.command.clone(),
|
||||
args: tool.args.clone(),
|
||||
security,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_parameters_schema(&self) -> serde_json::Value {
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut required = Vec::new();
|
||||
|
||||
for (name, description) in &self.args {
|
||||
properties.insert(
|
||||
name.clone(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": description
|
||||
}),
|
||||
);
|
||||
required.push(serde_json::Value::String(name.clone()));
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
})
|
||||
}
|
||||
|
||||
/// Substitute `{{arg_name}}` placeholders in the command template with
|
||||
/// the provided argument values. Unknown placeholders are left as-is.
|
||||
fn substitute_args(&self, args: &serde_json::Value) -> String {
|
||||
let mut command = self.command_template.clone();
|
||||
if let Some(obj) = args.as_object() {
|
||||
for (key, value) in obj {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
let replacement = value.as_str().unwrap_or_default();
|
||||
command = command.replace(&placeholder, replacement);
|
||||
}
|
||||
}
|
||||
command
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillShellTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.tool_description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.build_parameters_schema()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = self.substitute_args(&args);
|
||||
|
||||
// Rate limit check
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Security validation — always requires explicit approval (approved=true)
|
||||
// since skill tools are user-defined and should be treated as medium-risk.
|
||||
match self.security.validate_command_execution(&command, true) {
|
||||
Ok(_) => {}
|
||||
Err(reason) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(reason),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = self.security.forbidden_path_argument(&command) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Path blocked by security policy: {path}")),
|
||||
});
|
||||
}
|
||||
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build and execute the command
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(&command);
|
||||
cmd.current_dir(&self.security.workspace_dir);
|
||||
cmd.env_clear();
|
||||
|
||||
// Only pass safe environment variables
|
||||
for var in &[
|
||||
"PATH", "HOME", "TERM", "LANG", "LC_ALL", "USER", "SHELL", "TMPDIR",
|
||||
] {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
cmd.env(var, val);
|
||||
}
|
||||
}
|
||||
|
||||
let result =
|
||||
tokio::time::timeout(Duration::from_secs(SKILL_SHELL_TIMEOUT_SECS), cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let mut stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
if stdout.len() > MAX_OUTPUT_BYTES {
|
||||
let mut b = MAX_OUTPUT_BYTES.min(stdout.len());
|
||||
while b > 0 && !stdout.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
stdout.truncate(b);
|
||||
stdout.push_str("\n... [output truncated at 1MB]");
|
||||
}
|
||||
if stderr.len() > MAX_OUTPUT_BYTES {
|
||||
let mut b = MAX_OUTPUT_BYTES.min(stderr.len());
|
||||
while b > 0 && !stderr.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
stderr.truncate(b);
|
||||
stderr.push_str("\n... [stderr truncated at 1MB]");
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to execute command: {e}")),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Command timed out after {SKILL_SHELL_TIMEOUT_SECS}s and was killed"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use crate::skills::SkillTool;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn sample_skill_tool() -> SkillTool {
|
||||
let mut args = HashMap::new();
|
||||
args.insert("file".to_string(), "The file to lint".to_string());
|
||||
args.insert(
|
||||
"format".to_string(),
|
||||
"Output format (json|text)".to_string(),
|
||||
);
|
||||
|
||||
SkillTool {
|
||||
name: "run_lint".to_string(),
|
||||
description: "Run the linter on a file".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "lint --file {{file}} --format {{format}}".to_string(),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_name_is_prefixed() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
assert_eq!(tool.name(), "my_skill.run_lint");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_description() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
assert_eq!(tool.description(), "Run the linter on a file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_parameters_schema() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["file"].is_object());
|
||||
assert_eq!(schema["properties"]["file"]["type"], "string");
|
||||
assert!(schema["properties"]["format"].is_object());
|
||||
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required should be array");
|
||||
assert_eq!(required.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_substitute_args() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let result = tool.substitute_args(&serde_json::json!({
|
||||
"file": "src/main.rs",
|
||||
"format": "json"
|
||||
}));
|
||||
assert_eq!(result, "lint --file src/main.rs --format json");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_substitute_missing_arg() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let result = tool.substitute_args(&serde_json::json!({"file": "test.rs"}));
|
||||
// Missing {{format}} placeholder stays in the command
|
||||
assert!(result.contains("{{format}}"));
|
||||
assert!(result.contains("test.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_empty_args_schema() {
|
||||
let st = SkillTool {
|
||||
name: "simple".to_string(),
|
||||
description: "Simple tool".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "echo hello".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillShellTool::new("s", &st, test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"].as_object().unwrap().is_empty());
|
||||
assert!(schema["required"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn skill_shell_tool_executes_echo() {
|
||||
let st = SkillTool {
|
||||
name: "hello".to_string(),
|
||||
description: "Say hello".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "echo hello-skill".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillShellTool::new("test", &st, test_security());
|
||||
let result = tool.execute(serde_json::json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("hello-skill"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_spec_roundtrip() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "my_skill.run_lint");
|
||||
assert_eq!(spec.description, "Run the linter on a file");
|
||||
assert_eq!(spec.parameters["type"], "object");
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
pub enum WebSearchProviderRoute {
|
||||
DuckDuckGo,
|
||||
Brave,
|
||||
SearXNG,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -13,6 +14,7 @@ pub struct WebSearchProviderResolution {
|
||||
|
||||
pub const DEFAULT_WEB_SEARCH_PROVIDER: &str = "duckduckgo";
|
||||
const BRAVE_PROVIDER: &str = "brave";
|
||||
const SEARXNG_PROVIDER: &str = "searxng";
|
||||
|
||||
pub fn resolve_web_search_provider(raw_provider: &str) -> WebSearchProviderResolution {
|
||||
let normalized = raw_provider.trim().to_ascii_lowercase();
|
||||
@@ -29,6 +31,11 @@ pub fn resolve_web_search_provider(raw_provider: &str) -> WebSearchProviderResol
|
||||
canonical_provider: BRAVE_PROVIDER,
|
||||
used_fallback: false,
|
||||
},
|
||||
"searxng" | "searx" | "searx-ng" | "searx_ng" => WebSearchProviderResolution {
|
||||
route: WebSearchProviderRoute::SearXNG,
|
||||
canonical_provider: SEARXNG_PROVIDER,
|
||||
used_fallback: false,
|
||||
},
|
||||
_ => WebSearchProviderResolution {
|
||||
route: WebSearchProviderRoute::DuckDuckGo,
|
||||
canonical_provider: DEFAULT_WEB_SEARCH_PROVIDER,
|
||||
@@ -63,6 +70,17 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_aliases_to_searxng() {
|
||||
let searxng_aliases = ["searxng", "searx", "searx-ng", "searx_ng"];
|
||||
for alias in searxng_aliases {
|
||||
let resolved = resolve_web_search_provider(alias);
|
||||
assert_eq!(resolved.route, WebSearchProviderRoute::SearXNG);
|
||||
assert_eq!(resolved.canonical_provider, SEARXNG_PROVIDER);
|
||||
assert!(!resolved.used_fallback);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_unknown_provider_falls_back_to_default() {
|
||||
let resolved = resolve_web_search_provider("bing");
|
||||
|
||||
@@ -7,7 +7,8 @@ use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Web search tool for searching the internet.
|
||||
/// Supports multiple providers: DuckDuckGo (free), Brave (requires API key).
|
||||
/// Supports multiple providers: DuckDuckGo (free), Brave (requires API key),
|
||||
/// SearXNG (self-hosted, requires instance URL).
|
||||
///
|
||||
/// The Brave API key is resolved lazily at execution time: if the boot-time key
|
||||
/// is missing or still encrypted, the tool re-reads `config.toml`, decrypts the
|
||||
@@ -18,6 +19,8 @@ pub struct WebSearchTool {
|
||||
provider: String,
|
||||
/// Boot-time key snapshot (may be `None` if not yet configured at startup).
|
||||
boot_brave_api_key: Option<String>,
|
||||
/// SearXNG instance base URL (e.g. "https://searx.example.com").
|
||||
searxng_instance_url: Option<String>,
|
||||
max_results: usize,
|
||||
timeout_secs: u64,
|
||||
/// Path to `config.toml` for lazy re-read of keys at execution time.
|
||||
@@ -36,6 +39,7 @@ impl WebSearchTool {
|
||||
Self {
|
||||
provider: provider.trim().to_lowercase(),
|
||||
boot_brave_api_key: brave_api_key,
|
||||
searxng_instance_url: None,
|
||||
max_results: max_results.clamp(1, 10),
|
||||
timeout_secs: timeout_secs.max(1),
|
||||
config_path: PathBuf::new(),
|
||||
@@ -51,6 +55,7 @@ impl WebSearchTool {
|
||||
pub fn new_with_config(
|
||||
provider: String,
|
||||
brave_api_key: Option<String>,
|
||||
searxng_instance_url: Option<String>,
|
||||
max_results: usize,
|
||||
timeout_secs: u64,
|
||||
config_path: PathBuf,
|
||||
@@ -59,6 +64,7 @@ impl WebSearchTool {
|
||||
Self {
|
||||
provider: provider.trim().to_lowercase(),
|
||||
boot_brave_api_key: brave_api_key,
|
||||
searxng_instance_url,
|
||||
max_results: max_results.clamp(1, 10),
|
||||
timeout_secs: timeout_secs.max(1),
|
||||
config_path,
|
||||
@@ -248,6 +254,105 @@ impl WebSearchTool {
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
/// Resolve the SearXNG instance URL from the boot-time config or by
|
||||
/// re-reading `config.toml` at runtime.
|
||||
fn resolve_searxng_instance_url(&self) -> anyhow::Result<String> {
|
||||
if let Some(ref url) = self.searxng_instance_url {
|
||||
if !url.is_empty() {
|
||||
return Ok(url.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: re-read config.toml to pick up values set after boot.
|
||||
let contents = std::fs::read_to_string(&self.config_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read config file {} for SearXNG instance URL: {e}",
|
||||
self.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let config: crate::config::Config = toml::from_str(&contents).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to parse config file {} for SearXNG instance URL: {e}",
|
||||
self.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
config
|
||||
.web_search
|
||||
.searxng_instance_url
|
||||
.filter(|u| !u.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"SearXNG instance URL not configured. Set [web_search] searxng_instance_url \
|
||||
in config.toml or the SEARXNG_INSTANCE_URL environment variable."
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn search_searxng(&self, query: &str) -> anyhow::Result<String> {
|
||||
let instance_url = self.resolve_searxng_instance_url()?;
|
||||
let base_url = instance_url.trim_end_matches('/');
|
||||
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let search_url = format!(
|
||||
"{}/search?q={}&format=json&pageno=1",
|
||||
base_url, encoded_query
|
||||
);
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_secs))
|
||||
.user_agent("ZeroClaw/1.0")
|
||||
.build()?;
|
||||
|
||||
let response = client
|
||||
.get(&search_url)
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("SearXNG search failed with status: {}", response.status());
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response.json().await?;
|
||||
self.parse_searxng_results(&json, query)
|
||||
}
|
||||
|
||||
fn parse_searxng_results(
|
||||
&self,
|
||||
json: &serde_json::Value,
|
||||
query: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let results = json
|
||||
.get("results")
|
||||
.and_then(|r| r.as_array())
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid SearXNG API response"))?;
|
||||
|
||||
if results.is_empty() {
|
||||
return Ok(format!("No results found for: {}", query));
|
||||
}
|
||||
|
||||
let mut lines = vec![format!("Search results for: {} (via SearXNG)", query)];
|
||||
|
||||
for (i, result) in results.iter().take(self.max_results).enumerate() {
|
||||
let title = result
|
||||
.get("title")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("No title");
|
||||
let url = result.get("url").and_then(|u| u.as_str()).unwrap_or("");
|
||||
let content = result.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
|
||||
lines.push(format!("{}. {}", i + 1, title));
|
||||
lines.push(format!(" {}", url));
|
||||
if !content.is_empty() {
|
||||
lines.push(format!(" {}", content));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_ddg_redirect_url(raw_url: &str) -> String {
|
||||
@@ -314,6 +419,7 @@ impl Tool for WebSearchTool {
|
||||
let result = match resolution.route {
|
||||
WebSearchProviderRoute::DuckDuckGo => self.search_duckduckgo(query).await?,
|
||||
WebSearchProviderRoute::Brave => self.search_brave(query).await?,
|
||||
WebSearchProviderRoute::SearXNG => self.search_searxng(query).await?,
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
@@ -443,8 +549,15 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// No boot key -- forces reload from config
|
||||
let tool =
|
||||
WebSearchTool::new_with_config("brave".to_string(), None, 5, 15, config_path, false);
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
false,
|
||||
);
|
||||
let key = tool.resolve_brave_api_key().unwrap();
|
||||
assert_eq!(key, "fresh-key-from-disk");
|
||||
}
|
||||
@@ -466,6 +579,7 @@ mod tests {
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
Some(encrypted),
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
@@ -475,6 +589,111 @@ mod tests {
|
||||
assert_eq!(key, "brave-secret-key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_searxng_without_instance_url() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
std::fs::write(&config_path, "[web_search]\n").unwrap();
|
||||
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"searxng".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
false,
|
||||
);
|
||||
let result = tool.execute(json!({"query": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("SearXNG instance URL not configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_searxng_results_empty() {
|
||||
let tool = WebSearchTool::new("searxng".to_string(), None, 5, 15);
|
||||
let json = serde_json::json!({"results": []});
|
||||
let result = tool.parse_searxng_results(&json, "test").unwrap();
|
||||
assert!(result.contains("No results found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_searxng_results_with_data() {
|
||||
let tool = WebSearchTool::new("searxng".to_string(), None, 5, 15);
|
||||
let json = serde_json::json!({
|
||||
"results": [
|
||||
{
|
||||
"title": "SearXNG Example",
|
||||
"url": "https://example.com",
|
||||
"content": "A privacy-respecting metasearch engine"
|
||||
},
|
||||
{
|
||||
"title": "Another Result",
|
||||
"url": "https://example.org",
|
||||
"content": "More information here"
|
||||
}
|
||||
]
|
||||
});
|
||||
let result = tool.parse_searxng_results(&json, "test").unwrap();
|
||||
assert!(result.contains("SearXNG Example"));
|
||||
assert!(result.contains("https://example.com"));
|
||||
assert!(result.contains("A privacy-respecting metasearch engine"));
|
||||
assert!(result.contains("via SearXNG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_searxng_results_invalid_response() {
|
||||
let tool = WebSearchTool::new("searxng".to_string(), None, 5, 15);
|
||||
let json = serde_json::json!({"error": "bad request"});
|
||||
let result = tool.parse_searxng_results(&json, "test");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Invalid SearXNG API response"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_searxng_instance_url_from_boot() {
|
||||
let tool = WebSearchTool {
|
||||
provider: "searxng".to_string(),
|
||||
boot_brave_api_key: None,
|
||||
searxng_instance_url: Some("https://searx.example.com".to_string()),
|
||||
max_results: 5,
|
||||
timeout_secs: 15,
|
||||
config_path: PathBuf::new(),
|
||||
secrets_encrypt: false,
|
||||
};
|
||||
let url = tool.resolve_searxng_instance_url().unwrap();
|
||||
assert_eq!(url, "https://searx.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_searxng_instance_url_reloads_from_config() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
"[web_search]\nsearxng_instance_url = \"https://search.local\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"searxng".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
false,
|
||||
);
|
||||
let url = tool.resolve_searxng_instance_url().unwrap();
|
||||
assert_eq!(url, "https://search.local");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_brave_api_key_picks_up_runtime_update() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
@@ -486,6 +705,7 @@ mod tests {
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path.clone(),
|
||||
|
||||
@@ -401,7 +401,7 @@ fn config_nested_optional_sections_default_when_absent() {
|
||||
assert!(parsed.channels_config.telegram.is_none());
|
||||
assert!(!parsed.composio.enabled);
|
||||
assert!(parsed.composio.api_key.is_none());
|
||||
assert!(!parsed.browser.enabled);
|
||||
assert!(parsed.browser.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
ChevronRight,
|
||||
Terminal,
|
||||
Package,
|
||||
ChevronsUpDown,
|
||||
} from 'lucide-react';
|
||||
import type { ToolSpec, CliTool } from '@/types/api';
|
||||
import { getTools, getCliTools } from '@/lib/api';
|
||||
|
||||
Reference in New Issue
Block a user