Compare commits
77 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c312180a2 | |||
| a433c37c53 | |||
| af8e805016 | |||
| 9f127f896d | |||
| b98971c635 | |||
| 2ee0229740 | |||
| 0d2b57ee2e | |||
| b85a445955 | |||
| dbd8c77519 | |||
| 34db67428f | |||
| 01d0c6b23a | |||
| 79f0a5ae30 | |||
| 5bdeeba213 | |||
| b5447175ff | |||
| 0dc05771ba | |||
| 10f9ea3454 | |||
| 3ec532bc29 | |||
| f1688c5910 | |||
| fd9f140268 | |||
| b2dccf86eb | |||
| f0c106a938 | |||
| e276e66c05 | |||
| 2300f21315 | |||
| 2575edb1d2 | |||
| d31f2c2d97 | |||
| 9d7f6c5aaf | |||
| f9081fcfa7 | |||
| a4d95dec0e | |||
| 31508b8ec7 | |||
| 3d9069552c | |||
| 3b074041bf | |||
| 9a95318b85 | |||
| ccd572b827 | |||
| 41dd23175f | |||
| 864d754b56 | |||
| ccd52f3394 | |||
| eb01aa451d | |||
| c785b45f2d | |||
| ffb8b81f90 | |||
| 65f856d710 | |||
| 1682620377 | |||
| aa455ae89b | |||
| a9ffd38912 | |||
| 86a0584513 | |||
| abef4c5719 | |||
| 483b2336c4 | |||
| 14cda3bc9a | |||
| 6e8f0fa43c | |||
| a965b129f8 | |||
| c135de41b7 | |||
| 2d2c2ac9e6 | |||
| 5e774bbd70 | |||
| 33015067eb | |||
| 6b10c0b891 | |||
| bf817e30d2 | |||
| 0051a0c296 | |||
| d753de91f1 | |||
| f6b2f61a01 | |||
| 70e7910cb9 | |||
| a8868768e8 | |||
| 67293c50df | |||
| 1646079d25 | |||
| 25b639435f | |||
| 77779844e5 | |||
| f658d5806a | |||
| 7134fe0824 | |||
| 263802b3df | |||
| 3c25fddb2a | |||
| a6a46bdd25 | |||
| 235d4d2f1c | |||
| bd1e8c8e1a | |||
| f81807bff6 | |||
| bb7006313c | |||
| 9a49626376 | |||
| 8b978a721f | |||
| 75b4c1d4a4 | |||
| cb0779d761 |
@@ -1,97 +0,0 @@
|
||||
# Mem0 Integration: Dual-Scope Recall + Per-Turn Memory
|
||||
|
||||
## Context
|
||||
|
||||
Mem0 auto-save works but the integration is missing key features from mem0 best practices: per-turn recall, multi-level scoping, and proper context injection. This causes the bot to "forget" on follow-up turns and not differentiate users.
|
||||
|
||||
## What's Missing (vs mem0 docs)
|
||||
|
||||
1. **Per-turn recall** — only first turn gets memory context, follow-ups get nothing
|
||||
2. **Dual-scope** — no sender vs group distinction. All memories use single hardcoded `user_id`
|
||||
3. **System prompt injection** — memory prepended to user message (pollutes session history)
|
||||
4. **`agent_id` scoping** — mem0 supports agent-level patterns, not used
|
||||
|
||||
## Changes
|
||||
|
||||
### 1. `src/memory/mem0.rs` — Use session_id for multi-level scoping
|
||||
|
||||
Map zeroclaw's `session_id` param to mem0's `user_id`. This enables per-user and per-group memory namespaces without changing the `Memory` trait.
|
||||
|
||||
```rust
|
||||
// Add helper:
|
||||
fn effective_user_id(&self, session_id: Option<&str>) -> &str {
|
||||
session_id.filter(|s| !s.is_empty()).unwrap_or(&self.user_id)
|
||||
}
|
||||
|
||||
// In store(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In recall(): use effective_user_id(session_id) as mem0 user_id
|
||||
// In list(): use effective_user_id(session_id) as mem0 user_id
|
||||
```
|
||||
|
||||
### 2. `src/channels/mod.rs` ~line 2229 — Per-turn dual-scope recall
|
||||
|
||||
Remove `if !had_prior_history` gate. Always recall from both sender scope and group scope (for group chats).
|
||||
|
||||
```rust
|
||||
// Detect group chat
|
||||
let is_group = msg.reply_target.contains("@g.us")
|
||||
|| msg.reply_target.starts_with("group:");
|
||||
|
||||
// Sender-scope recall (always)
|
||||
let sender_context = build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&msg.sender),
|
||||
).await;
|
||||
|
||||
// Group-scope recall (groups only)
|
||||
let group_context = if is_group {
|
||||
build_memory_context(
|
||||
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
|
||||
Some(&history_key),
|
||||
).await
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Merge (deduplicate by checking substring overlap)
|
||||
let memory_context = merge_memory_contexts(&sender_context, &group_context);
|
||||
```
|
||||
|
||||
### 3. `src/channels/mod.rs` ~line 2244 — Inject into system prompt
|
||||
|
||||
Move memory context from user message to system prompt. Re-fetched each turn, doesn't pollute session.
|
||||
|
||||
```rust
|
||||
let mut system_prompt = build_channel_system_prompt(...);
|
||||
if !memory_context.is_empty() {
|
||||
system_prompt.push_str(&format!("\n\n{memory_context}"));
|
||||
}
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
```
|
||||
|
||||
### 4. `src/channels/mod.rs` — Dual-scope auto-save
|
||||
|
||||
Find existing auto-save call. For group messages, store twice:
|
||||
- `store(key, content, category, Some(&msg.sender))` — personal facts
|
||||
- `store(key, content, category, Some(&history_key))` — group context
|
||||
|
||||
Both async, non-blocking. DMs only store to sender scope.
|
||||
|
||||
### 5. `src/memory/mem0.rs` — Add `agent_id` support (optional)
|
||||
|
||||
Pass `self.app_name` as `agent_id` param to mem0 API for agent behavior tracking.
|
||||
|
||||
## Files to Modify
|
||||
|
||||
1. `src/memory/mem0.rs` — session_id → user_id mapping
|
||||
2. `src/channels/mod.rs` — per-turn recall, dual-scope, system prompt injection, dual-scope save
|
||||
|
||||
## Verification
|
||||
|
||||
1. `cargo check --features whatsapp-web,memory-mem0`
|
||||
2. `cargo test --features whatsapp-web,memory-mem0`
|
||||
3. Deploy to Synology
|
||||
4. Test DM: "我鍾意食壽司" → next turn "我鍾意食咩" → should recall
|
||||
5. Test group: Joe says "我鍾意食壽司" → someone else asks "Joe 鍾意食咩" → should recall from group scope
|
||||
6. Check mem0 server logs: GET with `user_id=sender` AND `user_id=group_key`
|
||||
7. Check mem0 server logs: POST with both user_ids for group messages
|
||||
@@ -118,3 +118,7 @@ PROVIDER=openrouter
|
||||
# Optional: Brave Search (requires API key from https://brave.com/search/api)
|
||||
# WEB_SEARCH_PROVIDER=brave
|
||||
# BRAVE_API_KEY=your-brave-search-api-key
|
||||
#
|
||||
# Optional: SearXNG (self-hosted, requires instance URL)
|
||||
# WEB_SEARCH_PROVIDER=searxng
|
||||
# SEARXNG_INSTANCE_URL=https://searx.example.com
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
name: Pub Homebrew Core
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag to publish (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Patch formula only (no push/PR)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
HOMEBREW_UPSTREAM_PR_TOKEN:
|
||||
required: false
|
||||
HOMEBREW_CORE_BOT_TOKEN:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
|
||||
@@ -41,6 +41,14 @@ jobs:
|
||||
echo "Current version: ${current}"
|
||||
echo "Previous version: ${previous}"
|
||||
|
||||
# Skip if stable release workflow will handle this version
|
||||
# (indicated by an existing or imminent stable tag)
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/v${current}" >/dev/null 2>&1; then
|
||||
echo "Stable tag v${current} exists — stable release workflow handles crates.io"
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$current" != "$previous" && -n "$current" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${current}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
@@ -26,22 +26,43 @@ jobs:
|
||||
outputs:
|
||||
version: ${{ steps.ver.outputs.version }}
|
||||
tag: ${{ steps.ver.outputs.tag }}
|
||||
skip: ${{ steps.ver.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Compute beta version
|
||||
id: ver
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
base_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
|
||||
# Skip beta if this is a version bump commit (stable release handles it)
|
||||
commit_msg=$(git log -1 --pretty=format:"%s")
|
||||
if [[ "$commit_msg" =~ ^chore:\ bump\ version ]]; then
|
||||
echo "Version bump commit detected — skipping beta release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Skip beta if a stable tag already exists for this version
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/v${base_version}" >/dev/null 2>&1; then
|
||||
echo "Stable tag v${base_version} exists — skipping beta release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
beta_tag="v${base_version}-beta.${GITHUB_RUN_NUMBER}"
|
||||
echo "version=${base_version}" >> "$GITHUB_OUTPUT"
|
||||
echo "tag=${beta_tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Beta release: ${beta_tag}"
|
||||
|
||||
release-notes:
|
||||
name: Generate Release Notes
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
needs: [version]
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
notes: ${{ steps.notes.outputs.body }}
|
||||
@@ -132,7 +153,8 @@ jobs:
|
||||
|
||||
web:
|
||||
name: Build Web Dashboard
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw'
|
||||
needs: [version]
|
||||
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
name: Release Stable
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v[0-9]+.[0-9]+.[0-9]+" # stable tags only (no -beta suffix)
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
@@ -33,11 +36,22 @@ jobs:
|
||||
- name: Validate semver and Cargo.toml match
|
||||
id: check
|
||||
shell: bash
|
||||
env:
|
||||
INPUT_VERSION: ${{ inputs.version || '' }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
input_version="${{ inputs.version }}"
|
||||
cargo_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
|
||||
|
||||
# Resolve version from tag push or manual input
|
||||
if [[ "$EVENT_NAME" == "push" ]]; then
|
||||
# Tag push: extract version from tag name (v0.5.9 -> 0.5.9)
|
||||
input_version="${REF_NAME#v}"
|
||||
else
|
||||
input_version="$INPUT_VERSION"
|
||||
fi
|
||||
|
||||
if [[ ! "$input_version" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "::error::Version must be semver (X.Y.Z). Got: ${input_version}"
|
||||
exit 1
|
||||
@@ -49,9 +63,13 @@ jobs:
|
||||
fi
|
||||
|
||||
tag="v${input_version}"
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${tag} already exists."
|
||||
exit 1
|
||||
|
||||
# Only check tag existence for manual dispatch (tag push means it already exists)
|
||||
if [[ "$EVENT_NAME" != "push" ]]; then
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${tag} already exists."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "tag=${tag}" >> "$GITHUB_OUTPUT"
|
||||
@@ -286,6 +304,14 @@ jobs:
|
||||
NOTES: ${{ needs.release-notes.outputs.notes }}
|
||||
run: printf '%s\n' "$NOTES" > release-notes.md
|
||||
|
||||
- name: Create tag if manual dispatch
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
TAG: ${{ needs.validate.outputs.tag }}
|
||||
run: |
|
||||
git tag -a "$TAG" -m "zeroclaw $TAG"
|
||||
git push origin "$TAG"
|
||||
|
||||
- name: Create GitHub Release
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
@@ -461,6 +487,16 @@ jobs:
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
homebrew:
|
||||
name: Update Homebrew Core
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-homebrew-core.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
# ── Post-publish: tweet after release + website are live ──────────────
|
||||
# Docker push can be slow; don't let it block the tweet.
|
||||
tweet:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
[workspace]
|
||||
members = [".", "crates/robot-kit", "crates/aardvark-sys"]
|
||||
members = [".", "crates/robot-kit", "crates/aardvark-sys", "apps/tauri"]
|
||||
resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.6"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -231,8 +231,6 @@ channel-matrix = ["dep:matrix-sdk"]
|
||||
channel-lark = ["dep:prost"]
|
||||
channel-feishu = ["channel-lark"] # Alias for Feishu users (Lark and Feishu are the same platform)
|
||||
memory-postgres = ["dep:postgres"]
|
||||
# memory-mem0 = Mem0 (OpenMemory) memory backend via REST API
|
||||
memory-mem0 = []
|
||||
observability-prometheus = ["dep:prometheus"]
|
||||
observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"]
|
||||
peripheral-rpi = ["rppal"]
|
||||
@@ -267,7 +265,6 @@ ci-all = [
|
||||
"channel-matrix",
|
||||
"channel-lark",
|
||||
"memory-postgres",
|
||||
"memory-mem0",
|
||||
"observability-prometheus",
|
||||
"observability-otel",
|
||||
"peripheral-rpi",
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "zeroclaw-desktop"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "ZeroClaw Desktop — Tauri-powered system tray app"
|
||||
publish = false
|
||||
|
||||
[build-dependencies]
|
||||
tauri-build = { version = "2.0", features = [] }
|
||||
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-store = "2.0"
|
||||
tauri-plugin-single-instance = "2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||
tokio = { version = "1.50", features = ["rt-multi-thread", "macros", "sync", "time"] }
|
||||
anyhow = "1.0"
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
objc2 = "0.6"
|
||||
objc2-app-kit = { version = "0.3", features = ["NSApplication", "NSImage", "NSRunningApplication"] }
|
||||
objc2-foundation = { version = "0.3", features = ["NSData"] }
|
||||
|
||||
[features]
|
||||
default = ["custom-protocol"]
|
||||
custom-protocol = ["tauri/custom-protocol"]
|
||||
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
tauri_build::build();
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"$schema": "../gen/schemas/desktop-schema.json",
|
||||
"identifier": "default",
|
||||
"description": "Default capability set for ZeroClaw Desktop",
|
||||
"windows": ["main"],
|
||||
"permissions": [
|
||||
"core:default",
|
||||
"shell:allow-open",
|
||||
"store:allow-get",
|
||||
"store:allow-set",
|
||||
"store:allow-save",
|
||||
"store:allow-load"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"identifier": "desktop",
|
||||
"description": "Desktop-specific permissions for ZeroClaw",
|
||||
"windows": ["main"],
|
||||
"permissions": [
|
||||
"core:default",
|
||||
"shell:allow-open",
|
||||
"shell:allow-execute",
|
||||
"store:allow-get",
|
||||
"store:allow-set",
|
||||
"store:allow-save",
|
||||
"store:allow-load"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"identifier": "mobile",
|
||||
"description": "Mobile-specific permissions for ZeroClaw",
|
||||
"windows": ["main"],
|
||||
"permissions": [
|
||||
"core:default"
|
||||
]
|
||||
}
|
||||
|
After Width: | Height: | Size: 1002 B |
|
After Width: | Height: | Size: 243 B |
|
After Width: | Height: | Size: 243 B |
@@ -0,0 +1,4 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 128 128">
|
||||
<rect width="128" height="128" rx="16" fill="#DC322F"/>
|
||||
<text x="64" y="80" font-size="64" font-family="monospace" font-weight="bold" fill="white" text-anchor="middle">Z</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 251 B |
|
After Width: | Height: | Size: 199 B |
|
After Width: | Height: | Size: 208 B |
|
After Width: | Height: | Size: 168 B |
|
After Width: | Height: | Size: 201 B |
@@ -0,0 +1,17 @@
|
||||
use crate::gateway_client::GatewayClient;
|
||||
use crate::state::SharedState;
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn send_message(
|
||||
state: State<'_, SharedState>,
|
||||
message: String,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let s = state.read().await;
|
||||
let client = GatewayClient::new(&s.gateway_url, s.token.as_deref());
|
||||
drop(s);
|
||||
client
|
||||
.send_webhook_message(&message)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
use crate::gateway_client::GatewayClient;
|
||||
use crate::state::SharedState;
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_channels(state: State<'_, SharedState>) -> Result<serde_json::Value, String> {
|
||||
let s = state.read().await;
|
||||
let client = GatewayClient::new(&s.gateway_url, s.token.as_deref());
|
||||
drop(s);
|
||||
client.get_status().await.map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use crate::gateway_client::GatewayClient;
|
||||
use crate::state::SharedState;
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_status(state: State<'_, SharedState>) -> Result<serde_json::Value, String> {
|
||||
let s = state.read().await;
|
||||
let client = GatewayClient::new(&s.gateway_url, s.token.as_deref());
|
||||
drop(s);
|
||||
client.get_status().await.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_health(state: State<'_, SharedState>) -> Result<bool, String> {
|
||||
let s = state.read().await;
|
||||
let client = GatewayClient::new(&s.gateway_url, s.token.as_deref());
|
||||
drop(s);
|
||||
client.get_health().await.map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub mod agent;
|
||||
pub mod channels;
|
||||
pub mod gateway;
|
||||
pub mod pairing;
|
||||
@@ -0,0 +1,19 @@
|
||||
use crate::gateway_client::GatewayClient;
|
||||
use crate::state::SharedState;
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn initiate_pairing(state: State<'_, SharedState>) -> Result<serde_json::Value, String> {
|
||||
let s = state.read().await;
|
||||
let client = GatewayClient::new(&s.gateway_url, s.token.as_deref());
|
||||
drop(s);
|
||||
client.initiate_pairing().await.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_devices(state: State<'_, SharedState>) -> Result<serde_json::Value, String> {
|
||||
let s = state.read().await;
|
||||
let client = GatewayClient::new(&s.gateway_url, s.token.as_deref());
|
||||
drop(s);
|
||||
client.get_devices().await.map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
//! HTTP client for communicating with the ZeroClaw gateway.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
pub struct GatewayClient {
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) token: Option<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl GatewayClient {
|
||||
pub fn new(base_url: &str, token: Option<&str>) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_default();
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
token: token.map(String::from),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn auth_header(&self) -> Option<String> {
|
||||
self.token.as_ref().map(|t| format!("Bearer {t}"))
|
||||
}
|
||||
|
||||
pub async fn get_status(&self) -> Result<serde_json::Value> {
|
||||
let mut req = self.client.get(format!("{}/api/status", self.base_url));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
let resp = req.send().await.context("status request failed")?;
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn get_health(&self) -> Result<bool> {
|
||||
match self
|
||||
.client
|
||||
.get(format!("{}/health", self.base_url))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => Ok(resp.status().is_success()),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_devices(&self) -> Result<serde_json::Value> {
|
||||
let mut req = self.client.get(format!("{}/api/devices", self.base_url));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
let resp = req.send().await.context("devices request failed")?;
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn initiate_pairing(&self) -> Result<serde_json::Value> {
|
||||
let mut req = self
|
||||
.client
|
||||
.post(format!("{}/api/pairing/initiate", self.base_url));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
let resp = req.send().await.context("pairing request failed")?;
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Check whether the gateway requires pairing.
|
||||
pub async fn requires_pairing(&self) -> Result<bool> {
|
||||
let resp = self
|
||||
.client
|
||||
.get(format!("{}/health", self.base_url))
|
||||
.send()
|
||||
.await
|
||||
.context("health request failed")?;
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
Ok(body["require_pairing"].as_bool().unwrap_or(false))
|
||||
}
|
||||
|
||||
/// Request a new pairing code from the gateway (localhost-only admin endpoint).
|
||||
pub async fn request_new_paircode(&self) -> Result<String> {
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/admin/paircode/new", self.base_url))
|
||||
.send()
|
||||
.await
|
||||
.context("paircode request failed")?;
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
body["pairing_code"]
|
||||
.as_str()
|
||||
.map(String::from)
|
||||
.context("no pairing_code in response")
|
||||
}
|
||||
|
||||
/// Exchange a pairing code for a bearer token.
|
||||
pub async fn pair_with_code(&self, code: &str) -> Result<String> {
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/pair", self.base_url))
|
||||
.header("X-Pairing-Code", code)
|
||||
.send()
|
||||
.await
|
||||
.context("pair request failed")?;
|
||||
if !resp.status().is_success() {
|
||||
anyhow::bail!("pair request returned {}", resp.status());
|
||||
}
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
body["token"]
|
||||
.as_str()
|
||||
.map(String::from)
|
||||
.context("no token in pair response")
|
||||
}
|
||||
|
||||
/// Validate an existing token by calling a protected endpoint.
|
||||
pub async fn validate_token(&self) -> Result<bool> {
|
||||
let mut req = self.client.get(format!("{}/api/status", self.base_url));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
match req.send().await {
|
||||
Ok(resp) => Ok(resp.status().is_success()),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Auto-pair with the gateway: request a new code and exchange it for a token.
|
||||
pub async fn auto_pair(&self) -> Result<String> {
|
||||
let code = self.request_new_paircode().await?;
|
||||
self.pair_with_code(&code).await
|
||||
}
|
||||
|
||||
pub async fn send_webhook_message(&self, message: &str) -> Result<serde_json::Value> {
|
||||
let mut req = self
|
||||
.client
|
||||
.post(format!("{}/webhook", self.base_url))
|
||||
.json(&serde_json::json!({ "message": message }));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
let resp = req.send().await.context("webhook request failed")?;
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn client_creation_no_token() {
|
||||
let client = GatewayClient::new("http://127.0.0.1:42617", None);
|
||||
assert_eq!(client.base_url, "http://127.0.0.1:42617");
|
||||
assert!(client.token.is_none());
|
||||
assert!(client.auth_header().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_creation_with_token() {
|
||||
let client = GatewayClient::new("http://localhost:8080", Some("test-token"));
|
||||
assert_eq!(client.base_url, "http://localhost:8080");
|
||||
assert_eq!(client.token.as_deref(), Some("test-token"));
|
||||
assert_eq!(client.auth_header().unwrap(), "Bearer test-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_custom_url() {
|
||||
let client = GatewayClient::new("https://zeroclaw.example.com:9999", None);
|
||||
assert_eq!(client.base_url, "https://zeroclaw.example.com:9999");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_header_format() {
|
||||
let client = GatewayClient::new("http://localhost", Some("zc_abc123"));
|
||||
assert_eq!(client.auth_header().unwrap(), "Bearer zc_abc123");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_returns_false_for_unreachable_host() {
|
||||
// Connect to a port that should not be listening.
|
||||
let client = GatewayClient::new("http://127.0.0.1:1", None);
|
||||
let result = client.get_health().await.unwrap();
|
||||
assert!(!result, "health should be false for unreachable host");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn status_fails_for_unreachable_host() {
|
||||
let client = GatewayClient::new("http://127.0.0.1:1", None);
|
||||
let result = client.get_status().await;
|
||||
assert!(result.is_err(), "status should fail for unreachable host");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn devices_fails_for_unreachable_host() {
|
||||
let client = GatewayClient::new("http://127.0.0.1:1", None);
|
||||
let result = client.get_devices().await;
|
||||
assert!(result.is_err(), "devices should fail for unreachable host");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pairing_fails_for_unreachable_host() {
|
||||
let client = GatewayClient::new("http://127.0.0.1:1", None);
|
||||
let result = client.initiate_pairing().await;
|
||||
assert!(result.is_err(), "pairing should fail for unreachable host");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_fails_for_unreachable_host() {
|
||||
let client = GatewayClient::new("http://127.0.0.1:1", None);
|
||||
let result = client.send_webhook_message("hello").await;
|
||||
assert!(result.is_err(), "webhook should fail for unreachable host");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
//! Background health polling for the ZeroClaw gateway.
|
||||
|
||||
use crate::gateway_client::GatewayClient;
|
||||
use crate::state::SharedState;
|
||||
use crate::tray::icon;
|
||||
use std::time::Duration;
|
||||
use tauri::{AppHandle, Emitter, Runtime};
|
||||
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Spawn a background task that polls gateway health and updates state + tray.
|
||||
pub fn spawn_health_poller<R: Runtime>(app: AppHandle<R>, state: SharedState) {
|
||||
tauri::async_runtime::spawn(async move {
|
||||
loop {
|
||||
let (url, token) = {
|
||||
let s = state.read().await;
|
||||
(s.gateway_url.clone(), s.token.clone())
|
||||
};
|
||||
|
||||
let client = GatewayClient::new(&url, token.as_deref());
|
||||
let healthy = client.get_health().await.unwrap_or(false);
|
||||
|
||||
let (connected, agent_status) = {
|
||||
let mut s = state.write().await;
|
||||
s.connected = healthy;
|
||||
(s.connected, s.agent_status)
|
||||
};
|
||||
|
||||
// Update the tray icon and tooltip to reflect current state.
|
||||
if let Some(tray) = app.tray_by_id("main") {
|
||||
let _ = tray.set_icon(Some(icon::icon_for_state(connected, agent_status)));
|
||||
let _ = tray.set_tooltip(Some(icon::tooltip_for_state(connected, agent_status)));
|
||||
}
|
||||
|
||||
let _ = app.emit("zeroclaw://status-changed", healthy);
|
||||
|
||||
tokio::time::sleep(POLL_INTERVAL).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
//! ZeroClaw Desktop — Tauri application library.
|
||||
|
||||
pub mod commands;
|
||||
pub mod gateway_client;
|
||||
pub mod health;
|
||||
pub mod state;
|
||||
pub mod tray;
|
||||
|
||||
use gateway_client::GatewayClient;
|
||||
use state::shared_state;
|
||||
use tauri::{Manager, RunEvent};
|
||||
|
||||
/// Attempt to auto-pair with the gateway so the WebView has a valid token
|
||||
/// before the React frontend mounts. Runs on localhost so the admin endpoints
|
||||
/// are accessible without auth.
|
||||
async fn auto_pair(state: &state::SharedState) -> Option<String> {
|
||||
let url = {
|
||||
let s = state.read().await;
|
||||
s.gateway_url.clone()
|
||||
};
|
||||
|
||||
let client = GatewayClient::new(&url, None);
|
||||
|
||||
// Check if gateway is reachable and requires pairing.
|
||||
if !client.requires_pairing().await.unwrap_or(false) {
|
||||
return None; // Pairing disabled — no token needed.
|
||||
}
|
||||
|
||||
// Check if we already have a valid token in state.
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(ref token) = s.token {
|
||||
let authed = GatewayClient::new(&url, Some(token));
|
||||
if authed.validate_token().await.unwrap_or(false) {
|
||||
return Some(token.clone()); // Existing token is valid.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No valid token — auto-pair by requesting a new code and exchanging it.
|
||||
let client = GatewayClient::new(&url, None);
|
||||
match client.auto_pair().await {
|
||||
Ok(token) => {
|
||||
let mut s = state.write().await;
|
||||
s.token = Some(token.clone());
|
||||
Some(token)
|
||||
}
|
||||
Err(_) => None, // Gateway may not be ready yet; health poller will retry.
|
||||
}
|
||||
}
|
||||
|
||||
/// Inject a bearer token into the WebView's localStorage so the React app
|
||||
/// skips the pairing dialog. Uses Tauri's WebviewWindow scripting API.
|
||||
fn inject_token_into_webview<R: tauri::Runtime>(window: &tauri::WebviewWindow<R>, token: &str) {
|
||||
let escaped = token.replace('\\', "\\\\").replace('\'', "\\'");
|
||||
let script = format!("localStorage.setItem('zeroclaw_token', '{escaped}')");
|
||||
// WebviewWindow scripting is the standard Tauri API for running JS in the WebView.
|
||||
let _ = window.eval(&script);
|
||||
}
|
||||
|
||||
/// Set the macOS dock icon programmatically so it shows even in dev builds
|
||||
/// (which don't have a proper .app bundle).
|
||||
#[cfg(target_os = "macos")]
|
||||
fn set_dock_icon() {
|
||||
use objc2::{AnyThread, MainThreadMarker};
|
||||
use objc2_app_kit::NSApplication;
|
||||
use objc2_app_kit::NSImage;
|
||||
use objc2_foundation::NSData;
|
||||
|
||||
let icon_bytes = include_bytes!("../icons/128x128.png");
|
||||
// Safety: setup() runs on the main thread in Tauri.
|
||||
let mtm = unsafe { MainThreadMarker::new_unchecked() };
|
||||
let data = NSData::with_bytes(icon_bytes);
|
||||
if let Some(image) = NSImage::initWithData(NSImage::alloc(), &data) {
|
||||
let app = NSApplication::sharedApplication(mtm);
|
||||
unsafe { app.setApplicationIconImage(Some(&image)) };
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure and run the Tauri application.
|
||||
pub fn run() {
|
||||
let shared = shared_state();
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_store::Builder::default().build())
|
||||
.plugin(tauri_plugin_single_instance::init(|app, _args, _cwd| {
|
||||
// When a second instance launches, focus the existing window.
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
let _ = window.show();
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
}))
|
||||
.manage(shared.clone())
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
commands::gateway::get_status,
|
||||
commands::gateway::get_health,
|
||||
commands::channels::list_channels,
|
||||
commands::pairing::initiate_pairing,
|
||||
commands::pairing::get_devices,
|
||||
commands::agent::send_message,
|
||||
])
|
||||
.setup(move |app| {
|
||||
// Set macOS dock icon (needed for dev builds without .app bundle).
|
||||
#[cfg(target_os = "macos")]
|
||||
set_dock_icon();
|
||||
|
||||
// Set up the system tray.
|
||||
let _ = tray::setup_tray(app);
|
||||
|
||||
// Auto-pair with gateway and inject token into the WebView.
|
||||
let app_handle = app.handle().clone();
|
||||
let pair_state = shared.clone();
|
||||
tauri::async_runtime::spawn(async move {
|
||||
if let Some(token) = auto_pair(&pair_state).await {
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
inject_token_into_webview(&window, &token);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Start background health polling.
|
||||
health::spawn_health_poller(app.handle().clone(), shared.clone());
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.build(tauri::generate_context!())
|
||||
.expect("error while building tauri application")
|
||||
.run(|_app, event| {
|
||||
// Keep the app running in the background when all windows are closed.
|
||||
// This is the standard pattern for menu bar / tray apps.
|
||||
if let RunEvent::ExitRequested { api, .. } = event {
|
||||
api.prevent_exit();
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
//! ZeroClaw Desktop — main entry point.
|
||||
//!
|
||||
//! Prevents an additional console window on Windows in release.
|
||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||
|
||||
fn main() {
|
||||
zeroclaw_desktop::run();
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
//! Mobile entry point for ZeroClaw Desktop (iOS/Android).
|
||||
|
||||
#[tauri::mobile_entry_point]
|
||||
fn main() {
|
||||
zeroclaw_desktop::run();
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
//! Shared application state for Tauri.
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Agent status as reported by the gateway.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AgentStatus {
|
||||
Idle,
|
||||
Working,
|
||||
Error,
|
||||
}
|
||||
|
||||
/// Shared application state behind an `Arc<RwLock<_>>`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppState {
|
||||
pub gateway_url: String,
|
||||
pub token: Option<String>,
|
||||
pub connected: bool,
|
||||
pub agent_status: AgentStatus,
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gateway_url: "http://127.0.0.1:42617".to_string(),
|
||||
token: None,
|
||||
connected: false,
|
||||
agent_status: AgentStatus::Idle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe wrapper around `AppState`.
|
||||
pub type SharedState = Arc<RwLock<AppState>>;
|
||||
|
||||
/// Create the default shared state.
|
||||
pub fn shared_state() -> SharedState {
|
||||
Arc::new(RwLock::new(AppState::default()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_state() {
|
||||
let state = AppState::default();
|
||||
assert_eq!(state.gateway_url, "http://127.0.0.1:42617");
|
||||
assert!(state.token.is_none());
|
||||
assert!(!state.connected);
|
||||
assert_eq!(state.agent_status, AgentStatus::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shared_state_is_cloneable() {
|
||||
let s1 = shared_state();
|
||||
let s2 = s1.clone();
|
||||
// Both references point to the same allocation.
|
||||
assert!(Arc::ptr_eq(&s1, &s2));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shared_state_concurrent_read_write() {
|
||||
let state = shared_state();
|
||||
|
||||
// Write from one handle.
|
||||
{
|
||||
let mut s = state.write().await;
|
||||
s.connected = true;
|
||||
s.agent_status = AgentStatus::Working;
|
||||
s.token = Some("zc_test".to_string());
|
||||
}
|
||||
|
||||
// Read from cloned handle.
|
||||
let state2 = state.clone();
|
||||
let s = state2.read().await;
|
||||
assert!(s.connected);
|
||||
assert_eq!(s.agent_status, AgentStatus::Working);
|
||||
assert_eq!(s.token.as_deref(), Some("zc_test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_status_serialization() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&AgentStatus::Idle).unwrap(),
|
||||
"\"idle\""
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::to_string(&AgentStatus::Working).unwrap(),
|
||||
"\"working\""
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::to_string(&AgentStatus::Error).unwrap(),
|
||||
"\"error\""
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//! Tray menu event handling.
|
||||
|
||||
use tauri::{menu::MenuEvent, AppHandle, Manager, Runtime};
|
||||
|
||||
pub fn handle_menu_event<R: Runtime>(app: &AppHandle<R>, event: MenuEvent) {
|
||||
match event.id().as_ref() {
|
||||
"show" => show_main_window(app, None),
|
||||
"chat" => show_main_window(app, Some("/agent")),
|
||||
"quit" => {
|
||||
app.exit(0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn show_main_window<R: Runtime>(app: &AppHandle<R>, navigate_to: Option<&str>) {
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
let _ = window.show();
|
||||
let _ = window.set_focus();
|
||||
if let Some(path) = navigate_to {
|
||||
let script = format!("window.location.hash = '{path}'");
|
||||
let _ = window.eval(&script);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
//! Tray icon management — swap icon based on connection/agent status.
|
||||
|
||||
use crate::state::AgentStatus;
|
||||
use tauri::image::Image;
|
||||
|
||||
/// Embedded tray icon PNGs (22x22, RGBA).
|
||||
const ICON_IDLE: &[u8] = include_bytes!("../../icons/tray-idle.png");
|
||||
const ICON_WORKING: &[u8] = include_bytes!("../../icons/tray-working.png");
|
||||
const ICON_ERROR: &[u8] = include_bytes!("../../icons/tray-error.png");
|
||||
const ICON_DISCONNECTED: &[u8] = include_bytes!("../../icons/tray-disconnected.png");
|
||||
|
||||
/// Select the appropriate tray icon for the current state.
|
||||
pub fn icon_for_state(connected: bool, status: AgentStatus) -> Image<'static> {
|
||||
let bytes: &[u8] = if !connected {
|
||||
ICON_DISCONNECTED
|
||||
} else {
|
||||
match status {
|
||||
AgentStatus::Idle => ICON_IDLE,
|
||||
AgentStatus::Working => ICON_WORKING,
|
||||
AgentStatus::Error => ICON_ERROR,
|
||||
}
|
||||
};
|
||||
Image::from_bytes(bytes).expect("embedded tray icon is a valid PNG")
|
||||
}
|
||||
|
||||
/// Tooltip text for the current state.
|
||||
pub fn tooltip_for_state(connected: bool, status: AgentStatus) -> &'static str {
|
||||
if !connected {
|
||||
return "ZeroClaw — Disconnected";
|
||||
}
|
||||
match status {
|
||||
AgentStatus::Idle => "ZeroClaw — Idle",
|
||||
AgentStatus::Working => "ZeroClaw — Working",
|
||||
AgentStatus::Error => "ZeroClaw — Error",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn icon_disconnected_when_not_connected() {
|
||||
// Should not panic — icon bytes are valid PNGs.
|
||||
let _img = icon_for_state(false, AgentStatus::Idle);
|
||||
let _img = icon_for_state(false, AgentStatus::Working);
|
||||
let _img = icon_for_state(false, AgentStatus::Error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn icon_connected_variants() {
|
||||
let _idle = icon_for_state(true, AgentStatus::Idle);
|
||||
let _working = icon_for_state(true, AgentStatus::Working);
|
||||
let _error = icon_for_state(true, AgentStatus::Error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tooltip_disconnected() {
|
||||
assert_eq!(
|
||||
tooltip_for_state(false, AgentStatus::Idle),
|
||||
"ZeroClaw — Disconnected"
|
||||
);
|
||||
// Agent status is irrelevant when disconnected.
|
||||
assert_eq!(
|
||||
tooltip_for_state(false, AgentStatus::Working),
|
||||
"ZeroClaw — Disconnected"
|
||||
);
|
||||
assert_eq!(
|
||||
tooltip_for_state(false, AgentStatus::Error),
|
||||
"ZeroClaw — Disconnected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tooltip_connected_variants() {
|
||||
assert_eq!(
|
||||
tooltip_for_state(true, AgentStatus::Idle),
|
||||
"ZeroClaw — Idle"
|
||||
);
|
||||
assert_eq!(
|
||||
tooltip_for_state(true, AgentStatus::Working),
|
||||
"ZeroClaw — Working"
|
||||
);
|
||||
assert_eq!(
|
||||
tooltip_for_state(true, AgentStatus::Error),
|
||||
"ZeroClaw — Error"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedded_icons_are_valid_png() {
|
||||
// Verify the PNG signature (first 8 bytes) of each embedded icon.
|
||||
let png_sig: &[u8] = &[0x89, b'P', b'N', b'G', 0x0D, 0x0A, 0x1A, 0x0A];
|
||||
assert!(ICON_IDLE.starts_with(png_sig), "idle icon not valid PNG");
|
||||
assert!(
|
||||
ICON_WORKING.starts_with(png_sig),
|
||||
"working icon not valid PNG"
|
||||
);
|
||||
assert!(ICON_ERROR.starts_with(png_sig), "error icon not valid PNG");
|
||||
assert!(
|
||||
ICON_DISCONNECTED.starts_with(png_sig),
|
||||
"disconnected icon not valid PNG"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
//! Tray menu construction.
|
||||
|
||||
use tauri::{
|
||||
menu::{Menu, MenuItemBuilder, PredefinedMenuItem},
|
||||
App, Runtime,
|
||||
};
|
||||
|
||||
pub fn create_tray_menu<R: Runtime>(app: &App<R>) -> Result<Menu<R>, tauri::Error> {
|
||||
let show = MenuItemBuilder::with_id("show", "Show Dashboard").build(app)?;
|
||||
let chat = MenuItemBuilder::with_id("chat", "Agent Chat").build(app)?;
|
||||
let sep1 = PredefinedMenuItem::separator(app)?;
|
||||
let status = MenuItemBuilder::with_id("status", "Status: Checking...")
|
||||
.enabled(false)
|
||||
.build(app)?;
|
||||
let sep2 = PredefinedMenuItem::separator(app)?;
|
||||
let quit = MenuItemBuilder::with_id("quit", "Quit ZeroClaw").build(app)?;
|
||||
|
||||
Menu::with_items(app, &[&show, &chat, &sep1, &status, &sep2, &quit])
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
//! System tray integration for ZeroClaw Desktop.
|
||||
|
||||
pub mod events;
|
||||
pub mod icon;
|
||||
pub mod menu;
|
||||
|
||||
use tauri::{
|
||||
tray::{TrayIcon, TrayIconBuilder, TrayIconEvent},
|
||||
App, Manager, Runtime,
|
||||
};
|
||||
|
||||
/// Set up the system tray icon and menu.
|
||||
pub fn setup_tray<R: Runtime>(app: &App<R>) -> Result<TrayIcon<R>, tauri::Error> {
|
||||
let menu = menu::create_tray_menu(app)?;
|
||||
|
||||
TrayIconBuilder::with_id("main")
|
||||
.tooltip("ZeroClaw — Disconnected")
|
||||
.icon(icon::icon_for_state(false, crate::state::AgentStatus::Idle))
|
||||
.menu(&menu)
|
||||
.show_menu_on_left_click(false)
|
||||
.on_menu_event(events::handle_menu_event)
|
||||
.on_tray_icon_event(|tray, event| {
|
||||
if let TrayIconEvent::Click { button, .. } = event {
|
||||
if button == tauri::tray::MouseButton::Left {
|
||||
let app = tray.app_handle();
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
let _ = window.show();
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.build(app)
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/dev/crates/tauri-cli/config.schema.json",
|
||||
"productName": "ZeroClaw",
|
||||
"version": "0.6.0",
|
||||
"identifier": "ai.zeroclawlabs.desktop",
|
||||
"build": {
|
||||
"devUrl": "http://127.0.0.1:42617/_app/",
|
||||
"frontendDist": "http://127.0.0.1:42617/_app/"
|
||||
},
|
||||
"app": {
|
||||
"windows": [
|
||||
{
|
||||
"title": "ZeroClaw",
|
||||
"width": 1200,
|
||||
"height": 800,
|
||||
"resizable": true,
|
||||
"fullscreen": false,
|
||||
"visible": false
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": "default-src 'self' http://127.0.0.1:* ws://127.0.0.1:*; connect-src 'self' http://127.0.0.1:* ws://127.0.0.1:*; script-src 'self' 'unsafe-inline' http://127.0.0.1:*; style-src 'self' 'unsafe-inline' http://127.0.0.1:*; img-src 'self' http://127.0.0.1:* data:"
|
||||
}
|
||||
},
|
||||
"bundle": {
|
||||
"active": true,
|
||||
"targets": "all",
|
||||
"icon": [
|
||||
"icons/32x32.png",
|
||||
"icons/128x128.png",
|
||||
"icons/icon.icns",
|
||||
"icons/icon.ico"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Start mem0 + reranker GPU container for ZeroClaw memory backend.
|
||||
#
|
||||
# Required env vars:
|
||||
# MEM0_LLM_API_KEY or ZAI_API_KEY — API key for the LLM used in fact extraction
|
||||
#
|
||||
# Optional env vars (with defaults):
|
||||
# MEM0_LLM_PROVIDER — mem0 LLM provider (default: "openai" i.e. OpenAI-compatible)
|
||||
# MEM0_LLM_MODEL — LLM model for fact extraction (default: "glm-5-turbo")
|
||||
# MEM0_LLM_BASE_URL — LLM API base URL (default: "https://api.z.ai/api/coding/paas/v4")
|
||||
# MEM0_EMBEDDER_MODEL — embedding model (default: "BAAI/bge-m3")
|
||||
# MEM0_EMBEDDER_DIMS — embedding dimensions (default: "1024")
|
||||
# MEM0_EMBEDDER_DEVICE — "cuda", "cpu", or "auto" (default: "cuda")
|
||||
# MEM0_VECTOR_COLLECTION — Qdrant collection name (default: "zeroclaw_mem0")
|
||||
# RERANKER_MODEL — reranker model (default: "BAAI/bge-reranker-v2-m3")
|
||||
# RERANKER_DEVICE — "cuda" or "cpu" (default: "cuda")
|
||||
# MEM0_PORT — mem0 server port (default: 8765)
|
||||
# RERANKER_PORT — reranker server port (default: 8678)
|
||||
# CONTAINER_IMAGE — base container image (default: docker.io/kyuz0/amd-strix-halo-comfyui:latest)
|
||||
# CONTAINER_NAME — container name (default: mem0-gpu)
|
||||
# DATA_DIR — host path for Qdrant data (default: ~/mem0-data)
|
||||
# SCRIPT_DIR — host path for server scripts (default: directory of this script)
|
||||
set -e
|
||||
|
||||
# Resolve script directory for mounting server scripts
|
||||
SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "$0")" && pwd)}"
|
||||
|
||||
# API key — accept either name
|
||||
export MEM0_LLM_API_KEY="${MEM0_LLM_API_KEY:-${ZAI_API_KEY:?MEM0_LLM_API_KEY or ZAI_API_KEY must be set}}"
|
||||
|
||||
# Defaults
|
||||
MEM0_LLM_MODEL="${MEM0_LLM_MODEL:-glm-5-turbo}"
|
||||
MEM0_LLM_BASE_URL="${MEM0_LLM_BASE_URL:-https://api.z.ai/api/coding/paas/v4}"
|
||||
MEM0_PORT="${MEM0_PORT:-8765}"
|
||||
RERANKER_PORT="${RERANKER_PORT:-8678}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE:-docker.io/kyuz0/amd-strix-halo-comfyui:latest}"
|
||||
CONTAINER_NAME="${CONTAINER_NAME:-mem0-gpu}"
|
||||
DATA_DIR="${DATA_DIR:-$HOME/mem0-data}"
|
||||
|
||||
# Stop existing CPU services (if any)
|
||||
kill -9 $(pgrep -f "mem0-server.py") 2>/dev/null || true
|
||||
kill -9 $(pgrep -f "reranker-server.py") 2>/dev/null || true
|
||||
|
||||
# Stop existing container
|
||||
podman stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||
podman rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||
|
||||
podman run -d --name "$CONTAINER_NAME" \
|
||||
--device /dev/dri --device /dev/kfd \
|
||||
--group-add video --group-add render \
|
||||
--restart unless-stopped \
|
||||
-p "$MEM0_PORT:$MEM0_PORT" -p "$RERANKER_PORT:$RERANKER_PORT" \
|
||||
-v "$DATA_DIR":/root/mem0-data:Z \
|
||||
-v "$SCRIPT_DIR/mem0-server.py":/app/mem0-server.py:ro,Z \
|
||||
-v "$SCRIPT_DIR/reranker-server.py":/app/reranker-server.py:ro,Z \
|
||||
-v "$HOME/.cache/huggingface":/root/.cache/huggingface:Z \
|
||||
-e MEM0_LLM_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e ZAI_API_KEY="$MEM0_LLM_API_KEY" \
|
||||
-e MEM0_LLM_MODEL="$MEM0_LLM_MODEL" \
|
||||
-e MEM0_LLM_BASE_URL="$MEM0_LLM_BASE_URL" \
|
||||
${MEM0_LLM_PROVIDER:+-e MEM0_LLM_PROVIDER="$MEM0_LLM_PROVIDER"} \
|
||||
${MEM0_EMBEDDER_MODEL:+-e MEM0_EMBEDDER_MODEL="$MEM0_EMBEDDER_MODEL"} \
|
||||
${MEM0_EMBEDDER_DIMS:+-e MEM0_EMBEDDER_DIMS="$MEM0_EMBEDDER_DIMS"} \
|
||||
${MEM0_EMBEDDER_DEVICE:+-e MEM0_EMBEDDER_DEVICE="$MEM0_EMBEDDER_DEVICE"} \
|
||||
${MEM0_VECTOR_COLLECTION:+-e MEM0_VECTOR_COLLECTION="$MEM0_VECTOR_COLLECTION"} \
|
||||
${RERANKER_MODEL:+-e RERANKER_MODEL="$RERANKER_MODEL"} \
|
||||
${RERANKER_DEVICE:+-e RERANKER_DEVICE="$RERANKER_DEVICE"} \
|
||||
-e RERANKER_PORT="$RERANKER_PORT" \
|
||||
-e RERANKER_URL="http://127.0.0.1:$RERANKER_PORT/rerank" \
|
||||
-e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
|
||||
-e HOME=/root \
|
||||
"$CONTAINER_IMAGE" \
|
||||
bash -c "pip install -q FlagEmbedding mem0ai flask httpx qdrant-client 2>&1 | tail -3; echo '=== Starting reranker (GPU) on :$RERANKER_PORT ==='; python3 /app/reranker-server.py & sleep 3; echo '=== Starting mem0 (GPU) on :$MEM0_PORT ==='; exec python3 /app/mem0-server.py"
|
||||
|
||||
echo "Container started, waiting for init..."
|
||||
sleep 15
|
||||
echo "=== Container logs ==="
|
||||
podman logs "$CONTAINER_NAME" 2>&1 | tail -25
|
||||
echo "=== Port check ==="
|
||||
ss -tlnp | grep "$MEM0_PORT\|$RERANKER_PORT" || echo "Ports not yet ready, check: podman logs $CONTAINER_NAME"
|
||||
@@ -1,288 +0,0 @@
|
||||
"""Minimal OpenMemory-compatible REST server wrapping mem0 Python SDK."""
|
||||
import asyncio
|
||||
import json, os, uuid, httpx
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import FastAPI, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from mem0 import Memory
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
RERANKER_URL = os.environ.get("RERANKER_URL", "http://127.0.0.1:8678/rerank")
|
||||
|
||||
CUSTOM_EXTRACTION_PROMPT = """You are a memory extraction specialist for a Cantonese/Chinese chat assistant.
|
||||
|
||||
Extract ONLY important, persistent facts from the conversation. Rules:
|
||||
1. Extract personal preferences, habits, relationships, names, locations
|
||||
2. Extract decisions, plans, and commitments people make
|
||||
3. SKIP small talk, greetings, reactions ("ok", "哈哈", "係呀")
|
||||
4. SKIP temporary states ("我依家食緊飯") unless they reveal a habit
|
||||
5. Keep facts in the ORIGINAL language (Cantonese/Chinese/English)
|
||||
6. For each fact, note WHO it's about (use their name or identifier if known)
|
||||
7. Merge/update existing facts rather than creating duplicates
|
||||
|
||||
Return a list of facts in JSON format: {"facts": ["fact1", "fact2", ...]}
|
||||
"""
|
||||
|
||||
PROCEDURAL_EXTRACTION_PROMPT = """You are a procedural memory specialist for an AI assistant.
|
||||
|
||||
Extract HOW-TO patterns and reusable procedures from the conversation trace. Rules:
|
||||
1. Identify step-by-step procedures the assistant followed to accomplish a task
|
||||
2. Extract tool usage patterns: which tools were called, in what order, with what arguments
|
||||
3. Capture decision points: why the assistant chose one approach over another
|
||||
4. Note error-recovery patterns: what failed, how it was fixed
|
||||
5. Keep the procedure generic enough to apply to similar future tasks
|
||||
6. Preserve technical details (commands, file paths, API calls) that are reusable
|
||||
7. SKIP greetings, small talk, and conversational filler
|
||||
8. Format each procedure as: "To [goal]: [step1] -> [step2] -> ... -> [result]"
|
||||
|
||||
Return a list of procedures in JSON format: {"facts": ["procedure1", "procedure2", ...]}
|
||||
"""
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
# LLM (for fact extraction when infer=true)
|
||||
MEM0_LLM_PROVIDER = os.environ.get("MEM0_LLM_PROVIDER", "openai") # "openai" (compatible), "anthropic", etc.
|
||||
MEM0_LLM_MODEL = os.environ.get("MEM0_LLM_MODEL", "glm-5-turbo")
|
||||
MEM0_LLM_API_KEY = os.environ.get("MEM0_LLM_API_KEY") or os.environ.get("ZAI_API_KEY", "")
|
||||
MEM0_LLM_BASE_URL = os.environ.get("MEM0_LLM_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
|
||||
|
||||
# Embedder
|
||||
MEM0_EMBEDDER_PROVIDER = os.environ.get("MEM0_EMBEDDER_PROVIDER", "huggingface") # "huggingface", "openai", etc.
|
||||
MEM0_EMBEDDER_MODEL = os.environ.get("MEM0_EMBEDDER_MODEL", "BAAI/bge-m3")
|
||||
MEM0_EMBEDDER_DIMS = int(os.environ.get("MEM0_EMBEDDER_DIMS", "1024"))
|
||||
MEM0_EMBEDDER_DEVICE = os.environ.get("MEM0_EMBEDDER_DEVICE", "cuda") # "cuda", "cpu", "auto"
|
||||
|
||||
# Vector store
|
||||
MEM0_VECTOR_PROVIDER = os.environ.get("MEM0_VECTOR_PROVIDER", "qdrant") # "qdrant", "chroma", etc.
|
||||
MEM0_VECTOR_COLLECTION = os.environ.get("MEM0_VECTOR_COLLECTION", "zeroclaw_mem0")
|
||||
MEM0_VECTOR_PATH = os.environ.get("MEM0_VECTOR_PATH", os.path.expanduser("~/mem0-data/qdrant"))
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": MEM0_LLM_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_LLM_MODEL,
|
||||
"api_key": MEM0_LLM_API_KEY,
|
||||
"openai_base_url": MEM0_LLM_BASE_URL,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": MEM0_EMBEDDER_PROVIDER,
|
||||
"config": {
|
||||
"model": MEM0_EMBEDDER_MODEL,
|
||||
"embedding_dims": MEM0_EMBEDDER_DIMS,
|
||||
"model_kwargs": {"device": MEM0_EMBEDDER_DEVICE},
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": MEM0_VECTOR_PROVIDER,
|
||||
"config": {
|
||||
"collection_name": MEM0_VECTOR_COLLECTION,
|
||||
"embedding_model_dims": MEM0_EMBEDDER_DIMS,
|
||||
"path": MEM0_VECTOR_PATH,
|
||||
},
|
||||
},
|
||||
"custom_fact_extraction_prompt": CUSTOM_EXTRACTION_PROMPT,
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
|
||||
|
||||
def rerank_results(query: str, items: list, top_k: int = 10) -> list:
|
||||
"""Rerank search results using bge-reranker-v2-m3."""
|
||||
if not items:
|
||||
return items
|
||||
documents = [item.get("memory", "") for item in items]
|
||||
try:
|
||||
resp = httpx.post(
|
||||
RERANKER_URL,
|
||||
json={"query": query, "documents": documents, "top_k": top_k},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ranked = resp.json().get("results", [])
|
||||
return [items[r["index"]] for r in ranked]
|
||||
except Exception as e:
|
||||
print(f"Reranker failed, using original order: {e}")
|
||||
return items
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
infer: bool = True
|
||||
app: Optional[str] = None
|
||||
custom_instructions: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/")
|
||||
async def add_memory(req: AddMemoryRequest):
|
||||
# Use client-supplied prompt, fall back to server default, then mem0 SDK default
|
||||
prompt = req.custom_instructions or CUSTOM_EXTRACTION_PROMPT
|
||||
result = await asyncio.to_thread(m.add, req.text, user_id=req.user_id, metadata=req.metadata or {}, prompt=prompt)
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
class ProceduralMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
messages: list[dict]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/memories/procedural")
|
||||
async def add_procedural_memory(req: ProceduralMemoryRequest):
|
||||
"""Store a conversation trace as procedural memory.
|
||||
|
||||
Accepts a list of messages (role/content dicts) representing a full
|
||||
conversation turn including tool calls, then uses mem0's native
|
||||
procedural memory extraction to learn reusable "how to" patterns.
|
||||
"""
|
||||
# Build metadata with procedural type marker
|
||||
meta = {"type": "procedural"}
|
||||
if req.metadata:
|
||||
meta.update(req.metadata)
|
||||
|
||||
# Use mem0's native message list support + procedural prompt
|
||||
result = await asyncio.to_thread(m.add,
|
||||
req.messages,
|
||||
user_id=req.user_id,
|
||||
metadata=meta,
|
||||
prompt=PROCEDURAL_EXTRACTION_PROMPT,
|
||||
)
|
||||
|
||||
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
|
||||
|
||||
|
||||
def _parse_mem0_results(raw_results) -> list:
|
||||
raw = raw_results.get("results", raw_results) if isinstance(raw_results, dict) else raw_results
|
||||
items = []
|
||||
for r in raw:
|
||||
item = r if isinstance(r, dict) else {"memory": str(r)}
|
||||
items.append({
|
||||
"id": item.get("id", str(uuid.uuid4())),
|
||||
"memory": item.get("memory", item.get("text", "")),
|
||||
"created_at": item.get("created_at", datetime.now(timezone.utc).isoformat()),
|
||||
"metadata_": item.get("metadata", {}),
|
||||
})
|
||||
return items
|
||||
|
||||
|
||||
def _parse_iso_timestamp(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 timestamp string, returning None on failure."""
|
||||
try:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _item_created_at(item: dict) -> Optional[datetime]:
|
||||
"""Extract created_at from an item as a timezone-aware datetime."""
|
||||
raw = item.get("created_at")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, datetime):
|
||||
if raw.tzinfo is None:
|
||||
raw = raw.replace(tzinfo=timezone.utc)
|
||||
return raw
|
||||
return _parse_iso_timestamp(str(raw))
|
||||
|
||||
|
||||
def _apply_post_filters(
|
||||
items: list,
|
||||
created_after: Optional[str],
|
||||
created_before: Optional[str],
|
||||
) -> list:
|
||||
"""Filter items by created_after / created_before timestamps (post-query)."""
|
||||
after_dt = _parse_iso_timestamp(created_after) if created_after else None
|
||||
before_dt = _parse_iso_timestamp(created_before) if created_before else None
|
||||
if after_dt is None and before_dt is None:
|
||||
return items
|
||||
filtered = []
|
||||
for item in items:
|
||||
ts = _item_created_at(item)
|
||||
if ts is None:
|
||||
# Keep items without a parseable timestamp
|
||||
filtered.append(item)
|
||||
continue
|
||||
if after_dt and ts < after_dt:
|
||||
continue
|
||||
if before_dt and ts > before_dt:
|
||||
continue
|
||||
filtered.append(item)
|
||||
return filtered
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/")
|
||||
async def list_or_search_memories(
|
||||
user_id: str = Query(...),
|
||||
search_query: Optional[str] = Query(None),
|
||||
size: int = Query(10),
|
||||
rerank: bool = Query(True),
|
||||
created_after: Optional[str] = Query(None),
|
||||
created_before: Optional[str] = Query(None),
|
||||
metadata_filter: Optional[str] = Query(None),
|
||||
):
|
||||
# Build mem0 SDK filters dict from metadata_filter JSON param
|
||||
sdk_filters = None
|
||||
if metadata_filter:
|
||||
try:
|
||||
sdk_filters = json.loads(metadata_filter)
|
||||
except json.JSONDecodeError:
|
||||
sdk_filters = None
|
||||
|
||||
if search_query:
|
||||
# Fetch more results than needed so reranker has candidates to work with
|
||||
fetch_size = min(size * 3, 50)
|
||||
results = await asyncio.to_thread(m.search,
|
||||
search_query,
|
||||
user_id=user_id,
|
||||
limit=fetch_size,
|
||||
filters=sdk_filters,
|
||||
)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
if rerank and items:
|
||||
items = rerank_results(search_query, items, top_k=size)
|
||||
else:
|
||||
items = items[:size]
|
||||
return {"items": items, "total": len(items)}
|
||||
else:
|
||||
results = await asyncio.to_thread(m.get_all,user_id=user_id, filters=sdk_filters)
|
||||
items = _parse_mem0_results(results)
|
||||
items = _apply_post_filters(items, created_after, created_before)
|
||||
return {"items": items, "total": len(items)}
|
||||
|
||||
|
||||
@app.delete("/api/v1/memories/{memory_id}")
|
||||
async def delete_memory(memory_id: str):
|
||||
try:
|
||||
await asyncio.to_thread(m.delete, memory_id)
|
||||
except Exception:
|
||||
pass
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/v1/memories/{memory_id}/history")
|
||||
async def get_memory_history(memory_id: str):
|
||||
"""Return the edit history of a specific memory."""
|
||||
try:
|
||||
history = await asyncio.to_thread(m.history, memory_id)
|
||||
# Normalize to list of dicts
|
||||
entries = []
|
||||
raw = history if isinstance(history, list) else history.get("results", history) if isinstance(history, dict) else [history]
|
||||
for h in raw:
|
||||
entry = h if isinstance(h, dict) else {"event": str(h)}
|
||||
entries.append(entry)
|
||||
return {"memory_id": memory_id, "history": entries}
|
||||
except Exception as e:
|
||||
return {"memory_id": memory_id, "history": [], "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8765)
|
||||
@@ -1,50 +0,0 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from FlagEmbedding import FlagReranker
|
||||
import os, torch
|
||||
|
||||
app = Flask(__name__)
|
||||
reranker = None
|
||||
|
||||
# ── Configurable via environment variables ─────────────────────────
|
||||
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
RERANKER_DEVICE = os.environ.get("RERANKER_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
RERANKER_PORT = int(os.environ.get("RERANKER_PORT", "8678"))
|
||||
|
||||
def get_reranker():
|
||||
global reranker
|
||||
if reranker is None:
|
||||
reranker = FlagReranker(RERANKER_MODEL, use_fp16=True, device=RERANKER_DEVICE)
|
||||
return reranker
|
||||
|
||||
@app.route('/rerank', methods=['POST'])
|
||||
def rerank():
|
||||
data = request.json
|
||||
query = data.get('query', '')
|
||||
documents = data.get('documents', [])
|
||||
top_k = data.get('top_k', len(documents))
|
||||
|
||||
if not query or not documents:
|
||||
return jsonify({'error': 'query and documents required'}), 400
|
||||
|
||||
pairs = [[query, doc] for doc in documents]
|
||||
scores = get_reranker().compute_score(pairs)
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
|
||||
results = sorted(
|
||||
[{'index': i, 'document': doc, 'score': score}
|
||||
for i, (doc, score) in enumerate(zip(documents, scores))],
|
||||
key=lambda x: x['score'], reverse=True
|
||||
)[:top_k]
|
||||
|
||||
return jsonify({'results': results})
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health():
|
||||
return jsonify({'status': 'ok', 'model': RERANKER_MODEL, 'device': RERANKER_DEVICE})
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f'Loading reranker model ({RERANKER_MODEL}) on {RERANKER_DEVICE}...')
|
||||
get_reranker()
|
||||
print(f'Reranker server ready on :{RERANKER_PORT}')
|
||||
app.run(host='0.0.0.0', port=RERANKER_PORT)
|
||||
@@ -10,3 +10,22 @@ default_temperature = 0.7
|
||||
port = 42617
|
||||
host = "[::]"
|
||||
allow_public_bind = true
|
||||
|
||||
# Cost tracking and budget enforcement configuration
|
||||
# Enable to track API usage costs and enforce spending limits
|
||||
[cost]
|
||||
enabled = false
|
||||
daily_limit_usd = 10.0
|
||||
monthly_limit_usd = 100.0
|
||||
warn_at_percent = 80
|
||||
allow_override = false
|
||||
|
||||
# Per-model pricing (USD per 1M tokens)
|
||||
# Uncomment and customize to override default pricing
|
||||
# [cost.prices."anthropic/claude-sonnet-4-20250514"]
|
||||
# input = 3.0
|
||||
# output = 15.0
|
||||
#
|
||||
# [cost.prices."openai/gpt-4o"]
|
||||
# input = 5.0
|
||||
# output = 15.0
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.6
|
||||
pkgver = 0.5.9
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.6.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.6.tar.gz
|
||||
source = zeroclaw-0.5.9.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.9.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.6
|
||||
pkgver=0.5.9
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.6",
|
||||
"version": "0.5.9",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.6/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.9/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# ADR-004: Tool Shared State Ownership Contract
|
||||
|
||||
**Status:** Accepted
|
||||
|
||||
**Date:** 2026-03-22
|
||||
|
||||
**Issue:** [#4057](https://github.com/zeroclaw/zeroclaw/issues/4057)
|
||||
|
||||
## Context
|
||||
|
||||
ZeroClaw tools execute in a multi-client environment where a single daemon
|
||||
process serves requests from multiple connected clients simultaneously. Several
|
||||
tools already maintain long-lived shared state:
|
||||
|
||||
- **`DelegateParentToolsHandle`** (`src/tools/mod.rs`):
|
||||
`Arc<RwLock<Vec<Arc<dyn Tool>>>>` — holds parent tools for delegate agents
|
||||
with no per-client isolation.
|
||||
- **`ChannelMapHandle`** (`src/tools/reaction.rs`):
|
||||
`Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>` — global channel map shared
|
||||
across all clients.
|
||||
- **`CanvasStore`** (`src/tools/canvas.rs`):
|
||||
`Arc<RwLock<HashMap<String, CanvasEntry>>>` — canvas IDs are plain strings
|
||||
with no client namespace.
|
||||
|
||||
These patterns emerged organically. As the tool surface grows and more clients
|
||||
connect concurrently, we need a clear contract governing ownership, identity,
|
||||
isolation, lifecycle, and reload behavior for tool-held shared state. Without
|
||||
this contract, new tools risk introducing data leaks between clients, stale
|
||||
state after config reloads, or inconsistent initialization timing.
|
||||
|
||||
Additional context:
|
||||
|
||||
- The tool registry is immutable after startup, built once in
|
||||
`all_tools_with_runtime()`.
|
||||
- Client identity is currently derived from IP address only
|
||||
(`src/gateway/mod.rs`), which is insufficient for reliable namespacing.
|
||||
- `SecurityPolicy` is scoped per agent, not per client.
|
||||
- `WorkspaceManager` provides some isolation but workspace switching is global.
|
||||
|
||||
## Decision
|
||||
|
||||
### 1. Ownership: May tools own long-lived shared state?
|
||||
|
||||
**Yes.** Tools MAY own long-lived shared state, provided they follow the
|
||||
established **handle pattern**: wrap the state in `Arc<RwLock<T>>` (or
|
||||
`Arc<parking_lot::RwLock<T>>`) and expose a cloneable handle type.
|
||||
|
||||
This pattern is already proven by three independent implementations:
|
||||
|
||||
| Handle | Location | Inner type |
|
||||
|--------|----------|-----------|
|
||||
| `DelegateParentToolsHandle` | `src/tools/mod.rs` | `Vec<Arc<dyn Tool>>` |
|
||||
| `ChannelMapHandle` | `src/tools/reaction.rs` | `HashMap<String, Arc<dyn Channel>>` |
|
||||
| `CanvasStore` | `src/tools/canvas.rs` | `HashMap<String, CanvasEntry>` |
|
||||
|
||||
Tools that need shared state MUST:
|
||||
|
||||
- Define a named handle type alias (e.g., `pub type FooHandle = Arc<RwLock<T>>`).
|
||||
- Accept the handle at construction time rather than creating global state.
|
||||
- Document the concurrency contract in the handle type's doc comment.
|
||||
|
||||
Tools MUST NOT use static mutable state (`lazy_static!`, `OnceCell` with
|
||||
interior mutability) for per-request or per-client data.
|
||||
|
||||
### 2. Identity assignment: Who constructs identity keys?
|
||||
|
||||
**The daemon SHOULD provide identity.** Tools MUST NOT construct their own
|
||||
client identity keys.
|
||||
|
||||
A new `ClientId` type should be introduced (opaque, `Clone + Eq + Hash + Send + Sync`)
|
||||
that the daemon assigns at connection time. This replaces the current approach
|
||||
of using raw IP addresses (`src/gateway/mod.rs:259-306`), which breaks when
|
||||
multiple clients share a NAT address or when proxied connections arrive.
|
||||
|
||||
`ClientId` is passed to tools that require per-client state namespacing as part
|
||||
of the tool execution context. Tools that do not need per-client isolation
|
||||
(e.g., the immutable tool registry) may ignore it.
|
||||
|
||||
The `ClientId` contract:
|
||||
|
||||
- Generated by the gateway layer at connection establishment.
|
||||
- Opaque to tools — tools must not parse or derive meaning from the value.
|
||||
- Stable for the lifetime of a single client session.
|
||||
- Passed through the execution context, not stored globally.
|
||||
|
||||
### 3. Lifecycle: When may tools run startup-style validation?
|
||||
|
||||
**Validation runs once at first registration, and again when config changes
|
||||
are detected.**
|
||||
|
||||
The lifecycle phases are:
|
||||
|
||||
1. **Construction** — tool is instantiated with handles and config. No I/O or
|
||||
validation occurs here.
|
||||
2. **Registration** — tool is registered in the tool registry via
|
||||
`all_tools_with_runtime()`. At this point the tool MAY perform one-time
|
||||
startup validation (e.g., checking that required credentials exist, verifying
|
||||
external service connectivity).
|
||||
3. **Execution** — tool handles individual requests. No re-validation unless
|
||||
the config-change signal fires (see Reload Semantics below).
|
||||
4. **Shutdown** — daemon is stopping. Tools with open resources SHOULD clean up
|
||||
gracefully via `Drop` or an explicit shutdown method.
|
||||
|
||||
Tools MUST NOT perform blocking validation during execution-phase calls.
|
||||
Validation results SHOULD be cached in the tool's handle state and checked
|
||||
via a fast path during execution.
|
||||
|
||||
### 4. Isolation: What must be isolated per client?
|
||||
|
||||
State falls into two categories with different isolation requirements:
|
||||
|
||||
**MUST be isolated per client:**
|
||||
|
||||
- Security-sensitive state: credentials, API keys, quotas, rate-limit counters,
|
||||
per-client authorization decisions.
|
||||
- User-specific session data: conversation context, user preferences,
|
||||
workspace-scoped file paths.
|
||||
|
||||
Isolation mechanism: tools holding per-client state MUST key their internal
|
||||
maps by `ClientId`. The handle pattern naturally supports this by using
|
||||
`HashMap<ClientId, T>` inside the `RwLock`.
|
||||
|
||||
**MAY be shared across clients (with namespace prefixing):**
|
||||
|
||||
- Broadcast/display state: canvas frames (`CanvasStore`), notification channels
|
||||
(`ChannelMapHandle`).
|
||||
- Read-only reference data: tool registry, static configuration, model
|
||||
metadata.
|
||||
|
||||
When shared state uses string keys (e.g., canvas IDs, channel names), tools
|
||||
SHOULD support optional namespace prefixing (e.g., `{client_id}:{canvas_name}`)
|
||||
to allow per-client isolation when needed without mandating it for broadcast
|
||||
use cases.
|
||||
|
||||
Tools MUST NOT store per-client secrets in shared (non-isolated) state
|
||||
structures.
|
||||
|
||||
### 5. Reload semantics: What invalidates prior shared state on config change?
|
||||
|
||||
**Config changes detected via hash comparison MUST invalidate cached
|
||||
validation state.**
|
||||
|
||||
The reload contract:
|
||||
|
||||
- The daemon computes a hash of the tool-relevant config section at startup and
|
||||
after each config reload event.
|
||||
- When the hash changes, the daemon signals affected tools to re-run their
|
||||
registration-phase validation.
|
||||
- Tools MUST treat their cached validation result as stale when signaled and
|
||||
re-validate before the next execution.
|
||||
|
||||
Specific invalidation rules:
|
||||
|
||||
| Config change | Invalidation scope |
|
||||
|--------------|-------------------|
|
||||
| Credential/secret rotation | Per-tool validation cache; per-client credential state |
|
||||
| Tool enable/disable | Full tool registry rebuild via `all_tools_with_runtime()` |
|
||||
| Security policy change | `SecurityPolicy` re-derivation; per-agent policy state |
|
||||
| Workspace directory change | `WorkspaceManager` state; file-path-dependent tool state |
|
||||
| Provider config change | Provider-dependent tools re-validate connectivity |
|
||||
|
||||
Tools MAY retain non-security shared state (e.g., canvas content, channel
|
||||
subscriptions) across config reloads unless the reload explicitly affects that
|
||||
state's validity.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
- **Consistency:** All new tools follow the same handle pattern, making shared
|
||||
state discoverable and auditable.
|
||||
- **Safety:** Per-client isolation of security-sensitive state prevents data
|
||||
leaks in multi-tenant scenarios.
|
||||
- **Clarity:** Explicit lifecycle phases eliminate ambiguity about when
|
||||
validation runs.
|
||||
- **Evolvability:** The `ClientId` abstraction decouples tools from transport
|
||||
details, supporting future identity mechanisms (tokens, certificates).
|
||||
|
||||
### Negative
|
||||
|
||||
- **Migration cost:** Existing tools (`CanvasStore`, `ReactionTool`) may need
|
||||
refactoring to accept `ClientId` and namespace their state.
|
||||
- **Complexity:** Tools that were simple singletons now need to consider
|
||||
multi-client semantics even if they currently have one client.
|
||||
- **Performance:** Per-client keying adds a hash lookup on each access, though
|
||||
this is negligible compared to I/O costs.
|
||||
|
||||
### Neutral
|
||||
|
||||
- The tool registry remains immutable after startup; this ADR does not change
|
||||
that invariant.
|
||||
- `SecurityPolicy` remains per-agent; this ADR documents that client isolation
|
||||
is orthogonal to agent-level policy.
|
||||
|
||||
## References
|
||||
|
||||
- `src/tools/mod.rs` — `DelegateParentToolsHandle`, `all_tools_with_runtime()`
|
||||
- `src/tools/reaction.rs` — `ChannelMapHandle`, `ReactionTool`
|
||||
- `src/tools/canvas.rs` — `CanvasStore`, `CanvasEntry`
|
||||
- `src/tools/traits.rs` — `Tool` trait
|
||||
- `src/gateway/mod.rs` — client IP extraction (`forwarded_client_ip`, `resolve_client_ip`)
|
||||
- `src/security/` — `SecurityPolicy`
|
||||
@@ -0,0 +1,215 @@
|
||||
# Browser Automation Setup Guide
|
||||
|
||||
This guide covers setting up browser automation capabilities in ZeroClaw, including both headless automation and GUI access via VNC.
|
||||
|
||||
## Overview
|
||||
|
||||
ZeroClaw supports multiple browser access methods:
|
||||
|
||||
| Method | Use Case | Requirements |
|
||||
|--------|----------|--------------|
|
||||
| **agent-browser CLI** | Headless automation, AI agents | npm, Chrome |
|
||||
| **VNC + noVNC** | GUI access, debugging | Xvfb, x11vnc, noVNC |
|
||||
| **Chrome Remote Desktop** | Remote GUI via Google | XFCE, Google account |
|
||||
|
||||
## Quick Start: Headless Automation
|
||||
|
||||
### 1. Install agent-browser
|
||||
|
||||
```bash
|
||||
# Install CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux (includes system deps)
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
### 2. Verify ZeroClaw Config
|
||||
|
||||
The browser tool is enabled by default. To verify or customize, edit
|
||||
`~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
### 3. Test
|
||||
|
||||
```bash
|
||||
echo "Open https://example.com and tell me what it says" | zeroclaw agent
|
||||
```
|
||||
|
||||
## VNC Setup (GUI Access)
|
||||
|
||||
For debugging or when you need visual browser access:
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
apt-get install -y xvfb x11vnc fluxbox novnc websockify
|
||||
|
||||
# Optional: Desktop environment for Chrome Remote Desktop
|
||||
apt-get install -y xfce4 xfce4-goodies
|
||||
```
|
||||
|
||||
### Start VNC Server
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access
|
||||
|
||||
DISPLAY_NUM=99
|
||||
VNC_PORT=5900
|
||||
NOVNC_PORT=6080
|
||||
RESOLUTION=1920x1080x24
|
||||
|
||||
# Start Xvfb
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
sleep 1
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
|
||||
echo "VNC available at:"
|
||||
echo " VNC Client: localhost:$VNC_PORT"
|
||||
echo " Web Browser: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
```
|
||||
|
||||
### VNC Access
|
||||
|
||||
- **VNC Client**: Connect to `localhost:5900`
|
||||
- **Web Browser**: Open `http://localhost:6080/vnc.html`
|
||||
|
||||
### Start Browser on VNC Display
|
||||
|
||||
```bash
|
||||
DISPLAY=:99 google-chrome --no-sandbox https://example.com &
|
||||
```
|
||||
|
||||
## Chrome Remote Desktop
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# Download and install
|
||||
wget https://dl.google.com/linux/direct/chrome-remote-desktop_current_amd64.deb
|
||||
apt-get install -y ./chrome-remote-desktop_current_amd64.deb
|
||||
|
||||
# Configure session
|
||||
echo "xfce4-session" > ~/.chrome-remote-desktop-session
|
||||
chmod +x ~/.chrome-remote-desktop-session
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
1. Visit <https://remotedesktop.google.com/headless>
|
||||
2. Copy the "Debian Linux" setup command
|
||||
3. Run it on your server
|
||||
4. Start the service: `systemctl --user start chrome-remote-desktop`
|
||||
|
||||
### Remote Access
|
||||
|
||||
Go to <https://remotedesktop.google.com/access> from any device.
|
||||
|
||||
## Testing
|
||||
|
||||
### CLI Tests
|
||||
|
||||
```bash
|
||||
# Basic open and close
|
||||
agent-browser open https://example.com
|
||||
agent-browser get title
|
||||
agent-browser close
|
||||
|
||||
# Snapshot with refs
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser close
|
||||
|
||||
# Screenshot
|
||||
agent-browser open https://example.com
|
||||
agent-browser screenshot /tmp/test.png
|
||||
agent-browser close
|
||||
```
|
||||
|
||||
### ZeroClaw Integration Tests
|
||||
|
||||
```bash
|
||||
# Content extraction
|
||||
echo "Open https://example.com and summarize it" | zeroclaw agent
|
||||
|
||||
# Navigation
|
||||
echo "Go to https://github.com/trending and list the top 3 repos" | zeroclaw agent
|
||||
|
||||
# Form interaction
|
||||
echo "Go to Wikipedia, search for 'Rust programming language', and summarize" | zeroclaw agent
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Element not found"
|
||||
|
||||
The page may not be fully loaded. Add a wait:
|
||||
|
||||
```bash
|
||||
agent-browser open https://slow-site.com
|
||||
agent-browser wait --load networkidle
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Cookie dialogs blocking access
|
||||
|
||||
Handle cookie consent first:
|
||||
|
||||
```bash
|
||||
agent-browser open https://site-with-cookies.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser click @accept_cookies # Click the accept button
|
||||
agent-browser snapshot -i # Now get the actual content
|
||||
```
|
||||
|
||||
### Docker sandbox network restrictions
|
||||
|
||||
If `web_fetch` fails inside Docker sandbox, use agent-browser instead:
|
||||
|
||||
```bash
|
||||
# Instead of web_fetch, use:
|
||||
agent-browser open https://example.com
|
||||
agent-browser get text body
|
||||
```
|
||||
|
||||
## Security Notes
|
||||
|
||||
- `agent-browser` runs Chrome in headless mode with sandboxing
|
||||
- For sensitive sites, use `--session-name` to persist auth state
|
||||
- The `--allowed-domains` config restricts navigation to specific domains
|
||||
- VNC ports (5900, 6080) should be behind a firewall or Tailscale
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser Documentation](https://github.com/vercel-labs/agent-browser)
|
||||
- [ZeroClaw Configuration Reference](./config-reference.md)
|
||||
- [Skills Documentation](../skills/)
|
||||
@@ -45,6 +45,15 @@ For complete code examples of each extension trait, see [extension-examples.md](
|
||||
- Keep multilingual entry-point parity for all supported locales (`en`, `zh-CN`, `ja`, `ru`, `fr`, `vi`) when nav or key wording changes.
|
||||
- When shared docs wording changes, sync corresponding localized docs in the same PR (or explicitly document deferral and follow-up PR).
|
||||
|
||||
## Tool Shared State
|
||||
|
||||
- Follow the `Arc<RwLock<T>>` handle pattern for any tool that owns long-lived shared state.
|
||||
- Accept handles at construction; do not create global/static mutable state.
|
||||
- Use `ClientId` (provided by the daemon) to namespace per-client state — never construct identity keys inside the tool.
|
||||
- Isolate security-sensitive state (credentials, quotas) per client; broadcast/display state may be shared with optional namespace prefixing.
|
||||
- Cached validation is invalidated on config change — tools must re-validate before the next execution when signaled.
|
||||
- See [ADR-004: Tool Shared State Ownership](../architecture/adr-004-tool-shared-state-ownership.md) for the full contract.
|
||||
|
||||
## Architecture Boundary Rules
|
||||
|
||||
- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features.
|
||||
|
||||
@@ -411,30 +411,6 @@ allowed_roots = [\"~/Desktop/projects\", \"/opt/shared-repo\"]
|
||||
|
||||
- 内存上下文注入忽略旧的 `assistant_resp*` 自动保存键,以防止旧模型生成的摘要被视为事实。
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) 后端 — 连接自托管 mem0 服务器,提供基于向量的记忆存储和 LLM 事实提取。构建时需要 `memory-mem0` feature flag,配置需设置 `backend = "mem0"`。
|
||||
|
||||
| 键 | 默认值 | 环境变量 | 用途 |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory 服务器地址 |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | 记忆作用域的用户 ID |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | 在 mem0 中注册的应用名称 |
|
||||
| `infer` | `true` | — | 使用 LLM 从存储文本中提取事实 (`true`) 或原样存储 (`false`) |
|
||||
| `extraction_prompt` | 未设置 | `MEM0_EXTRACTION_PROMPT` | 自定义 LLM 事实提取提示词(如适用于非英文内容) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "用原始语言提取事实..."
|
||||
```
|
||||
|
||||
服务器部署脚本位于 `deploy/mem0/`。
|
||||
|
||||
## `[[model_routes]]` 和 `[[embedding_routes]]`
|
||||
|
||||
使用路由提示,以便集成可以在模型 ID 演变时保持稳定的名称。
|
||||
|
||||
@@ -12,8 +12,6 @@ SOP 审计条目通过 `SopAuditLogger` 持久化到配置的内存后端的 `so
|
||||
- `sop_step_{run_id}_{step_number}`:单步结果
|
||||
- `sop_approval_{run_id}_{step_number}`:操作员审批记录
|
||||
- `sop_timeout_approve_{run_id}_{step_number}`:超时自动审批记录
|
||||
- `sop_gate_decision_{gate_id}_{timestamp_ms}`:门评估器决策记录(启用 `ampersona-gates` 时)
|
||||
- `sop_phase_state`:持久化的信任阶段状态快照(启用 `ampersona-gates` 时)
|
||||
|
||||
## 2. 检查路径
|
||||
|
||||
|
||||
@@ -508,30 +508,6 @@ Notes:
|
||||
|
||||
- Memory context injection ignores legacy `assistant_resp*` auto-save keys to prevent old model-authored summaries from being treated as facts.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Mem0 (OpenMemory) backend — connects to a self-hosted mem0 server for vector-based memory with LLM-powered fact extraction. Requires feature flag `memory-mem0` at build time and `backend = "mem0"` in config.
|
||||
|
||||
| Key | Default | Env var | Purpose |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory server URL |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID for scoping memories |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Application name registered in mem0 |
|
||||
| `infer` | `true` | — | Use LLM to extract facts from stored text (`true`) or store raw (`false`) |
|
||||
| `extraction_prompt` | unset | `MEM0_EXTRACTION_PROMPT` | Custom prompt for LLM fact extraction (e.g. for non-English content) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Extract facts in the original language..."
|
||||
```
|
||||
|
||||
Server deployment scripts are in `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` and `[[embedding_routes]]`
|
||||
|
||||
Use route hints so integrations can keep stable names while model IDs evolve.
|
||||
|
||||
@@ -12,8 +12,6 @@ Common key patterns:
|
||||
- `sop_step_{run_id}_{step_number}`: per-step result
|
||||
- `sop_approval_{run_id}_{step_number}`: operator approval record
|
||||
- `sop_timeout_approve_{run_id}_{step_number}`: timeout auto-approval record
|
||||
- `sop_gate_decision_{gate_id}_{timestamp_ms}`: gate evaluator decision record (when `ampersona-gates` is enabled)
|
||||
- `sop_phase_state`: persisted trust-phase state snapshot (when `ampersona-gates` is enabled)
|
||||
|
||||
## 2. Inspection Paths
|
||||
|
||||
|
||||
@@ -337,30 +337,6 @@ Lưu ý:
|
||||
|
||||
- Chèn ngữ cảnh memory bỏ qua khóa auto-save `assistant_resp*` kiểu cũ để tránh tóm tắt do model tạo bị coi là sự thật.
|
||||
|
||||
### `[memory.mem0]`
|
||||
|
||||
Backend Mem0 (OpenMemory) — kết nối đến server mem0 tự host, cung cấp bộ nhớ vector với trích xuất sự kiện bằng LLM. Cần feature flag `memory-mem0` khi build và `backend = "mem0"` trong config.
|
||||
|
||||
| Khóa | Mặc định | Biến môi trường | Mục đích |
|
||||
|---|---|---|---|
|
||||
| `url` | `http://localhost:8765` | `MEM0_URL` | URL server OpenMemory |
|
||||
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID để phân vùng memory |
|
||||
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Tên ứng dụng đăng ký trong mem0 |
|
||||
| `infer` | `true` | — | Dùng LLM trích xuất sự kiện từ text (`true`) hoặc lưu nguyên (`false`) |
|
||||
| `extraction_prompt` | chưa đặt | `MEM0_EXTRACTION_PROMPT` | Prompt tùy chỉnh cho trích xuất sự kiện LLM (vd: cho nội dung không phải tiếng Anh) |
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
backend = "mem0"
|
||||
|
||||
[memory.mem0]
|
||||
url = "http://192.168.0.171:8765"
|
||||
user_id = "zeroclaw-bot"
|
||||
extraction_prompt = "Trích xuất sự kiện bằng ngôn ngữ gốc..."
|
||||
```
|
||||
|
||||
Script triển khai server nằm trong `deploy/mem0/`.
|
||||
|
||||
## `[[model_routes]]` và `[[embedding_routes]]`
|
||||
|
||||
Route hint giúp tên tích hợp ổn định khi model ID thay đổi.
|
||||
|
||||
@@ -38,3 +38,74 @@ allowed_tools = ["read", "edit", "exec"]
|
||||
max_iterations = 15
|
||||
# Optional: use longer timeout for complex coding tasks
|
||||
agentic_timeout_secs = 600
|
||||
|
||||
# ── Cron Configuration ────────────────────────────────────────
|
||||
[cron]
|
||||
# Enable the cron subsystem. Default: true
|
||||
enabled = true
|
||||
# Run all overdue jobs at scheduler startup. Default: true
|
||||
catch_up_on_startup = true
|
||||
# Maximum number of historical cron run records to retain. Default: 50
|
||||
max_run_history = 50
|
||||
|
||||
# ── Declarative Cron Jobs ─────────────────────────────────────
|
||||
# Define cron jobs directly in config. These are synced to the database
|
||||
# at scheduler startup. Each job needs a stable `id` for merge semantics.
|
||||
|
||||
# Shell job: runs a shell command on a cron schedule
|
||||
[[cron.jobs]]
|
||||
id = "daily-backup"
|
||||
name = "Daily Backup"
|
||||
job_type = "shell"
|
||||
command = "tar czf /tmp/backup.tar.gz /data"
|
||||
schedule = { kind = "cron", expr = "0 2 * * *" }
|
||||
|
||||
# Agent job: runs an agent prompt on an interval
|
||||
[[cron.jobs]]
|
||||
id = "health-check"
|
||||
name = "Health Check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health: disk space, memory, CPU load"
|
||||
model = "anthropic/claude-sonnet-4"
|
||||
allowed_tools = ["shell", "file_read"]
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
|
||||
# Cron job with timezone and delivery
|
||||
# [[cron.jobs]]
|
||||
# id = "morning-report"
|
||||
# name = "Morning Report"
|
||||
# job_type = "agent"
|
||||
# prompt = "Generate a daily summary of system metrics"
|
||||
# schedule = { kind = "cron", expr = "0 9 * * 1-5", tz = "America/New_York" }
|
||||
# [cron.jobs.delivery]
|
||||
# mode = "announce"
|
||||
# channel = "telegram"
|
||||
# to = "123456789"
|
||||
|
||||
# ── Cost Tracking Configuration ────────────────────────────────
|
||||
[cost]
|
||||
# Enable cost tracking and budget enforcement. Default: false
|
||||
enabled = false
|
||||
# Daily spending limit in USD. Default: 10.0
|
||||
daily_limit_usd = 10.0
|
||||
# Monthly spending limit in USD. Default: 100.0
|
||||
monthly_limit_usd = 100.0
|
||||
# Warn when spending reaches this percentage of limit. Default: 80
|
||||
warn_at_percent = 80
|
||||
# Allow requests to exceed budget with --override flag. Default: false
|
||||
allow_override = false
|
||||
|
||||
# Per-model pricing (USD per 1M tokens).
|
||||
# Built-in defaults exist for popular models; add overrides here.
|
||||
# [cost.prices."anthropic/claude-opus-4-20250514"]
|
||||
# input = 15.0
|
||||
# output = 75.0
|
||||
# [cost.prices."anthropic/claude-sonnet-4-20250514"]
|
||||
# input = 3.0
|
||||
# output = 15.0
|
||||
# [cost.prices."openai/gpt-4o"]
|
||||
# input = 5.0
|
||||
# output = 15.0
|
||||
# [cost.prices."openai/gpt-4o-mini"]
|
||||
# input = 0.15
|
||||
# output = 0.60
|
||||
|
||||
@@ -1416,8 +1416,20 @@ if [[ "$SKIP_BUILD" == false ]]; then
|
||||
step_dot "Cleaning stale build cache (upgrade detected)"
|
||||
cargo clean --release 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Determine cargo feature flags — disable prometheus on 32-bit targets
|
||||
# (prometheus crate requires AtomicU64, unavailable on armv7l/armv6l)
|
||||
CARGO_FEATURE_FLAGS=""
|
||||
_build_arch="$(uname -m)"
|
||||
case "$_build_arch" in
|
||||
armv7l|armv6l|armhf)
|
||||
step_dot "32-bit ARM detected ($_build_arch) — disabling prometheus (requires 64-bit atomics)"
|
||||
CARGO_FEATURE_FLAGS="--no-default-features --features channel-nostr,skill-creation"
|
||||
;;
|
||||
esac
|
||||
|
||||
step_dot "Building release binary"
|
||||
cargo build --release --locked
|
||||
cargo build --release --locked $CARGO_FEATURE_FLAGS
|
||||
step_ok "Release binary built"
|
||||
else
|
||||
step_dot "Skipping build"
|
||||
@@ -1436,7 +1448,7 @@ if [[ "$SKIP_INSTALL" == false ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
cargo install --path "$WORK_DIR" --force --locked
|
||||
cargo install --path "$WORK_DIR" --force --locked $CARGO_FEATURE_FLAGS
|
||||
step_ok "ZeroClaw installed"
|
||||
|
||||
# Sync binary to ~/.local/bin so PATH lookups find the fresh version
|
||||
@@ -1448,6 +1460,84 @@ else
|
||||
step_dot "Skipping install"
|
||||
fi
|
||||
|
||||
# --- Build web dashboard ---
|
||||
if [[ "$SKIP_BUILD" == false && -d "$WORK_DIR/web" ]]; then
|
||||
if have_cmd node && have_cmd npm; then
|
||||
step_dot "Building web dashboard"
|
||||
if (cd "$WORK_DIR/web" && npm ci --ignore-scripts 2>/dev/null && npm run build 2>/dev/null); then
|
||||
step_ok "Web dashboard built"
|
||||
else
|
||||
warn "Web dashboard build failed — dashboard will not be available"
|
||||
fi
|
||||
else
|
||||
warn "node/npm not found — skipping web dashboard build"
|
||||
warn "Install Node.js (>=18) and re-run, or build manually: cd web && npm ci && npm run build"
|
||||
fi
|
||||
else
|
||||
if [[ "$SKIP_BUILD" == true ]]; then
|
||||
step_dot "Skipping web dashboard build"
|
||||
fi
|
||||
fi
|
||||
|
||||
# --- Build desktop app (macOS only) ---
|
||||
if [[ "$SKIP_BUILD" == false && "$OS_NAME" == "Darwin" && -d "$WORK_DIR/apps/tauri" ]]; then
|
||||
echo
|
||||
echo -e "${BOLD}Desktop app preflight${RESET}"
|
||||
|
||||
_desktop_ok=true
|
||||
|
||||
# Check Rust toolchain
|
||||
if have_cmd cargo && have_cmd rustc; then
|
||||
step_ok "Rust $(rustc --version | awk '{print $2}') found"
|
||||
else
|
||||
step_fail "Rust toolchain not found — required for desktop app"
|
||||
_desktop_ok=false
|
||||
fi
|
||||
|
||||
# Check Xcode CLT (needed for linking native frameworks)
|
||||
if xcode-select -p >/dev/null 2>&1; then
|
||||
step_ok "Xcode Command Line Tools installed"
|
||||
else
|
||||
step_fail "Xcode Command Line Tools not found — run: xcode-select --install"
|
||||
_desktop_ok=false
|
||||
fi
|
||||
|
||||
# Check that the Tauri CLI is available (cargo-tauri or tauri-cli)
|
||||
if have_cmd cargo-tauri; then
|
||||
step_ok "cargo-tauri $(cargo tauri --version 2>/dev/null | awk '{print $NF}') found"
|
||||
else
|
||||
step_dot "cargo-tauri not found — installing"
|
||||
if cargo install tauri-cli --locked 2>/dev/null; then
|
||||
step_ok "cargo-tauri installed"
|
||||
else
|
||||
warn "Failed to install cargo-tauri — desktop app build may fail"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check node/npm (needed for web frontend that Tauri embeds)
|
||||
if have_cmd node && have_cmd npm; then
|
||||
step_ok "Node.js $(node --version) found"
|
||||
else
|
||||
warn "node/npm not found — desktop app needs the web dashboard built first"
|
||||
fi
|
||||
|
||||
if [[ "$_desktop_ok" == true ]]; then
|
||||
step_dot "Building desktop app (zeroclaw-desktop)"
|
||||
if cargo build -p zeroclaw-desktop --release --locked 2>/dev/null; then
|
||||
step_ok "Desktop app built"
|
||||
# Copy binary to cargo bin for easy access
|
||||
if [[ -x "$WORK_DIR/target/release/zeroclaw-desktop" ]]; then
|
||||
cp -f "$WORK_DIR/target/release/zeroclaw-desktop" "$HOME/.cargo/bin/zeroclaw-desktop" 2>/dev/null && \
|
||||
step_ok "zeroclaw-desktop installed to ~/.cargo/bin" || true
|
||||
fi
|
||||
else
|
||||
warn "Desktop app build failed — you can build later with: cargo build -p zeroclaw-desktop --release"
|
||||
fi
|
||||
else
|
||||
warn "Skipping desktop app build — fix missing dependencies above and re-run"
|
||||
fi
|
||||
fi
|
||||
|
||||
ZEROCLAW_BIN=""
|
||||
if [[ -x "$HOME/.cargo/bin/zeroclaw" ]]; then
|
||||
ZEROCLAW_BIN="$HOME/.cargo/bin/zeroclaw"
|
||||
@@ -1614,6 +1704,9 @@ echo -e "${BOLD}Next steps:${RESET}"
|
||||
echo -e " ${DIM}zeroclaw status${RESET}"
|
||||
echo -e " ${DIM}zeroclaw agent -m \"Hello, ZeroClaw!\"${RESET}"
|
||||
echo -e " ${DIM}zeroclaw gateway${RESET}"
|
||||
if [[ "$OS_NAME" == "Darwin" ]] && have_cmd zeroclaw-desktop; then
|
||||
echo -e " ${DIM}zeroclaw-desktop${RESET} ${DIM}# Launch the menu bar app${RESET}"
|
||||
fi
|
||||
echo
|
||||
echo -e "${BOLD}Docs:${RESET} ${BLUE}https://www.zeroclawlabs.ai/docs${RESET}"
|
||||
echo
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Start a browser on a virtual display
|
||||
# Usage: ./start-browser.sh [display_num] [url]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
URL=${2:-"https://google.com"}
|
||||
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Check if display is running
|
||||
if ! xdpyinfo -display :$DISPLAY_NUM &>/dev/null; then
|
||||
echo "Error: Display :$DISPLAY_NUM not running."
|
||||
echo "Start VNC first: ./start-vnc.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
google-chrome --no-sandbox --disable-gpu --disable-setuid-sandbox "$URL" &
|
||||
echo "Chrome started on display :$DISPLAY_NUM"
|
||||
echo "View via VNC or noVNC"
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access for browser GUI
|
||||
# Usage: ./start-vnc.sh [display_num] [vnc_port] [novnc_port] [resolution]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
VNC_PORT=${2:-5900}
|
||||
NOVNC_PORT=${3:-6080}
|
||||
RESOLUTION=${4:-1920x1080x24}
|
||||
|
||||
echo "Starting virtual display :$DISPLAY_NUM at $RESOLUTION"
|
||||
|
||||
# Kill any existing sessions
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*$NOVNC_PORT" 2>/dev/null || true
|
||||
sleep 1
|
||||
|
||||
# Start Xvfb (virtual framebuffer)
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
XVFB_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Set DISPLAY
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM 2>/dev/null &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg 2>/dev/null
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC client)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
NOVNC_PID=$!
|
||||
|
||||
echo ""
|
||||
echo "==================================="
|
||||
echo "VNC Server started!"
|
||||
echo "==================================="
|
||||
echo "VNC Direct: localhost:$VNC_PORT"
|
||||
echo "noVNC Web: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
echo "Display: :$DISPLAY_NUM"
|
||||
echo "==================================="
|
||||
echo ""
|
||||
echo "To start a browser, run:"
|
||||
echo " DISPLAY=:$DISPLAY_NUM google-chrome &"
|
||||
echo ""
|
||||
echo "To stop, run: pkill -f 'Xvfb :$DISPLAY_NUM'"
|
||||
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
# Stop virtual display and VNC server
|
||||
# Usage: ./stop-vnc.sh [display_num]
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*6080" 2>/dev/null || true
|
||||
|
||||
echo "VNC server stopped"
|
||||
@@ -77,7 +77,9 @@ echo "Created annotated tag: $TAG"
|
||||
if [[ "$PUSH_TAG" == "true" ]]; then
|
||||
git push origin "$TAG"
|
||||
echo "Pushed tag to origin: $TAG"
|
||||
echo "GitHub release pipeline will run via .github/workflows/pub-release.yml"
|
||||
echo "Release Stable workflow will auto-trigger via tag push."
|
||||
echo "Monitor: gh workflow view 'Release Stable' --web"
|
||||
else
|
||||
echo "Next step: git push origin $TAG"
|
||||
echo "This will auto-trigger the Release Stable workflow (builds, Docker, crates.io, website, Scoop, AUR, Homebrew, tweet)."
|
||||
fi
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
---
|
||||
name: browser
|
||||
description: Headless browser automation using agent-browser CLI
|
||||
metadata: {"zeroclaw":{"emoji":"🌐","requires":{"bins":["agent-browser"]}}}
|
||||
---
|
||||
|
||||
# Browser Skill
|
||||
|
||||
Control a headless browser for web automation, scraping, and testing.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `agent-browser` CLI installed globally (`npm install -g agent-browser`)
|
||||
- Chrome downloaded (`agent-browser install`)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install agent-browser CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Navigate and snapshot
|
||||
|
||||
```bash
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Interact with elements
|
||||
|
||||
```bash
|
||||
agent-browser click @e1 # Click by ref
|
||||
agent-browser fill @e2 "text" # Fill input
|
||||
agent-browser press Enter # Press key
|
||||
```
|
||||
|
||||
### Extract data
|
||||
|
||||
```bash
|
||||
agent-browser get text @e1 # Get text content
|
||||
agent-browser get url # Get current URL
|
||||
agent-browser screenshot page.png # Take screenshot
|
||||
```
|
||||
|
||||
### Session management
|
||||
|
||||
```bash
|
||||
agent-browser close # Close browser
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Login flow
|
||||
|
||||
```bash
|
||||
agent-browser open https://site.com/login
|
||||
agent-browser snapshot -i
|
||||
agent-browser fill @email "user@example.com"
|
||||
agent-browser fill @password "secretpass"
|
||||
agent-browser click @submit
|
||||
agent-browser wait --text "Welcome"
|
||||
```
|
||||
|
||||
### Scrape page content
|
||||
|
||||
```bash
|
||||
agent-browser open https://news.ycombinator.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser get text @e1
|
||||
```
|
||||
|
||||
### Take screenshots
|
||||
|
||||
```bash
|
||||
agent-browser open https://google.com
|
||||
agent-browser screenshot --full page.png
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
- `--json` - JSON output for parsing
|
||||
- `--headed` - Show browser window (for debugging)
|
||||
- `--session-name <name>` - Persist session cookies
|
||||
- `--profile <path>` - Use persistent browser profile
|
||||
|
||||
## Configuration
|
||||
|
||||
The browser tool is enabled by default with `allowed_domains = ["*"]` and
|
||||
`backend = "agent_browser"`. To customize, edit `~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
## Full Command Reference
|
||||
|
||||
Run `agent-browser --help` for all available commands.
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser GitHub](https://github.com/vercel-labs/agent-browser)
|
||||
- [VNC Setup Guide](../docs/browser-setup.md)
|
||||
@@ -12,11 +12,29 @@ use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool, ToolSpec};
|
||||
use anyhow::Result;
|
||||
use chrono::{Datelike, Timelike};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write as IoWrite;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Events emitted during a streamed agent turn.
|
||||
///
|
||||
/// Consumers receive these through a `tokio::sync::mpsc::Sender<TurnEvent>`
|
||||
/// passed to [`Agent::turn_streamed`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TurnEvent {
|
||||
/// A text chunk from the LLM response (may arrive many times).
|
||||
Chunk { delta: String },
|
||||
/// The agent is invoking a tool.
|
||||
ToolCall {
|
||||
name: String,
|
||||
args: serde_json::Value,
|
||||
},
|
||||
/// A tool has returned a result.
|
||||
ToolResult { name: String, output: String },
|
||||
}
|
||||
|
||||
pub struct Agent {
|
||||
provider: Box<dyn Provider>,
|
||||
tools: Vec<Box<dyn Tool>>,
|
||||
@@ -359,22 +377,23 @@ impl Agent {
|
||||
None
|
||||
};
|
||||
|
||||
let (mut tools, delegate_handle, _reaction_handle) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
memory.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
config,
|
||||
None,
|
||||
);
|
||||
let (mut tools, delegate_handle, _reaction_handle, _channel_map_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
memory.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
// ── Wire MCP tools (non-fatal) ─────────────────────────────
|
||||
// Replicates the same MCP initialization logic used in the CLI
|
||||
@@ -668,11 +687,17 @@ impl Agent {
|
||||
.await;
|
||||
}
|
||||
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let now = chrono::Local::now();
|
||||
let (year, month, day) = (now.year(), now.month(), now.day());
|
||||
let (hour, minute, second) = (now.hour(), now.minute(), now.second());
|
||||
let tz = now.format("%Z");
|
||||
let date_str =
|
||||
format!("{year:04}-{month:02}-{day:02} {hour:02}:{minute:02}:{second:02} {tz}");
|
||||
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {user_message}")
|
||||
format!("[CURRENT DATE & TIME: {date_str}]\n\n{user_message}")
|
||||
} else {
|
||||
format!("{context}[{now}] {user_message}")
|
||||
format!("[CURRENT DATE & TIME: {date_str}]\n\n{context}\n\n{user_message}")
|
||||
};
|
||||
|
||||
self.history
|
||||
@@ -798,6 +823,254 @@ impl Agent {
|
||||
)
|
||||
}
|
||||
|
||||
/// Execute a single agent turn while streaming intermediate events.
|
||||
///
|
||||
/// Behaves identically to [`turn`](Self::turn) but forwards [`TurnEvent`]s
|
||||
/// through the provided channel so callers (e.g. the WebSocket gateway)
|
||||
/// can relay incremental updates to clients.
|
||||
///
|
||||
/// The returned `String` is the final, complete assistant response — the
|
||||
/// same value that `turn` would return.
|
||||
pub async fn turn_streamed(
|
||||
&mut self,
|
||||
user_message: &str,
|
||||
event_tx: tokio::sync::mpsc::Sender<TurnEvent>,
|
||||
) -> Result<String> {
|
||||
// ── Preamble (identical to turn) ───────────────────────────────
|
||||
if self.history.is_empty() {
|
||||
let system_prompt = self.build_system_prompt()?;
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::system(
|
||||
system_prompt,
|
||||
)));
|
||||
}
|
||||
|
||||
let context = self
|
||||
.memory_loader
|
||||
.load_context(
|
||||
self.memory.as_ref(),
|
||||
user_message,
|
||||
self.memory_session_id.as_deref(),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if self.auto_save {
|
||||
let _ = self
|
||||
.memory
|
||||
.store(
|
||||
"user_msg",
|
||||
user_message,
|
||||
MemoryCategory::Conversation,
|
||||
self.memory_session_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {user_message}")
|
||||
} else {
|
||||
format!("{context}[{now}] {user_message}")
|
||||
};
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
|
||||
|
||||
let effective_model = self.classify_model(user_message);
|
||||
|
||||
// ── Turn loop ──────────────────────────────────────────────────
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
|
||||
// Response cache check (same as turn)
|
||||
let cache_key = if self.temperature == 0.0 {
|
||||
self.response_cache.as_ref().map(|_| {
|
||||
let last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.as_str())
|
||||
.unwrap_or("");
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
crate::memory::response_cache::ResponseCache::cache_key(
|
||||
&effective_model,
|
||||
system,
|
||||
last_user,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
if let Ok(Some(cached)) = cache.get(key) {
|
||||
self.observer.record_event(&ObserverEvent::CacheHit {
|
||||
cache_type: "response".into(),
|
||||
tokens_saved: 0,
|
||||
});
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
cached.clone(),
|
||||
)));
|
||||
self.trim_history();
|
||||
return Ok(cached);
|
||||
}
|
||||
self.observer.record_event(&ObserverEvent::CacheMiss {
|
||||
cache_type: "response".into(),
|
||||
});
|
||||
}
|
||||
|
||||
// ── Streaming LLM call ────────────────────────────────────
|
||||
// Try streaming first; if the provider returns content we
|
||||
// forward deltas. Otherwise fall back to non-streaming chat.
|
||||
use futures_util::StreamExt;
|
||||
|
||||
let stream_opts = crate::providers::traits::StreamOptions::new(true);
|
||||
let mut stream = self.provider.stream_chat_with_history(
|
||||
&messages,
|
||||
&effective_model,
|
||||
self.temperature,
|
||||
stream_opts,
|
||||
);
|
||||
|
||||
let mut streamed_text = String::new();
|
||||
let mut got_stream = false;
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(chunk) => {
|
||||
if !chunk.delta.is_empty() {
|
||||
got_stream = true;
|
||||
streamed_text.push_str(&chunk.delta);
|
||||
let _ = event_tx.send(TurnEvent::Chunk { delta: chunk.delta }).await;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
// Drop the stream so we release the borrow on provider.
|
||||
drop(stream);
|
||||
|
||||
// If streaming produced text, use it as the response and
|
||||
// check for tool calls via the dispatcher.
|
||||
let response = if got_stream {
|
||||
// Build a synthetic ChatResponse from streamed text
|
||||
crate::providers::ChatResponse {
|
||||
text: Some(streamed_text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
}
|
||||
} else {
|
||||
// Fall back to non-streaming chat
|
||||
match self
|
||||
.provider
|
||||
.chat(
|
||||
ChatRequest {
|
||||
messages: &messages,
|
||||
tools: if self.tool_dispatcher.should_send_tool_specs() {
|
||||
Some(&self.tool_specs)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
},
|
||||
&effective_model,
|
||||
self.temperature,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
};
|
||||
|
||||
let (text, calls) = self.tool_dispatcher.parse_response(&response);
|
||||
if calls.is_empty() {
|
||||
let final_text = if text.is_empty() {
|
||||
response.text.unwrap_or_default()
|
||||
} else {
|
||||
text
|
||||
};
|
||||
|
||||
// Store in response cache
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
let token_count = response
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|u| u.output_tokens)
|
||||
.unwrap_or(0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let _ = cache.put(key, &effective_model, &final_text, token_count as u32);
|
||||
}
|
||||
|
||||
// If we didn't stream, send the full response as a single chunk
|
||||
if !got_stream && !final_text.is_empty() {
|
||||
let _ = event_tx
|
||||
.send(TurnEvent::Chunk {
|
||||
delta: final_text.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
final_text.clone(),
|
||||
)));
|
||||
self.trim_history();
|
||||
|
||||
return Ok(final_text);
|
||||
}
|
||||
|
||||
// ── Tool calls ─────────────────────────────────────────────
|
||||
if !text.is_empty() {
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
text.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
self.history.push(ConversationMessage::AssistantToolCalls {
|
||||
text: response.text.clone(),
|
||||
tool_calls: response.tool_calls.clone(),
|
||||
reasoning_content: response.reasoning_content.clone(),
|
||||
});
|
||||
|
||||
// Notify about each tool call
|
||||
for call in &calls {
|
||||
let _ = event_tx
|
||||
.send(TurnEvent::ToolCall {
|
||||
name: call.name.clone(),
|
||||
args: call.arguments.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
let results = self.execute_tools(&calls).await;
|
||||
|
||||
// Notify about each tool result
|
||||
for result in &results {
|
||||
let _ = event_tx
|
||||
.send(TurnEvent::ToolResult {
|
||||
name: result.name.clone(),
|
||||
output: result.output.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
let formatted = self.tool_dispatcher.format_results(&results);
|
||||
self.history.push(formatted);
|
||||
self.trim_history();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Agent exceeded maximum tool iterations ({})",
|
||||
self.config.max_tool_iterations
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn run_single(&mut self, message: &str) -> Result<String> {
|
||||
self.turn(message).await
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::config::Config;
|
||||
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
|
||||
use crate::cost::CostTracker;
|
||||
use crate::i18n::ToolDescriptions;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::memory::{self, decay, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
|
||||
use crate::providers::{
|
||||
@@ -561,6 +561,7 @@ fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Res
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
/// Core memories are exempt from time decay (evergreen).
|
||||
async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
@@ -570,7 +571,10 @@ async fn build_context(
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
if let Ok(mut entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
@@ -2659,6 +2663,14 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let mut consecutive_identical_outputs: usize = 0;
|
||||
let mut last_tool_output_hash: Option<u64> = None;
|
||||
|
||||
let mut loop_detector = crate::agent::loop_detector::LoopDetector::new(
|
||||
crate::agent::loop_detector::LoopDetectorConfig {
|
||||
enabled: pacing.loop_detection_enabled,
|
||||
window_size: pacing.loop_detection_window_size,
|
||||
max_repeats: pacing.loop_detection_max_repeats,
|
||||
},
|
||||
);
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
|
||||
@@ -2707,16 +2719,53 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
|
||||
// ── Vision provider routing ──────────────────────────
|
||||
// When the default provider lacks vision support but a dedicated
|
||||
// vision_provider is configured, create it on demand and use it
|
||||
// for this iteration. Otherwise, preserve the original error.
|
||||
let vision_provider_box: Option<Box<dyn Provider>> = if image_marker_count > 0
|
||||
&& !provider.supports_vision()
|
||||
{
|
||||
if let Some(ref vp) = multimodal_config.vision_provider {
|
||||
let vp_instance = providers::create_provider(vp, None)
|
||||
.map_err(|e| anyhow::anyhow!("failed to create vision provider '{vp}': {e}"))?;
|
||||
if !vp_instance.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: vp.clone(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"configured vision_provider '{vp}' does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
Some(vp_instance)
|
||||
} else {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
.into());
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (active_provider, active_provider_name, active_model): (&dyn Provider, &str, &str) =
|
||||
if let Some(ref vp_box) = vision_provider_box {
|
||||
let vp_name = multimodal_config
|
||||
.vision_provider
|
||||
.as_deref()
|
||||
.unwrap_or(provider_name);
|
||||
let vm = multimodal_config.vision_model.as_deref().unwrap_or(model);
|
||||
(vp_box.as_ref(), vp_name, vm)
|
||||
} else {
|
||||
(provider, provider_name, model)
|
||||
};
|
||||
|
||||
let prepared_messages =
|
||||
multimodal::prepare_messages_for_provider(history, multimodal_config).await?;
|
||||
@@ -2732,15 +2781,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmRequest {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
messages_count: history.len(),
|
||||
});
|
||||
runtime_trace::record_event(
|
||||
"llm_request",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
None,
|
||||
None,
|
||||
@@ -2778,12 +2827,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
None
|
||||
};
|
||||
|
||||
let chat_future = provider.chat(
|
||||
let chat_future = active_provider.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
active_model,
|
||||
temperature,
|
||||
);
|
||||
|
||||
@@ -2836,8 +2885,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
.unwrap_or((None, None));
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
@@ -2846,10 +2895,9 @@ pub(crate) async fn run_tool_call_loop(
|
||||
});
|
||||
|
||||
// Record cost via task-local tracker (no-op when not scoped)
|
||||
let _ = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage));
|
||||
let _ = resp.usage.as_ref().and_then(|usage| {
|
||||
record_tool_loop_cost_usage(active_provider_name, active_model, usage)
|
||||
});
|
||||
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
@@ -2872,8 +2920,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"tool_call_parse_issue",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&parse_issue),
|
||||
@@ -2890,8 +2938,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
@@ -2940,8 +2988,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
Err(e) => {
|
||||
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(safe_error.clone()),
|
||||
@@ -2951,8 +2999,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&safe_error),
|
||||
@@ -3036,7 +3084,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
if !display_text.is_empty() {
|
||||
if !native_tool_calls.is_empty() {
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx.send(display_text.clone()).await;
|
||||
let mut narration = display_text.clone();
|
||||
if !narration.ends_with('\n') {
|
||||
narration.push('\n');
|
||||
}
|
||||
let _ = tx.send(narration).await;
|
||||
}
|
||||
}
|
||||
if !silent {
|
||||
@@ -3325,9 +3377,54 @@ pub(crate) async fn run_tool_call_loop(
|
||||
// Collect tool results and build per-tool output for loop detection.
|
||||
// Only non-ignored tool outputs contribute to the identical-output hash.
|
||||
let mut detection_relevant_output = String::new();
|
||||
for (tool_name, tool_call_id, outcome) in ordered_results.into_iter().flatten() {
|
||||
// Use enumerate *before* filter_map so result_index stays aligned with
|
||||
// tool_calls even when some ordered_results entries are None.
|
||||
for (result_index, (tool_name, tool_call_id, outcome)) in ordered_results
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, opt)| opt.map(|v| (i, v)))
|
||||
{
|
||||
if !loop_ignore_tools.contains(tool_name.as_str()) {
|
||||
detection_relevant_output.push_str(&outcome.output);
|
||||
|
||||
// Feed the pattern-based loop detector with name + args + result.
|
||||
let args = tool_calls
|
||||
.get(result_index)
|
||||
.map(|c| &c.arguments)
|
||||
.unwrap_or(&serde_json::Value::Null);
|
||||
let det_result = loop_detector.record(&tool_name, args, &outcome.output);
|
||||
match det_result {
|
||||
crate::agent::loop_detector::LoopDetectionResult::Ok => {}
|
||||
crate::agent::loop_detector::LoopDetectionResult::Warning(ref msg) => {
|
||||
tracing::warn!(tool = %tool_name, %msg, "loop detector warning");
|
||||
// Inject a system nudge so the LLM adjusts strategy.
|
||||
history.push(ChatMessage::system(format!("[Loop Detection] {msg}")));
|
||||
}
|
||||
crate::agent::loop_detector::LoopDetectionResult::Block(ref msg) => {
|
||||
tracing::warn!(tool = %tool_name, %msg, "loop detector blocked tool call");
|
||||
// Replace the tool output with the block message.
|
||||
// We still continue the loop so the LLM sees the block feedback.
|
||||
history.push(ChatMessage::system(format!(
|
||||
"[Loop Detection — BLOCKED] {msg}"
|
||||
)));
|
||||
}
|
||||
crate::agent::loop_detector::LoopDetectionResult::Break(msg) => {
|
||||
runtime_trace::record_event(
|
||||
"loop_detector_circuit_breaker",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&msg),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"tool": tool_name,
|
||||
}),
|
||||
);
|
||||
anyhow::bail!("Agent loop aborted by loop detector: {msg}");
|
||||
}
|
||||
}
|
||||
}
|
||||
individual_results.push((tool_call_id, outcome.output.clone()));
|
||||
let _ = writeln!(
|
||||
@@ -3525,22 +3622,23 @@ pub async fn run(
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (mut tools_registry, delegate_handle, _reaction_handle) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let (mut tools_registry, delegate_handle, _reaction_handle, _channel_map_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
mem.clone(),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
||||
@@ -3701,6 +3799,11 @@ pub async fn run(
|
||||
|
||||
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
||||
let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config);
|
||||
|
||||
// Register skill-defined tools as callable tool specs in the tool registry
|
||||
// so the LLM can invoke them via native function calling, not just XML prompts.
|
||||
tools::register_skill_tools(&mut tools_registry, &skills, security.clone());
|
||||
|
||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||
(
|
||||
"shell",
|
||||
@@ -3865,17 +3968,45 @@ pub async fn run(
|
||||
|
||||
let mut final_output = String::new();
|
||||
|
||||
// Save the base system prompt before any thinking modifications so
|
||||
// the interactive loop can restore it between turns.
|
||||
let base_system_prompt = system_prompt.clone();
|
||||
|
||||
if let Some(msg) = message {
|
||||
// ── Parse thinking directive from user message ─────────
|
||||
let (thinking_directive, effective_msg) =
|
||||
match crate::agent::thinking::parse_thinking_directive(&msg) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed from message");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, msg.clone()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let effective_temperature = crate::agent::thinking::clamp_temperature(
|
||||
temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// Prepend thinking system prompt prefix when present.
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
}
|
||||
|
||||
// Auto-save user message to memory (skip short/trivial messages)
|
||||
if config.memory.auto_save
|
||||
&& msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&msg)
|
||||
&& effective_msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&effective_msg)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(
|
||||
&user_key,
|
||||
&msg,
|
||||
&effective_msg,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -3885,7 +4016,7 @@ pub async fn run(
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&msg,
|
||||
&effective_msg,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -3893,14 +4024,14 @@ pub async fn run(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, &msg, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, &effective_msg, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {msg}")
|
||||
format!("[{now}] {effective_msg}")
|
||||
} else {
|
||||
format!("{context}[{now}] {msg}")
|
||||
format!("{context}[{now}] {effective_msg}")
|
||||
};
|
||||
|
||||
let mut history = vec![
|
||||
@@ -3909,8 +4040,11 @@ pub async fn run(
|
||||
];
|
||||
|
||||
// Compute per-turn excluded MCP tools from tool_filter_groups.
|
||||
let excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, &msg);
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
&effective_msg,
|
||||
);
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
let mut response = String::new();
|
||||
@@ -3922,7 +4056,7 @@ pub async fn run(
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
effective_temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
@@ -4042,9 +4176,10 @@ pub async fn run(
|
||||
"/quit" | "/exit" => break,
|
||||
"/help" => {
|
||||
println!("Available commands:");
|
||||
println!(" /help Show this help message");
|
||||
println!(" /clear /new Clear conversation history");
|
||||
println!(" /quit /exit Exit interactive mode\n");
|
||||
println!(" /help Show this help message");
|
||||
println!(" /clear /new Clear conversation history");
|
||||
println!(" /quit /exit Exit interactive mode");
|
||||
println!(" /think:<level> Set reasoning depth (off|minimal|low|medium|high|max)\n");
|
||||
continue;
|
||||
}
|
||||
"/clear" | "/new" => {
|
||||
@@ -4096,16 +4231,47 @@ pub async fn run(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// ── Parse thinking directive from interactive input ───
|
||||
let (thinking_directive, effective_input) =
|
||||
match crate::agent::thinking::parse_thinking_directive(&user_input) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, user_input.clone()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let turn_temperature = crate::agent::thinking::clamp_temperature(
|
||||
temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// For non-Medium levels, temporarily patch the system prompt with prefix.
|
||||
let turn_system_prompt;
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
turn_system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
// Update the system message in history for this turn.
|
||||
if let Some(sys_msg) = history.first_mut() {
|
||||
if sys_msg.role == "system" {
|
||||
sys_msg.content = turn_system_prompt.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-save conversation turns (skip short/trivial messages)
|
||||
if config.memory.auto_save
|
||||
&& user_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&user_input)
|
||||
&& effective_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&effective_input)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(
|
||||
&user_key,
|
||||
&user_input,
|
||||
&effective_input,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -4115,7 +4281,7 @@ pub async fn run(
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&user_input,
|
||||
&effective_input,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -4123,14 +4289,14 @@ pub async fn run(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, &user_input, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, &effective_input, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {user_input}")
|
||||
format!("[{now}] {effective_input}")
|
||||
} else {
|
||||
format!("{context}[{now}] {user_input}")
|
||||
format!("{context}[{now}] {effective_input}")
|
||||
};
|
||||
|
||||
history.push(ChatMessage::user(&enriched));
|
||||
@@ -4139,7 +4305,7 @@ pub async fn run(
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
&user_input,
|
||||
&effective_input,
|
||||
);
|
||||
|
||||
let response = loop {
|
||||
@@ -4150,7 +4316,7 @@ pub async fn run(
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
turn_temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
@@ -4235,6 +4401,15 @@ pub async fn run(
|
||||
// Hard cap as a safety net.
|
||||
trim_history(&mut history, config.agent.max_history_messages);
|
||||
|
||||
// Restore base system prompt (remove per-turn thinking prefix).
|
||||
if thinking_params.system_prompt_prefix.is_some() {
|
||||
if let Some(sys_msg) = history.first_mut() {
|
||||
if sys_msg.role == "system" {
|
||||
sys_msg.content.clone_from(&base_system_prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = session_state_file.as_deref() {
|
||||
save_interactive_session_history(path, &history)?;
|
||||
}
|
||||
@@ -4285,7 +4460,7 @@ pub async fn process_message(
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (mut tools_registry, delegate_handle_pm, _reaction_handle_pm) =
|
||||
let (mut tools_registry, delegate_handle_pm, _reaction_handle_pm, _channel_map_handle_pm) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
@@ -4415,6 +4590,10 @@ pub async fn process_message(
|
||||
let i18n_descs = crate::i18n::ToolDescriptions::load(&i18n_locale, &i18n_search_dirs);
|
||||
|
||||
let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config);
|
||||
|
||||
// Register skill-defined tools as callable tool specs (process_message path).
|
||||
tools::register_skill_tools(&mut tools_registry, &skills, security.clone());
|
||||
|
||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||
("shell", "Execute terminal commands."),
|
||||
("file_read", "Read file contents."),
|
||||
@@ -4508,9 +4687,34 @@ pub async fn process_message(
|
||||
system_prompt.push_str(&deferred_section);
|
||||
}
|
||||
|
||||
// ── Parse thinking directive from user message ─────────────
|
||||
let (thinking_directive, effective_message) =
|
||||
match crate::agent::thinking::parse_thinking_directive(message) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed from message");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, message.to_string()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let effective_temperature = crate::agent::thinking::clamp_temperature(
|
||||
config.default_temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// Prepend thinking system prompt prefix when present.
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
}
|
||||
|
||||
let effective_msg_ref = effective_message.as_str();
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
message,
|
||||
effective_msg_ref,
|
||||
config.memory.min_relevance_score,
|
||||
session_id,
|
||||
)
|
||||
@@ -4518,22 +4722,25 @@ pub async fn process_message(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, message, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, effective_msg_ref, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {message}")
|
||||
format!("[{now}] {effective_message}")
|
||||
} else {
|
||||
format!("{context}[{now}] {message}")
|
||||
format!("{context}[{now}] {effective_message}")
|
||||
};
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system(&system_prompt),
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
let mut excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, message);
|
||||
let mut excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
effective_msg_ref,
|
||||
);
|
||||
if config.autonomy.level != AutonomyLevel::Full {
|
||||
excluded_tools.extend(config.autonomy.non_cli_excluded_tools.iter().cloned());
|
||||
}
|
||||
@@ -4545,7 +4752,7 @@ pub async fn process_message(
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
config.default_temperature,
|
||||
effective_temperature,
|
||||
true,
|
||||
"daemon",
|
||||
None,
|
||||
@@ -5094,6 +5301,7 @@ mod tests {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
@@ -5171,6 +5379,313 @@ mod tests {
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
/// When `vision_provider` is not set and the default provider lacks vision
|
||||
/// support, the original `ProviderCapabilityError` should be returned.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_no_vision_provider_config_preserves_error() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"check [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail without vision_provider config");
|
||||
|
||||
assert!(err.to_string().contains("capability=vision"));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
/// When `vision_provider` is set but the provider factory cannot resolve
|
||||
/// the name, a descriptive error should be returned (not the generic
|
||||
/// capability error).
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_vision_provider_creation_failure() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("some-model".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail when vision provider cannot be created");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure error, got: {}",
|
||||
err
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
/// Messages without image markers should use the default provider even
|
||||
/// when `vision_provider` is configured.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_no_images_uses_default_provider() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["hello world"]);
|
||||
|
||||
let mut history = vec![ChatMessage::user("just text, no images".to_string())];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("some-model".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even though vision_provider points to a nonexistent provider, this
|
||||
// should succeed because there are no image markers to trigger routing.
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"scripted",
|
||||
"scripted-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("text-only messages should succeed with default provider");
|
||||
|
||||
assert_eq!(result, "hello world");
|
||||
}
|
||||
|
||||
/// When `vision_provider` is set but `vision_model` is not, the default
|
||||
/// model should be used as fallback for the vision provider.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_vision_provider_without_model_falls_back() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"look [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
// vision_provider set but vision_model is None — the code should
|
||||
// fall back to the default model. Since the provider name is invalid,
|
||||
// we just verify the error path references the correct provider.
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail due to nonexistent vision provider");
|
||||
|
||||
// Verify the routing was attempted (not the generic capability error).
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure, got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
/// Empty `[IMAGE:]` markers (which are preserved as literal text by the
|
||||
/// parser) should not trigger vision provider routing.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_empty_image_markers_use_default_provider() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["handled"]);
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"empty marker [IMAGE:] should be ignored".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"scripted",
|
||||
"scripted-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("empty image markers should not trigger vision routing");
|
||||
|
||||
assert_eq!(result, "handled");
|
||||
}
|
||||
|
||||
/// Multiple image markers should still trigger vision routing when
|
||||
/// vision_provider is configured.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_multiple_images_trigger_vision_routing() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"two images [IMAGE:data:image/png;base64,aQ==] and [IMAGE:data:image/png;base64,bQ==]"
|
||||
.to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("llava:7b".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should attempt vision provider creation for multiple images");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure for multiple images, got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_execute_tools_in_parallel_returns_false_for_single_call() {
|
||||
let calls = vec![ParsedToolCall {
|
||||
@@ -5849,7 +6364,7 @@ mod tests {
|
||||
|
||||
let explanation_idx = deltas
|
||||
.iter()
|
||||
.position(|delta| delta == "Task started. Waiting 30 seconds before checking status.")
|
||||
.position(|delta| delta == "Task started. Waiting 30 seconds before checking status.\n")
|
||||
.expect("native assistant text should be relayed to on_delta");
|
||||
let clear_idx = deltas
|
||||
.iter()
|
||||
|
||||
@@ -0,0 +1,696 @@
|
||||
//! Loop detection guardrail for the agent tool-call loop.
|
||||
//!
|
||||
//! Monitors a sliding window of recent tool calls and their results to detect
|
||||
//! three repetitive patterns that indicate the agent is stuck:
|
||||
//!
|
||||
//! 1. **Exact repeat** — same tool + args called 3+ times consecutively.
|
||||
//! 2. **Ping-pong** — two tools alternating (A->B->A->B) for 4+ cycles.
|
||||
//! 3. **No progress** — same tool called 5+ times with different args but
|
||||
//! identical result hash each time.
|
||||
//!
|
||||
//! Detection triggers escalating responses: `Warning` -> `Block` -> `Break`.
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
// ── Configuration ────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for the loop detector, typically derived from
|
||||
/// `PacingConfig` fields at the call site.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct LoopDetectorConfig {
|
||||
/// Master switch. When `false`, `record` always returns `Ok`.
|
||||
pub enabled: bool,
|
||||
/// Number of recent calls retained for pattern analysis.
|
||||
pub window_size: usize,
|
||||
/// How many consecutive exact-repeat calls before escalation starts.
|
||||
pub max_repeats: usize,
|
||||
}
|
||||
|
||||
impl Default for LoopDetectorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
window_size: 20,
|
||||
max_repeats: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Result enum ──────────────────────────────────────────────────
|
||||
|
||||
/// Outcome of a loop-detection check after recording a tool call.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum LoopDetectionResult {
|
||||
/// No pattern detected — continue normally.
|
||||
Ok,
|
||||
/// A suspicious pattern was detected; the caller should inject a
|
||||
/// system-level nudge message into the conversation.
|
||||
Warning(String),
|
||||
/// The tool call should be refused (output replaced with an error).
|
||||
Block(String),
|
||||
/// The agent turn should be terminated immediately.
|
||||
Break(String),
|
||||
}
|
||||
|
||||
// ── Internal types ───────────────────────────────────────────────
|
||||
|
||||
/// A single recorded tool invocation inside the sliding window.
|
||||
#[derive(Debug, Clone)]
|
||||
struct ToolCallRecord {
|
||||
/// Tool name.
|
||||
name: String,
|
||||
/// Hash of the serialised arguments.
|
||||
args_hash: u64,
|
||||
/// Hash of the tool's output/result.
|
||||
result_hash: u64,
|
||||
}
|
||||
|
||||
/// Produce a deterministic hash for a JSON value by recursively sorting
|
||||
/// object keys before serialisation. This ensures `{"a":1,"b":2}` and
|
||||
/// `{"b":2,"a":1}` hash identically.
|
||||
fn hash_value(value: &serde_json::Value) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let canonical = serde_json::to_string(&canonicalise(value)).unwrap_or_default();
|
||||
canonical.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Return a clone of `value` with all object keys sorted recursively.
|
||||
fn canonicalise(value: &serde_json::Value) -> serde_json::Value {
|
||||
match value {
|
||||
serde_json::Value::Object(map) => {
|
||||
let mut sorted: Vec<(&String, &serde_json::Value)> = map.iter().collect();
|
||||
sorted.sort_by_key(|(k, _)| *k);
|
||||
let new_map: serde_json::Map<String, serde_json::Value> = sorted
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.clone(), canonicalise(v)))
|
||||
.collect();
|
||||
serde_json::Value::Object(new_map)
|
||||
}
|
||||
serde_json::Value::Array(arr) => {
|
||||
serde_json::Value::Array(arr.iter().map(canonicalise).collect())
|
||||
}
|
||||
other => other.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_str(s: &str) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
s.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// ── Detector ─────────────────────────────────────────────────────
|
||||
|
||||
/// Stateful loop detector that lives for the duration of a single
|
||||
/// `run_tool_call_loop` invocation.
|
||||
pub(crate) struct LoopDetector {
|
||||
config: LoopDetectorConfig,
|
||||
window: VecDeque<ToolCallRecord>,
|
||||
}
|
||||
|
||||
impl LoopDetector {
|
||||
pub fn new(config: LoopDetectorConfig) -> Self {
|
||||
Self {
|
||||
window: VecDeque::with_capacity(config.window_size),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed tool call and check for loop patterns.
|
||||
///
|
||||
/// * `name` — tool name (e.g. `"shell"`, `"file_read"`).
|
||||
/// * `args` — the arguments JSON value sent to the tool.
|
||||
/// * `result` — the tool's textual output.
|
||||
pub fn record(
|
||||
&mut self,
|
||||
name: &str,
|
||||
args: &serde_json::Value,
|
||||
result: &str,
|
||||
) -> LoopDetectionResult {
|
||||
if !self.config.enabled {
|
||||
return LoopDetectionResult::Ok;
|
||||
}
|
||||
|
||||
let record = ToolCallRecord {
|
||||
name: name.to_string(),
|
||||
args_hash: hash_value(args),
|
||||
result_hash: hash_str(result),
|
||||
};
|
||||
|
||||
// Maintain sliding window.
|
||||
if self.window.len() >= self.config.window_size {
|
||||
self.window.pop_front();
|
||||
}
|
||||
self.window.push_back(record);
|
||||
|
||||
// Run detectors in escalation order (most severe first).
|
||||
if let Some(result) = self.detect_exact_repeat() {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = self.detect_ping_pong() {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = self.detect_no_progress() {
|
||||
return result;
|
||||
}
|
||||
|
||||
LoopDetectionResult::Ok
|
||||
}
|
||||
|
||||
/// Pattern 1: Same tool + same args called N+ times consecutively.
|
||||
///
|
||||
/// Escalation:
|
||||
/// - N == max_repeats -> Warning
|
||||
/// - N == max_repeats + 1 -> Block
|
||||
/// - N >= max_repeats + 2 -> Break (circuit breaker)
|
||||
fn detect_exact_repeat(&self) -> Option<LoopDetectionResult> {
|
||||
let max = self.config.max_repeats;
|
||||
if self.window.len() < max {
|
||||
return None;
|
||||
}
|
||||
|
||||
let last = self.window.back()?;
|
||||
let consecutive = self
|
||||
.window
|
||||
.iter()
|
||||
.rev()
|
||||
.take_while(|r| r.name == last.name && r.args_hash == last.args_hash)
|
||||
.count();
|
||||
|
||||
if consecutive >= max + 2 {
|
||||
Some(LoopDetectionResult::Break(format!(
|
||||
"Circuit breaker: tool '{}' called {} times consecutively with identical arguments",
|
||||
last.name, consecutive
|
||||
)))
|
||||
} else if consecutive > max {
|
||||
Some(LoopDetectionResult::Block(format!(
|
||||
"Blocked: tool '{}' called {} times consecutively with identical arguments",
|
||||
last.name, consecutive
|
||||
)))
|
||||
} else if consecutive >= max {
|
||||
Some(LoopDetectionResult::Warning(format!(
|
||||
"Warning: tool '{}' has been called {} times consecutively with identical arguments. \
|
||||
Try a different approach.",
|
||||
last.name, consecutive
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Pattern 2: Two tools alternating (A->B->A->B) for 4+ full cycles
|
||||
/// (i.e. 8 consecutive entries following the pattern).
|
||||
fn detect_ping_pong(&self) -> Option<LoopDetectionResult> {
|
||||
const MIN_CYCLES: usize = 4;
|
||||
let needed = MIN_CYCLES * 2; // each cycle = 2 calls
|
||||
|
||||
if self.window.len() < needed {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tail: Vec<&ToolCallRecord> = self.window.iter().rev().take(needed).collect();
|
||||
// tail[0] is most recent; pattern: A, B, A, B, ...
|
||||
let a_name = &tail[0].name;
|
||||
let b_name = &tail[1].name;
|
||||
|
||||
if a_name == b_name {
|
||||
return None;
|
||||
}
|
||||
|
||||
let is_ping_pong = tail.iter().enumerate().all(|(i, r)| {
|
||||
if i % 2 == 0 {
|
||||
&r.name == a_name
|
||||
} else {
|
||||
&r.name == b_name
|
||||
}
|
||||
});
|
||||
|
||||
if !is_ping_pong {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Count total alternating length for escalation.
|
||||
let mut cycles = MIN_CYCLES;
|
||||
let extended: Vec<&ToolCallRecord> = self.window.iter().rev().collect();
|
||||
for extra_pair in extended.chunks(2).skip(MIN_CYCLES) {
|
||||
if extra_pair.len() == 2
|
||||
&& &extra_pair[0].name == a_name
|
||||
&& &extra_pair[1].name == b_name
|
||||
{
|
||||
cycles += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if cycles >= MIN_CYCLES + 2 {
|
||||
Some(LoopDetectionResult::Break(format!(
|
||||
"Circuit breaker: tools '{}' and '{}' have been alternating for {} cycles",
|
||||
a_name, b_name, cycles
|
||||
)))
|
||||
} else if cycles > MIN_CYCLES {
|
||||
Some(LoopDetectionResult::Block(format!(
|
||||
"Blocked: tools '{}' and '{}' have been alternating for {} cycles",
|
||||
a_name, b_name, cycles
|
||||
)))
|
||||
} else {
|
||||
Some(LoopDetectionResult::Warning(format!(
|
||||
"Warning: tools '{}' and '{}' appear to be alternating ({} cycles). \
|
||||
Consider a different strategy.",
|
||||
a_name, b_name, cycles
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Pattern 3: Same tool called 5+ times (with different args each time)
|
||||
/// but producing the exact same result hash every time.
|
||||
fn detect_no_progress(&self) -> Option<LoopDetectionResult> {
|
||||
const MIN_CALLS: usize = 5;
|
||||
|
||||
if self.window.len() < MIN_CALLS {
|
||||
return None;
|
||||
}
|
||||
|
||||
let last = self.window.back()?;
|
||||
let same_tool_same_result: Vec<&ToolCallRecord> = self
|
||||
.window
|
||||
.iter()
|
||||
.rev()
|
||||
.take_while(|r| r.name == last.name && r.result_hash == last.result_hash)
|
||||
.collect();
|
||||
|
||||
let count = same_tool_same_result.len();
|
||||
if count < MIN_CALLS {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Verify they have *different* args (otherwise exact_repeat handles it).
|
||||
let unique_args: std::collections::HashSet<u64> =
|
||||
same_tool_same_result.iter().map(|r| r.args_hash).collect();
|
||||
if unique_args.len() < 2 {
|
||||
// All same args — this is exact-repeat territory, not no-progress.
|
||||
return None;
|
||||
}
|
||||
|
||||
if count >= MIN_CALLS + 2 {
|
||||
Some(LoopDetectionResult::Break(format!(
|
||||
"Circuit breaker: tool '{}' called {} times with different arguments but identical results — no progress",
|
||||
last.name, count
|
||||
)))
|
||||
} else if count > MIN_CALLS {
|
||||
Some(LoopDetectionResult::Block(format!(
|
||||
"Blocked: tool '{}' called {} times with different arguments but identical results",
|
||||
last.name, count
|
||||
)))
|
||||
} else {
|
||||
Some(LoopDetectionResult::Warning(format!(
|
||||
"Warning: tool '{}' called {} times with different arguments but identical results. \
|
||||
The current approach may not be making progress.",
|
||||
last.name, count
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn default_config() -> LoopDetectorConfig {
|
||||
LoopDetectorConfig::default()
|
||||
}
|
||||
|
||||
fn config_with_repeats(max_repeats: usize) -> LoopDetectorConfig {
|
||||
LoopDetectorConfig {
|
||||
enabled: true,
|
||||
window_size: 20,
|
||||
max_repeats,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Exact repeat tests ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn exact_repeat_warning_at_threshold() {
|
||||
let mut det = LoopDetector::new(config_with_repeats(3));
|
||||
let args = json!({"path": "/tmp/foo"});
|
||||
|
||||
assert_eq!(
|
||||
det.record("file_read", &args, "contents"),
|
||||
LoopDetectionResult::Ok
|
||||
);
|
||||
assert_eq!(
|
||||
det.record("file_read", &args, "contents"),
|
||||
LoopDetectionResult::Ok
|
||||
);
|
||||
// 3rd consecutive = warning
|
||||
match det.record("file_read", &args, "contents") {
|
||||
LoopDetectionResult::Warning(msg) => {
|
||||
assert!(msg.contains("file_read"));
|
||||
assert!(msg.contains("3 times"));
|
||||
}
|
||||
other => panic!("expected Warning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_repeat_block_at_threshold_plus_one() {
|
||||
let mut det = LoopDetector::new(config_with_repeats(3));
|
||||
let args = json!({"cmd": "ls"});
|
||||
|
||||
for _ in 0..3 {
|
||||
det.record("shell", &args, "output");
|
||||
}
|
||||
match det.record("shell", &args, "output") {
|
||||
LoopDetectionResult::Block(msg) => {
|
||||
assert!(msg.contains("shell"));
|
||||
assert!(msg.contains("4 times"));
|
||||
}
|
||||
other => panic!("expected Block, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_repeat_break_at_threshold_plus_two() {
|
||||
let mut det = LoopDetector::new(config_with_repeats(3));
|
||||
let args = json!({"q": "test"});
|
||||
|
||||
for _ in 0..4 {
|
||||
det.record("search", &args, "no results");
|
||||
}
|
||||
match det.record("search", &args, "no results") {
|
||||
LoopDetectionResult::Break(msg) => {
|
||||
assert!(msg.contains("Circuit breaker"));
|
||||
assert!(msg.contains("search"));
|
||||
}
|
||||
other => panic!("expected Break, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_repeat_resets_on_different_call() {
|
||||
let mut det = LoopDetector::new(config_with_repeats(3));
|
||||
let args = json!({"x": 1});
|
||||
|
||||
det.record("tool_a", &args, "r1");
|
||||
det.record("tool_a", &args, "r1");
|
||||
// Interject a different tool — resets the streak.
|
||||
det.record("tool_b", &json!({}), "r2");
|
||||
det.record("tool_a", &args, "r1");
|
||||
det.record("tool_a", &args, "r1");
|
||||
// Only 2 consecutive now, should be Ok.
|
||||
assert_eq!(
|
||||
det.record("tool_a", &json!({"x": 999}), "r1"),
|
||||
LoopDetectionResult::Ok
|
||||
);
|
||||
}
|
||||
|
||||
// ── Ping-pong tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ping_pong_warning_at_four_cycles() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
let args = json!({});
|
||||
|
||||
// 4 full cycles = 8 calls: A B A B A B A B
|
||||
for i in 0..8 {
|
||||
let name = if i % 2 == 0 { "read" } else { "write" };
|
||||
let result = det.record(name, &args, &format!("r{i}"));
|
||||
if i < 7 {
|
||||
assert_eq!(result, LoopDetectionResult::Ok, "iteration {i}");
|
||||
} else {
|
||||
match result {
|
||||
LoopDetectionResult::Warning(msg) => {
|
||||
assert!(msg.contains("read"));
|
||||
assert!(msg.contains("write"));
|
||||
assert!(msg.contains("4 cycles"));
|
||||
}
|
||||
other => panic!("expected Warning at cycle 4, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ping_pong_escalates_with_more_cycles() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
let args = json!({});
|
||||
|
||||
// 5 cycles = 10 calls. The 10th call (completing cycle 5) triggers Block.
|
||||
for i in 0..10 {
|
||||
let name = if i % 2 == 0 { "fetch" } else { "parse" };
|
||||
det.record(name, &args, &format!("r{i}"));
|
||||
}
|
||||
// 11th call extends to 5.5 cycles; detector still counts 5 full -> Block.
|
||||
let r = det.record("fetch", &args, "r10");
|
||||
match r {
|
||||
LoopDetectionResult::Block(msg) => {
|
||||
assert!(msg.contains("fetch"));
|
||||
assert!(msg.contains("parse"));
|
||||
assert!(msg.contains("5 cycles"));
|
||||
}
|
||||
other => panic!("expected Block at 5 cycles, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ping_pong_not_triggered_for_same_tool() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
let args = json!({});
|
||||
|
||||
// Same tool repeated is not ping-pong.
|
||||
for _ in 0..10 {
|
||||
det.record("read", &args, "data");
|
||||
}
|
||||
// The exact_repeat detector fires, not ping_pong.
|
||||
// Verify by checking message content doesn't mention "alternating".
|
||||
let r = det.record("read", &args, "data");
|
||||
if let LoopDetectionResult::Break(msg) | LoopDetectionResult::Block(msg) = r {
|
||||
assert!(
|
||||
!msg.contains("alternating"),
|
||||
"should be exact-repeat, not ping-pong"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── No-progress tests ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn no_progress_warning_at_five_different_args_same_result() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
|
||||
for i in 0..5 {
|
||||
let args = json!({"query": format!("attempt_{i}")});
|
||||
let result = det.record("search", &args, "no results found");
|
||||
if i < 4 {
|
||||
assert_eq!(result, LoopDetectionResult::Ok, "iteration {i}");
|
||||
} else {
|
||||
match result {
|
||||
LoopDetectionResult::Warning(msg) => {
|
||||
assert!(msg.contains("search"));
|
||||
assert!(msg.contains("identical results"));
|
||||
}
|
||||
other => panic!("expected Warning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_progress_escalates_to_block_and_break() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
|
||||
// 6 calls with different args, same result.
|
||||
for i in 0..6 {
|
||||
let args = json!({"q": format!("v{i}")});
|
||||
det.record("web_fetch", &args, "timeout");
|
||||
}
|
||||
// 7th call: count=7 which is >= MIN_CALLS(5)+2 -> Break.
|
||||
let r7 = det.record("web_fetch", &json!({"q": "v6"}), "timeout");
|
||||
match r7 {
|
||||
LoopDetectionResult::Break(msg) => {
|
||||
assert!(msg.contains("web_fetch"));
|
||||
assert!(msg.contains("7 times"));
|
||||
assert!(msg.contains("no progress"));
|
||||
}
|
||||
other => panic!("expected Break at 7 calls, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_progress_not_triggered_when_results_differ() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
|
||||
for i in 0..8 {
|
||||
let args = json!({"q": format!("v{i}")});
|
||||
let result = det.record("search", &args, &format!("result_{i}"));
|
||||
assert_eq!(result, LoopDetectionResult::Ok, "iteration {i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_progress_not_triggered_when_all_args_identical() {
|
||||
// If args are all the same, exact_repeat should fire, not no_progress.
|
||||
let mut det = LoopDetector::new(config_with_repeats(6));
|
||||
let args = json!({"q": "same"});
|
||||
|
||||
for _ in 0..5 {
|
||||
det.record("search", &args, "no results");
|
||||
}
|
||||
// 6th call = exact repeat at threshold (max_repeats=6) -> Warning.
|
||||
// no_progress requires >=2 unique args, so it must NOT fire.
|
||||
let r = det.record("search", &args, "no results");
|
||||
match r {
|
||||
LoopDetectionResult::Warning(msg) => {
|
||||
assert!(
|
||||
msg.contains("identical arguments"),
|
||||
"should be exact-repeat Warning, got: {msg}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected exact-repeat Warning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Disabled / config tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn disabled_detector_always_returns_ok() {
|
||||
let config = LoopDetectorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
};
|
||||
let mut det = LoopDetector::new(config);
|
||||
let args = json!({"x": 1});
|
||||
|
||||
for _ in 0..20 {
|
||||
assert_eq!(det.record("tool", &args, "same"), LoopDetectionResult::Ok);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_size_limits_memory() {
|
||||
let config = LoopDetectorConfig {
|
||||
enabled: true,
|
||||
window_size: 5,
|
||||
max_repeats: 3,
|
||||
};
|
||||
let mut det = LoopDetector::new(config);
|
||||
let args = json!({"x": 1});
|
||||
|
||||
// Fill window with 5 different tools.
|
||||
for i in 0..5 {
|
||||
det.record(&format!("tool_{i}"), &args, "result");
|
||||
}
|
||||
assert_eq!(det.window.len(), 5);
|
||||
|
||||
// Adding one more evicts the oldest.
|
||||
det.record("tool_5", &args, "result");
|
||||
assert_eq!(det.window.len(), 5);
|
||||
assert_eq!(det.window.front().unwrap().name, "tool_1");
|
||||
}
|
||||
|
||||
// ── Ping-pong with varying args ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ping_pong_detects_alternation_with_varying_args() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
|
||||
// A->B->A->B with different args each time — ping-pong cares only
|
||||
// about tool names, not argument equality.
|
||||
for i in 0..8 {
|
||||
let name = if i % 2 == 0 { "read" } else { "write" };
|
||||
let args = json!({"attempt": i});
|
||||
let result = det.record(name, &args, &format!("r{i}"));
|
||||
if i < 7 {
|
||||
assert_eq!(result, LoopDetectionResult::Ok, "iteration {i}");
|
||||
} else {
|
||||
match result {
|
||||
LoopDetectionResult::Warning(msg) => {
|
||||
assert!(msg.contains("read"));
|
||||
assert!(msg.contains("write"));
|
||||
assert!(msg.contains("4 cycles"));
|
||||
}
|
||||
other => panic!("expected Warning at cycle 4, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Window eviction test ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn window_eviction_prevents_stale_pattern_detection() {
|
||||
let config = LoopDetectorConfig {
|
||||
enabled: true,
|
||||
window_size: 6,
|
||||
max_repeats: 3,
|
||||
};
|
||||
let mut det = LoopDetector::new(config);
|
||||
let args = json!({"x": 1});
|
||||
|
||||
// 2 consecutive calls of "tool_a".
|
||||
det.record("tool_a", &args, "r");
|
||||
det.record("tool_a", &args, "r");
|
||||
|
||||
// Fill the rest of the window with different tools (evicting the
|
||||
// first "tool_a" calls as the window is only 6).
|
||||
for i in 0..5 {
|
||||
det.record(&format!("other_{i}"), &json!({}), "ok");
|
||||
}
|
||||
|
||||
// Now "tool_a" again — only 1 consecutive, not 3.
|
||||
let r = det.record("tool_a", &args, "r");
|
||||
assert_eq!(
|
||||
r,
|
||||
LoopDetectionResult::Ok,
|
||||
"stale entries should be evicted"
|
||||
);
|
||||
}
|
||||
|
||||
// ── hash_value key-order independence ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hash_value_is_key_order_independent() {
|
||||
let a = json!({"alpha": 1, "beta": 2});
|
||||
let b = json!({"beta": 2, "alpha": 1});
|
||||
assert_eq!(
|
||||
hash_value(&a),
|
||||
hash_value(&b),
|
||||
"hash_value must produce identical hashes regardless of JSON key order"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_value_nested_key_order_independent() {
|
||||
let a = json!({"outer": {"x": 1, "y": 2}, "z": [1, 2]});
|
||||
let b = json!({"z": [1, 2], "outer": {"y": 2, "x": 1}});
|
||||
assert_eq!(
|
||||
hash_value(&a),
|
||||
hash_value(&b),
|
||||
"nested objects must also be key-order independent"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Escalation order tests ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn exact_repeat_takes_priority_over_no_progress() {
|
||||
// If tool+args are identical, exact_repeat fires before no_progress.
|
||||
let mut det = LoopDetector::new(config_with_repeats(3));
|
||||
let args = json!({"q": "same"});
|
||||
|
||||
det.record("s", &args, "r");
|
||||
det.record("s", &args, "r");
|
||||
let r = det.record("s", &args, "r");
|
||||
match r {
|
||||
LoopDetectionResult::Warning(msg) => {
|
||||
assert!(msg.contains("identical arguments"));
|
||||
}
|
||||
other => panic!("expected exact-repeat Warning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory};
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -43,13 +43,16 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory
|
||||
let mut entries = memory
|
||||
.recall(user_message, self.limit, session_id, None, None)
|
||||
.await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
@@ -118,6 +121,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -226,6 +232,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.95),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
@@ -235,6 +244,9 @@ mod tests {
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
@@ -3,13 +3,15 @@ pub mod agent;
|
||||
pub mod classifier;
|
||||
pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod loop_detector;
|
||||
pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
pub mod thinking;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use agent::{Agent, AgentBuilder};
|
||||
pub use agent::{Agent, AgentBuilder, TurnEvent};
|
||||
#[allow(unused_imports)]
|
||||
pub use loop_::{process_message, run};
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::security::AutonomyLevel;
|
||||
use crate::skills::Skill;
|
||||
use crate::tools::Tool;
|
||||
use anyhow::Result;
|
||||
use chrono::Local;
|
||||
use chrono::{Datelike, Local, Timelike};
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
|
||||
@@ -47,13 +47,13 @@ impl SystemPromptBuilder {
|
||||
pub fn with_defaults() -> Self {
|
||||
Self {
|
||||
sections: vec![
|
||||
Box::new(DateTimeSection),
|
||||
Box::new(IdentitySection),
|
||||
Box::new(ToolHonestySection),
|
||||
Box::new(ToolsSection),
|
||||
Box::new(SafetySection),
|
||||
Box::new(SkillsSection),
|
||||
Box::new(WorkspaceSection),
|
||||
Box::new(DateTimeSection),
|
||||
Box::new(RuntimeSection),
|
||||
Box::new(ChannelMediaSection),
|
||||
],
|
||||
@@ -278,10 +278,19 @@ impl PromptSection for DateTimeSection {
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> Result<String> {
|
||||
let now = Local::now();
|
||||
// Force Gregorian year to avoid confusion with local calendars (e.g. Buddhist calendar).
|
||||
let (year, month, day) = (now.year(), now.month(), now.day());
|
||||
let (hour, minute, second) = (now.hour(), now.minute(), now.second());
|
||||
let tz = now.format("%Z");
|
||||
|
||||
Ok(format!(
|
||||
"## Current Date & Time\n\n{} ({})",
|
||||
now.format("%Y-%m-%d %H:%M:%S"),
|
||||
now.format("%Z")
|
||||
"## CRITICAL CONTEXT: CURRENT DATE & TIME\n\n\
|
||||
The following is the ABSOLUTE TRUTH regarding the current date and time. \
|
||||
Use this for all relative time calculations (e.g. \"last 7 days\").\n\n\
|
||||
Date: {year:04}-{month:02}-{day:02}\n\
|
||||
Time: {hour:02}:{minute:02}:{second:02} ({tz})\n\
|
||||
ISO 8601: {year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}{}",
|
||||
now.format("%:z")
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -473,8 +482,9 @@ mod tests {
|
||||
assert!(output.contains("<available_skills>"));
|
||||
assert!(output.contains("<name>deploy</name>"));
|
||||
assert!(output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -516,10 +526,10 @@ mod tests {
|
||||
assert!(output.contains("<location>skills/deploy/SKILL.md</location>"));
|
||||
assert!(output.contains("read_skill(name)"));
|
||||
assert!(!output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(output.contains("<tools>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -539,12 +549,12 @@ mod tests {
|
||||
};
|
||||
|
||||
let rendered = DateTimeSection.build(&ctx).unwrap();
|
||||
assert!(rendered.starts_with("## Current Date & Time\n\n"));
|
||||
assert!(rendered.starts_with("## CRITICAL CONTEXT: CURRENT DATE & TIME\n\n"));
|
||||
|
||||
let payload = rendered.trim_start_matches("## Current Date & Time\n\n");
|
||||
let payload = rendered.trim_start_matches("## CRITICAL CONTEXT: CURRENT DATE & TIME\n\n");
|
||||
assert!(payload.chars().any(|c| c.is_ascii_digit()));
|
||||
assert!(payload.contains(" ("));
|
||||
assert!(payload.ends_with(')'));
|
||||
assert!(payload.contains("Date:"));
|
||||
assert!(payload.contains("Time:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
//! Thinking/Reasoning Level Control
|
||||
//!
|
||||
//! Allows users to control how deeply the model reasons per message,
|
||||
//! trading speed for depth. Levels range from `Off` (fastest, most concise)
|
||||
//! to `Max` (deepest reasoning, slowest).
|
||||
//!
|
||||
//! Users can set the level via:
|
||||
//! - Inline directive: `/think:high` at the start of a message
|
||||
//! - Agent config: `[agent.thinking]` section with `default_level`
|
||||
//!
|
||||
//! Resolution hierarchy (highest priority first):
|
||||
//! 1. Inline directive (`/think:<level>`)
|
||||
//! 2. Session override (reserved for future use)
|
||||
//! 3. Agent config (`agent.thinking.default_level`)
|
||||
//! 4. Global default (`Medium`)
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// How deeply the model should reason for a given message.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ThinkingLevel {
|
||||
/// No chain-of-thought. Fastest, most concise responses.
|
||||
Off,
|
||||
/// Minimal reasoning. Brief, direct answers.
|
||||
Minimal,
|
||||
/// Light reasoning. Short explanations when needed.
|
||||
Low,
|
||||
/// Balanced reasoning (default). Moderate depth.
|
||||
#[default]
|
||||
Medium,
|
||||
/// Deep reasoning. Thorough analysis and step-by-step thinking.
|
||||
High,
|
||||
/// Maximum reasoning depth. Exhaustive analysis.
|
||||
Max,
|
||||
}
|
||||
|
||||
impl ThinkingLevel {
|
||||
/// Parse a thinking level from a string (case-insensitive).
|
||||
pub fn from_str_insensitive(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"off" | "none" => Some(Self::Off),
|
||||
"minimal" | "min" => Some(Self::Minimal),
|
||||
"low" => Some(Self::Low),
|
||||
"medium" | "med" | "default" => Some(Self::Medium),
|
||||
"high" => Some(Self::High),
|
||||
"max" | "maximum" => Some(Self::Max),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for thinking/reasoning level control.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ThinkingConfig {
|
||||
/// Default thinking level when no directive is present.
|
||||
#[serde(default)]
|
||||
pub default_level: ThinkingLevel,
|
||||
}
|
||||
|
||||
impl Default for ThinkingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_level: ThinkingLevel::Medium,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters derived from a thinking level, applied to the LLM request.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ThinkingParams {
|
||||
/// Temperature adjustment (added to the base temperature, clamped to 0.0..=2.0).
|
||||
pub temperature_adjustment: f64,
|
||||
/// Maximum tokens adjustment (added to any existing max_tokens setting).
|
||||
pub max_tokens_adjustment: i64,
|
||||
/// Optional system prompt prefix injected before the existing system prompt.
|
||||
pub system_prompt_prefix: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse a `/think:<level>` directive from the start of a message.
|
||||
///
|
||||
/// Returns `Some((level, remaining_message))` if a directive is found,
|
||||
/// or `None` if no directive is present. The remaining message has
|
||||
/// leading whitespace after the directive trimmed.
|
||||
pub fn parse_thinking_directive(message: &str) -> Option<(ThinkingLevel, String)> {
|
||||
let trimmed = message.trim_start();
|
||||
if !trimmed.starts_with("/think:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract the level token (everything between `/think:` and the next whitespace or end).
|
||||
let after_prefix = &trimmed["/think:".len()..];
|
||||
let level_end = after_prefix
|
||||
.find(|c: char| c.is_whitespace())
|
||||
.unwrap_or(after_prefix.len());
|
||||
let level_str = &after_prefix[..level_end];
|
||||
|
||||
let level = ThinkingLevel::from_str_insensitive(level_str)?;
|
||||
|
||||
let remaining = after_prefix[level_end..].trim_start().to_string();
|
||||
Some((level, remaining))
|
||||
}
|
||||
|
||||
/// Convert a `ThinkingLevel` into concrete parameters for the LLM request.
|
||||
pub fn apply_thinking_level(level: ThinkingLevel) -> ThinkingParams {
|
||||
match level {
|
||||
ThinkingLevel::Off => ThinkingParams {
|
||||
temperature_adjustment: -0.2,
|
||||
max_tokens_adjustment: -1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Be extremely concise. Give direct answers without explanation \
|
||||
unless explicitly asked. No preamble."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Minimal => ThinkingParams {
|
||||
temperature_adjustment: -0.1,
|
||||
max_tokens_adjustment: -500,
|
||||
system_prompt_prefix: Some(
|
||||
"Be concise and fast. Keep explanations brief. \
|
||||
Prioritize speed over thoroughness."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Low => ThinkingParams {
|
||||
temperature_adjustment: -0.05,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: Some("Keep reasoning light. Explain only when helpful.".into()),
|
||||
},
|
||||
ThinkingLevel::Medium => ThinkingParams {
|
||||
temperature_adjustment: 0.0,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: None,
|
||||
},
|
||||
ThinkingLevel::High => ThinkingParams {
|
||||
temperature_adjustment: 0.05,
|
||||
max_tokens_adjustment: 1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think step by step. Provide thorough analysis and \
|
||||
consider edge cases before answering."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Max => ThinkingParams {
|
||||
temperature_adjustment: 0.1,
|
||||
max_tokens_adjustment: 2000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think very carefully and exhaustively. Break down the problem \
|
||||
into sub-problems, consider all angles, verify your reasoning, \
|
||||
and provide the most thorough analysis possible."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the effective thinking level using the priority hierarchy:
|
||||
/// 1. Inline directive (if present)
|
||||
/// 2. Session override (reserved, currently always `None`)
|
||||
/// 3. Agent config default
|
||||
/// 4. Global default (`Medium`)
|
||||
pub fn resolve_thinking_level(
|
||||
inline_directive: Option<ThinkingLevel>,
|
||||
session_override: Option<ThinkingLevel>,
|
||||
config: &ThinkingConfig,
|
||||
) -> ThinkingLevel {
|
||||
inline_directive
|
||||
.or(session_override)
|
||||
.unwrap_or(config.default_level)
|
||||
}
|
||||
|
||||
/// Clamp a temperature value to the valid range `[0.0, 2.0]`.
|
||||
pub fn clamp_temperature(temp: f64) -> f64 {
|
||||
temp.clamp(0.0, 2.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── ThinkingLevel parsing ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_canonical_names() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("off"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("minimal"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("low"),
|
||||
Some(ThinkingLevel::Low)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("medium"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("high"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_aliases() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("none"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("min"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("med"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("default"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("maximum"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_case_insensitive() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("HIGH"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("Max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("OFF"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_invalid_returns_none() {
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("turbo"), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive(""), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("super-high"), None);
|
||||
}
|
||||
|
||||
// ── Directive parsing ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_directive_extracts_level_and_remaining_message() {
|
||||
let result = parse_thinking_directive("/think:high What is Rust?");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::High);
|
||||
assert_eq!(remaining, "What is Rust?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_handles_directive_only() {
|
||||
let result = parse_thinking_directive("/think:off");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Off);
|
||||
assert_eq!(remaining, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_strips_leading_whitespace() {
|
||||
let result = parse_thinking_directive(" /think:low Tell me about Rust");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Low);
|
||||
assert_eq!(remaining, "Tell me about Rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_no_directive() {
|
||||
assert!(parse_thinking_directive("Hello world").is_none());
|
||||
assert!(parse_thinking_directive("").is_none());
|
||||
assert!(parse_thinking_directive("/think").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_invalid_level() {
|
||||
assert!(parse_thinking_directive("/think:turbo What?").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_not_triggered_mid_message() {
|
||||
assert!(parse_thinking_directive("Hello /think:high world").is_none());
|
||||
}
|
||||
|
||||
// ── Level application ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_off_is_concise() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Off);
|
||||
assert!(params.temperature_adjustment < 0.0);
|
||||
assert!(params.max_tokens_adjustment < 0);
|
||||
assert!(params.system_prompt_prefix.is_some());
|
||||
assert!(params
|
||||
.system_prompt_prefix
|
||||
.unwrap()
|
||||
.to_lowercase()
|
||||
.contains("concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_medium_is_neutral() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Medium);
|
||||
assert!((params.temperature_adjustment - 0.0).abs() < f64::EPSILON);
|
||||
assert_eq!(params.max_tokens_adjustment, 0);
|
||||
assert!(params.system_prompt_prefix.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_high_adds_step_by_step() {
|
||||
let params = apply_thinking_level(ThinkingLevel::High);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("step by step"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_max_is_most_thorough() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Max);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("exhaustively"));
|
||||
}
|
||||
|
||||
// ── Resolution hierarchy ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resolve_inline_directive_takes_priority() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result =
|
||||
resolve_thinking_level(Some(ThinkingLevel::Max), Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::Max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_session_override_takes_priority_over_config() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result = resolve_thinking_level(None, Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_falls_back_to_config_default() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Minimal,
|
||||
};
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Minimal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_default_config_uses_medium() {
|
||||
let config = ThinkingConfig::default();
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
// ── Temperature clamping ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_within_range() {
|
||||
assert!((clamp_temperature(0.7) - 0.7).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(0.0) - 0.0).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(2.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_below_minimum() {
|
||||
assert!((clamp_temperature(-0.5) - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_above_maximum() {
|
||||
assert!((clamp_temperature(3.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Serde round-trip ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_config_deserializes_from_toml() {
|
||||
let toml_str = r#"default_level = "high""#;
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_config_default_level_deserializes() {
|
||||
let toml_str = "";
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_serializes_lowercase() {
|
||||
let level = ThinkingLevel::High;
|
||||
let json = serde_json::to_string(&level).unwrap();
|
||||
assert_eq!(json, "\"high\"");
|
||||
}
|
||||
}
|
||||
@@ -562,4 +562,50 @@ mod tests {
|
||||
let parsed: ApprovalRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.tool_name, "shell");
|
||||
}
|
||||
|
||||
// ── Regression: #4247 default approved tools in channels ──
|
||||
|
||||
#[test]
|
||||
fn non_interactive_allows_default_auto_approve_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
|
||||
for tool in &config.auto_approve {
|
||||
assert!(
|
||||
!mgr.needs_approval(tool),
|
||||
"default auto_approve tool '{tool}' should not need approval in non-interactive mode"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_denies_unknown_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("some_unknown_tool"),
|
||||
"unknown tool should need approval"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_weather_is_auto_approved() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
!mgr.needs_approval("weather"),
|
||||
"weather tool must not need approval — it is in the default auto_approve list"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn always_ask_overrides_auto_approve() {
|
||||
let mut config = AutonomyConfig::default();
|
||||
config.always_ask = vec!["weather".into()];
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("weather"),
|
||||
"always_ask must override auto_approve"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,6 +252,7 @@ impl BlueskyChannel {
|
||||
timestamp,
|
||||
thread_ts: Some(notif.uri.clone()),
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ impl Channel for CliChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
@@ -113,6 +114,7 @@ mod tests {
|
||||
timestamp: 1_234_567_890,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
assert_eq!(msg.id, "test-id");
|
||||
assert_eq!(msg.sender, "user");
|
||||
@@ -133,6 +135,7 @@ mod tests {
|
||||
timestamp: 0,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
let cloned = msg.clone();
|
||||
assert_eq!(cloned.id, msg.id);
|
||||
|
||||
@@ -285,6 +285,7 @@ impl Channel for DingTalkChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
|
||||
@@ -20,6 +20,9 @@ pub struct DiscordChannel {
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@@ -38,6 +41,7 @@ impl DiscordChannel {
|
||||
mention_only,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +51,14 @@ impl DiscordChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
|
||||
}
|
||||
@@ -113,6 +125,88 @@ async fn process_attachments(
|
||||
parts.join("\n---\n")
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const DISCORD_AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check if a content type or filename indicates an audio file.
|
||||
fn is_discord_audio_attachment(content_type: &str, filename: &str) -> bool {
|
||||
if content_type.starts_with("audio/") {
|
||||
return true;
|
||||
}
|
||||
if let Some(ext) = filename.rsplit('.').next() {
|
||||
return DISCORD_AUDIO_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str());
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Download and transcribe audio attachments from a Discord message.
|
||||
///
|
||||
/// Returns transcribed text blocks for any audio attachments found.
|
||||
/// Non-audio attachments and failures are silently skipped.
|
||||
async fn transcribe_discord_audio_attachments(
|
||||
attachments: &[serde_json::Value],
|
||||
client: &reqwest::Client,
|
||||
config: &crate::config::TranscriptionConfig,
|
||||
) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
for att in attachments {
|
||||
let ct = att
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let name = att
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("file");
|
||||
|
||||
if !is_discord_audio_attachment(ct, name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(url) = att.get("url").and_then(|v| v.as_str()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let audio_data = match client.get(url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: failed to read audio attachment bytes");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Ok(resp) => {
|
||||
tracing::warn!(name, status = %resp.status(), "discord: audio attachment download failed");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: audio attachment fetch error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, name, config).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
tracing::info!(
|
||||
"Discord: transcribed audio attachment {} ({} chars)",
|
||||
name,
|
||||
trimmed.len()
|
||||
);
|
||||
parts.push(format!("[Voice] {trimmed}"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: voice transcription failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum DiscordAttachmentKind {
|
||||
Image,
|
||||
@@ -737,7 +831,28 @@ impl Channel for DiscordChannel {
|
||||
.and_then(|a| a.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
process_attachments(&atts, &self.http_client()).await
|
||||
let client = self.http_client();
|
||||
let mut text_parts = process_attachments(&atts, &client).await;
|
||||
|
||||
// Transcribe audio attachments when transcription is configured
|
||||
if let Some(ref transcription_config) = self.transcription {
|
||||
let voice_text = transcribe_discord_audio_attachments(
|
||||
&atts,
|
||||
&client,
|
||||
transcription_config,
|
||||
)
|
||||
.await;
|
||||
if !voice_text.is_empty() {
|
||||
if text_parts.is_empty() {
|
||||
text_parts = voice_text;
|
||||
} else {
|
||||
text_parts = format!("{text_parts}
|
||||
{voice_text}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_parts
|
||||
};
|
||||
let final_content = if attachment_text.is_empty() {
|
||||
clean_content
|
||||
@@ -799,6 +914,7 @@ impl Channel for DiscordChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
|
||||
@@ -494,6 +494,7 @@ impl Channel for DiscordHistoryChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: Vec::new(),
|
||||
};
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
break;
|
||||
|
||||
@@ -468,6 +468,7 @@ impl EmailChannel {
|
||||
timestamp: email.timestamp,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
|
||||
@@ -494,6 +494,7 @@ impl GmailPushChannel {
|
||||
timestamp,
|
||||
thread_ts: Some(gmail_msg.thread_id),
|
||||
interruption_scope_id: None,
|
||||
attachments: Vec::new(),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
|
||||
@@ -295,6 +295,7 @@ end tell"#
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
|
||||
@@ -581,6 +581,7 @@ impl Channel for IrcChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
//! Link enricher: auto-detects URLs in inbound messages, fetches their content,
|
||||
//! and prepends summaries so the agent has link context without explicit tool calls.
|
||||
|
||||
use regex::Regex;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for the link enricher pipeline stage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinkEnricherConfig {
|
||||
pub enabled: bool,
|
||||
pub max_links: usize,
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for LinkEnricherConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// URL regex: matches http:// and https:// URLs, stopping at whitespace, angle
|
||||
/// brackets, or double-quotes.
|
||||
static URL_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r#"https?://[^\s<>"']+"#).expect("URL regex must compile"));
|
||||
|
||||
/// Extract URLs from message text, returning up to `max` unique URLs.
|
||||
pub fn extract_urls(text: &str, max: usize) -> Vec<String> {
|
||||
let mut seen = Vec::new();
|
||||
for m in URL_RE.find_iter(text) {
|
||||
let url = m.as_str().to_string();
|
||||
if !seen.contains(&url) {
|
||||
seen.push(url);
|
||||
if seen.len() >= max {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
/// Returns `true` if the URL points to a private/local address that should be
|
||||
/// blocked for SSRF protection.
|
||||
pub fn is_ssrf_target(url: &str) -> bool {
|
||||
let host = match extract_host(url) {
|
||||
Some(h) => h,
|
||||
None => return true, // unparseable URLs are rejected
|
||||
};
|
||||
|
||||
// Check hostname-based locals
|
||||
if host == "localhost"
|
||||
|| host.ends_with(".localhost")
|
||||
|| host.ends_with(".local")
|
||||
|| host == "local"
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check IP-based private ranges
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return is_private_ip(ip);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Extract the host portion from a URL string.
|
||||
fn extract_host(url: &str) -> Option<String> {
|
||||
let rest = url
|
||||
.strip_prefix("https://")
|
||||
.or_else(|| url.strip_prefix("http://"))?;
|
||||
let authority = rest.split(['/', '?', '#']).next()?;
|
||||
if authority.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Strip port
|
||||
let host = if authority.starts_with('[') {
|
||||
// IPv6 in brackets — reject for simplicity
|
||||
return None;
|
||||
} else {
|
||||
authority.split(':').next().unwrap_or(authority)
|
||||
};
|
||||
Some(host.to_lowercase())
|
||||
}
|
||||
|
||||
/// Check if an IP address falls within private/reserved ranges.
|
||||
fn is_private_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
v4.is_loopback() // 127.0.0.0/8
|
||||
|| v4.is_private() // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
|
||||
|| v4.is_link_local() // 169.254.0.0/16
|
||||
|| v4.is_unspecified() // 0.0.0.0
|
||||
|| v4.is_broadcast() // 255.255.255.255
|
||||
|| v4.is_multicast() // 224.0.0.0/4
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
v6.is_loopback() // ::1
|
||||
|| v6.is_unspecified() // ::
|
||||
|| v6.is_multicast()
|
||||
// Check for IPv4-mapped IPv6 addresses
|
||||
|| v6.to_ipv4_mapped().is_some_and(|v4| {
|
||||
v4.is_loopback()
|
||||
|| v4.is_private()
|
||||
|| v4.is_link_local()
|
||||
|| v4.is_unspecified()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the `<title>` tag content from HTML.
|
||||
pub fn extract_title(html: &str) -> Option<String> {
|
||||
// Case-insensitive search for <title>...</title>
|
||||
let lower = html.to_lowercase();
|
||||
let start = lower.find("<title")? + "<title".len();
|
||||
// Skip attributes if any (e.g. <title lang="en">)
|
||||
let start = lower[start..].find('>')? + start + 1;
|
||||
let end = lower[start..].find("</title")? + start;
|
||||
let title = lower[start..end].trim().to_string();
|
||||
if title.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(html_entity_decode_basic(&title))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the first `max_chars` of visible body text from HTML.
|
||||
pub fn extract_body_text(html: &str, max_chars: usize) -> String {
|
||||
let text = nanohtml2text::html2text(html);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.len() <= max_chars {
|
||||
trimmed.to_string()
|
||||
} else {
|
||||
let mut result: String = trimmed.chars().take(max_chars).collect();
|
||||
result.push_str("...");
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Basic HTML entity decoding for title content.
|
||||
fn html_entity_decode_basic(s: &str) -> String {
|
||||
s.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
}
|
||||
|
||||
/// Summary of a fetched link.
|
||||
struct LinkSummary {
|
||||
title: String,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
/// Fetch a single URL and extract a summary. Returns `None` on any failure.
|
||||
async fn fetch_link_summary(url: &str, timeout_secs: u64) -> Option<LinkSummary> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.redirect(reqwest::redirect::Policy::limited(5))
|
||||
.user_agent("ZeroClaw/0.1 (link-enricher)")
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let response = client.get(url).send().await.ok()?;
|
||||
if !response.status().is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Only process text/html responses
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
if !content_type.contains("text/html") && !content_type.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read up to 256KB to extract title and snippet
|
||||
let max_bytes: usize = 256 * 1024;
|
||||
let bytes = response.bytes().await.ok()?;
|
||||
let body = if bytes.len() > max_bytes {
|
||||
String::from_utf8_lossy(&bytes[..max_bytes]).into_owned()
|
||||
} else {
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
};
|
||||
|
||||
let title = extract_title(&body).unwrap_or_else(|| "Untitled".to_string());
|
||||
let snippet = extract_body_text(&body, 200);
|
||||
|
||||
Some(LinkSummary { title, snippet })
|
||||
}
|
||||
|
||||
/// Enrich a message by prepending link summaries for any URLs found in the text.
|
||||
///
|
||||
/// This is the main entry point called from the channel message processing pipeline.
|
||||
/// If the enricher is disabled or no URLs are found, the original message is returned
|
||||
/// unchanged.
|
||||
pub async fn enrich_message(content: &str, config: &LinkEnricherConfig) -> String {
|
||||
if !config.enabled || config.max_links == 0 {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let urls = extract_urls(content, config.max_links);
|
||||
if urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
// Filter out SSRF targets
|
||||
let safe_urls: Vec<&str> = urls
|
||||
.iter()
|
||||
.filter(|u| !is_ssrf_target(u))
|
||||
.map(|u| u.as_str())
|
||||
.collect();
|
||||
if safe_urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let mut enrichments = Vec::new();
|
||||
for url in safe_urls {
|
||||
match fetch_link_summary(url, config.timeout_secs).await {
|
||||
Some(summary) => {
|
||||
enrichments.push(format!("[Link: {} — {}]", summary.title, summary.snippet));
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(url, "Link enricher: failed to fetch or extract summary");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if enrichments.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let prefix = enrichments.join("\n");
|
||||
format!("{prefix}\n{content}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── URL extraction ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_urls_finds_http_and_https() {
|
||||
let text = "Check https://example.com and http://test.org/page for info";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com", "http://test.org/page",]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_respects_max() {
|
||||
let text = "https://a.com https://b.com https://c.com https://d.com";
|
||||
let urls = extract_urls(text, 2);
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert_eq!(urls[0], "https://a.com");
|
||||
assert_eq!(urls[1], "https://b.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_deduplicates() {
|
||||
let text = "Visit https://example.com and https://example.com again";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_handles_no_urls() {
|
||||
let text = "Just a normal message without links";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert!(urls.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_angle_brackets() {
|
||||
let text = "Link: <https://example.com/path> done";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/path"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_quotes() {
|
||||
let text = r#"href="https://example.com/page" end"#;
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/page"]);
|
||||
}
|
||||
|
||||
// ── SSRF protection ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_localhost() {
|
||||
assert!(is_ssrf_target("http://localhost/admin"));
|
||||
assert!(is_ssrf_target("https://localhost:8080/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_loopback_ip() {
|
||||
assert!(is_ssrf_target("http://127.0.0.1/secret"));
|
||||
assert!(is_ssrf_target("http://127.0.0.2:9090"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_10_network() {
|
||||
assert!(is_ssrf_target("http://10.0.0.1/internal"));
|
||||
assert!(is_ssrf_target("http://10.255.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_172_network() {
|
||||
assert!(is_ssrf_target("http://172.16.0.1/admin"));
|
||||
assert!(is_ssrf_target("http://172.31.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_192_168_network() {
|
||||
assert!(is_ssrf_target("http://192.168.1.1/router"));
|
||||
assert!(is_ssrf_target("http://192.168.0.100:3000"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_link_local() {
|
||||
assert!(is_ssrf_target("http://169.254.0.1/metadata"));
|
||||
assert!(is_ssrf_target("http://169.254.169.254/latest"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_ipv6_loopback() {
|
||||
// IPv6 in brackets is rejected by extract_host
|
||||
assert!(is_ssrf_target("http://[::1]/admin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_dot_local() {
|
||||
assert!(is_ssrf_target("http://myhost.local/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_allows_public_urls() {
|
||||
assert!(!is_ssrf_target("https://example.com/page"));
|
||||
assert!(!is_ssrf_target("https://www.google.com"));
|
||||
assert!(!is_ssrf_target("http://93.184.216.34/resource"));
|
||||
}
|
||||
|
||||
// ── Title extraction ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_title_basic() {
|
||||
let html = "<html><head><title>My Page Title</title></head><body>Hello</body></html>";
|
||||
assert_eq!(extract_title(html), Some("my page title".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_with_entities() {
|
||||
let html = "<title>Tom & Jerry's Page</title>";
|
||||
assert_eq!(extract_title(html), Some("tom & jerry's page".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_case_insensitive() {
|
||||
let html = "<HTML><HEAD><TITLE>Upper Case</TITLE></HEAD></HTML>";
|
||||
assert_eq!(extract_title(html), Some("upper case".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_multibyte_chars_no_panic() {
|
||||
// İ (U+0130) lowercases to 2 chars, changing byte length.
|
||||
// This must not panic or produce wrong offsets.
|
||||
let html = "<title>İstanbul Guide</title>";
|
||||
let result = extract_title(html);
|
||||
assert!(result.is_some());
|
||||
let title = result.unwrap();
|
||||
assert!(title.contains("stanbul"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_missing() {
|
||||
let html = "<html><body>No title here</body></html>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_empty() {
|
||||
let html = "<title> </title>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
// ── Body text extraction ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_strips_html() {
|
||||
let html = "<html><body><h1>Header</h1><p>Some content here</p></body></html>";
|
||||
let text = extract_body_text(html, 200);
|
||||
assert!(text.contains("Header"));
|
||||
assert!(text.contains("Some content"));
|
||||
assert!(!text.contains("<h1>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_truncates() {
|
||||
let html = "<p>A very long paragraph that should be truncated to fit within the limit.</p>";
|
||||
let text = extract_body_text(html, 20);
|
||||
assert!(text.len() <= 25); // 20 chars + "..."
|
||||
assert!(text.ends_with("..."));
|
||||
}
|
||||
|
||||
// ── Config toggle ───────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_disabled_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Check https://example.com for details";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_no_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "No links in this message";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_ssrf_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Try http://127.0.0.1/admin and http://192.168.1.1/router";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_is_disabled() {
|
||||
let config = LinkEnricherConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_links, 3);
|
||||
assert_eq!(config.timeout_secs, 10);
|
||||
}
|
||||
}
|
||||
@@ -268,6 +268,7 @@ impl LinqChannel {
|
||||
timestamp,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
});
|
||||
|
||||
messages
|
||||
|
||||
@@ -8,6 +8,7 @@ use matrix_sdk::{
|
||||
events::reaction::ReactionEventContent,
|
||||
events::receipt::ReceiptThread,
|
||||
events::relation::{Annotation, Thread},
|
||||
events::room::member::StrippedRoomMemberEvent,
|
||||
events::room::message::{
|
||||
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||
},
|
||||
@@ -32,6 +33,7 @@ pub struct MatrixChannel {
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
session_owner_hint: Option<String>,
|
||||
session_device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
@@ -48,6 +50,7 @@ impl std::fmt::Debug for MatrixChannel {
|
||||
.field("homeserver", &self.homeserver)
|
||||
.field("room_id", &self.room_id)
|
||||
.field("allowed_users", &self.allowed_users)
|
||||
.field("allowed_rooms", &self.allowed_rooms)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -121,7 +124,16 @@ impl MatrixChannel {
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint(homeserver, access_token, room_id, allowed_users, None, None)
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_with_session_hint(
|
||||
@@ -132,11 +144,12 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint_and_zeroclaw_dir(
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
None,
|
||||
@@ -151,6 +164,28 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
zeroclaw_dir,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_full(
|
||||
homeserver: String,
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
let homeserver = homeserver.trim_end_matches('/').to_string();
|
||||
let access_token = access_token.trim().to_string();
|
||||
@@ -160,12 +195,18 @@ impl MatrixChannel {
|
||||
.map(|user| user.trim().to_string())
|
||||
.filter(|user| !user.is_empty())
|
||||
.collect();
|
||||
let allowed_rooms = allowed_rooms
|
||||
.into_iter()
|
||||
.map(|room| room.trim().to_string())
|
||||
.filter(|room| !room.is_empty())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms,
|
||||
session_owner_hint: Self::normalize_optional_field(owner_hint),
|
||||
session_device_id_hint: Self::normalize_optional_field(device_id_hint),
|
||||
zeroclaw_dir,
|
||||
@@ -220,6 +261,21 @@ impl MatrixChannel {
|
||||
allowed_users.iter().any(|u| u.eq_ignore_ascii_case(sender))
|
||||
}
|
||||
|
||||
/// Check whether a room (by its canonical ID) is in the allowed_rooms list.
|
||||
/// If allowed_rooms is empty, all rooms are allowed.
|
||||
fn is_room_allowed_static(allowed_rooms: &[String], room_id: &str) -> bool {
|
||||
if allowed_rooms.is_empty() {
|
||||
return true;
|
||||
}
|
||||
allowed_rooms
|
||||
.iter()
|
||||
.any(|r| r.eq_ignore_ascii_case(room_id))
|
||||
}
|
||||
|
||||
fn is_room_allowed(&self, room_id: &str) -> bool {
|
||||
Self::is_room_allowed_static(&self.allowed_rooms, room_id)
|
||||
}
|
||||
|
||||
fn is_supported_message_type(msgtype: &str) -> bool {
|
||||
matches!(msgtype, "m.text" | "m.notice")
|
||||
}
|
||||
@@ -228,6 +284,10 @@ impl MatrixChannel {
|
||||
!body.trim().is_empty()
|
||||
}
|
||||
|
||||
fn room_matches_target(target_room_id: &str, incoming_room_id: &str) -> bool {
|
||||
target_room_id == incoming_room_id
|
||||
}
|
||||
|
||||
fn cache_event_id(
|
||||
event_id: &str,
|
||||
recent_order: &mut std::collections::VecDeque<String>,
|
||||
@@ -526,8 +586,9 @@ impl MatrixChannel {
|
||||
if client.encryption().backups().are_enabled().await {
|
||||
tracing::info!("Matrix room-key backup is enabled for this device.");
|
||||
} else {
|
||||
let _ = client.encryption().backups().disable().await;
|
||||
tracing::warn!(
|
||||
"Matrix room-key backup is not enabled for this device; `matrix_sdk_crypto::backups` warnings about missing backup keys may appear until recovery is configured."
|
||||
"Matrix room-key backup is not enabled for this device; automatic backup attempts have been disabled to suppress recurring warnings. To enable backups, configure server-side key backup and recovery for this device."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -697,6 +758,7 @@ impl Channel for MatrixChannel {
|
||||
let target_room_for_handler = target_room.clone();
|
||||
let my_user_id_for_handler = my_user_id.clone();
|
||||
let allowed_users_for_handler = self.allowed_users.clone();
|
||||
let allowed_rooms_for_handler = self.allowed_rooms.clone();
|
||||
let dedupe_for_handler = Arc::clone(&recent_event_cache);
|
||||
let homeserver_for_handler = self.homeserver.clone();
|
||||
let access_token_for_handler = self.access_token.clone();
|
||||
@@ -704,18 +766,29 @@ impl Channel for MatrixChannel {
|
||||
|
||||
client.add_event_handler(move |event: OriginalSyncRoomMessageEvent, room: Room| {
|
||||
let tx = tx_handler.clone();
|
||||
let _target_room = target_room_for_handler.clone();
|
||||
let target_room = target_room_for_handler.clone();
|
||||
let my_user_id = my_user_id_for_handler.clone();
|
||||
let allowed_users = allowed_users_for_handler.clone();
|
||||
let allowed_rooms = allowed_rooms_for_handler.clone();
|
||||
let dedupe = Arc::clone(&dedupe_for_handler);
|
||||
let homeserver = homeserver_for_handler.clone();
|
||||
let access_token = access_token_for_handler.clone();
|
||||
let voice_mode = Arc::clone(&voice_mode_for_handler);
|
||||
|
||||
async move {
|
||||
if false
|
||||
/* multi-room: room_id filter disabled */
|
||||
{
|
||||
if !MatrixChannel::room_matches_target(
|
||||
target_room.as_str(),
|
||||
room.room_id().as_str(),
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Room allowlist: skip messages from rooms not in the configured list
|
||||
if !MatrixChannel::is_room_allowed_static(&allowed_rooms, room.room_id().as_ref()) {
|
||||
tracing::debug!(
|
||||
"Matrix: ignoring message from room {} (not in allowed_rooms)",
|
||||
room.room_id()
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -901,12 +974,52 @@ impl Channel for MatrixChannel {
|
||||
.as_secs(),
|
||||
thread_ts: thread_ts.clone(),
|
||||
interruption_scope_id: thread_ts,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Invite handler: auto-accept invites for allowed rooms, auto-reject others
|
||||
let allowed_rooms_for_invite = self.allowed_rooms.clone();
|
||||
client.add_event_handler(move |event: StrippedRoomMemberEvent, room: Room| {
|
||||
let allowed_rooms = allowed_rooms_for_invite.clone();
|
||||
async move {
|
||||
// Only process invite events targeting us
|
||||
if event.content.membership
|
||||
!= matrix_sdk::ruma::events::room::member::MembershipState::Invite
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let room_id_str = room.room_id().to_string();
|
||||
|
||||
if MatrixChannel::is_room_allowed_static(&allowed_rooms, &room_id_str) {
|
||||
// Room is allowed (or no allowlist configured): auto-accept
|
||||
tracing::info!(
|
||||
"Matrix: auto-accepting invite for allowed room {}",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.join().await {
|
||||
tracing::warn!("Matrix: failed to auto-join room {}: {error}", room_id_str);
|
||||
}
|
||||
} else {
|
||||
// Room is NOT in allowlist: auto-reject
|
||||
tracing::info!(
|
||||
"Matrix: auto-rejecting invite for room {} (not in allowed_rooms)",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.leave().await {
|
||||
tracing::warn!(
|
||||
"Matrix: failed to reject invite for room {}: {error}",
|
||||
room_id_str
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sync_settings = SyncSettings::new().timeout(std::time::Duration::from_secs(30));
|
||||
client
|
||||
.sync_with_result_callback(sync_settings, |sync_result| {
|
||||
@@ -1294,6 +1407,22 @@ mod tests {
|
||||
assert_eq!(value["room"]["timeline"]["limit"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_matches_configured_room() {
|
||||
assert!(MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!ops:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_rejects_other_rooms() {
|
||||
assert!(!MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!other:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_id_cache_deduplicates_and_evicts_old_entries() {
|
||||
let mut recent_order = std::collections::VecDeque::new();
|
||||
@@ -1549,4 +1678,79 @@ mod tests {
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.rooms.join.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowed_rooms_permits_all() {
|
||||
let ch = make_channel();
|
||||
assert!(ch.is_room_allowed("!any:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!other:evil.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_filters_by_id() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec!["!allowed:matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!allowed:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!forbidden:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_supports_aliases() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec![
|
||||
"#ops:matrix.org".to_string(),
|
||||
"!direct:matrix.org".to_string(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!direct:matrix.org"));
|
||||
assert!(ch.is_room_allowed("#ops:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!other:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_case_insensitive() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec!["!Room:Matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!ROOM:MATRIX.ORG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_trims_whitespace() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec![" !room:matrix.org ".to_string(), " ".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(ch.allowed_rooms.len(), 1);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
const MAX_MATTERMOST_AUDIO_BYTES: u64 = 25 * 1024 * 1024;
|
||||
|
||||
/// Mattermost channel — polls channel posts via REST API v4.
|
||||
/// Mattermost is API-compatible with many Slack patterns but uses a dedicated v4 structure.
|
||||
@@ -19,6 +22,8 @@ pub struct MattermostChannel {
|
||||
typing_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
transcription_manager: Option<Arc<super::transcription::TranscriptionManager>>,
|
||||
}
|
||||
|
||||
impl MattermostChannel {
|
||||
@@ -41,6 +46,8 @@ impl MattermostChannel {
|
||||
mention_only,
|
||||
typing_handle: Mutex::new(None),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
transcription_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,6 +57,24 @@ impl MattermostChannel {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if !config.enabled {
|
||||
return self;
|
||||
}
|
||||
match super::transcription::TranscriptionManager::new(&config) {
|
||||
Ok(m) => {
|
||||
self.transcription_manager = Some(Arc::new(m));
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"transcription manager init failed, voice transcription disabled: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client("channel.mattermost", self.proxy_url.as_deref())
|
||||
}
|
||||
@@ -90,6 +115,91 @@ impl MattermostChannel {
|
||||
.to_string();
|
||||
(id, username)
|
||||
}
|
||||
|
||||
async fn try_transcribe_audio_attachment(&self, post: &serde_json::Value) -> Option<String> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
let manager = self.transcription_manager.as_deref()?;
|
||||
|
||||
let files = post
|
||||
.get("metadata")
|
||||
.and_then(|m| m.get("files"))
|
||||
.and_then(|f| f.as_array())?;
|
||||
|
||||
let audio_file = files.iter().find(|f| is_audio_file(f))?;
|
||||
|
||||
if let Some(duration_ms) = audio_file.get("duration").and_then(|d| d.as_u64()) {
|
||||
let duration_secs = duration_ms / 1000;
|
||||
if duration_secs > config.max_duration_secs as u64 {
|
||||
tracing::debug!(
|
||||
duration_secs,
|
||||
max = config.max_duration_secs,
|
||||
"Mattermost audio attachment exceeds max duration, skipping"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let file_id = audio_file.get("id").and_then(|i| i.as_str())?;
|
||||
let file_name = audio_file
|
||||
.get("name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or("audio");
|
||||
|
||||
let response = match self
|
||||
.http_client()
|
||||
.get(format!("{}/api/v4/files/{}", self.base_url, file_id))
|
||||
.bearer_auth(&self.bot_token)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Mattermost: audio download failed for {file_id}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Mattermost: audio download returned {}: {file_id}",
|
||||
response.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(content_length) = response.content_length() {
|
||||
if content_length > MAX_MATTERMOST_AUDIO_BYTES {
|
||||
tracing::warn!(
|
||||
"Mattermost: audio file too large ({content_length} bytes): {file_id}"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let bytes = match response.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Mattermost: failed to read audio bytes for {file_id}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match manager.transcribe(&bytes, file_name).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::info!("Mattermost: transcription returned empty text, skipping");
|
||||
None
|
||||
} else {
|
||||
Some(format!("[Voice] {trimmed}"))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Mattermost audio transcription failed: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -188,21 +298,35 @@ impl Channel for MattermostChannel {
|
||||
let mut post_list: Vec<_> = posts.values().collect();
|
||||
post_list.sort_by_key(|p| p.get("create_at").and_then(|c| c.as_i64()).unwrap_or(0));
|
||||
|
||||
let last_create_at_before_this_batch = last_create_at;
|
||||
for post in post_list {
|
||||
let msg = self.parse_mattermost_post(
|
||||
post,
|
||||
&bot_user_id,
|
||||
&bot_username,
|
||||
last_create_at,
|
||||
&channel_id,
|
||||
);
|
||||
let create_at = post
|
||||
.get("create_at")
|
||||
.and_then(|c| c.as_i64())
|
||||
.unwrap_or(last_create_at);
|
||||
last_create_at = last_create_at.max(create_at);
|
||||
|
||||
if let Some(channel_msg) = msg {
|
||||
let effective_text = if post
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.is_empty()
|
||||
&& post_has_audio_attachment(post)
|
||||
{
|
||||
self.try_transcribe_audio_attachment(post).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(channel_msg) = self.parse_mattermost_post(
|
||||
post,
|
||||
&bot_user_id,
|
||||
&bot_username,
|
||||
last_create_at_before_this_batch,
|
||||
&channel_id,
|
||||
effective_text.as_deref(),
|
||||
) {
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -286,6 +410,7 @@ impl MattermostChannel {
|
||||
bot_username: &str,
|
||||
last_create_at: i64,
|
||||
channel_id: &str,
|
||||
injected_text: Option<&str>,
|
||||
) -> Option<ChannelMessage> {
|
||||
let id = post.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let user_id = post.get("user_id").and_then(|u| u.as_str()).unwrap_or("");
|
||||
@@ -293,10 +418,16 @@ impl MattermostChannel {
|
||||
let create_at = post.get("create_at").and_then(|c| c.as_i64()).unwrap_or(0);
|
||||
let root_id = post.get("root_id").and_then(|r| r.as_str()).unwrap_or("");
|
||||
|
||||
if user_id == bot_user_id || create_at <= last_create_at || text.is_empty() {
|
||||
if user_id == bot_user_id || create_at <= last_create_at {
|
||||
return None;
|
||||
}
|
||||
|
||||
let effective_text = if text.is_empty() {
|
||||
injected_text?
|
||||
} else {
|
||||
text
|
||||
};
|
||||
|
||||
if !self.is_user_allowed(user_id) {
|
||||
tracing::warn!("Mattermost: ignoring message from unauthorized user: {user_id}");
|
||||
return None;
|
||||
@@ -304,10 +435,11 @@ impl MattermostChannel {
|
||||
|
||||
// mention_only filtering: skip messages that don't @-mention the bot.
|
||||
let content = if self.mention_only {
|
||||
let normalized = normalize_mattermost_content(text, bot_user_id, bot_username, post);
|
||||
let normalized =
|
||||
normalize_mattermost_content(effective_text, bot_user_id, bot_username, post);
|
||||
normalized?
|
||||
} else {
|
||||
text.to_string()
|
||||
effective_text.to_string()
|
||||
};
|
||||
|
||||
// Reply routing depends on thread_replies config:
|
||||
@@ -332,10 +464,32 @@ impl MattermostChannel {
|
||||
timestamp: (create_at / 1000) as u64,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn post_has_audio_attachment(post: &serde_json::Value) -> bool {
|
||||
let files = post
|
||||
.get("metadata")
|
||||
.and_then(|m| m.get("files"))
|
||||
.and_then(|f| f.as_array());
|
||||
let Some(files) = files else { return false };
|
||||
files.iter().any(is_audio_file)
|
||||
}
|
||||
|
||||
fn is_audio_file(file: &serde_json::Value) -> bool {
|
||||
let mime = file.get("mime_type").and_then(|m| m.as_str()).unwrap_or("");
|
||||
if mime.starts_with("audio/") {
|
||||
return true;
|
||||
}
|
||||
let ext = file.get("extension").and_then(|e| e.as_str()).unwrap_or("");
|
||||
matches!(
|
||||
ext.to_ascii_lowercase().as_str(),
|
||||
"ogg" | "mp3" | "m4a" | "wav" | "opus" | "flac"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check whether a Mattermost post contains an @-mention of the bot.
|
||||
///
|
||||
/// Checks two sources:
|
||||
@@ -518,7 +672,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.sender, "user456");
|
||||
assert_eq!(msg.content, "hello world");
|
||||
@@ -537,7 +698,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.reply_target, "chan789:post123"); // Threaded reply
|
||||
}
|
||||
@@ -554,7 +722,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.reply_target, "chan789:root789"); // Stays in the thread
|
||||
}
|
||||
@@ -569,8 +744,14 @@ mod tests {
|
||||
"create_at": 1_600_000_000_000_i64
|
||||
});
|
||||
|
||||
let msg =
|
||||
ch.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789");
|
||||
let msg = ch.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
@@ -584,8 +765,14 @@ mod tests {
|
||||
"create_at": 1_400_000_000_000_i64
|
||||
});
|
||||
|
||||
let msg =
|
||||
ch.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789");
|
||||
let msg = ch.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
@@ -601,7 +788,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.reply_target, "chan789"); // No thread suffix
|
||||
}
|
||||
@@ -619,7 +813,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "botname", 1_500_000_000_000_i64, "chan789")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.reply_target, "chan789:root789"); // Stays in existing thread
|
||||
}
|
||||
@@ -637,8 +838,14 @@ mod tests {
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let msg =
|
||||
ch.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1");
|
||||
let msg = ch.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
@@ -654,7 +861,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.content, "what is the weather?");
|
||||
}
|
||||
@@ -671,7 +885,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.content, "run status");
|
||||
}
|
||||
@@ -687,8 +908,14 @@ mod tests {
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let msg =
|
||||
ch.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1");
|
||||
let msg = ch.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
@@ -704,7 +931,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.content, "hello");
|
||||
}
|
||||
@@ -725,7 +959,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
// Content is preserved as-is since no @username was in the text to strip.
|
||||
assert_eq!(msg.content, "hey check this out");
|
||||
@@ -743,8 +984,14 @@ mod tests {
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let msg =
|
||||
ch.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1");
|
||||
let msg = ch.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
@@ -760,7 +1007,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.content, "hey how are you?");
|
||||
}
|
||||
@@ -778,7 +1032,14 @@ mod tests {
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(&post, "bot123", "mybot", 1_500_000_000_000_i64, "chan1")
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"mybot",
|
||||
1_500_000_000_000_i64,
|
||||
"chan1",
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.content, "no mention here");
|
||||
}
|
||||
@@ -925,4 +1186,333 @@ mod tests {
|
||||
normalize_mattermost_content("@mybot hello @mybotx world", "bot123", "mybot", &post);
|
||||
assert_eq!(result.as_deref(), Some("hello @mybotx world"));
|
||||
}
|
||||
|
||||
// ── Transcription tests ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn mattermost_manager_none_when_transcription_not_configured() {
|
||||
let ch = make_channel(vec!["*".into()], false);
|
||||
assert!(ch.transcription_manager.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_manager_some_when_valid_config() {
|
||||
let ch = make_channel(vec!["*".into()], false).with_transcription(
|
||||
crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
default_provider: "groq".to_string(),
|
||||
api_key: Some("test_key".to_string()),
|
||||
api_url: "https://api.groq.com/openai/v1/audio/transcriptions".to_string(),
|
||||
model: "whisper-large-v3".to_string(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: 600,
|
||||
openai: None,
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
},
|
||||
);
|
||||
assert!(ch.transcription_manager.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_manager_none_and_warn_on_init_failure() {
|
||||
let ch = make_channel(vec!["*".into()], false).with_transcription(
|
||||
crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
default_provider: "groq".to_string(),
|
||||
api_key: Some(String::new()),
|
||||
api_url: "https://api.groq.com/openai/v1/audio/transcriptions".to_string(),
|
||||
model: "whisper-large-v3".to_string(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: 600,
|
||||
openai: None,
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
},
|
||||
);
|
||||
assert!(ch.transcription_manager.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_post_has_audio_attachment_true_for_audio_mime() {
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "audio/ogg",
|
||||
"name": "voice.ogg"
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
assert!(post_has_audio_attachment(&post));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_post_has_audio_attachment_true_for_audio_ext() {
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "application/octet-stream",
|
||||
"extension": "ogg"
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
assert!(post_has_audio_attachment(&post));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_post_has_audio_attachment_false_for_image() {
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "image/png",
|
||||
"name": "screenshot.png"
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
assert!(!post_has_audio_attachment(&post));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_post_has_audio_attachment_false_when_no_files() {
|
||||
let post = json!({
|
||||
"metadata": {}
|
||||
});
|
||||
assert!(!post_has_audio_attachment(&post));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_parse_post_uses_injected_text() {
|
||||
let ch = make_channel(vec!["*".into()], true);
|
||||
let post = json!({
|
||||
"id": "post123",
|
||||
"user_id": "user456",
|
||||
"message": "",
|
||||
"create_at": 1_600_000_000_000_i64,
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
Some("transcript text"),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(msg.content, "transcript text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mattermost_parse_post_rejects_empty_message_without_injected() {
|
||||
let ch = make_channel(vec!["*".into()], true);
|
||||
let post = json!({
|
||||
"id": "post123",
|
||||
"user_id": "user456",
|
||||
"message": "",
|
||||
"create_at": 1_600_000_000_000_i64,
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let msg = ch.parse_mattermost_post(
|
||||
&post,
|
||||
"bot123",
|
||||
"botname",
|
||||
1_500_000_000_000_i64,
|
||||
"chan789",
|
||||
None,
|
||||
);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mattermost_transcribe_skips_when_manager_none() {
|
||||
let ch = make_channel(vec!["*".into()], false);
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "audio/ogg",
|
||||
"name": "voice.ogg"
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
let result = ch.try_transcribe_audio_attachment(&post).await;
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mattermost_transcribe_skips_over_duration_limit() {
|
||||
let ch = make_channel(vec!["*".into()], false).with_transcription(
|
||||
crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
default_provider: "groq".to_string(),
|
||||
api_key: Some("test_key".to_string()),
|
||||
api_url: "https://api.groq.com/openai/v1/audio/transcriptions".to_string(),
|
||||
model: "whisper-large-v3".to_string(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: 3600,
|
||||
openai: None,
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
},
|
||||
);
|
||||
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "audio/ogg",
|
||||
"name": "voice.ogg",
|
||||
"duration": 7_200_000_u64
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let result = ch.try_transcribe_audio_attachment(&post).await;
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod http_tests {
|
||||
use super::*;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
#[tokio::test]
|
||||
async fn mattermost_audio_routes_through_local_whisper() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/v4/files/file1"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_bytes(b"audio bytes"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/audio/transcriptions"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(json!({"text": "test transcript"})),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let whisper_url = format!("{}/v1/audio/transcriptions", mock_server.uri());
|
||||
let ch = MattermostChannel::new(
|
||||
mock_server.uri(),
|
||||
"test_token".to_string(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.with_transcription(crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
default_provider: "local_whisper".to_string(),
|
||||
api_key: None,
|
||||
api_url: "https://api.groq.com/openai/v1/audio/transcriptions".to_string(),
|
||||
model: "whisper-large-v3".to_string(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: 600,
|
||||
openai: None,
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: Some(crate::config::LocalWhisperConfig {
|
||||
url: whisper_url,
|
||||
bearer_token: "test_token".to_string(),
|
||||
max_audio_bytes: 25_000_000,
|
||||
timeout_secs: 300,
|
||||
}),
|
||||
});
|
||||
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "audio/ogg",
|
||||
"name": "voice.ogg"
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let result = ch.try_transcribe_audio_attachment(&post).await;
|
||||
assert_eq!(result.as_deref(), Some("[Voice] test transcript"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mattermost_audio_skips_non_audio_attachment() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
let ch = MattermostChannel::new(
|
||||
mock_server.uri(),
|
||||
"test_token".to_string(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.with_transcription(crate::config::TranscriptionConfig {
|
||||
enabled: true,
|
||||
default_provider: "local_whisper".to_string(),
|
||||
api_key: None,
|
||||
api_url: "https://api.groq.com/openai/v1/audio/transcriptions".to_string(),
|
||||
model: "whisper-large-v3".to_string(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: 600,
|
||||
openai: None,
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: Some(crate::config::LocalWhisperConfig {
|
||||
url: mock_server.uri(),
|
||||
bearer_token: "test_token".to_string(),
|
||||
max_audio_bytes: 25_000_000,
|
||||
timeout_secs: 300,
|
||||
}),
|
||||
});
|
||||
|
||||
let post = json!({
|
||||
"metadata": {
|
||||
"files": [
|
||||
{
|
||||
"id": "file1",
|
||||
"mime_type": "image/png",
|
||||
"name": "screenshot.png"
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let result = ch.try_transcribe_audio_attachment(&post).await;
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
//! Automatic media understanding pipeline for inbound channel messages.
|
||||
//!
|
||||
//! Pre-processes media attachments (audio, images, video) before the agent sees
|
||||
//! the message, enriching the text with human-readable annotations:
|
||||
//!
|
||||
//! - **Audio**: transcribed via the existing [`super::transcription`] infrastructure,
|
||||
//! prepended as `[Audio transcription: ...]`.
|
||||
//! - **Images**: when a vision-capable provider is active, described as `[Image: <description>]`.
|
||||
//! Falls back to `[Image: attached]` when vision is unavailable.
|
||||
//! - **Video**: summarised as `[Video summary: ...]` when an API is available,
|
||||
//! otherwise `[Video: attached]`.
|
||||
//!
|
||||
//! The pipeline is **opt-in** via `[media_pipeline] enabled = true` in config.
|
||||
|
||||
use crate::config::{MediaPipelineConfig, TranscriptionConfig};
|
||||
|
||||
/// Classifies an attachment by MIME type or file extension.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MediaKind {
|
||||
Audio,
|
||||
Image,
|
||||
Video,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// A single media attachment on an inbound message.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MediaAttachment {
|
||||
/// Original file name (e.g. `voice.ogg`, `photo.jpg`).
|
||||
pub file_name: String,
|
||||
/// Raw bytes of the attachment.
|
||||
pub data: Vec<u8>,
|
||||
/// MIME type if known (e.g. `audio/ogg`, `image/jpeg`).
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
impl MediaAttachment {
|
||||
/// Classify this attachment into a [`MediaKind`].
|
||||
pub fn kind(&self) -> MediaKind {
|
||||
// Try MIME type first.
|
||||
if let Some(ref mime) = self.mime_type {
|
||||
let lower = mime.to_ascii_lowercase();
|
||||
if lower.starts_with("audio/") {
|
||||
return MediaKind::Audio;
|
||||
}
|
||||
if lower.starts_with("image/") {
|
||||
return MediaKind::Image;
|
||||
}
|
||||
if lower.starts_with("video/") {
|
||||
return MediaKind::Video;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to file extension.
|
||||
let ext = self
|
||||
.file_name
|
||||
.rsplit_once('.')
|
||||
.map(|(_, e)| e.to_ascii_lowercase())
|
||||
.unwrap_or_default();
|
||||
|
||||
match ext.as_str() {
|
||||
"flac" | "mp3" | "mpeg" | "mpga" | "m4a" | "ogg" | "oga" | "opus" | "wav" | "webm" => {
|
||||
MediaKind::Audio
|
||||
}
|
||||
"png" | "jpg" | "jpeg" | "gif" | "bmp" | "webp" | "heic" | "tiff" | "svg" => {
|
||||
MediaKind::Image
|
||||
}
|
||||
"mp4" | "mkv" | "avi" | "mov" | "wmv" | "flv" => MediaKind::Video,
|
||||
_ => MediaKind::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The media understanding pipeline.
|
||||
///
|
||||
/// Consumes a message's text and attachments, returning enriched text with
|
||||
/// media annotations prepended.
|
||||
pub struct MediaPipeline<'a> {
|
||||
config: &'a MediaPipelineConfig,
|
||||
transcription_config: &'a TranscriptionConfig,
|
||||
vision_available: bool,
|
||||
}
|
||||
|
||||
impl<'a> MediaPipeline<'a> {
|
||||
/// Create a new pipeline. `vision_available` indicates whether the current
|
||||
/// provider supports vision (image description).
|
||||
pub fn new(
|
||||
config: &'a MediaPipelineConfig,
|
||||
transcription_config: &'a TranscriptionConfig,
|
||||
vision_available: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
transcription_config,
|
||||
vision_available,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a message's attachments and return enriched text.
|
||||
///
|
||||
/// If the pipeline is disabled via config, returns `original_text` unchanged.
|
||||
pub async fn process(&self, original_text: &str, attachments: &[MediaAttachment]) -> String {
|
||||
if !self.config.enabled || attachments.is_empty() {
|
||||
return original_text.to_string();
|
||||
}
|
||||
|
||||
let mut annotations = Vec::new();
|
||||
|
||||
for attachment in attachments {
|
||||
match attachment.kind() {
|
||||
MediaKind::Audio if self.config.transcribe_audio => {
|
||||
let annotation = self.process_audio(attachment).await;
|
||||
annotations.push(annotation);
|
||||
}
|
||||
MediaKind::Image if self.config.describe_images => {
|
||||
let annotation = self.process_image(attachment);
|
||||
annotations.push(annotation);
|
||||
}
|
||||
MediaKind::Video if self.config.summarize_video => {
|
||||
let annotation = self.process_video(attachment);
|
||||
annotations.push(annotation);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if annotations.is_empty() {
|
||||
return original_text.to_string();
|
||||
}
|
||||
|
||||
let mut enriched = String::with_capacity(
|
||||
annotations.iter().map(|a| a.len() + 1).sum::<usize>() + original_text.len() + 2,
|
||||
);
|
||||
|
||||
for annotation in &annotations {
|
||||
enriched.push_str(annotation);
|
||||
enriched.push('\n');
|
||||
}
|
||||
|
||||
if !original_text.is_empty() {
|
||||
enriched.push('\n');
|
||||
enriched.push_str(original_text);
|
||||
}
|
||||
|
||||
enriched.trim().to_string()
|
||||
}
|
||||
|
||||
/// Transcribe an audio attachment using the existing transcription infra.
|
||||
async fn process_audio(&self, attachment: &MediaAttachment) -> String {
|
||||
if !self.transcription_config.enabled {
|
||||
return "[Audio: attached]".to_string();
|
||||
}
|
||||
|
||||
match super::transcription::transcribe_audio(
|
||||
attachment.data.clone(),
|
||||
&attachment.file_name,
|
||||
self.transcription_config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
"[Audio transcription: (empty)]".to_string()
|
||||
} else {
|
||||
format!("[Audio transcription: {trimmed}]")
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
file = %attachment.file_name,
|
||||
error = %err,
|
||||
"Media pipeline: audio transcription failed"
|
||||
);
|
||||
"[Audio: transcription failed]".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Describe an image attachment.
|
||||
///
|
||||
/// When vision is available, the image will be passed through to the
|
||||
/// provider as an `[IMAGE:]` marker and described by the model in the
|
||||
/// normal flow. Here we only add a placeholder annotation so the agent
|
||||
/// knows an image is present.
|
||||
fn process_image(&self, attachment: &MediaAttachment) -> String {
|
||||
if self.vision_available {
|
||||
format!(
|
||||
"[Image: {} attached, will be processed by vision model]",
|
||||
attachment.file_name
|
||||
)
|
||||
} else {
|
||||
format!("[Image: {} attached]", attachment.file_name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Summarize a video attachment.
|
||||
///
|
||||
/// Video analysis requires external APIs not currently integrated.
|
||||
/// For now we add a placeholder annotation.
|
||||
fn process_video(&self, attachment: &MediaAttachment) -> String {
|
||||
format!("[Video: {} attached]", attachment.file_name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_pipeline_config(enabled: bool) -> MediaPipelineConfig {
|
||||
MediaPipelineConfig {
|
||||
enabled,
|
||||
transcribe_audio: true,
|
||||
describe_images: true,
|
||||
summarize_video: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_audio() -> MediaAttachment {
|
||||
MediaAttachment {
|
||||
file_name: "voice.ogg".to_string(),
|
||||
data: vec![0u8; 100],
|
||||
mime_type: Some("audio/ogg".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_image() -> MediaAttachment {
|
||||
MediaAttachment {
|
||||
file_name: "photo.jpg".to_string(),
|
||||
data: vec![0u8; 50],
|
||||
mime_type: Some("image/jpeg".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_video() -> MediaAttachment {
|
||||
MediaAttachment {
|
||||
file_name: "clip.mp4".to_string(),
|
||||
data: vec![0u8; 200],
|
||||
mime_type: Some("video/mp4".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn media_kind_from_mime() {
|
||||
let audio = MediaAttachment {
|
||||
file_name: "file".to_string(),
|
||||
data: vec![],
|
||||
mime_type: Some("audio/ogg".to_string()),
|
||||
};
|
||||
assert_eq!(audio.kind(), MediaKind::Audio);
|
||||
|
||||
let image = MediaAttachment {
|
||||
file_name: "file".to_string(),
|
||||
data: vec![],
|
||||
mime_type: Some("image/png".to_string()),
|
||||
};
|
||||
assert_eq!(image.kind(), MediaKind::Image);
|
||||
|
||||
let video = MediaAttachment {
|
||||
file_name: "file".to_string(),
|
||||
data: vec![],
|
||||
mime_type: Some("video/mp4".to_string()),
|
||||
};
|
||||
assert_eq!(video.kind(), MediaKind::Video);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn media_kind_from_extension() {
|
||||
let audio = MediaAttachment {
|
||||
file_name: "voice.ogg".to_string(),
|
||||
data: vec![],
|
||||
mime_type: None,
|
||||
};
|
||||
assert_eq!(audio.kind(), MediaKind::Audio);
|
||||
|
||||
let image = MediaAttachment {
|
||||
file_name: "photo.png".to_string(),
|
||||
data: vec![],
|
||||
mime_type: None,
|
||||
};
|
||||
assert_eq!(image.kind(), MediaKind::Image);
|
||||
|
||||
let video = MediaAttachment {
|
||||
file_name: "clip.mp4".to_string(),
|
||||
data: vec![],
|
||||
mime_type: None,
|
||||
};
|
||||
assert_eq!(video.kind(), MediaKind::Video);
|
||||
|
||||
let unknown = MediaAttachment {
|
||||
file_name: "data.bin".to_string(),
|
||||
data: vec![],
|
||||
mime_type: None,
|
||||
};
|
||||
assert_eq!(unknown.kind(), MediaKind::Unknown);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disabled_pipeline_returns_original_text() {
|
||||
let config = default_pipeline_config(false);
|
||||
let tc = TranscriptionConfig::default();
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let result = pipeline.process("hello", &[sample_audio()]).await;
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_attachments_returns_original_text() {
|
||||
let config = default_pipeline_config(true);
|
||||
let tc = TranscriptionConfig::default();
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let result = pipeline.process("hello", &[]).await;
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn image_annotation_with_vision() {
|
||||
let config = default_pipeline_config(true);
|
||||
let tc = TranscriptionConfig::default();
|
||||
let pipeline = MediaPipeline::new(&config, &tc, true);
|
||||
|
||||
let result = pipeline.process("check this", &[sample_image()]).await;
|
||||
assert!(
|
||||
result.contains("[Image: photo.jpg attached, will be processed by vision model]"),
|
||||
"expected vision annotation, got: {result}"
|
||||
);
|
||||
assert!(result.contains("check this"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn image_annotation_without_vision() {
|
||||
let config = default_pipeline_config(true);
|
||||
let tc = TranscriptionConfig::default();
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let result = pipeline.process("check this", &[sample_image()]).await;
|
||||
assert!(
|
||||
result.contains("[Image: photo.jpg attached]"),
|
||||
"expected basic image annotation, got: {result}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn video_annotation() {
|
||||
let config = default_pipeline_config(true);
|
||||
let tc = TranscriptionConfig::default();
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let result = pipeline.process("watch", &[sample_video()]).await;
|
||||
assert!(
|
||||
result.contains("[Video: clip.mp4 attached]"),
|
||||
"expected video annotation, got: {result}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn audio_without_transcription_enabled() {
|
||||
let config = default_pipeline_config(true);
|
||||
let mut tc = TranscriptionConfig::default();
|
||||
tc.enabled = false;
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let result = pipeline.process("", &[sample_audio()]).await;
|
||||
assert_eq!(result, "[Audio: attached]");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_attachments_produce_multiple_annotations() {
|
||||
let config = default_pipeline_config(true);
|
||||
let mut tc = TranscriptionConfig::default();
|
||||
tc.enabled = false;
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let attachments = vec![sample_audio(), sample_image(), sample_video()];
|
||||
let result = pipeline.process("context", &attachments).await;
|
||||
|
||||
assert!(
|
||||
result.contains("[Audio: attached]"),
|
||||
"missing audio annotation"
|
||||
);
|
||||
assert!(
|
||||
result.contains("[Image: photo.jpg attached]"),
|
||||
"missing image annotation"
|
||||
);
|
||||
assert!(
|
||||
result.contains("[Video: clip.mp4 attached]"),
|
||||
"missing video annotation"
|
||||
);
|
||||
assert!(result.contains("context"), "missing original text");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disabled_sub_features_skip_processing() {
|
||||
let config = MediaPipelineConfig {
|
||||
enabled: true,
|
||||
transcribe_audio: false,
|
||||
describe_images: false,
|
||||
summarize_video: false,
|
||||
};
|
||||
let tc = TranscriptionConfig::default();
|
||||
let pipeline = MediaPipeline::new(&config, &tc, false);
|
||||
|
||||
let attachments = vec![sample_audio(), sample_image(), sample_video()];
|
||||
let result = pipeline.process("hello", &attachments).await;
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
}
|
||||
@@ -199,6 +199,7 @@ impl Channel for MochatChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
|
||||
@@ -26,10 +26,12 @@ pub mod imessage;
|
||||
pub mod irc;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
pub mod lark;
|
||||
pub mod link_enricher;
|
||||
pub mod linq;
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
pub mod matrix;
|
||||
pub mod mattermost;
|
||||
pub mod media_pipeline;
|
||||
pub mod mochat;
|
||||
pub mod nextcloud_talk;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
@@ -366,6 +368,8 @@ struct ChannelRuntimeContext {
|
||||
message_timeout_secs: u64,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig,
|
||||
multimodal: crate::config::MultimodalConfig,
|
||||
media_pipeline: crate::config::MediaPipelineConfig,
|
||||
transcription_config: crate::config::TranscriptionConfig,
|
||||
hooks: Option<Arc<crate::hooks::HookRunner>>,
|
||||
non_cli_excluded_tools: Arc<Vec<String>>,
|
||||
autonomy_level: AutonomyLevel,
|
||||
@@ -721,10 +725,6 @@ fn supports_runtime_model_switch(channel_name: &str) -> bool {
|
||||
}
|
||||
|
||||
fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRuntimeCommand> {
|
||||
if !supports_runtime_model_switch(channel_name) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let trimmed = content.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return None;
|
||||
@@ -739,7 +739,10 @@ fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRun
|
||||
.to_ascii_lowercase();
|
||||
|
||||
match base_command.as_str() {
|
||||
"/models" => {
|
||||
// `/new` is available on every channel — no model-switch gate.
|
||||
"/new" => Some(ChannelRuntimeCommand::NewSession),
|
||||
// Model/provider switching is channel-gated.
|
||||
"/models" if supports_runtime_model_switch(channel_name) => {
|
||||
if let Some(provider) = parts.next() {
|
||||
Some(ChannelRuntimeCommand::SetProvider(
|
||||
provider.trim().to_string(),
|
||||
@@ -748,7 +751,7 @@ fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRun
|
||||
Some(ChannelRuntimeCommand::ShowProviders)
|
||||
}
|
||||
}
|
||||
"/model" => {
|
||||
"/model" if supports_runtime_model_switch(channel_name) => {
|
||||
let model = parts.collect::<Vec<_>>().join(" ").trim().to_string();
|
||||
if model.is_empty() {
|
||||
Some(ChannelRuntimeCommand::ShowModel)
|
||||
@@ -756,7 +759,6 @@ fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRun
|
||||
Some(ChannelRuntimeCommand::SetModel(model))
|
||||
}
|
||||
}
|
||||
"/new" => Some(ChannelRuntimeCommand::NewSession),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -2066,6 +2068,36 @@ async fn process_channel_message(
|
||||
msg
|
||||
};
|
||||
|
||||
// ── Media pipeline: enrich inbound message with media annotations ──
|
||||
if ctx.media_pipeline.enabled && !msg.attachments.is_empty() {
|
||||
let vision = ctx.provider.supports_vision();
|
||||
let pipeline = media_pipeline::MediaPipeline::new(
|
||||
&ctx.media_pipeline,
|
||||
&ctx.transcription_config,
|
||||
vision,
|
||||
);
|
||||
msg.content = Box::pin(pipeline.process(&msg.content, &msg.attachments)).await;
|
||||
}
|
||||
|
||||
// ── Link enricher: prepend URL summaries before agent sees the message ──
|
||||
let le_config = &ctx.prompt_config.link_enricher;
|
||||
if le_config.enabled {
|
||||
let enricher_cfg = link_enricher::LinkEnricherConfig {
|
||||
enabled: le_config.enabled,
|
||||
max_links: le_config.max_links,
|
||||
timeout_secs: le_config.timeout_secs,
|
||||
};
|
||||
let enriched = link_enricher::enrich_message(&msg.content, &enricher_cfg).await;
|
||||
if enriched != msg.content {
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Link enricher: prepended URL summaries to message"
|
||||
);
|
||||
msg.content = enriched;
|
||||
}
|
||||
}
|
||||
|
||||
let target_channel = ctx
|
||||
.channels_by_name
|
||||
.get(&msg.channel)
|
||||
@@ -2524,9 +2556,12 @@ async fn process_channel_message(
|
||||
break loop_result;
|
||||
};
|
||||
|
||||
tracing::debug!("Post-loop: dropping delta_tx and awaiting draft updater");
|
||||
drop(delta_tx);
|
||||
if let Some(handle) = draft_updater {
|
||||
let _ = handle.await;
|
||||
}
|
||||
tracing::debug!("Post-loop: draft updater completed");
|
||||
|
||||
// Thread the final reply only if tools were used (multi-message response)
|
||||
if notify_observer_flag.tools_used.load(Ordering::Relaxed) && msg.channel != "cli" {
|
||||
@@ -3670,13 +3705,16 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
.discord
|
||||
.as_ref()
|
||||
.context("Discord channel is not configured")?;
|
||||
Ok(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)))
|
||||
Ok(Arc::new(
|
||||
DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
"slack" => {
|
||||
let sl = config
|
||||
@@ -3692,7 +3730,8 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
other => anyhow::bail!("Unknown channel '{other}'. Supported: telegram, discord, slack"),
|
||||
@@ -3778,7 +3817,8 @@ fn collect_configured_channels(
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_proxy_url(dc.proxy_url.clone()),
|
||||
.with_proxy_url(dc.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3822,7 +3862,8 @@ fn collect_configured_channels(
|
||||
.with_thread_replies(sl.thread_replies.unwrap_or(true))
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(sl.proxy_url.clone()),
|
||||
.with_proxy_url(sl.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3839,7 +3880,8 @@ fn collect_configured_channels(
|
||||
mm.thread_replies.unwrap_or(true),
|
||||
mm.mention_only.unwrap_or(false),
|
||||
)
|
||||
.with_proxy_url(mm.proxy_url.clone()),
|
||||
.with_proxy_url(mm.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3855,11 +3897,12 @@ fn collect_configured_channels(
|
||||
if let Some(ref mx) = config.channels_config.matrix {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Matrix",
|
||||
channel: Arc::new(MatrixChannel::new_with_session_hint_and_zeroclaw_dir(
|
||||
channel: Arc::new(MatrixChannel::new_full(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
mx.allowed_rooms.clone(),
|
||||
mx.user_id.clone(),
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
@@ -3968,15 +4011,18 @@ fn collect_configured_channels(
|
||||
}
|
||||
|
||||
if let Some(ref wati_cfg) = config.channels_config.wati {
|
||||
let wati_channel = WatiChannel::new_with_proxy(
|
||||
wati_cfg.api_token.clone(),
|
||||
wati_cfg.api_url.clone(),
|
||||
wati_cfg.tenant_id.clone(),
|
||||
wati_cfg.allowed_numbers.clone(),
|
||||
wati_cfg.proxy_url.clone(),
|
||||
)
|
||||
.with_transcription(config.transcription.clone());
|
||||
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WATI",
|
||||
channel: Arc::new(WatiChannel::new_with_proxy(
|
||||
wati_cfg.api_token.clone(),
|
||||
wati_cfg.api_url.clone(),
|
||||
wati_cfg.tenant_id.clone(),
|
||||
wati_cfg.allowed_numbers.clone(),
|
||||
wati_cfg.proxy_url.clone(),
|
||||
)),
|
||||
channel: Arc::new(wati_channel),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4039,13 +4085,19 @@ fn collect_configured_channels(
|
||||
);
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Feishu",
|
||||
channel: Arc::new(LarkChannel::from_config(lk)),
|
||||
channel: Arc::new(
|
||||
LarkChannel::from_config(lk)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Lark",
|
||||
channel: Arc::new(LarkChannel::from_lark_config(lk)),
|
||||
channel: Arc::new(
|
||||
LarkChannel::from_lark_config(lk)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -4054,7 +4106,10 @@ fn collect_configured_channels(
|
||||
if let Some(ref fs) = config.channels_config.feishu {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Feishu",
|
||||
channel: Arc::new(LarkChannel::from_feishu_config(fs)),
|
||||
channel: Arc::new(
|
||||
LarkChannel::from_feishu_config(fs)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4345,22 +4400,23 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let (mut built_tools, delegate_handle_ch, reaction_handle_ch) = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let (mut built_tools, delegate_handle_ch, reaction_handle_ch, _channel_map_handle) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
composio_entity_id,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.web_fetch,
|
||||
&workspace,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
|
||||
@@ -4709,6 +4765,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
matrix: interrupt_on_new_message_matrix,
|
||||
},
|
||||
multimodal: config.multimodal.clone(),
|
||||
media_pipeline: config.media_pipeline.clone(),
|
||||
transcription_config: config.transcription.clone(),
|
||||
hooks: if config.hooks.enabled {
|
||||
let mut runner = crate::hooks::HookRunner::new();
|
||||
if config.hooks.builtin.command_logger {
|
||||
@@ -5080,6 +5138,8 @@ mod tests {
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
@@ -5197,6 +5257,8 @@ mod tests {
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
@@ -5270,6 +5332,8 @@ mod tests {
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
@@ -5362,6 +5426,8 @@ mod tests {
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
@@ -5911,6 +5977,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -5936,6 +6004,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -5993,6 +6062,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -6018,6 +6089,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6086,6 +6158,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6114,6 +6188,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 3,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6167,6 +6242,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6195,6 +6272,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6258,6 +6336,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6286,6 +6366,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6370,6 +6451,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6398,6 +6481,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6463,6 +6547,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6491,6 +6577,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 3,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6571,6 +6658,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6599,6 +6688,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 4,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6664,6 +6754,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6678,7 +6770,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
pacing: crate::config::PacingConfig {
|
||||
loop_detection_enabled: false,
|
||||
..crate::config::PacingConfig::default()
|
||||
},
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6692,6 +6787,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6747,6 +6843,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6761,7 +6859,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
pacing: crate::config::PacingConfig {
|
||||
loop_detection_enabled: false,
|
||||
..crate::config::PacingConfig::default()
|
||||
},
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6775,6 +6876,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -6875,6 +6977,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: "2026-02-20T00:00:00Z".to_string(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -6945,6 +7050,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -6972,6 +7079,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -6984,6 +7092,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7048,6 +7157,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -7076,6 +7187,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7089,6 +7201,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7169,6 +7282,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -7194,6 +7309,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: Some("1741234567.100001".to_string()),
|
||||
interruption_scope_id: Some("1741234567.100001".to_string()),
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7207,6 +7323,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: Some("1741234567.100001".to_string()),
|
||||
interruption_scope_id: Some("1741234567.100001".to_string()),
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7281,6 +7398,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -7309,6 +7428,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7322,6 +7442,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -7378,6 +7499,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -7406,6 +7529,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -7459,6 +7583,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -7487,6 +7613,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -7696,9 +7823,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(prompt.contains("<instructions>"));
|
||||
assert!(prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
assert!(!prompt.contains("loaded on demand"));
|
||||
}
|
||||
|
||||
@@ -7741,10 +7868,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8016,6 +8143,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(conversation_memory_key(&msg), "slack_U123_msg_abc123");
|
||||
@@ -8032,6 +8160,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: Some("1741234567.123456".into()),
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
@@ -8051,6 +8180,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(followup_thread_id(&msg).as_deref(), Some("msg_abc123"));
|
||||
@@ -8067,6 +8197,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
let msg2 = traits::ChannelMessage {
|
||||
id: "msg_2".into(),
|
||||
@@ -8077,6 +8208,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
assert_ne!(
|
||||
@@ -8099,6 +8231,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
let msg2 = traits::ChannelMessage {
|
||||
id: "msg_2".into(),
|
||||
@@ -8109,6 +8242,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
mem.store(
|
||||
@@ -8230,6 +8364,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -8258,6 +8394,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -8274,6 +8411,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -8362,6 +8500,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -8390,6 +8530,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -8422,6 +8563,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -8460,6 +8602,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 3,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -8534,6 +8677,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -8562,6 +8707,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -8643,6 +8789,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -8671,6 +8819,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9216,6 +9365,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -9245,6 +9396,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9304,6 +9456,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -9332,6 +9486,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9348,6 +9503,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9467,6 +9623,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -9495,6 +9653,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9579,6 +9738,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -9607,6 +9768,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9683,6 +9845,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -9711,6 +9875,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9807,6 +9972,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -9835,6 +10002,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
@@ -9992,6 +10160,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 0,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
assert_eq!(interruption_scope_key(&msg), "matrix_room_alice");
|
||||
}
|
||||
@@ -10007,6 +10176,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 0,
|
||||
thread_ts: Some("$thread1".into()),
|
||||
interruption_scope_id: Some("$thread1".into()),
|
||||
attachments: vec![],
|
||||
};
|
||||
assert_eq!(interruption_scope_key(&msg), "matrix_room_alice_$thread1");
|
||||
}
|
||||
@@ -10023,6 +10193,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 0,
|
||||
thread_ts: Some("1234567890.000100".into()), // Slack top-level fallback
|
||||
interruption_scope_id: None, // but NOT a thread reply
|
||||
attachments: vec![],
|
||||
};
|
||||
assert_eq!(interruption_scope_key(&msg), "slack_C123_alice");
|
||||
}
|
||||
@@ -10069,6 +10240,8 @@ This is an example JSON object for profile settings."#;
|
||||
matrix: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
media_pipeline: crate::config::MediaPipelineConfig::default(),
|
||||
transcription_config: crate::config::TranscriptionConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
autonomy_level: AutonomyLevel::default(),
|
||||
@@ -10099,6 +10272,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 1,
|
||||
thread_ts: Some("1741234567.100001".to_string()),
|
||||
interruption_scope_id: Some("1741234567.100001".to_string()),
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -10112,6 +10286,7 @@ This is an example JSON object for profile settings."#;
|
||||
timestamp: 2,
|
||||
thread_ts: Some("1741234567.200002".to_string()),
|
||||
interruption_scope_id: Some("1741234567.200002".to_string()),
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -206,6 +206,7 @@ impl NextcloudTalkChannel {
|
||||
timestamp: Self::now_unix_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
});
|
||||
|
||||
messages
|
||||
@@ -308,6 +309,7 @@ impl NextcloudTalkChannel {
|
||||
timestamp,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
});
|
||||
|
||||
messages
|
||||
|
||||
@@ -254,6 +254,7 @@ impl Channel for NostrChannel {
|
||||
timestamp,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
if tx.send(msg).await.is_err() {
|
||||
tracing::info!("Nostr listener: message bus closed, stopping");
|
||||
|
||||
@@ -361,6 +361,7 @@ impl Channel for NotionChannel {
|
||||
timestamp,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
|
||||
@@ -1139,6 +1139,7 @@ impl Channel for QQChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
@@ -1178,6 +1179,7 @@ impl Channel for QQChannel {
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
|
||||
@@ -226,6 +226,7 @@ impl RedditChannel {
|
||||
timestamp,
|
||||
thread_ts: item.parent_id.clone(),
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ use chrono::{DateTime, Utc};
|
||||
pub struct SessionMetadata {
|
||||
/// Session key (e.g. `telegram_user123`).
|
||||
pub key: String,
|
||||
/// Optional human-readable name (e.g. `eyrie-commander-briefing`).
|
||||
pub name: Option<String>,
|
||||
/// When the session was first created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the last message was appended.
|
||||
@@ -54,6 +56,7 @@ pub trait SessionBackend: Send + Sync {
|
||||
let messages = self.load(&key);
|
||||
SessionMetadata {
|
||||
key,
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: messages.len(),
|
||||
@@ -81,6 +84,16 @@ pub trait SessionBackend: Send + Sync {
|
||||
fn delete_session(&self, _session_key: &str) -> std::io::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Set or update the human-readable name for a session.
|
||||
fn set_session_name(&self, _session_key: &str, _name: &str) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the human-readable name for a session (if set).
|
||||
fn get_session_name(&self, _session_key: &str) -> std::io::Result<Option<String>> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -91,6 +104,7 @@ mod tests {
|
||||
fn session_metadata_is_constructible() {
|
||||
let meta = SessionMetadata {
|
||||
key: "test".into(),
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: 5,
|
||||
|
||||
@@ -51,7 +51,8 @@ impl SqliteSessionBackend {
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
last_activity TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
message_count INTEGER NOT NULL DEFAULT 0,
|
||||
name TEXT
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
|
||||
@@ -69,6 +70,18 @@ impl SqliteSessionBackend {
|
||||
)
|
||||
.context("Failed to initialize session schema")?;
|
||||
|
||||
// Migration: add name column to existing databases
|
||||
let has_name: bool = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) > 0 FROM pragma_table_info('session_metadata') WHERE name = 'name'",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap_or(false);
|
||||
if !has_name {
|
||||
let _ = conn.execute("ALTER TABLE session_metadata ADD COLUMN name TEXT", []);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
@@ -226,7 +239,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_key, created_at, last_activity, message_count
|
||||
"SELECT session_key, created_at, last_activity, message_count, name
|
||||
FROM session_metadata ORDER BY last_activity DESC",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
@@ -238,6 +251,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
let created_str: String = row.get(1)?;
|
||||
let activity_str: String = row.get(2)?;
|
||||
let count: i64 = row.get(3)?;
|
||||
let name: Option<String> = row.get(4)?;
|
||||
|
||||
let created = DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
@@ -249,6 +263,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(SessionMetadata {
|
||||
key,
|
||||
name,
|
||||
created_at: created,
|
||||
last_activity: activity,
|
||||
message_count: count as usize,
|
||||
@@ -321,6 +336,27 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn set_session_name(&self, session_key: &str, name: &str) -> std::io::Result<()> {
|
||||
let conn = self.conn.lock();
|
||||
let name_val = if name.is_empty() { None } else { Some(name) };
|
||||
conn.execute(
|
||||
"UPDATE session_metadata SET name = ?1 WHERE session_key = ?2",
|
||||
params![name_val, session_key],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_session_name(&self, session_key: &str) -> std::io::Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
conn.query_row(
|
||||
"SELECT name FROM session_metadata WHERE session_key = ?1",
|
||||
params![session_key],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
|
||||
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
let Some(keyword) = &query.keyword else {
|
||||
return self.list_sessions_with_metadata();
|
||||
@@ -357,14 +393,16 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
keys.iter()
|
||||
.filter_map(|key| {
|
||||
conn.query_row(
|
||||
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
|
||||
"SELECT created_at, last_activity, message_count, name FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
|row| {
|
||||
let created_str: String = row.get(0)?;
|
||||
let activity_str: String = row.get(1)?;
|
||||
let count: i64 = row.get(2)?;
|
||||
let name: Option<String> = row.get(3)?;
|
||||
Ok(SessionMetadata {
|
||||
key: key.clone(),
|
||||
name,
|
||||
created_at: DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
@@ -555,4 +593,55 @@ mod tests {
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_persists() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "My Session").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert_eq!(meta[0].name.as_deref(), Some("My Session"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_updates_existing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "First").unwrap();
|
||||
backend.set_session_name("s1", "Second").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta[0].name.as_deref(), Some("Second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sessions_without_name_return_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_name_clears_to_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "Named").unwrap();
|
||||
backend.set_session_name("s1", "").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,6 +280,7 @@ impl SignalChannel {
|
||||
timestamp: timestamp / 1000, // millis → secs
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ pub struct SlackChannel {
|
||||
active_assistant_thread: Mutex<HashMap<String, String>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio file attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@@ -125,6 +128,7 @@ impl SlackChannel {
|
||||
workspace_dir: None,
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +162,14 @@ impl SlackChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio file attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client_with_timeouts(
|
||||
"channel.slack",
|
||||
@@ -558,6 +570,13 @@ impl SlackChannel {
|
||||
.await
|
||||
.unwrap_or_else(|| raw_file.clone());
|
||||
|
||||
// Voice / audio transcription: if transcription is configured and the
|
||||
// file looks like an audio attachment, download and transcribe it.
|
||||
if Self::is_audio_file(&file) {
|
||||
if let Some(transcribed) = self.try_transcribe_audio_file(&file).await {
|
||||
return Some(transcribed);
|
||||
}
|
||||
}
|
||||
if Self::is_image_file(&file) {
|
||||
if let Some(marker) = self.fetch_image_marker(&file).await {
|
||||
return Some(marker);
|
||||
@@ -1449,6 +1468,106 @@ impl SlackChannel {
|
||||
.is_some_and(|ext| Self::mime_from_extension(ext).is_some())
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check whether a Slack file object looks like an audio attachment
|
||||
/// (voice memo, audio message, or uploaded audio file).
|
||||
fn is_audio_file(file: &serde_json::Value) -> bool {
|
||||
// Slack voice messages use subtype "slack_audio"
|
||||
if let Some(subtype) = file.get("subtype").and_then(|v| v.as_str()) {
|
||||
if subtype == "slack_audio" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if Self::slack_file_mime(file)
|
||||
.as_deref()
|
||||
.is_some_and(|mime| mime.starts_with("audio/"))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(ft) = file
|
||||
.get("filetype")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|v| v.to_ascii_lowercase())
|
||||
{
|
||||
if Self::AUDIO_EXTENSIONS.contains(&ft.as_str()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Self::file_extension(&Self::slack_file_name(file))
|
||||
.as_deref()
|
||||
.is_some_and(|ext| Self::AUDIO_EXTENSIONS.contains(&ext))
|
||||
}
|
||||
|
||||
/// Download an audio file attachment and transcribe it using the configured
|
||||
/// transcription provider. Returns `None` if transcription is not configured
|
||||
/// or if the download/transcription fails.
|
||||
async fn try_transcribe_audio_file(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let file_name = Self::slack_file_name(file);
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
let resp = self.fetch_slack_private_file(url).await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
tracing::warn!(
|
||||
"Slack voice file download failed for {} ({status})",
|
||||
redacted_url
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let audio_data = match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice file read failed for {}: {e}", redacted_url);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Determine a filename with extension for the transcription API.
|
||||
let transcription_filename = if Self::file_extension(&file_name).is_some() {
|
||||
file_name.clone()
|
||||
} else {
|
||||
// Fall back to extension from mimetype or default to .ogg
|
||||
let mime_ext = Self::slack_file_mime(file)
|
||||
.and_then(|mime| mime.rsplit('/').next().map(|s| s.to_string()))
|
||||
.unwrap_or_else(|| "ogg".to_string());
|
||||
format!("voice.{mime_ext}")
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, &transcription_filename, config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::info!("Slack voice transcription returned empty text, skipping");
|
||||
None
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Slack: transcribed voice file {} ({} chars)",
|
||||
file_name,
|
||||
trimmed.len()
|
||||
);
|
||||
Some(format!("[Voice] {trimmed}"))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice transcription failed for {}: {e}", file_name);
|
||||
Some(Self::format_attachment_summary(file))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_text_snippet(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
@@ -1908,6 +2027,7 @@ impl SlackChannel {
|
||||
Self::inbound_thread_ts_genuine_only(event)
|
||||
},
|
||||
interruption_scope_id: Self::inbound_interruption_scope_id(event, ts),
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
// Track thread context so start_typing can set assistant status.
|
||||
@@ -2582,6 +2702,7 @@ impl Channel for SlackChannel {
|
||||
Self::inbound_thread_ts_genuine_only(msg)
|
||||
},
|
||||
interruption_scope_id: Self::inbound_interruption_scope_id(msg, ts),
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
@@ -2667,6 +2788,7 @@ impl Channel for SlackChannel {
|
||||
.as_secs(),
|
||||
thread_ts: Some(thread_ts.clone()),
|
||||
interruption_scope_id: Some(thread_ts.clone()),
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
@@ -3619,6 +3741,7 @@ mod tests {
|
||||
timestamp: 0,
|
||||
thread_ts: None, // thread_replies=false → no fallback to ts
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
let msg1 = make_msg("100.000");
|
||||
@@ -3644,6 +3767,7 @@ mod tests {
|
||||
timestamp: 0,
|
||||
thread_ts: Some(ts.to_string()), // thread_replies=true → ts as thread_ts
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
let msg1 = make_msg("100.000");
|
||||
|
||||
@@ -1140,6 +1140,11 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content = format!("{quote}\n\n{content}");
|
||||
}
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
content = format!("{attr}{content}");
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1152,6 +1157,7 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
.as_secs(),
|
||||
thread_ts: thread_id,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1263,6 +1269,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
format!("[Voice] {text}")
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1275,6 +1288,7 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
.as_secs(),
|
||||
thread_ts: thread_id,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1299,6 +1313,41 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
(username, sender_id, sender_identity)
|
||||
}
|
||||
|
||||
/// Build a forwarding attribution prefix from Telegram forward fields.
|
||||
///
|
||||
/// Returns `Some("[Forwarded from ...] ")` when the message is forwarded,
|
||||
/// `None` otherwise.
|
||||
fn format_forward_attribution(message: &serde_json::Value) -> Option<String> {
|
||||
if let Some(from_chat) = message.get("forward_from_chat") {
|
||||
// Forwarded from a channel or group
|
||||
let title = from_chat
|
||||
.get("title")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown channel");
|
||||
Some(format!("[Forwarded from channel: {title}] "))
|
||||
} else if let Some(from_user) = message.get("forward_from") {
|
||||
// Forwarded from a user (privacy allows identity)
|
||||
let label = from_user
|
||||
.get("username")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|u| format!("@{u}"))
|
||||
.or_else(|| {
|
||||
from_user
|
||||
.get("first_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(String::from)
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
Some(format!("[Forwarded from {label}] "))
|
||||
} else {
|
||||
// Forwarded from a user who hides their identity
|
||||
message
|
||||
.get("forward_sender_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|name| format!("[Forwarded from {name}] "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract reply context from a Telegram `reply_to_message`, if present.
|
||||
fn extract_reply_context(&self, message: &serde_json::Value) -> Option<String> {
|
||||
let reply = message.get("reply_to_message")?;
|
||||
@@ -1420,6 +1469,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
// Exit voice-chat mode when user switches back to typing
|
||||
if let Ok(mut vc) = self.voice_chats.lock() {
|
||||
vc.remove(&reply_target);
|
||||
@@ -1437,6 +1493,7 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
.as_secs(),
|
||||
thread_ts: thread_id,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4871,4 +4928,153 @@ mod tests {
|
||||
TelegramChannel::new("token".into(), vec!["*".into()], false).with_ack_reactions(true);
|
||||
assert!(ch.ack_reactions);
|
||||
}
|
||||
|
||||
// ── Forwarded message tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_with_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 100,
|
||||
"message": {
|
||||
"message_id": 50,
|
||||
"text": "Check this out",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"first_name": "Bob",
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from @bob] Check this out");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_channel() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 101,
|
||||
"message": {
|
||||
"message_id": 51,
|
||||
"text": "Breaking news",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from_chat": {
|
||||
"id": -1_001_234_567_890_i64,
|
||||
"title": "Daily News",
|
||||
"username": "dailynews",
|
||||
"type": "channel"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("channel-forwarded message should parse");
|
||||
assert_eq!(
|
||||
msg.content,
|
||||
"[Forwarded from channel: Daily News] Breaking news"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_hidden_sender() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 102,
|
||||
"message": {
|
||||
"message_id": 52,
|
||||
"text": "Secret tip",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_sender_name": "Hidden User",
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("hidden-sender forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Hidden User] Secret tip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_non_forwarded_unaffected() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 103,
|
||||
"message": {
|
||||
"message_id": 53,
|
||||
"text": "Normal message",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 }
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("non-forwarded message should parse");
|
||||
assert_eq!(msg.content, "Normal message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_no_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 104,
|
||||
"message": {
|
||||
"message_id": 54,
|
||||
"text": "Hello there",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 77,
|
||||
"first_name": "Charlie"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message without username should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Charlie] Hello there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarded_photo_attachment_has_attribution() {
|
||||
// Verify that format_forward_attribution produces correct prefix
|
||||
// for a photo message (the actual download is async, so we test the
|
||||
// helper directly with a photo-bearing message structure).
|
||||
let message = serde_json::json!({
|
||||
"message_id": 60,
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"photo": [
|
||||
{ "file_id": "abc123", "file_unique_id": "u1", "width": 320, "height": 240 }
|
||||
],
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
});
|
||||
|
||||
let attr =
|
||||
TelegramChannel::format_forward_attribution(&message).expect("should detect forward");
|
||||
assert_eq!(attr, "[Forwarded from @bob] ");
|
||||
|
||||
// Simulate what try_parse_attachment_message does after building content
|
||||
let photo_content = "[IMAGE:/tmp/photo.jpg]".to_string();
|
||||
let content = format!("{attr}{photo_content}");
|
||||
assert_eq!(content, "[Forwarded from @bob] [IMAGE:/tmp/photo.jpg]");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,10 @@ pub struct ChannelMessage {
|
||||
/// is genuinely inside a reply thread and should be isolated from other threads.
|
||||
/// `None` means top-level — scope is sender+channel only.
|
||||
pub interruption_scope_id: Option<String>,
|
||||
/// Media attachments (audio, images, video) for the media pipeline.
|
||||
/// Channels populate this when they receive media alongside a text message.
|
||||
/// Defaults to empty — existing channels are unaffected.
|
||||
pub attachments: Vec<super::media_pipeline::MediaAttachment>,
|
||||
}
|
||||
|
||||
/// Message to send through a channel
|
||||
@@ -188,6 +192,7 @@ mod tests {
|
||||
timestamp: 123,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e.to_string()))
|
||||
@@ -205,6 +210,7 @@ mod tests {
|
||||
timestamp: 999,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
attachments: vec![],
|
||||
};
|
||||
|
||||
let cloned = message.clone();
|
||||
|
||||
@@ -597,7 +597,7 @@ impl TranscriptionProvider for GoogleSttProvider {
|
||||
/// Self-hosted faster-whisper-compatible STT provider.
|
||||
///
|
||||
/// POSTs audio as `multipart/form-data` (field name `file`) to a configurable
|
||||
/// HTTP endpoint (e.g. faster-whisper on GEX44 over WireGuard). The endpoint
|
||||
/// HTTP endpoint (e.g. `http://localhost:8000` or a private network host). The endpoint
|
||||
/// must return `{"text": "..."}`. No cloud API key required. Size limit is
|
||||
/// configurable — not constrained by the 25 MB cloud API cap.
|
||||
pub struct LocalWhisperProvider {
|
||||
|
||||