Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 256e8ccebf | |||
| 72c9e6b6ca | |||
| 755a129ca2 | |||
| 8b0d3684c5 | |||
| a38a4d132e | |||
| 48aba73d3a | |||
| a1ab1e1a11 | |||
| 87b5bca449 | |||
| be40c0c5a5 | |||
| 6527871928 | |||
| 0bda80de9c | |||
| 02f57f4d98 | |||
| ef83dd44d7 | |||
| a986b6b912 | |||
| b6b1186e3b | |||
| 00dc0c8670 | |||
| 43f2a0a815 | |||
| 50b5bd4d73 | |||
| 8c074870a1 | |||
| 61d1841ce3 | |||
| eb396cf38f | |||
| 9f1657b9be | |||
| 8fecd4286c | |||
| df21d92da3 | |||
| 8d65924704 | |||
| 756c3cadff | |||
| ee870028ff |
@@ -102,6 +102,22 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish to crates.io
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
@@ -67,6 +67,24 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
if: "!inputs.dry_run"
|
||||
shell: bash
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
if: "!inputs.dry_run"
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish (dry run)
|
||||
if: inputs.dry_run
|
||||
run: cargo publish --dry-run --locked --allow-dirty --no-verify
|
||||
|
||||
@@ -323,6 +323,21 @@ jobs:
|
||||
- name: Clean web build artifacts
|
||||
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
|
||||
|
||||
- name: Publish aardvark-sys to crates.io
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
run: |
|
||||
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
|
||||
echo "$OUTPUT"
|
||||
if echo "$OUTPUT" | grep -q 'already exists'; then
|
||||
echo "::notice::aardvark-sys already on crates.io — skipping"
|
||||
exit 0
|
||||
fi
|
||||
exit 1
|
||||
|
||||
- name: Wait for aardvark-sys to index
|
||||
run: sleep 15
|
||||
|
||||
- name: Publish to crates.io
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
|
||||
Generated
+1
-1
@@ -9203,7 +9203,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.4"
|
||||
version = "0.5.6"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
|
||||
+2
-2
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.4"
|
||||
version = "0.5.6"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
@@ -97,7 +97,7 @@ anyhow = "1.0"
|
||||
thiserror = "2.0"
|
||||
|
||||
# Aardvark I2C/SPI/GPIO USB adapter (Total Phase) — stub when SDK absent
|
||||
aardvark-sys = { path = "crates/aardvark-sys" }
|
||||
aardvark-sys = { path = "crates/aardvark-sys", version = "0.1.0" }
|
||||
|
||||
# UUID generation
|
||||
uuid = { version = "1.22", default-features = false, features = ["v4", "std"] }
|
||||
|
||||
@@ -27,6 +27,7 @@ COPY Cargo.toml Cargo.lock ./
|
||||
# Previously we used sed to drop `crates/robot-kit`, which made the manifest disagree
|
||||
# with the lockfile and caused `cargo --locked` to fail (Cargo refused to rewrite the lock).
|
||||
COPY crates/robot-kit/ crates/robot-kit/
|
||||
COPY crates/aardvark-sys/ crates/aardvark-sys/
|
||||
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
|
||||
RUN mkdir -p src benches \
|
||||
&& echo "fn main() {}" > src/main.rs \
|
||||
|
||||
@@ -263,7 +263,7 @@ fn bench_memory_operations(c: &mut Criterion) {
|
||||
c.bench_function("memory_recall_top10", |b| {
|
||||
b.iter(|| {
|
||||
rt.block_on(async {
|
||||
mem.recall(black_box("zeroclaw agent"), 10, None)
|
||||
mem.recall(black_box("zeroclaw agent"), 10, None, None, None)
|
||||
.await
|
||||
.unwrap()
|
||||
})
|
||||
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.4
|
||||
pkgver = 0.5.6
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.4.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.4.tar.gz
|
||||
source = zeroclaw-0.5.6.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.6.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.4
|
||||
pkgver=0.5.6
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.4",
|
||||
"version": "0.5.6",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.4/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.6/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -122,6 +122,34 @@ tools = ["mcp_browser_*"]
|
||||
keywords = ["browse", "navigate", "open url", "screenshot"]
|
||||
```
|
||||
|
||||
## `[pacing]`
|
||||
|
||||
Pacing controls for slow/local LLM workloads (Ollama, llama.cpp, vLLM). All keys are optional; when absent, existing behavior is preserved.
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `step_timeout_secs` | _none_ | Per-step timeout: maximum seconds for a single LLM inference turn. Catches a truly hung model without terminating the overall task loop |
|
||||
| `loop_detection_min_elapsed_secs` | _none_ | Minimum elapsed seconds before loop detection activates. Tasks completing under this threshold get aggressive loop protection; longer-running tasks receive a grace period |
|
||||
| `loop_ignore_tools` | `[]` | Tool names excluded from identical-output loop detection. Useful for browser workflows where `browser_screenshot` structurally resembles a loop |
|
||||
| `message_timeout_scale_max` | `4` | Override for the hardcoded timeout scaling cap. The channel message timeout budget is `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)` |
|
||||
|
||||
Notes:
|
||||
|
||||
- These settings are intended for local/slow LLM deployments. Cloud-provider users typically do not need them.
|
||||
- `step_timeout_secs` operates independently of the total channel message timeout budget. A step timeout abort does not consume the overall budget; the loop simply stops.
|
||||
- `loop_detection_min_elapsed_secs` delays loop-detection counting, not the task itself. Loop protection remains fully active for short tasks (the default).
|
||||
- `loop_ignore_tools` only suppresses tool-output-based loop detection for the listed tools. Other safety features (max iterations, overall timeout) remain active.
|
||||
- `message_timeout_scale_max` must be >= 1. Setting it higher than `max_tool_iterations` has no additional effect (the formula uses `min()`).
|
||||
- Example configuration for a slow local Ollama deployment:
|
||||
|
||||
```toml
|
||||
[pacing]
|
||||
step_timeout_secs = 120
|
||||
loop_detection_min_elapsed_secs = 60
|
||||
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
|
||||
message_timeout_scale_max = 8
|
||||
```
|
||||
|
||||
## `[security.otp]`
|
||||
|
||||
| Key | Default | Purpose |
|
||||
@@ -185,12 +213,15 @@ Delegate sub-agent configurations. Each key under `[agents]` defines a named sub
|
||||
| `max_iterations` | `10` | Max tool-call iterations for agentic mode |
|
||||
| `timeout_secs` | `120` | Timeout in seconds for non-agentic provider calls (1–3600) |
|
||||
| `agentic_timeout_secs` | `300` | Timeout in seconds for agentic sub-agent loops (1–3600) |
|
||||
| `skills_directory` | unset | Optional skills directory path (workspace-relative) for scoped skill loading |
|
||||
|
||||
Notes:
|
||||
|
||||
- `agentic = false` preserves existing single prompt→response delegate behavior.
|
||||
- `agentic = true` requires at least one matching entry in `allowed_tools`.
|
||||
- The `delegate` tool is excluded from sub-agent allowlists to prevent re-entrant delegation loops.
|
||||
- Sub-agents receive an enriched system prompt containing: tools section (allowed tools with parameters), skills section (from scoped or default directory), workspace path, current date/time, safety constraints, and shell policy when `shell` is in the effective tool list.
|
||||
- When `skills_directory` is unset or empty, the sub-agent loads skills from the default workspace `skills/` directory. When set, skills are loaded exclusively from that directory (relative to workspace root), enabling per-agent scoped skill sets.
|
||||
|
||||
```toml
|
||||
[agents.researcher]
|
||||
@@ -208,6 +239,14 @@ provider = "ollama"
|
||||
model = "qwen2.5-coder:32b"
|
||||
temperature = 0.2
|
||||
timeout_secs = 60
|
||||
|
||||
[agents.code_reviewer]
|
||||
provider = "anthropic"
|
||||
model = "claude-opus-4-5"
|
||||
system_prompt = "You are an expert code reviewer focused on security and performance."
|
||||
agentic = true
|
||||
allowed_tools = ["file_read", "shell"]
|
||||
skills_directory = "skills/code-review"
|
||||
```
|
||||
|
||||
## `[runtime]`
|
||||
@@ -414,6 +453,12 @@ Notes:
|
||||
| `port` | `42617` | gateway listen port |
|
||||
| `require_pairing` | `true` | require pairing before bearer auth |
|
||||
| `allow_public_bind` | `false` | block accidental public exposure |
|
||||
| `path_prefix` | _(none)_ | URL path prefix for reverse-proxy deployments (e.g. `"/zeroclaw"`) |
|
||||
|
||||
When deploying behind a reverse proxy that maps ZeroClaw to a sub-path,
|
||||
set `path_prefix` to that sub-path (e.g. `"/zeroclaw"`). All gateway
|
||||
routes will be served under this prefix. The value must start with `/`
|
||||
and must not end with `/`.
|
||||
|
||||
## `[autonomy]`
|
||||
|
||||
@@ -586,7 +631,7 @@ Top-level channel options are configured under `channels_config`.
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x) |
|
||||
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x, overridable via `[pacing].message_timeout_scale_max`) |
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -601,7 +646,7 @@ Examples:
|
||||
Notes:
|
||||
|
||||
- Default `300s` is optimized for on-device LLMs (Ollama) which are slower than cloud APIs.
|
||||
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, 4)` and a minimum of `1`.
|
||||
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, cap)` and a minimum of `1`. The default cap is `4`; override with `[pacing].message_timeout_scale_max`.
|
||||
- This scaling avoids false timeouts when the first LLM turn is slow/retried but later tool-loop turns still need to complete.
|
||||
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
|
||||
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
|
||||
|
||||
+62
@@ -568,6 +568,31 @@ then re-run bootstrap.
|
||||
MSG
|
||||
exit 0
|
||||
fi
|
||||
# Detect un-accepted Xcode/CLT license (causes `cc` to exit 69).
|
||||
# xcrun --show-sdk-path can succeed even without an accepted license,
|
||||
# so we test-compile a trivial C file which reliably triggers the error.
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
|
||||
_xcode_accept_ok=false
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
xcodebuild -license accept && _xcode_accept_ok=true
|
||||
elif [[ -c /dev/tty ]] && have_cmd sudo; then
|
||||
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
|
||||
fi
|
||||
if [[ "$_xcode_accept_ok" == true ]]; then
|
||||
step_ok "Xcode license accepted"
|
||||
else
|
||||
error "Could not accept Xcode license. Run manually:"
|
||||
error " sudo xcodebuild -license accept"
|
||||
error "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
rm -f "$_xcode_test_file"
|
||||
fi
|
||||
if ! have_cmd git; then
|
||||
warn "git is not available. Install git (e.g., Homebrew) and re-run bootstrap."
|
||||
fi
|
||||
@@ -1168,6 +1193,43 @@ else
|
||||
install_system_deps
|
||||
fi
|
||||
|
||||
# Always check Xcode/CLT license on macOS, regardless of --install-system-deps.
|
||||
# An un-accepted license causes `cc` to exit 69, breaking all Rust builds.
|
||||
if [[ "$OS_NAME" == "Darwin" ]]; then
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
|
||||
# Use /dev/tty so sudo can prompt for a password even in a curl|bash pipe.
|
||||
_xcode_accept_ok=false
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
xcodebuild -license accept && _xcode_accept_ok=true
|
||||
elif [[ -c /dev/tty ]] && have_cmd sudo; then
|
||||
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
|
||||
fi
|
||||
if [[ "$_xcode_accept_ok" == true ]]; then
|
||||
step_ok "Xcode license accepted"
|
||||
# Re-test compilation to confirm it's fixed.
|
||||
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
|
||||
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
|
||||
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
|
||||
rm -f "$_xcode_test_file"
|
||||
error "C compiler still failing after license accept. Check your Xcode/CLT installation."
|
||||
exit 1
|
||||
fi
|
||||
rm -f "$_xcode_test_file"
|
||||
else
|
||||
error "Could not accept Xcode license. Run manually:"
|
||||
error " sudo xcodebuild -license accept"
|
||||
error "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
rm -f "$_xcode_test_file"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$INSTALL_RUST" == true ]]; then
|
||||
install_rust_toolchain
|
||||
fi
|
||||
|
||||
+451
-8
@@ -1,5 +1,8 @@
|
||||
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::config::Config;
|
||||
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
|
||||
use crate::cost::CostTracker;
|
||||
use crate::i18n::ToolDescriptions;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
@@ -23,6 +26,108 @@ use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ── Cost tracking via task-local ──
|
||||
|
||||
/// Context for cost tracking within the tool call loop.
|
||||
/// Scoped via `tokio::task_local!` at call sites (channels, gateway).
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ToolLoopCostTrackingContext {
|
||||
pub tracker: Arc<CostTracker>,
|
||||
pub prices: Arc<std::collections::HashMap<String, ModelPricing>>,
|
||||
}
|
||||
|
||||
impl ToolLoopCostTrackingContext {
|
||||
pub(crate) fn new(
|
||||
tracker: Arc<CostTracker>,
|
||||
prices: Arc<std::collections::HashMap<String, ModelPricing>>,
|
||||
) -> Self {
|
||||
Self { tracker, prices }
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task_local! {
|
||||
pub(crate) static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
|
||||
}
|
||||
|
||||
/// 3-tier model pricing lookup:
|
||||
/// 1. Direct model name
|
||||
/// 2. Qualified `provider/model`
|
||||
/// 3. Suffix after last `/`
|
||||
fn lookup_model_pricing<'a>(
|
||||
prices: &'a std::collections::HashMap<String, ModelPricing>,
|
||||
provider_name: &str,
|
||||
model: &str,
|
||||
) -> Option<&'a ModelPricing> {
|
||||
prices
|
||||
.get(model)
|
||||
.or_else(|| prices.get(&format!("{provider_name}/{model}")))
|
||||
.or_else(|| {
|
||||
model
|
||||
.rsplit_once('/')
|
||||
.and_then(|(_, suffix)| prices.get(suffix))
|
||||
})
|
||||
}
|
||||
|
||||
/// Record token usage from an LLM response via the task-local cost tracker.
|
||||
/// Returns `(total_tokens, cost_usd)` on success, `None` when not scoped or no usage.
|
||||
fn record_tool_loop_cost_usage(
|
||||
provider_name: &str,
|
||||
model: &str,
|
||||
usage: &crate::providers::traits::TokenUsage,
|
||||
) -> Option<(u64, f64)> {
|
||||
let input_tokens = usage.input_tokens.unwrap_or(0);
|
||||
let output_tokens = usage.output_tokens.unwrap_or(0);
|
||||
let total_tokens = input_tokens.saturating_add(output_tokens);
|
||||
if total_tokens == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten()?;
|
||||
let pricing = lookup_model_pricing(&ctx.prices, provider_name, model);
|
||||
let cost_usage = CostTokenUsage::new(
|
||||
model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
pricing.map_or(0.0, |entry| entry.input),
|
||||
pricing.map_or(0.0, |entry| entry.output),
|
||||
);
|
||||
|
||||
if pricing.is_none() {
|
||||
tracing::debug!(
|
||||
provider = provider_name,
|
||||
model,
|
||||
"Cost tracking recorded token usage with zero pricing (no pricing entry found)"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model,
|
||||
"Failed to record cost tracking usage: {error}"
|
||||
);
|
||||
}
|
||||
|
||||
Some((cost_usage.total_tokens, cost_usage.cost_usd))
|
||||
}
|
||||
|
||||
/// Check budget before an LLM call. Returns `None` when no cost tracking
|
||||
/// context is scoped (tests, delegate, CLI without cost config).
|
||||
pub(crate) fn check_tool_loop_budget() -> Option<BudgetCheck> {
|
||||
TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|ctx| {
|
||||
ctx.tracker
|
||||
.check_budget(0.0)
|
||||
.unwrap_or(BudgetCheck::Allowed)
|
||||
})
|
||||
}
|
||||
|
||||
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
|
||||
const STREAM_CHUNK_MIN_CHARS: usize = 80;
|
||||
|
||||
@@ -465,7 +570,7 @@ async fn build_context(
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
@@ -2226,6 +2331,7 @@ pub(crate) async fn agent_turn(
|
||||
dedup_exempt_tools,
|
||||
activated_tools,
|
||||
model_switch_callback,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2535,6 +2641,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
dedup_exempt_tools: &[String],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
model_switch_callback: Option<ModelSwitchCallback>,
|
||||
pacing: &crate::config::PacingConfig,
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
DEFAULT_MAX_TOOL_ITERATIONS
|
||||
@@ -2543,6 +2650,14 @@ pub(crate) async fn run_tool_call_loop(
|
||||
};
|
||||
|
||||
let turn_id = Uuid::new_v4().to_string();
|
||||
let loop_started_at = Instant::now();
|
||||
let loop_ignore_tools: HashSet<&str> = pacing
|
||||
.loop_ignore_tools
|
||||
.iter()
|
||||
.map(String::as_str)
|
||||
.collect();
|
||||
let mut consecutive_identical_outputs: usize = 0;
|
||||
let mut last_tool_output_hash: Option<u64> = None;
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
@@ -2642,6 +2757,19 @@ pub(crate) async fn run_tool_call_loop(
|
||||
hooks.fire_llm_input(history, model).await;
|
||||
}
|
||||
|
||||
// Budget enforcement — block if limit exceeded (no-op when not scoped)
|
||||
if let Some(BudgetCheck::Exceeded {
|
||||
current_usd,
|
||||
limit_usd,
|
||||
period,
|
||||
}) = check_tool_loop_budget()
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"Budget exceeded: ${:.4} of ${:.2} {:?} limit. Cannot make further API calls until the budget resets.",
|
||||
current_usd, limit_usd, period
|
||||
));
|
||||
}
|
||||
|
||||
// Unified path via Provider::chat so provider-specific native tool logic
|
||||
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
|
||||
let request_tools = if use_native_tools {
|
||||
@@ -2659,13 +2787,43 @@ pub(crate) async fn run_tool_call_loop(
|
||||
temperature,
|
||||
);
|
||||
|
||||
let chat_result = if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
// Wrap the LLM call with an optional per-step timeout from pacing config.
|
||||
// This catches a truly hung model response without terminating the overall
|
||||
// task loop (the per-message budget handles that separately).
|
||||
let chat_result = match pacing.step_timeout_secs {
|
||||
Some(step_secs) if step_secs > 0 => {
|
||||
let step_timeout = Duration::from_secs(step_secs);
|
||||
if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = tokio::time::timeout(step_timeout, chat_future) => {
|
||||
match result {
|
||||
Ok(inner) => inner,
|
||||
Err(_) => anyhow::bail!(
|
||||
"LLM inference step timed out after {step_secs}s (step_timeout_secs)"
|
||||
),
|
||||
}
|
||||
},
|
||||
}
|
||||
} else {
|
||||
match tokio::time::timeout(step_timeout, chat_future).await {
|
||||
Ok(inner) => inner,
|
||||
Err(_) => anyhow::bail!(
|
||||
"LLM inference step timed out after {step_secs}s (step_timeout_secs)"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
};
|
||||
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
||||
@@ -2687,6 +2845,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
output_tokens: resp_output_tokens,
|
||||
});
|
||||
|
||||
// Record cost via task-local tracker (no-op when not scoped)
|
||||
let _ = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage));
|
||||
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
// Fall back to text-based parsing (XML tags, markdown blocks,
|
||||
@@ -3158,7 +3322,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
|
||||
}
|
||||
|
||||
// Collect tool results and build per-tool output for loop detection.
|
||||
// Only non-ignored tool outputs contribute to the identical-output hash.
|
||||
let mut detection_relevant_output = String::new();
|
||||
for (tool_name, tool_call_id, outcome) in ordered_results.into_iter().flatten() {
|
||||
if !loop_ignore_tools.contains(tool_name.as_str()) {
|
||||
detection_relevant_output.push_str(&outcome.output);
|
||||
}
|
||||
individual_results.push((tool_call_id, outcome.output.clone()));
|
||||
let _ = writeln!(
|
||||
tool_results,
|
||||
@@ -3167,6 +3337,53 @@ pub(crate) async fn run_tool_call_loop(
|
||||
);
|
||||
}
|
||||
|
||||
// ── Time-gated loop detection ──────────────────────────
|
||||
// When pacing.loop_detection_min_elapsed_secs is set, identical-output
|
||||
// loop detection activates after the task has been running that long.
|
||||
// This avoids false-positive aborts on long-running browser/research
|
||||
// workflows while keeping aggressive protection for quick tasks.
|
||||
// When not configured, identical-output detection is disabled (preserving
|
||||
// existing behavior where only max_iterations prevents runaway loops).
|
||||
let loop_detection_active = match pacing.loop_detection_min_elapsed_secs {
|
||||
Some(min_secs) => loop_started_at.elapsed() >= Duration::from_secs(min_secs),
|
||||
None => false, // disabled when not configured (backwards compatible)
|
||||
};
|
||||
|
||||
if loop_detection_active && !detection_relevant_output.is_empty() {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
detection_relevant_output.hash(&mut hasher);
|
||||
let current_hash = hasher.finish();
|
||||
|
||||
if last_tool_output_hash == Some(current_hash) {
|
||||
consecutive_identical_outputs += 1;
|
||||
} else {
|
||||
consecutive_identical_outputs = 0;
|
||||
last_tool_output_hash = Some(current_hash);
|
||||
}
|
||||
|
||||
// Bail if we see 3+ consecutive identical tool outputs (clear runaway).
|
||||
if consecutive_identical_outputs >= 3 {
|
||||
runtime_trace::record_event(
|
||||
"tool_loop_identical_output_abort",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("identical tool output detected 3 consecutive times"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"consecutive_identical": consecutive_identical_outputs,
|
||||
}),
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Agent loop aborted: identical tool output detected {} consecutive times",
|
||||
consecutive_identical_outputs
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls + tool results to history.
|
||||
// Native mode: use JSON-structured messages so convert_messages() can
|
||||
// reconstruct proper OpenAI-format tool_calls and tool result messages.
|
||||
@@ -3716,6 +3933,7 @@ pub async fn run(
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&config.pacing,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -3943,6 +4161,7 @@ pub async fn run(
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&config.pacing,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -4840,6 +5059,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
@@ -4890,6 +5110,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
@@ -4934,6 +5155,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
@@ -5064,6 +5286,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("parallel execution should complete");
|
||||
@@ -5134,6 +5357,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("cron_add delivery defaults should be injected");
|
||||
@@ -5196,6 +5420,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("explicit delivery mode should be preserved");
|
||||
@@ -5253,6 +5478,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish after deduplicating repeated calls");
|
||||
@@ -5322,6 +5548,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("non-interactive shell should succeed for low-risk command");
|
||||
@@ -5382,6 +5609,7 @@ mod tests {
|
||||
&exempt,
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish with exempt tool executing twice");
|
||||
@@ -5462,6 +5690,7 @@ mod tests {
|
||||
&exempt,
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("loop should complete");
|
||||
@@ -5519,6 +5748,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("native fallback id flow should complete");
|
||||
@@ -5600,6 +5830,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("native tool-call text should be relayed through on_delta");
|
||||
@@ -6395,7 +6626,7 @@ Tail"#;
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None, None, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
@@ -7585,6 +7816,7 @@ Let me check the result."#;
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
@@ -7662,4 +7894,215 @@ Let me check the result."#;
|
||||
let result = filter_by_allowed_tools(specs, Some(&allowed));
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
// ── Cost tracking tests ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_tracking_records_usage_when_scoped() {
|
||||
use super::{
|
||||
run_tool_call_loop, ToolLoopCostTrackingContext, TOOL_LOOP_COST_TRACKING_CONTEXT,
|
||||
};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::cost::CostTracker;
|
||||
use crate::observability::noop::NoopObserver;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let provider = ScriptedProvider {
|
||||
responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse {
|
||||
text: Some("done".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(crate::providers::traits::TokenUsage {
|
||||
input_tokens: Some(1_000),
|
||||
output_tokens: Some(200),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}]))),
|
||||
capabilities: ProviderCapabilities::default(),
|
||||
};
|
||||
let observer = NoopObserver;
|
||||
let workspace = tempfile::TempDir::new().unwrap();
|
||||
let mut cost_config = crate::config::CostConfig {
|
||||
enabled: true,
|
||||
..crate::config::CostConfig::default()
|
||||
};
|
||||
cost_config.prices = HashMap::from([(
|
||||
"mock-model".to_string(),
|
||||
ModelPricing {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
},
|
||||
)]);
|
||||
let tracker = Arc::new(CostTracker::new(cost_config.clone(), workspace.path()).unwrap());
|
||||
let ctx = ToolLoopCostTrackingContext::new(
|
||||
Arc::clone(&tracker),
|
||||
Arc::new(cost_config.prices.clone()),
|
||||
);
|
||||
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
|
||||
|
||||
let result = TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.scope(
|
||||
Some(ctx),
|
||||
run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&[],
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"test",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should succeed");
|
||||
|
||||
assert_eq!(result, "done");
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
assert_eq!(summary.request_count, 1);
|
||||
assert_eq!(summary.total_tokens, 1_200);
|
||||
assert!(summary.session_cost_usd > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_tracking_enforces_budget() {
|
||||
use super::{
|
||||
run_tool_call_loop, ToolLoopCostTrackingContext, TOOL_LOOP_COST_TRACKING_CONTEXT,
|
||||
};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::cost::CostTracker;
|
||||
use crate::observability::noop::NoopObserver;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["should not reach this"]);
|
||||
let observer = NoopObserver;
|
||||
let workspace = tempfile::TempDir::new().unwrap();
|
||||
let cost_config = crate::config::CostConfig {
|
||||
enabled: true,
|
||||
daily_limit_usd: 0.001, // very low limit
|
||||
..crate::config::CostConfig::default()
|
||||
};
|
||||
let tracker = Arc::new(CostTracker::new(cost_config.clone(), workspace.path()).unwrap());
|
||||
// Record a usage that already exceeds the limit
|
||||
tracker
|
||||
.record_usage(crate::cost::types::TokenUsage::new(
|
||||
"mock-model",
|
||||
100_000,
|
||||
50_000,
|
||||
1.0,
|
||||
1.0,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let ctx = ToolLoopCostTrackingContext::new(
|
||||
Arc::clone(&tracker),
|
||||
Arc::new(HashMap::from([(
|
||||
"mock-model".to_string(),
|
||||
ModelPricing {
|
||||
input: 1.0,
|
||||
output: 1.0,
|
||||
},
|
||||
)])),
|
||||
);
|
||||
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
|
||||
|
||||
let err = TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.scope(
|
||||
Some(ctx),
|
||||
run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&[],
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"test",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail with budget exceeded");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("Budget exceeded"),
|
||||
"error should mention budget: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_tracking_is_noop_without_scope() {
|
||||
use super::run_tool_call_loop;
|
||||
use crate::observability::noop::NoopObserver;
|
||||
|
||||
// No TOOL_LOOP_COST_TRACKING_CONTEXT scoped — should run fine
|
||||
let provider = ScriptedProvider {
|
||||
responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse {
|
||||
text: Some("ok".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(crate::providers::traits::TokenUsage {
|
||||
input_tokens: Some(500),
|
||||
output_tokens: Some(100),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}]))),
|
||||
capabilities: ProviderCapabilities::default(),
|
||||
};
|
||||
let observer = NoopObserver;
|
||||
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&[],
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"test",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("should succeed without cost scope");
|
||||
|
||||
assert_eq!(result, "ok");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,9 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit, session_id).await?;
|
||||
let entries = memory
|
||||
.recall(user_message, self.limit, session_id, None, None)
|
||||
.await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
@@ -102,6 +104,8 @@ mod tests {
|
||||
_query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if limit == 0 {
|
||||
return Ok(vec![]);
|
||||
@@ -163,6 +167,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(self.entries.as_ref().clone())
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ pub struct DingTalkChannel {
|
||||
/// Per-chat session webhooks for sending replies (chatID -> webhook URL).
|
||||
/// DingTalk provides a unique webhook URL with each incoming message.
|
||||
session_webhooks: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Response from DingTalk gateway connection registration.
|
||||
@@ -34,11 +36,18 @@ impl DingTalkChannel {
|
||||
client_secret,
|
||||
allowed_users,
|
||||
session_webhooks: Arc::new(RwLock::new(HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.dingtalk")
|
||||
crate::config::build_channel_proxy_client("channel.dingtalk", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
|
||||
+10
-1
@@ -18,6 +18,8 @@ pub struct DiscordChannel {
|
||||
listen_to_bots: bool,
|
||||
mention_only: bool,
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@@ -35,11 +37,18 @@ impl DiscordChannel {
|
||||
listen_to_bots,
|
||||
mention_only,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.discord")
|
||||
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
/// Check if a Discord user ID is in the allowlist.
|
||||
|
||||
+16
-1
@@ -380,6 +380,8 @@ pub struct LarkChannel {
|
||||
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
|
||||
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
|
||||
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
@@ -423,6 +425,7 @@ impl LarkChannel {
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::default(),
|
||||
tenant_token: Arc::new(RwLock::new(None)),
|
||||
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -444,6 +447,7 @@ impl LarkChannel {
|
||||
platform,
|
||||
);
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch.proxy_url = config.proxy_url.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
@@ -461,6 +465,7 @@ impl LarkChannel {
|
||||
LarkPlatform::Lark,
|
||||
);
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch.proxy_url = config.proxy_url.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
@@ -476,11 +481,15 @@ impl LarkChannel {
|
||||
LarkPlatform::Feishu,
|
||||
);
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch.proxy_url = config.proxy_url.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client(self.platform.proxy_service_key())
|
||||
crate::config::build_channel_proxy_client(
|
||||
self.platform.proxy_service_key(),
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn channel_name(&self) -> &'static str {
|
||||
@@ -2113,6 +2122,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::default(),
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -2135,6 +2145,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -2169,6 +2180,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_config(&cfg);
|
||||
@@ -2193,6 +2205,7 @@ mod tests {
|
||||
use_feishu: true,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_lark_config(&cfg);
|
||||
@@ -2214,6 +2227,7 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_feishu_config(&cfg);
|
||||
@@ -2386,6 +2400,7 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let ch_feishu = LarkChannel::from_feishu_config(&feishu_cfg);
|
||||
assert_eq!(
|
||||
|
||||
@@ -17,6 +17,8 @@ pub struct MattermostChannel {
|
||||
mention_only: bool,
|
||||
/// Handle for the background typing-indicator loop (aborted on stop_typing).
|
||||
typing_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl MattermostChannel {
|
||||
@@ -38,11 +40,18 @@ impl MattermostChannel {
|
||||
thread_replies,
|
||||
mention_only,
|
||||
typing_handle: Mutex::new(None),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.mattermost")
|
||||
crate::config::build_channel_proxy_client("channel.mattermost", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
/// Check if a user ID is in the allowlist.
|
||||
|
||||
+218
-44
@@ -222,9 +222,21 @@ fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
|
||||
fn channel_message_timeout_budget_secs(
|
||||
message_timeout_secs: u64,
|
||||
max_tool_iterations: usize,
|
||||
) -> u64 {
|
||||
channel_message_timeout_budget_secs_with_cap(
|
||||
message_timeout_secs,
|
||||
max_tool_iterations,
|
||||
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP,
|
||||
)
|
||||
}
|
||||
|
||||
fn channel_message_timeout_budget_secs_with_cap(
|
||||
message_timeout_secs: u64,
|
||||
max_tool_iterations: usize,
|
||||
scale_cap: u64,
|
||||
) -> u64 {
|
||||
let iterations = max_tool_iterations.max(1) as u64;
|
||||
let scale = iterations.min(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
|
||||
let scale = iterations.min(scale_cap);
|
||||
message_timeout_secs.saturating_mul(scale)
|
||||
}
|
||||
|
||||
@@ -313,6 +325,12 @@ impl InterruptOnNewMessageConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelCostTrackingState {
|
||||
tracker: Arc<crate::cost::CostTracker>,
|
||||
prices: Arc<HashMap<String, crate::config::schema::ModelPricing>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelRuntimeContext {
|
||||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
@@ -355,6 +373,8 @@ struct ChannelRuntimeContext {
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
cost_tracking: Option<ChannelCostTrackingState>,
|
||||
pacing: crate::config::PacingConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1511,7 +1531,7 @@ async fn build_memory_context(
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
let mut included = 0usize;
|
||||
let mut used_chars = 0usize;
|
||||
|
||||
@@ -2395,8 +2415,18 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let model_switch_callback = get_model_switch_state();
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
let scale_cap = ctx
|
||||
.pacing
|
||||
.message_timeout_scale_max
|
||||
.unwrap_or(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
|
||||
let timeout_budget_secs = channel_message_timeout_budget_secs_with_cap(
|
||||
ctx.message_timeout_secs,
|
||||
ctx.max_tool_iterations,
|
||||
scale_cap,
|
||||
);
|
||||
let cost_tracking_context = ctx.cost_tracking.clone().map(|state| {
|
||||
crate::agent::loop_::ToolLoopCostTrackingContext::new(state.tracker, state.prices)
|
||||
});
|
||||
let llm_call_start = Instant::now();
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let elapsed_before_llm_ms = started_at.elapsed().as_millis() as u64;
|
||||
@@ -2406,6 +2436,8 @@ async fn process_channel_message(
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_budget_secs),
|
||||
crate::agent::loop_::TOOL_LOOP_COST_TRACKING_CONTEXT.scope(
|
||||
cost_tracking_context.clone(),
|
||||
run_tool_call_loop(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
@@ -2433,6 +2465,8 @@ async fn process_channel_message(
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&ctx.pacing,
|
||||
),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@@ -3691,7 +3725,8 @@ fn collect_configured_channels(
|
||||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms)
|
||||
.with_transcription(config.transcription.clone())
|
||||
.with_tts(config.tts.clone())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(tg.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3699,13 +3734,16 @@ fn collect_configured_channels(
|
||||
if let Some(ref dc) = config.channels_config.discord {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Discord",
|
||||
channel: Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)),
|
||||
channel: Arc::new(
|
||||
DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_proxy_url(dc.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3722,7 +3760,8 @@ fn collect_configured_channels(
|
||||
)
|
||||
.with_thread_replies(sl.thread_replies.unwrap_or(true))
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(sl.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3730,14 +3769,17 @@ fn collect_configured_channels(
|
||||
if let Some(ref mm) = config.channels_config.mattermost {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Mattermost",
|
||||
channel: Arc::new(MattermostChannel::new(
|
||||
mm.url.clone(),
|
||||
mm.bot_token.clone(),
|
||||
mm.channel_id.clone(),
|
||||
mm.allowed_users.clone(),
|
||||
mm.thread_replies.unwrap_or(true),
|
||||
mm.mention_only.unwrap_or(false),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
MattermostChannel::new(
|
||||
mm.url.clone(),
|
||||
mm.bot_token.clone(),
|
||||
mm.channel_id.clone(),
|
||||
mm.allowed_users.clone(),
|
||||
mm.thread_replies.unwrap_or(true),
|
||||
mm.mention_only.unwrap_or(false),
|
||||
)
|
||||
.with_proxy_url(mm.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3775,14 +3817,17 @@ fn collect_configured_channels(
|
||||
if let Some(ref sig) = config.channels_config.signal {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Signal",
|
||||
channel: Arc::new(SignalChannel::new(
|
||||
sig.http_url.clone(),
|
||||
sig.account.clone(),
|
||||
sig.group_id.clone(),
|
||||
sig.allowed_from.clone(),
|
||||
sig.ignore_attachments,
|
||||
sig.ignore_stories,
|
||||
)),
|
||||
channel: Arc::new(
|
||||
SignalChannel::new(
|
||||
sig.http_url.clone(),
|
||||
sig.account.clone(),
|
||||
sig.group_id.clone(),
|
||||
sig.allowed_from.clone(),
|
||||
sig.ignore_attachments,
|
||||
sig.ignore_stories,
|
||||
)
|
||||
.with_proxy_url(sig.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3799,12 +3844,15 @@ fn collect_configured_channels(
|
||||
if wa.is_cloud_config() {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WhatsApp",
|
||||
channel: Arc::new(WhatsAppChannel::new(
|
||||
wa.access_token.clone().unwrap_or_default(),
|
||||
wa.phone_number_id.clone().unwrap_or_default(),
|
||||
wa.verify_token.clone().unwrap_or_default(),
|
||||
wa.allowed_numbers.clone(),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
WhatsAppChannel::new(
|
||||
wa.access_token.clone().unwrap_or_default(),
|
||||
wa.phone_number_id.clone().unwrap_or_default(),
|
||||
wa.verify_token.clone().unwrap_or_default(),
|
||||
wa.allowed_numbers.clone(),
|
||||
)
|
||||
.with_proxy_url(wa.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
} else {
|
||||
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
|
||||
@@ -3861,11 +3909,12 @@ fn collect_configured_channels(
|
||||
if let Some(ref wati_cfg) = config.channels_config.wati {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WATI",
|
||||
channel: Arc::new(WatiChannel::new(
|
||||
channel: Arc::new(WatiChannel::new_with_proxy(
|
||||
wati_cfg.api_token.clone(),
|
||||
wati_cfg.api_url.clone(),
|
||||
wati_cfg.tenant_id.clone(),
|
||||
wati_cfg.allowed_numbers.clone(),
|
||||
wati_cfg.proxy_url.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
@@ -3873,10 +3922,11 @@ fn collect_configured_channels(
|
||||
if let Some(ref nc) = config.channels_config.nextcloud_talk {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Nextcloud Talk",
|
||||
channel: Arc::new(NextcloudTalkChannel::new(
|
||||
channel: Arc::new(NextcloudTalkChannel::new_with_proxy(
|
||||
nc.base_url.clone(),
|
||||
nc.app_token.clone(),
|
||||
nc.allowed_users.clone(),
|
||||
nc.proxy_url.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
@@ -3948,11 +3998,14 @@ fn collect_configured_channels(
|
||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "DingTalk",
|
||||
channel: Arc::new(DingTalkChannel::new(
|
||||
dt.client_id.clone(),
|
||||
dt.client_secret.clone(),
|
||||
dt.allowed_users.clone(),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
DingTalkChannel::new(
|
||||
dt.client_id.clone(),
|
||||
dt.client_secret.clone(),
|
||||
dt.allowed_users.clone(),
|
||||
)
|
||||
.with_proxy_url(dt.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3965,7 +4018,8 @@ fn collect_configured_channels(
|
||||
qq.app_secret.clone(),
|
||||
qq.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(qq.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -4600,6 +4654,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
activated_tools: ch_activated_handle,
|
||||
cost_tracking: crate::cost::CostTracker::get_or_init_global(
|
||||
config.cost.clone(),
|
||||
&config.workspace_dir,
|
||||
)
|
||||
.map(|tracker| ChannelCostTrackingState {
|
||||
tracker,
|
||||
prices: Arc::new(config.cost.prices.clone()),
|
||||
}),
|
||||
pacing: config.pacing.clone(),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
@@ -4696,6 +4759,49 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_message_timeout_budget_with_custom_scale_cap() {
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 8, 8),
|
||||
300 * 8
|
||||
);
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 20, 8),
|
||||
300 * 8
|
||||
);
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 10, 1),
|
||||
300
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pacing_config_defaults_preserve_existing_behavior() {
|
||||
let pacing = crate::config::PacingConfig::default();
|
||||
assert!(pacing.step_timeout_secs.is_none());
|
||||
assert!(pacing.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(pacing.loop_ignore_tools.is_empty());
|
||||
assert!(pacing.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pacing_message_timeout_scale_max_overrides_default_cap() {
|
||||
// Custom cap of 8 scales budget proportionally
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 10, 8),
|
||||
300 * 8
|
||||
);
|
||||
// Default cap produces the standard behavior
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(
|
||||
300,
|
||||
10,
|
||||
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
|
||||
),
|
||||
300 * CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_window_overflow_error_detector_matches_known_messages() {
|
||||
let overflow_err = anyhow::anyhow!(
|
||||
@@ -4899,6 +5005,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -5014,6 +5122,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -5085,6 +5195,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -5175,6 +5287,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(
|
||||
@@ -5715,6 +5829,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5795,6 +5911,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5889,6 +6007,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5968,6 +6088,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6057,6 +6179,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6167,6 +6291,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6258,6 +6384,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6364,6 +6492,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6455,6 +6585,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6536,6 +6668,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6584,6 +6718,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -6636,6 +6772,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(vec![crate::memory::MemoryEntry {
|
||||
id: "entry-1".to_string(),
|
||||
@@ -6728,6 +6866,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -6829,6 +6969,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6944,7 +7086,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7058,6 +7202,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7153,6 +7299,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7232,6 +7380,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7884,7 +8034,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None, None, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
@@ -7997,6 +8147,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8127,6 +8279,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8297,6 +8451,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8404,6 +8560,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8721,6 +8879,7 @@ This is an example JSON object for profile settings."#;
|
||||
thread_replies: Some(true),
|
||||
mention_only: Some(false),
|
||||
interrupt_on_new_message: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
let channels = collect_configured_channels(&config, "test");
|
||||
@@ -8974,6 +9133,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -9060,6 +9221,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9221,6 +9384,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9331,6 +9496,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9433,6 +9600,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9555,6 +9724,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9613,6 +9784,7 @@ This is an example JSON object for profile settings."#;
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
match build_channel_by_id(&config, "telegram") {
|
||||
Ok(channel) => assert_eq!(channel.name(), "telegram"),
|
||||
@@ -9814,6 +9986,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
|
||||
@@ -17,11 +17,23 @@ pub struct NextcloudTalkChannel {
|
||||
|
||||
impl NextcloudTalkChannel {
|
||||
pub fn new(base_url: String, app_token: String, allowed_users: Vec<String>) -> Self {
|
||||
Self::new_with_proxy(base_url, app_token, allowed_users, None)
|
||||
}
|
||||
|
||||
pub fn new_with_proxy(
|
||||
base_url: String,
|
||||
app_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
proxy_url: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
app_token,
|
||||
allowed_users,
|
||||
client: reqwest::Client::new(),
|
||||
client: crate::config::build_channel_proxy_client(
|
||||
"channel.nextcloud_talk",
|
||||
proxy_url.as_deref(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+10
-1
@@ -285,6 +285,8 @@ pub struct QQChannel {
|
||||
upload_cache: Arc<RwLock<HashMap<String, UploadCacheEntry>>>,
|
||||
/// Passive reply tracker for QQ API rate limiting.
|
||||
reply_tracker: Arc<RwLock<HashMap<String, ReplyRecord>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl QQChannel {
|
||||
@@ -298,6 +300,7 @@ impl QQChannel {
|
||||
workspace_dir: None,
|
||||
upload_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
reply_tracker: Arc::new(RwLock::new(HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,8 +310,14 @@ impl QQChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.qq")
|
||||
crate::config::build_channel_proxy_client("channel.qq", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
|
||||
+14
-1
@@ -28,6 +28,8 @@ pub struct SignalChannel {
|
||||
allowed_from: Vec<String>,
|
||||
ignore_attachments: bool,
|
||||
ignore_stories: bool,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
// ── signal-cli SSE event JSON shapes ────────────────────────────
|
||||
@@ -87,12 +89,23 @@ impl SignalChannel {
|
||||
allowed_from,
|
||||
ignore_attachments,
|
||||
ignore_stories,
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
let builder = Client::builder().connect_timeout(Duration::from_secs(10));
|
||||
let builder = crate::config::apply_runtime_proxy_to_builder(builder, "channel.signal");
|
||||
let builder = crate::config::apply_channel_proxy_to_builder(
|
||||
builder,
|
||||
"channel.signal",
|
||||
self.proxy_url.as_deref(),
|
||||
);
|
||||
builder.build().expect("Signal HTTP client should build")
|
||||
}
|
||||
|
||||
|
||||
+78
-2
@@ -32,6 +32,8 @@ pub struct SlackChannel {
|
||||
workspace_dir: Option<PathBuf>,
|
||||
/// Maps channel_id -> thread_ts for active assistant threads (used for status indicators).
|
||||
active_assistant_thread: Mutex<HashMap<String, String>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@@ -46,6 +48,7 @@ const SLACK_ATTACHMENT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
|
||||
const SLACK_ATTACHMENT_IMAGE_INLINE_FALLBACK_MAX_BYTES: usize = 512 * 1024;
|
||||
const SLACK_ATTACHMENT_TEXT_DOWNLOAD_MAX_BYTES: usize = 256 * 1024;
|
||||
const SLACK_ATTACHMENT_TEXT_INLINE_MAX_CHARS: usize = 12_000;
|
||||
const SLACK_MARKDOWN_BLOCK_MAX_CHARS: usize = 12_000;
|
||||
const SLACK_ATTACHMENT_FILENAME_MAX_CHARS: usize = 128;
|
||||
const SLACK_USER_CACHE_MAX_ENTRIES: usize = 1000;
|
||||
const SLACK_ATTACHMENT_SAVE_SUBDIR: &str = "slack_files";
|
||||
@@ -121,6 +124,7 @@ impl SlackChannel {
|
||||
user_display_name_cache: Mutex::new(HashMap::new()),
|
||||
workspace_dir: None,
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,8 +152,19 @@ impl SlackChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("channel.slack", 30, 10)
|
||||
crate::config::build_channel_proxy_client_with_timeouts(
|
||||
"channel.slack",
|
||||
self.proxy_url.as_deref(),
|
||||
30,
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if a Slack user ID is in the allowlist.
|
||||
@@ -804,12 +819,13 @@ impl SlackChannel {
|
||||
}
|
||||
|
||||
fn slack_media_http_client_no_redirect(&self) -> anyhow::Result<reqwest::Client> {
|
||||
let builder = crate::config::apply_runtime_proxy_to_builder(
|
||||
let builder = crate::config::apply_channel_proxy_to_builder(
|
||||
reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect_timeout(Duration::from_secs(10)),
|
||||
"channel.slack",
|
||||
self.proxy_url.as_deref(),
|
||||
);
|
||||
builder
|
||||
.build()
|
||||
@@ -2272,6 +2288,14 @@ impl Channel for SlackChannel {
|
||||
"text": message.content
|
||||
});
|
||||
|
||||
// Use Slack's native markdown block for rich formatting when content fits.
|
||||
if message.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": message.content
|
||||
}]);
|
||||
}
|
||||
|
||||
if let Some(ts) = self.outbound_thread_ts(message) {
|
||||
body["thread_ts"] = serde_json::json!(ts);
|
||||
}
|
||||
@@ -3630,6 +3654,58 @@ mod tests {
|
||||
assert_ne!(key1, key2, "session key should differ per thread");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_send_uses_markdown_blocks() {
|
||||
let msg = SendMessage::new("**bold** and _italic_", "C123");
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
|
||||
// Build the same JSON body that send() would construct.
|
||||
let mut body = serde_json::json!({
|
||||
"channel": msg.recipient,
|
||||
"text": msg.content
|
||||
});
|
||||
if msg.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": msg.content
|
||||
}]);
|
||||
}
|
||||
|
||||
// Verify blocks are present with correct structure.
|
||||
let blocks = body["blocks"]
|
||||
.as_array()
|
||||
.expect("blocks should be an array");
|
||||
assert_eq!(blocks.len(), 1);
|
||||
assert_eq!(blocks[0]["type"], "markdown");
|
||||
assert_eq!(blocks[0]["text"], msg.content);
|
||||
// text field kept as plaintext fallback.
|
||||
assert_eq!(body["text"], msg.content);
|
||||
// Suppress unused variable warning.
|
||||
let _ = ch.name();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_send_skips_markdown_blocks_for_long_content() {
|
||||
let long_content = "x".repeat(SLACK_MARKDOWN_BLOCK_MAX_CHARS + 1);
|
||||
let msg = SendMessage::new(long_content.clone(), "C123");
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"channel": msg.recipient,
|
||||
"text": msg.content
|
||||
});
|
||||
if msg.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": msg.content
|
||||
}]);
|
||||
}
|
||||
|
||||
assert!(
|
||||
body.get("blocks").is_none(),
|
||||
"blocks should not be set for oversized content"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_typing_requires_thread_context() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
|
||||
@@ -337,6 +337,8 @@ pub struct TelegramChannel {
|
||||
voice_chats: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
||||
pending_voice:
|
||||
Arc<std::sync::Mutex<std::collections::HashMap<String, (String, std::time::Instant)>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -379,6 +381,7 @@ impl TelegramChannel {
|
||||
tts_config: None,
|
||||
voice_chats: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
pending_voice: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,6 +391,12 @@ impl TelegramChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure workspace directory for saving downloaded attachments.
|
||||
pub fn with_workspace_dir(mut self, dir: std::path::PathBuf) -> Self {
|
||||
self.workspace_dir = Some(dir);
|
||||
@@ -478,7 +487,7 @@ impl TelegramChannel {
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.telegram")
|
||||
crate::config::build_channel_proxy_client("channel.telegram", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
fn normalize_identity(value: &str) -> String {
|
||||
|
||||
+389
-18
@@ -80,17 +80,10 @@ fn resolve_transcription_api_key(config: &TranscriptionConfig) -> Result<String>
|
||||
);
|
||||
}
|
||||
|
||||
/// Validate audio data and resolve MIME type from file name.
|
||||
/// Resolve MIME type and normalize filename from extension.
|
||||
///
|
||||
/// Returns `(normalized_filename, mime_type)` on success.
|
||||
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
|
||||
if audio_data.len() > MAX_AUDIO_BYTES {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
|
||||
audio_data.len()
|
||||
);
|
||||
}
|
||||
|
||||
/// No size check — callers enforce their own limits.
|
||||
fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> {
|
||||
let normalized_name = normalize_audio_filename(file_name);
|
||||
let extension = normalized_name
|
||||
.rsplit_once('.')
|
||||
@@ -98,13 +91,26 @@ fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'stati
|
||||
.unwrap_or("");
|
||||
let mime = mime_for_audio(extension).ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Unsupported audio format '.{extension}' — accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
|
||||
"Unsupported audio format '.{extension}' — \
|
||||
accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok((normalized_name, mime))
|
||||
}
|
||||
|
||||
/// Validate audio data and resolve MIME type from file name.
|
||||
///
|
||||
/// Enforces the 25 MB cloud API cap. Returns `(normalized_filename, mime_type)` on success.
|
||||
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
|
||||
if audio_data.len() > MAX_AUDIO_BYTES {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
|
||||
audio_data.len()
|
||||
);
|
||||
}
|
||||
resolve_audio_format(file_name)
|
||||
}
|
||||
|
||||
// ── TranscriptionProvider trait ─────────────────────────────────
|
||||
|
||||
/// Trait for speech-to-text provider implementations.
|
||||
@@ -586,21 +592,120 @@ impl TranscriptionProvider for GoogleSttProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider ────────────────────────────────────────
|
||||
|
||||
/// Self-hosted faster-whisper-compatible STT provider.
|
||||
///
|
||||
/// POSTs audio as `multipart/form-data` (field name `file`) to a configurable
|
||||
/// HTTP endpoint (e.g. faster-whisper on GEX44 over WireGuard). The endpoint
|
||||
/// must return `{"text": "..."}`. No cloud API key required. Size limit is
|
||||
/// configurable — not constrained by the 25 MB cloud API cap.
|
||||
pub struct LocalWhisperProvider {
|
||||
url: String,
|
||||
bearer_token: String,
|
||||
max_audio_bytes: usize,
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl LocalWhisperProvider {
|
||||
/// Build from config. Fails if `url` or `bearer_token` is empty, if `url`
|
||||
/// is not a valid HTTP/HTTPS URL (scheme must be `http` or `https`), if
|
||||
/// `max_audio_bytes` is zero, or if `timeout_secs` is zero.
|
||||
pub fn from_config(config: &crate::config::LocalWhisperConfig) -> Result<Self> {
|
||||
let url = config.url.trim().to_string();
|
||||
anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty");
|
||||
let parsed = url
|
||||
.parse::<reqwest::Url>()
|
||||
.with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?;
|
||||
anyhow::ensure!(
|
||||
matches!(parsed.scheme(), "http" | "https"),
|
||||
"local_whisper: `url` must use http or https scheme, got {:?}",
|
||||
parsed.scheme()
|
||||
);
|
||||
|
||||
let bearer_token = config.bearer_token.trim().to_string();
|
||||
anyhow::ensure!(
|
||||
!bearer_token.is_empty(),
|
||||
"local_whisper: `bearer_token` must not be empty"
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
config.max_audio_bytes > 0,
|
||||
"local_whisper: `max_audio_bytes` must be greater than zero"
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
config.timeout_secs > 0,
|
||||
"local_whisper: `timeout_secs` must be greater than zero"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
url,
|
||||
bearer_token,
|
||||
max_audio_bytes: config.max_audio_bytes,
|
||||
timeout_secs: config.timeout_secs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for LocalWhisperProvider {
|
||||
fn name(&self) -> &str {
|
||||
"local_whisper"
|
||||
}
|
||||
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
if audio_data.len() > self.max_audio_bytes {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, local_whisper max {})",
|
||||
audio_data.len(),
|
||||
self.max_audio_bytes
|
||||
);
|
||||
}
|
||||
|
||||
let (normalized_name, mime) = resolve_audio_format(file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.local_whisper");
|
||||
|
||||
// to_vec() clones the buffer for the multipart payload; peak memory per
|
||||
// call is ~2× max_audio_bytes. TODO: replace with streaming upload once
|
||||
// reqwest supports body streaming in multipart parts.
|
||||
let file_part = Part::bytes(audio_data.to_vec())
|
||||
.file_name(normalized_name)
|
||||
.mime_str(mime)?;
|
||||
|
||||
let resp = client
|
||||
.post(&self.url)
|
||||
.bearer_auth(&self.bearer_token)
|
||||
.multipart(Form::new().part("file", file_part))
|
||||
.timeout(std::time::Duration::from_secs(self.timeout_secs))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send audio to local Whisper endpoint")?;
|
||||
|
||||
parse_whisper_response(resp).await
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared response parsing ─────────────────────────────────────
|
||||
|
||||
/// Parse a standard Whisper-compatible JSON response (`{ "text": "..." }`).
|
||||
/// Parse a faster-whisper-compatible JSON response (`{ "text": "..." }`).
|
||||
///
|
||||
/// Checks HTTP status before attempting JSON parsing so that non-JSON error
|
||||
/// bodies (plain text, HTML, empty 5xx) produce a readable status error
|
||||
/// rather than a confusing "Failed to parse transcription response".
|
||||
async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
bail!("Transcription API error ({}): {}", status, body.trim());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse transcription response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
|
||||
bail!("Transcription API error ({}): {}", status, error_msg);
|
||||
}
|
||||
|
||||
let text = body["text"]
|
||||
.as_str()
|
||||
.context("Transcription response missing 'text' field")?
|
||||
@@ -657,6 +762,17 @@ impl TranscriptionManager {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref local_cfg) = config.local_whisper {
|
||||
match LocalWhisperProvider::from_config(local_cfg) {
|
||||
Ok(p) => {
|
||||
providers.insert("local_whisper".to_string(), Box::new(p));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("local_whisper config invalid, provider skipped: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let default_provider = config.default_provider.clone();
|
||||
|
||||
if config.enabled && !providers.contains_key(&default_provider) {
|
||||
@@ -1036,5 +1152,260 @@ mod tests {
|
||||
assert!(config.deepgram.is_none());
|
||||
assert!(config.assemblyai.is_none());
|
||||
assert!(config.google.is_none());
|
||||
assert!(config.local_whisper.is_none());
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider tests (TDD — added below as red/green cycles) ──
|
||||
|
||||
fn local_whisper_config(url: &str) -> crate::config::LocalWhisperConfig {
|
||||
crate::config::LocalWhisperConfig {
|
||||
url: url.to_string(),
|
||||
bearer_token: "test-token".to_string(),
|
||||
max_audio_bytes: 10 * 1024 * 1024,
|
||||
timeout_secs: 30,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_empty_url() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.url = String::new();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string().contains("`url` must not be empty"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_invalid_url() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.url = "not-a-url".to_string();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(err.to_string().contains("invalid `url`"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_non_http_url() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(err.to_string().contains("http or https"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_empty_bearer_token() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.bearer_token = String::new();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string().contains("`bearer_token` must not be empty"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_zero_max_audio_bytes() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.max_audio_bytes = 0;
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("`max_audio_bytes` must be greater than zero"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_zero_timeout() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.timeout_secs = 0;
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("`timeout_secs` must be greater than zero"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_registered_when_config_present() {
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.local_whisper = Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe"));
|
||||
config.default_provider = "local_whisper".to_string();
|
||||
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
assert!(
|
||||
manager.available_providers().contains(&"local_whisper"),
|
||||
"expected local_whisper in {:?}",
|
||||
manager.available_providers()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_misconfigured_section_fails_manager_construction() {
|
||||
// A misconfigured local_whisper section logs a warning and skips
|
||||
// registration. When local_whisper is also the default_provider and
|
||||
// transcription is enabled, the safety net in TranscriptionManager
|
||||
// surfaces the error: "not configured".
|
||||
let mut config = TranscriptionConfig::default();
|
||||
let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
bad_cfg.bearer_token = String::new();
|
||||
config.local_whisper = Some(bad_cfg);
|
||||
config.enabled = true;
|
||||
config.default_provider = "local_whisper".to_string();
|
||||
|
||||
let err = TranscriptionManager::new(&config).err().unwrap();
|
||||
assert!(
|
||||
err.to_string().contains("not configured"),
|
||||
"expected 'not configured' from manager safety net, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_audio_still_enforces_25mb_cap() {
|
||||
// Regression: extracting resolve_audio_format() must not weaken validate_audio().
|
||||
let at_limit = vec![0u8; MAX_AUDIO_BYTES];
|
||||
assert!(validate_audio(&at_limit, "test.ogg").is_ok());
|
||||
let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1];
|
||||
let err = validate_audio(&over_limit, "test.ogg").unwrap_err();
|
||||
assert!(err.to_string().contains("too large"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_rejects_oversized_audio() {
|
||||
let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
let big = vec![0u8; cfg.max_audio_bytes + 1];
|
||||
let err = provider.transcribe(&big, "voice.ogg").await.unwrap_err();
|
||||
assert!(err.to_string().contains("too large"), "got: {err}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_rejects_unsupported_format() {
|
||||
let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
let data = vec![0u8; 100];
|
||||
let err = provider.transcribe(&data, "voice.aiff").await.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("Unsupported audio format"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider HTTP mock tests ────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_returns_text_from_response() {
|
||||
use wiremock::matchers::{header_exists, method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.and(header_exists("authorization"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.set_body_json(serde_json::json!({"text": "hello world"})),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let result = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_sends_bearer_auth_header() {
|
||||
use wiremock::matchers::{header, method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.and(header("authorization", "Bearer test-token"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let result = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "auth ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_propagates_http_error() {
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(503).set_body_json(
|
||||
serde_json::json!({"error": {"message": "service unavailable"}}),
|
||||
),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let err = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("503") || err.to_string().contains("service unavailable"),
|
||||
"expected HTTP error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_propagates_non_json_http_error() {
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(502)
|
||||
.set_body_string("Bad Gateway")
|
||||
.insert_header("content-type", "text/plain"),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let err = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("502"), "got: {err}");
|
||||
assert!(
|
||||
err.to_string().contains("Bad Gateway"),
|
||||
"expected plain-text body in error, got: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+11
-1
@@ -22,13 +22,23 @@ impl WatiChannel {
|
||||
api_url: String,
|
||||
tenant_id: Option<String>,
|
||||
allowed_numbers: Vec<String>,
|
||||
) -> Self {
|
||||
Self::new_with_proxy(api_token, api_url, tenant_id, allowed_numbers, None)
|
||||
}
|
||||
|
||||
pub fn new_with_proxy(
|
||||
api_token: String,
|
||||
api_url: String,
|
||||
tenant_id: Option<String>,
|
||||
allowed_numbers: Vec<String>,
|
||||
proxy_url: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_token,
|
||||
api_url,
|
||||
tenant_id,
|
||||
allowed_numbers,
|
||||
client: crate::config::build_runtime_proxy_client("channel.wati"),
|
||||
client: crate::config::build_channel_proxy_client("channel.wati", proxy_url.as_deref()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ pub struct WhatsAppChannel {
|
||||
endpoint_id: String,
|
||||
verify_token: String,
|
||||
allowed_numbers: Vec<String>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl WhatsAppChannel {
|
||||
@@ -41,11 +43,18 @@ impl WhatsAppChannel {
|
||||
endpoint_id,
|
||||
verify_token,
|
||||
allowed_numbers,
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.whatsapp")
|
||||
crate::config::build_channel_proxy_client("channel.whatsapp", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
/// Check if a phone number is allowed (E.164 format: +1234567890)
|
||||
|
||||
@@ -249,7 +249,7 @@ async fn check_memory_roundtrip(config: &crate::config::Config) -> CheckResult {
|
||||
return CheckResult::fail("memory", format!("write failed: {e}"));
|
||||
}
|
||||
|
||||
match mem.recall(test_key, 1, None).await {
|
||||
match mem.recall(test_key, 1, None, None, None).await {
|
||||
Ok(entries) if !entries.is_empty() => {
|
||||
let _ = mem.forget(test_key).await;
|
||||
CheckResult::pass("memory", "write/read/delete round-trip OK")
|
||||
|
||||
+28
-22
@@ -4,31 +4,32 @@ pub mod workspace;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{
|
||||
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
|
||||
apply_channel_proxy_to_builder, apply_runtime_proxy_to_builder, build_channel_proxy_client,
|
||||
build_channel_proxy_client_with_timeouts, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, CloudOpsConfig, ComposioConfig, Config, ConversationalAiConfig, CostConfig,
|
||||
CronConfig, DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig,
|
||||
DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig,
|
||||
EstopConfig, FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig,
|
||||
GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig, HardwareConfig, HardwareTransport,
|
||||
HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig,
|
||||
ImageProviderDalleConfig, ImageProviderFluxConfig, ImageProviderImagenConfig,
|
||||
ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig, LarkConfig, LinkedInConfig,
|
||||
LinkedInContentConfig, LinkedInImageConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
|
||||
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
|
||||
DEFAULT_GWS_SERVICES,
|
||||
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
|
||||
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
|
||||
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
|
||||
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
|
||||
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
|
||||
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
|
||||
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
@@ -58,6 +59,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let discord = DiscordConfig {
|
||||
@@ -67,6 +69,7 @@ mod tests {
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let lark = LarkConfig {
|
||||
@@ -79,6 +82,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let feishu = FeishuConfig {
|
||||
app_id: "app-id".into(),
|
||||
@@ -88,6 +92,7 @@ mod tests {
|
||||
allowed_users: vec![],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let nextcloud_talk = NextcloudTalkConfig {
|
||||
@@ -95,6 +100,7 @@ mod tests {
|
||||
app_token: "app-token".into(),
|
||||
webhook_secret: None,
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
assert_eq!(telegram.allowed_users.len(), 1);
|
||||
|
||||
+422
-2
@@ -165,6 +165,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub agent: AgentConfig,
|
||||
|
||||
/// Pacing controls for slow/local LLM workloads (`[pacing]`).
|
||||
#[serde(default)]
|
||||
pub pacing: PacingConfig,
|
||||
|
||||
/// Skills loading and community repository behavior (`[skills]`).
|
||||
#[serde(default)]
|
||||
pub skills: SkillsConfig,
|
||||
@@ -371,6 +375,10 @@ pub struct Config {
|
||||
/// Verifiable Intent (VI) credential verification and issuance (`[verifiable_intent]`).
|
||||
#[serde(default)]
|
||||
pub verifiable_intent: VerifiableIntentConfig,
|
||||
|
||||
/// Claude Code tool configuration (`[claude_code]`).
|
||||
#[serde(default)]
|
||||
pub claude_code: ClaudeCodeConfig,
|
||||
}
|
||||
|
||||
/// Multi-client workspace isolation configuration.
|
||||
@@ -515,6 +523,10 @@ pub struct DelegateAgentConfig {
|
||||
/// When `None`, falls back to `[delegate].agentic_timeout_secs` (default: 300).
|
||||
#[serde(default)]
|
||||
pub agentic_timeout_secs: Option<u64>,
|
||||
/// Optional skills directory path (relative to workspace root) for scoped skill loading.
|
||||
/// When unset or empty, the sub-agent falls back to the default workspace `skills/` directory.
|
||||
#[serde(default)]
|
||||
pub skills_directory: Option<String>,
|
||||
}
|
||||
|
||||
fn default_delegate_timeout_secs() -> u64 {
|
||||
@@ -784,6 +796,9 @@ pub struct TranscriptionConfig {
|
||||
/// Google Cloud Speech-to-Text provider configuration.
|
||||
#[serde(default)]
|
||||
pub google: Option<GoogleSttConfig>,
|
||||
/// Local/self-hosted Whisper-compatible STT provider.
|
||||
#[serde(default)]
|
||||
pub local_whisper: Option<LocalWhisperConfig>,
|
||||
}
|
||||
|
||||
impl Default for TranscriptionConfig {
|
||||
@@ -801,6 +816,7 @@ impl Default for TranscriptionConfig {
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1169,6 +1185,35 @@ pub struct GoogleSttConfig {
|
||||
pub language_code: String,
|
||||
}
|
||||
|
||||
/// Local/self-hosted Whisper-compatible STT endpoint (`[transcription.local_whisper]`).
|
||||
///
|
||||
/// Audio is sent over WireGuard; never leaves the platform perimeter.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct LocalWhisperConfig {
|
||||
/// HTTP or HTTPS endpoint URL, e.g. `"http://10.10.0.1:8001/v1/transcribe"`.
|
||||
pub url: String,
|
||||
/// Bearer token for endpoint authentication.
|
||||
pub bearer_token: String,
|
||||
/// Maximum audio file size in bytes accepted by this endpoint.
|
||||
/// Defaults to 25 MB — matching the cloud API cap for a safe out-of-the-box
|
||||
/// experience. Self-hosted endpoints can accept much larger files; raise this
|
||||
/// as needed, but note that each transcription call clones the audio buffer
|
||||
/// into a multipart payload, so peak memory per request is ~2× this value.
|
||||
#[serde(default = "default_local_whisper_max_audio_bytes")]
|
||||
pub max_audio_bytes: usize,
|
||||
/// Request timeout in seconds. Defaults to 300 (large files on local GPU).
|
||||
#[serde(default = "default_local_whisper_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_local_whisper_max_audio_bytes() -> usize {
|
||||
25 * 1024 * 1024
|
||||
}
|
||||
|
||||
fn default_local_whisper_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
/// Agent orchestration configuration (`[agent]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AgentConfig {
|
||||
@@ -1236,6 +1281,43 @@ impl Default for AgentConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pacing ────────────────────────────────────────────────────────
|
||||
|
||||
/// Pacing controls for slow/local LLM workloads (`[pacing]` section).
|
||||
///
|
||||
/// All fields are optional and default to values that preserve existing
|
||||
/// behavior. When set, they extend — not replace — the existing timeout
|
||||
/// and loop-detection subsystems.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PacingConfig {
|
||||
/// Per-step timeout in seconds: the maximum time allowed for a single
|
||||
/// LLM inference turn, independent of the total message budget.
|
||||
/// `None` means no per-step timeout (existing behavior).
|
||||
#[serde(default)]
|
||||
pub step_timeout_secs: Option<u64>,
|
||||
|
||||
/// Minimum elapsed seconds before loop detection activates.
|
||||
/// Tasks completing under this threshold get aggressive loop protection;
|
||||
/// longer-running tasks receive a grace period before the detector starts
|
||||
/// counting. `None` means loop detection is always active (existing behavior).
|
||||
#[serde(default)]
|
||||
pub loop_detection_min_elapsed_secs: Option<u64>,
|
||||
|
||||
/// Tool names excluded from identical-output / alternating-pattern loop
|
||||
/// detection. Useful for browser workflows where `browser_screenshot`
|
||||
/// structurally resembles a loop even when making progress.
|
||||
#[serde(default)]
|
||||
pub loop_ignore_tools: Vec<String>,
|
||||
|
||||
/// Override for the hardcoded timeout scaling cap (default: 4).
|
||||
/// The channel message timeout budget is computed as:
|
||||
/// `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)`
|
||||
/// Raising this value lets long multi-step tasks with slow local models
|
||||
/// receive a proportionally larger budget without inflating the base timeout.
|
||||
#[serde(default)]
|
||||
pub message_timeout_scale_max: Option<u64>,
|
||||
}
|
||||
|
||||
/// Skills loading configuration (`[skills]` section).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -1611,6 +1693,12 @@ pub struct GatewayConfig {
|
||||
#[serde(default)]
|
||||
pub trust_forwarded_headers: bool,
|
||||
|
||||
/// Optional URL path prefix for reverse-proxy deployments.
|
||||
/// When set, all gateway routes are served under this prefix.
|
||||
/// Must start with `/` and must not end with `/`.
|
||||
#[serde(default)]
|
||||
pub path_prefix: Option<String>,
|
||||
|
||||
/// Maximum distinct client keys tracked by gateway rate limiter maps.
|
||||
#[serde(default = "default_gateway_rate_limit_max_keys")]
|
||||
pub rate_limit_max_keys: usize,
|
||||
@@ -1683,6 +1771,7 @@ impl Default for GatewayConfig {
|
||||
pair_rate_limit_per_minute: default_pair_rate_limit(),
|
||||
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
|
||||
trust_forwarded_headers: false,
|
||||
path_prefix: None,
|
||||
rate_limit_max_keys: default_gateway_rate_limit_max_keys(),
|
||||
idempotency_ttl_secs: default_idempotency_ttl_secs(),
|
||||
idempotency_max_keys: default_gateway_idempotency_max_keys(),
|
||||
@@ -2874,6 +2963,60 @@ impl Default for ImageProviderFluxConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude Code ─────────────────────────────────────────────────
|
||||
|
||||
/// Claude Code CLI tool configuration (`[claude_code]` section).
|
||||
///
|
||||
/// Delegates coding tasks to the `claude -p` CLI. Authentication uses the
|
||||
/// binary's own OAuth session (Max subscription) by default — no API key
|
||||
/// needed unless `env_passthrough` includes `ANTHROPIC_API_KEY`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ClaudeCodeConfig {
|
||||
/// Enable the `claude_code` tool
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Maximum execution time in seconds (coding tasks can be long)
|
||||
#[serde(default = "default_claude_code_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
/// Claude Code tools the subprocess is allowed to use
|
||||
#[serde(default = "default_claude_code_allowed_tools")]
|
||||
pub allowed_tools: Vec<String>,
|
||||
/// Optional system prompt appended to Claude Code invocations
|
||||
#[serde(default)]
|
||||
pub system_prompt: Option<String>,
|
||||
/// Maximum output size in bytes (2MB default)
|
||||
#[serde(default = "default_claude_code_max_output_bytes")]
|
||||
pub max_output_bytes: usize,
|
||||
/// Extra env vars passed to the claude subprocess (e.g. ANTHROPIC_API_KEY for API-key billing)
|
||||
#[serde(default)]
|
||||
pub env_passthrough: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_claude_code_timeout_secs() -> u64 {
|
||||
600
|
||||
}
|
||||
|
||||
fn default_claude_code_allowed_tools() -> Vec<String> {
|
||||
vec!["Read".into(), "Edit".into(), "Bash".into(), "Write".into()]
|
||||
}
|
||||
|
||||
fn default_claude_code_max_output_bytes() -> usize {
|
||||
2_097_152
|
||||
}
|
||||
|
||||
impl Default for ClaudeCodeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
timeout_secs: default_claude_code_timeout_secs(),
|
||||
allowed_tools: default_claude_code_allowed_tools(),
|
||||
system_prompt: None,
|
||||
max_output_bytes: default_claude_code_max_output_bytes(),
|
||||
env_passthrough: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Proxy ───────────────────────────────────────────────────────
|
||||
|
||||
/// Proxy application scope — determines which outbound traffic uses the proxy.
|
||||
@@ -3381,6 +3524,116 @@ pub fn build_runtime_proxy_client_with_timeouts(
|
||||
client
|
||||
}
|
||||
|
||||
/// Build an HTTP client for a channel, using an explicit per-channel proxy URL
|
||||
/// when configured. Falls back to the global runtime proxy when `proxy_url` is
|
||||
/// `None` or empty.
|
||||
pub fn build_channel_proxy_client(service_key: &str, proxy_url: Option<&str>) -> reqwest::Client {
|
||||
match normalize_proxy_url_option(proxy_url) {
|
||||
Some(url) => build_explicit_proxy_client(service_key, &url, None, None),
|
||||
None => build_runtime_proxy_client(service_key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an HTTP client for a channel with custom timeouts, using an explicit
|
||||
/// per-channel proxy URL when configured. Falls back to the global runtime
|
||||
/// proxy when `proxy_url` is `None` or empty.
|
||||
pub fn build_channel_proxy_client_with_timeouts(
|
||||
service_key: &str,
|
||||
proxy_url: Option<&str>,
|
||||
timeout_secs: u64,
|
||||
connect_timeout_secs: u64,
|
||||
) -> reqwest::Client {
|
||||
match normalize_proxy_url_option(proxy_url) {
|
||||
Some(url) => build_explicit_proxy_client(
|
||||
service_key,
|
||||
&url,
|
||||
Some(timeout_secs),
|
||||
Some(connect_timeout_secs),
|
||||
),
|
||||
None => build_runtime_proxy_client_with_timeouts(
|
||||
service_key,
|
||||
timeout_secs,
|
||||
connect_timeout_secs,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply an explicit proxy URL to a `reqwest::ClientBuilder`, returning the
|
||||
/// modified builder. Used by channels that specify a per-channel `proxy_url`.
|
||||
pub fn apply_channel_proxy_to_builder(
|
||||
builder: reqwest::ClientBuilder,
|
||||
service_key: &str,
|
||||
proxy_url: Option<&str>,
|
||||
) -> reqwest::ClientBuilder {
|
||||
match normalize_proxy_url_option(proxy_url) {
|
||||
Some(url) => apply_explicit_proxy_to_builder(builder, service_key, &url),
|
||||
None => apply_runtime_proxy_to_builder(builder, service_key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a client with a single explicit proxy URL (http+https via `Proxy::all`).
|
||||
fn build_explicit_proxy_client(
|
||||
service_key: &str,
|
||||
proxy_url: &str,
|
||||
timeout_secs: Option<u64>,
|
||||
connect_timeout_secs: Option<u64>,
|
||||
) -> reqwest::Client {
|
||||
let cache_key = format!(
|
||||
"explicit|{}|{}|timeout={}|connect_timeout={}",
|
||||
service_key.trim().to_ascii_lowercase(),
|
||||
proxy_url,
|
||||
timeout_secs
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "none".to_string()),
|
||||
connect_timeout_secs
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "none".to_string()),
|
||||
);
|
||||
if let Some(client) = runtime_proxy_cached_client(&cache_key) {
|
||||
return client;
|
||||
}
|
||||
|
||||
let mut builder = reqwest::Client::builder();
|
||||
if let Some(t) = timeout_secs {
|
||||
builder = builder.timeout(std::time::Duration::from_secs(t));
|
||||
}
|
||||
if let Some(ct) = connect_timeout_secs {
|
||||
builder = builder.connect_timeout(std::time::Duration::from_secs(ct));
|
||||
}
|
||||
builder = apply_explicit_proxy_to_builder(builder, service_key, proxy_url);
|
||||
let client = builder.build().unwrap_or_else(|error| {
|
||||
tracing::warn!(
|
||||
service_key,
|
||||
proxy_url,
|
||||
"Failed to build channel proxy client: {error}"
|
||||
);
|
||||
reqwest::Client::new()
|
||||
});
|
||||
set_runtime_proxy_cached_client(cache_key, client.clone());
|
||||
client
|
||||
}
|
||||
|
||||
/// Apply a single explicit proxy URL to a builder via `Proxy::all`.
|
||||
fn apply_explicit_proxy_to_builder(
|
||||
mut builder: reqwest::ClientBuilder,
|
||||
service_key: &str,
|
||||
proxy_url: &str,
|
||||
) -> reqwest::ClientBuilder {
|
||||
match reqwest::Proxy::all(proxy_url) {
|
||||
Ok(proxy) => {
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
proxy_url,
|
||||
service_key,
|
||||
"Ignoring invalid channel proxy_url: {error}"
|
||||
);
|
||||
}
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
fn parse_proxy_scope(raw: &str) -> Option<ProxyScope> {
|
||||
match raw.trim().to_ascii_lowercase().as_str() {
|
||||
"environment" | "env" => Some(ProxyScope::Environment),
|
||||
@@ -4885,6 +5138,10 @@ pub struct TelegramConfig {
|
||||
/// explicitly, it takes precedence.
|
||||
#[serde(default)]
|
||||
pub ack_reactions: Option<bool>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for TelegramConfig {
|
||||
@@ -4918,6 +5175,10 @@ pub struct DiscordConfig {
|
||||
/// Other messages in the guild are silently ignored.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for DiscordConfig {
|
||||
@@ -4954,6 +5215,10 @@ pub struct SlackConfig {
|
||||
/// Direct messages remain allowed.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for SlackConfig {
|
||||
@@ -4989,6 +5254,10 @@ pub struct MattermostConfig {
|
||||
/// cancels the in-flight request and starts a fresh response with preserved history.
|
||||
#[serde(default)]
|
||||
pub interrupt_on_new_message: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for MattermostConfig {
|
||||
@@ -5101,6 +5370,10 @@ pub struct SignalConfig {
|
||||
/// Skip incoming story messages.
|
||||
#[serde(default)]
|
||||
pub ignore_stories: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for SignalConfig {
|
||||
@@ -5195,6 +5468,10 @@ pub struct WhatsAppConfig {
|
||||
/// user's own self-chat (Notes to Self). Defaults to false.
|
||||
#[serde(default)]
|
||||
pub self_chat_mode: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for WhatsAppConfig {
|
||||
@@ -5243,6 +5520,10 @@ pub struct WatiConfig {
|
||||
/// Allowed phone numbers (E.164 format) or "*" for all.
|
||||
#[serde(default)]
|
||||
pub allowed_numbers: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
fn default_wati_api_url() -> String {
|
||||
@@ -5273,6 +5554,10 @@ pub struct NextcloudTalkConfig {
|
||||
/// Allowed Nextcloud actor IDs (`[]` = deny all, `"*"` = allow all).
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for NextcloudTalkConfig {
|
||||
@@ -5400,6 +5685,10 @@ pub struct LarkConfig {
|
||||
/// Not required (and ignored) for websocket mode.
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for LarkConfig {
|
||||
@@ -5434,6 +5723,10 @@ pub struct FeishuConfig {
|
||||
/// Not required (and ignored) for websocket mode.
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for FeishuConfig {
|
||||
@@ -5895,6 +6188,10 @@ pub struct DingTalkConfig {
|
||||
/// Allowed user IDs (staff IDs). Empty = deny all, "*" = allow all
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for DingTalkConfig {
|
||||
@@ -5935,6 +6232,10 @@ pub struct QQConfig {
|
||||
/// Allowed user IDs. Empty = deny all, "*" = allow all
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for QQConfig {
|
||||
@@ -6467,6 +6768,7 @@ impl Default for Config {
|
||||
reliability: ReliabilityConfig::default(),
|
||||
scheduler: SchedulerConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
pacing: PacingConfig::default(),
|
||||
skills: SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -6512,6 +6814,7 @@ impl Default for Config {
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7571,6 +7874,31 @@ impl Config {
|
||||
if self.gateway.host.trim().is_empty() {
|
||||
anyhow::bail!("gateway.host must not be empty");
|
||||
}
|
||||
if let Some(ref prefix) = self.gateway.path_prefix {
|
||||
// Validate the raw value — no silent trimming so the stored
|
||||
// value is exactly what was validated.
|
||||
if !prefix.is_empty() {
|
||||
if !prefix.starts_with('/') {
|
||||
anyhow::bail!("gateway.path_prefix must start with '/'");
|
||||
}
|
||||
if prefix.ends_with('/') {
|
||||
anyhow::bail!("gateway.path_prefix must not end with '/' (including bare '/')");
|
||||
}
|
||||
// Reject characters unsafe for URL paths or HTML/JS injection.
|
||||
// Whitespace is intentionally excluded from the allowed set.
|
||||
if let Some(bad) = prefix.chars().find(|c| {
|
||||
!matches!(c, '/' | '-' | '_' | '.' | '~'
|
||||
| 'a'..='z' | 'A'..='Z' | '0'..='9'
|
||||
| '!' | '$' | '&' | '\'' | '(' | ')' | '*' | '+' | ',' | ';' | '='
|
||||
| ':' | '@')
|
||||
}) {
|
||||
anyhow::bail!(
|
||||
"gateway.path_prefix contains invalid character '{bad}'; \
|
||||
only unreserved and sub-delim URI characters are allowed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Autonomy
|
||||
if self.autonomy.max_actions_per_hour == 0 {
|
||||
@@ -8081,10 +8409,10 @@ impl Config {
|
||||
{
|
||||
let dp = self.transcription.default_provider.trim();
|
||||
match dp {
|
||||
"groq" | "openai" | "deepgram" | "assemblyai" | "google" => {}
|
||||
"groq" | "openai" | "deepgram" | "assemblyai" | "google" | "local_whisper" => {}
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google (got '{other}')"
|
||||
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google, local_whisper (got '{other}')"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -9335,6 +9663,7 @@ default_temperature = 0.7
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
}),
|
||||
discord: None,
|
||||
slack: None,
|
||||
@@ -9386,6 +9715,7 @@ default_temperature = 0.7
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
pacing: PacingConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
@@ -9407,6 +9737,7 @@ default_temperature = 0.7
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
@@ -9656,6 +9987,47 @@ tool_dispatcher = "xml"
|
||||
assert_eq!(parsed.agent.tool_dispatcher, "xml");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn pacing_config_defaults_are_all_none_or_empty() {
|
||||
let cfg = PacingConfig::default();
|
||||
assert!(cfg.step_timeout_secs.is_none());
|
||||
assert!(cfg.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(cfg.loop_ignore_tools.is_empty());
|
||||
assert!(cfg.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn pacing_config_deserializes_from_toml() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
[pacing]
|
||||
step_timeout_secs = 120
|
||||
loop_detection_min_elapsed_secs = 60
|
||||
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
|
||||
message_timeout_scale_max = 8
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(raw).unwrap();
|
||||
assert_eq!(parsed.pacing.step_timeout_secs, Some(120));
|
||||
assert_eq!(parsed.pacing.loop_detection_min_elapsed_secs, Some(60));
|
||||
assert_eq!(
|
||||
parsed.pacing.loop_ignore_tools,
|
||||
vec!["browser_screenshot", "browser_navigate"]
|
||||
);
|
||||
assert_eq!(parsed.pacing.message_timeout_scale_max, Some(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn pacing_config_absent_preserves_defaults() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(raw).unwrap();
|
||||
assert!(parsed.pacing.step_timeout_secs.is_none());
|
||||
assert!(parsed.pacing.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(parsed.pacing.loop_ignore_tools.is_empty());
|
||||
assert!(parsed.pacing.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sync_directory_handles_existing_directory() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
@@ -9724,6 +10096,7 @@ tool_dispatcher = "xml"
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
pacing: PacingConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
@@ -9745,6 +10118,7 @@ tool_dispatcher = "xml"
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await.unwrap();
|
||||
@@ -9789,6 +10163,7 @@ tool_dispatcher = "xml"
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
config.agents.insert(
|
||||
@@ -9805,6 +10180,7 @@ tool_dispatcher = "xml"
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -9930,6 +10306,7 @@ tool_dispatcher = "xml"
|
||||
interrupt_on_new_message: true,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&tc).unwrap();
|
||||
let parsed: TelegramConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -9958,6 +10335,7 @@ tool_dispatcher = "xml"
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -9974,6 +10352,7 @@ tool_dispatcher = "xml"
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -10075,6 +10454,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
allowed_from: vec!["+1111111111".into()],
|
||||
ignore_attachments: true,
|
||||
ignore_stories: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&sc).unwrap();
|
||||
let parsed: SignalConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -10095,6 +10475,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
allowed_from: vec!["*".into()],
|
||||
ignore_attachments: false,
|
||||
ignore_stories: true,
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&sc).unwrap();
|
||||
let parsed: SignalConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -10325,6 +10706,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&wc).unwrap();
|
||||
let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -10349,6 +10731,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&wc).unwrap();
|
||||
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -10378,6 +10761,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&wc).unwrap();
|
||||
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -10399,6 +10783,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
assert!(wc.is_ambiguous_config());
|
||||
assert_eq!(wc.backend_type(), "cloud");
|
||||
@@ -10419,6 +10804,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
assert!(!wc.is_ambiguous_config());
|
||||
assert_eq!(wc.backend_type(), "web");
|
||||
@@ -10449,6 +10835,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
}),
|
||||
linq: None,
|
||||
wati: None,
|
||||
@@ -10553,6 +10940,7 @@ channel_id = "C123"
|
||||
pair_rate_limit_per_minute: 12,
|
||||
webhook_rate_limit_per_minute: 80,
|
||||
trust_forwarded_headers: true,
|
||||
path_prefix: Some("/zeroclaw".into()),
|
||||
rate_limit_max_keys: 2048,
|
||||
idempotency_ttl_secs: 600,
|
||||
idempotency_max_keys: 4096,
|
||||
@@ -10570,6 +10958,7 @@ channel_id = "C123"
|
||||
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
|
||||
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
|
||||
assert!(parsed.trust_forwarded_headers);
|
||||
assert_eq!(parsed.path_prefix.as_deref(), Some("/zeroclaw"));
|
||||
assert_eq!(parsed.rate_limit_max_keys, 2048);
|
||||
assert_eq!(parsed.idempotency_ttl_secs, 600);
|
||||
assert_eq!(parsed.idempotency_max_keys, 4096);
|
||||
@@ -11453,6 +11842,7 @@ default_model = "legacy-model"
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
config.save().await.unwrap();
|
||||
|
||||
@@ -12164,6 +12554,7 @@ default_model = "persisted-profile"
|
||||
use_feishu: true,
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -12187,6 +12578,7 @@ default_model = "persisted-profile"
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -12233,6 +12625,7 @@ default_model = "persisted-profile"
|
||||
allowed_users: vec!["user_123".into(), "user_456".into()],
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&fc).unwrap();
|
||||
let parsed: FeishuConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -12253,6 +12646,7 @@ default_model = "persisted-profile"
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&fc).unwrap();
|
||||
let parsed: FeishuConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -12280,6 +12674,7 @@ default_model = "persisted-profile"
|
||||
app_token: "app-token".into(),
|
||||
webhook_secret: Some("webhook-secret".into()),
|
||||
allowed_users: vec!["user_a".into(), "*".into()],
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&nc).unwrap();
|
||||
@@ -12467,6 +12862,30 @@ require_otp_to_resume = true
|
||||
assert!(err.to_string().contains("gated_domains"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validate_accepts_local_whisper_as_transcription_default_provider() {
|
||||
let mut config = Config::default();
|
||||
config.transcription.default_provider = "local_whisper".to_string();
|
||||
|
||||
config.validate().expect(
|
||||
"local_whisper must be accepted by the transcription.default_provider allowlist",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validate_rejects_unknown_transcription_default_provider() {
|
||||
let mut config = Config::default();
|
||||
config.transcription.default_provider = "unknown_stt".to_string();
|
||||
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("expected validation to reject unknown transcription provider");
|
||||
assert!(
|
||||
err.to_string().contains("transcription.default_provider"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_secret_telegram_bot_token_roundtrip() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
@@ -12488,6 +12907,7 @@ require_otp_to_resume = true
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
// Save (triggers encryption)
|
||||
|
||||
+30
-1
@@ -7,7 +7,7 @@ use std::collections::HashMap;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
/// Cost tracker for API usage monitoring and budget enforcement.
|
||||
pub struct CostTracker {
|
||||
@@ -175,6 +175,35 @@ impl CostTracker {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Process-global singleton ────────────────────────────────────────
|
||||
// Both the gateway and the channels supervisor share a single CostTracker
|
||||
// so that budget enforcement is consistent across all paths.
|
||||
|
||||
static GLOBAL_COST_TRACKER: OnceLock<Option<Arc<CostTracker>>> = OnceLock::new();
|
||||
|
||||
impl CostTracker {
|
||||
/// Return the process-global `CostTracker`, creating it on first call.
|
||||
/// Subsequent calls (from gateway or channels, whichever starts second)
|
||||
/// receive the same `Arc`. Returns `None` when cost tracking is disabled
|
||||
/// or initialisation fails.
|
||||
pub fn get_or_init_global(config: CostConfig, workspace_dir: &Path) -> Option<Arc<Self>> {
|
||||
GLOBAL_COST_TRACKER
|
||||
.get_or_init(|| {
|
||||
if !config.enabled {
|
||||
return None;
|
||||
}
|
||||
match Self::new(config, workspace_dir) {
|
||||
Ok(ct) => Some(Arc::new(ct)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to initialize global cost tracker: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_storage_path(workspace_dir: &Path) -> Result<PathBuf> {
|
||||
let storage_path = workspace_dir.join("state").join("costs.jsonl");
|
||||
let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db");
|
||||
|
||||
@@ -646,6 +646,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -657,6 +658,7 @@ mod tests {
|
||||
client_id: "client_id".into(),
|
||||
client_secret: "client_secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -672,6 +674,7 @@ mod tests {
|
||||
thread_replies: Some(true),
|
||||
mention_only: Some(false),
|
||||
interrupt_on_new_message: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -683,6 +686,7 @@ mod tests {
|
||||
app_id: "app-id".into(),
|
||||
app_secret: "app-secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -695,6 +699,7 @@ mod tests {
|
||||
app_token: "app-token".into(),
|
||||
webhook_secret: None,
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -761,6 +766,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
@@ -778,6 +784,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
|
||||
@@ -1283,6 +1283,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
config.agents.insert(
|
||||
@@ -1299,6 +1300,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
+17
-3
@@ -50,6 +50,10 @@ fn require_auth(
|
||||
pub struct MemoryQuery {
|
||||
pub query: Option<String>,
|
||||
pub category: Option<String>,
|
||||
/// Filter memories created at or after (RFC 3339 / ISO 8601)
|
||||
pub since: Option<String>,
|
||||
/// Filter memories created at or before (RFC 3339 / ISO 8601)
|
||||
pub until: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -633,9 +637,12 @@ pub async fn handle_api_memory_list(
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
if let Some(ref query) = params.query {
|
||||
// Search mode
|
||||
match state.mem.recall(query, 50, None).await {
|
||||
// Use recall when query or time range is provided
|
||||
if params.query.is_some() || params.since.is_some() || params.until.is_some() {
|
||||
let query = params.query.as_deref().unwrap_or("");
|
||||
let since = params.since.as_deref();
|
||||
let until = params.until.as_deref();
|
||||
match state.mem.recall(query, 50, None, since, until).await {
|
||||
Ok(entries) => Json(serde_json::json!({"entries": entries})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@@ -1356,6 +1363,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -1429,6 +1438,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
path_prefix: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1457,6 +1467,7 @@ mod tests {
|
||||
api_url: "https://live-mt-server.wati.io".to_string(),
|
||||
tenant_id: None,
|
||||
allowed_numbers: vec![],
|
||||
proxy_url: None,
|
||||
});
|
||||
cfg.channels_config.feishu = Some(crate::config::schema::FeishuConfig {
|
||||
app_id: "cli_aabbcc".to_string(),
|
||||
@@ -1466,6 +1477,7 @@ mod tests {
|
||||
allowed_users: vec!["*".to_string()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
cfg.channels_config.email = Some(crate::channels::email_channel::EmailConfig {
|
||||
imap_host: "imap.example.com".to_string(),
|
||||
@@ -1591,6 +1603,7 @@ mod tests {
|
||||
api_url: "https://live-mt-server.wati.io".to_string(),
|
||||
tenant_id: None,
|
||||
allowed_numbers: vec![],
|
||||
proxy_url: None,
|
||||
});
|
||||
current.channels_config.feishu = Some(crate::config::schema::FeishuConfig {
|
||||
app_id: "cli_current".to_string(),
|
||||
@@ -1600,6 +1613,7 @@ mod tests {
|
||||
allowed_users: vec!["*".to_string()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
current.channels_config.email = Some(crate::channels::email_channel::EmailConfig {
|
||||
imap_host: "imap.example.com".to_string(),
|
||||
|
||||
+61
-34
@@ -348,6 +348,8 @@ pub struct AppState {
|
||||
pub shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||
/// Registry of dynamically connected nodes
|
||||
pub node_registry: Arc<nodes::NodeRegistry>,
|
||||
/// Path prefix for reverse-proxy deployments (empty string = no prefix)
|
||||
pub path_prefix: String,
|
||||
/// Session backend for persisting gateway WS chat sessions
|
||||
pub session_backend: Option<Arc<dyn SessionBackend>>,
|
||||
/// Device registry for paired device management
|
||||
@@ -505,18 +507,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
let tools_registry: Arc<Vec<ToolSpec>> =
|
||||
Arc::new(tools_registry_raw.iter().map(|t| t.spec()).collect());
|
||||
|
||||
// Cost tracker (optional)
|
||||
let cost_tracker = if config.cost.enabled {
|
||||
match CostTracker::new(config.cost.clone(), &config.workspace_dir) {
|
||||
Ok(ct) => Some(Arc::new(ct)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to initialize cost tracker: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Cost tracker — process-global singleton so channels share the same instance
|
||||
let cost_tracker = CostTracker::get_or_init_global(config.cost.clone(), &config.workspace_dir);
|
||||
|
||||
// SSE broadcast channel for real-time events
|
||||
let (event_tx, _event_rx) = tokio::sync::broadcast::channel::<serde_json::Value>(256);
|
||||
@@ -683,6 +675,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
idempotency_max_keys,
|
||||
));
|
||||
|
||||
// Resolve optional path prefix for reverse-proxy deployments.
|
||||
let path_prefix: Option<&str> = config
|
||||
.gateway
|
||||
.path_prefix
|
||||
.as_deref()
|
||||
.filter(|p| !p.is_empty());
|
||||
|
||||
// ── Tunnel ────────────────────────────────────────────────
|
||||
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
||||
let mut tunnel_url: Option<String> = None;
|
||||
@@ -701,18 +700,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
|
||||
let pfx = path_prefix.unwrap_or("");
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}{pfx}");
|
||||
if let Some(ref url) = tunnel_url {
|
||||
println!(" 🌐 Public URL: {url}");
|
||||
}
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}/");
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}{pfx}/");
|
||||
if let Some(code) = pairing.pairing_code() {
|
||||
println!();
|
||||
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
|
||||
println!(" ┌──────────────┐");
|
||||
println!(" │ {code} │");
|
||||
println!(" └──────────────┘");
|
||||
println!();
|
||||
println!(" Send: POST {pfx}/pair with header X-Pairing-Code: {code}");
|
||||
} else if pairing.require_pairing() {
|
||||
println!(" 🔒 Pairing: ACTIVE (bearer token required)");
|
||||
println!(" To pair a new device: zeroclaw gateway get-paircode --new");
|
||||
@@ -721,29 +721,29 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||
println!();
|
||||
}
|
||||
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||
println!(" POST {pfx}/pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST {pfx}/webhook — {{\"message\": \"your prompt\"}}");
|
||||
if whatsapp_channel.is_some() {
|
||||
println!(" GET /whatsapp — Meta webhook verification");
|
||||
println!(" POST /whatsapp — WhatsApp message webhook");
|
||||
println!(" GET {pfx}/whatsapp — Meta webhook verification");
|
||||
println!(" POST {pfx}/whatsapp — WhatsApp message webhook");
|
||||
}
|
||||
if linq_channel.is_some() {
|
||||
println!(" POST /linq — Linq message webhook (iMessage/RCS/SMS)");
|
||||
println!(" POST {pfx}/linq — Linq message webhook (iMessage/RCS/SMS)");
|
||||
}
|
||||
if wati_channel.is_some() {
|
||||
println!(" GET /wati — WATI webhook verification");
|
||||
println!(" POST /wati — WATI message webhook");
|
||||
println!(" GET {pfx}/wati — WATI webhook verification");
|
||||
println!(" POST {pfx}/wati — WATI message webhook");
|
||||
}
|
||||
if nextcloud_talk_channel.is_some() {
|
||||
println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook");
|
||||
println!(" POST {pfx}/nextcloud-talk — Nextcloud Talk bot webhook");
|
||||
}
|
||||
println!(" GET /api/* — REST API (bearer token required)");
|
||||
println!(" GET /ws/chat — WebSocket agent chat");
|
||||
println!(" GET {pfx}/api/* — REST API (bearer token required)");
|
||||
println!(" GET {pfx}/ws/chat — WebSocket agent chat");
|
||||
if config.nodes.enabled {
|
||||
println!(" GET /ws/nodes — WebSocket node discovery");
|
||||
println!(" GET {pfx}/ws/nodes — WebSocket node discovery");
|
||||
}
|
||||
println!(" GET /health — health check");
|
||||
println!(" GET /metrics — Prometheus metrics");
|
||||
println!(" GET {pfx}/health — health check");
|
||||
println!(" GET {pfx}/metrics — Prometheus metrics");
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
crate::health::mark_component_ok("gateway");
|
||||
@@ -809,6 +809,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
session_backend,
|
||||
device_registry,
|
||||
pending_pairings,
|
||||
path_prefix: path_prefix.unwrap_or("").to_string(),
|
||||
};
|
||||
|
||||
// Config PUT needs larger body limit (1MB)
|
||||
@@ -817,7 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.layer(RequestBodyLimitLayer::new(1_048_576));
|
||||
|
||||
// Build router with middleware
|
||||
let app = Router::new()
|
||||
let inner = Router::new()
|
||||
// ── Admin routes (for CLI management) ──
|
||||
.route("/admin/shutdown", post(handle_admin_shutdown))
|
||||
.route("/admin/paircode", get(handle_admin_paircode))
|
||||
@@ -877,12 +878,12 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
|
||||
// ── Plugin management API (requires plugins-wasm feature) ──
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
let app = app.route(
|
||||
let inner = inner.route(
|
||||
"/api/plugins",
|
||||
get(api_plugins::plugin_routes::list_plugins),
|
||||
);
|
||||
|
||||
let app = app
|
||||
let inner = inner
|
||||
// ── SSE event stream ──
|
||||
.route("/api/events", get(sse::handle_sse_events))
|
||||
// ── WebSocket agent chat ──
|
||||
@@ -893,14 +894,27 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/_app/{*path}", get(static_files::handle_static))
|
||||
// ── Config PUT with larger body limit ──
|
||||
.merge(config_put_router)
|
||||
// ── SPA fallback: non-API GET requests serve index.html ──
|
||||
.fallback(get(static_files::handle_spa_fallback))
|
||||
.with_state(state)
|
||||
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
|
||||
.layer(TimeoutLayer::with_status_code(
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
Duration::from_secs(gateway_request_timeout_secs()),
|
||||
))
|
||||
// ── SPA fallback: non-API GET requests serve index.html ──
|
||||
.fallback(get(static_files::handle_spa_fallback));
|
||||
));
|
||||
|
||||
// Nest under path prefix when configured (axum strips prefix before routing).
|
||||
// nest() at "/prefix" handles both "/prefix" and "/prefix/*" but not "/prefix/"
|
||||
// with a trailing slash, so we add a fallback redirect for that case.
|
||||
let app = if let Some(prefix) = path_prefix {
|
||||
let redirect_target = prefix.to_string();
|
||||
Router::new().nest(prefix, inner).route(
|
||||
&format!("{prefix}/"),
|
||||
get(|| async move { axum::response::Redirect::permanent(&redirect_target) }),
|
||||
)
|
||||
} else {
|
||||
inner
|
||||
};
|
||||
|
||||
// Run the server with graceful shutdown
|
||||
axum::serve(
|
||||
@@ -1992,6 +2006,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2047,6 +2062,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2287,6 +2303,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -2362,6 +2380,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -2427,6 +2447,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2496,6 +2517,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2577,6 +2599,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2630,6 +2653,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2688,6 +2712,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2751,6 +2776,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2810,6 +2836,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
//! Uses `rust-embed` to bundle the `web/dist/` directory into the binary at compile time.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{header, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use rust_embed::Embed;
|
||||
|
||||
use super::AppState;
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "web/dist/"]
|
||||
struct WebAssets;
|
||||
@@ -23,16 +26,41 @@ pub async fn handle_static(uri: Uri) -> Response {
|
||||
serve_embedded_file(path)
|
||||
}
|
||||
|
||||
/// SPA fallback: serve index.html for any non-API, non-static GET request
|
||||
pub async fn handle_spa_fallback() -> Response {
|
||||
if WebAssets::get("index.html").is_none() {
|
||||
/// SPA fallback: serve index.html for any non-API, non-static GET request.
|
||||
/// Injects `window.__ZEROCLAW_BASE__` so the frontend knows the path prefix.
|
||||
pub async fn handle_spa_fallback(State(state): State<AppState>) -> Response {
|
||||
let Some(content) = WebAssets::get("index.html") else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"Web dashboard not available. Build it with: cd web && npm ci && npm run build",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
serve_embedded_file("index.html")
|
||||
};
|
||||
|
||||
let html = String::from_utf8_lossy(&content.data);
|
||||
|
||||
// Inject path prefix for the SPA and rewrite asset paths in the HTML
|
||||
let html = if state.path_prefix.is_empty() {
|
||||
html.into_owned()
|
||||
} else {
|
||||
let pfx = &state.path_prefix;
|
||||
// JSON-encode the prefix to safely embed in a <script> block
|
||||
let json_pfx = serde_json::to_string(pfx).unwrap_or_else(|_| "\"\"".to_string());
|
||||
let script = format!("<script>window.__ZEROCLAW_BASE__={json_pfx};</script>");
|
||||
// Rewrite absolute /_app/ references so the browser requests {prefix}/_app/...
|
||||
html.replace("/_app/", &format!("{pfx}/_app/"))
|
||||
.replace("<head>", &format!("<head>{script}"))
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, "text/html; charset=utf-8".to_string()),
|
||||
(header::CACHE_CONTROL, "no-cache".to_string()),
|
||||
],
|
||||
html,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn serve_embedded_file(path: &str) -> Response {
|
||||
|
||||
@@ -407,7 +407,11 @@ mod tests {
|
||||
// Simpler: write a temp script.
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let script_path = dir.path().join("tool.sh");
|
||||
std::fs::write(&script_path, format!("#!/bin/sh\necho '{}'\n", result_json)).unwrap();
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
format!("#!/bin/sh\ncat > /dev/null\necho '{}'\n", result_json),
|
||||
)
|
||||
.unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
@@ -841,6 +841,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
let entries = all_integrations();
|
||||
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
|
||||
|
||||
+47
-7
@@ -325,8 +325,27 @@ impl Memory for LucidMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||
let since_dt = since
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'since' date (expected RFC 3339): {e}"))?;
|
||||
let until_dt = until
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'until' date (expected RFC 3339): {e}"))?;
|
||||
if let (Some(s), Some(u)) = (&since_dt, &until_dt) {
|
||||
if s >= u {
|
||||
anyhow::bail!("'since' must be before 'until'");
|
||||
}
|
||||
}
|
||||
|
||||
let local_results = self
|
||||
.local
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await?;
|
||||
if limit == 0
|
||||
|| local_results.len() >= limit
|
||||
|| local_results.len() >= self.local_hit_threshold
|
||||
@@ -341,7 +360,28 @@ impl Memory for LucidMemory {
|
||||
match self.recall_from_lucid(query).await {
|
||||
Ok(lucid_results) if !lucid_results.is_empty() => {
|
||||
self.clear_failure();
|
||||
Ok(Self::merge_results(local_results, lucid_results, limit))
|
||||
let merged = Self::merge_results(local_results, lucid_results, limit);
|
||||
let filtered: Vec<MemoryEntry> = merged
|
||||
.into_iter()
|
||||
.filter(|e| {
|
||||
if let Some(ref s) = since_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&e.timestamp) {
|
||||
if ts < *s {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(ref u) = until_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&e.timestamp) {
|
||||
if ts > *u {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
Ok(_) => {
|
||||
self.clear_failure();
|
||||
@@ -541,7 +581,7 @@ exit 1
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
let entries = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
@@ -565,7 +605,7 @@ exit 1
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
let entries = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
@@ -603,7 +643,7 @@ exit 1
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||
let entries = memory.recall("rust", 5, None, None, None).await.unwrap();
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||
@@ -663,8 +703,8 @@ exit 1
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||
let first = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
|
||||
assert!(first.is_empty());
|
||||
assert!(second.is_empty());
|
||||
|
||||
+47
-6
@@ -158,7 +158,23 @@ impl Memory for MarkdownMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let since_dt = since
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'since' date (expected RFC 3339): {e}"))?;
|
||||
let until_dt = until
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'until' date (expected RFC 3339): {e}"))?;
|
||||
if let (Some(s), Some(u)) = (&since_dt, &until_dt) {
|
||||
if s >= u {
|
||||
anyhow::bail!("'since' must be before 'until'");
|
||||
}
|
||||
}
|
||||
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
@@ -166,6 +182,24 @@ impl Memory for MarkdownMemory {
|
||||
let mut scored: Vec<MemoryEntry> = all
|
||||
.into_iter()
|
||||
.filter_map(|mut entry| {
|
||||
if let Some(ref s) = since_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
if ts < *s {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(ref u) = until_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
if ts > *u {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
if keywords.is_empty() {
|
||||
entry.score = Some(1.0);
|
||||
return Some(entry);
|
||||
}
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let matched = keywords
|
||||
.iter()
|
||||
@@ -183,9 +217,13 @@ impl Memory for MarkdownMemory {
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
if keywords.is_empty() {
|
||||
b.timestamp.as_str().cmp(a.timestamp.as_str())
|
||||
} else {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
});
|
||||
scored.truncate(limit);
|
||||
Ok(scored)
|
||||
@@ -283,7 +321,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
@@ -296,7 +334,10 @@ mod tests {
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("javascript", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
@@ -343,7 +384,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||
let results = mem.recall("anything", 10, None, None, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
||||
+5
-1
@@ -364,14 +364,18 @@ impl Memory for Mem0Memory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// mem0 handles filtering server-side; since/until are not yet
|
||||
// supported by the mem0 API, so we pass them through as no-ops.
|
||||
self.recall_filtered(query, limit, session_id, None, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
// mem0 doesn't have a get-by-key API, so we search by key in metadata
|
||||
let results = self.recall(key, 1, None).await?;
|
||||
let results = self.recall(key, 1, None, None, None).await?;
|
||||
Ok(results.into_iter().find(|e| e.key == key))
|
||||
}
|
||||
|
||||
|
||||
+7
-1
@@ -35,6 +35,8 @@ impl Memory for NoneMemory {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -78,7 +80,11 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
assert!(memory.get("k").await.unwrap().is_none());
|
||||
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||
assert!(memory
|
||||
.recall("k", 10, None, None, None)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||
assert!(!memory.forget("k").await.unwrap());
|
||||
assert_eq!(memory.count().await.unwrap(), 0);
|
||||
|
||||
+23
-1
@@ -239,14 +239,30 @@ impl Memory for PostgresMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
let client = self.client.clone();
|
||||
let qualified_table = self.qualified_table.clone();
|
||||
let query = query.trim().to_string();
|
||||
let sid = session_id.map(str::to_string);
|
||||
let since_owned = since.map(str::to_string);
|
||||
let until_owned = until.map(str::to_string);
|
||||
|
||||
run_on_os_thread(move || -> Result<Vec<MemoryEntry>> {
|
||||
let mut client = client.lock();
|
||||
let since_ref = since_owned.as_deref();
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let time_filter: String = match (since_ref, until_ref) {
|
||||
(Some(_), Some(_)) => {
|
||||
" AND created_at >= $4::TIMESTAMPTZ AND created_at <= $5::TIMESTAMPTZ".into()
|
||||
}
|
||||
(Some(_), None) => " AND created_at >= $4::TIMESTAMPTZ".into(),
|
||||
(None, Some(_)) => " AND created_at <= $4::TIMESTAMPTZ".into(),
|
||||
(None, None) => String::new(),
|
||||
};
|
||||
|
||||
let stmt = format!(
|
||||
"
|
||||
SELECT id, key, content, category, created_at, session_id,
|
||||
@@ -257,6 +273,7 @@ impl Memory for PostgresMemory {
|
||||
FROM {qualified_table}
|
||||
WHERE ($2::TEXT IS NULL OR session_id = $2)
|
||||
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
|
||||
{time_filter}
|
||||
ORDER BY score DESC, updated_at DESC
|
||||
LIMIT $3
|
||||
"
|
||||
@@ -265,7 +282,12 @@ impl Memory for PostgresMemory {
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit_i64 = limit as i64;
|
||||
|
||||
let rows = client.query(&stmt, &[&query, &sid, &limit_i64])?;
|
||||
let rows = match (since_ref, until_ref) {
|
||||
(Some(s), Some(u)) => client.query(&stmt, &[&query, &sid, &limit_i64, &s, &u])?,
|
||||
(Some(s), None) => client.query(&stmt, &[&query, &sid, &limit_i64, &s])?,
|
||||
(None, Some(u)) => client.query(&stmt, &[&query, &sid, &limit_i64, &u])?,
|
||||
(None, None) => client.query(&stmt, &[&query, &sid, &limit_i64])?,
|
||||
};
|
||||
rows.iter()
|
||||
.map(Self::row_to_entry)
|
||||
.collect::<Result<Vec<MemoryEntry>>>()
|
||||
|
||||
+20
-2
@@ -291,9 +291,19 @@ impl Memory for QdrantMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
if query.trim().is_empty() {
|
||||
return self.list(None, session_id).await;
|
||||
let mut entries = self.list(None, session_id).await?;
|
||||
if let Some(s) = since {
|
||||
entries.retain(|e| e.timestamp.as_str() >= s);
|
||||
}
|
||||
if let Some(u) = until {
|
||||
entries.retain(|e| e.timestamp.as_str() <= u);
|
||||
}
|
||||
entries.truncate(limit);
|
||||
return Ok(entries);
|
||||
}
|
||||
|
||||
self.ensure_initialized().await?;
|
||||
@@ -344,7 +354,7 @@ impl Memory for QdrantMemory {
|
||||
|
||||
let result: QdrantSearchResult = resp.json().await?;
|
||||
|
||||
let entries = result
|
||||
let mut entries: Vec<MemoryEntry> = result
|
||||
.result
|
||||
.into_iter()
|
||||
.filter_map(|point| {
|
||||
@@ -367,6 +377,14 @@ impl Memory for QdrantMemory {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Filter by time range if specified
|
||||
if let Some(s) = since {
|
||||
entries.retain(|e| e.timestamp.as_str() >= s);
|
||||
}
|
||||
if let Some(u) = until {
|
||||
entries.retain(|e| e.timestamp.as_str() <= u);
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
|
||||
+167
-36
@@ -428,6 +428,74 @@ impl SqliteMemory {
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// List memories by time range (used when query is empty).
|
||||
async fn recall_by_time_only(
|
||||
&self,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self.conn.clone();
|
||||
let sid = session_id.map(String::from);
|
||||
let since_owned = since.map(String::from);
|
||||
let until_owned = until.map(String::from);
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let since_ref = since_owned.as_deref();
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let mut sql =
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories \
|
||||
WHERE 1=1"
|
||||
.to_string();
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
let mut idx = 1;
|
||||
|
||||
if let Some(sid) = sid.as_deref() {
|
||||
let _ = write!(sql, " AND session_id = ?{idx}");
|
||||
param_values.push(Box::new(sid.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
if let Some(s) = since_ref {
|
||||
let _ = write!(sql, " AND created_at >= ?{idx}");
|
||||
param_values.push(Box::new(s.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
if let Some(u) = until_ref {
|
||||
let _ = write!(sql, " AND created_at <= ?{idx}");
|
||||
param_values.push(Box::new(u.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
let _ = write!(sql, " ORDER BY updated_at DESC LIMIT ?{idx}");
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
param_values.push(Box::new(limit as i64));
|
||||
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
param_values.iter().map(AsRef::as_ref).collect();
|
||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
Ok(results)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -481,9 +549,14 @@ impl Memory for SqliteMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// Time-only query: list by time range when no keywords
|
||||
if query.trim().is_empty() {
|
||||
return Ok(Vec::new());
|
||||
return self
|
||||
.recall_by_time_only(limit, session_id, since, until)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Compute query embedding (async, before blocking work)
|
||||
@@ -492,12 +565,16 @@ impl Memory for SqliteMemory {
|
||||
let conn = self.conn.clone();
|
||||
let query = query.to_string();
|
||||
let sid = session_id.map(String::from);
|
||||
let since_owned = since.map(String::from);
|
||||
let until_owned = until.map(String::from);
|
||||
let vector_weight = self.vector_weight;
|
||||
let keyword_weight = self.keyword_weight;
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let session_ref = sid.as_deref();
|
||||
let since_ref = since_owned.as_deref();
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
// FTS5 BM25 keyword search
|
||||
let keyword_results = Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default();
|
||||
@@ -568,6 +645,16 @@ impl Memory for SqliteMemory {
|
||||
|
||||
for scored in &merged {
|
||||
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||
if let Some(s) = since_ref {
|
||||
if ts.as_str() < s {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Some(u) = until_ref {
|
||||
if ts.as_str() > u {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let entry = MemoryEntry {
|
||||
id: scored.id.clone(),
|
||||
key,
|
||||
@@ -588,8 +675,6 @@ impl Memory for SqliteMemory {
|
||||
}
|
||||
|
||||
// If hybrid returned nothing, fall back to LIKE search.
|
||||
// Cap keyword count so we don't create too many SQL shapes,
|
||||
// which helps prepared-statement cache efficiency.
|
||||
if results.is_empty() {
|
||||
const MAX_LIKE_KEYWORDS: usize = 8;
|
||||
let keywords: Vec<String> = query
|
||||
@@ -606,12 +691,21 @@ impl Memory for SqliteMemory {
|
||||
})
|
||||
.collect();
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let mut param_idx = keywords.len() * 2 + 1;
|
||||
let mut time_conditions = String::new();
|
||||
if since_ref.is_some() {
|
||||
let _ = write!(time_conditions, " AND created_at >= ?{param_idx}");
|
||||
param_idx += 1;
|
||||
}
|
||||
if until_ref.is_some() {
|
||||
let _ = write!(time_conditions, " AND created_at <= ?{param_idx}");
|
||||
param_idx += 1;
|
||||
}
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}
|
||||
WHERE {where_clause}{time_conditions}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
keywords.len() * 2 + 1
|
||||
LIMIT ?{param_idx}"
|
||||
);
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
@@ -619,6 +713,12 @@ impl Memory for SqliteMemory {
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
}
|
||||
if let Some(s) = since_ref {
|
||||
param_values.push(Box::new(s.to_string()));
|
||||
}
|
||||
if let Some(u) = until_ref {
|
||||
param_values.push(Box::new(u.to_string()));
|
||||
}
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
param_values.push(Box::new(limit as i64));
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
@@ -852,7 +952,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
@@ -869,7 +969,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||
let results = mem.recall("fast safe", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
@@ -881,7 +981,10 @@ mod tests {
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("javascript", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
@@ -1024,7 +1127,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
// All results should contain "Rust"
|
||||
for r in &results {
|
||||
@@ -1049,30 +1152,34 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||
let results = mem.recall("quick dog", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// "The quick dog runs fast" matches both terms
|
||||
assert!(results[0].content.contains("quick"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_empty_query_returns_empty() {
|
||||
async fn recall_empty_query_returns_recent_entries() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
// Empty query = time-only mode: returns recent entries
|
||||
let results = mem.recall("", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "a");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_whitespace_query_returns_empty() {
|
||||
async fn recall_whitespace_query_returns_recent_entries() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
// Whitespace-only query = time-only mode: returns recent entries
|
||||
let results = mem.recall(" ", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "a");
|
||||
}
|
||||
|
||||
// ── Embedding cache tests ────────────────────────────────────
|
||||
@@ -1283,7 +1390,7 @@ mod tests {
|
||||
assert_eq!(count, 0);
|
||||
|
||||
// FTS should still work after rebuild
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
@@ -1303,7 +1410,10 @@ mod tests {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("common keyword", 5, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
@@ -1316,7 +1426,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||
let results = mem.recall("scored", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
for r in &results {
|
||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||
@@ -1332,7 +1442,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
// Quotes in query should not crash FTS5
|
||||
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||
let results = mem.recall("\"hello\"", 10, None, None, None).await.unwrap();
|
||||
// May or may not match depending on FTS5 escaping, but must not error
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
@@ -1343,7 +1453,7 @@ mod tests {
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||
let results = mem.recall("wild*", 10, None, None, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
@@ -1353,7 +1463,10 @@ mod tests {
|
||||
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("function()", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
@@ -1365,7 +1478,7 @@ mod tests {
|
||||
.unwrap();
|
||||
// Should not crash or leak data
|
||||
let results = mem
|
||||
.recall("'; DROP TABLE memories; --", 10, None)
|
||||
.recall("'; DROP TABLE memories; --", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
@@ -1441,7 +1554,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
// Single char may not match FTS5 but LIKE fallback should work
|
||||
let results = mem.recall("x", 10, None).await.unwrap();
|
||||
let results = mem.recall("x", 10, None, None, None).await.unwrap();
|
||||
// Should not crash; may or may not find results
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
@@ -1452,7 +1565,7 @@ mod tests {
|
||||
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("some", 0, None).await.unwrap();
|
||||
let results = mem.recall("some", 0, None, None, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
@@ -1465,7 +1578,10 @@ mod tests {
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("matching content", 1, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
@@ -1481,7 +1597,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("rust", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty(), "Should match by key");
|
||||
}
|
||||
|
||||
@@ -1491,7 +1607,7 @@ mod tests {
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||
let results = mem.recall("日本語", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
@@ -1541,7 +1657,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("ghost").await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("phantom memory", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Deleted memory should not appear in recall"
|
||||
@@ -1582,7 +1701,7 @@ mod tests {
|
||||
let count = mem.reindex().await.unwrap();
|
||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||
// Data should still be intact
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
@@ -1686,7 +1805,10 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Recall with session-a filter returns only session-a entry
|
||||
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("fact", 10, Some("sess-a"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||
@@ -1706,7 +1828,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Recall without session filter returns all matching entries
|
||||
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||
let results = mem.recall("fact", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
|
||||
@@ -1723,11 +1845,17 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Session B cannot see session A data
|
||||
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("secret", 10, Some("sess-b"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Session A can see its own data
|
||||
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("secret", 10, Some("sess-a"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
@@ -1778,7 +1906,10 @@ mod tests {
|
||||
// Second open: migration runs again but is idempotent
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("reopen", 10, Some("sess-x"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||
|
||||
@@ -96,11 +96,15 @@ pub trait Memory: Send + Sync {
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||
/// and time range. Time bounds use RFC 3339 / ISO 8601 format
|
||||
/// (e.g. "2025-03-01T00:00:00Z"); inclusive (created_at >= since, created_at <= until).
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
|
||||
@@ -154,6 +154,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
agent: crate::config::schema::AgentConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -199,6 +200,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
claude_code: crate::config::ClaudeCodeConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@@ -575,6 +577,7 @@ async fn run_quick_setup_with_home(
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
agent: crate::config::schema::AgentConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -620,6 +623,7 @@ async fn run_quick_setup_with_home(
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
claude_code: crate::config::ClaudeCodeConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
@@ -3790,6 +3794,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Discord => {
|
||||
@@ -3890,6 +3895,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Slack => {
|
||||
@@ -4020,6 +4026,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
interrupt_on_new_message: false,
|
||||
thread_replies: None,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::IMessage => {
|
||||
@@ -4271,6 +4278,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
allowed_from,
|
||||
ignore_attachments,
|
||||
ignore_stories,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
println!(" {} Signal configured", style("✅").green().bold());
|
||||
@@ -4372,6 +4380,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
println!(
|
||||
@@ -4477,6 +4486,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Linq => {
|
||||
@@ -4810,6 +4820,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
Some(webhook_secret.trim().to_string())
|
||||
},
|
||||
allowed_users,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
println!(" {} Nextcloud Talk configured", style("✅").green().bold());
|
||||
@@ -4882,6 +4893,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
client_id,
|
||||
client_secret,
|
||||
allowed_users,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::QqOfficial => {
|
||||
@@ -4958,6 +4970,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
app_id,
|
||||
app_secret,
|
||||
allowed_users,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Lark | ChannelMenuChoice::Feishu => {
|
||||
@@ -5147,6 +5160,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
use_feishu: is_feishu,
|
||||
receive_mode,
|
||||
port,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
@@ -7511,6 +7525,7 @@ mod tests {
|
||||
allowed_from: vec!["*".into()],
|
||||
ignore_attachments: false,
|
||||
ignore_stories: true,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7523,6 +7538,7 @@ mod tests {
|
||||
thread_replies: Some(true),
|
||||
mention_only: Some(false),
|
||||
interrupt_on_new_message: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7531,6 +7547,7 @@ mod tests {
|
||||
app_id: "app-id".into(),
|
||||
app_secret: "app-secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7540,6 +7557,7 @@ mod tests {
|
||||
app_token: "token".into(),
|
||||
webhook_secret: Some("secret".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7552,6 +7570,7 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
}
|
||||
|
||||
+1
-1
@@ -146,7 +146,7 @@ fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec<Skill
|
||||
load_skills_from_directory(&skills_dir, allow_scripts)
|
||||
}
|
||||
|
||||
fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
pub fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
if !skills_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,446 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::ClaudeCodeConfig;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// Environment variables safe to pass through to the `claude` subprocess.
|
||||
const SAFE_ENV_VARS: &[&str] = &[
|
||||
"PATH", "HOME", "TERM", "LANG", "LC_ALL", "LC_CTYPE", "USER", "SHELL", "TMPDIR",
|
||||
];
|
||||
|
||||
/// Delegates coding tasks to the Claude Code CLI (`claude -p`).
|
||||
///
|
||||
/// This creates a two-tier agent architecture: ZeroClaw orchestrates high-level
|
||||
/// tasks and delegates complex coding work to Claude Code, which has its own
|
||||
/// agent loop with Read/Edit/Bash tools.
|
||||
///
|
||||
/// Authentication uses the `claude` binary's own OAuth session (Max subscription)
|
||||
/// by default. No API key is needed unless `env_passthrough` includes
|
||||
/// `ANTHROPIC_API_KEY` for API-key billing.
|
||||
pub struct ClaudeCodeTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
config: ClaudeCodeConfig,
|
||||
}
|
||||
|
||||
impl ClaudeCodeTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>, config: ClaudeCodeConfig) -> Self {
|
||||
Self { security, config }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ClaudeCodeTool {
|
||||
fn name(&self) -> &str {
|
||||
"claude_code"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Delegate a coding task to Claude Code (claude -p). Supports file editing, bash execution, structured output, and multi-turn sessions. Use for complex coding work that benefits from Claude Code's full agent loop."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The coding task to delegate to Claude Code"
|
||||
},
|
||||
"allowed_tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Override the default tool allowlist (e.g. [\"Read\", \"Edit\", \"Bash\", \"Write\"])"
|
||||
},
|
||||
"system_prompt": {
|
||||
"type": "string",
|
||||
"description": "Override or append a system prompt for this invocation"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Resume a previous Claude Code session by its ID"
|
||||
},
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"description": "Request structured output conforming to this JSON Schema"
|
||||
},
|
||||
"working_directory": {
|
||||
"type": "string",
|
||||
"description": "Working directory within the workspace (must be inside workspace_dir)"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Rate limit check
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Enforce act policy
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "claude_code")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
// Extract prompt (required)
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
// Extract optional params
|
||||
let allowed_tools: Vec<String> = args
|
||||
.get("allowed_tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_else(|| self.config.allowed_tools.clone());
|
||||
|
||||
let system_prompt = args
|
||||
.get("system_prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.or_else(|| self.config.system_prompt.clone());
|
||||
|
||||
let session_id = args.get("session_id").and_then(|v| v.as_str());
|
||||
|
||||
let json_schema = args.get("json_schema").filter(|v| v.is_object());
|
||||
|
||||
// Validate working directory — require both paths to exist (reject
|
||||
// non-existent paths instead of falling back to the raw value, which
|
||||
// could bypass the workspace containment check via symlinks or
|
||||
// specially-crafted path components).
|
||||
let work_dir = if let Some(wd) = args.get("working_directory").and_then(|v| v.as_str()) {
|
||||
let wd_path = std::path::PathBuf::from(wd);
|
||||
let workspace = &self.security.workspace_dir;
|
||||
let canonical_wd = match wd_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"working_directory '{}' does not exist or is not accessible",
|
||||
wd
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
let canonical_ws = match workspace.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"workspace directory '{}' does not exist or is not accessible",
|
||||
workspace.display()
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
if !canonical_wd.starts_with(&canonical_ws) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"working_directory '{}' is outside the workspace '{}'",
|
||||
wd,
|
||||
workspace.display()
|
||||
)),
|
||||
});
|
||||
}
|
||||
canonical_wd
|
||||
} else {
|
||||
self.security.workspace_dir.clone()
|
||||
};
|
||||
|
||||
// Record action budget
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build CLI command
|
||||
let mut cmd = Command::new("claude");
|
||||
cmd.arg("-p").arg(prompt);
|
||||
cmd.arg("--output-format").arg("json");
|
||||
|
||||
if !allowed_tools.is_empty() {
|
||||
for tool in &allowed_tools {
|
||||
cmd.arg("--allowedTools").arg(tool);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sp) = system_prompt {
|
||||
cmd.arg("--append-system-prompt").arg(sp);
|
||||
}
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
cmd.arg("--resume").arg(sid);
|
||||
}
|
||||
|
||||
if let Some(schema) = json_schema {
|
||||
let schema_str = serde_json::to_string(schema).unwrap_or_else(|_| "{}".to_string());
|
||||
cmd.arg("--json-schema").arg(schema_str);
|
||||
}
|
||||
|
||||
// Environment: clear everything, pass only safe vars + configured passthrough.
|
||||
// HOME is critical so `claude` finds its OAuth session in ~/.claude/
|
||||
cmd.env_clear();
|
||||
for var in SAFE_ENV_VARS {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
cmd.env(var, val);
|
||||
}
|
||||
}
|
||||
for var in &self.config.env_passthrough {
|
||||
let trimmed = var.trim();
|
||||
if !trimmed.is_empty() {
|
||||
if let Ok(val) = std::env::var(trimmed) {
|
||||
cmd.env(trimmed, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cmd.current_dir(&work_dir);
|
||||
// Execute with timeout — use kill_on_drop(true) so the child process
|
||||
// is automatically killed when the future is dropped on timeout,
|
||||
// preventing zombie processes.
|
||||
let timeout = Duration::from_secs(self.config.timeout_secs);
|
||||
cmd.kill_on_drop(true);
|
||||
|
||||
let result = tokio::time::timeout(timeout, cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
// Truncate to max_output_bytes with char-boundary safety
|
||||
if stdout.len() > self.config.max_output_bytes {
|
||||
let mut b = self.config.max_output_bytes.min(stdout.len());
|
||||
while b > 0 && !stdout.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
stdout.truncate(b);
|
||||
stdout.push_str("\n... [output truncated]");
|
||||
}
|
||||
|
||||
// Try to parse JSON response and extract result + session_id
|
||||
if let Ok(json_resp) = serde_json::from_str::<serde_json::Value>(&stdout) {
|
||||
let result_text = json_resp
|
||||
.get("result")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let resp_session_id = json_resp
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let mut formatted = String::new();
|
||||
if result_text.is_empty() {
|
||||
// Fall back to full JSON if no "result" key
|
||||
formatted.push_str(&stdout);
|
||||
} else {
|
||||
formatted.push_str(result_text);
|
||||
}
|
||||
if !resp_session_id.is_empty() {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(formatted, "\n\n[session_id: {}]", resp_session_id);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: formatted,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// JSON parse failed — return raw stdout (defensive)
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
let err_msg = e.to_string();
|
||||
let msg = if err_msg.contains("No such file or directory")
|
||||
|| err_msg.contains("not found")
|
||||
|| err_msg.contains("cannot find")
|
||||
{
|
||||
"Claude Code CLI ('claude') not found in PATH. Install with: npm install -g @anthropic-ai/claude-code".into()
|
||||
} else {
|
||||
format!("Failed to execute claude: {e}")
|
||||
};
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(msg),
|
||||
})
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout — kill_on_drop(true) ensures the child is killed
|
||||
// when the future is dropped.
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Claude Code timed out after {}s and was killed",
|
||||
self.config.timeout_secs
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::ClaudeCodeConfig;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_config() -> ClaudeCodeConfig {
|
||||
ClaudeCodeConfig::default()
|
||||
}
|
||||
|
||||
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_tool_name() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
|
||||
assert_eq!(tool.name(), "claude_code");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_tool_schema_has_prompt() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.expect("schema required should be an array")
|
||||
.contains(&json!("prompt")));
|
||||
// Optional params exist in properties
|
||||
assert!(schema["properties"]["allowed_tools"].is_object());
|
||||
assert!(schema["properties"]["system_prompt"].is_object());
|
||||
assert!(schema["properties"]["session_id"].is_object());
|
||||
assert!(schema["properties"]["json_schema"].is_object());
|
||||
assert!(schema["properties"]["working_directory"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_blocks_rate_limited() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
max_actions_per_hour: 0,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ClaudeCodeTool::new(security, test_config());
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "hello"}))
|
||||
.await
|
||||
.expect("rate-limited should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_blocks_readonly() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::ReadOnly), test_config());
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "hello"}))
|
||||
.await
|
||||
.expect("readonly should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_missing_prompt_param() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_rejects_path_outside_workspace() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Full), test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"prompt": "hello",
|
||||
"working_directory": "/etc"
|
||||
}))
|
||||
.await
|
||||
.expect("should return a result for path validation");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("outside the workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_env_passthrough_defaults() {
|
||||
let config = ClaudeCodeConfig::default();
|
||||
assert!(
|
||||
config.env_passthrough.is_empty(),
|
||||
"env_passthrough should default to empty (Max subscription needs no API key)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_default_config_values() {
|
||||
let config = ClaudeCodeConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.timeout_secs, 600);
|
||||
assert_eq!(config.max_output_bytes, 2_097_152);
|
||||
assert!(config.system_prompt.is_none());
|
||||
assert_eq!(config.allowed_tools, vec!["Read", "Edit", "Bash", "Write"]);
|
||||
}
|
||||
}
|
||||
+344
-2
@@ -1,5 +1,6 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::agent::loop_::run_tool_call_loop;
|
||||
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
|
||||
use crate::config::{DelegateAgentConfig, DelegateToolConfig};
|
||||
use crate::observability::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
use crate::providers::{self, ChatMessage, Provider};
|
||||
@@ -9,6 +10,7 @@ use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -31,6 +33,8 @@ pub struct DelegateTool {
|
||||
multimodal_config: crate::config::MultimodalConfig,
|
||||
/// Global delegate tool config providing default timeout values.
|
||||
delegate_config: DelegateToolConfig,
|
||||
/// Workspace directory inherited from the root agent context.
|
||||
workspace_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
@@ -62,6 +66,7 @@ impl DelegateTool {
|
||||
parent_tools: Arc::new(RwLock::new(Vec::new())),
|
||||
multimodal_config: crate::config::MultimodalConfig::default(),
|
||||
delegate_config: DelegateToolConfig::default(),
|
||||
workspace_dir: PathBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +104,7 @@ impl DelegateTool {
|
||||
parent_tools: Arc::new(RwLock::new(Vec::new())),
|
||||
multimodal_config: crate::config::MultimodalConfig::default(),
|
||||
delegate_config: DelegateToolConfig::default(),
|
||||
workspace_dir: PathBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +131,12 @@ impl DelegateTool {
|
||||
pub fn parent_tools_handle(&self) -> Arc<RwLock<Vec<Arc<dyn Tool>>>> {
|
||||
Arc::clone(&self.parent_tools)
|
||||
}
|
||||
|
||||
/// Attach the workspace directory for system prompt enrichment.
|
||||
pub fn with_workspace_dir(mut self, workspace_dir: PathBuf) -> Self {
|
||||
self.workspace_dir = workspace_dir;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -300,6 +312,11 @@ impl Tool for DelegateTool {
|
||||
.await;
|
||||
}
|
||||
|
||||
// Build enriched system prompt for non-agentic sub-agent.
|
||||
let enriched_system_prompt =
|
||||
self.build_enriched_system_prompt(agent_config, &[], &self.workspace_dir);
|
||||
let system_prompt_ref = enriched_system_prompt.as_deref();
|
||||
|
||||
// Wrap the provider call in a timeout to prevent indefinite blocking
|
||||
let timeout_secs = agent_config
|
||||
.timeout_secs
|
||||
@@ -307,7 +324,7 @@ impl Tool for DelegateTool {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
provider.chat_with_system(
|
||||
agent_config.system_prompt.as_deref(),
|
||||
system_prompt_ref,
|
||||
&full_prompt,
|
||||
&agent_config.model,
|
||||
temperature,
|
||||
@@ -355,6 +372,80 @@ impl Tool for DelegateTool {
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
/// Build an enriched system prompt for a sub-agent by composing structured
|
||||
/// operational sections (tools, skills, workspace, datetime, shell policy)
|
||||
/// with the operator-configured `system_prompt` string.
|
||||
fn build_enriched_system_prompt(
|
||||
&self,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
sub_tools: &[Box<dyn Tool>],
|
||||
workspace_dir: &Path,
|
||||
) -> Option<String> {
|
||||
// Resolve skills directory: scoped if configured, otherwise workspace default.
|
||||
let skills_dir = agent_config
|
||||
.skills_directory
|
||||
.as_ref()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.map(|dir| workspace_dir.join(dir))
|
||||
.unwrap_or_else(|| crate::skills::skills_dir(workspace_dir));
|
||||
let skills = crate::skills::load_skills_from_directory(&skills_dir, false);
|
||||
|
||||
// Determine shell policy instructions when the `shell` tool is in the
|
||||
// effective tool list.
|
||||
let has_shell = sub_tools.iter().any(|t| t.name() == "shell");
|
||||
let shell_policy = if has_shell {
|
||||
"## Shell Policy\n\n\
|
||||
- Prefer non-destructive commands. Use `trash` over `rm` where possible.\n\
|
||||
- Do not run commands that exfiltrate data or modify system-critical paths.\n\
|
||||
- Avoid interactive commands that block on stdin.\n\
|
||||
- Quote paths that may contain spaces."
|
||||
.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Build structured operational context using SystemPromptBuilder sections.
|
||||
let ctx = PromptContext {
|
||||
workspace_dir,
|
||||
model_name: &agent_config.model,
|
||||
tools: sub_tools,
|
||||
skills: &skills,
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode::Full,
|
||||
identity_config: None,
|
||||
dispatcher_instructions: "",
|
||||
tool_descriptions: None,
|
||||
security_summary: None,
|
||||
autonomy_level: crate::security::AutonomyLevel::default(),
|
||||
};
|
||||
|
||||
let builder = SystemPromptBuilder::default()
|
||||
.add_section(Box::new(crate::agent::prompt::ToolsSection))
|
||||
.add_section(Box::new(crate::agent::prompt::SafetySection))
|
||||
.add_section(Box::new(crate::agent::prompt::SkillsSection))
|
||||
.add_section(Box::new(crate::agent::prompt::WorkspaceSection))
|
||||
.add_section(Box::new(crate::agent::prompt::DateTimeSection));
|
||||
|
||||
let mut enriched = builder.build(&ctx).unwrap_or_default();
|
||||
|
||||
if !shell_policy.is_empty() {
|
||||
enriched.push_str(&shell_policy);
|
||||
enriched.push_str("\n\n");
|
||||
}
|
||||
|
||||
// Append the operator-configured system_prompt as the identity/role block.
|
||||
if let Some(operator_prompt) = agent_config.system_prompt.as_ref() {
|
||||
enriched.push_str(operator_prompt);
|
||||
enriched.push('\n');
|
||||
}
|
||||
|
||||
let trimmed = enriched.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_agentic(
|
||||
&self,
|
||||
agent_name: &str,
|
||||
@@ -401,8 +492,12 @@ impl DelegateTool {
|
||||
});
|
||||
}
|
||||
|
||||
// Build enriched system prompt with tools, skills, workspace, datetime context.
|
||||
let enriched_system_prompt =
|
||||
self.build_enriched_system_prompt(agent_config, &sub_tools, &self.workspace_dir);
|
||||
|
||||
let mut history = Vec::new();
|
||||
if let Some(system_prompt) = agent_config.system_prompt.as_ref() {
|
||||
if let Some(system_prompt) = enriched_system_prompt.as_ref() {
|
||||
history.push(ChatMessage::system(system_prompt.clone()));
|
||||
}
|
||||
history.push(ChatMessage::user(full_prompt.to_string()));
|
||||
@@ -435,6 +530,7 @@ impl DelegateTool {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
@@ -548,6 +644,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
@@ -564,6 +661,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents
|
||||
@@ -719,6 +817,7 @@ mod tests {
|
||||
max_iterations,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -829,6 +928,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
@@ -937,6 +1037,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
@@ -974,6 +1075,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
@@ -1235,6 +1337,113 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_includes_tools_workspace_datetime() {
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: Some("You are a code reviewer.".to_string()),
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_delegate_enrich_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
std::fs::create_dir_all(&workspace).unwrap();
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.clone());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(prompt.contains("## Tools"), "should contain tools section");
|
||||
assert!(prompt.contains("echo_tool"), "should list allowed tools");
|
||||
assert!(
|
||||
prompt.contains("## Workspace"),
|
||||
"should contain workspace section"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains(&workspace.display().to_string()),
|
||||
"should contain workspace path"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("## Current Date & Time"),
|
||||
"should contain datetime section"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("You are a code reviewer."),
|
||||
"should append operator system_prompt"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_includes_shell_policy_when_shell_present() {
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["shell".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
struct MockShellTool;
|
||||
#[async_trait]
|
||||
impl Tool for MockShellTool {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"Execute shell commands"
|
||||
}
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({"type": "object"})
|
||||
}
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: String::new(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockShellTool)];
|
||||
let workspace = std::env::temp_dir();
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.to_path_buf());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
prompt.contains("## Shell Policy"),
|
||||
"should contain shell policy when shell tool is present"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parent_tools_handle_returns_shared_reference() {
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security()).with_parent_tools(
|
||||
@@ -1265,6 +1474,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
assert_eq!(
|
||||
config.timeout_secs.unwrap_or(DEFAULT_DELEGATE_TIMEOUT_SECS),
|
||||
@@ -1278,6 +1488,39 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_omits_shell_policy_without_shell_tool() {
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
let workspace = std::env::temp_dir();
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.to_path_buf());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!prompt.contains("## Shell Policy"),
|
||||
"should not contain shell policy when shell tool is absent"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_timeout_values_are_respected() {
|
||||
let config = DelegateAgentConfig {
|
||||
@@ -1292,6 +1535,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(60),
|
||||
agentic_timeout_secs: Some(600),
|
||||
skills_directory: None,
|
||||
};
|
||||
assert_eq!(
|
||||
config.timeout_secs.unwrap_or(DEFAULT_DELEGATE_TIMEOUT_SECS),
|
||||
@@ -1346,6 +1590,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(0),
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1372,6 +1617,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: Some(0),
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1398,6 +1644,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(7200),
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1424,6 +1671,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: Some(5000),
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1450,6 +1698,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(3600),
|
||||
agentic_timeout_secs: Some(3600),
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
assert!(config.validate().is_ok());
|
||||
@@ -1472,8 +1721,101 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_loads_skills_from_scoped_directory() {
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_delegate_skills_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
let scoped_skills_dir = workspace.join("skills/code-review");
|
||||
std::fs::create_dir_all(scoped_skills_dir.join("lint-check")).unwrap();
|
||||
std::fs::write(
|
||||
scoped_skills_dir.join("lint-check/SKILL.toml"),
|
||||
"[skill]\nname = \"lint-check\"\ndescription = \"Run lint checks\"\nversion = \"1.0.0\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: Some("skills/code-review".to_string()),
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.clone());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
prompt.contains("lint-check"),
|
||||
"should contain skills from scoped directory"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_falls_back_to_default_skills_dir() {
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_delegate_fallback_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
let default_skills_dir = workspace.join("skills");
|
||||
std::fs::create_dir_all(default_skills_dir.join("deploy")).unwrap();
|
||||
std::fs::write(
|
||||
default_skills_dir.join("deploy/SKILL.toml"),
|
||||
"[skill]\nname = \"deploy\"\ndescription = \"Deploy safely\"\nversion = \"1.0.0\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.clone());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
prompt.contains("deploy"),
|
||||
"should contain skills from default workspace skills/ directory"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
}
|
||||
|
||||
+85
-13
@@ -23,7 +23,7 @@ impl Tool for MemoryRecallTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
|
||||
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance. Supports keyword search, time-only query (since/until), or both."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@@ -32,22 +32,76 @@ impl Tool for MemoryRecallTool {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Keywords or phrase to search for in memory"
|
||||
"description": "Keywords or phrase to search for in memory (optional if since/until provided)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 5)"
|
||||
},
|
||||
"since": {
|
||||
"type": "string",
|
||||
"description": "Filter memories created at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
|
||||
},
|
||||
"until": {
|
||||
"type": "string",
|
||||
"description": "Filter memories created at or before this time (RFC 3339)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let since = args.get("since").and_then(|v| v.as_str());
|
||||
let until = args.get("until").and_then(|v| v.as_str());
|
||||
|
||||
if query.trim().is_empty() && since.is_none() && until.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate date strings
|
||||
if let Some(s) = since {
|
||||
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'since' date: {s}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(u) = until {
|
||||
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'until' date: {u}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let (Some(s), Some(u)) = (since, until) {
|
||||
if let (Ok(s_dt), Ok(u_dt)) = (
|
||||
chrono::DateTime::parse_from_rfc3339(s),
|
||||
chrono::DateTime::parse_from_rfc3339(u),
|
||||
) {
|
||||
if s_dt >= u_dt {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'since' must be before 'until'".into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
@@ -55,10 +109,10 @@ impl Tool for MemoryRecallTool {
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(5, |v| v as usize);
|
||||
|
||||
match self.memory.recall(query, limit, None).await {
|
||||
match self.memory.recall(query, limit, None, since, until).await {
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No memories found matching that query.".into(),
|
||||
output: "No memories found.".into(),
|
||||
error: None,
|
||||
}),
|
||||
Ok(entries) => {
|
||||
@@ -150,11 +204,29 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_missing_query() {
|
||||
async fn recall_requires_query_or_time() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("at least"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_time_only_returns_entries() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
// Time-only: since far in past
|
||||
let result = tool
|
||||
.execute(json!({"since": "2020-01-01T00:00:00Z", "limit": 5}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Found 1"));
|
||||
assert!(result.output.contains("Rust"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
+13
-1
@@ -20,6 +20,7 @@ pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
pub mod browser_open;
|
||||
pub mod calculator;
|
||||
pub mod claude_code;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
pub mod cloud_patterns;
|
||||
@@ -92,6 +93,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use claude_code::ClaudeCodeTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
pub use composio::ComposioTool;
|
||||
@@ -528,6 +530,14 @@ pub fn all_tools_with_runtime(
|
||||
);
|
||||
}
|
||||
|
||||
// Claude Code delegation tool
|
||||
if root_config.claude_code.enabled {
|
||||
tool_arcs.push(Arc::new(ClaudeCodeTool::new(
|
||||
security.clone(),
|
||||
root_config.claude_code.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// PDF extraction (feature-gated at compile time via rag-pdf)
|
||||
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
|
||||
|
||||
@@ -669,7 +679,8 @@ pub fn all_tools_with_runtime(
|
||||
)
|
||||
.with_parent_tools(Arc::clone(&parent_tools))
|
||||
.with_multimodal_config(root_config.multimodal.clone())
|
||||
.with_delegate_config(root_config.delegate.clone());
|
||||
.with_delegate_config(root_config.delegate.clone())
|
||||
.with_workspace_dir(workspace_dir.to_path_buf());
|
||||
tool_arcs.push(Arc::new(delegate_tool));
|
||||
Some(parent_tools)
|
||||
};
|
||||
@@ -1000,6 +1011,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -707,6 +707,7 @@ impl ModelRoutingConfigTool {
|
||||
max_iterations: DEFAULT_AGENT_MAX_ITERATIONS,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
});
|
||||
|
||||
next_agent.provider = provider;
|
||||
|
||||
@@ -568,6 +568,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
@@ -584,6 +585,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents
|
||||
|
||||
@@ -100,6 +100,10 @@ fn gateway_config_defaults_are_secure() {
|
||||
!gw.trust_forwarded_headers,
|
||||
"forwarded headers should be untrusted by default"
|
||||
);
|
||||
assert!(
|
||||
gw.path_prefix.is_none(),
|
||||
"path_prefix should default to None"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -124,6 +128,7 @@ fn gateway_config_toml_roundtrip() {
|
||||
host: "0.0.0.0".into(),
|
||||
require_pairing: false,
|
||||
pair_rate_limit_per_minute: 5,
|
||||
path_prefix: Some("/zeroclaw".into()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -134,6 +139,7 @@ fn gateway_config_toml_roundtrip() {
|
||||
assert_eq!(parsed.host, "0.0.0.0");
|
||||
assert!(!parsed.require_pairing);
|
||||
assert_eq!(parsed.pair_rate_limit_per_minute, 5);
|
||||
assert_eq!(parsed.path_prefix.as_deref(), Some("/zeroclaw"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -163,6 +169,93 @@ port = 9090
|
||||
assert_eq!(parsed.gateway.pair_rate_limit_per_minute, 10);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// GatewayConfig path_prefix validation
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_missing_leading_slash() {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some("zeroclaw".into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must start with '/'"),
|
||||
"expected leading-slash error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_trailing_slash() {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some("/zeroclaw/".into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not end with '/'"),
|
||||
"expected trailing-slash error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_bare_slash() {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some("/".into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not end with '/'"),
|
||||
"expected bare-slash error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_accepts_valid_prefixes() {
|
||||
for prefix in ["/zeroclaw", "/apps/zeroclaw", "/api/hassio_ingress/abc123"] {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some(prefix.into());
|
||||
config
|
||||
.validate()
|
||||
.unwrap_or_else(|e| panic!("prefix {prefix:?} should be valid, got: {e}"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_unsafe_characters() {
|
||||
for prefix in [
|
||||
"/zero claw",
|
||||
"/zero<claw",
|
||||
"/zero>claw",
|
||||
"/zero\"claw",
|
||||
"/zero?query",
|
||||
"/zero#frag",
|
||||
] {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some(prefix.into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("invalid character"),
|
||||
"prefix {prefix:?} should be rejected, got: {err}"
|
||||
);
|
||||
}
|
||||
// Leading/trailing whitespace is rejected by the starts_with('/') or
|
||||
// invalid-character check — either way it must not pass validation.
|
||||
for prefix in [" /zeroclaw ", " /zeroclaw"] {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some(prefix.into());
|
||||
assert!(
|
||||
config.validate().is_err(),
|
||||
"whitespace-padded prefix {prefix:?} should be rejected"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_accepts_none() {
|
||||
let config = Config::default();
|
||||
assert!(config.gateway.path_prefix.is_none());
|
||||
config
|
||||
.validate()
|
||||
.expect("absent path_prefix should be valid");
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// SecurityConfig boundary tests
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -147,8 +147,8 @@ async fn compare_recall_quality() {
|
||||
println!("RECALL QUALITY (10 entries seeded):\n");
|
||||
|
||||
for (query, desc) in &queries {
|
||||
let sq_results = sq.recall(query, 10, None).await.unwrap();
|
||||
let md_results = md.recall(query, 10, None).await.unwrap();
|
||||
let sq_results = sq.recall(query, 10, None, None, None).await.unwrap();
|
||||
let md_results = md.recall(query, 10, None, None, None).await.unwrap();
|
||||
|
||||
println!(" Query: \"{query}\" — {desc}");
|
||||
println!(" SQLite: {} results", sq_results.len());
|
||||
@@ -202,11 +202,17 @@ async fn compare_recall_speed() {
|
||||
|
||||
// Benchmark recall
|
||||
let start = Instant::now();
|
||||
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
|
||||
let sq_results = sq
|
||||
.recall("Rust systems", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let sq_dur = start.elapsed();
|
||||
|
||||
let start = Instant::now();
|
||||
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
|
||||
let md_results = md
|
||||
.recall("Rust systems", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let md_dur = start.elapsed();
|
||||
|
||||
println!("\n============================================================");
|
||||
@@ -312,7 +318,7 @@ async fn compare_upsert() {
|
||||
let md_count = md.count().await.unwrap();
|
||||
|
||||
let sq_entry = sq.get("pref").await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5, None, None, None).await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("UPSERT (store same key twice):");
|
||||
|
||||
@@ -216,7 +216,10 @@ async fn sqlite_memory_recall_returns_relevant_results() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust programming", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("Rust programming", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!results.is_empty(), "recall should find matching entries");
|
||||
// The Rust-related entry should be in results
|
||||
assert!(
|
||||
@@ -241,7 +244,10 @@ async fn sqlite_memory_recall_respects_limit() {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("test content", 3, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("test content", 3, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
results.len() <= 3,
|
||||
"recall should respect limit of 3, got {}",
|
||||
@@ -250,7 +256,7 @@ async fn sqlite_memory_recall_respects_limit() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_memory_recall_empty_query_returns_empty() {
|
||||
async fn sqlite_memory_recall_empty_query_returns_recent_entries() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
||||
@@ -258,8 +264,10 @@ async fn sqlite_memory_recall_empty_query_returns_empty() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty(), "empty query should return no results");
|
||||
// Empty query uses time-only path: returns recent entries by updated_at
|
||||
let results = mem.recall("", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1, "empty query should return recent entries");
|
||||
assert_eq!(results[0].key, "fact");
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
+2
-1
@@ -16,6 +16,7 @@ import Pairing from './pages/Pairing';
|
||||
import { AuthProvider, useAuth } from './hooks/useAuth';
|
||||
import { DraftContext, useDraftStore } from './hooks/useDraft';
|
||||
import { setLocale, type Locale } from './lib/i18n';
|
||||
import { basePath } from './lib/basePath';
|
||||
import { getAdminPairCode } from './lib/api';
|
||||
|
||||
// Locale context
|
||||
@@ -131,7 +132,7 @@ function PairingDialog({ onPair }: { onPair: (code: string) => Promise<void> })
|
||||
|
||||
<div className="text-center mb-8">
|
||||
<img
|
||||
src="/_app/zeroclaw-trans.png"
|
||||
src={`${basePath}/_app/zeroclaw-trans.png`}
|
||||
alt="ZeroClaw"
|
||||
className="h-20 w-20 rounded-2xl object-cover mx-auto mb-4 animate-float"
|
||||
onError={(e) => { e.currentTarget.style.display = 'none'; }}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { NavLink } from 'react-router-dom';
|
||||
import { basePath } from '../../lib/basePath';
|
||||
import {
|
||||
LayoutDashboard,
|
||||
MessageSquare,
|
||||
@@ -34,7 +35,7 @@ export default function Sidebar() {
|
||||
<div className="relative shrink-0">
|
||||
<div className="absolute -inset-1.5 rounded-xl" style={{ background: 'linear-gradient(135deg, rgba(var(--pc-accent-rgb), 0.15), rgba(var(--pc-accent-rgb), 0.05))' }} />
|
||||
<img
|
||||
src="/_app/zeroclaw-trans.png"
|
||||
src={`${basePath}/_app/zeroclaw-trans.png`}
|
||||
alt="ZeroClaw"
|
||||
className="relative h-9 w-9 rounded-xl object-cover"
|
||||
onError={(e) => {
|
||||
|
||||
+4
-3
@@ -11,6 +11,7 @@ import type {
|
||||
HealthSnapshot,
|
||||
} from '../types/api';
|
||||
import { clearToken, getToken, setToken } from './auth';
|
||||
import { basePath } from './basePath';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Base fetch wrapper
|
||||
@@ -42,7 +43,7 @@ export async function apiFetch<T = unknown>(
|
||||
headers.set('Content-Type', 'application/json');
|
||||
}
|
||||
|
||||
const response = await fetch(path, { ...options, headers });
|
||||
const response = await fetch(`${basePath}${path}`, { ...options, headers });
|
||||
|
||||
if (response.status === 401) {
|
||||
clearToken();
|
||||
@@ -78,7 +79,7 @@ function unwrapField<T>(value: T | Record<string, T>, key: string): T {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function pair(code: string): Promise<{ token: string }> {
|
||||
const response = await fetch('/pair', {
|
||||
const response = await fetch(`${basePath}/pair`, {
|
||||
method: 'POST',
|
||||
headers: { 'X-Pairing-Code': code },
|
||||
});
|
||||
@@ -106,7 +107,7 @@ export async function getAdminPairCode(): Promise<{ pairing_code: string | null;
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function getPublicHealth(): Promise<{ require_pairing: boolean; paired: boolean }> {
|
||||
const response = await fetch('/health');
|
||||
const response = await fetch(`${basePath}/health`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Health check failed (${response.status})`);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
// Runtime base path injected by the Rust gateway into index.html.
|
||||
// Allows the SPA to work under a reverse-proxy path prefix.
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
__ZEROCLAW_BASE__?: string;
|
||||
}
|
||||
}
|
||||
|
||||
/** Gateway path prefix (e.g. "/zeroclaw"), or empty string when served at root. */
|
||||
export const basePath: string = (window.__ZEROCLAW_BASE__ ?? '').replace(/\/+$/, '');
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
import type { SSEEvent } from '../types/api';
|
||||
import { getToken } from './auth';
|
||||
import { basePath } from './basePath';
|
||||
|
||||
export type SSEEventHandler = (event: SSEEvent) => void;
|
||||
export type SSEErrorHandler = (error: Event | Error) => void;
|
||||
@@ -41,7 +42,7 @@ export class SSEClient {
|
||||
private readonly autoReconnect: boolean;
|
||||
|
||||
constructor(options: SSEClientOptions = {}) {
|
||||
this.path = options.path ?? '/api/events';
|
||||
this.path = options.path ?? `${basePath}/api/events`;
|
||||
this.reconnectDelay = options.reconnectDelay ?? DEFAULT_RECONNECT_DELAY;
|
||||
this.maxReconnectDelay = options.maxReconnectDelay ?? MAX_RECONNECT_DELAY;
|
||||
this.autoReconnect = options.autoReconnect ?? true;
|
||||
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
import type { WsMessage } from '../types/api';
|
||||
import { getToken } from './auth';
|
||||
import { basePath } from './basePath';
|
||||
import { generateUUID } from './uuid';
|
||||
|
||||
export type WsMessageHandler = (msg: WsMessage) => void;
|
||||
@@ -69,7 +70,7 @@ export class WebSocketClient {
|
||||
const params = new URLSearchParams();
|
||||
if (token) params.set('token', token);
|
||||
params.set('session_id', sessionId);
|
||||
const url = `${this.baseUrl}/ws/chat?${params.toString()}`;
|
||||
const url = `${this.baseUrl}${basePath}/ws/chat?${params.toString()}`;
|
||||
|
||||
const protocols: string[] = ['zeroclaw.v1'];
|
||||
if (token) protocols.push(`bearer.${token}`);
|
||||
|
||||
+3
-2
@@ -2,12 +2,13 @@ import React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import { BrowserRouter } from 'react-router-dom';
|
||||
import App from './App';
|
||||
import { basePath } from './lib/basePath';
|
||||
import './index.css';
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<React.StrictMode>
|
||||
{/* Vite base '/_app/' scopes static asset URLs only; app routes stay rooted at '/' for SPA fallback. */}
|
||||
<BrowserRouter basename="/">
|
||||
{/* basePath is injected by the Rust gateway at serve time for reverse-proxy prefix support. */}
|
||||
<BrowserRouter basename={basePath || '/'}>
|
||||
<App />
|
||||
</BrowserRouter>
|
||||
</React.StrictMode>
|
||||
|
||||
Reference in New Issue
Block a user