Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 70e7910cb9 | |||
| a8868768e8 | |||
| 67293c50df | |||
| 1646079d25 | |||
| 25b639435f | |||
| 77779844e5 | |||
| f658d5806a | |||
| 7134fe0824 | |||
| 263802b3df | |||
| 3c25fddb2a | |||
| a6a46bdd25 | |||
| 235d4d2f1c | |||
| bd1e8c8e1a | |||
| f81807bff6 | |||
| bb7006313c | |||
| 9a49626376 | |||
| 8b978a721f | |||
| 75b4c1d4a4 | |||
| b2087e6065 | |||
| ad8f81ad76 | |||
| c58e1c1fb3 | |||
| cb0779d761 | |||
| daca2d9354 | |||
| 3c1e710c38 | |||
| 0aefde95f2 | |||
| a84aa60554 | |||
| edd4b37325 | |||
| c5f0155061 | |||
| 9ee06ed6fc | |||
| ac6b43e9f4 | |||
| 6c5573ad96 | |||
| 1d57a0d1e5 | |||
| 9780c7d797 | |||
| 35a5451a17 | |||
| 8e81d44d54 | |||
| 86ad0c6a2b | |||
| 6ecf89d6a9 | |||
| 691efa4d8c | |||
| d1e3f435b4 | |||
| 44c3e264ad | |||
| f2b6013329 | |||
| 05d3c51a30 | |||
| 2ceda31ce2 | |||
| 9069bc3c1f | |||
| 9319fe18da | |||
| cc454a86c8 | |||
| 256e8ccebf | |||
| 72c9e6b6ca | |||
| 755a129ca2 | |||
| 8b0d3684c5 | |||
| cdb5ac1471 | |||
| 67acb1a0bb | |||
| 9eac6bafef | |||
| a12f2ff439 | |||
| a38a4d132e | |||
| 48aba73d3a | |||
| a1ab1e1a11 | |||
| f394abf35c | |||
| 52e0271bd5 | |||
| 6c0a48efff |
@@ -1,97 +0,0 @@
|
||||
# Mem0 Integration: Dual-Scope Recall + Per-Turn Memory
|
||||
|
||||
## Context
|
||||
|
||||
Mem0 auto-save works but the integration is missing key features from mem0 best practices: per-turn recall, multi-level scoping, and proper context injection. This causes the bot to "forget" on follow-up turns and not differentiate users.
|
||||
|
||||
## What's Missing (vs mem0 docs)
|
||||
|
||||
1. **Per-turn recall** — only first turn gets memory context, follow-ups get nothing
|
||||
2. **Dual-scope** — no sender vs group distinction. All memories use single hardcoded `user_id`
|
||||
3. **System prompt injection** — memory prepended to user message (pollutes session history)
|
||||
4. **`agent_id` scoping** — mem0 supports agent-level patterns, not used
|
||||
|
||||
## Changes
|
||||
|
||||
### 1. `src/memory/mem0.rs` — Use session_id for multi-level scoping
|
||||
|
||||
Map zeroclaw's `session_id` param to mem0's `user_id`. This enables per-user and per-group memory namespaces without changing the `Memory` trait.
|
||||
|
||||
```rust
|
||||
// Add helper:
|
||||
fn effective_user_id(&self, session_id: Option<&str>) -> &str {
|
||||
session_id.filter(|s| !s.is_empty()).unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
// In store(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In recall(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In list(): use effective_user_id(session_id) as mem0 user_id
|
||||
```
|
||||
|
||||
### 2. `src/channels/mod.rs` ~line 2229 — Per-turn dual-scope recall
|
||||
|
||||
Remove `if !had_prior_history` gate. Always recall from both sender scope and group scope (for group chats).
|
||||
|
||||
```rust
|
||||
// Detect group chat
|
||||
let is_group = msg.reply_target.contains("@g.us")
|
||||
|| msg.reply_target.starts_with("group:");
|
||||
|
||||
// Sender-scope recall (always)
|
||||
let sender_context = build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&msg.sender),
|
||||
).await;
|
||||
|
||||
// Group-scope recall (groups only)
|
||||
let group_context = if is_group {
|
||||
build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&history_key),
|
||||
).await
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Merge (deduplicate by checking substring overlap)
|
||||
let memory_context = merge_memory_contexts(&sender_context, &group_context);
|
||||
```
|
||||
|
||||
### 3. `src/channels/mod.rs` ~line 2244 — Inject into system prompt
|
||||
|
||||
Move memory context from user message to system prompt. Re-fetched each turn, doesn't pollute session.
|
||||
|
||||
```rust
|
||||
let mut system_prompt = build_channel_system_prompt(...);
|
||||
if !memory_context.is_empty() {
|
||||
system_prompt.push_str(&format!("\n\n{memory_context}"));
|
||||
}
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
```
|
||||
|
||||
### 4. `src/channels/mod.rs` — Dual-scope auto-save
|
||||
|
||||
Find existing auto-save call. For group messages, store twice:
|
||||
- `store(key, content, category, Some(&msg.sender))` — personal facts
|
||||
- `store(key, content, category, Some(&history_key))` — group context
|
||||
|
||||
Both async, non-blocking. DMs only store to sender scope.
|
||||
|
||||
### 5. `src/memory/mem0.rs` — Add `agent_id` support (optional)
|
||||
|
||||
Pass `self.app_name` as `agent_id` param to mem0 API for agent behavior tracking.
|
||||
|
||||
## Files to Modify
|
||||
|
||||
1. `src/memory/mem0.rs` — session_id → user_id mapping
|
||||
2. `src/channels/mod.rs` — per-turn recall, dual-scope, system prompt injection, dual-scope save
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo check --features whatsapp-web,memory-mem0`
|
||||
2. `cargo test --features whatsapp-web,memory-mem0`
|
||||
3. Deploy to Synology
|
||||
4. Test DM: "我鍾意食壽司" → next turn "我鍾意食咩" → should recall
|
||||
5. Test group: Joe says "我鍾意食壽司" → someone else asks "Joe 鍾意食咩" → should recall from group scope
|
||||
6. Check mem0 server logs: GET with `user_id=sender` AND `user_id=group_key`
|
||||
7. Check mem0 server logs: POST with both user_ids for group messages
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
run: mkdir -p web/dist && touch web/dist/.gitkeep
|
||||
|
||||
- name: Check all features
|
||||
run: cargo check --all-features --locked
|
||||
run: cargo check --features ci-all --locked
|
||||
|
||||
docs-quality:
|
||||
name: Docs Quality
|
||||
|
||||
@@ -19,6 +19,7 @@ env:
|
||||
jobs:
|
||||
detect-version-change:
|
||||
name: Detect Version Bump
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
@@ -102,6 +103,22 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish to crates.io
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
@@ -67,6 +67,24 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
if: "!inputs.dry_run"
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
if: "!inputs.dry_run"
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish (dry run)
|
||||
if: inputs.dry_run
|
||||
run: cargo publish --dry-run --locked --allow-dirty --no-verify
|
||||
|
||||
@@ -21,6 +21,7 @@ env:
|
||||
jobs:
|
||||
version:
|
||||
name: Resolve Version
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
version: ${{ steps.ver.outputs.version }}
|
||||
@@ -40,6 +41,7 @@ jobs:
|
||||
|
||||
release-notes:
|
||||
name: Generate Release Notes
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
notes: ${{ steps.notes.outputs.body }}
|
||||
@@ -130,6 +132,7 @@ jobs:
|
||||
|
||||
web:
|
||||
name: Build Web Dashboard
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
|
||||
@@ -323,6 +323,21 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish to crates.io
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
|
||||
Generated
+331
-3
@@ -117,6 +117,28 @@ version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||
|
||||
[[package]]
|
||||
name = "alsa"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
|
||||
dependencies = [
|
||||
"alsa-sys",
|
||||
"bitflags 2.11.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "alsa-sys"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db8fee663d06c4e303404ef5f40488a53e062f89ba8bfed81f42325aafad1527"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ambient-authority"
|
||||
version = "0.0.2"
|
||||
@@ -583,6 +605,24 @@ dependencies = [
|
||||
"virtue",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.72.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bip39"
|
||||
version = "2.2.2"
|
||||
@@ -878,6 +918,21 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cesu8"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||
dependencies = [
|
||||
"nom 7.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cff-parser"
|
||||
version = "0.1.0"
|
||||
@@ -1003,6 +1058,17 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"libc",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.6.0"
|
||||
@@ -1086,6 +1152,16 @@ version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compression-codecs"
|
||||
version = "0.4.37"
|
||||
@@ -1209,6 +1285,49 @@ dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coreaudio-rs"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation-sys",
|
||||
"coreaudio-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coreaudio-sys"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ceec7a6067e62d6f931a2baf6f3a751f4a892595bcec1461a3c94ef9949864b6"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpal"
|
||||
version = "0.15.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779"
|
||||
dependencies = [
|
||||
"alsa",
|
||||
"core-foundation-sys",
|
||||
"coreaudio-rs",
|
||||
"dasp_sample",
|
||||
"jni",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"mach2 0.4.3",
|
||||
"ndk",
|
||||
"ndk-context",
|
||||
"oboe",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
"windows",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpp_demangle"
|
||||
version = "0.4.5"
|
||||
@@ -1588,6 +1707,12 @@ dependencies = [
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dasp_sample"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f"
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.10.0"
|
||||
@@ -2944,7 +3069,7 @@ dependencies = [
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
"windows-core 0.62.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3395,6 +3520,28 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
|
||||
dependencies = [
|
||||
"cesu8",
|
||||
"cfg-if",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.34"
|
||||
@@ -4242,6 +4389,35 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11ec1bc47d34ae756616f387c11fd0595f86f2cc7e6473bde9e3ded30cb902a1"
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"jni-sys",
|
||||
"log",
|
||||
"ndk-sys",
|
||||
"num_enum",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk-context"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b"
|
||||
|
||||
[[package]]
|
||||
name = "ndk-sys"
|
||||
version = "0.5.0+25.2.9519653"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691"
|
||||
dependencies = [
|
||||
"jni-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "negentropy"
|
||||
version = "0.5.0"
|
||||
@@ -4421,6 +4597,17 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -4440,6 +4627,28 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26"
|
||||
dependencies = [
|
||||
"num_enum_derive",
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum_derive"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8"
|
||||
dependencies = [
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nusb"
|
||||
version = "0.2.3"
|
||||
@@ -4519,6 +4728,29 @@ dependencies = [
|
||||
"ruzstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oboe"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb"
|
||||
dependencies = [
|
||||
"jni",
|
||||
"ndk",
|
||||
"ndk-context",
|
||||
"num-derive",
|
||||
"num-traits",
|
||||
"oboe-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oboe-sys"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.4"
|
||||
@@ -8691,6 +8923,26 @@ dependencies = [
|
||||
"wasmtime-internal-math",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows"
|
||||
version = "0.54.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49"
|
||||
dependencies = [
|
||||
"windows-core 0.54.0",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.54.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12661b9c89351d684a50a8a643ce5f608e20243b9fb84687800163429f161d65"
|
||||
dependencies = [
|
||||
"windows-result 0.1.2",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.2"
|
||||
@@ -8700,7 +8952,7 @@ dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-result 0.4.1",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
@@ -8732,6 +8984,15 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8"
|
||||
dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.4.1"
|
||||
@@ -8750,6 +9011,15 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
|
||||
dependencies = [
|
||||
"windows-targets 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
@@ -8786,6 +9056,21 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
@@ -8819,6 +9104,12 @@ dependencies = [
|
||||
"windows_x86_64_msvc 0.53.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
@@ -8831,6 +9122,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -8843,6 +9140,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -8867,6 +9170,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -8879,6 +9188,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -8891,6 +9206,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
@@ -8903,6 +9224,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
@@ -9203,7 +9530,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.5"
|
||||
version = "0.5.7"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
@@ -9217,6 +9544,7 @@ dependencies = [
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"console",
|
||||
"cpal",
|
||||
"criterion",
|
||||
"cron",
|
||||
"dialoguer",
|
||||
|
||||
+27
-4
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.5"
|
||||
version = "0.5.7"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -97,7 +97,7 @@ anyhow = "1.0"
|
||||
thiserror = "2.0"
|
||||
|
||||
# Aardvark I2C/SPI/GPIO USB adapter (Total Phase) — stub when SDK absent
|
||||
aardvark-sys = { path = "crates/aardvark-sys" }
|
||||
aardvark-sys = { path = "crates/aardvark-sys", version = "0.1.0" }
|
||||
|
||||
# UUID generation
|
||||
uuid = { version = "1.22", default-features = false, features = ["v4", "std"] }
|
||||
@@ -199,6 +199,9 @@ pdf-extract = { version = "0.10", optional = true }
|
||||
# WASM plugin runtime (extism)
|
||||
extism = { version = "1.20", optional = true }
|
||||
|
||||
# Cross-platform audio capture for voice wake word detection (optional, enable with --features voice-wake)
|
||||
cpal = { version = "0.15", optional = true }
|
||||
|
||||
# Terminal QR rendering for WhatsApp Web pairing flow.
|
||||
qrcode = { version = "0.14", optional = true }
|
||||
|
||||
@@ -228,8 +231,6 @@ channel-matrix = ["dep:matrix-sdk"]
|
||||
channel-lark = ["dep:prost"]
|
||||
channel-feishu = ["channel-lark"] # Alias for Feishu users (Lark and Feishu are the same platform)
|
||||
memory-postgres = ["dep:postgres"]
|
||||
# memory-mem0 = Mem0 (OpenMemory) memory backend via REST API
|
||||
memory-mem0 = []
|
||||
observability-prometheus = ["dep:prometheus"]
|
||||
observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"]
|
||||
peripheral-rpi = ["rppal"]
|
||||
@@ -252,8 +253,30 @@ rag-pdf = ["dep:pdf-extract"]
|
||||
skill-creation = []
|
||||
# whatsapp-web = Native WhatsApp Web client with custom rusqlite storage backend
|
||||
whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-proto", "dep:wa-rs-ureq-http", "dep:wa-rs-tokio-transport", "dep:serde-big-array", "dep:prost", "dep:qrcode"]
|
||||
# voice-wake = Voice wake word detection via microphone (cpal)
|
||||
voice-wake = ["dep:cpal"]
|
||||
# WASM plugin system (extism-based)
|
||||
plugins-wasm = ["dep:extism"]
|
||||
# Meta-feature for CI: all features except those requiring system C libraries
|
||||
# not available on standard CI runners (e.g., voice-wake needs libasound2-dev).
|
||||
ci-all = [
|
||||
"channel-nostr",
|
||||
"hardware",
|
||||
"channel-matrix",
|
||||
"channel-lark",
|
||||
"memory-postgres",
|
||||
"observability-prometheus",
|
||||
"observability-otel",
|
||||
"peripheral-rpi",
|
||||
"browser-native",
|
||||
"sandbox-landlock",
|
||||
"sandbox-bubblewrap",
|
||||
"probe",
|
||||
"rag-pdf",
|
||||
"skill-creation",
|
||||
"whatsapp-web",
|
||||
"plugins-wasm",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
|
||||
@@ -27,6 +27,7 @@ COPY Cargo.toml Cargo.lock ./
|
||||
# Previously we used sed to drop `crates/robot-kit`, which made the manifest disagree
|
||||
# with the lockfile and caused `cargo --locked` to fail (Cargo refused to rewrite the lock).
|
||||
COPY crates/robot-kit/ crates/robot-kit/
|
||||
COPY crates/aardvark-sys/ crates/aardvark-sys/
|
||||
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
|
||||
RUN mkdir -p src benches \
|
||||
&& echo "fn main() {}" > src/main.rs \
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Start mem0 + reranker GPU container for ZeroClaw memory backend.
|
||||
#
|
||||
# Required env vars:
|
||||
# MEM0_LLM_API_KEY or ZAI_API_KEY — API key for the LLM used in fact extraction
|
||||
#
|
||||
# Optional env vars (with defaults):
|
||||
# MEM0_LLM_PROVIDER — mem0 LLM provider (default: "openai" i.e. OpenAI-compatible)
|
||||
# MEM0_LLM_MODEL — LLM model for fact extraction (default: "glm-5-turbo")
|
||||
# MEM0_LLM_BASE_URL — LLM API base URL (default: "https://api.z.ai/api/coding/paas/v4")
|
||||
# MEM0_EMBEDDER_MODEL — embedding model (default: "BAAI/bge-m3")
|
||||
# MEM0_EMBEDDER_DIMS — embedding dimensions (default: "1024")
|
||||
# MEM0_EMBEDDER_DEVICE — "cuda", "cpu", or "auto" (default: "cuda")
|
||||
# MEM0_VECTOR_COLLECTION — Qdrant collection name (default: "zeroclaw_mem0")
|
||||
# RERANKER_MODEL — reranker model (default: "BAAI/bge-reranker-v2-m3")
|
||||
# RERANKER_DEVICE — "cuda" or "cpu" (default: "cuda")
|
||||
# MEM0_PORT — mem0 server port (default: 8765)
|
||||
# RERANKER_PORT — reranker server port (default: 8678)
|
||||
# CONTAINER_IMAGE — base container image (default: docker.io/kyuz0/amd-strix-halo-comfyui:latest)
|
||||
# CONTAINER_NAME — container name (default: mem0-gpu)
|
||||
# DATA_DIR — host path for Qdrant data (default: ~/mem0-data)
|
||||
# SCRIPT_DIR — host path for server scripts (default: directory of this script)
|
||||
set -e
|
||||
|
||||
# Resolve script directory for mounting server scripts
|
||||
SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "$0")" && pwd)}"
|
||||
|
||||
# API key — accept either name
|
||||
export MEM0_LLM_API_KEY="${MEM0_LLM_API_KEY:-${ZAI_API_KEY:?MEM0_LLM_API_KEY or ZAI_API_KEY must be set}}"
|
||||
|
||||
# Defaults
|
||||
MEM0_LLM_MODEL="${MEM0_LLM_MODEL:-glm-5-turbo}"
|
||||
MEM0_LLM_BASE_URL="${MEM0_LLM_BASE_URL:-https://api.z.ai/api/coding/paas/v4}"
|
||||
MEM0_PORT="${MEM0_PORT:-8765}"
|
||||
RERANKER_PORT="${RERANKER_PORT:-8678}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE:-docker.io/kyuz0/amd-strix-halo-comfyui:latest}"
|
||||
CONTAINER_NAME="${CONTAINER_NAME:-mem0-gpu}"
|
||||
DATA_DIR="${DATA_DIR:-$HOME/mem0-data}"
|
||||
|
||||
# Stop existing CPU services (if any)
|
||||
kill -9 $(pgrep -f "mem0-server.py") 2>/dev/null || true
|
||||
kill -9 $(pgrep -f "reranker-server.py") 2>/dev/null || true
|
||||
|
||||
# Stop existing container
|
||||
podman stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||
podman rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||
|
||||
podman run -d --name "$CONTAINER_NAME" \
|
||||
--device /dev/dri --device /dev/kfd \
|
||||
--group-add video --group-add render \
|
||||
--restart unless-stopped \
|
||||
-p "$MEM0_PORT:$MEM0_PORT" -p "$RERANKER_PORT:$RERANKER_PORT" \
|
||||
-v "$DATA_DIR":/root/mem0-data:Z \
|
||||
-v "$SCRIPT_DIR/mem0-server.py":/app/mem0-server.py:ro,Z \
|
||||
-v "$SCRIPT_DIR/reranker-server.py":/app/reranker-server.py:ro,Z \
|
||||
-v "$HOME/.cache/huggingface":/root/.cache/huggingface:Z \
|
||||
-e MEM0_LLM_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e ZAI_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e MEM0_LLM_MODEL="$MEM0_LLM_MODEL" \
|
||||
-e MEM0_LLM_BASE_URL="$MEM0_LLM_BASE_URL" \
|
||||
${MEM0_LLM_PROVIDER:+-e MEM0_LLM_PROVIDER="$MEM0_LLM_PROVIDER"} \
|
||||
${MEM0_EMBEDDER_MODEL:+-e MEM0_EMBEDDER_MODEL="$MEM0_EMBEDDER_MODEL"} \
|
||||
${MEM0_EMBEDDER_DIMS:+-e MEM0_EMBEDDER_DIMS="$MEM0_EMBEDDER_DIMS"} \
|
||||
${MEM0_EMBEDDER_DEVICE:+-e MEM0_EMBEDDER_DEVICE="$MEM0_EMBEDDER_DEVICE"} \
|
||||
${MEM0_VECTOR_COLLECTION:+-e MEM0_VECTOR_COLLECTION="$MEM0_VECTOR_COLLECTION"} \
|
||||
${RERANKER_MODEL:+-e RERANKER_MODEL="$RERANKER_MODEL"} \
|
||||
${RERANKER_DEVICE:+-e RERANKER_DEVICE="$RERANKER_DEVICE"} \
|
||||
-e RERANKER_PORT="$RERANKER_PORT" \
|
||||
-e RERANKER_URL="http://127.0.0.1:$RERANKER_PORT/rerank" \
|
||||
-e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
|
||||
-e HOME=/root \
|
||||
"$CONTAINER_IMAGE" \
|
||||
bash -c "pip install -q FlagEmbedding mem0ai flask httpx qdrant-client 2>&1 | tail -3; echo '=== Starting reranker (GPU) on :$RERANKER_PORT ==='; python3 /app/reranker-server.py & sleep 3; echo '=== Starting mem0 (GPU) on :$MEM0_PORT ==='; exec python3 /app/mem0-server.py"
|
||||
|
||||
echo "Container started, waiting for init..."
|
||||
sleep 15
|
||||
echo "=== Container logs ==="
|
||||
podman logs "$CONTAINER_NAME" 2>&1 | tail -25
|
||||
echo "=== Port check ==="
|
||||
ss -tlnp | grep "$MEM0_PORT\|$RERANKER_PORT" || echo "Ports not yet ready, check: podman logs $CONTAINER_NAME"
|
||||
@@ -1,288 +0,0 @@
|
||||
"""Minimal OpenMemory-compatible REST server wrapping mem0 Python SDK."""
|
||||
import asyncio
|
||||
import json, os, uuid, httpx
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import FastAPI, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from mem0 import Memory
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
RERANKER_URL = os.environ.get("RERANKER_URL", "http://127.0.0.1:8678/rerank")
|
||||
|
||||
CUSTOM_EXTRACTION_PROMPT = """You are a memory extraction specialist for a Cantonese/Chinese chat assistant.
|
||||
|
||||
Extract ONLY important, persistent facts from the conversation. Rules:
|
||||
1. Extract personal preferences, habits, relationships, names, locations
|
||||
2. Extract decisions, plans, and commitments people make
|
||||
3. SKIP small talk, greetings, reactions ("ok", "哈哈", "係呀")
|
||||
4. SKIP temporary states ("我依家食緊飯") unless they reveal a habit
|
||||
5. Keep facts in the ORIGINAL language (Cantonese/Chinese/English)
|
||||
6. For each fact, note WHO it's about (use their name or identifier if known)
|
||||
7. Merge/update existing facts rather than creating duplicates
|
||||
|
||||
Return a list of facts in JSON format: {"facts": ["fact1", "fact2", ...]}
|
||||
"""
|
||||
|
||||
PROCEDURAL_EXTRACTION_PROMPT = """You are a procedural memory specialist for an AI assistant.
|
||||
|
||||
Extract HOW-TO patterns and reusable procedures from the conversation trace. Rules:
|
||||
1. Identify step-by-step procedures the assistant followed to accomplish a task
|
||||
2. Extract tool usage patterns: which tools were called, in what order, with what arguments
|
||||
3. Capture decision points: why the assistant chose one approach over another
|
||||
4. Note error-recovery patterns: what failed, how it was fixed
|
||||
5. Keep the procedure generic enough to apply to similar future tasks
|
||||
6. Preserve technical details (commands, file paths, API calls) that are reusable
|
||||
7. SKIP greetings, small talk, and conversational filler
|
||||
8. Format each procedure as: "To [goal]: [step1] -> [step2] -> ... -> [result]"
|
||||
|
||||
Return a list of procedures in JSON format: {"facts": ["procedure1", "procedure2", ...]}
|
||||
"""
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
# LLM (for fact extraction when infer=true)
|
||||
MEM0_LLM_PROVIDER = os.environ.get("MEM0_LLM_PROVIDER", "openai") # "openai" (compatible), "anthropic", etc.
|
||||
MEM0_LLM_MODEL = os.environ.get("MEM0_LLM_MODEL", "glm-5-turbo")
|
||||
MEM0_LLM_API_KEY = os.environ.get("MEM0_LLM_API_KEY") or os.environ.get("ZAI_API_KEY", "")
|
||||
MEM0_LLM_BASE_URL = os.environ.get("MEM0_LLM_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
|
||||
|
||||
# Embedder
|
||||
MEM0_EMBEDDER_PROVIDER = os.environ.get("MEM0_EMBEDDER_PROVIDER", "huggingface") # "huggingface", "openai", etc.
|
||||
MEM0_EMBEDDER_MODEL = os.environ.get("MEM0_EMBEDDER_MODEL", "BAAI/bge-m3")
|
||||
MEM0_EMBEDDER_DIMS = int(os.environ.get("MEM0_EMBEDDER_DIMS", "1024"))
|
||||
MEM0_EMBEDDER_DEVICE = os.environ.get("MEM0_EMBEDDER_DEVICE", "cuda") # "cuda", "cpu", "auto"
|
||||
|
||||
# Vector store
|
||||
MEM0_VECTOR_PROVIDER = os.environ.get("MEM0_VECTOR_PROVIDER", "qdrant") # "qdrant", "chroma", etc.
|
||||
MEM0_VECTOR_COLLECTION = os.environ.get("MEM0_VECTOR_COLLECTION", "zeroclaw_mem0")
|
||||
MEM0_VECTOR_PATH = os.environ.get("MEM0_VECTOR_PATH", os.path.expanduser("~/mem0-data/qdrant"))
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": MEM0_LLM_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_LLM_MODEL,
|
||||
"api_key": MEM0_LLM_API_KEY,
|
||||
"openai_base_url": MEM0_LLM_BASE_URL,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": MEM0_EMBEDDER_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_EMBEDDER_MODEL,
|
||||
"embedding_dims": MEM0_EMBEDDER_DIMS,
|
||||
"model_kwargs": {"device": MEM0_EMBEDDER_DEVICE},
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": MEM0_VECTOR_PROVIDER,
|
||||
"config": {
|
||||
"collection_name": MEM0_VECTOR_COLLECTION,
|
||||
"embedding_model_dims": MEM0_EMBEDDER_DIMS,
|
||||
"path": MEM0_VECTOR_PATH,
|
||||
},
|
||||
},
|
||||
"custom_fact_extraction_prompt": CUSTOM_EXTRACTION_PROMPT,
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
|
||||
|
||||
def rerank_results(query: str, items: list, top_k: int = 10) -> list:
|
||||
"""Rerank search results using bge-reranker-v2-m3."""
|
||||
if not items:
|
||||
return items
|
||||
documents = [item.get("memory", "") for item in items]
|
||||
try:
|
||||
resp = httpx.post(
|
||||
RERANKER_URL,
|
||||
json={"query": query, "documents": documents, "top_k": top_k},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ranked = resp.json().get("results", [])
|
||||
return [items[r["index"]] for r in ranked]
|
||||
except Exception as e:
|
||||
print(f"Reranker failed, using original order: {e}")
|
||||
return items
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
infer: bool = True
|
||||
app: Optional[str] = None
|
||||
custom_instructions: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/")
|
||||
async def add_memory(req: AddMemoryRequest):
|
||||
# Use client-supplied prompt, fall back to server default, then mem0 SDK default
|
||||
prompt = req.custom_instructions or CUSTOM_EXTRACTION_PROMPT
|
||||
result = await asyncio.to_thread(m.add, req.text, user_id=req.user_id, metadata=req.metadata or {}, prompt=prompt)
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
class ProceduralMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
messages: list[dict]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/procedural")
|
||||
async def add_procedural_memory(req: ProceduralMemoryRequest):
|
||||
"""Store a conversation trace as procedural memory.
|
||||
|
||||
Accepts a list of messages (role/content dicts) representing a full
|
||||
conversation turn including tool calls, then uses mem0's native
|
||||
procedural memory extraction to learn reusable "how to" patterns.
|
||||
"""
|
||||
# Build metadata with procedural type marker
|
||||
meta = {"type": "procedural"}
|
||||
if req.metadata:
|
||||
meta.update(req.metadata)
|
||||
|
||||
# Use mem0's native message list support + procedural prompt
|
||||
result = await asyncio.to_thread(m.add,
|
||||
req.messages,
|
||||
user_id=req.user_id,
|
||||
metadata=meta,
|
||||
prompt=PROCEDURAL_EXTRACTION_PROMPT,
|
||||
)
|
||||
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
def _parse_mem0_results(raw_results) -> list:
|
||||
raw = raw_results.get("results", raw_results) if isinstance(raw_results, dict) else raw_results
|
||||
items = []
|
||||
for r in raw:
|
||||
item = r if isinstance(r, dict) else {"memory": str(r)}
|
||||
items.append({
|
||||
"id": item.get("id", str(uuid.uuid4())),
|
||||
"memory": item.get("memory", item.get("text", "")),
|
||||
"created_at": item.get("created_at", datetime.now(timezone.utc).isoformat()),
|
||||
"metadata_": item.get("metadata", {}),
|
||||
})
|
||||
return items
|
||||
|
||||
|
||||
def _parse_iso_timestamp(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 timestamp string, returning None on failure."""
|
||||
try:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _item_created_at(item: dict) -> Optional[datetime]:
|
||||
"""Extract created_at from an item as a timezone-aware datetime."""
|
||||
raw = item.get("created_at")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, datetime):
|
||||
if raw.tzinfo is None:
|
||||
raw = raw.replace(tzinfo=timezone.utc)
|
||||
return raw
|
||||
return _parse_iso_timestamp(str(raw))
|
||||
|
||||
|
||||
def _apply_post_filters(
|
||||
items: list,
|
||||
created_after: Optional[str],
|
||||
created_before: Optional[str],
|
||||
) -> list:
|
||||
"""Filter items by created_after / created_before timestamps (post-query)."""
|
||||
after_dt = _parse_iso_timestamp(created_after) if created_after else None
|
||||
before_dt = _parse_iso_timestamp(created_before) if created_before else None
|
||||
if after_dt is None and before_dt is None:
|
||||
return items
|
||||
filtered = []
|
||||
for item in items:
|
||||
ts = _item_created_at(item)
|
||||
if ts is None:
|
||||
# Keep items without a parseable timestamp
|
||||
filtered.append(item)
|
||||
continue
|
||||
if after_dt and ts < after_dt:
|
||||
continue
|
||||
if before_dt and ts > before_dt:
|
||||
continue
|
||||
filtered.append(item)
|
||||
return filtered
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/")
|
||||
async def list_or_search_memories(
|
||||
user_id: str = Query(...),
|
||||
search_query: Optional[str] = Query(None),
|
||||
size: int = Query(10),
|
||||
rerank: bool = Query(True),
|
||||
created_after: Optional[str] = Query(None),
|
||||
created_before: Optional[str] = Query(None),
|
||||
metadata_filter: Optional[str] = Query(None),
|
||||
):
|
||||
# Build mem0 SDK filters dict from metadata_filter JSON param
|
||||
sdk_filters = None
|
||||
if metadata_filter:
|
||||
try:
|
||||
sdk_filters = json.loads(metadata_filter)
|
||||
except json.JSONDecodeError:
|
||||
sdk_filters = None
|
||||
|
||||
if search_query:
|
||||
# Fetch more results than needed so reranker has candidates to work with
|
||||
fetch_size = min(size * 3, 50)
|
||||
results = await asyncio.to_thread(m.search,
|
||||
search_query,
|
||||
user_id=user_id,
|
||||
limit=fetch_size,
|
||||
filters=sdk_filters,
|
||||
)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
if rerank and items:
|
||||
items = rerank_results(search_query, items, top_k=size)
|
||||
else:
|
||||
items = items[:size]
|
||||
return {"items": items, "total": len(items)}
|
||||
else:
|
||||
results = await asyncio.to_thread(m.get_all,user_id=user_id, filters=sdk_filters)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
return {"items": items, "total": len(items)}
|
||||
|
||||
|
||||
@app.delete("/api/v1/memories/{memory_id}")
|
||||
async def delete_memory(memory_id: str):
|
||||
try:
|
||||
await asyncio.to_thread(m.delete, memory_id)
|
||||
except Exception:
|
||||
pass
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/{memory_id}/history")
|
||||
async def get_memory_history(memory_id: str):
|
||||
"""Return the edit history of a specific memory."""
|
||||
try:
|
||||
history = await asyncio.to_thread(m.history, memory_id)
|
||||
# Normalize to list of dicts
|
||||
entries = []
|
||||
raw = history if isinstance(history, list) else history.get("results", history) if isinstance(history, dict) else [history]
|
||||
for h in raw:
|
||||
entry = h if isinstance(h, dict) else {"event": str(h)}
|
||||
entries.append(entry)
|
||||
return {"memory_id": memory_id, "history": entries}
|
||||
except Exception as e:
|
||||
return {"memory_id": memory_id, "history": [], "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8765)
|
||||
@@ -1,50 +0,0 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from FlagEmbedding import FlagReranker
|
||||
import os, torch
|
||||
|
||||
app = Flask(__name__)
|
||||
reranker = None
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
RERANKER_DEVICE = os.environ.get("RERANKER_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
RERANKER_PORT = int(os.environ.get("RERANKER_PORT", "8678"))
|
||||
|
||||
def get_reranker():
|
||||
global reranker
|
||||
if reranker is None:
|
||||
reranker = FlagReranker(RERANKER_MODEL, use_fp16=True, device=RERANKER_DEVICE)
|
||||
return reranker
|
||||
|
||||
@app.route('/rerank', methods=['POST'])
|
||||
def rerank():
|
||||
data = request.json
|
||||
query = data.get('query', '')
|
||||
documents = data.get('documents', [])
|
||||
top_k = data.get('top_k', len(documents))
|
||||
|
||||
if not query or not documents:
|
||||
return jsonify({'error': 'query and documents required'}), 400
|
||||
|
||||
pairs = [[query, doc] for doc in documents]
|
||||
scores = get_reranker().compute_score(pairs)
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
|
||||
results = sorted(
|
||||
[{'index': i, 'document': doc, 'score': score}
|
||||
for i, (doc, score) in enumerate(zip(documents, scores))],
|
||||
key=lambda x: x['score'], reverse=True
|
||||
)[:top_k]
|
||||
|
||||
return jsonify({'results': results})
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health():
|
||||
return jsonify({'status': 'ok', 'model': RERANKER_MODEL, 'device': RERANKER_DEVICE})
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f'Loading reranker model ({RERANKER_MODEL}) on {RERANKER_DEVICE}...')
|
||||
get_reranker()
|
||||
print(f'Reranker server ready on :{RERANKER_PORT}')
|
||||
app.run(host='0.0.0.0', port=RERANKER_PORT)
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.5
|
||||
pkgver = 0.5.7
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.5.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.5.tar.gz
|
||||
source = zeroclaw-0.5.7.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.7.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.5
|
||||
pkgver=0.5.7
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.5",
|
||||
"version": "0.5.7",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.5/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.7/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -411,30 +411,6 @@ allowed_roots = [\"~/Desktop/projects\", \"/opt/shared-repo\"]
|
||||
|
||||
- 内存上下文注入忽略旧的 `assistant_resp*` 自动保存键,以防止旧模型生成的摘要被视为事实。
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) 后端 — 连接自托管 mem0 服务器,提供基于向量的记忆存储和 LLM 事实提取。构建时需要 `memory-mem0` feature flag,配置需设置 `backend = "mem0"`。
|
||||
|
||||
| 键 | 默认值 | 环境变量 | 用途 |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory 服务器地址 |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | 记忆作用域的用户 ID |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | 在 mem0 中注册的应用名称 |
|
||||
| `infer` | `true` | — | 使用 LLM 从存储文本中提取事实 (`true`) 或原样存储 (`false`) |
|
||||
| `extraction_prompt` | 未设置 | `MEM0_EXTRACTION_PROMPT` | 自定义 LLM 事实提取提示词(如适用于非英文内容) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "用原始语言提取事实..."
|
||||
```
|
||||
|
||||
服务器部署脚本位于 `deploy/mem0/`。
|
||||
|
||||
## `[[model_routes]]` 和 `[[embedding_routes]]`
|
||||
|
||||
使用路由提示,以便集成可以在模型 ID 演变时保持稳定的名称。
|
||||
|
||||
@@ -508,30 +508,6 @@ Notes:
|
||||
|
||||
- Memory context injection ignores legacy `assistant_resp*` auto-save keys to prevent old model-authored summaries from being treated as facts.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) backend — connects to a self-hosted mem0 server for vector-based memory with LLM-powered fact extraction. Requires feature flag `memory-mem0` at build time and `backend = "mem0"` in config.
|
||||
|
||||
| Key | Default | Env var | Purpose |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory server URL |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID for scoping memories |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Application name registered in mem0 |
|
||||
| `infer` | `true` | — | Use LLM to extract facts from stored text (`true`) or store raw (`false`) |
|
||||
| `extraction_prompt` | unset | `MEM0_EXTRACTION_PROMPT` | Custom prompt for LLM fact extraction (e.g. for non-English content) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Extract facts in the original language..."
|
||||
```
|
||||
|
||||
Server deployment scripts are in `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` and `[[embedding_routes]]`
|
||||
|
||||
Use route hints so integrations can keep stable names while model IDs evolve.
|
||||
|
||||
@@ -337,30 +337,6 @@ Lưu ý:
|
||||
|
||||
- Chèn ngữ cảnh memory bỏ qua khóa auto-save `assistant_resp*` kiểu cũ để tránh tóm tắt do model tạo bị coi là sự thật.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Backend Mem0 (OpenMemory) — kết nối đến server mem0 tự host, cung cấp bộ nhớ vector với trích xuất sự kiện bằng LLM. Cần feature flag `memory-mem0` khi build và `backend = "mem0"` trong config.
|
||||
|
||||
| Khóa | Mặc định | Biến môi trường | Mục đích |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | URL server OpenMemory |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID để phân vùng memory |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Tên ứng dụng đăng ký trong mem0 |
|
||||
| `infer` | `true` | — | Dùng LLM trích xuất sự kiện từ text (`true`) hoặc lưu nguyên (`false`) |
|
||||
| `extraction_prompt` | chưa đặt | `MEM0_EXTRACTION_PROMPT` | Prompt tùy chỉnh cho trích xuất sự kiện LLM (vd: cho nội dung không phải tiếng Anh) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Trích xuất sự kiện bằng ngôn ngữ gốc..."
|
||||
```
|
||||
|
||||
Script triển khai server nằm trong `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` và `[[embedding_routes]]`
|
||||
|
||||
Route hint giúp tên tích hợp ổn định khi model ID thay đổi.
|
||||
|
||||
+79
-25
@@ -569,11 +569,29 @@ MSG
|
||||
exit 0
|
||||
fi
|
||||
# Detect un-accepted Xcode/CLT license (causes `cc` to exit 69).
|
||||
if ! /usr/bin/xcrun --show-sdk-path >/dev/null 2>&1; then
|
||||
warn "Xcode license has not been accepted. Run:"
|
||||
warn " sudo xcodebuild -license accept"
|
||||
warn "then re-run this installer."
|
||||
exit 1
|
||||
# xcrun --show-sdk-path can succeed even without an accepted license,
|
||||
# so we test-compile a trivial C file which reliably triggers the error.
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
|
||||
_xcode_accept_ok=false
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
xcodebuild -license accept && _xcode_accept_ok=true
|
||||
elif [[ -c /dev/tty ]] && have_cmd sudo; then
|
||||
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
|
||||
fi
|
||||
if [[ "$_xcode_accept_ok" == true ]]; then
|
||||
step_ok "Xcode license accepted"
|
||||
else
|
||||
error "Could not accept Xcode license. Run manually:"
|
||||
error " sudo xcodebuild -license accept"
|
||||
error "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
rm -f "$_xcode_test_file"
|
||||
fi
|
||||
if ! have_cmd git; then
|
||||
warn "git is not available. Install git (e.g., Homebrew) and re-run bootstrap."
|
||||
@@ -1175,6 +1193,43 @@ else
|
||||
install_system_deps
|
||||
fi
|
||||
|
||||
# Always check Xcode/CLT license on macOS, regardless of --install-system-deps.
|
||||
# An un-accepted license causes `cc` to exit 69, breaking all Rust builds.
|
||||
if [[ "$OS_NAME" == "Darwin" ]]; then
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
|
||||
# Use /dev/tty so sudo can prompt for a password even in a curl|bash pipe.
|
||||
_xcode_accept_ok=false
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
xcodebuild -license accept && _xcode_accept_ok=true
|
||||
elif [[ -c /dev/tty ]] && have_cmd sudo; then
|
||||
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
|
||||
fi
|
||||
if [[ "$_xcode_accept_ok" == true ]]; then
|
||||
step_ok "Xcode license accepted"
|
||||
# Re-test compilation to confirm it's fixed.
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
error "C compiler still failing after license accept. Check your Xcode/CLT installation."
|
||||
exit 1
|
||||
fi
|
||||
rm -f "$_xcode_test_file"
|
||||
else
|
||||
error "Could not accept Xcode license. Run manually:"
|
||||
error " sudo xcodebuild -license accept"
|
||||
error "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
rm -f "$_xcode_test_file"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$INSTALL_RUST" == true ]]; then
|
||||
install_rust_toolchain
|
||||
fi
|
||||
@@ -1393,6 +1448,25 @@ else
|
||||
step_dot "Skipping install"
|
||||
fi
|
||||
|
||||
# --- Build web dashboard ---
|
||||
if [[ "$SKIP_BUILD" == false && -d "$WORK_DIR/web" ]]; then
|
||||
if have_cmd node && have_cmd npm; then
|
||||
step_dot "Building web dashboard"
|
||||
if (cd "$WORK_DIR/web" && npm ci --ignore-scripts 2>/dev/null && npm run build 2>/dev/null); then
|
||||
step_ok "Web dashboard built"
|
||||
else
|
||||
warn "Web dashboard build failed — dashboard will not be available"
|
||||
fi
|
||||
else
|
||||
warn "node/npm not found — skipping web dashboard build"
|
||||
warn "Install Node.js (>=18) and re-run, or build manually: cd web && npm ci && npm run build"
|
||||
fi
|
||||
else
|
||||
if [[ "$SKIP_BUILD" == true ]]; then
|
||||
step_dot "Skipping web dashboard build"
|
||||
fi
|
||||
fi
|
||||
|
||||
ZEROCLAW_BIN=""
|
||||
if [[ -x "$HOME/.cargo/bin/zeroclaw" ]]; then
|
||||
ZEROCLAW_BIN="$HOME/.cargo/bin/zeroclaw"
|
||||
@@ -1467,25 +1541,6 @@ if [[ -n "$ZEROCLAW_BIN" ]]; then
|
||||
if "$ZEROCLAW_BIN" service restart 2>/dev/null; then
|
||||
step_ok "Gateway service restarted"
|
||||
|
||||
# Fetch and display pairing code from running gateway
|
||||
PAIR_CODE=""
|
||||
for i in 1 2 3 4 5; do
|
||||
sleep 2
|
||||
if PAIR_CODE=$("$ZEROCLAW_BIN" gateway get-paircode 2>/dev/null | grep -oE '[0-9]{6}'); then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [[ -n "$PAIR_CODE" ]]; then
|
||||
echo
|
||||
echo -e " ${BOLD_BLUE}🔐 Gateway Pairing Code${RESET}"
|
||||
echo
|
||||
echo -e " ${BOLD_BLUE}┌──────────────┐${RESET}"
|
||||
echo -e " ${BOLD_BLUE}│${RESET} ${BOLD}${PAIR_CODE}${RESET} ${BOLD_BLUE}│${RESET}"
|
||||
echo -e " ${BOLD_BLUE}└──────────────┘${RESET}"
|
||||
echo
|
||||
echo -e " ${DIM}Enter this code in the dashboard to pair your device.${RESET}"
|
||||
echo -e " ${DIM}Run 'zeroclaw gateway get-paircode --new' anytime to generate a fresh code.${RESET}"
|
||||
fi
|
||||
else
|
||||
step_fail "Gateway service restart failed — re-run with zeroclaw service start"
|
||||
fi
|
||||
@@ -1532,7 +1587,6 @@ GATEWAY_PORT=42617
|
||||
DASHBOARD_URL="http://127.0.0.1:${GATEWAY_PORT}"
|
||||
echo
|
||||
echo -e "${BOLD}Dashboard URL:${RESET} ${BLUE}${DASHBOARD_URL}${RESET}"
|
||||
echo -e "${DIM} Run 'zeroclaw gateway get-paircode' to get your pairing code.${RESET}"
|
||||
|
||||
# --- Copy to clipboard ---
|
||||
COPIED_TO_CLIPBOARD=false
|
||||
|
||||
+2
-1
@@ -359,7 +359,7 @@ impl Agent {
|
||||
None
|
||||
};
|
||||
|
||||
let (mut tools, delegate_handle) = tools::all_tools_with_runtime(
|
||||
let (mut tools, delegate_handle, _reaction_handle) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
@@ -373,6 +373,7 @@ impl Agent {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
// ── Wire MCP tools (non-fatal) ─────────────────────────────
|
||||
|
||||
+23
-16
@@ -3525,7 +3525,7 @@ pub async fn run(
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (mut tools_registry, delegate_handle) = tools::all_tools_with_runtime(
|
||||
let (mut tools_registry, delegate_handle, _reaction_handle) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
@@ -3539,6 +3539,7 @@ pub async fn run(
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
@@ -3833,6 +3834,8 @@ pub async fn run(
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
|
||||
// Append structured tool-use instructions with schemas (only for non-native providers)
|
||||
@@ -4282,21 +4285,23 @@ pub async fn process_message(
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (mut tools_registry, delegate_handle_pm) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
);
|
||||
let (mut tools_registry, delegate_handle_pm, _reaction_handle_pm) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
||||
tools_registry.extend(peripheral_tools);
|
||||
@@ -4492,6 +4497,8 @@ pub async fn process_message(
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
if !native_tools {
|
||||
system_prompt.push_str(&build_tool_instructions(&tools_registry, Some(&i18n_descs)));
|
||||
|
||||
@@ -118,6 +118,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -226,6 +229,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.95),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
@@ -235,6 +241,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
+2
-2
@@ -122,7 +122,7 @@ impl ApprovalManager {
|
||||
}
|
||||
|
||||
// always_ask overrides everything.
|
||||
if self.always_ask.contains(tool_name) {
|
||||
if self.always_ask.contains("*") || self.always_ask.contains(tool_name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ impl ApprovalManager {
|
||||
}
|
||||
|
||||
// auto_approve skips the prompt.
|
||||
if self.auto_approve.contains(tool_name) {
|
||||
if self.auto_approve.contains("*") || self.auto_approve.contains(tool_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,549 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::memory::{Memory, MemoryCategory};
|
||||
|
||||
/// Discord History channel — connects via Gateway WebSocket, stores ALL non-bot messages
|
||||
/// to a dedicated discord.db, and forwards @mention messages to the agent.
|
||||
pub struct DiscordHistoryChannel {
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
/// Channel IDs to watch. Empty = watch all channels.
|
||||
channel_ids: Vec<String>,
|
||||
/// Dedicated discord.db memory backend.
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
proxy_url: Option<String>,
|
||||
/// When false, DM messages are not stored in discord.db.
|
||||
store_dms: bool,
|
||||
/// When false, @mentions in DMs are not forwarded to the agent.
|
||||
respond_to_dms: bool,
|
||||
}
|
||||
|
||||
impl DiscordHistoryChannel {
|
||||
pub fn new(
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
channel_ids: Vec<String>,
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
store_dms: bool,
|
||||
respond_to_dms: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
allowed_users,
|
||||
channel_ids,
|
||||
discord_memory,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
store_dms,
|
||||
respond_to_dms,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client(
|
||||
"channel.discord_history",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
if self.allowed_users.is_empty() {
|
||||
return true; // default open for logging channel
|
||||
}
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
fn is_channel_watched(&self, channel_id: &str) -> bool {
|
||||
self.channel_ids.is_empty() || self.channel_ids.iter().any(|c| c == channel_id)
|
||||
}
|
||||
|
||||
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
||||
let part = token.split('.').next()?;
|
||||
base64_decode(part)
|
||||
}
|
||||
|
||||
async fn resolve_channel_name(&self, channel_id: &str) -> String {
|
||||
// 1. Check persistent database (via discord_memory)
|
||||
let cache_key = format!("cache:channel_name:{}", channel_id);
|
||||
|
||||
if let Ok(Some(cached_mem)) = self.discord_memory.get(&cache_key).await {
|
||||
// Check if it's still fresh (e.g., less than 24 hours old)
|
||||
// Note: cached_mem.timestamp is an RFC3339 string
|
||||
let is_fresh =
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&cached_mem.timestamp) {
|
||||
chrono::Utc::now().signed_duration_since(ts.with_timezone(&chrono::Utc))
|
||||
< chrono::Duration::hours(24)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if is_fresh {
|
||||
return cached_mem.content.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Fetch from API (either not in DB or stale)
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}");
|
||||
let resp = self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let name = if let Ok(r) = resp {
|
||||
if let Ok(json) = r.json::<serde_json::Value>().await {
|
||||
json.get("name")
|
||||
.and_then(|n| n.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
// For DMs, there might not be a 'name', use the recipient's username if available
|
||||
json.get("recipients")
|
||||
.and_then(|r| r.as_array())
|
||||
.and_then(|a| a.first())
|
||||
.and_then(|u| u.get("username"))
|
||||
.and_then(|un| un.as_str())
|
||||
.map(|s| format!("dm-{}", s))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let resolved = name.unwrap_or_else(|| channel_id.to_string());
|
||||
|
||||
// 3. Store in persistent database
|
||||
let _ = self
|
||||
.discord_memory
|
||||
.store(
|
||||
&cache_key,
|
||||
&resolved,
|
||||
crate::memory::MemoryCategory::Custom("channel_cache".to_string()),
|
||||
Some(channel_id),
|
||||
)
|
||||
.await;
|
||||
|
||||
resolved
|
||||
}
|
||||
}
|
||||
|
||||
const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
fn base64_decode(input: &str) -> Option<String> {
|
||||
let padded = match input.len() % 4 {
|
||||
2 => format!("{input}=="),
|
||||
3 => format!("{input}="),
|
||||
_ => input.to_string(),
|
||||
};
|
||||
let mut bytes = Vec::new();
|
||||
let chars: Vec<u8> = padded.bytes().collect();
|
||||
for chunk in chars.chunks(4) {
|
||||
if chunk.len() < 4 {
|
||||
break;
|
||||
}
|
||||
let mut v = [0usize; 4];
|
||||
for (i, &b) in chunk.iter().enumerate() {
|
||||
if b == b'=' {
|
||||
v[i] = 0;
|
||||
} else {
|
||||
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
|
||||
}
|
||||
}
|
||||
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
|
||||
if chunk[2] != b'=' {
|
||||
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
|
||||
}
|
||||
if chunk[3] != b'=' {
|
||||
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
|
||||
}
|
||||
}
|
||||
String::from_utf8(bytes).ok()
|
||||
}
|
||||
|
||||
fn contains_bot_mention(content: &str, bot_user_id: &str) -> bool {
|
||||
if bot_user_id.is_empty() {
|
||||
return false;
|
||||
}
|
||||
content.contains(&format!("<@{bot_user_id}>"))
|
||||
|| content.contains(&format!("<@!{bot_user_id}>"))
|
||||
}
|
||||
|
||||
fn strip_bot_mention(content: &str, bot_user_id: &str) -> String {
|
||||
let mut result = content.to_string();
|
||||
for tag in [format!("<@{bot_user_id}>"), format!("<@!{bot_user_id}>")] {
|
||||
result = result.replace(&tag, " ");
|
||||
}
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for DiscordHistoryChannel {
|
||||
fn name(&self) -> &str {
|
||||
"discord_history"
|
||||
}
|
||||
|
||||
/// Send a reply back to Discord (used when agent responds to @mention).
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
let content = super::strip_tool_call_tags(&message.content);
|
||||
let url = format!(
|
||||
"https://discord.com/api/v10/channels/{}/messages",
|
||||
message.recipient
|
||||
);
|
||||
self.http_client()
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.json(&json!({"content": content}))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
|
||||
|
||||
// Get Gateway URL
|
||||
let gw_resp: serde_json::Value = self
|
||||
.http_client()
|
||||
.get("https://discord.com/api/v10/gateway/bot")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let gw_url = gw_resp
|
||||
.get("url")
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("wss://gateway.discord.gg");
|
||||
|
||||
let ws_url = format!("{gw_url}/?v=10&encoding=json");
|
||||
tracing::info!("DiscordHistory: connecting to gateway...");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
|
||||
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
|
||||
let heartbeat_interval = hello_data
|
||||
.get("d")
|
||||
.and_then(|d| d.get("heartbeat_interval"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(41250);
|
||||
|
||||
// Identify with intents for guild + DM messages + message content
|
||||
let identify = json!({
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.bot_token,
|
||||
"intents": 37377,
|
||||
"properties": {
|
||||
"os": "linux",
|
||||
"browser": "zeroclaw",
|
||||
"device": "zeroclaw"
|
||||
}
|
||||
}
|
||||
});
|
||||
write
|
||||
.send(Message::Text(identify.to_string().into()))
|
||||
.await?;
|
||||
|
||||
tracing::info!("DiscordHistory: connected and identified");
|
||||
|
||||
let mut sequence: i64 = -1;
|
||||
|
||||
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
|
||||
tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(std::time::Duration::from_millis(heartbeat_interval));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if hb_tx.send(()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let guild_filter = self.guild_id.clone();
|
||||
let discord_memory = Arc::clone(&self.discord_memory);
|
||||
let store_dms = self.store_dms;
|
||||
let respond_to_dms = self.respond_to_dms;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = hb_rx.recv() => {
|
||||
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
||||
let hb = json!({"op": 1, "d": d});
|
||||
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = read.next() => {
|
||||
let msg = match msg {
|
||||
Some(Ok(Message::Text(t))) => t,
|
||||
Some(Ok(Message::Ping(payload))) => {
|
||||
if write.send(Message::Pong(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Err(e)) => {
|
||||
tracing::warn!("DiscordHistory: websocket error: {e}");
|
||||
break;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) {
|
||||
sequence = s;
|
||||
}
|
||||
|
||||
let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0);
|
||||
match op {
|
||||
1 => {
|
||||
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
||||
let hb = json!({"op": 1, "d": d});
|
||||
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
7 => { tracing::warn!("DiscordHistory: Reconnect (op 7)"); break; }
|
||||
9 => { tracing::warn!("DiscordHistory: Invalid Session (op 9)"); break; }
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
|
||||
if event_type != "MESSAGE_CREATE" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(d) = event.get("d") else { continue };
|
||||
|
||||
// Skip messages from the bot itself
|
||||
let author_id = d
|
||||
.get("author")
|
||||
.and_then(|a| a.get("id"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("");
|
||||
let username = d
|
||||
.get("author")
|
||||
.and_then(|a| a.get("username"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or(author_id);
|
||||
|
||||
if author_id == bot_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip other bots
|
||||
if d.get("author")
|
||||
.and_then(|a| a.get("bot"))
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_id = d
|
||||
.get("channel_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// DM detection: DMs have no guild_id
|
||||
let is_dm_event = d.get("guild_id").and_then(serde_json::Value::as_str).is_none();
|
||||
|
||||
// Resolve channel name (with cache)
|
||||
let channel_display = if is_dm_event {
|
||||
"dm".to_string()
|
||||
} else {
|
||||
self.resolve_channel_name(&channel_id).await
|
||||
};
|
||||
|
||||
if is_dm_event && !store_dms && !respond_to_dms {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Guild filter
|
||||
if let Some(ref gid) = guild_filter {
|
||||
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str);
|
||||
if let Some(g) = msg_guild {
|
||||
if g != gid {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Channel filter
|
||||
if !self.is_channel_watched(&channel_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !self.is_user_allowed(author_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let is_mention = contains_bot_mention(content, &bot_user_id);
|
||||
|
||||
// Collect attachment URLs
|
||||
let attachments: Vec<String> = d
|
||||
.get("attachments")
|
||||
.and_then(|a| a.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|a| a.get("url").and_then(|u| u.as_str()))
|
||||
.map(|u| u.to_string())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Store messages to discord.db (skip DMs if store_dms=false)
|
||||
if (!is_dm_event || store_dms) && (!content.is_empty() || !attachments.is_empty()) {
|
||||
let ts = chrono::Utc::now().to_rfc3339();
|
||||
let mut mem_content = format!(
|
||||
"@{username} in #{channel_display} at {ts}: {content}"
|
||||
);
|
||||
if !attachments.is_empty() {
|
||||
mem_content.push_str(" [attachments: ");
|
||||
mem_content.push_str(&attachments.join(", "));
|
||||
mem_content.push(']');
|
||||
}
|
||||
let mem_key = format!(
|
||||
"discord_{}",
|
||||
if message_id.is_empty() {
|
||||
Uuid::new_v4().to_string()
|
||||
} else {
|
||||
message_id.to_string()
|
||||
}
|
||||
);
|
||||
let channel_id_for_session = if channel_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(channel_id.as_str())
|
||||
};
|
||||
if let Err(err) = discord_memory
|
||||
.store(
|
||||
&mem_key,
|
||||
&mem_content,
|
||||
MemoryCategory::Custom("discord".to_string()),
|
||||
channel_id_for_session,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("discord_history: failed to store message: {err}");
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"discord_history: stored message from @{username} in #{channel_display}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Forward @mention to agent (skip DMs if respond_to_dms=false)
|
||||
if is_mention && (!is_dm_event || respond_to_dms) {
|
||||
let clean_content = strip_bot_mention(content, &bot_user_id);
|
||||
if clean_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let channel_msg = ChannelMessage {
|
||||
id: if message_id.is_empty() {
|
||||
Uuid::new_v4().to_string()
|
||||
} else {
|
||||
format!("discord_{message_id}")
|
||||
},
|
||||
sender: author_id.to_string(),
|
||||
reply_target: if channel_id.is_empty() {
|
||||
author_id.to_string()
|
||||
} else {
|
||||
channel_id.clone()
|
||||
},
|
||||
content: clean_content,
|
||||
channel: "discord_history".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
};
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.http_client()
|
||||
.get("https://discord.com/api/v10/users/@me")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> {
|
||||
let mut guard = self.typing_handles.lock();
|
||||
if let Some(h) = guard.remove(recipient) {
|
||||
h.abort();
|
||||
}
|
||||
let client = self.http_client();
|
||||
let token = self.bot_token.clone();
|
||||
let channel_id = recipient.to_string();
|
||||
let handle = tokio::spawn(async move {
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing");
|
||||
loop {
|
||||
let _ = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {token}"))
|
||||
.send()
|
||||
.await;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(8)).await;
|
||||
}
|
||||
});
|
||||
guard.insert(recipient.to_string(), handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> {
|
||||
let mut guard = self.typing_handles.lock();
|
||||
if let Some(handle) = guard.remove(recipient) {
|
||||
handle.abort();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+128
-29
@@ -19,7 +19,9 @@ pub mod clawdtalk;
|
||||
pub mod cli;
|
||||
pub mod dingtalk;
|
||||
pub mod discord;
|
||||
pub mod discord_history;
|
||||
pub mod email_channel;
|
||||
pub mod gmail_push;
|
||||
pub mod imessage;
|
||||
pub mod irc;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
@@ -45,6 +47,8 @@ pub mod traits;
|
||||
pub mod transcription;
|
||||
pub mod tts;
|
||||
pub mod twitter;
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub mod voice_wake;
|
||||
pub mod wati;
|
||||
pub mod webhook;
|
||||
pub mod wecom;
|
||||
@@ -59,7 +63,9 @@ pub use clawdtalk::{ClawdTalkChannel, ClawdTalkConfig};
|
||||
pub use cli::CliChannel;
|
||||
pub use dingtalk::DingTalkChannel;
|
||||
pub use discord::DiscordChannel;
|
||||
pub use discord_history::DiscordHistoryChannel;
|
||||
pub use email_channel::EmailChannel;
|
||||
pub use gmail_push::GmailPushChannel;
|
||||
pub use imessage::IMessageChannel;
|
||||
pub use irc::IrcChannel;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
@@ -82,6 +88,8 @@ pub use traits::{Channel, SendMessage};
|
||||
#[allow(unused_imports)]
|
||||
pub use tts::{TtsManager, TtsProvider};
|
||||
pub use twitter::TwitterChannel;
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub use voice_wake::VoiceWakeChannel;
|
||||
pub use wati::WatiChannel;
|
||||
pub use webhook::WebhookChannel;
|
||||
pub use wecom::WeComChannel;
|
||||
@@ -3128,9 +3136,12 @@ pub fn build_system_prompt_with_mode(
|
||||
Some(&autonomy_cfg),
|
||||
native_tools,
|
||||
skills_prompt_mode,
|
||||
false,
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
workspace_dir: &std::path::Path,
|
||||
model_name: &str,
|
||||
@@ -3141,6 +3152,8 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
autonomy_config: Option<&crate::config::AutonomyConfig>,
|
||||
native_tools: bool,
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode,
|
||||
compact_context: bool,
|
||||
max_system_prompt_chars: usize,
|
||||
) -> String {
|
||||
use std::fmt::Write;
|
||||
let mut prompt = String::with_capacity(8192);
|
||||
@@ -3167,11 +3180,19 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
// ── 1. Tooling ──────────────────────────────────────────────
|
||||
if !tools.is_empty() {
|
||||
prompt.push_str("## Tools\n\n");
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
if compact_context {
|
||||
// Compact mode: tool names only, no descriptions/schemas
|
||||
prompt.push_str("Available tools: ");
|
||||
let names: Vec<&str> = tools.iter().map(|(name, _)| *name).collect();
|
||||
prompt.push_str(&names.join(", "));
|
||||
prompt.push_str("\n\n");
|
||||
} else {
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
// ── 1b. Hardware (when gpio/arduino tools present) ───────────
|
||||
@@ -3315,11 +3336,13 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
std::env::consts::OS,
|
||||
);
|
||||
|
||||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
|
||||
// ── 8. Channel Capabilities (skipped in compact_context mode) ──
|
||||
if !compact_context {
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||||
prompt
|
||||
.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
|
||||
Some(crate::security::AutonomyLevel::Full) => {
|
||||
"- If the runtime policy already allows a tool, use it directly; do not ask the user for extra approval.\n\
|
||||
- Never pretend you are waiting for a human approval click or confirmation when the runtime policy already permits the action.\n\
|
||||
@@ -3333,10 +3356,23 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
- If there is no approval path for this channel or the runtime blocks an action, explain that restriction directly instead of simulating an approval flow.\n"
|
||||
}
|
||||
});
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
|
||||
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
|
||||
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
|
||||
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
|
||||
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
|
||||
} // end if !compact_context (Channel Capabilities)
|
||||
|
||||
// ── 9. Truncation (max_system_prompt_chars budget) ──────────
|
||||
if max_system_prompt_chars > 0 && prompt.len() > max_system_prompt_chars {
|
||||
// Truncate on a char boundary, keeping the top portion (identity + safety).
|
||||
let mut end = max_system_prompt_chars;
|
||||
// Ensure we don't split a multi-byte UTF-8 character.
|
||||
while !prompt.is_char_boundary(end) && end > 0 {
|
||||
end -= 1;
|
||||
}
|
||||
prompt.truncate(end);
|
||||
prompt.push_str("\n\n[System prompt truncated to fit context budget]\n");
|
||||
}
|
||||
|
||||
if prompt.is_empty() {
|
||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct."
|
||||
@@ -3747,6 +3783,31 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref dh) = config.channels_config.discord_history {
|
||||
match crate::memory::SqliteMemory::new_named(&config.workspace_dir, "discord") {
|
||||
Ok(discord_mem) => {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Discord History",
|
||||
channel: Arc::new(
|
||||
DiscordHistoryChannel::new(
|
||||
dh.bot_token.clone(),
|
||||
dh.guild_id.clone(),
|
||||
dh.allowed_users.clone(),
|
||||
dh.channel_ids.clone(),
|
||||
Arc::new(discord_mem),
|
||||
dh.store_dms,
|
||||
dh.respond_to_dms,
|
||||
)
|
||||
.with_proxy_url(dh.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("discord_history: failed to open discord.db: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sl) = config.channels_config.slack {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Slack",
|
||||
@@ -3938,6 +3999,15 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref gp_cfg) = config.channels_config.gmail_push {
|
||||
if gp_cfg.enabled {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Gmail Push",
|
||||
channel: Arc::new(GmailPushChannel::new(gp_cfg.clone())),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "IRC",
|
||||
@@ -4114,6 +4184,17 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
if let Some(ref vw) = config.channels_config.voice_wake {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "VoiceWake",
|
||||
channel: Arc::new(VoiceWakeChannel::new(
|
||||
vw.clone(),
|
||||
config.transcription.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref wh) = config.channels_config.webhook {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Webhook",
|
||||
@@ -4264,22 +4345,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let (mut built_tools, delegate_handle_ch): (Vec<Box<dyn Tool>>, _) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
);
|
||||
let (mut built_tools, delegate_handle_ch, reaction_handle_ch) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
|
||||
@@ -4452,6 +4533,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
if !native_tools {
|
||||
system_prompt.push_str(&build_tool_instructions(
|
||||
@@ -4551,6 +4634,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
);
|
||||
|
||||
// Populate the reaction tool's channel map now that channels are initialized.
|
||||
if let Some(ref handle) = reaction_handle_ch {
|
||||
let mut map = handle.write();
|
||||
for (name, ch) in channels_by_name.as_ref() {
|
||||
map.insert(name.clone(), Arc::clone(ch));
|
||||
}
|
||||
}
|
||||
|
||||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||||
|
||||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||||
@@ -6783,6 +6875,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: "2026-02-20T00:00:00Z".to_string(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -7774,6 +7869,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
Some(&config),
|
||||
false,
|
||||
crate::config::SkillsPromptInjectionMode::Full,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
|
||||
assert!(
|
||||
@@ -7803,6 +7900,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
Some(&config),
|
||||
false,
|
||||
crate::config::SkillsPromptInjectionMode::Full,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
|
||||
assert!(
|
||||
|
||||
@@ -0,0 +1,531 @@
|
||||
//! Voice Wake Word detection channel.
|
||||
//!
|
||||
//! Listens on the default microphone via `cpal`, detects a configurable wake
|
||||
//! word using energy-based VAD followed by transcription-based keyword matching,
|
||||
//! then captures the subsequent utterance and dispatches it as a channel message.
|
||||
//!
|
||||
//! Gated behind the `voice-wake` Cargo feature.
|
||||
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::channels::transcription::transcribe_audio;
|
||||
use crate::config::schema::VoiceWakeConfig;
|
||||
use crate::config::TranscriptionConfig;
|
||||
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
|
||||
// ── State machine ──────────────────────────────────────────────
|
||||
|
||||
/// Internal states for the wake-word detector.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WakeState {
|
||||
/// Passively monitoring microphone energy levels.
|
||||
Listening,
|
||||
/// Energy spike detected — capturing a short window to check for wake word.
|
||||
Triggered,
|
||||
/// Wake word confirmed — capturing the full utterance that follows.
|
||||
Capturing,
|
||||
/// Captured audio is being transcribed.
|
||||
Processing,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WakeState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Listening => write!(f, "Listening"),
|
||||
Self::Triggered => write!(f, "Triggered"),
|
||||
Self::Capturing => write!(f, "Capturing"),
|
||||
Self::Processing => write!(f, "Processing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Channel implementation ─────────────────────────────────────
|
||||
|
||||
/// Voice wake-word channel that activates on a spoken keyword.
|
||||
pub struct VoiceWakeChannel {
|
||||
config: VoiceWakeConfig,
|
||||
transcription_config: TranscriptionConfig,
|
||||
}
|
||||
|
||||
impl VoiceWakeChannel {
|
||||
/// Create a new `VoiceWakeChannel` from its config sections.
|
||||
pub fn new(config: VoiceWakeConfig, transcription_config: TranscriptionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
transcription_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for VoiceWakeChannel {
|
||||
fn name(&self) -> &str {
|
||||
"voice_wake"
|
||||
}
|
||||
|
||||
async fn send(&self, _message: &SendMessage) -> Result<()> {
|
||||
// Voice wake is input-only; outbound messages are not supported.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
let config = self.config.clone();
|
||||
let transcription_config = self.transcription_config.clone();
|
||||
|
||||
// Run the blocking audio capture loop on a dedicated thread.
|
||||
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(4);
|
||||
|
||||
let energy_threshold = config.energy_threshold;
|
||||
let silence_timeout = Duration::from_millis(u64::from(config.silence_timeout_ms));
|
||||
let max_capture = Duration::from_secs(u64::from(config.max_capture_secs));
|
||||
let sample_rate: u32;
|
||||
let channels_count: u16;
|
||||
|
||||
// ── Initialise cpal stream ────────────────────────────
|
||||
{
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.ok_or_else(|| anyhow::anyhow!("No default audio input device available"))?;
|
||||
|
||||
let supported = device.default_input_config()?;
|
||||
sample_rate = supported.sample_rate().0;
|
||||
channels_count = supported.channels();
|
||||
|
||||
info!(
|
||||
device = ?device.name().unwrap_or_default(),
|
||||
sample_rate,
|
||||
channels = channels_count,
|
||||
"VoiceWake: opening audio input"
|
||||
);
|
||||
|
||||
let stream_config: cpal::StreamConfig = supported.into();
|
||||
let audio_tx_clone = audio_tx.clone();
|
||||
|
||||
let stream = device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
// Non-blocking: try_send and drop if full.
|
||||
let _ = audio_tx_clone.try_send(data.to_vec());
|
||||
},
|
||||
move |err| {
|
||||
warn!("VoiceWake: audio stream error: {err}");
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
|
||||
stream.play()?;
|
||||
|
||||
// Keep the stream alive for the lifetime of the channel.
|
||||
// We leak it intentionally — the channel runs until the daemon shuts down.
|
||||
std::mem::forget(stream);
|
||||
}
|
||||
|
||||
// Drop the extra sender so the channel closes when the stream sender drops.
|
||||
drop(audio_tx);
|
||||
|
||||
// ── Main detection loop ───────────────────────────────
|
||||
let wake_word = config.wake_word.to_lowercase();
|
||||
let mut state = WakeState::Listening;
|
||||
let mut capture_buf: Vec<f32> = Vec::new();
|
||||
let mut last_voice_at = Instant::now();
|
||||
let mut capture_start = Instant::now();
|
||||
let mut msg_counter: u64 = 0;
|
||||
|
||||
info!(wake_word = %wake_word, "VoiceWake: entering listen loop");
|
||||
|
||||
while let Some(chunk) = audio_rx.recv().await {
|
||||
let energy = compute_rms_energy(&chunk);
|
||||
|
||||
match state {
|
||||
WakeState::Listening => {
|
||||
if energy >= energy_threshold {
|
||||
debug!(
|
||||
energy,
|
||||
"VoiceWake: energy spike — transitioning to Triggered"
|
||||
);
|
||||
state = WakeState::Triggered;
|
||||
capture_buf.clear();
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
last_voice_at = Instant::now();
|
||||
capture_start = Instant::now();
|
||||
}
|
||||
}
|
||||
WakeState::Triggered => {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
}
|
||||
|
||||
let since_voice = last_voice_at.elapsed();
|
||||
let since_start = capture_start.elapsed();
|
||||
|
||||
// After enough silence or max time, transcribe to check for wake word.
|
||||
if since_voice >= silence_timeout || since_start >= max_capture {
|
||||
debug!("VoiceWake: Triggered window closed — transcribing for wake word");
|
||||
|
||||
let wav_bytes =
|
||||
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
|
||||
|
||||
match transcribe_audio(wav_bytes, "wake_check.wav", &transcription_config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let lower = text.to_lowercase();
|
||||
if lower.contains(&wake_word) {
|
||||
info!(text = %text, "VoiceWake: wake word detected — capturing utterance");
|
||||
state = WakeState::Capturing;
|
||||
capture_buf.clear();
|
||||
last_voice_at = Instant::now();
|
||||
capture_start = Instant::now();
|
||||
} else {
|
||||
debug!(text = %text, "VoiceWake: no wake word — back to Listening");
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VoiceWake: transcription error during wake check: {e}");
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WakeState::Capturing => {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
}
|
||||
|
||||
let since_voice = last_voice_at.elapsed();
|
||||
let since_start = capture_start.elapsed();
|
||||
|
||||
if since_voice >= silence_timeout || since_start >= max_capture {
|
||||
debug!("VoiceWake: utterance capture complete — transcribing");
|
||||
|
||||
let wav_bytes =
|
||||
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
|
||||
|
||||
match transcribe_audio(wav_bytes, "utterance.wav", &transcription_config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
msg_counter += 1;
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: format!("voice_wake_{msg_counter}"),
|
||||
sender: "voice_user".into(),
|
||||
reply_target: "voice_user".into(),
|
||||
content: trimmed,
|
||||
channel: "voice_wake".into(),
|
||||
timestamp: ts,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
};
|
||||
|
||||
if let Err(e) = tx.send(msg).await {
|
||||
warn!("VoiceWake: failed to dispatch message: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VoiceWake: transcription error for utterance: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
WakeState::Processing => {
|
||||
// Should not receive chunks while processing, but just buffer them.
|
||||
// State transitions happen above synchronously after transcription.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bail!("VoiceWake: audio stream ended unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Audio utilities ────────────────────────────────────────────
|
||||
|
||||
/// Compute RMS (root-mean-square) energy of an audio chunk.
|
||||
pub fn compute_rms_energy(samples: &[f32]) -> f32 {
|
||||
if samples.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
|
||||
(sum_sq / samples.len() as f32).sqrt()
|
||||
}
|
||||
|
||||
/// Encode raw f32 PCM samples as a WAV byte buffer (16-bit PCM).
|
||||
///
|
||||
/// This produces a minimal valid WAV file that Whisper-compatible APIs accept.
|
||||
pub fn encode_wav_from_f32(samples: &[f32], sample_rate: u32, channels: u16) -> Vec<u8> {
|
||||
let bits_per_sample: u16 = 16;
|
||||
let byte_rate = u32::from(channels) * sample_rate * u32::from(bits_per_sample) / 8;
|
||||
let block_align = channels * bits_per_sample / 8;
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let data_len = (samples.len() * 2) as u32; // 16-bit = 2 bytes per sample; max ~25 MB
|
||||
let file_len = 36 + data_len;
|
||||
|
||||
let mut buf = Vec::with_capacity(file_len as usize + 8);
|
||||
|
||||
// RIFF header
|
||||
buf.extend_from_slice(b"RIFF");
|
||||
buf.extend_from_slice(&file_len.to_le_bytes());
|
||||
buf.extend_from_slice(b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
buf.extend_from_slice(b"fmt ");
|
||||
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
|
||||
buf.extend_from_slice(&channels.to_le_bytes());
|
||||
buf.extend_from_slice(&sample_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&byte_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&block_align.to_le_bytes());
|
||||
buf.extend_from_slice(&bits_per_sample.to_le_bytes());
|
||||
|
||||
// data chunk
|
||||
buf.extend_from_slice(b"data");
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
|
||||
for &sample in samples {
|
||||
let clamped = sample.clamp(-1.0, 1.0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let pcm16 = (clamped * 32767.0) as i16; // clamped to [-1,1] so fits i16
|
||||
buf.extend_from_slice(&pcm16.to_le_bytes());
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::traits::ChannelConfig;
|
||||
|
||||
// ── State machine tests ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wake_state_display() {
|
||||
assert_eq!(WakeState::Listening.to_string(), "Listening");
|
||||
assert_eq!(WakeState::Triggered.to_string(), "Triggered");
|
||||
assert_eq!(WakeState::Capturing.to_string(), "Capturing");
|
||||
assert_eq!(WakeState::Processing.to_string(), "Processing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wake_state_equality() {
|
||||
assert_eq!(WakeState::Listening, WakeState::Listening);
|
||||
assert_ne!(WakeState::Listening, WakeState::Triggered);
|
||||
}
|
||||
|
||||
// ── Energy computation tests ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_silence_is_zero() {
|
||||
let silence = vec![0.0f32; 1024];
|
||||
assert_eq!(compute_rms_energy(&silence), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_empty_is_zero() {
|
||||
assert_eq!(compute_rms_energy(&[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_constant_signal() {
|
||||
// Constant signal at 0.5 → RMS should be 0.5
|
||||
let signal = vec![0.5f32; 100];
|
||||
let energy = compute_rms_energy(&signal);
|
||||
assert!((energy - 0.5).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_above_threshold() {
|
||||
let loud = vec![0.8f32; 256];
|
||||
let energy = compute_rms_energy(&loud);
|
||||
assert!(energy > 0.01, "Loud signal should exceed default threshold");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_below_threshold_for_quiet() {
|
||||
let quiet = vec![0.001f32; 256];
|
||||
let energy = compute_rms_energy(&quiet);
|
||||
assert!(
|
||||
energy < 0.01,
|
||||
"Very quiet signal should be below default threshold"
|
||||
);
|
||||
}
|
||||
|
||||
// ── WAV encoding tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wav_header_is_valid() {
|
||||
let samples = vec![0.0f32; 100];
|
||||
let wav = encode_wav_from_f32(&samples, 16000, 1);
|
||||
|
||||
// RIFF header
|
||||
assert_eq!(&wav[0..4], b"RIFF");
|
||||
assert_eq!(&wav[8..12], b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
assert_eq!(&wav[12..16], b"fmt ");
|
||||
let fmt_size = u32::from_le_bytes(wav[16..20].try_into().unwrap());
|
||||
assert_eq!(fmt_size, 16);
|
||||
|
||||
// PCM format
|
||||
let format = u16::from_le_bytes(wav[20..22].try_into().unwrap());
|
||||
assert_eq!(format, 1);
|
||||
|
||||
// Channels
|
||||
let channels = u16::from_le_bytes(wav[22..24].try_into().unwrap());
|
||||
assert_eq!(channels, 1);
|
||||
|
||||
// Sample rate
|
||||
let sr = u32::from_le_bytes(wav[24..28].try_into().unwrap());
|
||||
assert_eq!(sr, 16000);
|
||||
|
||||
// data chunk
|
||||
assert_eq!(&wav[36..40], b"data");
|
||||
let data_size = u32::from_le_bytes(wav[40..44].try_into().unwrap());
|
||||
assert_eq!(data_size, 200); // 100 samples * 2 bytes each
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wav_total_size_correct() {
|
||||
let samples = vec![0.0f32; 50];
|
||||
let wav = encode_wav_from_f32(&samples, 44100, 2);
|
||||
// header (44 bytes) + data (50 * 2 = 100 bytes)
|
||||
assert_eq!(wav.len(), 144);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wav_encodes_clipped_samples() {
|
||||
// Samples outside [-1, 1] should be clamped
|
||||
let samples = vec![-2.0f32, 2.0, 0.0];
|
||||
let wav = encode_wav_from_f32(&samples, 16000, 1);
|
||||
|
||||
let s0 = i16::from_le_bytes(wav[44..46].try_into().unwrap());
|
||||
let s1 = i16::from_le_bytes(wav[46..48].try_into().unwrap());
|
||||
let s2 = i16::from_le_bytes(wav[48..50].try_into().unwrap());
|
||||
|
||||
assert_eq!(s0, -32767); // clamped to -1.0
|
||||
assert_eq!(s1, 32767); // clamped to 1.0
|
||||
assert_eq!(s2, 0);
|
||||
}
|
||||
|
||||
// ── Config parsing tests ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_defaults() {
|
||||
let config = VoiceWakeConfig::default();
|
||||
assert_eq!(config.wake_word, "hey zeroclaw");
|
||||
assert_eq!(config.silence_timeout_ms, 2000);
|
||||
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
|
||||
assert_eq!(config.max_capture_secs, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_deserialize_partial() {
|
||||
let toml_str = r#"
|
||||
wake_word = "okay agent"
|
||||
max_capture_secs = 60
|
||||
"#;
|
||||
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.wake_word, "okay agent");
|
||||
assert_eq!(config.max_capture_secs, 60);
|
||||
// Defaults preserved for unset fields
|
||||
assert_eq!(config.silence_timeout_ms, 2000);
|
||||
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_deserialize_all_fields() {
|
||||
let toml_str = r#"
|
||||
wake_word = "hello bot"
|
||||
silence_timeout_ms = 3000
|
||||
energy_threshold = 0.05
|
||||
max_capture_secs = 15
|
||||
"#;
|
||||
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.wake_word, "hello bot");
|
||||
assert_eq!(config.silence_timeout_ms, 3000);
|
||||
assert!((config.energy_threshold - 0.05).abs() < f32::EPSILON);
|
||||
assert_eq!(config.max_capture_secs, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_channel_config_trait() {
|
||||
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
|
||||
assert_eq!(VoiceWakeConfig::desc(), "voice wake word detection");
|
||||
}
|
||||
|
||||
// ── State transition logic tests ───────────────────────
|
||||
|
||||
#[test]
|
||||
fn energy_threshold_determines_trigger() {
|
||||
let threshold = 0.01f32;
|
||||
let quiet_energy = compute_rms_energy(&vec![0.005f32; 256]);
|
||||
let loud_energy = compute_rms_energy(&vec![0.5f32; 256]);
|
||||
|
||||
assert!(quiet_energy < threshold, "Quiet should not trigger");
|
||||
assert!(loud_energy >= threshold, "Loud should trigger");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_transitions_are_deterministic() {
|
||||
// Verify that the state enum values are distinct and copyable
|
||||
let states = [
|
||||
WakeState::Listening,
|
||||
WakeState::Triggered,
|
||||
WakeState::Capturing,
|
||||
WakeState::Processing,
|
||||
];
|
||||
for (i, a) in states.iter().enumerate() {
|
||||
for (j, b) in states.iter().enumerate() {
|
||||
if i == j {
|
||||
assert_eq!(a, b);
|
||||
} else {
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_config_impl() {
|
||||
// VoiceWakeConfig implements ChannelConfig
|
||||
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
|
||||
assert!(!VoiceWakeConfig::desc().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_channel_name() {
|
||||
let config = VoiceWakeConfig::default();
|
||||
let transcription_config = TranscriptionConfig::default();
|
||||
let channel = VoiceWakeChannel::new(config, transcription_config);
|
||||
assert_eq!(channel.name(), "voice_wake");
|
||||
}
|
||||
}
|
||||
+16
-15
@@ -15,21 +15,22 @@ pub use schema::{
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
|
||||
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
|
||||
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
|
||||
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
|
||||
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
|
||||
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig,
|
||||
LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig,
|
||||
MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PacingConfig,
|
||||
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
|
||||
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
|
||||
DEFAULT_GWS_SERVICES,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+358
-94
@@ -357,6 +357,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub linkedin: LinkedInConfig,
|
||||
|
||||
/// Standalone image generation tool configuration (`[image_gen]`).
|
||||
#[serde(default)]
|
||||
pub image_gen: ImageGenConfig,
|
||||
|
||||
/// Plugin system configuration (`[plugins]`).
|
||||
#[serde(default)]
|
||||
pub plugins: PluginsConfig,
|
||||
@@ -1248,6 +1252,12 @@ pub struct AgentConfig {
|
||||
/// Default: `[]` (no filtering — all tools included).
|
||||
#[serde(default)]
|
||||
pub tool_filter_groups: Vec<ToolFilterGroup>,
|
||||
/// Maximum characters for the assembled system prompt. When `> 0`, the prompt
|
||||
/// is truncated to this limit after assembly (keeping the top portion which
|
||||
/// contains identity and safety instructions). `0` means unlimited.
|
||||
/// Useful for small-context models (e.g. glm-4.5-air ~8K tokens → set to 8000).
|
||||
#[serde(default = "default_max_system_prompt_chars")]
|
||||
pub max_system_prompt_chars: usize,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@@ -1266,6 +1276,10 @@ fn default_agent_tool_dispatcher() -> String {
|
||||
"auto".into()
|
||||
}
|
||||
|
||||
fn default_max_system_prompt_chars() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -1277,6 +1291,7 @@ impl Default for AgentConfig {
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
tool_filter_groups: Vec::new(),
|
||||
max_system_prompt_chars: default_max_system_prompt_chars(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2963,6 +2978,46 @@ impl Default for ImageProviderFluxConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Standalone Image Generation ─────────────────────────────────
|
||||
|
||||
/// Standalone image generation tool configuration (`[image_gen]`).
|
||||
///
|
||||
/// When enabled, registers an `image_gen` tool that generates images via
|
||||
/// fal.ai's synchronous API (Flux / Nano Banana models) and saves them
|
||||
/// to the workspace `images/` directory.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ImageGenConfig {
|
||||
/// Enable the standalone image generation tool. Default: false.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Default fal.ai model identifier.
|
||||
#[serde(default = "default_image_gen_model")]
|
||||
pub default_model: String,
|
||||
|
||||
/// Environment variable name holding the fal.ai API key.
|
||||
#[serde(default = "default_image_gen_api_key_env")]
|
||||
pub api_key_env: String,
|
||||
}
|
||||
|
||||
fn default_image_gen_model() -> String {
|
||||
"fal-ai/flux/schnell".into()
|
||||
}
|
||||
|
||||
fn default_image_gen_api_key_env() -> String {
|
||||
"FAL_API_KEY".into()
|
||||
}
|
||||
|
||||
impl Default for ImageGenConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
default_model: default_image_gen_model(),
|
||||
api_key_env: default_image_gen_api_key_env(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude Code ─────────────────────────────────────────────────
|
||||
|
||||
/// Claude Code CLI tool configuration (`[claude_code]` section).
|
||||
@@ -3754,77 +3809,6 @@ impl Default for QdrantConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the mem0 (OpenMemory) memory backend.
|
||||
///
|
||||
/// Connects to a self-hosted OpenMemory server via its REST API.
|
||||
/// Deploy OpenMemory with `docker compose up` from the mem0 repo,
|
||||
/// then point `url` at the API (default `http://localhost:8765`).
|
||||
///
|
||||
/// ```toml
|
||||
/// [memory]
|
||||
/// backend = "mem0"
|
||||
///
|
||||
/// [memory.mem0]
|
||||
/// url = "http://localhost:8765"
|
||||
/// user_id = "zeroclaw"
|
||||
/// app_name = "zeroclaw"
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct Mem0Config {
|
||||
/// OpenMemory server URL (e.g. `http://localhost:8765`).
|
||||
/// Falls back to `MEM0_URL` env var if not set.
|
||||
#[serde(default = "default_mem0_url")]
|
||||
pub url: String,
|
||||
/// User ID for scoping memories within mem0.
|
||||
/// Falls back to `MEM0_USER_ID` env var, or default `"zeroclaw"`.
|
||||
#[serde(default = "default_mem0_user_id")]
|
||||
pub user_id: String,
|
||||
/// Application name registered in mem0.
|
||||
/// Falls back to `MEM0_APP_NAME` env var, or default `"zeroclaw"`.
|
||||
#[serde(default = "default_mem0_app_name")]
|
||||
pub app_name: String,
|
||||
/// Whether mem0 should use its built-in LLM to extract facts from
|
||||
/// stored text (`infer = true`) or store raw text as-is (`false`).
|
||||
#[serde(default = "default_mem0_infer")]
|
||||
pub infer: bool,
|
||||
/// Custom prompt for guiding LLM-based fact extraction when `infer = true`.
|
||||
/// Useful for non-English content (e.g. Cantonese/Chinese).
|
||||
/// Falls back to `MEM0_EXTRACTION_PROMPT` env var.
|
||||
/// If unset, the mem0 server uses its built-in default prompt.
|
||||
#[serde(default = "default_mem0_extraction_prompt")]
|
||||
pub extraction_prompt: Option<String>,
|
||||
}
|
||||
|
||||
fn default_mem0_url() -> String {
|
||||
std::env::var("MEM0_URL").unwrap_or_else(|_| "http://localhost:8765".into())
|
||||
}
|
||||
fn default_mem0_user_id() -> String {
|
||||
std::env::var("MEM0_USER_ID").unwrap_or_else(|_| "zeroclaw".into())
|
||||
}
|
||||
fn default_mem0_app_name() -> String {
|
||||
std::env::var("MEM0_APP_NAME").unwrap_or_else(|_| "zeroclaw".into())
|
||||
}
|
||||
fn default_mem0_infer() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_mem0_extraction_prompt() -> Option<String> {
|
||||
std::env::var("MEM0_EXTRACTION_PROMPT")
|
||||
.ok()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
}
|
||||
|
||||
impl Default for Mem0Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
url: default_mem0_url(),
|
||||
user_id: default_mem0_user_id(),
|
||||
app_name: default_mem0_app_name(),
|
||||
infer: default_mem0_infer(),
|
||||
extraction_prompt: default_mem0_extraction_prompt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct MemoryConfig {
|
||||
@@ -3899,6 +3883,43 @@ pub struct MemoryConfig {
|
||||
#[serde(default = "default_true")]
|
||||
pub auto_hydrate: bool,
|
||||
|
||||
// ── Retrieval Pipeline ─────────────────────────────────────
|
||||
/// Retrieval stages to execute in order. Valid: "cache", "fts", "vector".
|
||||
#[serde(default = "default_retrieval_stages")]
|
||||
pub retrieval_stages: Vec<String>,
|
||||
/// Enable LLM reranking when candidate count exceeds threshold.
|
||||
#[serde(default)]
|
||||
pub rerank_enabled: bool,
|
||||
/// Minimum candidate count to trigger reranking.
|
||||
#[serde(default = "default_rerank_threshold")]
|
||||
pub rerank_threshold: usize,
|
||||
/// FTS score above which to early-return without vector search (0.0–1.0).
|
||||
#[serde(default = "default_fts_early_return_score")]
|
||||
pub fts_early_return_score: f64,
|
||||
|
||||
// ── Namespace Isolation ─────────────────────────────────────
|
||||
/// Default namespace for memory entries.
|
||||
#[serde(default = "default_namespace")]
|
||||
pub default_namespace: String,
|
||||
|
||||
// ── Conflict Resolution ─────────────────────────────────────
|
||||
/// Cosine similarity threshold for conflict detection (0.0–1.0).
|
||||
#[serde(default = "default_conflict_threshold")]
|
||||
pub conflict_threshold: f64,
|
||||
|
||||
// ── Audit Trail ─────────────────────────────────────────────
|
||||
/// Enable audit logging of memory operations.
|
||||
#[serde(default)]
|
||||
pub audit_enabled: bool,
|
||||
/// Retention period for audit entries in days (default: 30).
|
||||
#[serde(default = "default_audit_retention_days")]
|
||||
pub audit_retention_days: u32,
|
||||
|
||||
// ── Policy Engine ───────────────────────────────────────────
|
||||
/// Memory policy configuration.
|
||||
#[serde(default)]
|
||||
pub policy: MemoryPolicyConfig,
|
||||
|
||||
// ── SQLite backend options ─────────────────────────────────
|
||||
/// For sqlite backend: max seconds to wait when opening the DB (e.g. file locked).
|
||||
/// None = wait indefinitely (default). Recommended max: 300.
|
||||
@@ -3910,13 +3931,42 @@ pub struct MemoryConfig {
|
||||
/// Only used when `backend = "qdrant"`.
|
||||
#[serde(default)]
|
||||
pub qdrant: QdrantConfig,
|
||||
}
|
||||
|
||||
// ── Mem0 backend options ─────────────────────────────────
|
||||
/// Configuration for mem0 (OpenMemory) backend.
|
||||
/// Only used when `backend = "mem0"`.
|
||||
/// Requires `--features memory-mem0` at build time.
|
||||
/// Memory policy configuration (`[memory.policy]` section).
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct MemoryPolicyConfig {
|
||||
/// Maximum entries per namespace (0 = unlimited).
|
||||
#[serde(default)]
|
||||
pub mem0: Mem0Config,
|
||||
pub max_entries_per_namespace: usize,
|
||||
/// Maximum entries per category (0 = unlimited).
|
||||
#[serde(default)]
|
||||
pub max_entries_per_category: usize,
|
||||
/// Retention days by category (overrides global). Keys: "core", "daily", "conversation".
|
||||
#[serde(default)]
|
||||
pub retention_days_by_category: std::collections::HashMap<String, u32>,
|
||||
/// Namespaces that are read-only (writes are rejected).
|
||||
#[serde(default)]
|
||||
pub read_only_namespaces: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_retrieval_stages() -> Vec<String> {
|
||||
vec!["cache".into(), "fts".into(), "vector".into()]
|
||||
}
|
||||
fn default_rerank_threshold() -> usize {
|
||||
5
|
||||
}
|
||||
fn default_fts_early_return_score() -> f64 {
|
||||
0.85
|
||||
}
|
||||
fn default_namespace() -> String {
|
||||
"default".into()
|
||||
}
|
||||
fn default_conflict_threshold() -> f64 {
|
||||
0.85
|
||||
}
|
||||
fn default_audit_retention_days() -> u32 {
|
||||
30
|
||||
}
|
||||
|
||||
fn default_embedding_provider() -> String {
|
||||
@@ -3990,9 +4040,17 @@ impl Default for MemoryConfig {
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
retrieval_stages: default_retrieval_stages(),
|
||||
rerank_enabled: false,
|
||||
rerank_threshold: default_rerank_threshold(),
|
||||
fts_early_return_score: default_fts_early_return_score(),
|
||||
default_namespace: default_namespace(),
|
||||
conflict_threshold: default_conflict_threshold(),
|
||||
audit_enabled: false,
|
||||
audit_retention_days: default_audit_retention_days(),
|
||||
policy: MemoryPolicyConfig::default(),
|
||||
sqlite_open_timeout_secs: None,
|
||||
qdrant: QdrantConfig::default(),
|
||||
mem0: Mem0Config::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4201,6 +4259,7 @@ fn default_auto_approve() -> Vec<String> {
|
||||
"glob_search".into(),
|
||||
"content_search".into(),
|
||||
"image_info".into(),
|
||||
"weather".into(),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -4587,6 +4646,7 @@ pub struct ClassificationRule {
|
||||
|
||||
/// Heartbeat configuration for periodic health pings (`[heartbeat]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct HeartbeatConfig {
|
||||
/// Enable periodic heartbeat pings. Default: `false`.
|
||||
pub enabled: bool,
|
||||
@@ -4633,6 +4693,14 @@ pub struct HeartbeatConfig {
|
||||
/// Maximum number of heartbeat run history records to retain. Default: `100`.
|
||||
#[serde(default = "default_heartbeat_max_run_history")]
|
||||
pub max_run_history: u32,
|
||||
/// Load the channel session history before each heartbeat task execution so
|
||||
/// the LLM has conversational context. Default: `false`.
|
||||
///
|
||||
/// When `true`, the session file for the configured `target`/`to` is passed
|
||||
/// to the agent as `session_state_file`, giving it access to the recent
|
||||
/// conversation history — just as if the user had sent a message.
|
||||
#[serde(default)]
|
||||
pub load_session_context: bool,
|
||||
}
|
||||
|
||||
fn default_heartbeat_interval() -> u32 {
|
||||
@@ -4671,6 +4739,7 @@ impl Default for HeartbeatConfig {
|
||||
deadman_channel: None,
|
||||
deadman_to: None,
|
||||
max_run_history: default_heartbeat_max_run_history(),
|
||||
load_session_context: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4867,6 +4936,8 @@ pub struct ChannelsConfig {
|
||||
pub telegram: Option<TelegramConfig>,
|
||||
/// Discord bot channel configuration.
|
||||
pub discord: Option<DiscordConfig>,
|
||||
/// Discord history channel — logs ALL messages and forwards @mentions to agent.
|
||||
pub discord_history: Option<DiscordHistoryConfig>,
|
||||
/// Slack bot channel configuration.
|
||||
pub slack: Option<SlackConfig>,
|
||||
/// Mattermost bot channel configuration.
|
||||
@@ -4889,6 +4960,8 @@ pub struct ChannelsConfig {
|
||||
pub nextcloud_talk: Option<NextcloudTalkConfig>,
|
||||
/// Email channel configuration.
|
||||
pub email: Option<crate::channels::email_channel::EmailConfig>,
|
||||
/// Gmail Pub/Sub push notification channel configuration.
|
||||
pub gmail_push: Option<crate::channels::gmail_push::GmailPushConfig>,
|
||||
/// IRC channel configuration.
|
||||
pub irc: Option<IrcConfig>,
|
||||
/// Lark channel configuration.
|
||||
@@ -4913,6 +4986,9 @@ pub struct ChannelsConfig {
|
||||
pub reddit: Option<RedditConfig>,
|
||||
/// Bluesky channel configuration (AT Protocol).
|
||||
pub bluesky: Option<BlueskyConfig>,
|
||||
/// Voice wake word detection channel configuration.
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub voice_wake: Option<VoiceWakeConfig>,
|
||||
/// Base timeout in seconds for processing a single channel message (LLM + tools).
|
||||
/// Runtime uses this as a per-turn budget that scales with tool-loop depth
|
||||
/// (up to 4x, capped) so one slow/retried model call does not consume the
|
||||
@@ -4995,6 +5071,10 @@ impl ChannelsConfig {
|
||||
Box::new(ConfigWrapper::new(self.email.as_ref())),
|
||||
self.email.is_some(),
|
||||
),
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.gmail_push.as_ref())),
|
||||
self.gmail_push.is_some(),
|
||||
),
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.irc.as_ref())),
|
||||
self.irc.is_some()
|
||||
@@ -5036,6 +5116,11 @@ impl ChannelsConfig {
|
||||
Box::new(ConfigWrapper::new(self.bluesky.as_ref())),
|
||||
self.bluesky.is_some(),
|
||||
),
|
||||
#[cfg(feature = "voice-wake")]
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.voice_wake.as_ref())),
|
||||
self.voice_wake.is_some(),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -5063,6 +5148,7 @@ impl Default for ChannelsConfig {
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -5074,6 +5160,7 @@ impl Default for ChannelsConfig {
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -5087,6 +5174,8 @@ impl Default for ChannelsConfig {
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: default_channel_message_timeout_secs(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: false,
|
||||
@@ -5190,6 +5279,39 @@ impl ChannelConfig for DiscordConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Discord history channel — logs ALL messages to discord.db and forwards @mentions to the agent.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DiscordHistoryConfig {
|
||||
/// Discord bot token (from Discord Developer Portal).
|
||||
pub bot_token: String,
|
||||
/// Optional guild (server) ID to restrict logging to a single guild.
|
||||
pub guild_id: Option<String>,
|
||||
/// Allowed Discord user IDs. Empty = allow all (open logging).
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Discord channel IDs to watch. Empty = watch all channels.
|
||||
#[serde(default)]
|
||||
pub channel_ids: Vec<String>,
|
||||
/// When true (default), store Direct Messages in discord.db.
|
||||
#[serde(default = "default_true")]
|
||||
pub store_dms: bool,
|
||||
/// When true (default), respond to @mentions in Direct Messages.
|
||||
#[serde(default = "default_true")]
|
||||
pub respond_to_dms: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for DiscordHistoryConfig {
|
||||
fn name() -> &'static str {
|
||||
"Discord History"
|
||||
}
|
||||
fn desc() -> &'static str {
|
||||
"log all messages and forward @mentions"
|
||||
}
|
||||
}
|
||||
|
||||
/// Slack bot channel configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SlackConfig {
|
||||
@@ -6338,6 +6460,74 @@ impl ChannelConfig for BlueskyConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Voice wake word detection channel configuration.
|
||||
///
|
||||
/// Listens on the default microphone for a configurable wake word,
|
||||
/// then captures the following utterance and transcribes it via the
|
||||
/// existing transcription API.
|
||||
#[cfg(feature = "voice-wake")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct VoiceWakeConfig {
|
||||
/// Wake word phrase to listen for (case-insensitive substring match).
|
||||
/// Default: `"hey zeroclaw"`.
|
||||
#[serde(default = "default_voice_wake_word")]
|
||||
pub wake_word: String,
|
||||
/// Silence timeout in milliseconds — how long to wait after the last
|
||||
/// energy spike before finalizing a capture window. Default: `2000`.
|
||||
#[serde(default = "default_voice_wake_silence_timeout_ms")]
|
||||
pub silence_timeout_ms: u32,
|
||||
/// RMS energy threshold for voice activity detection. Samples below
|
||||
/// this level are treated as silence. Default: `0.01`.
|
||||
#[serde(default = "default_voice_wake_energy_threshold")]
|
||||
pub energy_threshold: f32,
|
||||
/// Maximum capture duration in seconds before forcing transcription.
|
||||
/// Default: `30`.
|
||||
#[serde(default = "default_voice_wake_max_capture_secs")]
|
||||
pub max_capture_secs: u32,
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_word() -> String {
|
||||
"hey zeroclaw".into()
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_silence_timeout_ms() -> u32 {
|
||||
2000
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_energy_threshold() -> f32 {
|
||||
0.01
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_max_capture_secs() -> u32 {
|
||||
30
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
impl Default for VoiceWakeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
wake_word: default_voice_wake_word(),
|
||||
silence_timeout_ms: default_voice_wake_silence_timeout_ms(),
|
||||
energy_threshold: default_voice_wake_energy_threshold(),
|
||||
max_capture_secs: default_voice_wake_max_capture_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
impl ChannelConfig for VoiceWakeConfig {
|
||||
fn name() -> &'static str {
|
||||
"VoiceWake"
|
||||
}
|
||||
fn desc() -> &'static str {
|
||||
"voice wake word detection"
|
||||
}
|
||||
}
|
||||
|
||||
/// Nostr channel configuration (NIP-04 + NIP-17 private messages)
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -6811,6 +7001,7 @@ impl Default for Config {
|
||||
node_transport: NodeTransportConfig::default(),
|
||||
knowledge: KnowledgeConfig::default(),
|
||||
linkedin: LinkedInConfig::default(),
|
||||
image_gen: ImageGenConfig::default(),
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
@@ -7339,22 +7530,37 @@ impl Config {
|
||||
.await
|
||||
.context("Failed to read config file")?;
|
||||
|
||||
// Track ignored/unknown config keys to warn users about silent misconfigurations
|
||||
// (e.g., using [providers.ollama] which doesn't exist instead of top-level api_url)
|
||||
// Deserialize the config with the standard TOML parser.
|
||||
//
|
||||
// Previously this used `serde_ignored::deserialize` for both
|
||||
// deserialization and unknown-key detection. However,
|
||||
// `serde_ignored` silently drops field values inside nested
|
||||
// structs that carry `#[serde(default)]` (e.g. the entire
|
||||
// `[autonomy]` table), causing user-supplied values to be
|
||||
// replaced by defaults. See #4171.
|
||||
//
|
||||
// We now deserialize with `toml::from_str` (which is correct)
|
||||
// and run `serde_ignored` separately just for diagnostics.
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).context("Failed to deserialize config file")?;
|
||||
|
||||
// Detect unknown/ignored config keys for diagnostic warnings.
|
||||
// This second pass uses serde_ignored but discards the parsed
|
||||
// result — only the ignored-path list is kept.
|
||||
let mut ignored_paths: Vec<String> = Vec::new();
|
||||
let mut config: Config = serde_ignored::deserialize(
|
||||
toml::de::Deserializer::parse(&contents).context("Failed to parse config file")?,
|
||||
let _: Result<Config, _> = serde_ignored::deserialize(
|
||||
toml::de::Deserializer::parse(&contents)
|
||||
.unwrap_or_else(|_| unreachable!("already parsed above")),
|
||||
|path| {
|
||||
ignored_paths.push(path.to_string());
|
||||
},
|
||||
)
|
||||
.context("Failed to deserialize config file")?;
|
||||
);
|
||||
|
||||
// Warn about each unknown config key.
|
||||
// serde_ignored + #[serde(default)] on nested structs can produce
|
||||
// false positives: parent-level fields get re-reported under the
|
||||
// nested key (e.g. "memory.mem0.auto_hydrate" even though
|
||||
// auto_hydrate belongs to MemoryConfig, not Mem0Config). We
|
||||
// nested key (e.g. "memory.qdrant.auto_hydrate" even though
|
||||
// auto_hydrate belongs to MemoryConfig, not QdrantConfig). We
|
||||
// suppress these by checking whether the leaf key is a known field
|
||||
// on the parent struct.
|
||||
let known_memory_fields: &[&str] = &[
|
||||
@@ -7383,7 +7589,7 @@ impl Config {
|
||||
];
|
||||
for path in ignored_paths {
|
||||
// Skip false positives from nested memory sub-sections
|
||||
if path.starts_with("memory.mem0.") || path.starts_with("memory.qdrant.") {
|
||||
if path.starts_with("memory.qdrant.") {
|
||||
let leaf = path.rsplit('.').next().unwrap_or("");
|
||||
if known_memory_fields.contains(&leaf) {
|
||||
continue;
|
||||
@@ -7607,6 +7813,13 @@ impl Config {
|
||||
"config.channels_config.email.password",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut gp) = config.channels_config.gmail_push {
|
||||
decrypt_secret(
|
||||
&store,
|
||||
&mut gp.oauth_token,
|
||||
"config.channels_config.gmail_push.oauth_token",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut irc) = config.channels_config.irc {
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
@@ -9038,6 +9251,13 @@ impl Config {
|
||||
"config.channels_config.email.password",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut gp) = config_to_save.channels_config.gmail_push {
|
||||
encrypt_secret(
|
||||
&store,
|
||||
&mut gp.oauth_token,
|
||||
"config.channels_config.gmail_push.oauth_token",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut irc) = config_to_save.channels_config.irc {
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
@@ -9275,7 +9495,6 @@ mod tests {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex as StdMutex};
|
||||
#[cfg(unix)]
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
use tokio::test;
|
||||
@@ -9666,6 +9885,7 @@ default_temperature = 0.7
|
||||
proxy_url: None,
|
||||
}),
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -9677,6 +9897,7 @@ default_temperature = 0.7
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -9690,6 +9911,8 @@ default_temperature = 0.7
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
@@ -9734,6 +9957,7 @@ default_temperature = 0.7
|
||||
node_transport: NodeTransportConfig::default(),
|
||||
knowledge: KnowledgeConfig::default(),
|
||||
linkedin: LinkedInConfig::default(),
|
||||
image_gen: ImageGenConfig::default(),
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
@@ -9791,6 +10015,37 @@ default_temperature = 0.7
|
||||
assert_eq!(parsed.provider_timeout_secs, 120);
|
||||
}
|
||||
|
||||
/// Regression test for #4171: the `[autonomy]` section must not be
|
||||
/// silently dropped when parsing config TOML.
|
||||
#[test]
|
||||
async fn autonomy_section_is_not_silently_ignored() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
level = "full"
|
||||
max_actions_per_hour = 99
|
||||
auto_approve = ["file_read", "memory_recall", "http_request"]
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
assert_eq!(
|
||||
parsed.autonomy.level,
|
||||
AutonomyLevel::Full,
|
||||
"autonomy.level must be parsed from config (was silently defaulting to Supervised)"
|
||||
);
|
||||
assert_eq!(
|
||||
parsed.autonomy.max_actions_per_hour, 99,
|
||||
"autonomy.max_actions_per_hour must be parsed from config"
|
||||
);
|
||||
assert!(
|
||||
parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.contains(&"http_request".to_string()),
|
||||
"autonomy.auto_approve must include http_request from config"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_timeout_secs_parses_from_toml() {
|
||||
let raw = r#"
|
||||
@@ -10115,6 +10370,7 @@ default_temperature = 0.7
|
||||
node_transport: NodeTransportConfig::default(),
|
||||
knowledge: KnowledgeConfig::default(),
|
||||
linkedin: LinkedInConfig::default(),
|
||||
image_gen: ImageGenConfig::default(),
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
@@ -10501,6 +10757,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -10522,6 +10779,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -10534,6 +10792,8 @@ allowed_users = ["@ops:matrix.org"]
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
@@ -10816,6 +11076,7 @@ channel_id = "C123"
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -10841,6 +11102,7 @@ channel_id = "C123"
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -10853,6 +11115,8 @@ channel_id = "C123"
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
@@ -13671,12 +13935,12 @@ require_otp_to_resume = true
|
||||
async fn ensure_bootstrap_files_creates_missing_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ws = tmp.path().join("workspace");
|
||||
tokio::fs::create_dir_all(&ws).await.unwrap();
|
||||
let _: () = tokio::fs::create_dir_all(&ws).await.unwrap();
|
||||
|
||||
ensure_bootstrap_files(&ws).await.unwrap();
|
||||
|
||||
let soul = tokio::fs::read_to_string(ws.join("SOUL.md")).await.unwrap();
|
||||
let identity = tokio::fs::read_to_string(ws.join("IDENTITY.md"))
|
||||
let soul: String = tokio::fs::read_to_string(ws.join("SOUL.md")).await.unwrap();
|
||||
let identity: String = tokio::fs::read_to_string(ws.join("IDENTITY.md"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(soul.contains("SOUL.md"));
|
||||
@@ -13687,21 +13951,21 @@ require_otp_to_resume = true
|
||||
async fn ensure_bootstrap_files_does_not_overwrite_existing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ws = tmp.path().join("workspace");
|
||||
tokio::fs::create_dir_all(&ws).await.unwrap();
|
||||
let _: () = tokio::fs::create_dir_all(&ws).await.unwrap();
|
||||
|
||||
let custom = "# My custom SOUL";
|
||||
tokio::fs::write(ws.join("SOUL.md"), custom).await.unwrap();
|
||||
let _: () = tokio::fs::write(ws.join("SOUL.md"), custom).await.unwrap();
|
||||
|
||||
ensure_bootstrap_files(&ws).await.unwrap();
|
||||
|
||||
let soul = tokio::fs::read_to_string(ws.join("SOUL.md")).await.unwrap();
|
||||
let soul: String = tokio::fs::read_to_string(ws.join("SOUL.md")).await.unwrap();
|
||||
assert_eq!(
|
||||
soul, custom,
|
||||
"ensure_bootstrap_files must not overwrite existing files"
|
||||
);
|
||||
|
||||
// IDENTITY.md should still be created since it was missing
|
||||
let identity = tokio::fs::read_to_string(ws.join("IDENTITY.md"))
|
||||
let identity: String = tokio::fs::read_to_string(ws.join("IDENTITY.md"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(identity.contains("IDENTITY.md"));
|
||||
|
||||
+193
-1
@@ -362,10 +362,22 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
// Re-read session context on every tick so we pick up messages
|
||||
// that arrived since the daemon started.
|
||||
let session_context = if config.heartbeat.load_session_context {
|
||||
load_heartbeat_session_context(&config)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut tick_had_error = false;
|
||||
for task in &tasks_to_run {
|
||||
let task_start = std::time::Instant::now();
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let task_prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let prompt = match &session_context {
|
||||
Some(ctx) => format!("{ctx}\n\n{task_prompt}"),
|
||||
None => task_prompt,
|
||||
};
|
||||
let temp = config.default_temperature;
|
||||
match Box::pin(crate::agent::run(
|
||||
config.clone(),
|
||||
@@ -497,6 +509,186 @@ fn resolve_heartbeat_delivery(config: &Config) -> Result<Option<(String, String)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load recent conversation history for the heartbeat's delivery target and
|
||||
/// format it as a text preamble to inject into the task prompt.
|
||||
///
|
||||
/// Scans `{workspace}/sessions/` for JSONL files whose name starts with
|
||||
/// `{channel}_` and ends with `_{to}.jsonl` (or exactly `{channel}_{to}.jsonl`),
|
||||
/// then picks the most recently modified match. This handles session key
|
||||
/// formats such as `telegram_diskiller.jsonl` and
|
||||
/// `telegram_5673725398_diskiller.jsonl`.
|
||||
/// Returns `None` when `target`/`to` are not configured or no session exists.
|
||||
const HEARTBEAT_SESSION_CONTEXT_MESSAGES: usize = 20;
|
||||
|
||||
fn load_heartbeat_session_context(config: &Config) -> Option<String> {
|
||||
use crate::providers::traits::ChatMessage;
|
||||
|
||||
let channel = config
|
||||
.heartbeat
|
||||
.target
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())?;
|
||||
let to = config
|
||||
.heartbeat
|
||||
.to
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())?;
|
||||
|
||||
if channel.contains('/') || channel.contains('\\') || to.contains('/') || to.contains('\\') {
|
||||
tracing::warn!("heartbeat session context: channel/to contains path separators, skipping");
|
||||
return None;
|
||||
}
|
||||
|
||||
let sessions_dir = config.workspace_dir.join("sessions");
|
||||
|
||||
// Find the most recently modified JSONL file that belongs to this target.
|
||||
// Matches both `{channel}_{to}.jsonl` and `{channel}_{anything}_{to}.jsonl`.
|
||||
let prefix = format!("{channel}_");
|
||||
let suffix = format!("_{to}.jsonl");
|
||||
let exact = format!("{channel}_{to}.jsonl");
|
||||
let mid_prefix = format!("{channel}_{to}_");
|
||||
|
||||
let path = std::fs::read_dir(&sessions_dir)
|
||||
.ok()?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
let name = e.file_name();
|
||||
let name = name.to_string_lossy();
|
||||
name.ends_with(".jsonl")
|
||||
&& (name == exact
|
||||
|| (name.starts_with(&prefix) && name.ends_with(&suffix))
|
||||
|| name.starts_with(&mid_prefix))
|
||||
})
|
||||
.max_by_key(|e| {
|
||||
e.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
|
||||
})
|
||||
.map(|e| e.path())?;
|
||||
|
||||
if !path.exists() {
|
||||
tracing::debug!("💓 Heartbeat session context: no session file found for {channel}/{to}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let messages = load_jsonl_messages(&path);
|
||||
if messages.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let recent: Vec<&ChatMessage> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" || m.role == "assistant")
|
||||
.rev()
|
||||
.take(HEARTBEAT_SESSION_CONTEXT_MESSAGES)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
// Only inject context if there is at least one real user message in the
|
||||
// window. If the JSONL contains only assistant messages (e.g. previous
|
||||
// heartbeat outputs with no reply yet), skip context to avoid feeding
|
||||
// Monika's own messages back to her in a loop.
|
||||
let has_user_message = recent.iter().any(|m| m.role == "user");
|
||||
if !has_user_message {
|
||||
tracing::debug!(
|
||||
"💓 Heartbeat session context: no user messages in recent history — skipping"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
// Use the session file's mtime as a proxy for when the last message arrived.
|
||||
let last_message_age = std::fs::metadata(&path)
|
||||
.ok()
|
||||
.and_then(|m| m.modified().ok())
|
||||
.and_then(|mtime| mtime.elapsed().ok());
|
||||
|
||||
let silence_note = match last_message_age {
|
||||
Some(age) => {
|
||||
let mins = age.as_secs() / 60;
|
||||
if mins < 60 {
|
||||
format!("(last message ~{mins} minutes ago)\n")
|
||||
} else {
|
||||
let hours = mins / 60;
|
||||
let rem = mins % 60;
|
||||
if rem == 0 {
|
||||
format!("(last message ~{hours}h ago)\n")
|
||||
} else {
|
||||
format!("(last message ~{hours}h {rem}m ago)\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
"💓 Heartbeat session context: {} messages from {}, silence: {}",
|
||||
recent.len(),
|
||||
path.display(),
|
||||
silence_note.trim(),
|
||||
);
|
||||
|
||||
let mut ctx = format!(
|
||||
"[Recent conversation history — use this for context when composing your message] {silence_note}",
|
||||
);
|
||||
for msg in &recent {
|
||||
let label = if msg.role == "user" { "User" } else { "You" };
|
||||
// Truncate very long messages to avoid bloating the prompt.
|
||||
// Use char_indices to avoid panicking on multi-byte UTF-8 characters.
|
||||
let content = if msg.content.len() > 500 {
|
||||
let truncate_at = msg
|
||||
.content
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 500)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
format!("{}…", &msg.content[..truncate_at])
|
||||
} else {
|
||||
msg.content.clone()
|
||||
};
|
||||
ctx.push_str(label);
|
||||
ctx.push_str(": ");
|
||||
ctx.push_str(&content);
|
||||
ctx.push('\n');
|
||||
}
|
||||
|
||||
Some(ctx)
|
||||
}
|
||||
|
||||
/// Read the last `HEARTBEAT_SESSION_CONTEXT_MESSAGES` `ChatMessage` lines from
|
||||
/// a JSONL session file using a bounded rolling window so we never hold the
|
||||
/// entire file in memory.
|
||||
fn load_jsonl_messages(path: &std::path::Path) -> Vec<crate::providers::traits::ChatMessage> {
|
||||
use std::collections::VecDeque;
|
||||
use std::io::BufRead;
|
||||
|
||||
let file = match std::fs::File::open(path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut window: VecDeque<crate::providers::traits::ChatMessage> =
|
||||
VecDeque::with_capacity(HEARTBEAT_SESSION_CONTEXT_MESSAGES + 1);
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<crate::providers::traits::ChatMessage>(trimmed) {
|
||||
window.push_back(msg);
|
||||
if window.len() > HEARTBEAT_SESSION_CONTEXT_MESSAGES {
|
||||
window.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
window.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Auto-detect the best channel for heartbeat delivery by checking which
|
||||
/// channels are configured. Returns the first match in priority order.
|
||||
fn auto_detect_heartbeat_channel(config: &Config) -> Option<(String, String)> {
|
||||
|
||||
+3
-1
@@ -23,7 +23,7 @@ fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
|
||||
}
|
||||
|
||||
/// Verify bearer token against PairingGuard. Returns error response if unauthorized.
|
||||
fn require_auth(
|
||||
pub(super) fn require_auth(
|
||||
state: &AppState,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
|
||||
@@ -1429,6 +1429,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -1439,6 +1440,7 @@ mod tests {
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
path_prefix: String::new(),
|
||||
canvas_store: crate::tools::canvas::CanvasStore::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
//! Live Canvas gateway routes — REST + WebSocket for real-time canvas updates.
|
||||
//!
|
||||
//! - `GET /api/canvas/:id` — get current canvas content (JSON)
|
||||
//! - `POST /api/canvas/:id` — push content programmatically
|
||||
//! - `GET /api/canvas` — list all active canvases
|
||||
//! - `WS /ws/canvas/:id` — real-time canvas updates via WebSocket
|
||||
|
||||
use super::api::require_auth;
|
||||
use super::AppState;
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
Path, State, WebSocketUpgrade,
|
||||
},
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// POST /api/canvas/:id request body.
|
||||
#[derive(Deserialize)]
|
||||
pub struct CanvasPostBody {
|
||||
pub content_type: Option<String>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// GET /api/canvas — list all active canvases.
|
||||
pub async fn handle_canvas_list(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let ids = state.canvas_store.list();
|
||||
Json(serde_json::json!({ "canvases": ids })).into_response()
|
||||
}
|
||||
|
||||
/// GET /api/canvas/:id — get current canvas content.
|
||||
pub async fn handle_canvas_get(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
match state.canvas_store.snapshot(&id) {
|
||||
Some(frame) => Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frame": frame,
|
||||
}))
|
||||
.into_response(),
|
||||
None => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({ "error": format!("Canvas '{}' not found", id) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/canvas/:id/history — get canvas frame history.
|
||||
pub async fn handle_canvas_history(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let history = state.canvas_store.history(&id);
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frames": history,
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// POST /api/canvas/:id — push content to a canvas.
|
||||
pub async fn handle_canvas_post(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<CanvasPostBody>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let content_type = body.content_type.as_deref().unwrap_or("html");
|
||||
|
||||
// Validate content_type against allowed set (prevent injecting "eval" frames via REST).
|
||||
if !crate::tools::canvas::ALLOWED_CONTENT_TYPES.contains(&content_type) {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": format!(
|
||||
"Invalid content_type '{}'. Allowed: {:?}",
|
||||
content_type,
|
||||
crate::tools::canvas::ALLOWED_CONTENT_TYPES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Enforce content size limit (same as tool-side validation).
|
||||
if body.content.len() > crate::tools::canvas::MAX_CONTENT_SIZE {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Json(serde_json::json!({
|
||||
"error": format!(
|
||||
"Content exceeds maximum size of {} bytes",
|
||||
crate::tools::canvas::MAX_CONTENT_SIZE
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match state.canvas_store.render(&id, content_type, &body.content) {
|
||||
Some(frame) => (
|
||||
StatusCode::CREATED,
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frame": frame,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
None => (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(serde_json::json!({
|
||||
"error": "Maximum canvas count reached. Clear unused canvases first."
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// DELETE /api/canvas/:id — clear a canvas.
|
||||
pub async fn handle_canvas_clear(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
state.canvas_store.clear(&id);
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"status": "cleared",
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// WS /ws/canvas/:id — real-time canvas updates.
|
||||
pub async fn handle_ws_canvas(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> impl IntoResponse {
|
||||
// Auth check (same pattern as ws::handle_ws_chat)
|
||||
if state.pairing.require_pairing() {
|
||||
let token = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.or_else(|| {
|
||||
// Fallback: check query params in the upgrade request URI
|
||||
headers
|
||||
.get("sec-websocket-protocol")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|protos| {
|
||||
protos
|
||||
.split(',')
|
||||
.map(|p| p.trim())
|
||||
.find_map(|p| p.strip_prefix("bearer."))
|
||||
})
|
||||
})
|
||||
.unwrap_or("");
|
||||
|
||||
if !state.pairing.is_authenticated(token) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Unauthorized — provide Authorization header or Sec-WebSocket-Protocol bearer",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
ws.on_upgrade(move |socket| handle_canvas_socket(socket, state, id))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn handle_canvas_socket(socket: WebSocket, state: AppState, canvas_id: String) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Subscribe to canvas updates
|
||||
let mut rx = match state.canvas_store.subscribe(&canvas_id) {
|
||||
Some(rx) => rx,
|
||||
None => {
|
||||
let msg = serde_json::json!({
|
||||
"type": "error",
|
||||
"error": "Maximum canvas count reached",
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Send current state immediately if available
|
||||
if let Some(frame) = state.canvas_store.snapshot(&canvas_id) {
|
||||
let msg = serde_json::json!({
|
||||
"type": "frame",
|
||||
"canvas_id": canvas_id,
|
||||
"frame": frame,
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
}
|
||||
|
||||
// Send a connected acknowledgement
|
||||
let ack = serde_json::json!({
|
||||
"type": "connected",
|
||||
"canvas_id": canvas_id,
|
||||
});
|
||||
let _ = sender.send(Message::Text(ack.to_string().into())).await;
|
||||
|
||||
// Spawn a task that forwards broadcast updates to the WebSocket
|
||||
let canvas_id_clone = canvas_id.clone();
|
||||
let send_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(frame) => {
|
||||
let msg = serde_json::json!({
|
||||
"type": "frame",
|
||||
"canvas_id": canvas_id_clone,
|
||||
"frame": frame,
|
||||
});
|
||||
if sender
|
||||
.send(Message::Text(msg.to_string().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
// Client fell behind — notify and continue rather than disconnecting.
|
||||
let msg = serde_json::json!({
|
||||
"type": "lagged",
|
||||
"canvas_id": canvas_id_clone,
|
||||
"missed_frames": n,
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Read loop: we mostly ignore incoming messages but handle close/ping
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Close(_)) | Err(_) => break,
|
||||
_ => {} // Ignore all other messages (pings are handled by axum)
|
||||
}
|
||||
}
|
||||
|
||||
// Abort the send task when the connection is closed
|
||||
send_task.abort();
|
||||
}
|
||||
+138
-17
@@ -11,14 +11,15 @@ pub mod api;
|
||||
pub mod api_pairing;
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
pub mod api_plugins;
|
||||
pub mod canvas;
|
||||
pub mod nodes;
|
||||
pub mod sse;
|
||||
pub mod static_files;
|
||||
pub mod ws;
|
||||
|
||||
use crate::channels::{
|
||||
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel, LinqChannel,
|
||||
NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
|
||||
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel,
|
||||
GmailPushChannel, LinqChannel, NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::cost::CostTracker;
|
||||
@@ -28,6 +29,7 @@ use crate::runtime;
|
||||
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools;
|
||||
use crate::tools::canvas::CanvasStore;
|
||||
use crate::tools::traits::ToolSpec;
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::{Context, Result};
|
||||
@@ -336,6 +338,8 @@ pub struct AppState {
|
||||
/// Nextcloud Talk webhook secret for signature verification
|
||||
pub nextcloud_talk_webhook_secret: Option<Arc<str>>,
|
||||
pub wati: Option<Arc<WatiChannel>>,
|
||||
/// Gmail Pub/Sub push notification channel
|
||||
pub gmail_push: Option<Arc<GmailPushChannel>>,
|
||||
/// Observability backend for metrics scraping
|
||||
pub observer: Arc<dyn crate::observability::Observer>,
|
||||
/// Registered tool specs (for web dashboard tools page)
|
||||
@@ -356,6 +360,8 @@ pub struct AppState {
|
||||
pub device_registry: Option<Arc<api_pairing::DeviceRegistry>>,
|
||||
/// Pending pairing request store
|
||||
pub pending_pairings: Option<Arc<api_pairing::PairingStore>>,
|
||||
/// Shared canvas store for Live Canvas (A2UI) system
|
||||
pub canvas_store: CanvasStore,
|
||||
}
|
||||
|
||||
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
|
||||
@@ -432,21 +438,25 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let (mut tools_registry_raw, delegate_handle_gw) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
);
|
||||
let canvas_store = tools::CanvasStore::new();
|
||||
|
||||
let (mut tools_registry_raw, delegate_handle_gw, _reaction_handle_gw) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
Some(canvas_store.clone()),
|
||||
);
|
||||
|
||||
// ── Wire MCP tools into the gateway tool registry (non-fatal) ───
|
||||
// Without this, the `/api/tools` endpoint misses MCP tools.
|
||||
@@ -629,6 +639,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
})
|
||||
.map(Arc::from);
|
||||
|
||||
// Gmail Push channel (if configured and enabled)
|
||||
let gmail_push_channel: Option<Arc<GmailPushChannel>> = config
|
||||
.channels_config
|
||||
.gmail_push
|
||||
.as_ref()
|
||||
.filter(|gp| gp.enabled)
|
||||
.map(|gp| Arc::new(GmailPushChannel::new(gp.clone())));
|
||||
|
||||
// ── Session persistence for WS chat ─────────────────────
|
||||
let session_backend: Option<Arc<dyn SessionBackend>> = if config.gateway.session_persistence {
|
||||
match SqliteSessionBackend::new(&config.workspace_dir) {
|
||||
@@ -800,6 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
nextcloud_talk: nextcloud_talk_channel,
|
||||
nextcloud_talk_webhook_secret,
|
||||
wati: wati_channel,
|
||||
gmail_push: gmail_push_channel,
|
||||
observer: broadcast_observer,
|
||||
tools_registry,
|
||||
cost_tracker,
|
||||
@@ -810,6 +829,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
device_registry,
|
||||
pending_pairings,
|
||||
path_prefix: path_prefix.unwrap_or("").to_string(),
|
||||
canvas_store,
|
||||
};
|
||||
|
||||
// Config PUT needs larger body limit (1MB)
|
||||
@@ -834,6 +854,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/wati", get(handle_wati_verify))
|
||||
.route("/wati", post(handle_wati_webhook))
|
||||
.route("/nextcloud-talk", post(handle_nextcloud_talk_webhook))
|
||||
.route("/webhook/gmail", post(handle_gmail_push_webhook))
|
||||
// ── Web Dashboard API routes ──
|
||||
.route("/api/status", get(api::handle_api_status))
|
||||
.route("/api/config", get(api::handle_api_config_get))
|
||||
@@ -874,6 +895,18 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route(
|
||||
"/api/devices/{id}/token/rotate",
|
||||
post(api_pairing::rotate_token),
|
||||
)
|
||||
// ── Live Canvas (A2UI) routes ──
|
||||
.route("/api/canvas", get(canvas::handle_canvas_list))
|
||||
.route(
|
||||
"/api/canvas/{id}",
|
||||
get(canvas::handle_canvas_get)
|
||||
.post(canvas::handle_canvas_post)
|
||||
.delete(canvas::handle_canvas_clear),
|
||||
)
|
||||
.route(
|
||||
"/api/canvas/{id}/history",
|
||||
get(canvas::handle_canvas_history),
|
||||
);
|
||||
|
||||
// ── Plugin management API (requires plugins-wasm feature) ──
|
||||
@@ -888,6 +921,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/api/events", get(sse::handle_sse_events))
|
||||
// ── WebSocket agent chat ──
|
||||
.route("/ws/chat", get(ws::handle_ws_chat))
|
||||
// ── WebSocket canvas updates ──
|
||||
.route("/ws/canvas/{id}", get(canvas::handle_ws_canvas))
|
||||
// ── WebSocket node discovery ──
|
||||
.route("/ws/nodes", get(nodes::handle_ws_nodes))
|
||||
// ── Static assets (web dashboard) ──
|
||||
@@ -1812,6 +1847,74 @@ async fn handle_nextcloud_talk_webhook(
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
/// Maximum request body size for the Gmail webhook endpoint (1 MB).
|
||||
/// Google Pub/Sub messages are typically under 10 KB.
|
||||
const GMAIL_WEBHOOK_MAX_BODY: usize = 1024 * 1024;
|
||||
|
||||
/// POST /webhook/gmail — incoming Gmail Pub/Sub push notification
|
||||
async fn handle_gmail_push_webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> impl IntoResponse {
|
||||
let Some(ref gmail_push) = state.gmail_push else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Gmail push not configured"})),
|
||||
);
|
||||
};
|
||||
|
||||
// Enforce body size limit.
|
||||
if body.len() > GMAIL_WEBHOOK_MAX_BODY {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Json(serde_json::json!({"error": "Request body too large"})),
|
||||
);
|
||||
}
|
||||
|
||||
// Authenticate the webhook request using a shared secret.
|
||||
let secret = gmail_push.resolve_webhook_secret();
|
||||
if !secret.is_empty() {
|
||||
let provided = headers
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.unwrap_or("");
|
||||
|
||||
if provided != secret {
|
||||
tracing::warn!("Gmail push webhook: unauthorized request");
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"error": "Unauthorized"})),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let body_str = String::from_utf8_lossy(&body);
|
||||
let envelope: crate::channels::gmail_push::PubSubEnvelope =
|
||||
match serde_json::from_str(&body_str) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!("Gmail push webhook: invalid payload: {e}");
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Invalid Pub/Sub envelope"})),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Process the notification asynchronously (non-blocking for the webhook response)
|
||||
let channel = Arc::clone(gmail_push);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = channel.handle_notification(&envelope).await {
|
||||
tracing::error!("Gmail push notification processing failed: {e:#}");
|
||||
}
|
||||
});
|
||||
|
||||
// Acknowledge immediately — Google Pub/Sub requires a 2xx within ~10s
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// ADMIN HANDLERS (for CLI management)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -2000,6 +2103,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2010,6 +2114,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_metrics(State(state)).await.into_response();
|
||||
@@ -2056,6 +2161,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer,
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2066,6 +2172,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_metrics(State(state)).await.into_response();
|
||||
@@ -2441,6 +2548,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2451,6 +2559,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2511,6 +2620,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2521,6 +2631,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
@@ -2593,6 +2704,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2603,6 +2715,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_webhook(
|
||||
@@ -2647,6 +2760,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2657,6 +2771,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2706,6 +2821,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2716,6 +2832,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2770,6 +2887,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2780,6 +2898,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = Box::pin(handle_nextcloud_talk_webhook(
|
||||
@@ -2830,6 +2949,7 @@ mod tests {
|
||||
nextcloud_talk: Some(channel),
|
||||
nextcloud_talk_webhook_secret: Some(Arc::from(secret)),
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2840,6 +2960,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
@@ -407,7 +407,11 @@ mod tests {
|
||||
// Simpler: write a temp script.
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let script_path = dir.path().join("tool.sh");
|
||||
std::fs::write(&script_path, format!("#!/bin/sh\necho '{}'\n", result_json)).unwrap();
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
format!("#!/bin/sh\ncat > /dev/null\necho '{}'\n", result_json),
|
||||
)
|
||||
.unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
//! Audit trail for memory operations.
|
||||
//!
|
||||
//! Provides a decorator `AuditedMemory<M>` that wraps any `Memory` backend
|
||||
//! and logs all operations to a `memory_audit` table. Opt-in via
|
||||
//! `[memory] audit_enabled = true`.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Audit log entry operations.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum AuditOp {
|
||||
Store,
|
||||
Recall,
|
||||
Get,
|
||||
List,
|
||||
Forget,
|
||||
StoreProcedural,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AuditOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Store => write!(f, "store"),
|
||||
Self::Recall => write!(f, "recall"),
|
||||
Self::Get => write!(f, "get"),
|
||||
Self::List => write!(f, "list"),
|
||||
Self::Forget => write!(f, "forget"),
|
||||
Self::StoreProcedural => write!(f, "store_procedural"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decorator that wraps a `Memory` backend with audit logging.
|
||||
pub struct AuditedMemory<M: Memory> {
|
||||
inner: M,
|
||||
audit_conn: Arc<Mutex<Connection>>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl<M: Memory> AuditedMemory<M> {
|
||||
pub fn new(inner: M, workspace_dir: &Path) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join("audit.db");
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&db_path)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
CREATE TABLE IF NOT EXISTS memory_audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
operation TEXT NOT NULL,
|
||||
key TEXT,
|
||||
namespace TEXT,
|
||||
session_id TEXT,
|
||||
timestamp TEXT NOT NULL,
|
||||
metadata TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON memory_audit(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_operation ON memory_audit(operation);",
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
audit_conn: Arc::new(Mutex::new(conn)),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn log_audit(
|
||||
&self,
|
||||
op: AuditOp,
|
||||
key: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
session_id: Option<&str>,
|
||||
metadata: Option<&str>,
|
||||
) {
|
||||
let conn = self.audit_conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let op_str = op.to_string();
|
||||
let _ = conn.execute(
|
||||
"INSERT INTO memory_audit (operation, key, namespace, session_id, timestamp, metadata)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![op_str, key, namespace, session_id, now, metadata],
|
||||
);
|
||||
}
|
||||
|
||||
/// Prune audit entries older than the given number of days.
|
||||
pub fn prune_older_than(&self, retention_days: u32) -> anyhow::Result<u64> {
|
||||
let conn = self.audit_conn.lock();
|
||||
let cutoff =
|
||||
(Local::now() - chrono::Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memory_audit WHERE timestamp < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
Ok(u64::try_from(affected).unwrap_or(0))
|
||||
}
|
||||
|
||||
/// Count total audit entries.
|
||||
pub fn audit_count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self.audit_conn.lock();
|
||||
let count: i64 =
|
||||
conn.query_row("SELECT COUNT(*) FROM memory_audit", [], |row| row.get(0))?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(count as usize)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<M: Memory> Memory for AuditedMemory<M> {
|
||||
fn name(&self) -> &str {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(AuditOp::Store, Some(key), None, session_id, None);
|
||||
self.inner.store(key, content, category, session_id).await
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(
|
||||
AuditOp::Recall,
|
||||
None,
|
||||
None,
|
||||
session_id,
|
||||
Some(&format!("query={query}")),
|
||||
);
|
||||
self.inner
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
self.log_audit(AuditOp::Get, Some(key), None, None, None);
|
||||
self.inner.get(key).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(AuditOp::List, None, None, session_id, None);
|
||||
self.inner.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
self.log_audit(AuditOp::Forget, Some(key), None, None, None);
|
||||
self.inner.forget(key).await
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
self.inner.count().await
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.inner.health_check().await
|
||||
}
|
||||
|
||||
async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(
|
||||
AuditOp::StoreProcedural,
|
||||
None,
|
||||
None,
|
||||
session_id,
|
||||
Some(&format!("messages={}", messages.len())),
|
||||
);
|
||||
self.inner.store_procedural(messages, session_id).await
|
||||
}
|
||||
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.log_audit(
|
||||
AuditOp::Recall,
|
||||
None,
|
||||
Some(namespace),
|
||||
session_id,
|
||||
Some(&format!("query={query}")),
|
||||
);
|
||||
self.inner
|
||||
.recall_namespaced(namespace, query, limit, session_id, since, until)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.log_audit(AuditOp::Store, Some(key), namespace, session_id, None);
|
||||
self.inner
|
||||
.store_with_metadata(key, content, category, session_id, namespace, importance)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::NoneMemory;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_logs_store_operation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
audited
|
||||
.store("test_key", "test_value", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(audited.audit_count().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_logs_recall_operation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
let _ = audited.recall("query", 10, None, None, None).await;
|
||||
|
||||
assert_eq!(audited.audit_count().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_prune_works() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
audited
|
||||
.store("k1", "v1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Pruning with 0 days should remove entries
|
||||
let pruned = audited.prune_older_than(0).unwrap();
|
||||
// Entry was just created, so 0-day retention should remove it
|
||||
// Pruning should succeed (pruned is usize, always >= 0)
|
||||
let _ = pruned;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audited_memory_delegates_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let inner = NoneMemory::new();
|
||||
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
|
||||
|
||||
assert_eq!(audited.name(), "none");
|
||||
assert!(audited.health_check().await);
|
||||
assert_eq!(audited.count().await.unwrap(), 0);
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ pub enum MemoryBackendKind {
|
||||
Lucid,
|
||||
Postgres,
|
||||
Qdrant,
|
||||
Mem0,
|
||||
Markdown,
|
||||
None,
|
||||
Unknown,
|
||||
@@ -66,15 +65,6 @@ const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const MEM0_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "mem0",
|
||||
label: "Mem0 (OpenMemory) — semantic memory with LLM fact extraction via [memory.mem0]",
|
||||
auto_save_default: true,
|
||||
uses_sqlite_hygiene: false,
|
||||
sqlite_based: false,
|
||||
optional_dependency: true,
|
||||
};
|
||||
|
||||
const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "none",
|
||||
label: "None — disable persistent memory",
|
||||
@@ -114,7 +104,6 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
|
||||
"lucid" => MemoryBackendKind::Lucid,
|
||||
"postgres" => MemoryBackendKind::Postgres,
|
||||
"qdrant" => MemoryBackendKind::Qdrant,
|
||||
"mem0" | "openmemory" => MemoryBackendKind::Mem0,
|
||||
"markdown" => MemoryBackendKind::Markdown,
|
||||
"none" => MemoryBackendKind::None,
|
||||
_ => MemoryBackendKind::Unknown,
|
||||
@@ -127,7 +116,6 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
|
||||
MemoryBackendKind::Lucid => LUCID_PROFILE,
|
||||
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
|
||||
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
|
||||
MemoryBackendKind::Mem0 => MEM0_PROFILE,
|
||||
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
|
||||
MemoryBackendKind::None => NONE_PROFILE,
|
||||
MemoryBackendKind::Unknown => CUSTOM_PROFILE,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,173 @@
|
||||
//! Conflict resolution for memory entries.
|
||||
//!
|
||||
//! Before storing Core memories, performs a semantic similarity check against
|
||||
//! existing entries. If cosine similarity exceeds a threshold but content
|
||||
//! differs, the old entry is marked as superseded.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
|
||||
/// Check for conflicting memories and mark old ones as superseded.
|
||||
///
|
||||
/// Returns the list of entry IDs that were superseded.
|
||||
pub async fn check_and_resolve_conflicts(
|
||||
memory: &dyn Memory,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: &MemoryCategory,
|
||||
threshold: f64,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
// Only check conflicts for Core memories
|
||||
if !matches!(category, MemoryCategory::Core) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Search for similar existing entries
|
||||
let candidates = memory.recall(content, 10, None, None, None).await?;
|
||||
|
||||
let mut superseded = Vec::new();
|
||||
for candidate in &candidates {
|
||||
if candidate.key == key {
|
||||
continue; // Same key = update, not conflict
|
||||
}
|
||||
if !matches!(candidate.category, MemoryCategory::Core) {
|
||||
continue;
|
||||
}
|
||||
if let Some(score) = candidate.score {
|
||||
if score > threshold && candidate.content != content {
|
||||
superseded.push(candidate.id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(superseded)
|
||||
}
|
||||
|
||||
/// Mark entries as superseded in SQLite by setting their `superseded_by` column.
|
||||
pub fn mark_superseded(
|
||||
conn: &rusqlite::Connection,
|
||||
superseded_ids: &[String],
|
||||
new_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if superseded_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for id in superseded_ids {
|
||||
conn.execute(
|
||||
"UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
|
||||
rusqlite::params![new_id, id],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Simple text-based conflict detection without embeddings.
|
||||
///
|
||||
/// Uses token overlap (Jaccard similarity) as a fast approximation
|
||||
/// when vector embeddings are unavailable.
|
||||
pub fn jaccard_similarity(a: &str, b: &str) -> f64 {
|
||||
let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
|
||||
let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
|
||||
|
||||
if words_a.is_empty() && words_b.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
if words_a.is_empty() || words_b.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let intersection = words_a.intersection(&words_b).count();
|
||||
let union = words_a.union(&words_b).count();
|
||||
|
||||
if union == 0 {
|
||||
0.0
|
||||
} else {
|
||||
intersection as f64 / union as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Find potentially conflicting entries using text similarity when embeddings
|
||||
/// are not available. Returns entries above the threshold.
|
||||
pub fn find_text_conflicts(
|
||||
entries: &[MemoryEntry],
|
||||
new_content: &str,
|
||||
threshold: f64,
|
||||
) -> Vec<String> {
|
||||
entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
matches!(e.category, MemoryCategory::Core)
|
||||
&& e.superseded_by.is_none()
|
||||
&& jaccard_similarity(&e.content, new_content) > threshold
|
||||
&& e.content != new_content
|
||||
})
|
||||
.map(|e| e.id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn jaccard_identical_strings() {
|
||||
let sim = jaccard_similarity("hello world", "hello world");
|
||||
assert!((sim - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_disjoint_strings() {
|
||||
let sim = jaccard_similarity("hello world", "foo bar");
|
||||
assert!(sim.abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_partial_overlap() {
|
||||
let sim = jaccard_similarity("the quick brown fox", "the slow brown dog");
|
||||
// overlap: "the", "brown" = 2; union: "the", "quick", "brown", "fox", "slow", "dog" = 6
|
||||
assert!((sim - 2.0 / 6.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_empty_strings() {
|
||||
assert!((jaccard_similarity("", "") - 1.0).abs() < f64::EPSILON);
|
||||
assert!(jaccard_similarity("hello", "").abs() < f64::EPSILON);
|
||||
assert!(jaccard_similarity("", "hello").abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_text_conflicts_filters_correctly() {
|
||||
let entries = vec![
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "pref".into(),
|
||||
content: "User prefers Rust for systems work".into(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.7),
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
key: "daily1".into(),
|
||||
content: "User prefers Rust for systems work".into(),
|
||||
category: MemoryCategory::Daily,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.3),
|
||||
superseded_by: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Only Core entries should be flagged
|
||||
let conflicts = find_text_conflicts(&entries, "User now prefers Go for systems work", 0.3);
|
||||
assert_eq!(conflicts.len(), 1);
|
||||
assert_eq!(conflicts[0], "1");
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@
|
||||
//! This two-phase approach replaces the naive raw-message auto-save with
|
||||
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
|
||||
|
||||
use crate::memory::conflict;
|
||||
use crate::memory::importance;
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
@@ -78,8 +80,33 @@ pub async fn consolidate_turn(
|
||||
if let Some(ref update) = result.memory_update {
|
||||
if !update.trim().is_empty() {
|
||||
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
|
||||
|
||||
// Compute importance score heuristically.
|
||||
let imp = importance::compute_importance(update, &MemoryCategory::Core);
|
||||
|
||||
// Check for conflicts with existing Core memories.
|
||||
if let Err(e) = conflict::check_and_resolve_conflicts(
|
||||
memory,
|
||||
&mem_key,
|
||||
update,
|
||||
&MemoryCategory::Core,
|
||||
0.85,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::debug!("conflict check skipped: {e}");
|
||||
}
|
||||
|
||||
// Store with importance metadata.
|
||||
memory
|
||||
.store(&mem_key, update, MemoryCategory::Core, None)
|
||||
.store_with_metadata(
|
||||
&mem_key,
|
||||
update,
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
None,
|
||||
Some(imp),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
+42
-4
@@ -1,4 +1,5 @@
|
||||
use crate::config::MemoryConfig;
|
||||
use crate::memory::policy::PolicyEnforcer;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Local, NaiveDate, Utc};
|
||||
use rusqlite::{params, Connection};
|
||||
@@ -47,6 +48,13 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Use policy engine for per-category retention overrides.
|
||||
let enforcer = PolicyEnforcer::new(&config.policy);
|
||||
let conversation_retention = enforcer.retention_days_for_category(
|
||||
&crate::memory::traits::MemoryCategory::Conversation,
|
||||
config.conversation_retention_days,
|
||||
);
|
||||
|
||||
let report = HygieneReport {
|
||||
archived_memory_files: archive_daily_memory_files(
|
||||
workspace_dir,
|
||||
@@ -55,12 +63,16 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||
archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?,
|
||||
purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?,
|
||||
purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?,
|
||||
pruned_conversation_rows: prune_conversation_rows(
|
||||
workspace_dir,
|
||||
config.conversation_retention_days,
|
||||
)?,
|
||||
pruned_conversation_rows: prune_conversation_rows(workspace_dir, conversation_retention)?,
|
||||
};
|
||||
|
||||
// Prune audit entries if audit is enabled.
|
||||
if config.audit_enabled {
|
||||
if let Err(e) = prune_audit_entries(workspace_dir, config.audit_retention_days) {
|
||||
tracing::debug!("audit pruning skipped: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
write_state(workspace_dir, &report)?;
|
||||
|
||||
if report.total_actions() > 0 {
|
||||
@@ -318,6 +330,32 @@ fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result<
|
||||
Ok(u64::try_from(affected).unwrap_or(0))
|
||||
}
|
||||
|
||||
fn prune_audit_entries(workspace_dir: &Path, retention_days: u32) -> Result<()> {
|
||||
if retention_days == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let db_path = workspace_dir.join("memory").join("audit.db");
|
||||
if !db_path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)?;
|
||||
conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?;
|
||||
let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||
|
||||
let affected = conn.execute(
|
||||
"DELETE FROM memory_audit WHERE timestamp < ?1",
|
||||
params![cutoff],
|
||||
)?;
|
||||
|
||||
if affected > 0 {
|
||||
tracing::debug!("pruned {affected} audit entries older than {retention_days} days");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn memory_date_from_filename(filename: &str) -> Option<NaiveDate> {
|
||||
let stem = filename.strip_suffix(".md")?;
|
||||
let date_part = stem.split('_').next().unwrap_or(stem);
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
//! Heuristic importance scorer for non-LLM paths.
|
||||
//!
|
||||
//! Assigns importance scores (0.0–1.0) based on memory category and keyword
|
||||
//! signals. Used when LLM-based consolidation is unavailable or as a fast
|
||||
//! first-pass scorer.
|
||||
|
||||
use super::traits::MemoryCategory;
|
||||
|
||||
/// Base importance by category.
|
||||
fn category_base_score(category: &MemoryCategory) -> f64 {
|
||||
match category {
|
||||
MemoryCategory::Core => 0.7,
|
||||
MemoryCategory::Daily => 0.3,
|
||||
MemoryCategory::Conversation => 0.2,
|
||||
MemoryCategory::Custom(_) => 0.4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyword boost: if the content contains high-signal keywords, bump importance.
|
||||
fn keyword_boost(content: &str) -> f64 {
|
||||
const HIGH_SIGNAL_KEYWORDS: &[&str] = &[
|
||||
"decision",
|
||||
"always",
|
||||
"never",
|
||||
"important",
|
||||
"critical",
|
||||
"must",
|
||||
"requirement",
|
||||
"policy",
|
||||
"rule",
|
||||
"principle",
|
||||
];
|
||||
|
||||
let lowered = content.to_ascii_lowercase();
|
||||
let matches = HIGH_SIGNAL_KEYWORDS
|
||||
.iter()
|
||||
.filter(|kw| lowered.contains(**kw))
|
||||
.count();
|
||||
|
||||
// Cap at +0.2
|
||||
(matches as f64 * 0.1).min(0.2)
|
||||
}
|
||||
|
||||
/// Compute heuristic importance score for a memory entry.
|
||||
pub fn compute_importance(content: &str, category: &MemoryCategory) -> f64 {
|
||||
let base = category_base_score(category);
|
||||
let boost = keyword_boost(content);
|
||||
(base + boost).min(1.0)
|
||||
}
|
||||
|
||||
/// Compute final retrieval score incorporating importance and recency.
|
||||
///
|
||||
/// `hybrid_score`: raw retrieval score from FTS/vector (0.0–1.0)
|
||||
/// `importance`: importance score (0.0–1.0)
|
||||
/// `recency_decay`: recency factor (0.0–1.0, 1.0 = very recent)
|
||||
pub fn weighted_final_score(hybrid_score: f64, importance: f64, recency_decay: f64) -> f64 {
|
||||
hybrid_score * 0.7 + importance * 0.2 + recency_decay * 0.1
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn core_category_has_high_base_score() {
|
||||
let score = compute_importance("some fact", &MemoryCategory::Core);
|
||||
assert!((score - 0.7).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversation_category_has_low_base_score() {
|
||||
let score = compute_importance("chat message", &MemoryCategory::Conversation);
|
||||
assert!((score - 0.2).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keywords_boost_importance() {
|
||||
let score = compute_importance(
|
||||
"This is a critical decision that must always be followed",
|
||||
&MemoryCategory::Core,
|
||||
);
|
||||
// base 0.7 + boost for "critical", "decision", "must", "always" = 0.7 + 0.2 (capped) = 0.9
|
||||
assert!(score > 0.85);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boost_capped_at_point_two() {
|
||||
let score = compute_importance(
|
||||
"important critical decision rule policy must always never requirement principle",
|
||||
&MemoryCategory::Conversation,
|
||||
);
|
||||
// base 0.2 + max boost 0.2 = 0.4
|
||||
assert!((score - 0.4).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weighted_final_score_formula() {
|
||||
let score = weighted_final_score(1.0, 1.0, 1.0);
|
||||
assert!((score - 1.0).abs() < f64::EPSILON);
|
||||
|
||||
let score = weighted_final_score(0.0, 0.0, 0.0);
|
||||
assert!(score.abs() < f64::EPSILON);
|
||||
|
||||
let score = weighted_final_score(0.5, 0.5, 0.5);
|
||||
assert!((score - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -226,6 +226,9 @@ impl LucidMemory {
|
||||
timestamp: now.clone(),
|
||||
session_id: None,
|
||||
score: Some((1.0 - rank as f64 * 0.05).max(0.1)),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -91,6 +91,9 @@ impl MarkdownMemory {
|
||||
timestamp: filename.to_string(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
|
||||
@@ -1,635 +0,0 @@
|
||||
//! Mem0 (OpenMemory) memory backend.
|
||||
//!
|
||||
//! Connects to a self-hosted OpenMemory server via its REST API
|
||||
//! and implements the [`Memory`] trait for seamless integration with
|
||||
//! ZeroClaw's auto-save, auto-recall, and hygiene lifecycle.
|
||||
//!
|
||||
//! Deploy OpenMemory: `docker compose up` from the mem0 repo.
|
||||
//! Default endpoint: `http://localhost:8765`.
|
||||
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
|
||||
use crate::config::schema::Mem0Config;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Memory backend backed by a mem0 (OpenMemory) REST API.
|
||||
pub struct Mem0Memory {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
user_id: String,
|
||||
app_name: String,
|
||||
infer: bool,
|
||||
extraction_prompt: Option<String>,
|
||||
}
|
||||
|
||||
// ── mem0 API request/response types ────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AddMemoryRequest<'a> {
|
||||
user_id: &'a str,
|
||||
text: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
metadata: Option<Mem0Metadata<'a>>,
|
||||
infer: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
app: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
custom_instructions: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Mem0Metadata<'a> {
|
||||
key: &'a str,
|
||||
category: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AddProceduralRequest<'a> {
|
||||
user_id: &'a str,
|
||||
messages: &'a [ProceduralMessage],
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DeleteMemoriesRequest<'a> {
|
||||
memory_ids: Vec<&'a str>,
|
||||
user_id: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0MemoryItem {
|
||||
id: String,
|
||||
#[serde(alias = "content", alias = "text", default)]
|
||||
memory: String,
|
||||
#[serde(default)]
|
||||
created_at: Option<serde_json::Value>,
|
||||
#[serde(default, rename = "metadata_")]
|
||||
metadata: Option<Mem0ResponseMetadata>,
|
||||
#[serde(alias = "relevance_score", default)]
|
||||
score: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
struct Mem0ResponseMetadata {
|
||||
#[serde(default)]
|
||||
key: Option<String>,
|
||||
#[serde(default)]
|
||||
category: Option<String>,
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0ListResponse {
|
||||
#[serde(default)]
|
||||
items: Vec<Mem0MemoryItem>,
|
||||
#[serde(default)]
|
||||
total: usize,
|
||||
}
|
||||
|
||||
// ── Implementation ─────────────────────────────────────────────────
|
||||
|
||||
impl Mem0Memory {
|
||||
/// Create a new mem0 memory backend from config.
|
||||
pub fn new(config: &Mem0Config) -> anyhow::Result<Self> {
|
||||
let base_url = config.url.trim_end_matches('/').to_string();
|
||||
if base_url.is_empty() {
|
||||
anyhow::bail!("mem0 URL is empty; set [memory.mem0] url or MEM0_URL env var");
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
user_id: config.user_id.clone(),
|
||||
app_name: config.app_name.clone(),
|
||||
infer: config.infer,
|
||||
extraction_prompt: config.extraction_prompt.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn api_url(&self, path: &str) -> String {
|
||||
format!("{}/api/v1{}", self.base_url, path)
|
||||
}
|
||||
|
||||
/// Use `session_id` as the effective mem0 `user_id` when provided,
|
||||
/// falling back to the configured default. This enables per-user
|
||||
/// and per-group memory scoping via the existing `Memory` trait.
|
||||
fn effective_user_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
|
||||
session_id
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
/// Recall memories with optional search filters.
|
||||
///
|
||||
/// - `created_after` / `created_before`: ISO 8601 timestamps for time-range filtering.
|
||||
/// - `metadata_filter`: arbitrary JSON object passed to the mem0 SDK `filters` param.
|
||||
pub async fn recall_filtered(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
created_after: Option<&str>,
|
||||
created_before: Option<&str>,
|
||||
metadata_filter: Option<&serde_json::Value>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let limit_str = limit.to_string();
|
||||
let mut params: Vec<(&str, &str)> = vec![
|
||||
("user_id", effective_user),
|
||||
("search_query", query),
|
||||
("size", &limit_str),
|
||||
];
|
||||
if let Some(after) = created_after {
|
||||
params.push(("created_after", after));
|
||||
}
|
||||
if let Some(before) = created_before {
|
||||
params.push(("created_before", before));
|
||||
}
|
||||
let meta_json;
|
||||
if let Some(mf) = metadata_filter {
|
||||
meta_json = serde_json::to_string(mf)?;
|
||||
params.push(("metadata_filter", &meta_json));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 recall failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
Ok(list.items.into_iter().map(|i| self.to_entry(i)).collect())
|
||||
}
|
||||
|
||||
fn to_entry(&self, item: Mem0MemoryItem) -> MemoryEntry {
|
||||
let meta = item.metadata.unwrap_or_default();
|
||||
let timestamp = match item.created_at {
|
||||
Some(serde_json::Value::Number(n)) => {
|
||||
// Unix timestamp → ISO 8601
|
||||
if let Some(ts) = n.as_i64() {
|
||||
chrono::DateTime::from_timestamp(ts, 0)
|
||||
.map(|dt| dt.to_rfc3339())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
Some(serde_json::Value::String(s)) => s,
|
||||
_ => String::new(),
|
||||
};
|
||||
|
||||
let category = match meta.category.as_deref() {
|
||||
Some("daily") => MemoryCategory::Daily,
|
||||
Some("conversation") => MemoryCategory::Conversation,
|
||||
Some(other) if other != "core" => MemoryCategory::Custom(other.to_string()),
|
||||
// "core" or None → default
|
||||
_ => MemoryCategory::Core,
|
||||
};
|
||||
|
||||
MemoryEntry {
|
||||
id: item.id,
|
||||
key: meta.key.unwrap_or_default(),
|
||||
content: item.memory,
|
||||
category,
|
||||
timestamp,
|
||||
session_id: meta.session_id,
|
||||
score: item.score,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a conversation trace as procedural memory.
|
||||
///
|
||||
/// Sends the message history (user input, tool calls, assistant response)
|
||||
/// to the mem0 procedural endpoint so that "how to" patterns can be
|
||||
/// extracted and stored for future recall.
|
||||
pub async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
if messages.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let body = AddProceduralRequest {
|
||||
user_id: effective_user,
|
||||
messages,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url("/memories/procedural"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 store_procedural failed ({status}): {text}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── History API types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Mem0HistoryResponse {
|
||||
#[serde(default)]
|
||||
history: Vec<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl Mem0Memory {
|
||||
/// Retrieve the edit history (audit trail) for a specific memory by ID.
|
||||
pub async fn history(&self, memory_id: &str) -> anyhow::Result<String> {
|
||||
let url = self.api_url(&format!("/memories/{memory_id}/history"));
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 history failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let body: Mem0HistoryResponse = resp.json().await?;
|
||||
|
||||
if let Some(err) = body.error {
|
||||
anyhow::bail!("mem0 history error: {err}");
|
||||
}
|
||||
|
||||
if body.history.is_empty() {
|
||||
return Ok(format!("No history found for memory {memory_id}."));
|
||||
}
|
||||
|
||||
let mut lines = Vec::with_capacity(body.history.len() + 1);
|
||||
lines.push(format!("History for memory {memory_id}:"));
|
||||
|
||||
for (i, entry) in body.history.iter().enumerate() {
|
||||
let event = entry
|
||||
.get("event")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let old_memory = entry
|
||||
.get("old_memory")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-");
|
||||
let new_memory = entry
|
||||
.get("new_memory")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-");
|
||||
let timestamp = entry
|
||||
.get("created_at")
|
||||
.or_else(|| entry.get("timestamp"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
lines.push(format!(
|
||||
" {idx}. [{event}] at {timestamp}\n old: {old_memory}\n new: {new_memory}",
|
||||
idx = i + 1,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for Mem0Memory {
|
||||
fn name(&self) -> &str {
|
||||
"mem0"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cat_str = category.to_string();
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let body = AddMemoryRequest {
|
||||
user_id: effective_user,
|
||||
text: content,
|
||||
metadata: Some(Mem0Metadata {
|
||||
key,
|
||||
category: &cat_str,
|
||||
session_id,
|
||||
}),
|
||||
infer: self.infer,
|
||||
app: Some(&self.app_name),
|
||||
custom_instructions: self.extraction_prompt.as_deref(),
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url("/memories/"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 store failed ({status}): {text}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// mem0 handles filtering server-side; since/until are not yet
|
||||
// supported by the mem0 API, so we pass them through as no-ops.
|
||||
self.recall_filtered(query, limit, session_id, None, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
// mem0 doesn't have a get-by-key API, so we search by key in metadata
|
||||
let results = self.recall(key, 1, None, None, None).await?;
|
||||
Ok(results.into_iter().find(|e| e.key == key))
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let effective_user = self.effective_user_id(session_id);
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[("user_id", effective_user), ("size", "100")])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 list failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
let entries: Vec<MemoryEntry> = list.items.into_iter().map(|i| self.to_entry(i)).collect();
|
||||
|
||||
// Client-side category filter (mem0 API doesn't filter by metadata)
|
||||
match category {
|
||||
Some(cat) => Ok(entries.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
None => Ok(entries),
|
||||
}
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
// Find the memory ID by key first
|
||||
let entry = self.get(key).await?;
|
||||
let entry = match entry {
|
||||
Some(e) => e,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let body = DeleteMemoriesRequest {
|
||||
memory_ids: vec![&entry.id],
|
||||
user_id: &self.user_id,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.delete(self.api_url("/memories/"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let resp = self
|
||||
.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[
|
||||
("user_id", self.user_id.as_str()),
|
||||
("size", "1"),
|
||||
("page", "1"),
|
||||
])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("mem0 count failed ({status}): {text}");
|
||||
}
|
||||
|
||||
let list: Mem0ListResponse = resp.json().await?;
|
||||
Ok(list.total)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get(self.api_url("/memories/"))
|
||||
.query(&[
|
||||
("user_id", self.user_id.as_str()),
|
||||
("size", "1"),
|
||||
("page", "1"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.is_ok_and(|r| r.status().is_success())
|
||||
}
|
||||
|
||||
async fn store_procedural(
|
||||
&self,
|
||||
messages: &[ProceduralMessage],
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Mem0Memory::store_procedural(self, messages, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> Mem0Config {
|
||||
Mem0Config {
|
||||
url: "http://localhost:8765".into(),
|
||||
user_id: "test-user".into(),
|
||||
app_name: "test-app".into(),
|
||||
infer: true,
|
||||
extraction_prompt: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_rejects_empty_url() {
|
||||
let config = Mem0Config {
|
||||
url: String::new(),
|
||||
..test_config()
|
||||
};
|
||||
assert!(Mem0Memory::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_trims_trailing_slash() {
|
||||
let config = Mem0Config {
|
||||
url: "http://localhost:8765/".into(),
|
||||
..test_config()
|
||||
};
|
||||
let mem = Mem0Memory::new(&config).unwrap();
|
||||
assert_eq!(mem.base_url, "http://localhost:8765");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_url_builds_correct_path() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
assert_eq!(
|
||||
mem.api_url("/memories/"),
|
||||
"http://localhost:8765/api/v1/memories/"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_maps_unix_timestamp() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-1".into(),
|
||||
memory: "hello".into(),
|
||||
created_at: Some(serde_json::json!(1_700_000_000)),
|
||||
metadata: Some(Mem0ResponseMetadata {
|
||||
key: Some("k1".into()),
|
||||
category: Some("core".into()),
|
||||
session_id: None,
|
||||
}),
|
||||
score: Some(0.95),
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.id, "id-1");
|
||||
assert_eq!(entry.key, "k1");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
assert!(!entry.timestamp.is_empty());
|
||||
assert_eq!(entry.score, Some(0.95));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_maps_string_timestamp() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-2".into(),
|
||||
memory: "world".into(),
|
||||
created_at: Some(serde_json::json!("2024-01-01T00:00:00Z")),
|
||||
metadata: None,
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.timestamp, "2024-01-01T00:00:00Z");
|
||||
assert_eq!(entry.category, MemoryCategory::Core); // default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_handles_missing_metadata() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-3".into(),
|
||||
memory: "bare".into(),
|
||||
created_at: None,
|
||||
metadata: None,
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(entry.key, "");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
assert!(entry.timestamp.is_empty());
|
||||
assert_eq!(entry.score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_entry_custom_category() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
let item = Mem0MemoryItem {
|
||||
id: "id-4".into(),
|
||||
memory: "custom".into(),
|
||||
created_at: None,
|
||||
metadata: Some(Mem0ResponseMetadata {
|
||||
key: Some("k".into()),
|
||||
category: Some("project_notes".into()),
|
||||
session_id: Some("s1".into()),
|
||||
}),
|
||||
score: None,
|
||||
};
|
||||
let entry = mem.to_entry(item);
|
||||
assert_eq!(
|
||||
entry.category,
|
||||
MemoryCategory::Custom("project_notes".into())
|
||||
);
|
||||
assert_eq!(entry.session_id.as_deref(), Some("s1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_returns_mem0() {
|
||||
let mem = Mem0Memory::new(&test_config()).unwrap();
|
||||
assert_eq!(mem.name(), "mem0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn procedural_request_serializes_messages() {
|
||||
let messages = vec![
|
||||
ProceduralMessage {
|
||||
role: "user".into(),
|
||||
content: "How do I deploy?".into(),
|
||||
name: None,
|
||||
},
|
||||
ProceduralMessage {
|
||||
role: "tool".into(),
|
||||
content: "deployment started".into(),
|
||||
name: Some("shell".into()),
|
||||
},
|
||||
ProceduralMessage {
|
||||
role: "assistant".into(),
|
||||
content: "Deployment complete.".into(),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
let req = AddProceduralRequest {
|
||||
user_id: "test-user",
|
||||
messages: &messages,
|
||||
metadata: None,
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(json["user_id"], "test-user");
|
||||
let msgs = json["messages"].as_array().unwrap();
|
||||
assert_eq!(msgs.len(), 3);
|
||||
assert_eq!(msgs[0]["role"], "user");
|
||||
assert_eq!(msgs[1]["name"], "shell");
|
||||
// metadata should be absent when None
|
||||
assert!(json.get("metadata").is_none());
|
||||
}
|
||||
}
|
||||
+15
-27
@@ -1,24 +1,32 @@
|
||||
pub mod audit;
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod conflict;
|
||||
pub mod consolidation;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod importance;
|
||||
pub mod knowledge_graph;
|
||||
pub mod lucid;
|
||||
pub mod markdown;
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
pub mod mem0;
|
||||
pub mod none;
|
||||
pub mod policy;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub mod postgres;
|
||||
pub mod qdrant;
|
||||
pub mod response_cache;
|
||||
pub mod retrieval;
|
||||
pub mod snapshot;
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
pub mod vector;
|
||||
|
||||
#[cfg(test)]
|
||||
mod battle_tests;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use audit::AuditedMemory;
|
||||
#[allow(unused_imports)]
|
||||
pub use backend::{
|
||||
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
|
||||
@@ -26,13 +34,15 @@ pub use backend::{
|
||||
};
|
||||
pub use lucid::LucidMemory;
|
||||
pub use markdown::MarkdownMemory;
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
pub use mem0::Mem0Memory;
|
||||
pub use none::NoneMemory;
|
||||
#[allow(unused_imports)]
|
||||
pub use policy::PolicyEnforcer;
|
||||
#[cfg(feature = "memory-postgres")]
|
||||
pub use postgres::PostgresMemory;
|
||||
pub use qdrant::QdrantMemory;
|
||||
pub use response_cache::ResponseCache;
|
||||
#[allow(unused_imports)]
|
||||
pub use retrieval::{RetrievalConfig, RetrievalPipeline};
|
||||
pub use sqlite::SqliteMemory;
|
||||
pub use traits::Memory;
|
||||
#[allow(unused_imports)]
|
||||
@@ -61,7 +71,7 @@ where
|
||||
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
|
||||
}
|
||||
MemoryBackendKind::Postgres => postgres_builder(),
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Mem0 | MemoryBackendKind::Markdown => {
|
||||
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
|
||||
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
|
||||
}
|
||||
MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())),
|
||||
@@ -340,28 +350,6 @@ pub fn create_memory_with_storage_and_routes(
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "memory-mem0")]
|
||||
fn build_mem0_memory(config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
|
||||
let mem = Mem0Memory::new(&config.mem0)?;
|
||||
tracing::info!(
|
||||
"📦 Mem0 memory backend configured (url: {}, user: {})",
|
||||
config.mem0.url,
|
||||
config.mem0.user_id
|
||||
);
|
||||
Ok(Box::new(mem))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "memory-mem0"))]
|
||||
fn build_mem0_memory(_config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
|
||||
anyhow::bail!(
|
||||
"memory backend 'mem0' requested but this build was compiled without `memory-mem0`; rebuild with `--features memory-mem0`"
|
||||
);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Mem0) {
|
||||
return build_mem0_memory(config);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
|
||||
let url = config
|
||||
.qdrant
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
//! Policy engine for memory operations.
|
||||
//!
|
||||
//! Validates operations against configurable rules before they reach the
|
||||
//! backend. Enforces namespace quotas, category limits, read-only namespaces,
|
||||
//! and per-category retention rules.
|
||||
|
||||
use super::traits::MemoryCategory;
|
||||
use crate::config::MemoryPolicyConfig;
|
||||
|
||||
/// Policy enforcer that validates memory operations.
|
||||
pub struct PolicyEnforcer {
|
||||
config: MemoryPolicyConfig,
|
||||
}
|
||||
|
||||
impl PolicyEnforcer {
|
||||
pub fn new(config: &MemoryPolicyConfig) -> Self {
|
||||
Self {
|
||||
config: config.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a namespace is read-only.
|
||||
pub fn is_read_only(&self, namespace: &str) -> bool {
|
||||
self.config
|
||||
.read_only_namespaces
|
||||
.iter()
|
||||
.any(|ns| ns == namespace)
|
||||
}
|
||||
|
||||
/// Validate a store operation against policy rules.
|
||||
pub fn validate_store(
|
||||
&self,
|
||||
namespace: &str,
|
||||
_category: &MemoryCategory,
|
||||
) -> Result<(), PolicyViolation> {
|
||||
if self.is_read_only(namespace) {
|
||||
return Err(PolicyViolation::ReadOnlyNamespace(namespace.to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if adding an entry would exceed namespace limits.
|
||||
pub fn check_namespace_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
|
||||
if self.config.max_entries_per_namespace > 0
|
||||
&& current_count >= self.config.max_entries_per_namespace
|
||||
{
|
||||
return Err(PolicyViolation::NamespaceQuotaExceeded {
|
||||
max: self.config.max_entries_per_namespace,
|
||||
current: current_count,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if adding an entry would exceed category limits.
|
||||
pub fn check_category_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
|
||||
if self.config.max_entries_per_category > 0
|
||||
&& current_count >= self.config.max_entries_per_category
|
||||
{
|
||||
return Err(PolicyViolation::CategoryQuotaExceeded {
|
||||
max: self.config.max_entries_per_category,
|
||||
current: current_count,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the retention days for a specific category, falling back to the
|
||||
/// provided default if no per-category override exists.
|
||||
pub fn retention_days_for_category(&self, category: &MemoryCategory, default_days: u32) -> u32 {
|
||||
let key = category.to_string();
|
||||
self.config
|
||||
.retention_days_by_category
|
||||
.get(&key)
|
||||
.copied()
|
||||
.unwrap_or(default_days)
|
||||
}
|
||||
}
|
||||
|
||||
/// Policy violation errors.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PolicyViolation {
|
||||
ReadOnlyNamespace(String),
|
||||
NamespaceQuotaExceeded { max: usize, current: usize },
|
||||
CategoryQuotaExceeded { max: usize, current: usize },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PolicyViolation {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ReadOnlyNamespace(ns) => write!(f, "namespace '{ns}' is read-only"),
|
||||
Self::NamespaceQuotaExceeded { max, current } => {
|
||||
write!(f, "namespace quota exceeded: {current}/{max} entries")
|
||||
}
|
||||
Self::CategoryQuotaExceeded { max, current } => {
|
||||
write!(f, "category quota exceeded: {current}/{max} entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PolicyViolation {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn empty_policy() -> MemoryPolicyConfig {
|
||||
MemoryPolicyConfig::default()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_policy_allows_everything() {
|
||||
let enforcer = PolicyEnforcer::new(&empty_policy());
|
||||
assert!(!enforcer.is_read_only("default"));
|
||||
assert!(enforcer
|
||||
.validate_store("default", &MemoryCategory::Core)
|
||||
.is_ok());
|
||||
assert!(enforcer.check_namespace_limit(100).is_ok());
|
||||
assert!(enforcer.check_category_limit(100).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_namespace_blocks_writes() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
read_only_namespaces: vec!["archive".into()],
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.is_read_only("archive"));
|
||||
assert!(!enforcer.is_read_only("default"));
|
||||
assert!(enforcer
|
||||
.validate_store("archive", &MemoryCategory::Core)
|
||||
.is_err());
|
||||
assert!(enforcer
|
||||
.validate_store("default", &MemoryCategory::Core)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn namespace_quota_enforced() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
max_entries_per_namespace: 10,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.check_namespace_limit(5).is_ok());
|
||||
assert!(enforcer.check_namespace_limit(10).is_err());
|
||||
assert!(enforcer.check_namespace_limit(15).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn category_quota_enforced() {
|
||||
let policy = MemoryPolicyConfig {
|
||||
max_entries_per_category: 50,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert!(enforcer.check_category_limit(25).is_ok());
|
||||
assert!(enforcer.check_category_limit(50).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_category_retention_overrides_default() {
|
||||
let mut retention = HashMap::new();
|
||||
retention.insert("core".into(), 365);
|
||||
retention.insert("conversation".into(), 7);
|
||||
|
||||
let policy = MemoryPolicyConfig {
|
||||
retention_days_by_category: retention,
|
||||
..empty_policy()
|
||||
};
|
||||
let enforcer = PolicyEnforcer::new(&policy);
|
||||
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Core, 30),
|
||||
365
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Conversation, 30),
|
||||
7
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.retention_days_for_category(&MemoryCategory::Daily, 30),
|
||||
30
|
||||
);
|
||||
}
|
||||
}
|
||||
+12
-3
@@ -100,6 +100,8 @@ impl PostgresMemory {
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_category ON {qualified_table}(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session_id ON {qualified_table}(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_updated_at ON {qualified_table}(updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_content_fts ON {qualified_table} USING gin(to_tsvector('simple', content));
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_key_fts ON {qualified_table} USING gin(to_tsvector('simple', key));
|
||||
"
|
||||
))?;
|
||||
|
||||
@@ -135,6 +137,9 @@ impl PostgresMemory {
|
||||
timestamp: timestamp.to_rfc3339(),
|
||||
session_id: row.get(5),
|
||||
score: row.try_get(6).ok(),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -267,12 +272,16 @@ impl Memory for PostgresMemory {
|
||||
"
|
||||
SELECT id, key, content, category, created_at, session_id,
|
||||
(
|
||||
CASE WHEN key ILIKE '%' || $1 || '%' THEN 2.0 ELSE 0.0 END +
|
||||
CASE WHEN content ILIKE '%' || $1 || '%' THEN 1.0 ELSE 0.0 END
|
||||
CASE WHEN to_tsvector('simple', key) @@ plainto_tsquery('simple', $1)
|
||||
THEN ts_rank_cd(to_tsvector('simple', key), plainto_tsquery('simple', $1)) * 2.0
|
||||
ELSE 0.0 END +
|
||||
CASE WHEN to_tsvector('simple', content) @@ plainto_tsquery('simple', $1)
|
||||
THEN ts_rank_cd(to_tsvector('simple', content), plainto_tsquery('simple', $1))
|
||||
ELSE 0.0 END
|
||||
) AS score
|
||||
FROM {qualified_table}
|
||||
WHERE ($2::TEXT IS NULL OR session_id = $2)
|
||||
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
|
||||
AND ($1 = '' OR to_tsvector('simple', key || ' ' || content) @@ plainto_tsquery('simple', $1))
|
||||
{time_filter}
|
||||
ORDER BY score DESC, updated_at DESC
|
||||
LIMIT $3
|
||||
|
||||
@@ -373,6 +373,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: Some(point.score),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -437,6 +440,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
});
|
||||
|
||||
@@ -514,6 +520,9 @@ impl Memory for QdrantMemory {
|
||||
timestamp: payload.timestamp,
|
||||
session_id: payload.session_id,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
//! Multi-stage retrieval pipeline.
|
||||
//!
|
||||
//! Wraps a `Memory` trait object with staged retrieval:
|
||||
//! - **Stage 1 (Hot cache):** In-memory LRU of recent recall results.
|
||||
//! - **Stage 2 (FTS):** FTS5 keyword search with optional early-return.
|
||||
//! - **Stage 3 (Vector):** Vector similarity search + hybrid merge.
|
||||
//!
|
||||
//! Configurable via `[memory]` settings: `retrieval_stages`, `fts_early_return_score`.
|
||||
|
||||
use super::traits::{Memory, MemoryEntry};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// A cached recall result.
|
||||
struct CachedResult {
|
||||
entries: Vec<MemoryEntry>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
/// Multi-stage retrieval pipeline configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
/// Ordered list of stages: "cache", "fts", "vector".
|
||||
pub stages: Vec<String>,
|
||||
/// FTS score above which to early-return without vector stage.
|
||||
pub fts_early_return_score: f64,
|
||||
/// Max entries in the hot cache.
|
||||
pub cache_max_entries: usize,
|
||||
/// TTL for cached results.
|
||||
pub cache_ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stages: vec!["cache".into(), "fts".into(), "vector".into()],
|
||||
fts_early_return_score: 0.85,
|
||||
cache_max_entries: 256,
|
||||
cache_ttl: Duration::from_secs(300),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-stage retrieval pipeline wrapping a `Memory` backend.
|
||||
pub struct RetrievalPipeline {
|
||||
memory: Arc<dyn Memory>,
|
||||
config: RetrievalConfig,
|
||||
hot_cache: Mutex<HashMap<String, CachedResult>>,
|
||||
}
|
||||
|
||||
impl RetrievalPipeline {
|
||||
pub fn new(memory: Arc<dyn Memory>, config: RetrievalConfig) -> Self {
|
||||
Self {
|
||||
memory,
|
||||
config,
|
||||
hot_cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a cache key from query parameters.
|
||||
fn cache_key(
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
) -> String {
|
||||
format!(
|
||||
"{}:{}:{}:{}",
|
||||
query,
|
||||
limit,
|
||||
session_id.unwrap_or(""),
|
||||
namespace.unwrap_or("")
|
||||
)
|
||||
}
|
||||
|
||||
/// Check the hot cache for a previous result.
|
||||
fn check_cache(&self, key: &str) -> Option<Vec<MemoryEntry>> {
|
||||
let cache = self.hot_cache.lock();
|
||||
if let Some(cached) = cache.get(key) {
|
||||
if cached.created_at.elapsed() < self.config.cache_ttl {
|
||||
return Some(cached.entries.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Store a result in the hot cache with LRU eviction.
|
||||
fn store_in_cache(&self, key: String, entries: Vec<MemoryEntry>) {
|
||||
let mut cache = self.hot_cache.lock();
|
||||
|
||||
// LRU eviction: remove oldest entries if at capacity
|
||||
if cache.len() >= self.config.cache_max_entries {
|
||||
let oldest_key = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.created_at)
|
||||
.map(|(k, _)| k.clone());
|
||||
if let Some(k) = oldest_key {
|
||||
cache.remove(&k);
|
||||
}
|
||||
}
|
||||
|
||||
cache.insert(
|
||||
key,
|
||||
CachedResult {
|
||||
entries,
|
||||
created_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Execute the multi-stage retrieval pipeline.
|
||||
pub async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let ck = Self::cache_key(query, limit, session_id, namespace);
|
||||
|
||||
for stage in &self.config.stages {
|
||||
match stage.as_str() {
|
||||
"cache" => {
|
||||
if let Some(cached) = self.check_cache(&ck) {
|
||||
tracing::debug!("retrieval pipeline: cache hit for '{query}'");
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
"fts" | "vector" => {
|
||||
// Both FTS and vector are handled by the backend's recall method
|
||||
// which already does hybrid merge. We delegate to it.
|
||||
let results = if let Some(ns) = namespace {
|
||||
self.memory
|
||||
.recall_namespaced(ns, query, limit, session_id, since, until)
|
||||
.await?
|
||||
} else {
|
||||
self.memory
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await?
|
||||
};
|
||||
|
||||
if !results.is_empty() {
|
||||
// Check for FTS early-return: if top score exceeds threshold
|
||||
// and we're in the FTS stage, we can skip further stages
|
||||
if stage == "fts" {
|
||||
if let Some(top_score) = results.first().and_then(|e| e.score) {
|
||||
if top_score >= self.config.fts_early_return_score {
|
||||
tracing::debug!(
|
||||
"retrieval pipeline: FTS early return (score={top_score:.3})"
|
||||
);
|
||||
self.store_in_cache(ck, results.clone());
|
||||
return Ok(results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.store_in_cache(ck, results.clone());
|
||||
return Ok(results);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("retrieval pipeline: unknown stage '{other}', skipping");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No results from any stage
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Invalidate the hot cache (e.g. after a store operation).
|
||||
pub fn invalidate_cache(&self) {
|
||||
self.hot_cache.lock().clear();
|
||||
}
|
||||
|
||||
/// Get the number of entries in the hot cache.
|
||||
pub fn cache_size(&self) -> usize {
|
||||
self.hot_cache.lock().len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::NoneMemory;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_returns_empty_from_none_backend() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
|
||||
|
||||
let results = pipeline
|
||||
.recall("test", 10, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_cache_invalidation() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
|
||||
|
||||
// Force a cache entry
|
||||
let ck = RetrievalPipeline::cache_key("test", 10, None, None);
|
||||
pipeline.store_in_cache(ck, vec![]);
|
||||
|
||||
assert_eq!(pipeline.cache_size(), 1);
|
||||
pipeline.invalidate_cache();
|
||||
assert_eq!(pipeline.cache_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cache_key_includes_all_params() {
|
||||
let k1 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns1"));
|
||||
let k2 = RetrievalPipeline::cache_key("hello", 10, Some("sess-b"), Some("ns1"));
|
||||
let k3 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns2"));
|
||||
|
||||
assert_ne!(k1, k2);
|
||||
assert_ne!(k1, k3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_caches_results() {
|
||||
let memory = Arc::new(NoneMemory::new());
|
||||
let config = RetrievalConfig {
|
||||
stages: vec!["cache".into()],
|
||||
..Default::default()
|
||||
};
|
||||
let pipeline = RetrievalPipeline::new(memory, config);
|
||||
|
||||
// First call: cache miss, no results
|
||||
let results = pipeline
|
||||
.recall("test", 10, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Manually insert a cache entry
|
||||
let ck = RetrievalPipeline::cache_key("cached_query", 5, None, None);
|
||||
let fake_entry = MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "k".into(),
|
||||
content: "cached content".into(),
|
||||
category: crate::memory::MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
};
|
||||
pipeline.store_in_cache(ck, vec![fake_entry]);
|
||||
|
||||
// Cache hit
|
||||
let results = pipeline
|
||||
.recall("cached_query", 5, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].content, "cached content");
|
||||
}
|
||||
}
|
||||
+154
-23
@@ -46,6 +46,31 @@ impl SqliteMemory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Like `new`, but stores data in `{db_name}.db` instead of `brain.db`.
|
||||
pub fn new_named(workspace_dir: &Path, db_name: &str) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join(format!("{db_name}.db"));
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let conn = Self::open_connection(&db_path, None)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA mmap_size = 8388608;
|
||||
PRAGMA cache_size = -2000;
|
||||
PRAGMA temp_store = MEMORY;",
|
||||
)?;
|
||||
Self::init_schema(&conn)?;
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
db_path,
|
||||
embedder: Arc::new(super::embeddings::NoopEmbedding),
|
||||
vector_weight: 0.7,
|
||||
keyword_weight: 0.3,
|
||||
cache_max: 10_000,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build SQLite memory with optional open timeout.
|
||||
///
|
||||
/// If `open_timeout_secs` is `Some(n)`, opening the database is limited to `n` seconds
|
||||
@@ -172,17 +197,35 @@ impl SqliteMemory {
|
||||
)?;
|
||||
|
||||
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||
let has_session_id: bool = conn
|
||||
let schema_sql: String = conn
|
||||
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||
.query_row([], |row| row.get::<_, String>(0))?
|
||||
.contains("session_id");
|
||||
if !has_session_id {
|
||||
.query_row([], |row| row.get::<_, String>(0))?;
|
||||
|
||||
if !schema_sql.contains("session_id") {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Migration: add namespace column
|
||||
if !schema_sql.contains("namespace") {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN namespace TEXT DEFAULT 'default';
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Migration: add importance column
|
||||
if !schema_sql.contains("importance") {
|
||||
conn.execute_batch("ALTER TABLE memories ADD COLUMN importance REAL DEFAULT 0.5;")?;
|
||||
}
|
||||
|
||||
// Migration: add superseded_by column
|
||||
if !schema_sql.contains("superseded_by") {
|
||||
conn.execute_batch("ALTER TABLE memories ADD COLUMN superseded_by TEXT;")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -221,8 +264,13 @@ impl SqliteMemory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Provide access to the connection for advanced queries (e.g. retrieval pipeline).
|
||||
pub fn connection(&self) -> &Arc<Mutex<Connection>> {
|
||||
&self.conn
|
||||
}
|
||||
|
||||
/// Get embedding from cache, or compute + cache it
|
||||
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||
pub async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||
if self.embedder.dimensions() == 0 {
|
||||
return Ok(None); // Noop embedder
|
||||
}
|
||||
@@ -285,7 +333,7 @@ impl SqliteMemory {
|
||||
}
|
||||
|
||||
/// FTS5 BM25 keyword search
|
||||
fn fts5_search(
|
||||
pub fn fts5_search(
|
||||
conn: &Connection,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
@@ -331,7 +379,7 @@ impl SqliteMemory {
|
||||
///
|
||||
/// Optional `category` and `session_id` filters reduce full-table scans
|
||||
/// when the caller already knows the scope of relevant memories.
|
||||
fn vector_search(
|
||||
pub fn vector_search(
|
||||
conn: &Connection,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
@@ -448,8 +496,8 @@ impl SqliteMemory {
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let mut sql =
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories \
|
||||
WHERE 1=1"
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories \
|
||||
WHERE superseded_by IS NULL AND 1=1"
|
||||
.to_string();
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
let mut idx = 1;
|
||||
@@ -485,6 +533,9 @@ impl SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -529,8 +580,8 @@ impl Memory for SqliteMemory {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 'default', 0.5)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
@@ -616,8 +667,8 @@ impl Memory for SqliteMemory {
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id \
|
||||
FROM memories WHERE id IN ({placeholders})"
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by \
|
||||
FROM memories WHERE superseded_by IS NULL AND id IN ({placeholders})"
|
||||
);
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
|
||||
@@ -634,17 +685,20 @@ impl Memory for SqliteMemory {
|
||||
row.get::<_, String>(3)?,
|
||||
row.get::<_, String>(4)?,
|
||||
row.get::<_, Option<String>>(5)?,
|
||||
row.get::<_, Option<String>>(6)?,
|
||||
row.get::<_, Option<f64>>(7)?,
|
||||
row.get::<_, Option<String>>(8)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut entry_map = std::collections::HashMap::new();
|
||||
for row in rows {
|
||||
let (id, key, content, cat, ts, sid) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid));
|
||||
let (id, key, content, cat, ts, sid, ns, imp, sup) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid, ns, imp, sup));
|
||||
}
|
||||
|
||||
for scored in &merged {
|
||||
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||
if let Some((key, content, cat, ts, sid, ns, imp, sup)) = entry_map.remove(&scored.id) {
|
||||
if let Some(s) = since_ref {
|
||||
if ts.as_str() < s {
|
||||
continue;
|
||||
@@ -663,6 +717,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: ts,
|
||||
session_id: sid,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
namespace: ns.unwrap_or_else(|| "default".into()),
|
||||
importance: imp,
|
||||
superseded_by: sup,
|
||||
};
|
||||
if let Some(filter_sid) = session_ref {
|
||||
if entry.session_id.as_deref() != Some(filter_sid) {
|
||||
@@ -702,8 +759,8 @@ impl Memory for SqliteMemory {
|
||||
param_idx += 1;
|
||||
}
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}{time_conditions}
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL AND ({where_clause}){time_conditions}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{param_idx}"
|
||||
);
|
||||
@@ -732,6 +789,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(1.0),
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
for row in rows {
|
||||
@@ -759,7 +819,7 @@ impl Memory for SqliteMemory {
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
@@ -771,6 +831,9 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -807,14 +870,17 @@ impl Memory for SqliteMemory {
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
|
||||
importance: row.get(7)?,
|
||||
superseded_by: row.get(8)?,
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(ref cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL AND category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
@@ -828,8 +894,8 @@ impl Memory for SqliteMemory {
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC LIMIT ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
|
||||
WHERE superseded_by IS NULL ORDER BY updated_at DESC LIMIT ?1",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
@@ -879,6 +945,71 @@ impl Memory for SqliteMemory {
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let entries = self
|
||||
.recall(query, limit * 2, session_id, since, until)
|
||||
.await?;
|
||||
let filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.namespace == namespace)
|
||||
.take(limit)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
namespace: Option<&str>,
|
||||
importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
let embedding_bytes = self
|
||||
.get_or_compute_embedding(content)
|
||||
.await?
|
||||
.map(|emb| vector::vec_to_bytes(&emb));
|
||||
|
||||
let conn = self.conn.clone();
|
||||
let key = key.to_string();
|
||||
let content = content.to_string();
|
||||
let sid = session_id.map(String::from);
|
||||
let ns = namespace.unwrap_or("default").to_string();
|
||||
let imp = importance.unwrap_or(0.5);
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let conn = conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let cat = Self::category_to_str(&category);
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
embedding = excluded.embedding,
|
||||
updated_at = excluded.updated_at,
|
||||
session_id = excluded.session_id,
|
||||
namespace = excluded.namespace,
|
||||
importance = excluded.importance",
|
||||
params![id, key, content, cat, embedding_bytes, now, now, sid, ns, imp],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+63
-2
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
/// A single message in a conversation trace for procedural memory.
|
||||
///
|
||||
/// Used to capture "how to" patterns from tool-calling turns so that
|
||||
/// backends that support procedural storage (e.g. mem0) can learn from them.
|
||||
/// backends that support procedural storage can learn from them.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ProceduralMessage {
|
||||
pub role: String,
|
||||
@@ -23,6 +23,19 @@ pub struct MemoryEntry {
|
||||
pub timestamp: String,
|
||||
pub session_id: Option<String>,
|
||||
pub score: Option<f64>,
|
||||
/// Namespace for isolation between agents/contexts.
|
||||
#[serde(default = "default_namespace")]
|
||||
pub namespace: String,
|
||||
/// Importance score (0.0–1.0) for prioritized retrieval.
|
||||
#[serde(default)]
|
||||
pub importance: Option<f64>,
|
||||
/// If this entry was superseded by a newer conflicting entry.
|
||||
#[serde(default)]
|
||||
pub superseded_by: Option<String>,
|
||||
}
|
||||
|
||||
fn default_namespace() -> String {
|
||||
"default".into()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MemoryEntry {
|
||||
@@ -34,6 +47,8 @@ impl std::fmt::Debug for MemoryEntry {
|
||||
.field("category", &self.category)
|
||||
.field("timestamp", &self.timestamp)
|
||||
.field("score", &self.score)
|
||||
.field("namespace", &self.namespace)
|
||||
.field("importance", &self.importance)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -128,7 +143,7 @@ pub trait Memory: Send + Sync {
|
||||
|
||||
/// Store a conversation trace as procedural memory.
|
||||
///
|
||||
/// Backends that support procedural storage (e.g. mem0) override this
|
||||
/// Backends that support procedural storage override this
|
||||
/// to extract "how to" patterns from tool-calling turns. The default
|
||||
/// implementation is a no-op.
|
||||
async fn store_procedural(
|
||||
@@ -138,6 +153,46 @@ pub trait Memory: Send + Sync {
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recall memories scoped to a specific namespace.
|
||||
///
|
||||
/// Default implementation delegates to `recall()` and filters by namespace.
|
||||
/// Backends with native namespace support should override for efficiency.
|
||||
async fn recall_namespaced(
|
||||
&self,
|
||||
namespace: &str,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let entries = self
|
||||
.recall(query, limit * 2, session_id, since, until)
|
||||
.await?;
|
||||
let filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.namespace == namespace)
|
||||
.take(limit)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Store a memory entry with namespace and importance.
|
||||
///
|
||||
/// Default implementation delegates to `store()`. Backends with native
|
||||
/// namespace/importance support should override.
|
||||
async fn store_with_metadata(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
_namespace: Option<&str>,
|
||||
_importance: Option<f64>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.store(key, content, category, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -185,6 +240,9 @@ mod tests {
|
||||
timestamp: "2026-02-16T00:00:00Z".into(),
|
||||
session_id: Some("session-abc".into()),
|
||||
score: Some(0.98),
|
||||
namespace: "default".into(),
|
||||
importance: Some(0.7),
|
||||
superseded_by: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&entry).unwrap();
|
||||
@@ -196,5 +254,8 @@ mod tests {
|
||||
assert_eq!(parsed.category, MemoryCategory::Core);
|
||||
assert_eq!(parsed.session_id.as_deref(), Some("session-abc"));
|
||||
assert_eq!(parsed.score, Some(0.98));
|
||||
assert_eq!(parsed.namespace, "default");
|
||||
assert_eq!(parsed.importance, Some(0.7));
|
||||
assert!(parsed.superseded_by.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,6 +126,7 @@ pub fn hybrid_merge(
|
||||
b.final_score
|
||||
.partial_cmp(&a.final_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| a.id.cmp(&b.id))
|
||||
});
|
||||
results.truncate(limit);
|
||||
results
|
||||
|
||||
+11
-1
@@ -197,6 +197,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
knowledge: crate::config::KnowledgeConfig::default(),
|
||||
linkedin: crate::config::LinkedInConfig::default(),
|
||||
image_gen: crate::config::ImageGenConfig::default(),
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
@@ -419,9 +420,17 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
retrieval_stages: vec!["cache".into(), "fts".into(), "vector".into()],
|
||||
rerank_enabled: false,
|
||||
rerank_threshold: 5,
|
||||
fts_early_return_score: 0.85,
|
||||
default_namespace: "default".into(),
|
||||
conflict_threshold: 0.85,
|
||||
audit_enabled: false,
|
||||
audit_retention_days: 30,
|
||||
policy: crate::config::MemoryPolicyConfig::default(),
|
||||
sqlite_open_timeout_secs: None,
|
||||
qdrant: crate::config::QdrantConfig::default(),
|
||||
mem0: crate::config::schema::Mem0Config::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -620,6 +629,7 @@ async fn run_quick_setup_with_home(
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
knowledge: crate::config::KnowledgeConfig::default(),
|
||||
linkedin: crate::config::LinkedInConfig::default(),
|
||||
image_gen: crate::config::ImageGenConfig::default(),
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
|
||||
@@ -412,6 +412,7 @@ mod tests {
|
||||
));
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
writeln!(f, "#!/bin/sh\ncat /dev/stdin").unwrap();
|
||||
f.sync_all().unwrap();
|
||||
drop(f);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
|
||||
@@ -767,6 +767,12 @@ impl Provider for OpenAiCodexProvider {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Mutex that serializes all tests which mutate process-global env vars
|
||||
/// (`std::env::set_var` / `remove_var`). Each such test must hold this
|
||||
/// lock for its entire duration so that parallel test threads don't race.
|
||||
static ENV_MUTEX: Mutex<()> = Mutex::new(());
|
||||
|
||||
struct EnvGuard {
|
||||
key: &'static str,
|
||||
@@ -841,6 +847,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_responses_url_prefers_explicit_endpoint_env() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _endpoint_guard = EnvGuard::set(
|
||||
CODEX_RESPONSES_URL_ENV,
|
||||
Some("https://env.example.com/v1/responses"),
|
||||
@@ -856,6 +863,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_responses_url_uses_provider_api_url_override() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None);
|
||||
let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None);
|
||||
|
||||
@@ -959,6 +967,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_reasoning_effort_prefers_configured_override() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("low"));
|
||||
assert_eq!(
|
||||
resolve_reasoning_effort("gpt-5-codex", Some("high")),
|
||||
@@ -968,6 +977,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn resolve_reasoning_effort_uses_legacy_env_when_unconfigured() {
|
||||
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("minimal"));
|
||||
assert_eq!(
|
||||
resolve_reasoning_effort("gpt-5-codex", None),
|
||||
|
||||
@@ -108,6 +108,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
|
||||
"token limit exceeded",
|
||||
"prompt is too long",
|
||||
"input is too long",
|
||||
"prompt exceeds max length",
|
||||
];
|
||||
|
||||
hints.iter().any(|hint| lower.contains(hint))
|
||||
|
||||
+17
-2
@@ -97,7 +97,8 @@ pub struct SecurityPolicy {
|
||||
/// Default allowed commands for Unix platforms.
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn default_allowed_commands() -> Vec<String> {
|
||||
vec![
|
||||
#[allow(unused_mut)]
|
||||
let mut cmds = vec![
|
||||
"git".into(),
|
||||
"npm".into(),
|
||||
"cargo".into(),
|
||||
@@ -111,7 +112,16 @@ fn default_allowed_commands() -> Vec<String> {
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
"date".into(),
|
||||
]
|
||||
"df".into(),
|
||||
"du".into(),
|
||||
"uname".into(),
|
||||
"uptime".into(),
|
||||
"hostname".into(),
|
||||
];
|
||||
// `free` is Linux-only; it does not exist on macOS or other BSDs.
|
||||
#[cfg(target_os = "linux")]
|
||||
cmds.push("free".into());
|
||||
cmds
|
||||
}
|
||||
|
||||
/// Default allowed commands for Windows platforms.
|
||||
@@ -142,6 +152,11 @@ fn default_allowed_commands() -> Vec<String> {
|
||||
"wc".into(),
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
"df".into(),
|
||||
"du".into(),
|
||||
"uname".into(),
|
||||
"uptime".into(),
|
||||
"hostname".into(),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
+3
-3
@@ -1236,7 +1236,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_capture_reads_stdout() {
|
||||
let out = run_capture(Command::new("sh").args(["-lc", "echo hello"]))
|
||||
let out = run_capture(Command::new("sh").args(["-c", "echo hello"]))
|
||||
.expect("stdout capture should succeed");
|
||||
assert_eq!(out.trim(), "hello");
|
||||
}
|
||||
@@ -1244,7 +1244,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_capture_falls_back_to_stderr() {
|
||||
let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"]))
|
||||
let out = run_capture(Command::new("sh").args(["-c", "echo warn 1>&2"]))
|
||||
.expect("stderr capture should succeed");
|
||||
assert_eq!(out.trim(), "warn");
|
||||
}
|
||||
@@ -1252,7 +1252,7 @@ mod tests {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[test]
|
||||
fn run_checked_errors_on_non_zero_status() {
|
||||
let err = run_checked(Command::new("sh").args(["-lc", "exit 17"]))
|
||||
let err = run_checked(Command::new("sh").args(["-c", "exit 17"]))
|
||||
.expect_err("non-zero exit should error");
|
||||
assert!(err.to_string().contains("Command failed"));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,636 @@
|
||||
//! Live Canvas (A2UI) tool — push rendered content to a web canvas in real time.
|
||||
//!
|
||||
//! The agent can render HTML/SVG/Markdown to a named canvas, snapshot its
|
||||
//! current state, clear it, or evaluate a JavaScript expression in the canvas
|
||||
//! context. Content is stored in a shared [`CanvasStore`] and broadcast to
|
||||
//! connected WebSocket clients via per-canvas channels.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Maximum content size per canvas frame (256 KB).
|
||||
pub const MAX_CONTENT_SIZE: usize = 256 * 1024;
|
||||
|
||||
/// Maximum number of history frames kept per canvas.
|
||||
const MAX_HISTORY_FRAMES: usize = 50;
|
||||
|
||||
/// Broadcast channel capacity per canvas.
|
||||
const BROADCAST_CAPACITY: usize = 64;
|
||||
|
||||
/// Maximum number of concurrent canvases to prevent memory exhaustion.
|
||||
const MAX_CANVAS_COUNT: usize = 100;
|
||||
|
||||
/// Allowed content types for canvas frames via the REST API.
|
||||
pub const ALLOWED_CONTENT_TYPES: &[&str] = &["html", "svg", "markdown", "text"];
|
||||
|
||||
/// A single canvas frame (one render).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CanvasFrame {
|
||||
/// Unique frame identifier.
|
||||
pub frame_id: String,
|
||||
/// Content type: `html`, `svg`, `markdown`, or `text`.
|
||||
pub content_type: String,
|
||||
/// The rendered content.
|
||||
pub content: String,
|
||||
/// ISO-8601 timestamp of when the frame was created.
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Per-canvas state: current content + history + broadcast sender.
|
||||
struct CanvasEntry {
|
||||
current: Option<CanvasFrame>,
|
||||
history: Vec<CanvasFrame>,
|
||||
tx: broadcast::Sender<CanvasFrame>,
|
||||
}
|
||||
|
||||
/// Shared canvas store — holds all active canvases.
|
||||
///
|
||||
/// Thread-safe and cheaply cloneable (wraps `Arc`).
|
||||
#[derive(Clone)]
|
||||
pub struct CanvasStore {
|
||||
inner: Arc<RwLock<HashMap<String, CanvasEntry>>>,
|
||||
}
|
||||
|
||||
impl Default for CanvasStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CanvasStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new frame to a canvas. Creates the canvas if it does not exist.
|
||||
/// Returns `None` if the maximum canvas count has been reached and this is a new canvas.
|
||||
pub fn render(
|
||||
&self,
|
||||
canvas_id: &str,
|
||||
content_type: &str,
|
||||
content: &str,
|
||||
) -> Option<CanvasFrame> {
|
||||
let frame = CanvasFrame {
|
||||
frame_id: uuid::Uuid::new_v4().to_string(),
|
||||
content_type: content_type.to_string(),
|
||||
content: content.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
let mut store = self.inner.write();
|
||||
|
||||
// Enforce canvas count limit for new canvases.
|
||||
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = store
|
||||
.entry(canvas_id.to_string())
|
||||
.or_insert_with(|| CanvasEntry {
|
||||
current: None,
|
||||
history: Vec::new(),
|
||||
tx: broadcast::channel(BROADCAST_CAPACITY).0,
|
||||
});
|
||||
|
||||
entry.current = Some(frame.clone());
|
||||
entry.history.push(frame.clone());
|
||||
if entry.history.len() > MAX_HISTORY_FRAMES {
|
||||
let excess = entry.history.len() - MAX_HISTORY_FRAMES;
|
||||
entry.history.drain(..excess);
|
||||
}
|
||||
|
||||
// Best-effort broadcast — ignore errors (no receivers is fine).
|
||||
let _ = entry.tx.send(frame.clone());
|
||||
|
||||
Some(frame)
|
||||
}
|
||||
|
||||
/// Get the current (most recent) frame for a canvas.
|
||||
pub fn snapshot(&self, canvas_id: &str) -> Option<CanvasFrame> {
|
||||
let store = self.inner.read();
|
||||
store.get(canvas_id).and_then(|entry| entry.current.clone())
|
||||
}
|
||||
|
||||
/// Get the frame history for a canvas.
|
||||
pub fn history(&self, canvas_id: &str) -> Vec<CanvasFrame> {
|
||||
let store = self.inner.read();
|
||||
store
|
||||
.get(canvas_id)
|
||||
.map(|entry| entry.history.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Clear a canvas (removes current content and history).
|
||||
pub fn clear(&self, canvas_id: &str) -> bool {
|
||||
let mut store = self.inner.write();
|
||||
if let Some(entry) = store.get_mut(canvas_id) {
|
||||
entry.current = None;
|
||||
entry.history.clear();
|
||||
// Send an empty frame to signal clear to subscribers.
|
||||
let clear_frame = CanvasFrame {
|
||||
frame_id: uuid::Uuid::new_v4().to_string(),
|
||||
content_type: "clear".to_string(),
|
||||
content: String::new(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
let _ = entry.tx.send(clear_frame);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to real-time updates for a canvas.
|
||||
/// Creates the canvas entry if it does not exist (subject to canvas count limit).
|
||||
/// Returns `None` if the canvas does not exist and the limit has been reached.
|
||||
pub fn subscribe(&self, canvas_id: &str) -> Option<broadcast::Receiver<CanvasFrame>> {
|
||||
let mut store = self.inner.write();
|
||||
|
||||
// Enforce canvas count limit for new entries.
|
||||
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = store
|
||||
.entry(canvas_id.to_string())
|
||||
.or_insert_with(|| CanvasEntry {
|
||||
current: None,
|
||||
history: Vec::new(),
|
||||
tx: broadcast::channel(BROADCAST_CAPACITY).0,
|
||||
});
|
||||
Some(entry.tx.subscribe())
|
||||
}
|
||||
|
||||
/// List all canvas IDs that currently have content.
|
||||
pub fn list(&self) -> Vec<String> {
|
||||
let store = self.inner.read();
|
||||
store.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// `CanvasTool` — agent-callable tool for the Live Canvas (A2UI) system.
|
||||
pub struct CanvasTool {
|
||||
store: CanvasStore,
|
||||
}
|
||||
|
||||
impl CanvasTool {
|
||||
pub fn new(store: CanvasStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CanvasTool {
|
||||
fn name(&self) -> &str {
|
||||
"canvas"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Push rendered content (HTML, SVG, Markdown) to a live web canvas that users can see \
|
||||
in real-time. Actions: render (push content), snapshot (get current content), \
|
||||
clear (reset canvas), eval (evaluate JS expression in canvas context). \
|
||||
Each canvas is identified by a canvas_id string."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform on the canvas.",
|
||||
"enum": ["render", "snapshot", "clear", "eval"]
|
||||
},
|
||||
"canvas_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the canvas. Defaults to 'default'."
|
||||
},
|
||||
"content_type": {
|
||||
"type": "string",
|
||||
"description": "Content type for render action: html, svg, markdown, or text.",
|
||||
"enum": ["html", "svg", "markdown", "text"]
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to render (for render action)."
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "JavaScript expression to evaluate (for eval action). \
|
||||
The result is returned as text. Evaluated client-side in the canvas iframe."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let canvas_id = args
|
||||
.get("canvas_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
|
||||
match action {
|
||||
"render" => {
|
||||
let content_type = args
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("html");
|
||||
|
||||
let content = match args.get("content").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing required parameter: content (for render action)"
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if content.len() > MAX_CONTENT_SIZE {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Content exceeds maximum size of {} bytes",
|
||||
MAX_CONTENT_SIZE
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
match self.store.render(canvas_id, content_type, content) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Rendered {} content to canvas '{}' (frame: {})",
|
||||
content_type, canvas_id, frame.frame_id
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Maximum canvas count ({}) reached. Clear unused canvases first.",
|
||||
MAX_CANVAS_COUNT
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
"snapshot" => match self.store.snapshot(canvas_id) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&frame)
|
||||
.unwrap_or_else(|_| frame.content.clone()),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Canvas '{}' is empty", canvas_id),
|
||||
error: None,
|
||||
}),
|
||||
},
|
||||
|
||||
"clear" => {
|
||||
let existed = self.store.clear(canvas_id);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: if existed {
|
||||
format!("Canvas '{}' cleared", canvas_id)
|
||||
} else {
|
||||
format!("Canvas '{}' was already empty", canvas_id)
|
||||
},
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
"eval" => {
|
||||
// Eval is handled client-side. We store an eval request as a special frame
|
||||
// that the web viewer interprets.
|
||||
let expression = match args.get("expression").and_then(|v| v.as_str()) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing required parameter: expression (for eval action)"
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Push a special eval frame so connected clients know to evaluate it.
|
||||
match self.store.render(canvas_id, "eval", expression) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Eval request sent to canvas '{}' (frame: {}). \
|
||||
Result will be available to connected viewers.",
|
||||
canvas_id, frame.frame_id
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Maximum canvas count ({}) reached. Clear unused canvases first.",
|
||||
MAX_CANVAS_COUNT
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: '{}'. Valid actions: render, snapshot, clear, eval",
|
||||
other
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn canvas_store_render_and_snapshot() {
|
||||
let store = CanvasStore::new();
|
||||
let frame = store.render("test", "html", "<h1>Hello</h1>").unwrap();
|
||||
assert_eq!(frame.content_type, "html");
|
||||
assert_eq!(frame.content, "<h1>Hello</h1>");
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.frame_id, frame.frame_id);
|
||||
assert_eq!(snapshot.content, "<h1>Hello</h1>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_snapshot_empty_returns_none() {
|
||||
let store = CanvasStore::new();
|
||||
assert!(store.snapshot("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_clear_removes_content() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>content</p>");
|
||||
assert!(store.snapshot("test").is_some());
|
||||
|
||||
let cleared = store.clear("test");
|
||||
assert!(cleared);
|
||||
assert!(store.snapshot("test").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_clear_nonexistent_returns_false() {
|
||||
let store = CanvasStore::new();
|
||||
assert!(!store.clear("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_history_tracks_frames() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "frame1");
|
||||
store.render("test", "html", "frame2");
|
||||
store.render("test", "html", "frame3");
|
||||
|
||||
let history = store.history("test");
|
||||
assert_eq!(history.len(), 3);
|
||||
assert_eq!(history[0].content, "frame1");
|
||||
assert_eq!(history[2].content, "frame3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_history_limit_enforced() {
|
||||
let store = CanvasStore::new();
|
||||
for i in 0..60 {
|
||||
store.render("test", "html", &format!("frame{i}"));
|
||||
}
|
||||
|
||||
let history = store.history("test");
|
||||
assert_eq!(history.len(), MAX_HISTORY_FRAMES);
|
||||
// Oldest frames should have been dropped
|
||||
assert_eq!(history[0].content, "frame10");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_list_returns_canvas_ids() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("alpha", "html", "a");
|
||||
store.render("beta", "svg", "b");
|
||||
|
||||
let mut ids = store.list();
|
||||
ids.sort();
|
||||
assert_eq!(ids, vec!["alpha", "beta"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_subscribe_receives_updates() {
|
||||
let store = CanvasStore::new();
|
||||
let mut rx = store.subscribe("test").unwrap();
|
||||
store.render("test", "html", "<p>live</p>");
|
||||
|
||||
let frame = rx.try_recv().unwrap();
|
||||
assert_eq!(frame.content, "<p>live</p>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"canvas_id": "test",
|
||||
"content_type": "html",
|
||||
"content": "<h1>Hello World</h1>"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Rendered html content"));
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.content, "<h1>Hello World</h1>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_snapshot_action() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>snap</p>");
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "snapshot", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("<p>snap</p>"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_snapshot_empty() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "snapshot", "canvas_id": "empty"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_clear_action() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>clear me</p>");
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({"action": "clear", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("cleared"));
|
||||
assert!(store.snapshot("test").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_eval_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "eval",
|
||||
"canvas_id": "test",
|
||||
"expression": "document.title"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Eval request sent"));
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.content_type, "eval");
|
||||
assert_eq!(snapshot.content, "document.title");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_unknown_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_missing_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_missing_content() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "render", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("content"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_content_too_large() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let big_content = "x".repeat(MAX_CONTENT_SIZE + 1);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"canvas_id": "test",
|
||||
"content": big_content
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("maximum size"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_default_canvas_id() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"content_type": "html",
|
||||
"content": "<p>default</p>"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(store.snapshot("default").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_enforces_max_canvas_count() {
|
||||
let store = CanvasStore::new();
|
||||
// Create MAX_CANVAS_COUNT canvases
|
||||
for i in 0..MAX_CANVAS_COUNT {
|
||||
assert!(store
|
||||
.render(&format!("canvas_{i}"), "html", "content")
|
||||
.is_some());
|
||||
}
|
||||
// The next new canvas should be rejected
|
||||
assert!(store.render("one_too_many", "html", "content").is_none());
|
||||
// But rendering to an existing canvas should still work
|
||||
assert!(store.render("canvas_0", "html", "updated").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_eval_missing_expression() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "eval", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("expression"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Search Discord message history stored in discord.db.
|
||||
pub struct DiscordSearchTool {
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl DiscordSearchTool {
|
||||
pub fn new(discord_memory: Arc<dyn Memory>) -> Self {
|
||||
Self { discord_memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DiscordSearchTool {
|
||||
fn name(&self) -> &str {
|
||||
"discord_search"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search Discord message history. Returns messages matching a keyword query, optionally filtered by channel_id, author_id, or time range."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Keywords or phrase to search for in Discord messages (optional if since/until provided)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 10)"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Filter results to a specific Discord channel ID"
|
||||
},
|
||||
"since": {
|
||||
"type": "string",
|
||||
"description": "Filter messages at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
|
||||
},
|
||||
"until": {
|
||||
"type": "string",
|
||||
"description": "Filter messages at or before this time (RFC 3339)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let channel_id = args.get("channel_id").and_then(|v| v.as_str());
|
||||
let since = args.get("since").and_then(|v| v.as_str());
|
||||
let until = args.get("until").and_then(|v| v.as_str());
|
||||
|
||||
if query.trim().is_empty() && since.is_none() && until.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(s) = since {
|
||||
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'since' date: {s}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(u) = until {
|
||||
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'until' date: {u}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let (Some(s), Some(u)) = (since, until) {
|
||||
if let (Ok(s_dt), Ok(u_dt)) = (
|
||||
chrono::DateTime::parse_from_rfc3339(s),
|
||||
chrono::DateTime::parse_from_rfc3339(u),
|
||||
) {
|
||||
if s_dt >= u_dt {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'since' must be before 'until'".into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(10, |v| v as usize);
|
||||
|
||||
match self
|
||||
.discord_memory
|
||||
.recall(query, limit, channel_id, since, until)
|
||||
.await
|
||||
{
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No Discord messages found.".into(),
|
||||
error: None,
|
||||
}),
|
||||
Ok(entries) => {
|
||||
let mut output = format!("Found {} Discord messages:\n", entries.len());
|
||||
for entry in &entries {
|
||||
let score = entry
|
||||
.score
|
||||
.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
|
||||
let _ = writeln!(output, "- {}{score}", entry.content);
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Discord search failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn seeded_discord_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new_named(tmp.path(), "discord").unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_empty() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No Discord messages found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_finds_match() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
mem.store(
|
||||
"discord_001",
|
||||
"@user1 in #general at 2025-01-01T00:00:00Z: hello world",
|
||||
MemoryCategory::Custom("discord".to_string()),
|
||||
Some("general"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("hello"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_requires_query_or_time() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("at least"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
assert_eq!(tool.name(), "discord_search");
|
||||
assert!(tool.parameters_schema()["properties"]["query"].is_object());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,494 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
|
||||
///
|
||||
/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
|
||||
/// calls the fal.ai synchronous endpoint, downloads the resulting image,
|
||||
/// and saves it to `{workspace}/images/{filename}.png`.
|
||||
pub struct ImageGenTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
default_model: String,
|
||||
api_key_env: String,
|
||||
}
|
||||
|
||||
impl ImageGenTool {
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
default_model: String,
|
||||
api_key_env: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
workspace_dir,
|
||||
default_model,
|
||||
api_key_env,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a reusable HTTP client with reasonable timeouts.
|
||||
fn http_client() -> reqwest::Client {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.build()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Read an API key from the environment.
|
||||
fn read_api_key(env_var: &str) -> Result<String, String> {
|
||||
std::env::var(env_var)
|
||||
.map(|v| v.trim().to_string())
|
||||
.ok()
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
|
||||
}
|
||||
|
||||
/// Core generation logic: call fal.ai, download image, save to disk.
|
||||
async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// ── Parse parameters ───────────────────────────────────────
|
||||
let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
|
||||
Some(p) if !p.trim().is_empty() => p.trim().to_string(),
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: 'prompt'".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let filename = args
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or("generated_image");
|
||||
|
||||
// Sanitize filename — strip path components to prevent traversal.
|
||||
let safe_name = PathBuf::from(filename).file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
|
||||
let size = args
|
||||
.get("size")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("square_hd");
|
||||
|
||||
// Validate size enum.
|
||||
const VALID_SIZES: &[&str] = &[
|
||||
"square_hd",
|
||||
"landscape_4_3",
|
||||
"portrait_4_3",
|
||||
"landscape_16_9",
|
||||
"portrait_16_9",
|
||||
];
|
||||
if !VALID_SIZES.contains(&size) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid size '{size}'. Valid values: {}",
|
||||
VALID_SIZES.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let model = args
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or(&self.default_model);
|
||||
|
||||
// Validate model identifier: must look like a fal.ai model path
|
||||
// (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
|
||||
// strings, or fragments that could redirect the HTTP request.
|
||||
if model.contains("..")
|
||||
|| model.contains('?')
|
||||
|| model.contains('#')
|
||||
|| model.contains('\\')
|
||||
|| model.starts_with('/')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid model identifier '{model}'. \
|
||||
Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// ── Read API key ───────────────────────────────────────────
|
||||
let api_key = match Self::read_api_key(&self.api_key_env) {
|
||||
Ok(k) => k,
|
||||
Err(msg) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(msg),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// ── Call fal.ai ────────────────────────────────────────────
|
||||
let client = Self::http_client();
|
||||
let url = format!("https://fal.run/{model}");
|
||||
|
||||
let body = json!({
|
||||
"prompt": prompt,
|
||||
"image_size": size,
|
||||
"num_images": 1
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Key {api_key}"))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("fal.ai request failed")?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("fal.ai API error ({status}): {body_text}")),
|
||||
});
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse fal.ai response as JSON")?;
|
||||
|
||||
let image_url = resp_json
|
||||
.pointer("/images/0/url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("No image URL in fal.ai response"))?;
|
||||
|
||||
// ── Download image ─────────────────────────────────────────
|
||||
let img_resp = client
|
||||
.get(image_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to download generated image")?;
|
||||
|
||||
if !img_resp.status().is_success() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to download image from {image_url} ({})",
|
||||
img_resp.status()
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let bytes = img_resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("Failed to read image bytes")?;
|
||||
|
||||
// ── Save to disk ───────────────────────────────────────────
|
||||
let images_dir = self.workspace_dir.join("images");
|
||||
tokio::fs::create_dir_all(&images_dir)
|
||||
.await
|
||||
.context("Failed to create images directory")?;
|
||||
|
||||
let output_path = images_dir.join(format!("{safe_name}.png"));
|
||||
tokio::fs::write(&output_path, &bytes)
|
||||
.await
|
||||
.context("Failed to write image file")?;
|
||||
|
||||
let size_kb = bytes.len() / 1024;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Image generated successfully.\n\
|
||||
File: {}\n\
|
||||
Size: {} KB\n\
|
||||
Model: {}\n\
|
||||
Prompt: {}",
|
||||
output_path.display(),
|
||||
size_kb,
|
||||
model,
|
||||
prompt,
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ImageGenTool {
|
||||
fn name(&self) -> &str {
|
||||
"image_gen"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Generate an image from a text prompt using fal.ai (Flux models). \
|
||||
Saves the result to the workspace images directory and returns the file path."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["prompt"],
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text prompt describing the image to generate."
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
|
||||
},
|
||||
"size": {
|
||||
"type": "string",
|
||||
"enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
|
||||
"description": "Image aspect ratio / size preset (default: 'square_hd')."
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security: image generation is a side-effecting action (HTTP + file write).
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "image_gen")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
self.generate(args).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn test_tool() -> ImageGenTool {
|
||||
ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY".into(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "image_gen");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_description_is_nonempty() {
|
||||
let tool = test_tool();
|
||||
assert!(!tool.description().is_empty());
|
||||
assert!(tool.description().contains("image"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_required_prompt() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["required"], json!(["prompt"]));
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_optional_params() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["filename"].is_object());
|
||||
assert!(schema["properties"]["size"].is_object());
|
||||
assert!(schema["properties"]["model"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_spec_roundtrip() {
|
||||
let tool = test_tool();
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "image_gen");
|
||||
assert!(spec.parameters.is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_prompt_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"prompt": " "})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_api_key_returns_error() {
|
||||
// Temporarily ensure the env var is unset.
|
||||
let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
|
||||
std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_IMAGE_GEN".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "a sunset over the ocean"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("FAL_API_KEY_TEST_IMAGE_GEN"));
|
||||
|
||||
// Restore if it was set.
|
||||
if let Some(val) = original {
|
||||
std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_size_returns_error() {
|
||||
// Set a dummy key so we get past the key check.
|
||||
std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_SIZE".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "test", "size": "invalid_size"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Invalid size"));
|
||||
|
||||
std::env::remove_var("FAL_API_KEY_TEST_SIZE");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_only_autonomy_blocks_execution() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ImageGenTool::new(
|
||||
security,
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY".into(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
let err = result.error.as_deref().unwrap();
|
||||
assert!(
|
||||
err.contains("read-only") || err.contains("image_gen"),
|
||||
"expected read-only or image_gen in error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_model_with_traversal_returns_error() {
|
||||
std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_MODEL".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("Invalid model identifier"));
|
||||
|
||||
std::env::remove_var("FAL_API_KEY_TEST_MODEL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_missing() {
|
||||
let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.contains("DEFINITELY_NOT_SET_ZC_TEST_12345"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filename_traversal_is_sanitized() {
|
||||
// Verify that path traversal in filenames is stripped to just the final component.
|
||||
let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
assert_eq!(sanitized, "passwd");
|
||||
|
||||
// ".." alone has no file_name, falls back to default.
|
||||
let sanitized = PathBuf::from("..").file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
assert_eq!(sanitized, "generated_image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_present() {
|
||||
std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123");
|
||||
let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "test_value_123");
|
||||
std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY");
|
||||
}
|
||||
}
|
||||
+85
-10
@@ -20,6 +20,7 @@ pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
pub mod browser_open;
|
||||
pub mod calculator;
|
||||
pub mod canvas;
|
||||
pub mod claude_code;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
@@ -34,6 +35,7 @@ pub mod cron_runs;
|
||||
pub mod cron_update;
|
||||
pub mod data_management;
|
||||
pub mod delegate;
|
||||
pub mod discord_search;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
@@ -47,6 +49,7 @@ pub mod hardware_memory_map;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub mod hardware_memory_read;
|
||||
pub mod http_request;
|
||||
pub mod image_gen;
|
||||
pub mod image_info;
|
||||
pub mod jira_tool;
|
||||
pub mod knowledge_tool;
|
||||
@@ -69,12 +72,14 @@ pub mod pdf_read;
|
||||
pub mod project_intel;
|
||||
pub mod proxy_config;
|
||||
pub mod pushover;
|
||||
pub mod reaction;
|
||||
pub mod read_skill;
|
||||
pub mod report_templates;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod sessions;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod text_browser;
|
||||
@@ -93,6 +98,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use canvas::{CanvasStore, CanvasTool};
|
||||
pub use claude_code::ClaudeCodeTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
@@ -106,6 +112,7 @@ pub use cron_runs::CronRunsTool;
|
||||
pub use cron_update::CronUpdateTool;
|
||||
pub use data_management::DataManagementTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use discord_search::DiscordSearchTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
@@ -119,6 +126,7 @@ pub use hardware_memory_map::HardwareMemoryMapTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub use hardware_memory_read::HardwareMemoryReadTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use image_gen::ImageGenTool;
|
||||
pub use image_info::ImageInfoTool;
|
||||
pub use jira_tool::JiraTool;
|
||||
pub use knowledge_tool::KnowledgeTool;
|
||||
@@ -139,12 +147,14 @@ pub use pdf_read::PdfReadTool;
|
||||
pub use project_intel::ProjectIntelTool;
|
||||
pub use proxy_config::ProxyConfigTool;
|
||||
pub use pushover::PushoverTool;
|
||||
pub use reaction::{ChannelMapHandle, ReactionTool};
|
||||
pub use read_skill::ReadSkillTool;
|
||||
pub use schedule::ScheduleTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use sessions::{SessionsHistoryTool, SessionsListTool, SessionsSendTool};
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use text_browser::TextBrowserTool;
|
||||
@@ -262,7 +272,12 @@ pub fn all_tools(
|
||||
agents: &HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
|
||||
canvas_store: Option<CanvasStore>,
|
||||
) -> (
|
||||
Vec<Box<dyn Tool>>,
|
||||
Option<DelegateParentToolsHandle>,
|
||||
Option<ChannelMapHandle>,
|
||||
) {
|
||||
all_tools_with_runtime(
|
||||
config,
|
||||
security,
|
||||
@@ -277,6 +292,7 @@ pub fn all_tools(
|
||||
agents,
|
||||
fallback_api_key,
|
||||
root_config,
|
||||
canvas_store,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -296,7 +312,12 @@ pub fn all_tools_with_runtime(
|
||||
agents: &HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
|
||||
canvas_store: Option<CanvasStore>,
|
||||
) -> (
|
||||
Vec<Box<dyn Tool>>,
|
||||
Option<DelegateParentToolsHandle>,
|
||||
Option<ChannelMapHandle>,
|
||||
) {
|
||||
let has_shell_access = runtime.has_shell_access();
|
||||
let sandbox = create_sandbox(&root_config.security);
|
||||
let mut tool_arcs: Vec<Arc<dyn Tool>> = vec![
|
||||
@@ -336,8 +357,21 @@ pub fn all_tools_with_runtime(
|
||||
)),
|
||||
Arc::new(CalculatorTool::new()),
|
||||
Arc::new(WeatherTool::new()),
|
||||
Arc::new(CanvasTool::new(canvas_store.unwrap_or_default())),
|
||||
];
|
||||
|
||||
// Register discord_search if discord_history channel is configured
|
||||
if root_config.channels_config.discord_history.is_some() {
|
||||
match crate::memory::SqliteMemory::new_named(workspace_dir, "discord") {
|
||||
Ok(discord_mem) => {
|
||||
tool_arcs.push(Arc::new(DiscordSearchTool::new(Arc::new(discord_mem))));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("discord_search: failed to open discord.db: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(
|
||||
root_config.skills.prompt_injection_mode,
|
||||
crate::config::SkillsPromptInjectionMode::Compact
|
||||
@@ -545,6 +579,18 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
|
||||
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));
|
||||
|
||||
// Session-to-session messaging tools (always available when sessions dir exists)
|
||||
if let Ok(session_store) = crate::channels::session_store::SessionStore::new(workspace_dir) {
|
||||
let backend: Arc<dyn crate::channels::session_backend::SessionBackend> =
|
||||
Arc::new(session_store);
|
||||
tool_arcs.push(Arc::new(SessionsListTool::new(backend.clone())));
|
||||
tool_arcs.push(Arc::new(SessionsHistoryTool::new(
|
||||
backend.clone(),
|
||||
security.clone(),
|
||||
)));
|
||||
tool_arcs.push(Arc::new(SessionsSendTool::new(backend, security.clone())));
|
||||
}
|
||||
|
||||
// LinkedIn integration (config-gated)
|
||||
if root_config.linkedin.enabled {
|
||||
tool_arcs.push(Arc::new(LinkedInTool::new(
|
||||
@@ -556,6 +602,16 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Standalone image generation tool (config-gated)
|
||||
if root_config.image_gen.enabled {
|
||||
tool_arcs.push(Arc::new(ImageGenTool::new(
|
||||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.image_gen.default_model.clone(),
|
||||
root_config.image_gen.api_key_env.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(key) = composio_key {
|
||||
if !key.is_empty() {
|
||||
tool_arcs.push(Arc::new(ComposioTool::new(
|
||||
@@ -566,6 +622,11 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
}
|
||||
|
||||
// Emoji reaction tool — always registered; channel map populated later by start_channels.
|
||||
let reaction_tool = ReactionTool::new(security.clone());
|
||||
let reaction_handle = reaction_tool.channel_map_handle();
|
||||
tool_arcs.push(Arc::new(reaction_tool));
|
||||
|
||||
// Microsoft 365 Graph API integration
|
||||
if root_config.microsoft365.enabled {
|
||||
let ms_cfg = &root_config.microsoft365;
|
||||
@@ -592,7 +653,11 @@ pub fn all_tools_with_runtime(
|
||||
tracing::error!(
|
||||
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
|
||||
);
|
||||
return (boxed_registry_from_arcs(tool_arcs), None);
|
||||
return (
|
||||
boxed_registry_from_arcs(tool_arcs),
|
||||
None,
|
||||
Some(reaction_handle),
|
||||
);
|
||||
}
|
||||
|
||||
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
|
||||
@@ -776,7 +841,11 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
}
|
||||
|
||||
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
|
||||
(
|
||||
boxed_registry_from_arcs(tool_arcs),
|
||||
delegate_handle,
|
||||
Some(reaction_handle),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -820,7 +889,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -833,6 +902,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"browser_open"));
|
||||
@@ -862,7 +932,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -875,6 +945,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"browser_open"));
|
||||
@@ -1015,7 +1086,7 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1028,6 +1099,7 @@ mod tests {
|
||||
&agents,
|
||||
Some("delegate-test-credential"),
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"delegate"));
|
||||
@@ -1048,7 +1120,7 @@ mod tests {
|
||||
let http = crate::config::HttpRequestConfig::default();
|
||||
let cfg = test_config(&tmp);
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(Config::default()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1061,6 +1133,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"delegate"));
|
||||
@@ -1082,7 +1155,7 @@ mod tests {
|
||||
let mut cfg = test_config(&tmp);
|
||||
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Compact;
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(cfg.clone()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1095,6 +1168,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"read_skill"));
|
||||
@@ -1116,7 +1190,7 @@ mod tests {
|
||||
let mut cfg = test_config(&tmp);
|
||||
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Full;
|
||||
|
||||
let (tools, _) = all_tools(
|
||||
let (tools, _, _) = all_tools(
|
||||
Arc::new(cfg.clone()),
|
||||
&security,
|
||||
mem,
|
||||
@@ -1129,6 +1203,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"read_skill"));
|
||||
|
||||
@@ -0,0 +1,546 @@
|
||||
//! Emoji reaction tool for cross-channel message reactions.
|
||||
//!
|
||||
//! Exposes `add_reaction` and `remove_reaction` from the [`Channel`] trait as an
|
||||
//! agent-callable tool. The tool holds a late-binding channel map handle that is
|
||||
//! populated once channels are initialized (after tool construction). This mirrors
|
||||
//! the pattern used by [`DelegateTool`] for its parent-tools handle.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::channels::traits::Channel;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Shared handle to the channel map. Starts empty; populated once channels boot.
|
||||
pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
|
||||
|
||||
/// Agent-callable tool for adding or removing emoji reactions on messages.
|
||||
pub struct ReactionTool {
|
||||
channels: ChannelMapHandle,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl ReactionTool {
|
||||
/// Create a new reaction tool with an empty channel map.
|
||||
/// Call [`populate`] or write to the returned [`ChannelMapHandle`] once channels
|
||||
/// are available.
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
security,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the shared handle so callers can populate it after channel init.
|
||||
pub fn channel_map_handle(&self) -> ChannelMapHandle {
|
||||
Arc::clone(&self.channels)
|
||||
}
|
||||
|
||||
/// Convenience: populate the channel map from a pre-built map.
|
||||
pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
|
||||
*self.channels.write() = map;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ReactionTool {
|
||||
fn name(&self) -> &str {
|
||||
"reaction"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Add or remove an emoji reaction on a message in any active channel. \
|
||||
Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
|
||||
the platform message ID, and the emoji (Unicode character or platform shortcode)."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"description": "Platform-scoped message identifier to react to"
|
||||
},
|
||||
"emoji": {
|
||||
"type": "string",
|
||||
"description": "Emoji to react with (Unicode character or platform shortcode)"
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "remove"],
|
||||
"description": "Whether to add or remove the reaction (default: 'add')"
|
||||
}
|
||||
},
|
||||
"required": ["channel", "channel_id", "message_id", "emoji"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security gate
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "reaction")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let channel_name = args
|
||||
.get("channel")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'channel' parameter"))?;
|
||||
|
||||
let channel_id = args
|
||||
.get("channel_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'channel_id' parameter"))?;
|
||||
|
||||
let message_id = args
|
||||
.get("message_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message_id' parameter"))?;
|
||||
|
||||
let emoji = args
|
||||
.get("emoji")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'emoji' parameter"))?;
|
||||
|
||||
let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
|
||||
|
||||
if action != "add" && action != "remove" {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid action '{action}': must be 'add' or 'remove'"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Read-lock the channel map to find the target channel.
|
||||
let channel = {
|
||||
let map = self.channels.read();
|
||||
if map.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("No channels available yet (channels not initialized)".to_string()),
|
||||
});
|
||||
}
|
||||
match map.get(channel_name) {
|
||||
Some(ch) => Arc::clone(ch),
|
||||
None => {
|
||||
let available: Vec<String> = map.keys().cloned().collect();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Channel '{channel_name}' not found. Available channels: {}",
|
||||
available.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let result = if action == "add" {
|
||||
channel.add_reaction(channel_id, message_id, emoji).await
|
||||
} else {
|
||||
channel.remove_reaction(channel_id, message_id, emoji).await
|
||||
};
|
||||
|
||||
let past_tense = if action == "remove" {
|
||||
"removed"
|
||||
} else {
|
||||
"added"
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to {action} reaction: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::channels::traits::{ChannelMessage, SendMessage};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
struct MockChannel {
|
||||
reaction_added: AtomicBool,
|
||||
reaction_removed: AtomicBool,
|
||||
last_channel_id: parking_lot::Mutex<Option<String>>,
|
||||
fail_on_add: bool,
|
||||
}
|
||||
|
||||
impl MockChannel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
reaction_added: AtomicBool::new(false),
|
||||
reaction_removed: AtomicBool::new(false),
|
||||
last_channel_id: parking_lot::Mutex::new(None),
|
||||
fail_on_add: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn failing() -> Self {
|
||||
Self {
|
||||
reaction_added: AtomicBool::new(false),
|
||||
reaction_removed: AtomicBool::new(false),
|
||||
last_channel_id: parking_lot::Mutex::new(None),
|
||||
fail_on_add: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for MockChannel {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_reaction(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
_message_id: &str,
|
||||
_emoji: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if self.fail_on_add {
|
||||
return Err(anyhow::anyhow!("API error: rate limited"));
|
||||
}
|
||||
*self.last_channel_id.lock() = Some(channel_id.to_string());
|
||||
self.reaction_added.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_reaction(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
_message_id: &str,
|
||||
_emoji: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
*self.last_channel_id.lock() = Some(channel_id.to_string());
|
||||
self.reaction_removed.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let map: HashMap<String, Arc<dyn Channel>> = channels
|
||||
.into_iter()
|
||||
.map(|(name, ch)| (name.to_string(), ch))
|
||||
.collect();
|
||||
tool.populate(map);
|
||||
tool
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_metadata() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
assert_eq!(tool.name(), "reaction");
|
||||
assert!(!tool.description().is_empty());
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["channel"].is_object());
|
||||
assert!(schema["properties"]["channel_id"].is_object());
|
||||
assert!(schema["properties"]["message_id"].is_object());
|
||||
assert!(schema["properties"]["emoji"].is_object());
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.iter().any(|v| v == "channel"));
|
||||
assert!(required.iter().any(|v| v == "channel_id"));
|
||||
assert!(required.iter().any(|v| v == "message_id"));
|
||||
assert!(required.iter().any(|v| v == "emoji"));
|
||||
// action is optional (defaults to "add")
|
||||
assert!(!required.iter().any(|v| v == "action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_reaction_success() {
|
||||
let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
|
||||
let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_123",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("added"));
|
||||
assert!(result.error.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_reaction_success() {
|
||||
let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
|
||||
let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "slack",
|
||||
"channel_id": "C0123SLACK",
|
||||
"message_id": "msg_456",
|
||||
"emoji": "\u{1F440}",
|
||||
"action": "remove"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("removed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_channel_returns_error() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"discord",
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "nonexistent",
|
||||
"channel_id": "ch_x",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
let err = result.error.as_deref().unwrap();
|
||||
assert!(err.contains("not found"));
|
||||
assert!(err.contains("discord"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_action_returns_error() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"discord",
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}",
|
||||
"action": "toggle"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("toggle"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_error_propagated() {
|
||||
let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
|
||||
let tool = make_tool_with_channels(vec![("discord", mock)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("rate limited"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_required_params() {
|
||||
let tool = make_tool_with_channels(vec![(
|
||||
"test",
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
)]);
|
||||
|
||||
// Missing channel
|
||||
let result = tool
|
||||
.execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Missing channel_id
|
||||
let result = tool
|
||||
.execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Missing message_id
|
||||
let result = tool
|
||||
.execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Missing emoji
|
||||
let result = tool
|
||||
.execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_channels_returns_not_initialized() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
// No channels populated
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "ch_001",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("not initialized"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_action_is_add() {
|
||||
let mock = Arc::new(MockChannel::new());
|
||||
let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
|
||||
let tool = make_tool_with_channels(vec![("test", mock_ch)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "test",
|
||||
"channel_id": "ch_test",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(mock.reaction_added.load(Ordering::SeqCst));
|
||||
assert!(!mock.reaction_removed.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_id_passed_to_trait_not_channel_name() {
|
||||
let mock = Arc::new(MockChannel::new());
|
||||
let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
|
||||
let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "discord",
|
||||
"channel_id": "123456789",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
// The trait must receive the platform channel_id, not the channel name
|
||||
assert_eq!(
|
||||
mock.last_channel_id.lock().as_deref(),
|
||||
Some("123456789"),
|
||||
"add_reaction must receive channel_id, not channel name"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_map_handle_allows_late_binding() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let handle = tool.channel_map_handle();
|
||||
|
||||
// Initially empty — tool reports not initialized
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "slack",
|
||||
"channel_id": "C0123",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
|
||||
// Populate via the handle
|
||||
{
|
||||
let mut map = handle.write();
|
||||
map.insert(
|
||||
"slack".to_string(),
|
||||
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
|
||||
);
|
||||
}
|
||||
|
||||
// Now the tool can route to the channel
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"channel": "slack",
|
||||
"channel_id": "C0123",
|
||||
"message_id": "msg_1",
|
||||
"emoji": "\u{2705}"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_matches_metadata() {
|
||||
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "reaction");
|
||||
assert_eq!(spec.description, tool.description());
|
||||
assert!(spec.parameters["required"].is_array());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
//! Session-to-session messaging tools for inter-agent communication.
|
||||
//!
|
||||
//! Provides three tools:
|
||||
//! - `sessions_list` — list active sessions with metadata
|
||||
//! - `sessions_history` — read message history from a specific session
|
||||
//! - `sessions_send` — send a message to a specific session
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::channels::session_backend::SessionBackend;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Validate that a session ID is non-empty and contains at least one
|
||||
/// alphanumeric character (prevents blank keys after sanitization).
|
||||
fn validate_session_id(session_id: &str) -> Result<(), ToolResult> {
|
||||
let trimmed = session_id.trim();
|
||||
if trimmed.is_empty() || !trimmed.chars().any(|c| c.is_alphanumeric()) {
|
||||
return Err(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Invalid 'session_id': must be non-empty and contain at least one alphanumeric character.".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── SessionsListTool ────────────────────────────────────────────────
|
||||
|
||||
/// Lists active sessions with their channel, last activity time, and message count.
|
||||
pub struct SessionsListTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
}
|
||||
|
||||
impl SessionsListTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>) -> Self {
|
||||
Self { backend }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsListTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_list"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"List all active conversation sessions with their channel, last activity time, and message count."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max sessions to return (default: 50)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(50, |v| v as usize);
|
||||
|
||||
let metadata = self.backend.list_sessions_with_metadata();
|
||||
|
||||
if metadata.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No active sessions found.".into(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let capped: Vec<_> = metadata.into_iter().take(limit).collect();
|
||||
let mut output = format!("Found {} session(s):\n", capped.len());
|
||||
for meta in &capped {
|
||||
// Extract channel from key (convention: channel__identifier)
|
||||
let channel = meta.key.split("__").next().unwrap_or(&meta.key);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"- {}: channel={}, messages={}, last_activity={}",
|
||||
meta.key, channel, meta.message_count, meta.last_activity
|
||||
);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionsHistoryTool ─────────────────────────────────────────────
|
||||
|
||||
/// Reads the message history of a specific session by ID.
|
||||
pub struct SessionsHistoryTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SessionsHistoryTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { backend, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsHistoryTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_history"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the message history of a specific session by its session ID. Returns the last N messages."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The session ID to read history from (e.g. telegram__user123)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max messages to return, from most recent (default: 20)"
|
||||
}
|
||||
},
|
||||
"required": ["session_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "sessions_history")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let session_id = args
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
|
||||
|
||||
if let Err(result) = validate_session_id(session_id) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(20, |v| v as usize);
|
||||
|
||||
let messages = self.backend.load(session_id);
|
||||
|
||||
if messages.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("No messages found for session '{session_id}'."),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Take the last `limit` messages
|
||||
let start = messages.len().saturating_sub(limit);
|
||||
let tail = &messages[start..];
|
||||
|
||||
let mut output = format!(
|
||||
"Session '{}': showing {}/{} messages\n",
|
||||
session_id,
|
||||
tail.len(),
|
||||
messages.len()
|
||||
);
|
||||
for msg in tail {
|
||||
let _ = writeln!(output, "[{}] {}", msg.role, msg.content);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionsSendTool ────────────────────────────────────────────────
|
||||
|
||||
/// Sends a message to a specific session, enabling inter-agent communication.
|
||||
pub struct SessionsSendTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SessionsSendTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { backend, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsSendTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_send"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a message to a specific session by its session ID. The message is appended to the session's conversation history as a 'user' message, enabling inter-agent communication."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The target session ID (e.g. telegram__user123)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
}
|
||||
},
|
||||
"required": ["session_id", "message"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "sessions_send")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let session_id = args
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
|
||||
|
||||
if let Err(result) = validate_session_id(session_id) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let message = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?;
|
||||
|
||||
if message.trim().is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Message content must not be empty.".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let chat_msg = crate::providers::traits::ChatMessage::user(message);
|
||||
|
||||
match self.backend.append(session_id, &chat_msg) {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Message sent to session '{session_id}'."),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to send message: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::channels::session_store::SessionStore;
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_backend() -> (TempDir, Arc<dyn SessionBackend>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(store))
|
||||
}
|
||||
|
||||
fn seeded_backend() -> (TempDir, Arc<dyn SessionBackend>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
store
|
||||
.append("telegram__alice", &ChatMessage::user("Hello from Alice"))
|
||||
.unwrap();
|
||||
store
|
||||
.append(
|
||||
"telegram__alice",
|
||||
&ChatMessage::assistant("Hi Alice, how can I help?"),
|
||||
)
|
||||
.unwrap();
|
||||
store
|
||||
.append("discord__bob", &ChatMessage::user("Hey from Bob"))
|
||||
.unwrap();
|
||||
(tmp, Arc::new(store))
|
||||
}
|
||||
|
||||
// ── SessionsListTool tests ──────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_empty_sessions() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No active sessions"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_shows_all() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("2 session(s)"));
|
||||
assert!(result.output.contains("telegram__alice"));
|
||||
assert!(result.output.contains("discord__bob"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_respects_limit() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({"limit": 1})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("1 session(s)"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_extracts_channel() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.output.contains("channel=telegram"));
|
||||
assert!(result.output.contains("channel=discord"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
assert_eq!(tool.name(), "sessions_list");
|
||||
assert!(tool.parameters_schema()["properties"]["limit"].is_object());
|
||||
}
|
||||
|
||||
// ── SessionsHistoryTool tests ───────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_empty_session() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "nonexistent"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No messages found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_returns_messages() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "telegram__alice"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("showing 2/2 messages"));
|
||||
assert!(result.output.contains("[user] Hello from Alice"));
|
||||
assert!(result.output.contains("[assistant] Hi Alice"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_respects_limit() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "telegram__alice", "limit": 1}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("showing 1/2 messages"));
|
||||
// Should show only the last message
|
||||
assert!(result.output.contains("[assistant]"));
|
||||
assert!(!result.output.contains("[user] Hello from Alice"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_missing_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("session_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_rejects_empty_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"session_id": " "})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
assert_eq!(tool.name(), "sessions_history");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["session_id"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("session_id")));
|
||||
}
|
||||
|
||||
// ── SessionsSendTool tests ──────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_appends_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": "Hello from another agent"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Message sent"));
|
||||
|
||||
// Verify message was appended
|
||||
let messages = backend.load("telegram__alice");
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "Hello from another agent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_to_existing_session() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsSendTool::new(backend.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": "Inter-agent message"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let messages = backend.load("telegram__alice");
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(messages[2].content, "Inter-agent message");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_empty_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": " "
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_empty_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "",
|
||||
"message": "hello"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_non_alphanumeric_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "///",
|
||||
"message": "hello"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_missing_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"message": "hi"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("session_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_missing_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"session_id": "telegram__alice"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("message"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
assert_eq!(tool.name(), "sessions_send");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("session_id")));
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("message")));
|
||||
}
|
||||
}
|
||||
@@ -34,6 +34,7 @@ image_info = "Read image file metadata (format, dimensions, size) and optionally
|
||||
jira = "Interact with Jira: get tickets with configurable detail level, search issues with JQL, and add comments with mention and formatting support."
|
||||
knowledge = "Manage a knowledge graph of architecture decisions, solution patterns, lessons learned, and experts. Actions: capture, search, relate, suggest, expert_find, lessons_extract, graph_stats."
|
||||
linkedin = "Manage LinkedIn: create posts, list your posts, comment, react, delete posts, view engagement, get profile info, and read the configured content strategy. Requires LINKEDIN_* credentials in .env file."
|
||||
discord_search = "Search Discord message history stored in discord.db. Use to find past messages, summarize channel activity, or look up what users said. Supports keyword search and optional filters: channel_id, since, until."
|
||||
memory_forget = "Remove a memory by key. Use to delete outdated facts or sensitive data. Returns whether the memory was found and removed."
|
||||
memory_recall = "Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
|
||||
memory_store = "Store a fact, preference, or note in long-term memory. Use category 'core' for permanent facts, 'daily' for session notes, 'conversation' for chat context, or a custom category name."
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
# คำอธิบายเครื่องมือภาษาไทย (Thai tool descriptions)
|
||||
#
|
||||
# แต่ละคีย์ภายใต้ [tools] จะตรงกับค่าที่ส่งกลับจาก name() ของเครื่องมือ
|
||||
# ค่าคือคำอธิบายที่มนุษย์อ่านได้ซึ่งจะแสดงใน system prompts
|
||||
|
||||
[tools]
|
||||
backup = "สร้าง, ลิสต์, ตรวจสอบ และกู้คืนข้อมูลสำรองของเวิร์กสเปซ"
|
||||
browser = "การทำงานอัตโนมัติบนเว็บ/เบราว์เซอร์ด้วยแบ็คเอนด์ที่ถอดเปลี่ยนได้ (agent-browser, rust-native, computer_use) รองรับการดำเนินการ DOM และการดำเนินการระดับ OS (ย้ายเมาส์, คลิก, ลาก, พิมพ์คีย์, กดคีย์, จับภาพหน้าจอ) ผ่าน computer-use sidecar ใช้ 'snapshot' เพื่อจับคู่พารามิเตอร์โต้ตอบกับ refs (@e1, @e2) บังคับใช้ browser.allowed_domains สำหรับการเปิดหน้าเว็บ"
|
||||
browser_delegate = "มอบหมายงานที่ใช้เบราว์เซอร์ให้กับ CLI ที่มีความสามารถด้านเบราว์เซอร์เพื่อโต้ตอบกับเว็บแอปพลิเคชัน เช่น Teams, Outlook, Jira, Confluence"
|
||||
browser_open = "เปิด URL HTTPS ที่ได้รับอนุญาตในเบราว์เซอร์ของระบบ ข้อจำกัดด้านความปลอดภัย: เฉพาะโดเมนใน allowlist เท่านั้น, ห้ามโฮสต์โลคัล/ส่วนตัว, ห้ามดึงข้อมูล (scraping)"
|
||||
cloud_ops = "เครื่องมือให้คำปรึกษาด้านการเปลี่ยนแปลงคลาวด์ วิเคราะห์แผน IaC, ประเมินเส้นทางการย้ายระบบ, ตรวจสอบค่าใช้จ่าย และตรวจสอบสถาปัตยกรรมตามหลัก Well-Architected Framework อ่านอย่างเดียว: ไม่สร้างหรือแก้ไขทรัพยากรคลาวด์"
|
||||
cloud_patterns = "ไลบรารีรูปแบบคลาวด์ แนะนำรูปแบบสถาปัตยกรรม cloud-native ที่เหมาะสม (containerization, serverless, การปรับปรุงฐานข้อมูลให้ทันสมัย ฯลฯ) ตามคำอธิบายภาระงาน"
|
||||
composio = "รันคำสั่งบนแอปมากกว่า 1,000 แอปผ่าน Composio (Gmail, Notion, GitHub, Slack ฯลฯ) ใช้ action='list' เพื่อดูคำสั่งที่ใช้งานได้ ใช้ action='execute' พร้อม action_name/tool_slug และพารามิเตอร์เพื่อรันคำสั่ง หากไม่แน่ใจพารามิเตอร์ ให้ส่ง 'text' พร้อมคำอธิบายภาษาธรรมชาติแทน ใช้ action='list_accounts' เพื่อดูบัญชีที่เชื่อมต่อ และ action='connect' เพื่อรับ URL OAuth"
|
||||
content_search = "ค้นหาเนื้อหาไฟล์ด้วยรูปแบบ regex ภายในเวิร์กสเปซ รองรับ ripgrep (rg) พร้อมระบบสำรองเป็น grep โหมดเอาต์พุต: 'content' (บรรทัดที่ตรงกันพร้อมบริบท), 'files_with_matches' (เฉพาะเส้นทางไฟล์), 'count' (จำนวนที่พบต่อไฟล์) ตัวอย่าง: pattern='fn main', include='*.rs', output_mode='content'"
|
||||
cron_add = """สร้างงานตั้งเวลา cron (shell หรือ agent) รองรับตารางเวลาแบบ cron/at/every ใช้ job_type='agent' พร้อม prompt เพื่อรัน AI agent ตามกำหนดเวลา หากต้องการส่งเอาต์พุตไปยังแชนเนล (Discord, Telegram, Slack, Mattermost, Matrix) ให้ตั้งค่า delivery={"mode":"announce","channel":"discord","to":"<channel_id_or_chat_id>"} นี่เป็นเครื่องมือที่แนะนำสำหรับการส่งข้อความตั้งเวลาหรือหน่วงเวลาไปยังผู้ใช้ผ่านแชนเนล"""
|
||||
cron_list = "รายการงานตั้งเวลา cron ทั้งหมด"
|
||||
cron_remove = "ลบงานตั้งเวลาด้วย id"
|
||||
cron_run = "บังคับรันงานตั้งเวลาทันทีและบันทึกประวัติการรัน"
|
||||
cron_runs = "รายการประวัติการรันล่าสุดของงานตั้งเวลา"
|
||||
cron_update = "แก้ไขงานตั้งเวลาที่มีอยู่ (ตารางเวลา, คำสั่ง, prompt, การเปิดใช้งาน, การส่งข้อมูล, โมเดล ฯลฯ)"
|
||||
data_management = "การเก็บรักษาข้อมูลเวิร์กสเปซ, การล้างข้อมูล และสถิติการจัดเก็บ"
|
||||
delegate = "มอบหมายงานย่อยให้กับเอเจนต์เฉพาะทาง ใช้เมื่อ: งานจะได้รับประโยชน์จากโมเดลที่ต่างออกไป (เช่น สรุปผลเร็ว, การให้เหตุผลเชิงลึก, การสร้างโค้ด) เอเจนต์ย่อยจะรันหนึ่ง prompt ตามค่าเริ่มต้น หากตั้ง agentic=true จะสามารถทำงานวนซ้ำด้วยเครื่องมือที่จำกัดได้"
|
||||
discord_search = "ค้นหาประวัติข้อความ Discord ที่เก็บไว้ใน discord.db ใช้เพื่อค้นหาข้อความในอดีต, สรุปกิจกรรมในแชนเนล หรือดูว่าผู้ใช้พูดอะไร รองรับการค้นหาด้วยคีย์เวิร์ดและตัวกรองเสริม: channel_id, since, until"
|
||||
file_edit = "แก้ไขไฟล์โดยการแทนที่ข้อความที่ตรงกันเป๊ะๆ ด้วยเนื้อหาใหม่"
|
||||
file_read = "อ่านเนื้อหาไฟล์พร้อมเลขบรรทัด รองรับการอ่านบางส่วนผ่าน offset และ limit ดึงข้อความจาก PDF; ไฟล์ไบนารีอื่นจะถูกอ่านด้วยการแปลง UTF-8 แบบสูญเสียข้อมูล"
|
||||
file_write = "เขียนเนื้อหาลงในไฟล์ในเวิร์กสเปซ"
|
||||
git_operations = "รันคำสั่ง Git แบบโครงสร้าง (status, diff, log, branch, commit, add, checkout, stash) ให้เอาต์พุต JSON ที่แยกส่วนแล้ว และรวมเข้ากับนโยบายความปลอดภัยสำหรับการควบคุมตนเอง"
|
||||
glob_search = "ค้นหาไฟล์ที่ตรงกับรูปแบบ glob ภายในเวิร์กสเปซ ส่งกลับรายการเส้นทางไฟล์ที่ตรงกันเทียบกับรูทของเวิร์กสเปซ ตัวอย่าง: '**/*.rs' (ไฟล์ Rust ทั้งหมด), 'src/**/mod.rs' (mod.rs ทั้งหมดใน src)"
|
||||
google_workspace = "โต้ตอบกับบริการ Google Workspace (Drive, Gmail, Calendar, Sheets, Docs ฯลฯ) ผ่าน gws CLI ต้องติดตั้งและยืนยันตัวตน gws ก่อน"
|
||||
hardware_board_info = "ส่งกลับข้อมูลบอร์ดฉบับเต็ม (ชิป, สถาปัตยกรรม, แผนผังหน่วยความจำ) สำหรับฮาร์ดแวร์ที่เชื่อมต่อ ใช้เมื่อ: ผู้ใช้ถามเกี่ยวกับ 'board info', 'ใช้บอร์ดอะไร', 'ฮาร์ดแวร์ที่ต่ออยู่', 'ข้อมูลชิป' หรือ 'memory map'"
|
||||
hardware_memory_map = "ส่งกลับแผนผังหน่วยความจำ (ช่วงที่อยู่ flash และ RAM) สำหรับฮาร์ดแวร์ที่เชื่อมต่อ ใช้เมื่อ: ผู้ใช้ถามเกี่ยวกับ 'upper and lower memory addresses', 'แผนผังหน่วยความจำ' หรือ 'ที่อยู่ที่อ่านได้'"
|
||||
hardware_memory_read = "อ่านค่าหน่วยความจำ/รีจิสเตอร์จริงจาก Nucleo ผ่าน USB ใช้เมื่อ: ผู้ใช้ถามให้ 'อ่านค่ารีจิสเตอร์', 'อ่านหน่วยความจำที่แอดเดรส', 'dump memory' ส่งกลับเป็น hex dump ต้องเชื่อมต่อ Nucleo ผ่าน USB พารามิเตอร์: address (hex), length (bytes)"
|
||||
http_request = "ส่งคำขอ HTTP ไปยัง API ภายนอก รองรับเมธอด GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS ข้อจำกัดด้านความปลอดภัย: เฉพาะโดเมนใน allowlist เท่านั้น, ห้ามโฮสต์โลคัล/ส่วนตัว, ตั้งค่า timeout และจำกัดขนาดการตอบกลับได้"
|
||||
image_info = "อ่านข้อมูลเมตาของไฟล์รูปภาพ (รูปแบบ, ขนาดกว้างยาว, ขนาดไฟล์) และสามารถเลือกส่งกลับข้อมูลที่เข้ารหัส base64 ได้"
|
||||
jira = "โต้ตอบกับ Jira: ดึงตั๋วตามระดับรายละเอียดที่กำหนด, ค้นหา issue ด้วย JQL และเพิ่มคอมเมนต์พร้อมรองรับการกล่าวถึง (mention) และการจัดรูปแบบ"
|
||||
knowledge = "จัดการกราฟความรู้ของการตัดสินใจด้านสถาปัตยกรรม, รูปแบบโซลูชัน, บทเรียนที่ได้รับ และผู้เชี่ยวชาญ การดำเนินการ: capture, search, relate, suggest, expert_find, lessons_extract, graph_stats"
|
||||
linkedin = "จัดการ LinkedIn: สร้างโพสต์, รายการโพสต์ของคุณ, คอมเมนต์, แสดงความรู้สึก, ลบโพสต์, ดูการมีส่วนร่วม, ดูข้อมูลโปรไฟล์ และอ่านกลยุทธ์เนื้อหาที่กำหนดไว้ ต้องมีข้อมูลยืนยันตัวตน LINKEDIN_* ในไฟล์ .env"
|
||||
memory_forget = "ลบความจำด้วยคีย์ ใช้เพื่อลบข้อมูลที่ล้าสมัยหรือข้อมูลที่ละเอียดอ่อน ส่งกลับว่าพบและลบความจำหรือไม่"
|
||||
memory_recall = "ค้นหาความจำระยะยาวสำหรับข้อเท็จจริง ความชอบ หรือบริบทที่เกี่ยวข้อง ส่งกลับผลลัพธ์ที่จัดอันดับตามความเกี่ยวข้อง"
|
||||
memory_store = "เก็บข้อเท็จจริง ความชอบ หรือบันทึกลงในความจำระยะยาว ใช้หมวดหมู่ 'core' สำหรับข้อมูลถาวร, 'daily' สำหรับบันทึกเซสชัน, 'conversation' สำหรับบริบทการแชท หรือชื่อหมวดหมู่ที่กำหนดเอง"
|
||||
microsoft365 = "การรวมเข้ากับ Microsoft 365: จัดการอีเมล Outlook, ข้อความ Teams, กิจกรรมปฏิทิน, ไฟล์ OneDrive และการค้นหา SharePoint ผ่าน Microsoft Graph API"
|
||||
model_routing_config = "จัดการการตั้งค่าโมเดลเริ่มต้น, เส้นทางผู้ให้บริการ/โมเดลตามสถานการณ์, กฎการจำแนกประเภท และโปรไฟล์เอเจนต์ย่อย"
|
||||
notion = "โต้ตอบกับ Notion: สอบถามฐานข้อมูล, อ่าน/สร้าง/อัปเดตหน้า และค้นหาในเวิร์กสเปซ"
|
||||
pdf_read = "ดึงข้อความธรรมดาจากไฟล์ PDF ในเวิร์กสเปซ ส่งกลับข้อความที่อ่านได้ทั้งหมด ไฟล์ PDF ที่มีแต่รูปภาพหรือเข้ารหัสจะส่งกลับผลลัพธ์ที่ว่างเปล่า ต้องเปิดฟีเจอร์ 'rag-pdf' ตอน build"
|
||||
project_intel = "ข้อมูลอัจฉริยะในการส่งมอบโปรเจกต์: สร้างรายงานสถานะ, ตรวจจับความเสี่ยง, ร่างการอัปเดตสำหรับลูกค้า, สรุป sprint และประเมินพยายาม เป็นเครื่องมือวิเคราะห์แบบอ่านอย่างเดียว"
|
||||
proxy_config = "จัดการการตั้งค่าพร็อกซีของ ZeroClaw (ขอบเขต: environment | zeroclaw | services) รวมถึงการปรับใช้ในขณะรันและใน process environment"
|
||||
pushover = "ส่งการแจ้งเตือน Pushover ไปยังอุปกรณ์ของคุณ ต้องมี PUSHOVER_TOKEN และ PUSHOVER_USER_KEY ในไฟล์ .env"
|
||||
schedule = """จัดการงาน shell ที่ตั้งเวลาไว้ การดำเนินการ: create/add/once/list/get/cancel/remove/pause/resume คำเตือน: เครื่องมือนี้สร้างงาน shell ที่เอาต์พุตจะถูกบันทึกใน log เท่านั้น ไม่ส่งไปยังแชนเนลใดๆ หากต้องการส่งข้อความตั้งเวลาไปยัง Discord/Telegram/Slack/Matrix ให้ใช้เครื่องมือ cron_add"""
|
||||
screenshot = "จับภาพหน้าจอปัจจุบัน ส่งกลับเส้นทางไฟล์และข้อมูล PNG ที่เข้ารหัส base64"
|
||||
security_ops = "เครื่องมือปฏิบัติการด้านความปลอดภัยสำหรับบริการจัดการความปลอดภัยไซเบอร์ การดำเนินการ: triage_alert, run_playbook, parse_vulnerability, generate_report, list_playbooks, alert_stats"
|
||||
shell = "รันคำสั่ง shell ในไดเรกทอรีรูทของเวิร์กสเปซ"
|
||||
sop_advance = "รายงานผลลัพธ์ของขั้นตอน SOP ปัจจุบันและไปยังขั้นตอนถัดไป ระบุ run_id, ขั้นตอนสำเร็จหรือล้มเหลว และสรุปเอาต์พุตสั้นๆ"
|
||||
sop_approve = "อนุมัติขั้นตอน SOP ที่รอการอนุมัติจากผู้ปฏิบัติงาน ส่งกลับคำสั่งในขั้นตอนที่จะดำเนินการ ใช้ sop_status เพื่อดูว่ามีรายการใดรอยู่"
|
||||
sop_execute = "สั่งรันขั้นตอนการปฏิบัติงานมาตรฐาน (SOP) ด้วยชื่อด้วยตนเอง ส่งกลับ run ID และคำสั่งขั้นตอนแรก ใช้ sop_list เพื่อดู SOP ที่มี"
|
||||
sop_list = "รายการขั้นตอนการปฏิบัติงานมาตรฐาน (SOP) ทั้งหมดที่โหลดไว้ พร้อมเงื่อนไขการรัน, ลำดับความสำคัญ, จำนวนขั้นตอน และจำนวนการรันที่ใช้งานอยู่"
|
||||
sop_status = "สอบถามสถานะการรัน SOP ระบุ run_id สำหรับการรันเฉพาะ หรือ sop_name สำหรับรายการการรันของ SOP นั้น หากไม่มีพารามิเตอร์จะแสดงการรันที่ใช้งานอยู่ทั้งหมด"
|
||||
swarm = "ประสานงานกลุ่มเอเจนต์เพื่อทำงานร่วมกัน รองรับกลยุทธ์แบบลำดับ (pipeline), แบบขนาน (fan-out/fan-in) และแบบเราเตอร์ (เลือกโดย LLM)"
|
||||
tool_search = """ดึงข้อมูลโครงสร้าง schema ฉบับเต็มสำหรับเครื่องมือ MCP ที่โหลดแบบหน่วงเวลา (deferred) เพื่อให้สามารถเรียกใช้งานได้ ใช้ "select:name1,name2" สำหรับการจับคู่ที่แน่นอนหรือใช้คีย์เวิร์ดเพื่อค้นหา"""
|
||||
weather = "ดึงข้อมูลสภาพอากาศปัจจุบันและพยากรณ์อากาศสำหรับสถานที่ใดก็ได้ทั่วโลก รองรับชื่อเมือง (ในภาษาหรือตัวอักษรใดก็ได้), รหัสสนามบิน IATA, พิกัด GPS, รหัสไปรษณีย์ และการระบุตำแหน่งตามโดเมน ส่งกลับอุณหภูมิ, ความรู้สึกจริง, ความชื้น, ความเร็ว/ทิศทางลม, ปริมาณน้ำฝน, ทัศนวิสัย, ความกดอากาศ, ดัชนี UV และเมฆปกคลุม เลือกพยากรณ์อากาศได้ 0–3 วัน หน่วยเริ่มต้นเป็นเมตริก (°C, km/h, mm) แต่สามารถตั้งเป็นอิมพีเรียลได้ ไม่ต้องใช้คีย์ API"
|
||||
web_fetch = "ดึงข้อมูลหน้าเว็บและส่งกลับเนื้อหาเป็นข้อความธรรมดาที่สะอาด หน้า HTML จะถูกแปลงเป็นข้อความที่อ่านได้โดยอัตมัติ คำตอบที่เป็น JSON และข้อความธรรมดาจะถูกส่งกลับตามเดิม เฉพาะคำขอ GET เท่านั้น ปฏิบัติตามการเปลี่ยนเส้นทาง ความปลอดภัย: เฉพาะโดเมนใน allowlist เท่านั้น ห้ามโฮสต์โลคัล/ส่วนตัว"
|
||||
web_search_tool = "ค้นหาข้อมูลบนเว็บ ส่งกลับผลลัพธ์การค้นหาที่เกี่ยวข้องพร้อมชื่อเรื่อง, URL และคำอธิบาย ใช้เพื่อค้นหาข้อมูลปัจจุบัน ข่าวสาร หรือหัวข้อการวิจัย"
|
||||
workspace = "จัดการเวิร์กสเปซแบบหลายไคลเอนต์ คำสั่งย่อย: list, switch, create, info, export แต่ละเวิร์กสเปซจะมีการแยกหน่วยความจำ, การตรวจสอบ, ความลับ และข้อจำกัดเครื่องมือออกจากกัน"
|
||||
@@ -13,6 +13,7 @@ import Cost from './pages/Cost';
|
||||
import Logs from './pages/Logs';
|
||||
import Doctor from './pages/Doctor';
|
||||
import Pairing from './pages/Pairing';
|
||||
import Canvas from './pages/Canvas';
|
||||
import { AuthProvider, useAuth } from './hooks/useAuth';
|
||||
import { DraftContext, useDraftStore } from './hooks/useDraft';
|
||||
import { setLocale, type Locale } from './lib/i18n';
|
||||
@@ -234,6 +235,7 @@ function AppContent() {
|
||||
<Route path="/logs" element={<Logs />} />
|
||||
<Route path="/doctor" element={<Doctor />} />
|
||||
<Route path="/pairing" element={<Pairing />} />
|
||||
<Route path="/canvas" element={<Canvas />} />
|
||||
<Route path="*" element={<Navigate to="/" replace />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
DollarSign,
|
||||
Activity,
|
||||
Stethoscope,
|
||||
Monitor,
|
||||
} from 'lucide-react';
|
||||
import { t } from '@/lib/i18n';
|
||||
|
||||
@@ -25,6 +26,7 @@ const navItems = [
|
||||
{ to: '/cost', icon: DollarSign, labelKey: 'nav.cost' },
|
||||
{ to: '/logs', icon: Activity, labelKey: 'nav.logs' },
|
||||
{ to: '/doctor', icon: Stethoscope, labelKey: 'nav.doctor' },
|
||||
{ to: '/canvas', icon: Monitor, labelKey: 'nav.canvas' },
|
||||
];
|
||||
|
||||
export default function Sidebar() {
|
||||
|
||||
@@ -20,6 +20,7 @@ const translations: Record<Locale, Record<string, string>> = {
|
||||
'nav.cost': '成本追踪',
|
||||
'nav.logs': '日志',
|
||||
'nav.doctor': '诊断',
|
||||
'nav.canvas': '画布',
|
||||
|
||||
// Dashboard
|
||||
'dashboard.title': '仪表盘',
|
||||
@@ -350,6 +351,7 @@ const translations: Record<Locale, Record<string, string>> = {
|
||||
'nav.cost': 'Cost Tracker',
|
||||
'nav.logs': 'Logs',
|
||||
'nav.doctor': 'Doctor',
|
||||
'nav.canvas': 'Canvas',
|
||||
|
||||
// Dashboard
|
||||
'dashboard.title': 'Dashboard',
|
||||
@@ -680,6 +682,7 @@ const translations: Record<Locale, Record<string, string>> = {
|
||||
'nav.cost': 'Maliyet Takibi',
|
||||
'nav.logs': 'Kayıtlar',
|
||||
'nav.doctor': 'Doktor',
|
||||
'nav.canvas': 'Tuval',
|
||||
|
||||
// Dashboard
|
||||
'dashboard.title': 'Kontrol Paneli',
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { Monitor, Trash2, History, RefreshCw } from 'lucide-react';
|
||||
import { apiFetch } from '@/lib/api';
|
||||
import { basePath } from '@/lib/basePath';
|
||||
import { getToken } from '@/lib/auth';
|
||||
|
||||
interface CanvasFrame {
|
||||
frame_id: string;
|
||||
content_type: string;
|
||||
content: string;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
interface WsCanvasMessage {
|
||||
type: string;
|
||||
canvas_id: string;
|
||||
frame?: CanvasFrame;
|
||||
}
|
||||
|
||||
export default function Canvas() {
|
||||
const [canvasId, setCanvasId] = useState('default');
|
||||
const [canvasIdInput, setCanvasIdInput] = useState('default');
|
||||
const [currentFrame, setCurrentFrame] = useState<CanvasFrame | null>(null);
|
||||
const [history, setHistory] = useState<CanvasFrame[]>([]);
|
||||
const [connected, setConnected] = useState(false);
|
||||
const [showHistory, setShowHistory] = useState(false);
|
||||
const [canvasList, setCanvasList] = useState<string[]>([]);
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const iframeRef = useRef<HTMLIFrameElement>(null);
|
||||
|
||||
// Build WebSocket URL for canvas
|
||||
const getWsUrl = useCallback((id: string) => {
|
||||
const proto = location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const base = basePath || '';
|
||||
return `${proto}//${location.host}${base}/ws/canvas/${encodeURIComponent(id)}`;
|
||||
}, []);
|
||||
|
||||
// Connect to canvas WebSocket
|
||||
const connectWs = useCallback((id: string) => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
}
|
||||
|
||||
const token = getToken();
|
||||
const protocols = token ? ['zeroclaw.v1', `bearer.${token}`] : ['zeroclaw.v1'];
|
||||
const ws = new WebSocket(getWsUrl(id), protocols);
|
||||
|
||||
ws.onopen = () => setConnected(true);
|
||||
ws.onclose = () => setConnected(false);
|
||||
ws.onerror = () => setConnected(false);
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const msg: WsCanvasMessage = JSON.parse(event.data);
|
||||
if (msg.type === 'frame' && msg.frame) {
|
||||
if (msg.frame.content_type === 'clear') {
|
||||
setCurrentFrame(null);
|
||||
setHistory([]);
|
||||
} else {
|
||||
setCurrentFrame(msg.frame);
|
||||
setHistory((prev) => [...prev.slice(-49), msg.frame!]);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// ignore parse errors
|
||||
}
|
||||
};
|
||||
|
||||
wsRef.current = ws;
|
||||
}, [getWsUrl]);
|
||||
|
||||
// Connect on mount and when canvasId changes
|
||||
useEffect(() => {
|
||||
connectWs(canvasId);
|
||||
return () => {
|
||||
wsRef.current?.close();
|
||||
};
|
||||
}, [canvasId, connectWs]);
|
||||
|
||||
// Fetch canvas list periodically
|
||||
useEffect(() => {
|
||||
const fetchList = async () => {
|
||||
try {
|
||||
const data = await apiFetch<{ canvases: string[] }>('/api/canvas');
|
||||
setCanvasList(data.canvases || []);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
};
|
||||
fetchList();
|
||||
const interval = setInterval(fetchList, 5000);
|
||||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
// Render content into the iframe
|
||||
useEffect(() => {
|
||||
if (!iframeRef.current || !currentFrame) return;
|
||||
if (currentFrame.content_type === 'eval') return; // eval frames are special
|
||||
|
||||
const iframe = iframeRef.current;
|
||||
const doc = iframe.contentDocument;
|
||||
if (!doc) return;
|
||||
|
||||
let html = currentFrame.content;
|
||||
if (currentFrame.content_type === 'svg') {
|
||||
html = `<!DOCTYPE html><html><head><style>body{margin:0;display:flex;align-items:center;justify-content:center;min-height:100vh;background:#1a1a2e;}</style></head><body>${currentFrame.content}</body></html>`;
|
||||
} else if (currentFrame.content_type === 'markdown') {
|
||||
// Simple markdown-to-HTML: render as preformatted text with basic styling
|
||||
const escaped = currentFrame.content
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>');
|
||||
html = `<!DOCTYPE html><html><head><style>body{margin:1rem;font-family:system-ui,sans-serif;color:#e0e0e0;background:#1a1a2e;line-height:1.6;}pre{white-space:pre-wrap;word-wrap:break-word;}</style></head><body><pre>${escaped}</pre></body></html>`;
|
||||
} else if (currentFrame.content_type === 'text') {
|
||||
const escaped = currentFrame.content
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>');
|
||||
html = `<!DOCTYPE html><html><head><style>body{margin:1rem;font-family:monospace;color:#e0e0e0;background:#1a1a2e;white-space:pre-wrap;}</style></head><body>${escaped}</body></html>`;
|
||||
}
|
||||
|
||||
doc.open();
|
||||
doc.write(html);
|
||||
doc.close();
|
||||
}, [currentFrame]);
|
||||
|
||||
const handleSwitchCanvas = () => {
|
||||
if (canvasIdInput.trim()) {
|
||||
setCanvasId(canvasIdInput.trim());
|
||||
setCurrentFrame(null);
|
||||
setHistory([]);
|
||||
}
|
||||
};
|
||||
|
||||
const handleClear = async () => {
|
||||
try {
|
||||
await apiFetch(`/api/canvas/${encodeURIComponent(canvasId)}`, {
|
||||
method: 'DELETE',
|
||||
});
|
||||
setCurrentFrame(null);
|
||||
setHistory([]);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectHistoryFrame = (frame: CanvasFrame) => {
|
||||
setCurrentFrame(frame);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="p-6 space-y-4 h-full flex flex-col">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<Monitor className="h-6 w-6" style={{ color: 'var(--pc-accent)' }} />
|
||||
<h1 className="text-xl font-semibold" style={{ color: 'var(--pc-text-primary)' }}>
|
||||
Live Canvas
|
||||
</h1>
|
||||
<span
|
||||
className="text-xs px-2 py-0.5 rounded-full font-medium"
|
||||
style={{
|
||||
background: connected ? 'rgba(34, 197, 94, 0.15)' : 'rgba(239, 68, 68, 0.15)',
|
||||
color: connected ? '#22c55e' : '#ef4444',
|
||||
}}
|
||||
>
|
||||
{connected ? 'Connected' : 'Disconnected'}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={() => setShowHistory(!showHistory)}
|
||||
className="p-2 rounded-lg transition-colors hover:opacity-80"
|
||||
style={{ background: 'var(--pc-bg-elevated)', color: 'var(--pc-text-muted)' }}
|
||||
title="Toggle history"
|
||||
>
|
||||
<History className="h-4 w-4" />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleClear}
|
||||
className="p-2 rounded-lg transition-colors hover:opacity-80"
|
||||
style={{ background: 'var(--pc-bg-elevated)', color: 'var(--pc-text-muted)' }}
|
||||
title="Clear canvas"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</button>
|
||||
<button
|
||||
onClick={() => connectWs(canvasId)}
|
||||
className="p-2 rounded-lg transition-colors hover:opacity-80"
|
||||
style={{ background: 'var(--pc-bg-elevated)', color: 'var(--pc-text-muted)' }}
|
||||
title="Reconnect"
|
||||
>
|
||||
<RefreshCw className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Canvas selector */}
|
||||
<div className="flex items-center gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={canvasIdInput}
|
||||
onChange={(e) => setCanvasIdInput(e.target.value)}
|
||||
onKeyDown={(e) => e.key === 'Enter' && handleSwitchCanvas()}
|
||||
placeholder="Canvas ID"
|
||||
className="px-3 py-1.5 rounded-lg text-sm border"
|
||||
style={{
|
||||
background: 'var(--pc-bg-elevated)',
|
||||
borderColor: 'var(--pc-border)',
|
||||
color: 'var(--pc-text-primary)',
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSwitchCanvas}
|
||||
className="px-3 py-1.5 rounded-lg text-sm font-medium"
|
||||
style={{ background: 'var(--pc-accent)', color: '#fff' }}
|
||||
>
|
||||
Switch
|
||||
</button>
|
||||
{canvasList.length > 0 && (
|
||||
<div className="flex items-center gap-1 ml-2">
|
||||
<span className="text-xs" style={{ color: 'var(--pc-text-muted)' }}>Active:</span>
|
||||
{canvasList.map((id) => (
|
||||
<button
|
||||
key={id}
|
||||
onClick={() => {
|
||||
setCanvasIdInput(id);
|
||||
setCanvasId(id);
|
||||
setCurrentFrame(null);
|
||||
setHistory([]);
|
||||
}}
|
||||
className="px-2 py-0.5 rounded text-xs font-mono transition-colors"
|
||||
style={{
|
||||
background: id === canvasId ? 'var(--pc-accent-dim)' : 'var(--pc-bg-elevated)',
|
||||
color: id === canvasId ? 'var(--pc-accent)' : 'var(--pc-text-muted)',
|
||||
borderColor: 'var(--pc-border)',
|
||||
}}
|
||||
>
|
||||
{id}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Main content area */}
|
||||
<div className="flex-1 flex gap-4 min-h-0">
|
||||
{/* Canvas viewer */}
|
||||
<div
|
||||
className="flex-1 rounded-lg border overflow-hidden"
|
||||
style={{ borderColor: 'var(--pc-border)', background: '#1a1a2e' }}
|
||||
>
|
||||
{currentFrame ? (
|
||||
<iframe
|
||||
ref={iframeRef}
|
||||
sandbox="allow-scripts"
|
||||
className="w-full h-full border-0"
|
||||
title={`Canvas: ${canvasId}`}
|
||||
style={{ background: '#1a1a2e' }}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-full">
|
||||
<div className="text-center">
|
||||
<Monitor
|
||||
className="h-12 w-12 mx-auto mb-3 opacity-30"
|
||||
style={{ color: 'var(--pc-text-muted)' }}
|
||||
/>
|
||||
<p className="text-sm" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
Waiting for content on canvas <span className="font-mono">"{canvasId}"</span>
|
||||
</p>
|
||||
<p className="text-xs mt-1" style={{ color: 'var(--pc-text-muted)', opacity: 0.6 }}>
|
||||
The agent can push content here using the canvas tool
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* History panel */}
|
||||
{showHistory && (
|
||||
<div
|
||||
className="w-64 rounded-lg border overflow-y-auto"
|
||||
style={{
|
||||
borderColor: 'var(--pc-border)',
|
||||
background: 'var(--pc-bg-elevated)',
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="px-3 py-2 border-b text-xs font-medium sticky top-0"
|
||||
style={{
|
||||
borderColor: 'var(--pc-border)',
|
||||
background: 'var(--pc-bg-elevated)',
|
||||
color: 'var(--pc-text-muted)',
|
||||
}}
|
||||
>
|
||||
Frame History ({history.length})
|
||||
</div>
|
||||
{history.length === 0 ? (
|
||||
<p className="p-3 text-xs" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
No frames yet
|
||||
</p>
|
||||
) : (
|
||||
<div className="space-y-1 p-2">
|
||||
{[...history].reverse().map((frame) => (
|
||||
<button
|
||||
key={frame.frame_id}
|
||||
onClick={() => handleSelectHistoryFrame(frame)}
|
||||
className="w-full text-left px-2 py-1.5 rounded text-xs transition-colors"
|
||||
style={{
|
||||
background:
|
||||
currentFrame?.frame_id === frame.frame_id
|
||||
? 'var(--pc-accent-dim)'
|
||||
: 'transparent',
|
||||
color: 'var(--pc-text-primary)',
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="font-mono truncate" style={{ color: 'var(--pc-accent)' }}>
|
||||
{frame.content_type}
|
||||
</span>
|
||||
<span style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{new Date(frame.timestamp).toLocaleTimeString()}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
className="truncate mt-0.5"
|
||||
style={{ color: 'var(--pc-text-muted)', fontSize: '0.65rem' }}
|
||||
>
|
||||
{frame.content.substring(0, 60)}
|
||||
{frame.content.length > 60 ? '...' : ''}
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Frame info bar */}
|
||||
{currentFrame && (
|
||||
<div
|
||||
className="flex items-center justify-between px-3 py-1.5 rounded-lg text-xs"
|
||||
style={{ background: 'var(--pc-bg-elevated)', color: 'var(--pc-text-muted)' }}
|
||||
>
|
||||
<span>
|
||||
Type: <span className="font-mono">{currentFrame.content_type}</span> | Frame:{' '}
|
||||
<span className="font-mono">{currentFrame.frame_id.substring(0, 8)}</span>
|
||||
</span>
|
||||
<span>{new Date(currentFrame.timestamp).toLocaleString()}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
+126
-10
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import {
|
||||
Settings,
|
||||
Save,
|
||||
@@ -9,6 +9,91 @@ import {
|
||||
import { getConfig, putConfig } from '@/lib/api';
|
||||
import { t } from '@/lib/i18n';
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Lightweight zero-dependency TOML syntax highlighter.
|
||||
// Produces an HTML string. The <pre> overlay sits behind the <textarea> so
|
||||
// the textarea remains the editable surface; the pre just provides colour.
|
||||
// ---------------------------------------------------------------------------
|
||||
function highlightToml(raw: string): string {
|
||||
const lines = raw.split('\n');
|
||||
const result: string[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
const escaped = line
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>');
|
||||
|
||||
// Section header [section] or [[array]]
|
||||
if (/^\s*\[/.test(escaped)) {
|
||||
result.push(`<span style="color:#67e8f9;font-weight:600">${escaped}</span>`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Comment line
|
||||
if (/^\s*#/.test(escaped)) {
|
||||
result.push(`<span style="color:#52525b;font-style:italic">${escaped}</span>`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Key = value line
|
||||
const kvMatch = escaped.match(/^(\s*)([A-Za-z0-9_\-.]+)(\s*=\s*)(.*)$/);
|
||||
if (kvMatch) {
|
||||
const [, indent, key, eq, rawValue] = kvMatch;
|
||||
const value = colorValue(rawValue ?? '');
|
||||
result.push(
|
||||
`${indent}<span style="color:#a78bfa">${key}</span>`
|
||||
+ `<span style="color:#71717a">${eq}</span>${value}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
result.push(escaped);
|
||||
}
|
||||
|
||||
return result.join('\n') + '\n';
|
||||
}
|
||||
|
||||
function colorValue(v: string): string {
|
||||
const trimmed = v.trim();
|
||||
const commentIdx = findUnquotedHash(trimmed);
|
||||
if (commentIdx !== -1) {
|
||||
const valueCore = trimmed.slice(0, commentIdx).trimEnd();
|
||||
const comment = `<span style="color:#52525b;font-style:italic">${trimmed.slice(commentIdx)}</span>`;
|
||||
const leading = v.slice(0, v.indexOf(trimmed));
|
||||
return leading + colorScalar(valueCore) + ' ' + comment;
|
||||
}
|
||||
return colorScalar(v);
|
||||
}
|
||||
|
||||
function findUnquotedHash(s: string): number {
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
for (let i = 0; i < s.length; i++) {
|
||||
const c = s[i];
|
||||
if (c === "'" && !inDouble) inSingle = !inSingle;
|
||||
else if (c === '"' && !inSingle) inDouble = !inDouble;
|
||||
else if (c === '#' && !inSingle && !inDouble) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
function colorScalar(v: string): string {
|
||||
const t = v.trim();
|
||||
if (t === 'true' || t === 'false')
|
||||
return `<span style="color:#34d399">${v}</span>`;
|
||||
if (/^-?\d[\d_]*(\.[\d_]*)?([eE][+-]?\d+)?$/.test(t))
|
||||
return `<span style="color:#fbbf24">${v}</span>`;
|
||||
if (t.startsWith('"') || t.startsWith("'"))
|
||||
return `<span style="color:#86efac">${v}</span>`;
|
||||
if (t.startsWith('[') || t.startsWith('{'))
|
||||
return `<span style="color:#e2e8f0">${v}</span>`;
|
||||
if (/^\d{4}-\d{2}-\d{2}/.test(t))
|
||||
return `<span style="color:#fb923c">${v}</span>`;
|
||||
return v;
|
||||
}
|
||||
|
||||
export default function Config() {
|
||||
const [config, setConfig] = useState('');
|
||||
const [loading, setLoading] = useState(true);
|
||||
@@ -16,6 +101,16 @@ export default function Config() {
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [success, setSuccess] = useState<string | null>(null);
|
||||
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const preRef = useRef<HTMLPreElement>(null);
|
||||
|
||||
const syncScroll = useCallback(() => {
|
||||
if (preRef.current && textareaRef.current) {
|
||||
preRef.current.scrollTop = textareaRef.current.scrollTop;
|
||||
preRef.current.scrollLeft = textareaRef.current.scrollLeft;
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
getConfig()
|
||||
.then((data) => { setConfig(typeof data === 'string' ? data : JSON.stringify(data, null, 2)); })
|
||||
@@ -53,7 +148,7 @@ export default function Config() {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-6 space-y-6 animate-fade-in">
|
||||
<div className="flex flex-col h-full p-6 gap-6 animate-fade-in overflow-hidden">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
@@ -95,7 +190,7 @@ export default function Config() {
|
||||
)}
|
||||
|
||||
{/* Config Editor */}
|
||||
<div className="card overflow-hidden rounded-2xl">
|
||||
<div className="card overflow-hidden rounded-2xl flex flex-col flex-1 min-h-0">
|
||||
<div className="flex items-center justify-between px-4 py-2.5 border-b" style={{ borderColor: 'var(--pc-border)', background: 'var(--pc-accent-glow)' }}>
|
||||
<span className="text-[10px] font-semibold uppercase tracking-wider" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{t('config.toml_label')}
|
||||
@@ -104,13 +199,34 @@ export default function Config() {
|
||||
{config.split('\n').length} {t('config.lines')}
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
value={config}
|
||||
onChange={(e) => setConfig(e.target.value)}
|
||||
spellCheck={false}
|
||||
className="w-full min-h-[500px] text-sm p-4 resize-y focus:outline-none font-mono"
|
||||
style={{ background: 'var(--pc-bg-base)', color: 'var(--pc-text-secondary)', tabSize: 4 }}
|
||||
/>
|
||||
<div className="relative flex-1 min-h-0 overflow-hidden">
|
||||
<pre
|
||||
ref={preRef}
|
||||
aria-hidden="true"
|
||||
className="absolute inset-0 text-sm p-4 font-mono overflow-auto whitespace-pre pointer-events-none m-0"
|
||||
style={{ background: 'var(--pc-bg-base)', tabSize: 4 }}
|
||||
dangerouslySetInnerHTML={{ __html: highlightToml(config) }}
|
||||
/>
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
value={config}
|
||||
onChange={(e) => setConfig(e.target.value)}
|
||||
onScroll={syncScroll}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Tab') {
|
||||
e.preventDefault();
|
||||
const el = e.currentTarget;
|
||||
const start = el.selectionStart;
|
||||
const end = el.selectionEnd;
|
||||
setConfig(config.slice(0, start) + ' ' + config.slice(end));
|
||||
requestAnimationFrame(() => { el.selectionStart = el.selectionEnd = start + 2; });
|
||||
}
|
||||
}}
|
||||
spellCheck={false}
|
||||
className="absolute inset-0 w-full h-full text-sm p-4 resize-none focus:outline-none font-mono caret-white"
|
||||
style={{ background: 'transparent', color: 'transparent', tabSize: 4 }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -293,7 +293,7 @@ export default function Cron() {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-6 space-y-6 animate-fade-in">
|
||||
<div className="flex flex-col h-full p-6 gap-6 animate-fade-in overflow-hidden">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
@@ -415,7 +415,7 @@ export default function Cron() {
|
||||
<p style={{ color: 'var(--pc-text-muted)' }}>{t('cron.empty')}</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="card overflow-x-auto rounded-2xl">
|
||||
<div className="card overflow-auto rounded-2xl flex-1 min-h-0">
|
||||
<table className="table-electric">
|
||||
<thead>
|
||||
<tr>
|
||||
|
||||
@@ -236,7 +236,7 @@ export default function Dashboard() {
|
||||
className="text-sm font-semibold uppercase tracking-wider"
|
||||
style={{ color: "var(--pc-text-primary)" }}
|
||||
>
|
||||
{t("dashboard.active_channels")}
|
||||
{t("dashboard.channels")}
|
||||
</h2>
|
||||
<button
|
||||
onClick={() => setShowAllChannels((v) => !v)}
|
||||
@@ -265,7 +265,7 @@ export default function Dashboard() {
|
||||
: t("dashboard.filter_active")}
|
||||
</button>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="space-y-2 overflow-y-auto max-h-48 pr-1">
|
||||
{Object.entries(status.channels).length === 0 ? (
|
||||
<p className="text-sm" style={{ color: "var(--pc-text-faint)" }}>
|
||||
{t("dashboard.no_channels")}
|
||||
|
||||
+39
-11
@@ -24,7 +24,13 @@ function eventTypeStyle(type: string): { color: string; bg: string; border: stri
|
||||
return { color: 'var(--color-status-warning)', bg: 'rgba(255, 170, 0, 0.06)', border: 'rgba(255, 170, 0, 0.2)' };
|
||||
case 'tool_call':
|
||||
case 'tool_result':
|
||||
case 'tool_call_start':
|
||||
return { color: '#a78bfa', bg: 'rgba(167, 139, 250, 0.06)', border: 'rgba(167, 139, 250, 0.2)' };
|
||||
case 'llm_request':
|
||||
return { color: '#38bdf8', bg: 'rgba(56, 189, 248, 0.06)', border: 'rgba(56, 189, 248, 0.2)' };
|
||||
case 'agent_start':
|
||||
case 'agent_end':
|
||||
return { color: '#34d399', bg: 'rgba(52, 211, 153, 0.06)', border: 'rgba(52, 211, 153, 0.2)' };
|
||||
case 'message':
|
||||
case 'chat':
|
||||
return { color: 'var(--pc-accent)', bg: 'var(--pc-accent-glow)', border: 'var(--pc-accent-dim)' };
|
||||
@@ -43,6 +49,7 @@ export default function Logs() {
|
||||
const [paused, setPaused] = useState(false);
|
||||
const [connected, setConnected] = useState(false);
|
||||
const [autoScroll, setAutoScroll] = useState(true);
|
||||
const [infoDismissed, setInfoDismissed] = useState(false);
|
||||
const [typeFilters, setTypeFilters] = useState<Set<string>>(new Set());
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const sseRef = useRef<SSEClient | null>(null);
|
||||
@@ -121,21 +128,12 @@ export default function Logs() {
|
||||
const filteredEntries = typeFilters.size === 0 ? entries : entries.filter((e) => typeFilters.has(e.event.type));
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-[calc(100vh-3.5rem)]">
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Toolbar */}
|
||||
<div className="flex items-center justify-between px-6 py-3 border-b animate-fade-in" style={{ borderColor: 'var(--pc-border)', background: 'var(--pc-bg-surface)' }}>
|
||||
<div className="flex items-center gap-3">
|
||||
<Activity className="h-5 w-5" style={{ color: 'var(--pc-accent)' }} />
|
||||
<h2 className="text-sm font-semibold uppercase tracking-wider" style={{ color: 'var(--pc-text-primary)' }}>{t('logs.live_logs')}</h2>
|
||||
<div className="flex items-center gap-2 ml-2">
|
||||
<span className="status-dot" style={
|
||||
connected ? { background: 'var(--color-status-success)', boxShadow: '0 0 6px var(--color-status-success)' } : { background: 'var(--color-status-error)', boxShadow: '0 0 6px var(--color-status-error)' }
|
||||
}
|
||||
/>
|
||||
<span className="text-[10px]" style={{ color: 'var(--pc-text-faint)' }}>
|
||||
{connected ? t('logs.connected') : t('logs.disconnected')}
|
||||
</span>
|
||||
</div>
|
||||
<span className="text-[10px] font-mono ml-2" style={{ color: 'var(--pc-text-faint)' }}>
|
||||
{filteredEntries.length} {t('logs.events')}
|
||||
</span>
|
||||
@@ -195,11 +193,32 @@ export default function Logs() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Informational banner — what appears here and what does not */}
|
||||
{!infoDismissed && (
|
||||
<div className="flex items-start gap-3 px-6 py-3 border-b flex-shrink-0" style={{ borderColor: 'rgba(56, 189, 248, 0.2)', background: 'rgba(56, 189, 248, 0.05)' }}>
|
||||
<div className="flex-1 text-xs" style={{ color: 'var(--pc-text-secondary)' }}>
|
||||
<span className="font-semibold" style={{ color: '#38bdf8' }}>What appears here: </span>
|
||||
agent activity over SSE — LLM requests, tool calls, agent start/end, and errors.
|
||||
{' '}<span className="font-semibold" style={{ color: 'var(--pc-text-muted)' }}>What does not: </span>
|
||||
daemon stdout and <code>RUST_LOG</code> tracing output go to the terminal or log file, not this stream.
|
||||
{' '}To see tracing logs, run the daemon with <code>RUST_LOG=info zeroclaw</code> and check your terminal.
|
||||
</div>
|
||||
<button
|
||||
onClick={() => setInfoDismissed(true)}
|
||||
className="flex-shrink-0 text-[10px] btn-icon"
|
||||
aria-label="Dismiss"
|
||||
style={{ color: 'var(--pc-text-faint)' }}
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Log entries */}
|
||||
<div
|
||||
ref={containerRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 overflow-y-auto p-4 space-y-2"
|
||||
className="flex-1 overflow-y-auto p-4 space-y-2 min-h-0"
|
||||
>
|
||||
{filteredEntries.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center h-full animate-fade-in" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
@@ -249,6 +268,15 @@ export default function Logs() {
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
{/* Footer: connection status */}
|
||||
<div className="flex items-center justify-center gap-2 px-6 py-2 border-t flex-shrink-0" style={{ borderColor: 'var(--pc-border)', background: 'var(--pc-bg-surface)' }}>
|
||||
<span className="status-dot" style={
|
||||
connected ? { background: 'var(--color-status-success)', boxShadow: '0 0 6px var(--color-status-success)' } : { background: 'var(--color-status-error)', boxShadow: '0 0 6px var(--color-status-error)' }
|
||||
} />
|
||||
<span className="text-[10px]" style={{ color: 'var(--pc-text-faint)' }}>
|
||||
{connected ? t('logs.connected') : t('logs.disconnected')}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
+109
-83
@@ -16,6 +16,8 @@ export default function Tools() {
|
||||
const [cliTools, setCliTools] = useState<CliTool[]>([]);
|
||||
const [search, setSearch] = useState('');
|
||||
const [expandedTool, setExpandedTool] = useState<string | null>(null);
|
||||
const [agentSectionOpen, setAgentSectionOpen] = useState(true);
|
||||
const [cliSectionOpen, setCliSectionOpen] = useState(true);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
@@ -70,104 +72,128 @@ export default function Tools() {
|
||||
|
||||
{/* Agent Tools Grid */}
|
||||
<div>
|
||||
<div className="flex items-center gap-2 mb-4">
|
||||
<button
|
||||
onClick={() => setAgentSectionOpen((v) => !v)}
|
||||
className="flex items-center gap-2 mb-4 w-full text-left group"
|
||||
style={{ background: 'transparent', border: 'none', cursor: 'pointer', padding: 0 }}
|
||||
aria-expanded={agentSectionOpen}
|
||||
aria-controls="agent-tools-section"
|
||||
>
|
||||
<Wrench className="h-5 w-5" style={{ color: 'var(--pc-accent)' }} />
|
||||
<h2 className="text-sm font-semibold uppercase tracking-wider" style={{ color: 'var(--pc-text-primary)' }}>
|
||||
<span className="text-sm font-semibold uppercase tracking-wider flex-1" role="heading" aria-level={2} style={{ color: 'var(--pc-text-primary)' }}>
|
||||
{t('tools.agent_tools')} ({filtered.length})
|
||||
</h2>
|
||||
</div>
|
||||
</span>
|
||||
<ChevronDown
|
||||
className="h-4 w-4 opacity-40 group-hover:opacity-100"
|
||||
style={{ color: 'var(--pc-text-muted)', transform: agentSectionOpen ? 'rotate(0deg)' : 'rotate(-90deg)', transition: 'transform 0.2s ease, opacity 0.2s ease' }}
|
||||
/>
|
||||
</button>
|
||||
|
||||
{filtered.length === 0 ? (
|
||||
<p className="text-sm" style={{ color: 'var(--pc-text-muted)' }}>{t('tools.empty')}</p>
|
||||
) : (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 xl:grid-cols-3 gap-4 stagger-children">
|
||||
{filtered.map((tool) => {
|
||||
const isExpanded = expandedTool === tool.name;
|
||||
return (
|
||||
<div
|
||||
key={tool.name}
|
||||
className="card overflow-hidden animate-slide-in-up"
|
||||
>
|
||||
<button
|
||||
onClick={() => setExpandedTool(isExpanded ? null : tool.name)}
|
||||
className="w-full text-left p-4 transition-all"
|
||||
style={{ background: 'transparent' }}
|
||||
onMouseEnter={(e) => { e.currentTarget.style.background = 'var(--pc-hover)'; }}
|
||||
onMouseLeave={(e) => { e.currentTarget.style.background = 'transparent'; }}
|
||||
<div id="agent-tools-section">
|
||||
{agentSectionOpen && (filtered.length === 0 ? (
|
||||
<p className="text-sm" style={{ color: 'var(--pc-text-muted)' }}>{t('tools.empty')}</p>
|
||||
) : (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 xl:grid-cols-3 gap-4 stagger-children">
|
||||
{filtered.map((tool) => {
|
||||
const isExpanded = expandedTool === tool.name;
|
||||
return (
|
||||
<div
|
||||
key={tool.name}
|
||||
className="card overflow-hidden animate-slide-in-up"
|
||||
>
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex items-center gap-2 min-w-0">
|
||||
<Package className="h-4 w-4 flex-shrink-0" style={{ color: 'var(--pc-accent)' }} />
|
||||
<h3 className="text-sm font-semibold truncate" style={{ color: 'var(--pc-text-primary)' }}>{tool.name}</h3>
|
||||
<button
|
||||
onClick={() => setExpandedTool(isExpanded ? null : tool.name)}
|
||||
className="w-full text-left p-4 transition-all"
|
||||
style={{ background: 'transparent' }}
|
||||
onMouseEnter={(e) => { e.currentTarget.style.background = 'var(--pc-hover)'; }}
|
||||
onMouseLeave={(e) => { e.currentTarget.style.background = 'transparent'; }}
|
||||
>
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex items-center gap-2 min-w-0">
|
||||
<Package className="h-4 w-4 flex-shrink-0" style={{ color: 'var(--pc-accent)' }} />
|
||||
<h3 className="text-sm font-semibold truncate" style={{ color: 'var(--pc-text-primary)' }}>{tool.name}</h3>
|
||||
</div>
|
||||
{isExpanded
|
||||
? <ChevronDown className="h-4 w-4 flex-shrink-0" style={{ color: 'var(--pc-accent)' }} />
|
||||
: <ChevronRight className="h-4 w-4 flex-shrink-0" style={{ color: 'var(--pc-text-faint)' }} />
|
||||
}
|
||||
</div>
|
||||
{isExpanded
|
||||
? <ChevronDown className="h-4 w-4 flex-shrink-0" style={{ color: 'var(--pc-accent)' }} />
|
||||
: <ChevronRight className="h-4 w-4 flex-shrink-0" style={{ color: 'var(--pc-text-faint)' }} />
|
||||
}
|
||||
</div>
|
||||
<p className="text-sm mt-2 line-clamp-2" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{tool.description}
|
||||
</p>
|
||||
</button>
|
||||
|
||||
{isExpanded && tool.parameters && (
|
||||
<div className="border-t p-4 animate-fade-in" style={{ borderColor: 'var(--pc-border)' }}>
|
||||
<p className="text-[10px] font-semibold uppercase tracking-wider mb-2" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{t('tools.parameter_schema')}
|
||||
<p className="text-sm mt-2 line-clamp-2" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{tool.description}
|
||||
</p>
|
||||
<pre className="text-xs rounded-xl p-3 overflow-x-auto max-h-64 overflow-y-auto font-mono" style={{ background: 'var(--pc-bg-base)', color: 'var(--pc-text-secondary)' }}>
|
||||
{JSON.stringify(tool.parameters, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
|
||||
{isExpanded && tool.parameters && (
|
||||
<div className="border-t p-4 animate-fade-in" style={{ borderColor: 'var(--pc-border)' }}>
|
||||
<p className="text-[10px] font-semibold uppercase tracking-wider mb-2" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{t('tools.parameter_schema')}
|
||||
</p>
|
||||
<pre className="text-xs rounded-xl p-3 overflow-x-auto max-h-64 overflow-y-auto font-mono" style={{ background: 'var(--pc-bg-base)', color: 'var(--pc-text-secondary)' }}>
|
||||
{JSON.stringify(tool.parameters, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* CLI Tools Section */}
|
||||
{filteredCli.length > 0 && (
|
||||
<div className="animate-slide-in-up" style={{ animationDelay: '200ms' }}>
|
||||
<div className="flex items-center gap-2 mb-4">
|
||||
<button
|
||||
onClick={() => setCliSectionOpen((v) => !v)}
|
||||
className="flex items-center gap-2 mb-4 w-full text-left group"
|
||||
style={{ background: 'transparent', border: 'none', cursor: 'pointer', padding: 0 }}
|
||||
aria-expanded={cliSectionOpen}
|
||||
aria-controls="cli-tools-section"
|
||||
>
|
||||
<Terminal className="h-5 w-5" style={{ color: 'var(--color-status-success)' }} />
|
||||
<h2 className="text-sm font-semibold uppercase tracking-wider" style={{ color: 'var(--pc-text-primary)' }}>
|
||||
<span className="text-sm font-semibold uppercase tracking-wider flex-1" role="heading" aria-level={2} style={{ color: 'var(--pc-text-primary)' }}>
|
||||
{t('tools.cli_tools')} ({filteredCli.length})
|
||||
</h2>
|
||||
</div>
|
||||
</span>
|
||||
<ChevronDown
|
||||
className="h-4 w-4 opacity-40 group-hover:opacity-100"
|
||||
style={{ color: 'var(--pc-text-muted)', transform: cliSectionOpen ? 'rotate(0deg)' : 'rotate(-90deg)', transition: 'transform 0.2s ease, opacity 0.2s ease' }}
|
||||
/>
|
||||
</button>
|
||||
|
||||
<div className="card overflow-hidden rounded-2xl">
|
||||
<table className="table-electric">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{t('tools.name')}</th>
|
||||
<th>{t('tools.path')}</th>
|
||||
<th>{t('tools.version')}</th>
|
||||
<th>{t('tools.category')}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredCli.map((tool) => (
|
||||
<tr key={tool.name}>
|
||||
<td className="font-medium text-sm" style={{ color: 'var(--pc-text-primary)' }}>
|
||||
{tool.name}
|
||||
</td>
|
||||
<td className="font-mono text-xs truncate max-w-[200px]" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{tool.path}
|
||||
</td>
|
||||
<td style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{tool.version ?? '-'}
|
||||
</td>
|
||||
<td>
|
||||
<span className="inline-flex items-center px-2.5 py-0.5 rounded-full text-[10px] font-semibold capitalize border" style={{ borderColor: 'var(--pc-border)', color: 'var(--pc-text-secondary)', background: 'var(--pc-accent-glow)' }}>
|
||||
{tool.category}
|
||||
</span>
|
||||
</td>
|
||||
<div id="cli-tools-section">
|
||||
{cliSectionOpen && <div className="card overflow-hidden rounded-2xl">
|
||||
<table className="table-electric">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{t('tools.name')}</th>
|
||||
<th>{t('tools.path')}</th>
|
||||
<th>{t('tools.version')}</th>
|
||||
<th>{t('tools.category')}</th>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredCli.map((tool) => (
|
||||
<tr key={tool.name}>
|
||||
<td className="font-medium text-sm" style={{ color: 'var(--pc-text-primary)' }}>
|
||||
{tool.name}
|
||||
</td>
|
||||
<td className="font-mono text-xs truncate max-w-[200px]" style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{tool.path}
|
||||
</td>
|
||||
<td style={{ color: 'var(--pc-text-muted)' }}>
|
||||
{tool.version ?? '-'}
|
||||
</td>
|
||||
<td>
|
||||
<span className="inline-flex items-center px-2.5 py-0.5 rounded-full text-[10px] font-semibold capitalize border" style={{ borderColor: 'var(--pc-border)', color: 'var(--pc-text-secondary)', background: 'var(--pc-accent-glow)' }}>
|
||||
{tool.category}
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user