Compare commits

...

20 Commits

Author SHA1 Message Date
SimianAstronaut7 87b5bca449 feat(config): add configurable pacing controls for slow/local LLM workloads (#3343)
* feat(config): add configurable pacing controls for slow/local LLM workloads (#2963)

Add a new `[pacing]` config section with four opt-in parameters that
let users tune timeout and loop-detection behavior for local LLMs
(Ollama, llama.cpp, vLLM) without disabling safety features entirely:

- `step_timeout_secs`: per-step LLM inference timeout independent of
  the overall message budget, catching hung model responses early.
- `loop_detection_min_elapsed_secs`: time-gated loop detection that
  only activates after a configurable grace period, avoiding false
  positives on long-running browser/research workflows.
- `loop_ignore_tools`: per-tool loop-detection exclusions so tools
  like `browser_screenshot` that structurally resemble loops are not
  counted toward identical-output detection.
- `message_timeout_scale_max`: overrides the hardcoded 4x ceiling in
  the channel message timeout scaling formula.

All parameters are strictly optional with no effect when absent,
preserving full backwards compatibility.

Closes #2963

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(config): add missing pacing fields in tests and call sites

* fix(config): add pacing arg to remaining cost-tracking test call sites

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
2026-03-21 08:54:08 -04:00
Argenis be40c0c5a5 Merge pull request #4145 from zeroclaw-labs/feat/gateway-path-prefix
feat(gateway): add path_prefix for reverse-proxy deployments
2026-03-21 08:48:56 -04:00
argenis de la rosa 6527871928 fix: add path_prefix to test AppState in gateway/api.rs 2026-03-21 08:14:28 -04:00
argenis de la rosa 0bda80de9c feat(gateway): add path_prefix for reverse-proxy deployments
Adopted from #3709 by @slayer with minor cleanup.
Supersedes #3709
2026-03-21 08:14:28 -04:00
Argenis 02f57f4d98 Merge pull request #4144 from zeroclaw-labs/feat/claude-code-tool
feat(tools): add ClaudeCodeTool for two-tier agent delegation
2026-03-21 08:14:19 -04:00
Argenis ef83dd44d7 Merge pull request #4146 from zeroclaw-labs/feat/memory-recall-time-range
feat(memory): add time range filter to recall (since/until)
2026-03-21 08:14:12 -04:00
Argenis a986b6b912 fix(install): detect un-accepted Xcode license + bump to v0.5.5 (#4147)
* fix(install): detect un-accepted Xcode license before build

Add an xcrun check after verifying Xcode CLT is installed. When the
Xcode/CLT license has not been accepted, cc exits with code 69 and
the build fails with a cryptic linker error. This surfaces a clear
message telling the user to run `sudo xcodebuild -license accept`.

* chore(release): bump version to v0.5.5

Update version across all distribution manifests:
- Cargo.toml and Cargo.lock
- dist/aur/PKGBUILD and .SRCINFO
- dist/scoop/zeroclaw.json
2026-03-21 08:09:27 -04:00
SimianAstronaut7 b6b1186e3b feat(channel): add per-channel proxy_url support for HTTP/SOCKS5 proxies (#3345)
* feat(channel): add per-channel proxy_url support for HTTP/SOCKS5 proxies

Allow each channel to optionally specify a `proxy_url` in its config,
enabling users behind restrictive networks to route channel traffic
through HTTP or SOCKS5 proxies. When set, the per-channel proxy takes
precedence over the global `[proxy]` config; when absent, the channel
falls back to the existing runtime proxy behavior.

Adds `proxy_url: Option<String>` to all 12 channel config structs
(Telegram, Discord, Slack, Mattermost, Signal, WhatsApp, Wati,
NextcloudTalk, DingTalk, QQ, Lark, Feishu) and introduces
`build_channel_proxy_client`, `build_channel_proxy_client_with_timeouts`,
and `apply_channel_proxy_to_builder` helpers that normalize proxy URLs
and integrate with the existing client cache.

Closes #3262

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(channel): add missing proxy_url fields in test initializers

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
2026-03-21 07:53:20 -04:00
SimianAstronaut7 00dc0c8670 feat(tool): enrich delegate sub-agent system prompt and add skills_directory config key (#3344)
* feat(tool): enrich delegate sub-agent system prompt and add skills_directory config key (#3046)

Sub-agents configured under [agents.<name>] previously received only the
bare system_prompt string. They now receive a structured system prompt
containing: tools section (allowed tools with parameters and invocation
protocol), skills section (from scoped or default directory), workspace
path, current date/time, safety constraints, and shell policy when shell
is in the effective tool list.

Add optional skills_directory field to DelegateAgentConfig for per-agent
scoped skill loading. When unset, falls back to default workspace
skills/ directory.

Closes #3046

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(tools): add missing fields after rebase

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
2026-03-21 07:53:02 -04:00
argenis de la rosa 43f2a0a815 fix: add ClaudeCodeConfig to config re-exports and fix formatting 2026-03-21 07:51:36 -04:00
argenis de la rosa 50b5bd4d73 ci: retrigger CI after stuck runners 2026-03-21 07:46:34 -04:00
argenis de la rosa 8c074870a1 fix(memory): replace redundant closures with function references
Clippy flagged `.map(|s| chrono::DateTime::parse_from_rfc3339(s))` as
redundant — use `.map(chrono::DateTime::parse_from_rfc3339)` directly.
2026-03-21 07:46:34 -04:00
argenis de la rosa 61d1841ce3 fix: update gateway mock Memory impls with since/until params
Both test mock implementations of Memory::recall() in gateway/mod.rs
were missing the new since/until parameters.
2026-03-21 07:46:34 -04:00
argenis de la rosa eb396cf38f feat(memory): add time range filter to recall (since/until)
Adopted from #3705 by @fangxueshun with fixes:
- Added input validation for date strings (RFC 3339)
- Used chrono DateTime comparison instead of string comparison
- Added since < until validation
- Updated mem0 backend
Supersedes #3705
2026-03-21 07:46:34 -04:00
argenis de la rosa 9f1657b9be fix(tools): use kill_on_drop for ClaudeCodeTool subprocess timeout 2026-03-21 07:46:24 -04:00
argenis de la rosa 8fecd4286c fix(tools): use kill_on_drop for ClaudeCodeTool subprocess timeout
Fixes E0382 borrow-after-move error: wait_with_output() consumed the
child handle, making child.kill() in the timeout branch invalid.
Use kill_on_drop(true) with cmd.output() instead.
2026-03-21 07:46:24 -04:00
argenis de la rosa df21d92da3 feat(tools): add ClaudeCodeTool for two-tier agent delegation
Adopted from #3748 by @ilyasubkhankulov with fixes:
- Removed unused _runtime field
- Fixed subprocess timeout handling
- Excluded unrelated Slack threading and Dockerfile changes

Closes #3748 (superseded)
2026-03-21 07:46:24 -04:00
Argenis 8d65924704 fix(channels): add cost tracking and enforcement to all channels (#4143)
Adds per-channel cost tracking via task-local context in the tool call
loop. Budget enforcement blocks further API calls when limits are
exceeded. Resolves merge conflicts with model-switch retry loop,
reply_target parameter, and autonomy level additions on master.

Supersedes #3758
2026-03-21 07:37:15 -04:00
Argenis 756c3cadff feat(transcription): add LocalWhisperProvider for self-hosted STT (TDD) (#4141)
Self-hosted Whisper-compatible STT provider that POSTs audio to a
configurable HTTP endpoint (e.g. faster-whisper over WireGuard). Audio
never leaves the platform perimeter.

Implemented via red/green TDD cycles:
  Wave 1 — config schema: LocalWhisperConfig struct, local_whisper field
    on TranscriptionConfig + Default impl, re-export in config/mod.rs
  Wave 2 — from_config validation: url non-empty, url parseable, bearer_token
    non-empty, max_audio_bytes > 0, timeout_secs > 0; returns Result<Self>
  Wave 3 — manager integration: registration with ? propagation (not if let Ok
    — credentials come directly from config, no env-var fallback; present
    section with bad values is a hard error, not a silent skip)
  Wave 4 — transcribe(): resolve_audio_format() extracted from validate_audio()
    so LocalWhisperProvider can resolve MIME without the 25 MB cloud cap;
    size check + format resolution before HTTP send
  Wave 5 — HTTP mock tests: success response, bearer auth header, 503 error

33 tests (20 baseline + 13 new), all passing. Clippy clean.

Co-authored-by: Nim G <theredspoon@users.noreply.github.com>
2026-03-21 07:15:36 -04:00
Argenis ee870028ff feat(channel): use Slack native markdown blocks for rich formatting (#4142)
Slack's Block Kit supports a native `markdown` block type that accepts
standard Markdown and handles rendering. This removes the need for a
custom Markdown-to-mrkdwn converter. Messages over 12,000 chars fall
back to plain text.

Co-authored-by: Joe Hoyle <joehoyle@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 07:12:27 -04:00
59 changed files with 3295 additions and 251 deletions
Generated
+1 -1
View File
@@ -9203,7 +9203,7 @@ dependencies = [
[[package]]
name = "zeroclawlabs"
version = "0.5.4"
version = "0.5.5"
dependencies = [
"aardvark-sys",
"anyhow",
+1 -1
View File
@@ -4,7 +4,7 @@ resolver = "2"
[package]
name = "zeroclawlabs"
version = "0.5.4"
version = "0.5.5"
edition = "2021"
authors = ["theonlyhennygod"]
license = "MIT OR Apache-2.0"
+1 -1
View File
@@ -263,7 +263,7 @@ fn bench_memory_operations(c: &mut Criterion) {
c.bench_function("memory_recall_top10", |b| {
b.iter(|| {
rt.block_on(async {
mem.recall(black_box("zeroclaw agent"), 10, None)
mem.recall(black_box("zeroclaw agent"), 10, None, None, None)
.await
.unwrap()
})
+2 -2
View File
@@ -1,6 +1,6 @@
pkgbase = zeroclaw
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
pkgver = 0.5.4
pkgver = 0.5.5
pkgrel = 1
url = https://github.com/zeroclaw-labs/zeroclaw
arch = x86_64
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
makedepends = git
depends = gcc-libs
depends = openssl
source = zeroclaw-0.5.4.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.4.tar.gz
source = zeroclaw-0.5.5.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.5.tar.gz
sha256sums = SKIP
pkgname = zeroclaw
+1 -1
View File
@@ -1,6 +1,6 @@
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
pkgname=zeroclaw
pkgver=0.5.4
pkgver=0.5.5
pkgrel=1
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
arch=('x86_64')
+2 -2
View File
@@ -1,11 +1,11 @@
{
"version": "0.5.4",
"version": "0.5.5",
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
"license": "MIT|Apache-2.0",
"architecture": {
"64bit": {
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.4/zeroclaw-x86_64-pc-windows-msvc.zip",
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.5/zeroclaw-x86_64-pc-windows-msvc.zip",
"hash": "",
"bin": "zeroclaw.exe"
}
+47 -2
View File
@@ -122,6 +122,34 @@ tools = ["mcp_browser_*"]
keywords = ["browse", "navigate", "open url", "screenshot"]
```
## `[pacing]`
Pacing controls for slow/local LLM workloads (Ollama, llama.cpp, vLLM). All keys are optional; when absent, existing behavior is preserved.
| Key | Default | Purpose |
|---|---|---|
| `step_timeout_secs` | _none_ | Per-step timeout: maximum seconds for a single LLM inference turn. Catches a truly hung model without terminating the overall task loop |
| `loop_detection_min_elapsed_secs` | _none_ | Minimum elapsed seconds before loop detection activates. Tasks completing under this threshold get aggressive loop protection; longer-running tasks receive a grace period |
| `loop_ignore_tools` | `[]` | Tool names excluded from identical-output loop detection. Useful for browser workflows where `browser_screenshot` structurally resembles a loop |
| `message_timeout_scale_max` | `4` | Override for the hardcoded timeout scaling cap. The channel message timeout budget is `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)` |
Notes:
- These settings are intended for local/slow LLM deployments. Cloud-provider users typically do not need them.
- `step_timeout_secs` operates independently of the total channel message timeout budget. A step timeout abort does not consume the overall budget; the loop simply stops.
- `loop_detection_min_elapsed_secs` delays loop-detection counting, not the task itself. Loop protection remains fully active for short tasks (the default).
- `loop_ignore_tools` only suppresses tool-output-based loop detection for the listed tools. Other safety features (max iterations, overall timeout) remain active.
- `message_timeout_scale_max` must be >= 1. Setting it higher than `max_tool_iterations` has no additional effect (the formula uses `min()`).
- Example configuration for a slow local Ollama deployment:
```toml
[pacing]
step_timeout_secs = 120
loop_detection_min_elapsed_secs = 60
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
message_timeout_scale_max = 8
```
## `[security.otp]`
| Key | Default | Purpose |
@@ -185,12 +213,15 @@ Delegate sub-agent configurations. Each key under `[agents]` defines a named sub
| `max_iterations` | `10` | Max tool-call iterations for agentic mode |
| `timeout_secs` | `120` | Timeout in seconds for non-agentic provider calls (13600) |
| `agentic_timeout_secs` | `300` | Timeout in seconds for agentic sub-agent loops (13600) |
| `skills_directory` | unset | Optional skills directory path (workspace-relative) for scoped skill loading |
Notes:
- `agentic = false` preserves existing single prompt→response delegate behavior.
- `agentic = true` requires at least one matching entry in `allowed_tools`.
- The `delegate` tool is excluded from sub-agent allowlists to prevent re-entrant delegation loops.
- Sub-agents receive an enriched system prompt containing: tools section (allowed tools with parameters), skills section (from scoped or default directory), workspace path, current date/time, safety constraints, and shell policy when `shell` is in the effective tool list.
- When `skills_directory` is unset or empty, the sub-agent loads skills from the default workspace `skills/` directory. When set, skills are loaded exclusively from that directory (relative to workspace root), enabling per-agent scoped skill sets.
```toml
[agents.researcher]
@@ -208,6 +239,14 @@ provider = "ollama"
model = "qwen2.5-coder:32b"
temperature = 0.2
timeout_secs = 60
[agents.code_reviewer]
provider = "anthropic"
model = "claude-opus-4-5"
system_prompt = "You are an expert code reviewer focused on security and performance."
agentic = true
allowed_tools = ["file_read", "shell"]
skills_directory = "skills/code-review"
```
## `[runtime]`
@@ -414,6 +453,12 @@ Notes:
| `port` | `42617` | gateway listen port |
| `require_pairing` | `true` | require pairing before bearer auth |
| `allow_public_bind` | `false` | block accidental public exposure |
| `path_prefix` | _(none)_ | URL path prefix for reverse-proxy deployments (e.g. `"/zeroclaw"`) |
When deploying behind a reverse proxy that maps ZeroClaw to a sub-path,
set `path_prefix` to that sub-path (e.g. `"/zeroclaw"`). All gateway
routes will be served under this prefix. The value must start with `/`
and must not end with `/`.
## `[autonomy]`
@@ -586,7 +631,7 @@ Top-level channel options are configured under `channels_config`.
| Key | Default | Purpose |
|---|---|---|
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x) |
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x, overridable via `[pacing].message_timeout_scale_max`) |
Examples:
@@ -601,7 +646,7 @@ Examples:
Notes:
- Default `300s` is optimized for on-device LLMs (Ollama) which are slower than cloud APIs.
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, 4)` and a minimum of `1`.
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, cap)` and a minimum of `1`. The default cap is `4`; override with `[pacing].message_timeout_scale_max`.
- This scaling avoids false timeouts when the first LLM turn is slow/retried but later tool-loop turns still need to complete.
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
+7
View File
@@ -568,6 +568,13 @@ then re-run bootstrap.
MSG
exit 0
fi
# Detect un-accepted Xcode/CLT license (causes `cc` to exit 69).
if ! /usr/bin/xcrun --show-sdk-path >/dev/null 2>&1; then
warn "Xcode license has not been accepted. Run:"
warn " sudo xcodebuild -license accept"
warn "then re-run this installer."
exit 1
fi
if ! have_cmd git; then
warn "git is not available. Install git (e.g., Homebrew) and re-run bootstrap."
fi
+451 -8
View File
@@ -1,5 +1,8 @@
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
use crate::config::schema::ModelPricing;
use crate::config::Config;
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
use crate::cost::CostTracker;
use crate::i18n::ToolDescriptions;
use crate::memory::{self, Memory, MemoryCategory};
use crate::multimodal;
@@ -23,6 +26,108 @@ use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
// ── Cost tracking via task-local ──
/// Context for cost tracking within the tool call loop.
/// Scoped via `tokio::task_local!` at call sites (channels, gateway).
#[derive(Clone)]
pub(crate) struct ToolLoopCostTrackingContext {
pub tracker: Arc<CostTracker>,
pub prices: Arc<std::collections::HashMap<String, ModelPricing>>,
}
impl ToolLoopCostTrackingContext {
pub(crate) fn new(
tracker: Arc<CostTracker>,
prices: Arc<std::collections::HashMap<String, ModelPricing>>,
) -> Self {
Self { tracker, prices }
}
}
tokio::task_local! {
pub(crate) static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
}
/// 3-tier model pricing lookup:
/// 1. Direct model name
/// 2. Qualified `provider/model`
/// 3. Suffix after last `/`
fn lookup_model_pricing<'a>(
prices: &'a std::collections::HashMap<String, ModelPricing>,
provider_name: &str,
model: &str,
) -> Option<&'a ModelPricing> {
prices
.get(model)
.or_else(|| prices.get(&format!("{provider_name}/{model}")))
.or_else(|| {
model
.rsplit_once('/')
.and_then(|(_, suffix)| prices.get(suffix))
})
}
/// Record token usage from an LLM response via the task-local cost tracker.
/// Returns `(total_tokens, cost_usd)` on success, `None` when not scoped or no usage.
fn record_tool_loop_cost_usage(
provider_name: &str,
model: &str,
usage: &crate::providers::traits::TokenUsage,
) -> Option<(u64, f64)> {
let input_tokens = usage.input_tokens.unwrap_or(0);
let output_tokens = usage.output_tokens.unwrap_or(0);
let total_tokens = input_tokens.saturating_add(output_tokens);
if total_tokens == 0 {
return None;
}
let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten()?;
let pricing = lookup_model_pricing(&ctx.prices, provider_name, model);
let cost_usage = CostTokenUsage::new(
model,
input_tokens,
output_tokens,
pricing.map_or(0.0, |entry| entry.input),
pricing.map_or(0.0, |entry| entry.output),
);
if pricing.is_none() {
tracing::debug!(
provider = provider_name,
model,
"Cost tracking recorded token usage with zero pricing (no pricing entry found)"
);
}
if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) {
tracing::warn!(
provider = provider_name,
model,
"Failed to record cost tracking usage: {error}"
);
}
Some((cost_usage.total_tokens, cost_usage.cost_usd))
}
/// Check budget before an LLM call. Returns `None` when no cost tracking
/// context is scoped (tests, delegate, CLI without cost config).
pub(crate) fn check_tool_loop_budget() -> Option<BudgetCheck> {
TOOL_LOOP_COST_TRACKING_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten()
.map(|ctx| {
ctx.tracker
.check_budget(0.0)
.unwrap_or(BudgetCheck::Allowed)
})
}
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
const STREAM_CHUNK_MIN_CHARS: usize = 80;
@@ -465,7 +570,7 @@ async fn build_context(
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
let relevant: Vec<_> = entries
.iter()
.filter(|e| match e.score {
@@ -2226,6 +2331,7 @@ pub(crate) async fn agent_turn(
dedup_exempt_tools,
activated_tools,
model_switch_callback,
&crate::config::PacingConfig::default(),
)
.await
}
@@ -2535,6 +2641,7 @@ pub(crate) async fn run_tool_call_loop(
dedup_exempt_tools: &[String],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
model_switch_callback: Option<ModelSwitchCallback>,
pacing: &crate::config::PacingConfig,
) -> Result<String> {
let max_iterations = if max_tool_iterations == 0 {
DEFAULT_MAX_TOOL_ITERATIONS
@@ -2543,6 +2650,14 @@ pub(crate) async fn run_tool_call_loop(
};
let turn_id = Uuid::new_v4().to_string();
let loop_started_at = Instant::now();
let loop_ignore_tools: HashSet<&str> = pacing
.loop_ignore_tools
.iter()
.map(String::as_str)
.collect();
let mut consecutive_identical_outputs: usize = 0;
let mut last_tool_output_hash: Option<u64> = None;
for iteration in 0..max_iterations {
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
@@ -2642,6 +2757,19 @@ pub(crate) async fn run_tool_call_loop(
hooks.fire_llm_input(history, model).await;
}
// Budget enforcement — block if limit exceeded (no-op when not scoped)
if let Some(BudgetCheck::Exceeded {
current_usd,
limit_usd,
period,
}) = check_tool_loop_budget()
{
return Err(anyhow::anyhow!(
"Budget exceeded: ${:.4} of ${:.2} {:?} limit. Cannot make further API calls until the budget resets.",
current_usd, limit_usd, period
));
}
// Unified path via Provider::chat so provider-specific native tool logic
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
let request_tools = if use_native_tools {
@@ -2659,13 +2787,43 @@ pub(crate) async fn run_tool_call_loop(
temperature,
);
let chat_result = if let Some(token) = cancellation_token.as_ref() {
tokio::select! {
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
result = chat_future => result,
// Wrap the LLM call with an optional per-step timeout from pacing config.
// This catches a truly hung model response without terminating the overall
// task loop (the per-message budget handles that separately).
let chat_result = match pacing.step_timeout_secs {
Some(step_secs) if step_secs > 0 => {
let step_timeout = Duration::from_secs(step_secs);
if let Some(token) = cancellation_token.as_ref() {
tokio::select! {
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
result = tokio::time::timeout(step_timeout, chat_future) => {
match result {
Ok(inner) => inner,
Err(_) => anyhow::bail!(
"LLM inference step timed out after {step_secs}s (step_timeout_secs)"
),
}
},
}
} else {
match tokio::time::timeout(step_timeout, chat_future).await {
Ok(inner) => inner,
Err(_) => anyhow::bail!(
"LLM inference step timed out after {step_secs}s (step_timeout_secs)"
),
}
}
}
_ => {
if let Some(token) = cancellation_token.as_ref() {
tokio::select! {
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
result = chat_future => result,
}
} else {
chat_future.await
}
}
} else {
chat_future.await
};
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
@@ -2687,6 +2845,12 @@ pub(crate) async fn run_tool_call_loop(
output_tokens: resp_output_tokens,
});
// Record cost via task-local tracker (no-op when not scoped)
let _ = resp
.usage
.as_ref()
.and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage));
let response_text = resp.text_or_empty().to_string();
// First try native structured tool calls (OpenAI-format).
// Fall back to text-based parsing (XML tags, markdown blocks,
@@ -3158,7 +3322,13 @@ pub(crate) async fn run_tool_call_loop(
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
}
// Collect tool results and build per-tool output for loop detection.
// Only non-ignored tool outputs contribute to the identical-output hash.
let mut detection_relevant_output = String::new();
for (tool_name, tool_call_id, outcome) in ordered_results.into_iter().flatten() {
if !loop_ignore_tools.contains(tool_name.as_str()) {
detection_relevant_output.push_str(&outcome.output);
}
individual_results.push((tool_call_id, outcome.output.clone()));
let _ = writeln!(
tool_results,
@@ -3167,6 +3337,53 @@ pub(crate) async fn run_tool_call_loop(
);
}
// ── Time-gated loop detection ──────────────────────────
// When pacing.loop_detection_min_elapsed_secs is set, identical-output
// loop detection activates after the task has been running that long.
// This avoids false-positive aborts on long-running browser/research
// workflows while keeping aggressive protection for quick tasks.
// When not configured, identical-output detection is disabled (preserving
// existing behavior where only max_iterations prevents runaway loops).
let loop_detection_active = match pacing.loop_detection_min_elapsed_secs {
Some(min_secs) => loop_started_at.elapsed() >= Duration::from_secs(min_secs),
None => false, // disabled when not configured (backwards compatible)
};
if loop_detection_active && !detection_relevant_output.is_empty() {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
detection_relevant_output.hash(&mut hasher);
let current_hash = hasher.finish();
if last_tool_output_hash == Some(current_hash) {
consecutive_identical_outputs += 1;
} else {
consecutive_identical_outputs = 0;
last_tool_output_hash = Some(current_hash);
}
// Bail if we see 3+ consecutive identical tool outputs (clear runaway).
if consecutive_identical_outputs >= 3 {
runtime_trace::record_event(
"tool_loop_identical_output_abort",
Some(channel_name),
Some(provider_name),
Some(model),
Some(&turn_id),
Some(false),
Some("identical tool output detected 3 consecutive times"),
serde_json::json!({
"iteration": iteration + 1,
"consecutive_identical": consecutive_identical_outputs,
}),
);
anyhow::bail!(
"Agent loop aborted: identical tool output detected {} consecutive times",
consecutive_identical_outputs
);
}
}
// Add assistant message with tool calls + tool results to history.
// Native mode: use JSON-structured messages so convert_messages() can
// reconstruct proper OpenAI-format tool_calls and tool result messages.
@@ -3716,6 +3933,7 @@ pub async fn run(
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
Some(model_switch_callback.clone()),
&config.pacing,
)
.await
{
@@ -3943,6 +4161,7 @@ pub async fn run(
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
Some(model_switch_callback.clone()),
&config.pacing,
)
.await
{
@@ -4840,6 +5059,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect_err("provider without vision support should fail");
@@ -4890,6 +5110,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect_err("oversized payload must fail");
@@ -4934,6 +5155,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("valid multimodal payload should pass");
@@ -5064,6 +5286,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("parallel execution should complete");
@@ -5134,6 +5357,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("cron_add delivery defaults should be injected");
@@ -5196,6 +5420,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("explicit delivery mode should be preserved");
@@ -5253,6 +5478,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("loop should finish after deduplicating repeated calls");
@@ -5322,6 +5548,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("non-interactive shell should succeed for low-risk command");
@@ -5382,6 +5609,7 @@ mod tests {
&exempt,
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("loop should finish with exempt tool executing twice");
@@ -5462,6 +5690,7 @@ mod tests {
&exempt,
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("loop should complete");
@@ -5519,6 +5748,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("native fallback id flow should complete");
@@ -5600,6 +5830,7 @@ mod tests {
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("native tool-call text should be relayed through on_delta");
@@ -6395,7 +6626,7 @@ Tail"#;
assert_eq!(mem.count().await.unwrap(), 2);
let recalled = mem.recall("45", 5, None).await.unwrap();
let recalled = mem.recall("45", 5, None, None, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
}
@@ -7585,6 +7816,7 @@ Let me check the result."#;
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("tool loop should complete");
@@ -7662,4 +7894,215 @@ Let me check the result."#;
let result = filter_by_allowed_tools(specs, Some(&allowed));
assert!(result.is_empty());
}
// ── Cost tracking tests ──
#[tokio::test]
async fn cost_tracking_records_usage_when_scoped() {
use super::{
run_tool_call_loop, ToolLoopCostTrackingContext, TOOL_LOOP_COST_TRACKING_CONTEXT,
};
use crate::config::schema::ModelPricing;
use crate::cost::CostTracker;
use crate::observability::noop::NoopObserver;
use std::collections::HashMap;
let provider = ScriptedProvider {
responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse {
text: Some("done".to_string()),
tool_calls: Vec::new(),
usage: Some(crate::providers::traits::TokenUsage {
input_tokens: Some(1_000),
output_tokens: Some(200),
cached_input_tokens: None,
}),
reasoning_content: None,
}]))),
capabilities: ProviderCapabilities::default(),
};
let observer = NoopObserver;
let workspace = tempfile::TempDir::new().unwrap();
let mut cost_config = crate::config::CostConfig {
enabled: true,
..crate::config::CostConfig::default()
};
cost_config.prices = HashMap::from([(
"mock-model".to_string(),
ModelPricing {
input: 3.0,
output: 15.0,
},
)]);
let tracker = Arc::new(CostTracker::new(cost_config.clone(), workspace.path()).unwrap());
let ctx = ToolLoopCostTrackingContext::new(
Arc::clone(&tracker),
Arc::new(cost_config.prices.clone()),
);
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
let result = TOOL_LOOP_COST_TRACKING_CONTEXT
.scope(
Some(ctx),
run_tool_call_loop(
&provider,
&mut history,
&[],
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"test",
None,
&crate::config::MultimodalConfig::default(),
2,
None,
None,
None,
&[],
&[],
None,
None,
&crate::config::PacingConfig::default(),
),
)
.await
.expect("tool loop should succeed");
assert_eq!(result, "done");
let summary = tracker.get_summary().unwrap();
assert_eq!(summary.request_count, 1);
assert_eq!(summary.total_tokens, 1_200);
assert!(summary.session_cost_usd > 0.0);
}
#[tokio::test]
async fn cost_tracking_enforces_budget() {
use super::{
run_tool_call_loop, ToolLoopCostTrackingContext, TOOL_LOOP_COST_TRACKING_CONTEXT,
};
use crate::config::schema::ModelPricing;
use crate::cost::CostTracker;
use crate::observability::noop::NoopObserver;
use std::collections::HashMap;
let provider = ScriptedProvider::from_text_responses(vec!["should not reach this"]);
let observer = NoopObserver;
let workspace = tempfile::TempDir::new().unwrap();
let cost_config = crate::config::CostConfig {
enabled: true,
daily_limit_usd: 0.001, // very low limit
..crate::config::CostConfig::default()
};
let tracker = Arc::new(CostTracker::new(cost_config.clone(), workspace.path()).unwrap());
// Record a usage that already exceeds the limit
tracker
.record_usage(crate::cost::types::TokenUsage::new(
"mock-model",
100_000,
50_000,
1.0,
1.0,
))
.unwrap();
let ctx = ToolLoopCostTrackingContext::new(
Arc::clone(&tracker),
Arc::new(HashMap::from([(
"mock-model".to_string(),
ModelPricing {
input: 1.0,
output: 1.0,
},
)])),
);
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
let err = TOOL_LOOP_COST_TRACKING_CONTEXT
.scope(
Some(ctx),
run_tool_call_loop(
&provider,
&mut history,
&[],
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"test",
None,
&crate::config::MultimodalConfig::default(),
2,
None,
None,
None,
&[],
&[],
None,
None,
&crate::config::PacingConfig::default(),
),
)
.await
.expect_err("should fail with budget exceeded");
assert!(
err.to_string().contains("Budget exceeded"),
"error should mention budget: {err}"
);
}
#[tokio::test]
async fn cost_tracking_is_noop_without_scope() {
use super::run_tool_call_loop;
use crate::observability::noop::NoopObserver;
// No TOOL_LOOP_COST_TRACKING_CONTEXT scoped — should run fine
let provider = ScriptedProvider {
responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse {
text: Some("ok".to_string()),
tool_calls: Vec::new(),
usage: Some(crate::providers::traits::TokenUsage {
input_tokens: Some(500),
output_tokens: Some(100),
cached_input_tokens: None,
}),
reasoning_content: None,
}]))),
capabilities: ProviderCapabilities::default(),
};
let observer = NoopObserver;
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
let result = run_tool_call_loop(
&provider,
&mut history,
&[],
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"test",
None,
&crate::config::MultimodalConfig::default(),
2,
None,
None,
None,
&[],
&[],
None,
None,
&crate::config::PacingConfig::default(),
)
.await
.expect("should succeed without cost scope");
assert_eq!(result, "ok");
}
}
+7 -1
View File
@@ -43,7 +43,9 @@ impl MemoryLoader for DefaultMemoryLoader {
user_message: &str,
session_id: Option<&str>,
) -> anyhow::Result<String> {
let entries = memory.recall(user_message, self.limit, session_id).await?;
let entries = memory
.recall(user_message, self.limit, session_id, None, None)
.await?;
if entries.is_empty() {
return Ok(String::new());
}
@@ -102,6 +104,8 @@ mod tests {
_query: &str,
limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if limit == 0 {
return Ok(vec![]);
@@ -163,6 +167,8 @@ mod tests {
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(self.entries.as_ref().clone())
}
+10 -1
View File
@@ -18,6 +18,8 @@ pub struct DingTalkChannel {
/// Per-chat session webhooks for sending replies (chatID -> webhook URL).
/// DingTalk provides a unique webhook URL with each incoming message.
session_webhooks: Arc<RwLock<HashMap<String, String>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
/// Response from DingTalk gateway connection registration.
@@ -34,11 +36,18 @@ impl DingTalkChannel {
client_secret,
allowed_users,
session_webhooks: Arc::new(RwLock::new(HashMap::new())),
proxy_url: None,
}
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client("channel.dingtalk")
crate::config::build_channel_proxy_client("channel.dingtalk", self.proxy_url.as_deref())
}
fn is_user_allowed(&self, user_id: &str) -> bool {
+10 -1
View File
@@ -18,6 +18,8 @@ pub struct DiscordChannel {
listen_to_bots: bool,
mention_only: bool,
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
impl DiscordChannel {
@@ -35,11 +37,18 @@ impl DiscordChannel {
listen_to_bots,
mention_only,
typing_handles: Mutex::new(HashMap::new()),
proxy_url: None,
}
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client("channel.discord")
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
}
/// Check if a Discord user ID is in the allowlist.
+16 -1
View File
@@ -380,6 +380,8 @@ pub struct LarkChannel {
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
impl LarkChannel {
@@ -423,6 +425,7 @@ impl LarkChannel {
receive_mode: crate::config::schema::LarkReceiveMode::default(),
tenant_token: Arc::new(RwLock::new(None)),
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
proxy_url: None,
}
}
@@ -444,6 +447,7 @@ impl LarkChannel {
platform,
);
ch.receive_mode = config.receive_mode.clone();
ch.proxy_url = config.proxy_url.clone();
ch
}
@@ -461,6 +465,7 @@ impl LarkChannel {
LarkPlatform::Lark,
);
ch.receive_mode = config.receive_mode.clone();
ch.proxy_url = config.proxy_url.clone();
ch
}
@@ -476,11 +481,15 @@ impl LarkChannel {
LarkPlatform::Feishu,
);
ch.receive_mode = config.receive_mode.clone();
ch.proxy_url = config.proxy_url.clone();
ch
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client(self.platform.proxy_service_key())
crate::config::build_channel_proxy_client(
self.platform.proxy_service_key(),
self.proxy_url.as_deref(),
)
}
fn channel_name(&self) -> &'static str {
@@ -2113,6 +2122,7 @@ mod tests {
use_feishu: false,
receive_mode: LarkReceiveMode::default(),
port: None,
proxy_url: None,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
@@ -2135,6 +2145,7 @@ mod tests {
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
@@ -2169,6 +2180,7 @@ mod tests {
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let ch = LarkChannel::from_config(&cfg);
@@ -2193,6 +2205,7 @@ mod tests {
use_feishu: true,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let ch = LarkChannel::from_lark_config(&cfg);
@@ -2214,6 +2227,7 @@ mod tests {
allowed_users: vec!["*".into()],
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let ch = LarkChannel::from_feishu_config(&cfg);
@@ -2386,6 +2400,7 @@ mod tests {
allowed_users: vec!["*".into()],
receive_mode: crate::config::schema::LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let ch_feishu = LarkChannel::from_feishu_config(&feishu_cfg);
assert_eq!(
+10 -1
View File
@@ -17,6 +17,8 @@ pub struct MattermostChannel {
mention_only: bool,
/// Handle for the background typing-indicator loop (aborted on stop_typing).
typing_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
impl MattermostChannel {
@@ -38,11 +40,18 @@ impl MattermostChannel {
thread_replies,
mention_only,
typing_handle: Mutex::new(None),
proxy_url: None,
}
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client("channel.mattermost")
crate::config::build_channel_proxy_client("channel.mattermost", self.proxy_url.as_deref())
}
/// Check if a user ID is in the allowlist.
+218 -44
View File
@@ -222,9 +222,21 @@ fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
fn channel_message_timeout_budget_secs(
message_timeout_secs: u64,
max_tool_iterations: usize,
) -> u64 {
channel_message_timeout_budget_secs_with_cap(
message_timeout_secs,
max_tool_iterations,
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP,
)
}
fn channel_message_timeout_budget_secs_with_cap(
message_timeout_secs: u64,
max_tool_iterations: usize,
scale_cap: u64,
) -> u64 {
let iterations = max_tool_iterations.max(1) as u64;
let scale = iterations.min(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
let scale = iterations.min(scale_cap);
message_timeout_secs.saturating_mul(scale)
}
@@ -313,6 +325,12 @@ impl InterruptOnNewMessageConfig {
}
}
#[derive(Clone)]
struct ChannelCostTrackingState {
tracker: Arc<crate::cost::CostTracker>,
prices: Arc<HashMap<String, crate::config::schema::ModelPricing>>,
}
#[derive(Clone)]
struct ChannelRuntimeContext {
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
@@ -355,6 +373,8 @@ struct ChannelRuntimeContext {
/// approval since no operator is present on channel runs.
approval_manager: Arc<ApprovalManager>,
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
cost_tracking: Option<ChannelCostTrackingState>,
pacing: crate::config::PacingConfig,
}
#[derive(Clone)]
@@ -1511,7 +1531,7 @@ async fn build_memory_context(
) -> String {
let mut context = String::new();
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
let mut included = 0usize;
let mut used_chars = 0usize;
@@ -2395,8 +2415,18 @@ async fn process_channel_message(
}
let model_switch_callback = get_model_switch_state();
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
let scale_cap = ctx
.pacing
.message_timeout_scale_max
.unwrap_or(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
let timeout_budget_secs = channel_message_timeout_budget_secs_with_cap(
ctx.message_timeout_secs,
ctx.max_tool_iterations,
scale_cap,
);
let cost_tracking_context = ctx.cost_tracking.clone().map(|state| {
crate::agent::loop_::ToolLoopCostTrackingContext::new(state.tracker, state.prices)
});
let llm_call_start = Instant::now();
#[allow(clippy::cast_possible_truncation)]
let elapsed_before_llm_ms = started_at.elapsed().as_millis() as u64;
@@ -2406,6 +2436,8 @@ async fn process_channel_message(
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
result = tokio::time::timeout(
Duration::from_secs(timeout_budget_secs),
crate::agent::loop_::TOOL_LOOP_COST_TRACKING_CONTEXT.scope(
cost_tracking_context.clone(),
run_tool_call_loop(
active_provider.as_ref(),
&mut history,
@@ -2433,6 +2465,8 @@ async fn process_channel_message(
ctx.tool_call_dedup_exempt.as_ref(),
ctx.activated_tools.as_ref(),
Some(model_switch_callback.clone()),
&ctx.pacing,
),
),
) => LlmExecutionResult::Completed(result),
};
@@ -3691,7 +3725,8 @@ fn collect_configured_channels(
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms)
.with_transcription(config.transcription.clone())
.with_tts(config.tts.clone())
.with_workspace_dir(config.workspace_dir.clone()),
.with_workspace_dir(config.workspace_dir.clone())
.with_proxy_url(tg.proxy_url.clone()),
),
});
}
@@ -3699,13 +3734,16 @@ fn collect_configured_channels(
if let Some(ref dc) = config.channels_config.discord {
channels.push(ConfiguredChannel {
display_name: "Discord",
channel: Arc::new(DiscordChannel::new(
dc.bot_token.clone(),
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
)),
channel: Arc::new(
DiscordChannel::new(
dc.bot_token.clone(),
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
)
.with_proxy_url(dc.proxy_url.clone()),
),
});
}
@@ -3722,7 +3760,8 @@ fn collect_configured_channels(
)
.with_thread_replies(sl.thread_replies.unwrap_or(true))
.with_group_reply_policy(sl.mention_only, Vec::new())
.with_workspace_dir(config.workspace_dir.clone()),
.with_workspace_dir(config.workspace_dir.clone())
.with_proxy_url(sl.proxy_url.clone()),
),
});
}
@@ -3730,14 +3769,17 @@ fn collect_configured_channels(
if let Some(ref mm) = config.channels_config.mattermost {
channels.push(ConfiguredChannel {
display_name: "Mattermost",
channel: Arc::new(MattermostChannel::new(
mm.url.clone(),
mm.bot_token.clone(),
mm.channel_id.clone(),
mm.allowed_users.clone(),
mm.thread_replies.unwrap_or(true),
mm.mention_only.unwrap_or(false),
)),
channel: Arc::new(
MattermostChannel::new(
mm.url.clone(),
mm.bot_token.clone(),
mm.channel_id.clone(),
mm.allowed_users.clone(),
mm.thread_replies.unwrap_or(true),
mm.mention_only.unwrap_or(false),
)
.with_proxy_url(mm.proxy_url.clone()),
),
});
}
@@ -3775,14 +3817,17 @@ fn collect_configured_channels(
if let Some(ref sig) = config.channels_config.signal {
channels.push(ConfiguredChannel {
display_name: "Signal",
channel: Arc::new(SignalChannel::new(
sig.http_url.clone(),
sig.account.clone(),
sig.group_id.clone(),
sig.allowed_from.clone(),
sig.ignore_attachments,
sig.ignore_stories,
)),
channel: Arc::new(
SignalChannel::new(
sig.http_url.clone(),
sig.account.clone(),
sig.group_id.clone(),
sig.allowed_from.clone(),
sig.ignore_attachments,
sig.ignore_stories,
)
.with_proxy_url(sig.proxy_url.clone()),
),
});
}
@@ -3799,12 +3844,15 @@ fn collect_configured_channels(
if wa.is_cloud_config() {
channels.push(ConfiguredChannel {
display_name: "WhatsApp",
channel: Arc::new(WhatsAppChannel::new(
wa.access_token.clone().unwrap_or_default(),
wa.phone_number_id.clone().unwrap_or_default(),
wa.verify_token.clone().unwrap_or_default(),
wa.allowed_numbers.clone(),
)),
channel: Arc::new(
WhatsAppChannel::new(
wa.access_token.clone().unwrap_or_default(),
wa.phone_number_id.clone().unwrap_or_default(),
wa.verify_token.clone().unwrap_or_default(),
wa.allowed_numbers.clone(),
)
.with_proxy_url(wa.proxy_url.clone()),
),
});
} else {
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
@@ -3861,11 +3909,12 @@ fn collect_configured_channels(
if let Some(ref wati_cfg) = config.channels_config.wati {
channels.push(ConfiguredChannel {
display_name: "WATI",
channel: Arc::new(WatiChannel::new(
channel: Arc::new(WatiChannel::new_with_proxy(
wati_cfg.api_token.clone(),
wati_cfg.api_url.clone(),
wati_cfg.tenant_id.clone(),
wati_cfg.allowed_numbers.clone(),
wati_cfg.proxy_url.clone(),
)),
});
}
@@ -3873,10 +3922,11 @@ fn collect_configured_channels(
if let Some(ref nc) = config.channels_config.nextcloud_talk {
channels.push(ConfiguredChannel {
display_name: "Nextcloud Talk",
channel: Arc::new(NextcloudTalkChannel::new(
channel: Arc::new(NextcloudTalkChannel::new_with_proxy(
nc.base_url.clone(),
nc.app_token.clone(),
nc.allowed_users.clone(),
nc.proxy_url.clone(),
)),
});
}
@@ -3948,11 +3998,14 @@ fn collect_configured_channels(
if let Some(ref dt) = config.channels_config.dingtalk {
channels.push(ConfiguredChannel {
display_name: "DingTalk",
channel: Arc::new(DingTalkChannel::new(
dt.client_id.clone(),
dt.client_secret.clone(),
dt.allowed_users.clone(),
)),
channel: Arc::new(
DingTalkChannel::new(
dt.client_id.clone(),
dt.client_secret.clone(),
dt.allowed_users.clone(),
)
.with_proxy_url(dt.proxy_url.clone()),
),
});
}
@@ -3965,7 +4018,8 @@ fn collect_configured_channels(
qq.app_secret.clone(),
qq.allowed_users.clone(),
)
.with_workspace_dir(config.workspace_dir.clone()),
.with_workspace_dir(config.workspace_dir.clone())
.with_proxy_url(qq.proxy_url.clone()),
),
});
}
@@ -4600,6 +4654,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
},
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
activated_tools: ch_activated_handle,
cost_tracking: crate::cost::CostTracker::get_or_init_global(
config.cost.clone(),
&config.workspace_dir,
)
.map(|tracker| ChannelCostTrackingState {
tracker,
prices: Arc::new(config.cost.prices.clone()),
}),
pacing: config.pacing.clone(),
});
// Hydrate in-memory conversation histories from persisted JSONL session files.
@@ -4696,6 +4759,49 @@ mod tests {
);
}
#[test]
fn channel_message_timeout_budget_with_custom_scale_cap() {
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 8, 8),
300 * 8
);
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 20, 8),
300 * 8
);
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 10, 1),
300
);
}
#[test]
fn pacing_config_defaults_preserve_existing_behavior() {
let pacing = crate::config::PacingConfig::default();
assert!(pacing.step_timeout_secs.is_none());
assert!(pacing.loop_detection_min_elapsed_secs.is_none());
assert!(pacing.loop_ignore_tools.is_empty());
assert!(pacing.message_timeout_scale_max.is_none());
}
#[test]
fn pacing_message_timeout_scale_max_overrides_default_cap() {
// Custom cap of 8 scales budget proportionally
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 10, 8),
300 * 8
);
// Default cap produces the standard behavior
assert_eq!(
channel_message_timeout_budget_secs_with_cap(
300,
10,
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
),
300 * CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
);
}
#[test]
fn context_window_overflow_error_detector_matches_known_messages() {
let overflow_err = anyhow::anyhow!(
@@ -4899,6 +5005,8 @@ mod tests {
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
assert!(compact_sender_history(&ctx, &sender));
@@ -5014,6 +5122,8 @@ mod tests {
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
@@ -5085,6 +5195,8 @@ mod tests {
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
@@ -5175,6 +5287,8 @@ mod tests {
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
assert!(rollback_orphan_user_turn(
@@ -5715,6 +5829,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -5795,6 +5911,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -5889,6 +6007,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -5968,6 +6088,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6057,6 +6179,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6167,6 +6291,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6258,6 +6384,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6364,6 +6492,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6455,6 +6585,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6536,6 +6668,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6584,6 +6718,8 @@ BTC is currently around $65,000 based on latest tool output."#
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new())
}
@@ -6636,6 +6772,8 @@ BTC is currently around $65,000 based on latest tool output."#
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(vec![crate::memory::MemoryEntry {
id: "entry-1".to_string(),
@@ -6728,6 +6866,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
@@ -6829,6 +6969,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -6944,7 +7086,9 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
query_classification: crate::config::QueryClassificationConfig::default(),
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -7058,6 +7202,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -7153,6 +7299,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -7232,6 +7380,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -7884,7 +8034,7 @@ BTC is currently around $65,000 based on latest tool output."#
assert_eq!(mem.count().await.unwrap(), 2);
let recalled = mem.recall("45", 5, None).await.unwrap();
let recalled = mem.recall("45", 5, None, None, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
}
@@ -7997,6 +8147,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8127,6 +8279,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8297,6 +8451,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8404,6 +8560,8 @@ BTC is currently around $65,000 based on latest tool output."#
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8721,6 +8879,7 @@ This is an example JSON object for profile settings."#;
thread_replies: Some(true),
mention_only: Some(false),
interrupt_on_new_message: false,
proxy_url: None,
});
let channels = collect_configured_channels(&config, "test");
@@ -8974,6 +9133,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
// Simulate a photo attachment message with [IMAGE:] marker.
@@ -9060,6 +9221,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9221,6 +9384,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9331,6 +9496,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9433,6 +9600,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9555,6 +9724,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9613,6 +9784,7 @@ This is an example JSON object for profile settings."#;
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
match build_channel_by_id(&config, "telegram") {
Ok(channel) => assert_eq!(channel.name(), "telegram"),
@@ -9814,6 +9986,8 @@ This is an example JSON object for profile settings."#;
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
+13 -1
View File
@@ -17,11 +17,23 @@ pub struct NextcloudTalkChannel {
impl NextcloudTalkChannel {
pub fn new(base_url: String, app_token: String, allowed_users: Vec<String>) -> Self {
Self::new_with_proxy(base_url, app_token, allowed_users, None)
}
pub fn new_with_proxy(
base_url: String,
app_token: String,
allowed_users: Vec<String>,
proxy_url: Option<String>,
) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
app_token,
allowed_users,
client: reqwest::Client::new(),
client: crate::config::build_channel_proxy_client(
"channel.nextcloud_talk",
proxy_url.as_deref(),
),
}
}
+10 -1
View File
@@ -285,6 +285,8 @@ pub struct QQChannel {
upload_cache: Arc<RwLock<HashMap<String, UploadCacheEntry>>>,
/// Passive reply tracker for QQ API rate limiting.
reply_tracker: Arc<RwLock<HashMap<String, ReplyRecord>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
impl QQChannel {
@@ -298,6 +300,7 @@ impl QQChannel {
workspace_dir: None,
upload_cache: Arc::new(RwLock::new(HashMap::new())),
reply_tracker: Arc::new(RwLock::new(HashMap::new())),
proxy_url: None,
}
}
@@ -307,8 +310,14 @@ impl QQChannel {
self
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client("channel.qq")
crate::config::build_channel_proxy_client("channel.qq", self.proxy_url.as_deref())
}
fn is_user_allowed(&self, user_id: &str) -> bool {
+14 -1
View File
@@ -28,6 +28,8 @@ pub struct SignalChannel {
allowed_from: Vec<String>,
ignore_attachments: bool,
ignore_stories: bool,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
// ── signal-cli SSE event JSON shapes ────────────────────────────
@@ -87,12 +89,23 @@ impl SignalChannel {
allowed_from,
ignore_attachments,
ignore_stories,
proxy_url: None,
}
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> Client {
let builder = Client::builder().connect_timeout(Duration::from_secs(10));
let builder = crate::config::apply_runtime_proxy_to_builder(builder, "channel.signal");
let builder = crate::config::apply_channel_proxy_to_builder(
builder,
"channel.signal",
self.proxy_url.as_deref(),
);
builder.build().expect("Signal HTTP client should build")
}
+78 -2
View File
@@ -32,6 +32,8 @@ pub struct SlackChannel {
workspace_dir: Option<PathBuf>,
/// Maps channel_id -> thread_ts for active assistant threads (used for status indicators).
active_assistant_thread: Mutex<HashMap<String, String>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
@@ -46,6 +48,7 @@ const SLACK_ATTACHMENT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
const SLACK_ATTACHMENT_IMAGE_INLINE_FALLBACK_MAX_BYTES: usize = 512 * 1024;
const SLACK_ATTACHMENT_TEXT_DOWNLOAD_MAX_BYTES: usize = 256 * 1024;
const SLACK_ATTACHMENT_TEXT_INLINE_MAX_CHARS: usize = 12_000;
const SLACK_MARKDOWN_BLOCK_MAX_CHARS: usize = 12_000;
const SLACK_ATTACHMENT_FILENAME_MAX_CHARS: usize = 128;
const SLACK_USER_CACHE_MAX_ENTRIES: usize = 1000;
const SLACK_ATTACHMENT_SAVE_SUBDIR: &str = "slack_files";
@@ -121,6 +124,7 @@ impl SlackChannel {
user_display_name_cache: Mutex::new(HashMap::new()),
workspace_dir: None,
active_assistant_thread: Mutex::new(HashMap::new()),
proxy_url: None,
}
}
@@ -148,8 +152,19 @@ impl SlackChannel {
self
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client_with_timeouts("channel.slack", 30, 10)
crate::config::build_channel_proxy_client_with_timeouts(
"channel.slack",
self.proxy_url.as_deref(),
30,
10,
)
}
/// Check if a Slack user ID is in the allowlist.
@@ -804,12 +819,13 @@ impl SlackChannel {
}
fn slack_media_http_client_no_redirect(&self) -> anyhow::Result<reqwest::Client> {
let builder = crate::config::apply_runtime_proxy_to_builder(
let builder = crate::config::apply_channel_proxy_to_builder(
reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(Duration::from_secs(30))
.connect_timeout(Duration::from_secs(10)),
"channel.slack",
self.proxy_url.as_deref(),
);
builder
.build()
@@ -2272,6 +2288,14 @@ impl Channel for SlackChannel {
"text": message.content
});
// Use Slack's native markdown block for rich formatting when content fits.
if message.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
body["blocks"] = serde_json::json!([{
"type": "markdown",
"text": message.content
}]);
}
if let Some(ts) = self.outbound_thread_ts(message) {
body["thread_ts"] = serde_json::json!(ts);
}
@@ -3630,6 +3654,58 @@ mod tests {
assert_ne!(key1, key2, "session key should differ per thread");
}
#[test]
fn slack_send_uses_markdown_blocks() {
let msg = SendMessage::new("**bold** and _italic_", "C123");
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
// Build the same JSON body that send() would construct.
let mut body = serde_json::json!({
"channel": msg.recipient,
"text": msg.content
});
if msg.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
body["blocks"] = serde_json::json!([{
"type": "markdown",
"text": msg.content
}]);
}
// Verify blocks are present with correct structure.
let blocks = body["blocks"]
.as_array()
.expect("blocks should be an array");
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0]["type"], "markdown");
assert_eq!(blocks[0]["text"], msg.content);
// text field kept as plaintext fallback.
assert_eq!(body["text"], msg.content);
// Suppress unused variable warning.
let _ = ch.name();
}
#[test]
fn slack_send_skips_markdown_blocks_for_long_content() {
let long_content = "x".repeat(SLACK_MARKDOWN_BLOCK_MAX_CHARS + 1);
let msg = SendMessage::new(long_content.clone(), "C123");
let mut body = serde_json::json!({
"channel": msg.recipient,
"text": msg.content
});
if msg.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
body["blocks"] = serde_json::json!([{
"type": "markdown",
"text": msg.content
}]);
}
assert!(
body.get("blocks").is_none(),
"blocks should not be set for oversized content"
);
}
#[tokio::test]
async fn start_typing_requires_thread_context() {
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
+10 -1
View File
@@ -337,6 +337,8 @@ pub struct TelegramChannel {
voice_chats: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
pending_voice:
Arc<std::sync::Mutex<std::collections::HashMap<String, (String, std::time::Instant)>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -379,6 +381,7 @@ impl TelegramChannel {
tts_config: None,
voice_chats: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
pending_voice: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
proxy_url: None,
}
}
@@ -388,6 +391,12 @@ impl TelegramChannel {
self
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
/// Configure workspace directory for saving downloaded attachments.
pub fn with_workspace_dir(mut self, dir: std::path::PathBuf) -> Self {
self.workspace_dir = Some(dir);
@@ -478,7 +487,7 @@ impl TelegramChannel {
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client("channel.telegram")
crate::config::build_channel_proxy_client("channel.telegram", self.proxy_url.as_deref())
}
fn normalize_identity(value: &str) -> String {
+389 -18
View File
@@ -80,17 +80,10 @@ fn resolve_transcription_api_key(config: &TranscriptionConfig) -> Result<String>
);
}
/// Validate audio data and resolve MIME type from file name.
/// Resolve MIME type and normalize filename from extension.
///
/// Returns `(normalized_filename, mime_type)` on success.
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
if audio_data.len() > MAX_AUDIO_BYTES {
bail!(
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
audio_data.len()
);
}
/// No size check — callers enforce their own limits.
fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> {
let normalized_name = normalize_audio_filename(file_name);
let extension = normalized_name
.rsplit_once('.')
@@ -98,13 +91,26 @@ fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'stati
.unwrap_or("");
let mime = mime_for_audio(extension).ok_or_else(|| {
anyhow::anyhow!(
"Unsupported audio format '.{extension}' — accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
"Unsupported audio format '.{extension}' — \
accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
)
})?;
Ok((normalized_name, mime))
}
/// Validate audio data and resolve MIME type from file name.
///
/// Enforces the 25 MB cloud API cap. Returns `(normalized_filename, mime_type)` on success.
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
if audio_data.len() > MAX_AUDIO_BYTES {
bail!(
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
audio_data.len()
);
}
resolve_audio_format(file_name)
}
// ── TranscriptionProvider trait ─────────────────────────────────
/// Trait for speech-to-text provider implementations.
@@ -586,21 +592,120 @@ impl TranscriptionProvider for GoogleSttProvider {
}
}
// ── LocalWhisperProvider ────────────────────────────────────────
/// Self-hosted faster-whisper-compatible STT provider.
///
/// POSTs audio as `multipart/form-data` (field name `file`) to a configurable
/// HTTP endpoint (e.g. faster-whisper on GEX44 over WireGuard). The endpoint
/// must return `{"text": "..."}`. No cloud API key required. Size limit is
/// configurable — not constrained by the 25 MB cloud API cap.
pub struct LocalWhisperProvider {
url: String,
bearer_token: String,
max_audio_bytes: usize,
timeout_secs: u64,
}
impl LocalWhisperProvider {
/// Build from config. Fails if `url` or `bearer_token` is empty, if `url`
/// is not a valid HTTP/HTTPS URL (scheme must be `http` or `https`), if
/// `max_audio_bytes` is zero, or if `timeout_secs` is zero.
pub fn from_config(config: &crate::config::LocalWhisperConfig) -> Result<Self> {
let url = config.url.trim().to_string();
anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty");
let parsed = url
.parse::<reqwest::Url>()
.with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?;
anyhow::ensure!(
matches!(parsed.scheme(), "http" | "https"),
"local_whisper: `url` must use http or https scheme, got {:?}",
parsed.scheme()
);
let bearer_token = config.bearer_token.trim().to_string();
anyhow::ensure!(
!bearer_token.is_empty(),
"local_whisper: `bearer_token` must not be empty"
);
anyhow::ensure!(
config.max_audio_bytes > 0,
"local_whisper: `max_audio_bytes` must be greater than zero"
);
anyhow::ensure!(
config.timeout_secs > 0,
"local_whisper: `timeout_secs` must be greater than zero"
);
Ok(Self {
url,
bearer_token,
max_audio_bytes: config.max_audio_bytes,
timeout_secs: config.timeout_secs,
})
}
}
#[async_trait]
impl TranscriptionProvider for LocalWhisperProvider {
fn name(&self) -> &str {
"local_whisper"
}
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
if audio_data.len() > self.max_audio_bytes {
bail!(
"Audio file too large ({} bytes, local_whisper max {})",
audio_data.len(),
self.max_audio_bytes
);
}
let (normalized_name, mime) = resolve_audio_format(file_name)?;
let client = crate::config::build_runtime_proxy_client("transcription.local_whisper");
// to_vec() clones the buffer for the multipart payload; peak memory per
// call is ~2× max_audio_bytes. TODO: replace with streaming upload once
// reqwest supports body streaming in multipart parts.
let file_part = Part::bytes(audio_data.to_vec())
.file_name(normalized_name)
.mime_str(mime)?;
let resp = client
.post(&self.url)
.bearer_auth(&self.bearer_token)
.multipart(Form::new().part("file", file_part))
.timeout(std::time::Duration::from_secs(self.timeout_secs))
.send()
.await
.context("Failed to send audio to local Whisper endpoint")?;
parse_whisper_response(resp).await
}
}
// ── Shared response parsing ─────────────────────────────────────
/// Parse a standard Whisper-compatible JSON response (`{ "text": "..." }`).
/// Parse a faster-whisper-compatible JSON response (`{ "text": "..." }`).
///
/// Checks HTTP status before attempting JSON parsing so that non-JSON error
/// bodies (plain text, HTML, empty 5xx) produce a readable status error
/// rather than a confusing "Failed to parse transcription response".
async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
bail!("Transcription API error ({}): {}", status, body.trim());
}
let body: serde_json::Value = resp
.json()
.await
.context("Failed to parse transcription response")?;
if !status.is_success() {
let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
bail!("Transcription API error ({}): {}", status, error_msg);
}
let text = body["text"]
.as_str()
.context("Transcription response missing 'text' field")?
@@ -657,6 +762,17 @@ impl TranscriptionManager {
}
}
if let Some(ref local_cfg) = config.local_whisper {
match LocalWhisperProvider::from_config(local_cfg) {
Ok(p) => {
providers.insert("local_whisper".to_string(), Box::new(p));
}
Err(e) => {
tracing::warn!("local_whisper config invalid, provider skipped: {e}");
}
}
}
let default_provider = config.default_provider.clone();
if config.enabled && !providers.contains_key(&default_provider) {
@@ -1036,5 +1152,260 @@ mod tests {
assert!(config.deepgram.is_none());
assert!(config.assemblyai.is_none());
assert!(config.google.is_none());
assert!(config.local_whisper.is_none());
}
// ── LocalWhisperProvider tests (TDD — added below as red/green cycles) ──
fn local_whisper_config(url: &str) -> crate::config::LocalWhisperConfig {
crate::config::LocalWhisperConfig {
url: url.to_string(),
bearer_token: "test-token".to_string(),
max_audio_bytes: 10 * 1024 * 1024,
timeout_secs: 30,
}
}
#[test]
fn local_whisper_rejects_empty_url() {
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
cfg.url = String::new();
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
assert!(
err.to_string().contains("`url` must not be empty"),
"got: {err}"
);
}
#[test]
fn local_whisper_rejects_invalid_url() {
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
cfg.url = "not-a-url".to_string();
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
assert!(err.to_string().contains("invalid `url`"), "got: {err}");
}
#[test]
fn local_whisper_rejects_non_http_url() {
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string();
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
assert!(err.to_string().contains("http or https"), "got: {err}");
}
#[test]
fn local_whisper_rejects_empty_bearer_token() {
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
cfg.bearer_token = String::new();
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
assert!(
err.to_string().contains("`bearer_token` must not be empty"),
"got: {err}"
);
}
#[test]
fn local_whisper_rejects_zero_max_audio_bytes() {
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
cfg.max_audio_bytes = 0;
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
assert!(
err.to_string()
.contains("`max_audio_bytes` must be greater than zero"),
"got: {err}"
);
}
#[test]
fn local_whisper_rejects_zero_timeout() {
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
cfg.timeout_secs = 0;
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
assert!(
err.to_string()
.contains("`timeout_secs` must be greater than zero"),
"got: {err}"
);
}
#[test]
fn local_whisper_registered_when_config_present() {
let mut config = TranscriptionConfig::default();
config.local_whisper = Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe"));
config.default_provider = "local_whisper".to_string();
let manager = TranscriptionManager::new(&config).unwrap();
assert!(
manager.available_providers().contains(&"local_whisper"),
"expected local_whisper in {:?}",
manager.available_providers()
);
}
#[test]
fn local_whisper_misconfigured_section_fails_manager_construction() {
// A misconfigured local_whisper section logs a warning and skips
// registration. When local_whisper is also the default_provider and
// transcription is enabled, the safety net in TranscriptionManager
// surfaces the error: "not configured".
let mut config = TranscriptionConfig::default();
let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
bad_cfg.bearer_token = String::new();
config.local_whisper = Some(bad_cfg);
config.enabled = true;
config.default_provider = "local_whisper".to_string();
let err = TranscriptionManager::new(&config).err().unwrap();
assert!(
err.to_string().contains("not configured"),
"expected 'not configured' from manager safety net, got: {err}"
);
}
#[test]
fn validate_audio_still_enforces_25mb_cap() {
// Regression: extracting resolve_audio_format() must not weaken validate_audio().
let at_limit = vec![0u8; MAX_AUDIO_BYTES];
assert!(validate_audio(&at_limit, "test.ogg").is_ok());
let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1];
let err = validate_audio(&over_limit, "test.ogg").unwrap_err();
assert!(err.to_string().contains("too large"));
}
#[tokio::test]
async fn local_whisper_rejects_oversized_audio() {
let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
let big = vec![0u8; cfg.max_audio_bytes + 1];
let err = provider.transcribe(&big, "voice.ogg").await.unwrap_err();
assert!(err.to_string().contains("too large"), "got: {err}");
}
#[tokio::test]
async fn local_whisper_rejects_unsupported_format() {
let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
let data = vec![0u8; 100];
let err = provider.transcribe(&data, "voice.aiff").await.unwrap_err();
assert!(
err.to_string().contains("Unsupported audio format"),
"got: {err}"
);
}
// ── LocalWhisperProvider HTTP mock tests ────────────────────
#[tokio::test]
async fn local_whisper_returns_text_from_response() {
use wiremock::matchers::{header_exists, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/transcribe"))
.and(header_exists("authorization"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"text": "hello world"})),
)
.mount(&server)
.await;
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
let result = provider
.transcribe(b"fake-audio", "voice.ogg")
.await
.unwrap();
assert_eq!(result, "hello world");
}
#[tokio::test]
async fn local_whisper_sends_bearer_auth_header() {
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/transcribe"))
.and(header("authorization", "Bearer test-token"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})),
)
.mount(&server)
.await;
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
let result = provider
.transcribe(b"fake-audio", "voice.ogg")
.await
.unwrap();
assert_eq!(result, "auth ok");
}
#[tokio::test]
async fn local_whisper_propagates_http_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/transcribe"))
.respond_with(
ResponseTemplate::new(503).set_body_json(
serde_json::json!({"error": {"message": "service unavailable"}}),
),
)
.mount(&server)
.await;
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
let err = provider
.transcribe(b"fake-audio", "voice.ogg")
.await
.unwrap_err();
assert!(
err.to_string().contains("503") || err.to_string().contains("service unavailable"),
"expected HTTP error, got: {err}"
);
}
#[tokio::test]
async fn local_whisper_propagates_non_json_http_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/transcribe"))
.respond_with(
ResponseTemplate::new(502)
.set_body_string("Bad Gateway")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
let err = provider
.transcribe(b"fake-audio", "voice.ogg")
.await
.unwrap_err();
assert!(err.to_string().contains("502"), "got: {err}");
assert!(
err.to_string().contains("Bad Gateway"),
"expected plain-text body in error, got: {err}"
);
}
}
+11 -1
View File
@@ -22,13 +22,23 @@ impl WatiChannel {
api_url: String,
tenant_id: Option<String>,
allowed_numbers: Vec<String>,
) -> Self {
Self::new_with_proxy(api_token, api_url, tenant_id, allowed_numbers, None)
}
pub fn new_with_proxy(
api_token: String,
api_url: String,
tenant_id: Option<String>,
allowed_numbers: Vec<String>,
proxy_url: Option<String>,
) -> Self {
Self {
api_token,
api_url,
tenant_id,
allowed_numbers,
client: crate::config::build_runtime_proxy_client("channel.wati"),
client: crate::config::build_channel_proxy_client("channel.wati", proxy_url.as_deref()),
}
}
+10 -1
View File
@@ -27,6 +27,8 @@ pub struct WhatsAppChannel {
endpoint_id: String,
verify_token: String,
allowed_numbers: Vec<String>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
}
impl WhatsAppChannel {
@@ -41,11 +43,18 @@ impl WhatsAppChannel {
endpoint_id,
verify_token,
allowed_numbers,
proxy_url: None,
}
}
/// Set a per-channel proxy URL that overrides the global proxy config.
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_runtime_proxy_client("channel.whatsapp")
crate::config::build_channel_proxy_client("channel.whatsapp", self.proxy_url.as_deref())
}
/// Check if a phone number is allowed (E.164 format: +1234567890)
+1 -1
View File
@@ -249,7 +249,7 @@ async fn check_memory_roundtrip(config: &crate::config::Config) -> CheckResult {
return CheckResult::fail("memory", format!("write failed: {e}"));
}
match mem.recall(test_key, 1, None).await {
match mem.recall(test_key, 1, None, None, None).await {
Ok(entries) if !entries.is_empty() => {
let _ = mem.forget(test_key).await;
CheckResult::pass("memory", "write/read/delete round-trip OK")
+28 -22
View File
@@ -4,31 +4,32 @@ pub mod workspace;
#[allow(unused_imports)]
pub use schema::{
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
apply_channel_proxy_to_builder, apply_runtime_proxy_to_builder, build_channel_proxy_client,
build_channel_proxy_client_with_timeouts, build_runtime_proxy_client,
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
ClassificationRule, CloudOpsConfig, ComposioConfig, Config, ConversationalAiConfig, CostConfig,
CronConfig, DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig,
DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig,
EstopConfig, FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig,
GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig, HardwareConfig, HardwareTransport,
HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig,
ImageProviderDalleConfig, ImageProviderFluxConfig, ImageProviderImagenConfig,
ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig, LarkConfig, LinkedInConfig,
LinkedInContentConfig, LinkedInImageConfig, MatrixConfig, McpConfig, McpServerConfig,
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
DEFAULT_GWS_SERVICES,
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
OtpConfig, OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
};
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
@@ -58,6 +59,7 @@ mod tests {
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
};
let discord = DiscordConfig {
@@ -67,6 +69,7 @@ mod tests {
listen_to_bots: false,
interrupt_on_new_message: false,
mention_only: false,
proxy_url: None,
};
let lark = LarkConfig {
@@ -79,6 +82,7 @@ mod tests {
use_feishu: false,
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
};
let feishu = FeishuConfig {
app_id: "app-id".into(),
@@ -88,6 +92,7 @@ mod tests {
allowed_users: vec![],
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
};
let nextcloud_talk = NextcloudTalkConfig {
@@ -95,6 +100,7 @@ mod tests {
app_token: "app-token".into(),
webhook_secret: None,
allowed_users: vec!["*".into()],
proxy_url: None,
};
assert_eq!(telegram.allowed_users.len(), 1);
+422 -2
View File
@@ -165,6 +165,10 @@ pub struct Config {
#[serde(default)]
pub agent: AgentConfig,
/// Pacing controls for slow/local LLM workloads (`[pacing]`).
#[serde(default)]
pub pacing: PacingConfig,
/// Skills loading and community repository behavior (`[skills]`).
#[serde(default)]
pub skills: SkillsConfig,
@@ -371,6 +375,10 @@ pub struct Config {
/// Verifiable Intent (VI) credential verification and issuance (`[verifiable_intent]`).
#[serde(default)]
pub verifiable_intent: VerifiableIntentConfig,
/// Claude Code tool configuration (`[claude_code]`).
#[serde(default)]
pub claude_code: ClaudeCodeConfig,
}
/// Multi-client workspace isolation configuration.
@@ -515,6 +523,10 @@ pub struct DelegateAgentConfig {
/// When `None`, falls back to `[delegate].agentic_timeout_secs` (default: 300).
#[serde(default)]
pub agentic_timeout_secs: Option<u64>,
/// Optional skills directory path (relative to workspace root) for scoped skill loading.
/// When unset or empty, the sub-agent falls back to the default workspace `skills/` directory.
#[serde(default)]
pub skills_directory: Option<String>,
}
fn default_delegate_timeout_secs() -> u64 {
@@ -784,6 +796,9 @@ pub struct TranscriptionConfig {
/// Google Cloud Speech-to-Text provider configuration.
#[serde(default)]
pub google: Option<GoogleSttConfig>,
/// Local/self-hosted Whisper-compatible STT provider.
#[serde(default)]
pub local_whisper: Option<LocalWhisperConfig>,
}
impl Default for TranscriptionConfig {
@@ -801,6 +816,7 @@ impl Default for TranscriptionConfig {
deepgram: None,
assemblyai: None,
google: None,
local_whisper: None,
}
}
}
@@ -1169,6 +1185,35 @@ pub struct GoogleSttConfig {
pub language_code: String,
}
/// Local/self-hosted Whisper-compatible STT endpoint (`[transcription.local_whisper]`).
///
/// Audio is sent over WireGuard; never leaves the platform perimeter.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct LocalWhisperConfig {
/// HTTP or HTTPS endpoint URL, e.g. `"http://10.10.0.1:8001/v1/transcribe"`.
pub url: String,
/// Bearer token for endpoint authentication.
pub bearer_token: String,
/// Maximum audio file size in bytes accepted by this endpoint.
/// Defaults to 25 MB — matching the cloud API cap for a safe out-of-the-box
/// experience. Self-hosted endpoints can accept much larger files; raise this
/// as needed, but note that each transcription call clones the audio buffer
/// into a multipart payload, so peak memory per request is ~2× this value.
#[serde(default = "default_local_whisper_max_audio_bytes")]
pub max_audio_bytes: usize,
/// Request timeout in seconds. Defaults to 300 (large files on local GPU).
#[serde(default = "default_local_whisper_timeout_secs")]
pub timeout_secs: u64,
}
fn default_local_whisper_max_audio_bytes() -> usize {
25 * 1024 * 1024
}
fn default_local_whisper_timeout_secs() -> u64 {
300
}
/// Agent orchestration configuration (`[agent]` section).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentConfig {
@@ -1236,6 +1281,43 @@ impl Default for AgentConfig {
}
}
// ── Pacing ────────────────────────────────────────────────────────
/// Pacing controls for slow/local LLM workloads (`[pacing]` section).
///
/// All fields are optional and default to values that preserve existing
/// behavior. When set, they extend — not replace — the existing timeout
/// and loop-detection subsystems.
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct PacingConfig {
/// Per-step timeout in seconds: the maximum time allowed for a single
/// LLM inference turn, independent of the total message budget.
/// `None` means no per-step timeout (existing behavior).
#[serde(default)]
pub step_timeout_secs: Option<u64>,
/// Minimum elapsed seconds before loop detection activates.
/// Tasks completing under this threshold get aggressive loop protection;
/// longer-running tasks receive a grace period before the detector starts
/// counting. `None` means loop detection is always active (existing behavior).
#[serde(default)]
pub loop_detection_min_elapsed_secs: Option<u64>,
/// Tool names excluded from identical-output / alternating-pattern loop
/// detection. Useful for browser workflows where `browser_screenshot`
/// structurally resembles a loop even when making progress.
#[serde(default)]
pub loop_ignore_tools: Vec<String>,
/// Override for the hardcoded timeout scaling cap (default: 4).
/// The channel message timeout budget is computed as:
/// `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)`
/// Raising this value lets long multi-step tasks with slow local models
/// receive a proportionally larger budget without inflating the base timeout.
#[serde(default)]
pub message_timeout_scale_max: Option<u64>,
}
/// Skills loading configuration (`[skills]` section).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
@@ -1611,6 +1693,12 @@ pub struct GatewayConfig {
#[serde(default)]
pub trust_forwarded_headers: bool,
/// Optional URL path prefix for reverse-proxy deployments.
/// When set, all gateway routes are served under this prefix.
/// Must start with `/` and must not end with `/`.
#[serde(default)]
pub path_prefix: Option<String>,
/// Maximum distinct client keys tracked by gateway rate limiter maps.
#[serde(default = "default_gateway_rate_limit_max_keys")]
pub rate_limit_max_keys: usize,
@@ -1683,6 +1771,7 @@ impl Default for GatewayConfig {
pair_rate_limit_per_minute: default_pair_rate_limit(),
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
trust_forwarded_headers: false,
path_prefix: None,
rate_limit_max_keys: default_gateway_rate_limit_max_keys(),
idempotency_ttl_secs: default_idempotency_ttl_secs(),
idempotency_max_keys: default_gateway_idempotency_max_keys(),
@@ -2874,6 +2963,60 @@ impl Default for ImageProviderFluxConfig {
}
}
// ── Claude Code ─────────────────────────────────────────────────
/// Claude Code CLI tool configuration (`[claude_code]` section).
///
/// Delegates coding tasks to the `claude -p` CLI. Authentication uses the
/// binary's own OAuth session (Max subscription) by default — no API key
/// needed unless `env_passthrough` includes `ANTHROPIC_API_KEY`.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ClaudeCodeConfig {
/// Enable the `claude_code` tool
#[serde(default)]
pub enabled: bool,
/// Maximum execution time in seconds (coding tasks can be long)
#[serde(default = "default_claude_code_timeout_secs")]
pub timeout_secs: u64,
/// Claude Code tools the subprocess is allowed to use
#[serde(default = "default_claude_code_allowed_tools")]
pub allowed_tools: Vec<String>,
/// Optional system prompt appended to Claude Code invocations
#[serde(default)]
pub system_prompt: Option<String>,
/// Maximum output size in bytes (2MB default)
#[serde(default = "default_claude_code_max_output_bytes")]
pub max_output_bytes: usize,
/// Extra env vars passed to the claude subprocess (e.g. ANTHROPIC_API_KEY for API-key billing)
#[serde(default)]
pub env_passthrough: Vec<String>,
}
fn default_claude_code_timeout_secs() -> u64 {
600
}
fn default_claude_code_allowed_tools() -> Vec<String> {
vec!["Read".into(), "Edit".into(), "Bash".into(), "Write".into()]
}
fn default_claude_code_max_output_bytes() -> usize {
2_097_152
}
impl Default for ClaudeCodeConfig {
fn default() -> Self {
Self {
enabled: false,
timeout_secs: default_claude_code_timeout_secs(),
allowed_tools: default_claude_code_allowed_tools(),
system_prompt: None,
max_output_bytes: default_claude_code_max_output_bytes(),
env_passthrough: Vec::new(),
}
}
}
// ── Proxy ───────────────────────────────────────────────────────
/// Proxy application scope — determines which outbound traffic uses the proxy.
@@ -3381,6 +3524,116 @@ pub fn build_runtime_proxy_client_with_timeouts(
client
}
/// Build an HTTP client for a channel, using an explicit per-channel proxy URL
/// when configured. Falls back to the global runtime proxy when `proxy_url` is
/// `None` or empty.
pub fn build_channel_proxy_client(service_key: &str, proxy_url: Option<&str>) -> reqwest::Client {
match normalize_proxy_url_option(proxy_url) {
Some(url) => build_explicit_proxy_client(service_key, &url, None, None),
None => build_runtime_proxy_client(service_key),
}
}
/// Build an HTTP client for a channel with custom timeouts, using an explicit
/// per-channel proxy URL when configured. Falls back to the global runtime
/// proxy when `proxy_url` is `None` or empty.
pub fn build_channel_proxy_client_with_timeouts(
service_key: &str,
proxy_url: Option<&str>,
timeout_secs: u64,
connect_timeout_secs: u64,
) -> reqwest::Client {
match normalize_proxy_url_option(proxy_url) {
Some(url) => build_explicit_proxy_client(
service_key,
&url,
Some(timeout_secs),
Some(connect_timeout_secs),
),
None => build_runtime_proxy_client_with_timeouts(
service_key,
timeout_secs,
connect_timeout_secs,
),
}
}
/// Apply an explicit proxy URL to a `reqwest::ClientBuilder`, returning the
/// modified builder. Used by channels that specify a per-channel `proxy_url`.
pub fn apply_channel_proxy_to_builder(
builder: reqwest::ClientBuilder,
service_key: &str,
proxy_url: Option<&str>,
) -> reqwest::ClientBuilder {
match normalize_proxy_url_option(proxy_url) {
Some(url) => apply_explicit_proxy_to_builder(builder, service_key, &url),
None => apply_runtime_proxy_to_builder(builder, service_key),
}
}
/// Build a client with a single explicit proxy URL (http+https via `Proxy::all`).
fn build_explicit_proxy_client(
service_key: &str,
proxy_url: &str,
timeout_secs: Option<u64>,
connect_timeout_secs: Option<u64>,
) -> reqwest::Client {
let cache_key = format!(
"explicit|{}|{}|timeout={}|connect_timeout={}",
service_key.trim().to_ascii_lowercase(),
proxy_url,
timeout_secs
.map(|v| v.to_string())
.unwrap_or_else(|| "none".to_string()),
connect_timeout_secs
.map(|v| v.to_string())
.unwrap_or_else(|| "none".to_string()),
);
if let Some(client) = runtime_proxy_cached_client(&cache_key) {
return client;
}
let mut builder = reqwest::Client::builder();
if let Some(t) = timeout_secs {
builder = builder.timeout(std::time::Duration::from_secs(t));
}
if let Some(ct) = connect_timeout_secs {
builder = builder.connect_timeout(std::time::Duration::from_secs(ct));
}
builder = apply_explicit_proxy_to_builder(builder, service_key, proxy_url);
let client = builder.build().unwrap_or_else(|error| {
tracing::warn!(
service_key,
proxy_url,
"Failed to build channel proxy client: {error}"
);
reqwest::Client::new()
});
set_runtime_proxy_cached_client(cache_key, client.clone());
client
}
/// Apply a single explicit proxy URL to a builder via `Proxy::all`.
fn apply_explicit_proxy_to_builder(
mut builder: reqwest::ClientBuilder,
service_key: &str,
proxy_url: &str,
) -> reqwest::ClientBuilder {
match reqwest::Proxy::all(proxy_url) {
Ok(proxy) => {
builder = builder.proxy(proxy);
}
Err(error) => {
tracing::warn!(
proxy_url,
service_key,
"Ignoring invalid channel proxy_url: {error}"
);
}
}
builder
}
fn parse_proxy_scope(raw: &str) -> Option<ProxyScope> {
match raw.trim().to_ascii_lowercase().as_str() {
"environment" | "env" => Some(ProxyScope::Environment),
@@ -4885,6 +5138,10 @@ pub struct TelegramConfig {
/// explicitly, it takes precedence.
#[serde(default)]
pub ack_reactions: Option<bool>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for TelegramConfig {
@@ -4918,6 +5175,10 @@ pub struct DiscordConfig {
/// Other messages in the guild are silently ignored.
#[serde(default)]
pub mention_only: bool,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for DiscordConfig {
@@ -4954,6 +5215,10 @@ pub struct SlackConfig {
/// Direct messages remain allowed.
#[serde(default)]
pub mention_only: bool,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for SlackConfig {
@@ -4989,6 +5254,10 @@ pub struct MattermostConfig {
/// cancels the in-flight request and starts a fresh response with preserved history.
#[serde(default)]
pub interrupt_on_new_message: bool,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for MattermostConfig {
@@ -5101,6 +5370,10 @@ pub struct SignalConfig {
/// Skip incoming story messages.
#[serde(default)]
pub ignore_stories: bool,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for SignalConfig {
@@ -5195,6 +5468,10 @@ pub struct WhatsAppConfig {
/// user's own self-chat (Notes to Self). Defaults to false.
#[serde(default)]
pub self_chat_mode: bool,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for WhatsAppConfig {
@@ -5243,6 +5520,10 @@ pub struct WatiConfig {
/// Allowed phone numbers (E.164 format) or "*" for all.
#[serde(default)]
pub allowed_numbers: Vec<String>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
fn default_wati_api_url() -> String {
@@ -5273,6 +5554,10 @@ pub struct NextcloudTalkConfig {
/// Allowed Nextcloud actor IDs (`[]` = deny all, `"*"` = allow all).
#[serde(default)]
pub allowed_users: Vec<String>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for NextcloudTalkConfig {
@@ -5400,6 +5685,10 @@ pub struct LarkConfig {
/// Not required (and ignored) for websocket mode.
#[serde(default)]
pub port: Option<u16>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for LarkConfig {
@@ -5434,6 +5723,10 @@ pub struct FeishuConfig {
/// Not required (and ignored) for websocket mode.
#[serde(default)]
pub port: Option<u16>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for FeishuConfig {
@@ -5895,6 +6188,10 @@ pub struct DingTalkConfig {
/// Allowed user IDs (staff IDs). Empty = deny all, "*" = allow all
#[serde(default)]
pub allowed_users: Vec<String>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for DingTalkConfig {
@@ -5935,6 +6232,10 @@ pub struct QQConfig {
/// Allowed user IDs. Empty = deny all, "*" = allow all
#[serde(default)]
pub allowed_users: Vec<String>,
/// Per-channel proxy URL (http, https, socks5, socks5h).
/// Overrides the global `[proxy]` setting for this channel only.
#[serde(default)]
pub proxy_url: Option<String>,
}
impl ChannelConfig for QQConfig {
@@ -6467,6 +6768,7 @@ impl Default for Config {
reliability: ReliabilityConfig::default(),
scheduler: SchedulerConfig::default(),
agent: AgentConfig::default(),
pacing: PacingConfig::default(),
skills: SkillsConfig::default(),
model_routes: Vec::new(),
embedding_routes: Vec::new(),
@@ -6512,6 +6814,7 @@ impl Default for Config {
plugins: PluginsConfig::default(),
locale: None,
verifiable_intent: VerifiableIntentConfig::default(),
claude_code: ClaudeCodeConfig::default(),
}
}
}
@@ -7571,6 +7874,31 @@ impl Config {
if self.gateway.host.trim().is_empty() {
anyhow::bail!("gateway.host must not be empty");
}
if let Some(ref prefix) = self.gateway.path_prefix {
// Validate the raw value — no silent trimming so the stored
// value is exactly what was validated.
if !prefix.is_empty() {
if !prefix.starts_with('/') {
anyhow::bail!("gateway.path_prefix must start with '/'");
}
if prefix.ends_with('/') {
anyhow::bail!("gateway.path_prefix must not end with '/' (including bare '/')");
}
// Reject characters unsafe for URL paths or HTML/JS injection.
// Whitespace is intentionally excluded from the allowed set.
if let Some(bad) = prefix.chars().find(|c| {
!matches!(c, '/' | '-' | '_' | '.' | '~'
| 'a'..='z' | 'A'..='Z' | '0'..='9'
| '!' | '$' | '&' | '\'' | '(' | ')' | '*' | '+' | ',' | ';' | '='
| ':' | '@')
}) {
anyhow::bail!(
"gateway.path_prefix contains invalid character '{bad}'; \
only unreserved and sub-delim URI characters are allowed"
);
}
}
}
// Autonomy
if self.autonomy.max_actions_per_hour == 0 {
@@ -8081,10 +8409,10 @@ impl Config {
{
let dp = self.transcription.default_provider.trim();
match dp {
"groq" | "openai" | "deepgram" | "assemblyai" | "google" => {}
"groq" | "openai" | "deepgram" | "assemblyai" | "google" | "local_whisper" => {}
other => {
anyhow::bail!(
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google (got '{other}')"
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google, local_whisper (got '{other}')"
);
}
}
@@ -9335,6 +9663,7 @@ default_temperature = 0.7
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
}),
discord: None,
slack: None,
@@ -9386,6 +9715,7 @@ default_temperature = 0.7
google_workspace: GoogleWorkspaceConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
pacing: PacingConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
peripherals: PeripheralsConfig::default(),
@@ -9407,6 +9737,7 @@ default_temperature = 0.7
plugins: PluginsConfig::default(),
locale: None,
verifiable_intent: VerifiableIntentConfig::default(),
claude_code: ClaudeCodeConfig::default(),
};
let toml_str = toml::to_string_pretty(&config).unwrap();
@@ -9656,6 +9987,47 @@ tool_dispatcher = "xml"
assert_eq!(parsed.agent.tool_dispatcher, "xml");
}
#[test]
async fn pacing_config_defaults_are_all_none_or_empty() {
let cfg = PacingConfig::default();
assert!(cfg.step_timeout_secs.is_none());
assert!(cfg.loop_detection_min_elapsed_secs.is_none());
assert!(cfg.loop_ignore_tools.is_empty());
assert!(cfg.message_timeout_scale_max.is_none());
}
#[test]
async fn pacing_config_deserializes_from_toml() {
let raw = r#"
default_temperature = 0.7
[pacing]
step_timeout_secs = 120
loop_detection_min_elapsed_secs = 60
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
message_timeout_scale_max = 8
"#;
let parsed: Config = toml::from_str(raw).unwrap();
assert_eq!(parsed.pacing.step_timeout_secs, Some(120));
assert_eq!(parsed.pacing.loop_detection_min_elapsed_secs, Some(60));
assert_eq!(
parsed.pacing.loop_ignore_tools,
vec!["browser_screenshot", "browser_navigate"]
);
assert_eq!(parsed.pacing.message_timeout_scale_max, Some(8));
}
#[test]
async fn pacing_config_absent_preserves_defaults() {
let raw = r#"
default_temperature = 0.7
"#;
let parsed: Config = toml::from_str(raw).unwrap();
assert!(parsed.pacing.step_timeout_secs.is_none());
assert!(parsed.pacing.loop_detection_min_elapsed_secs.is_none());
assert!(parsed.pacing.loop_ignore_tools.is_empty());
assert!(parsed.pacing.message_timeout_scale_max.is_none());
}
#[tokio::test]
async fn sync_directory_handles_existing_directory() {
let dir = std::env::temp_dir().join(format!(
@@ -9724,6 +10096,7 @@ tool_dispatcher = "xml"
google_workspace: GoogleWorkspaceConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
pacing: PacingConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
peripherals: PeripheralsConfig::default(),
@@ -9745,6 +10118,7 @@ tool_dispatcher = "xml"
plugins: PluginsConfig::default(),
locale: None,
verifiable_intent: VerifiableIntentConfig::default(),
claude_code: ClaudeCodeConfig::default(),
};
config.save().await.unwrap();
@@ -9789,6 +10163,7 @@ tool_dispatcher = "xml"
allowed_users: vec!["*".into()],
receive_mode: LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
});
config.agents.insert(
@@ -9805,6 +10180,7 @@ tool_dispatcher = "xml"
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
@@ -9930,6 +10306,7 @@ tool_dispatcher = "xml"
interrupt_on_new_message: true,
mention_only: false,
ack_reactions: None,
proxy_url: None,
};
let json = serde_json::to_string(&tc).unwrap();
let parsed: TelegramConfig = serde_json::from_str(&json).unwrap();
@@ -9958,6 +10335,7 @@ tool_dispatcher = "xml"
listen_to_bots: false,
interrupt_on_new_message: false,
mention_only: false,
proxy_url: None,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@@ -9974,6 +10352,7 @@ tool_dispatcher = "xml"
listen_to_bots: false,
interrupt_on_new_message: false,
mention_only: false,
proxy_url: None,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@@ -10075,6 +10454,7 @@ allowed_users = ["@ops:matrix.org"]
allowed_from: vec!["+1111111111".into()],
ignore_attachments: true,
ignore_stories: false,
proxy_url: None,
};
let json = serde_json::to_string(&sc).unwrap();
let parsed: SignalConfig = serde_json::from_str(&json).unwrap();
@@ -10095,6 +10475,7 @@ allowed_users = ["@ops:matrix.org"]
allowed_from: vec!["*".into()],
ignore_attachments: false,
ignore_stories: true,
proxy_url: None,
};
let toml_str = toml::to_string(&sc).unwrap();
let parsed: SignalConfig = toml::from_str(&toml_str).unwrap();
@@ -10325,6 +10706,7 @@ channel_id = "C123"
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
};
let json = serde_json::to_string(&wc).unwrap();
let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap();
@@ -10349,6 +10731,7 @@ channel_id = "C123"
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
};
let toml_str = toml::to_string(&wc).unwrap();
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
@@ -10378,6 +10761,7 @@ channel_id = "C123"
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
};
let toml_str = toml::to_string(&wc).unwrap();
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
@@ -10399,6 +10783,7 @@ channel_id = "C123"
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
};
assert!(wc.is_ambiguous_config());
assert_eq!(wc.backend_type(), "cloud");
@@ -10419,6 +10804,7 @@ channel_id = "C123"
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
};
assert!(!wc.is_ambiguous_config());
assert_eq!(wc.backend_type(), "web");
@@ -10449,6 +10835,7 @@ channel_id = "C123"
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
}),
linq: None,
wati: None,
@@ -10553,6 +10940,7 @@ channel_id = "C123"
pair_rate_limit_per_minute: 12,
webhook_rate_limit_per_minute: 80,
trust_forwarded_headers: true,
path_prefix: Some("/zeroclaw".into()),
rate_limit_max_keys: 2048,
idempotency_ttl_secs: 600,
idempotency_max_keys: 4096,
@@ -10570,6 +10958,7 @@ channel_id = "C123"
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
assert!(parsed.trust_forwarded_headers);
assert_eq!(parsed.path_prefix.as_deref(), Some("/zeroclaw"));
assert_eq!(parsed.rate_limit_max_keys, 2048);
assert_eq!(parsed.idempotency_ttl_secs, 600);
assert_eq!(parsed.idempotency_max_keys, 4096);
@@ -11453,6 +11842,7 @@ default_model = "legacy-model"
allowed_users: vec!["*".into()],
receive_mode: LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
});
config.save().await.unwrap();
@@ -12164,6 +12554,7 @@ default_model = "persisted-profile"
use_feishu: true,
receive_mode: LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
@@ -12187,6 +12578,7 @@ default_model = "persisted-profile"
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
@@ -12233,6 +12625,7 @@ default_model = "persisted-profile"
allowed_users: vec!["user_123".into(), "user_456".into()],
receive_mode: LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
};
let json = serde_json::to_string(&fc).unwrap();
let parsed: FeishuConfig = serde_json::from_str(&json).unwrap();
@@ -12253,6 +12646,7 @@ default_model = "persisted-profile"
allowed_users: vec!["*".into()],
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
proxy_url: None,
};
let toml_str = toml::to_string(&fc).unwrap();
let parsed: FeishuConfig = toml::from_str(&toml_str).unwrap();
@@ -12280,6 +12674,7 @@ default_model = "persisted-profile"
app_token: "app-token".into(),
webhook_secret: Some("webhook-secret".into()),
allowed_users: vec!["user_a".into(), "*".into()],
proxy_url: None,
};
let json = serde_json::to_string(&nc).unwrap();
@@ -12467,6 +12862,30 @@ require_otp_to_resume = true
assert!(err.to_string().contains("gated_domains"));
}
#[test]
async fn validate_accepts_local_whisper_as_transcription_default_provider() {
let mut config = Config::default();
config.transcription.default_provider = "local_whisper".to_string();
config.validate().expect(
"local_whisper must be accepted by the transcription.default_provider allowlist",
);
}
#[test]
async fn validate_rejects_unknown_transcription_default_provider() {
let mut config = Config::default();
config.transcription.default_provider = "unknown_stt".to_string();
let err = config
.validate()
.expect_err("expected validation to reject unknown transcription provider");
assert!(
err.to_string().contains("transcription.default_provider"),
"got: {err}"
);
}
#[tokio::test]
async fn channel_secret_telegram_bot_token_roundtrip() {
let dir = std::env::temp_dir().join(format!(
@@ -12488,6 +12907,7 @@ require_otp_to_resume = true
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
// Save (triggers encryption)
+30 -1
View File
@@ -7,7 +7,7 @@ use std::collections::HashMap;
use std::fs::{self, File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
/// Cost tracker for API usage monitoring and budget enforcement.
pub struct CostTracker {
@@ -175,6 +175,35 @@ impl CostTracker {
}
}
// ── Process-global singleton ────────────────────────────────────────
// Both the gateway and the channels supervisor share a single CostTracker
// so that budget enforcement is consistent across all paths.
static GLOBAL_COST_TRACKER: OnceLock<Option<Arc<CostTracker>>> = OnceLock::new();
impl CostTracker {
/// Return the process-global `CostTracker`, creating it on first call.
/// Subsequent calls (from gateway or channels, whichever starts second)
/// receive the same `Arc`. Returns `None` when cost tracking is disabled
/// or initialisation fails.
pub fn get_or_init_global(config: CostConfig, workspace_dir: &Path) -> Option<Arc<Self>> {
GLOBAL_COST_TRACKER
.get_or_init(|| {
if !config.enabled {
return None;
}
match Self::new(config, workspace_dir) {
Ok(ct) => Some(Arc::new(ct)),
Err(e) => {
tracing::warn!("Failed to initialize global cost tracker: {e}");
None
}
}
})
.clone()
}
}
fn resolve_storage_path(workspace_dir: &Path) -> Result<PathBuf> {
let storage_path = workspace_dir.join("state").join("costs.jsonl");
let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db");
+7
View File
@@ -646,6 +646,7 @@ mod tests {
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
assert!(has_supervised_channels(&config));
}
@@ -657,6 +658,7 @@ mod tests {
client_id: "client_id".into(),
client_secret: "client_secret".into(),
allowed_users: vec!["*".into()],
proxy_url: None,
});
assert!(has_supervised_channels(&config));
}
@@ -672,6 +674,7 @@ mod tests {
thread_replies: Some(true),
mention_only: Some(false),
interrupt_on_new_message: false,
proxy_url: None,
});
assert!(has_supervised_channels(&config));
}
@@ -683,6 +686,7 @@ mod tests {
app_id: "app-id".into(),
app_secret: "app-secret".into(),
allowed_users: vec!["*".into()],
proxy_url: None,
});
assert!(has_supervised_channels(&config));
}
@@ -695,6 +699,7 @@ mod tests {
app_token: "app-token".into(),
webhook_secret: None,
allowed_users: vec!["*".into()],
proxy_url: None,
});
assert!(has_supervised_channels(&config));
}
@@ -761,6 +766,7 @@ mod tests {
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
let target = resolve_heartbeat_delivery(&config).unwrap();
@@ -778,6 +784,7 @@ mod tests {
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
let target = resolve_heartbeat_delivery(&config).unwrap();
+2
View File
@@ -1283,6 +1283,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
config.agents.insert(
@@ -1299,6 +1300,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
+17 -3
View File
@@ -50,6 +50,10 @@ fn require_auth(
pub struct MemoryQuery {
pub query: Option<String>,
pub category: Option<String>,
/// Filter memories created at or after (RFC 3339 / ISO 8601)
pub since: Option<String>,
/// Filter memories created at or before (RFC 3339 / ISO 8601)
pub until: Option<String>,
}
#[derive(Deserialize)]
@@ -633,9 +637,12 @@ pub async fn handle_api_memory_list(
return e.into_response();
}
if let Some(ref query) = params.query {
// Search mode
match state.mem.recall(query, 50, None).await {
// Use recall when query or time range is provided
if params.query.is_some() || params.since.is_some() || params.until.is_some() {
let query = params.query.as_deref().unwrap_or("");
let since = params.since.as_deref();
let until = params.until.as_deref();
match state.mem.recall(query, 50, None, since, until).await {
Ok(entries) => Json(serde_json::json!({"entries": entries})).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
@@ -1356,6 +1363,8 @@ mod tests {
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@@ -1429,6 +1438,7 @@ mod tests {
session_backend: None,
device_registry: None,
pending_pairings: None,
path_prefix: String::new(),
}
}
@@ -1457,6 +1467,7 @@ mod tests {
api_url: "https://live-mt-server.wati.io".to_string(),
tenant_id: None,
allowed_numbers: vec![],
proxy_url: None,
});
cfg.channels_config.feishu = Some(crate::config::schema::FeishuConfig {
app_id: "cli_aabbcc".to_string(),
@@ -1466,6 +1477,7 @@ mod tests {
allowed_users: vec!["*".to_string()],
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
});
cfg.channels_config.email = Some(crate::channels::email_channel::EmailConfig {
imap_host: "imap.example.com".to_string(),
@@ -1591,6 +1603,7 @@ mod tests {
api_url: "https://live-mt-server.wati.io".to_string(),
tenant_id: None,
allowed_numbers: vec![],
proxy_url: None,
});
current.channels_config.feishu = Some(crate::config::schema::FeishuConfig {
app_id: "cli_current".to_string(),
@@ -1600,6 +1613,7 @@ mod tests {
allowed_users: vec!["*".to_string()],
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
});
current.channels_config.email = Some(crate::channels::email_channel::EmailConfig {
imap_host: "imap.example.com".to_string(),
+61 -34
View File
@@ -348,6 +348,8 @@ pub struct AppState {
pub shutdown_tx: tokio::sync::watch::Sender<bool>,
/// Registry of dynamically connected nodes
pub node_registry: Arc<nodes::NodeRegistry>,
/// Path prefix for reverse-proxy deployments (empty string = no prefix)
pub path_prefix: String,
/// Session backend for persisting gateway WS chat sessions
pub session_backend: Option<Arc<dyn SessionBackend>>,
/// Device registry for paired device management
@@ -505,18 +507,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let tools_registry: Arc<Vec<ToolSpec>> =
Arc::new(tools_registry_raw.iter().map(|t| t.spec()).collect());
// Cost tracker (optional)
let cost_tracker = if config.cost.enabled {
match CostTracker::new(config.cost.clone(), &config.workspace_dir) {
Ok(ct) => Some(Arc::new(ct)),
Err(e) => {
tracing::warn!("Failed to initialize cost tracker: {e}");
None
}
}
} else {
None
};
// Cost tracker — process-global singleton so channels share the same instance
let cost_tracker = CostTracker::get_or_init_global(config.cost.clone(), &config.workspace_dir);
// SSE broadcast channel for real-time events
let (event_tx, _event_rx) = tokio::sync::broadcast::channel::<serde_json::Value>(256);
@@ -683,6 +675,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
idempotency_max_keys,
));
// Resolve optional path prefix for reverse-proxy deployments.
let path_prefix: Option<&str> = config
.gateway
.path_prefix
.as_deref()
.filter(|p| !p.is_empty());
// ── Tunnel ────────────────────────────────────────────────
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
let mut tunnel_url: Option<String> = None;
@@ -701,18 +700,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
}
}
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
let pfx = path_prefix.unwrap_or("");
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}{pfx}");
if let Some(ref url) = tunnel_url {
println!(" 🌐 Public URL: {url}");
}
println!(" 🌐 Web Dashboard: http://{display_addr}/");
println!(" 🌐 Web Dashboard: http://{display_addr}{pfx}/");
if let Some(code) = pairing.pairing_code() {
println!();
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
println!(" ┌──────────────┐");
println!("{code}");
println!(" └──────────────┘");
println!();
println!(" Send: POST {pfx}/pair with header X-Pairing-Code: {code}");
} else if pairing.require_pairing() {
println!(" 🔒 Pairing: ACTIVE (bearer token required)");
println!(" To pair a new device: zeroclaw gateway get-paircode --new");
@@ -721,29 +721,29 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
println!();
}
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
println!(" POST {pfx}/pair — pair a new client (X-Pairing-Code header)");
println!(" POST {pfx}/webhook — {{\"message\": \"your prompt\"}}");
if whatsapp_channel.is_some() {
println!(" GET /whatsapp — Meta webhook verification");
println!(" POST /whatsapp — WhatsApp message webhook");
println!(" GET {pfx}/whatsapp — Meta webhook verification");
println!(" POST {pfx}/whatsapp — WhatsApp message webhook");
}
if linq_channel.is_some() {
println!(" POST /linq — Linq message webhook (iMessage/RCS/SMS)");
println!(" POST {pfx}/linq — Linq message webhook (iMessage/RCS/SMS)");
}
if wati_channel.is_some() {
println!(" GET /wati — WATI webhook verification");
println!(" POST /wati — WATI message webhook");
println!(" GET {pfx}/wati — WATI webhook verification");
println!(" POST {pfx}/wati — WATI message webhook");
}
if nextcloud_talk_channel.is_some() {
println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook");
println!(" POST {pfx}/nextcloud-talk — Nextcloud Talk bot webhook");
}
println!(" GET /api/* — REST API (bearer token required)");
println!(" GET /ws/chat — WebSocket agent chat");
println!(" GET {pfx}/api/* — REST API (bearer token required)");
println!(" GET {pfx}/ws/chat — WebSocket agent chat");
if config.nodes.enabled {
println!(" GET /ws/nodes — WebSocket node discovery");
println!(" GET {pfx}/ws/nodes — WebSocket node discovery");
}
println!(" GET /health — health check");
println!(" GET /metrics — Prometheus metrics");
println!(" GET {pfx}/health — health check");
println!(" GET {pfx}/metrics — Prometheus metrics");
println!(" Press Ctrl+C to stop.\n");
crate::health::mark_component_ok("gateway");
@@ -809,6 +809,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
session_backend,
device_registry,
pending_pairings,
path_prefix: path_prefix.unwrap_or("").to_string(),
};
// Config PUT needs larger body limit (1MB)
@@ -817,7 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.layer(RequestBodyLimitLayer::new(1_048_576));
// Build router with middleware
let app = Router::new()
let inner = Router::new()
// ── Admin routes (for CLI management) ──
.route("/admin/shutdown", post(handle_admin_shutdown))
.route("/admin/paircode", get(handle_admin_paircode))
@@ -877,12 +878,12 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
// ── Plugin management API (requires plugins-wasm feature) ──
#[cfg(feature = "plugins-wasm")]
let app = app.route(
let inner = inner.route(
"/api/plugins",
get(api_plugins::plugin_routes::list_plugins),
);
let app = app
let inner = inner
// ── SSE event stream ──
.route("/api/events", get(sse::handle_sse_events))
// ── WebSocket agent chat ──
@@ -893,14 +894,27 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.route("/_app/{*path}", get(static_files::handle_static))
// ── Config PUT with larger body limit ──
.merge(config_put_router)
// ── SPA fallback: non-API GET requests serve index.html ──
.fallback(get(static_files::handle_spa_fallback))
.with_state(state)
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(gateway_request_timeout_secs()),
))
// ── SPA fallback: non-API GET requests serve index.html ──
.fallback(get(static_files::handle_spa_fallback));
));
// Nest under path prefix when configured (axum strips prefix before routing).
// nest() at "/prefix" handles both "/prefix" and "/prefix/*" but not "/prefix/"
// with a trailing slash, so we add a fallback redirect for that case.
let app = if let Some(prefix) = path_prefix {
let redirect_target = prefix.to_string();
Router::new().nest(prefix, inner).route(
&format!("{prefix}/"),
get(|| async move { axum::response::Redirect::permanent(&redirect_target) }),
)
} else {
inner
};
// Run the server with graceful shutdown
axum::serve(
@@ -1992,6 +2006,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2047,6 +2062,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2287,6 +2303,8 @@ mod tests {
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@@ -2362,6 +2380,8 @@ mod tests {
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@@ -2427,6 +2447,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2496,6 +2517,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2577,6 +2599,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2630,6 +2653,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2688,6 +2712,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2751,6 +2776,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
@@ -2810,6 +2836,7 @@ mod tests {
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
+33 -5
View File
@@ -3,11 +3,14 @@
//! Uses `rust-embed` to bundle the `web/dist/` directory into the binary at compile time.
use axum::{
extract::State,
http::{header, StatusCode, Uri},
response::{IntoResponse, Response},
};
use rust_embed::Embed;
use super::AppState;
#[derive(Embed)]
#[folder = "web/dist/"]
struct WebAssets;
@@ -23,16 +26,41 @@ pub async fn handle_static(uri: Uri) -> Response {
serve_embedded_file(path)
}
/// SPA fallback: serve index.html for any non-API, non-static GET request
pub async fn handle_spa_fallback() -> Response {
if WebAssets::get("index.html").is_none() {
/// SPA fallback: serve index.html for any non-API, non-static GET request.
/// Injects `window.__ZEROCLAW_BASE__` so the frontend knows the path prefix.
pub async fn handle_spa_fallback(State(state): State<AppState>) -> Response {
let Some(content) = WebAssets::get("index.html") else {
return (
StatusCode::SERVICE_UNAVAILABLE,
"Web dashboard not available. Build it with: cd web && npm ci && npm run build",
)
.into_response();
}
serve_embedded_file("index.html")
};
let html = String::from_utf8_lossy(&content.data);
// Inject path prefix for the SPA and rewrite asset paths in the HTML
let html = if state.path_prefix.is_empty() {
html.into_owned()
} else {
let pfx = &state.path_prefix;
// JSON-encode the prefix to safely embed in a <script> block
let json_pfx = serde_json::to_string(pfx).unwrap_or_else(|_| "\"\"".to_string());
let script = format!("<script>window.__ZEROCLAW_BASE__={json_pfx};</script>");
// Rewrite absolute /_app/ references so the browser requests {prefix}/_app/...
html.replace("/_app/", &format!("{pfx}/_app/"))
.replace("<head>", &format!("<head>{script}"))
};
(
StatusCode::OK,
[
(header::CONTENT_TYPE, "text/html; charset=utf-8".to_string()),
(header::CACHE_CONTROL, "no-cache".to_string()),
],
html,
)
.into_response()
}
fn serve_embedded_file(path: &str) -> Response {
+1
View File
@@ -841,6 +841,7 @@ mod tests {
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
let entries = all_integrations();
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
+47 -7
View File
@@ -325,8 +325,27 @@ impl Memory for LucidMemory {
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let local_results = self.local.recall(query, limit, session_id).await?;
let since_dt = since
.map(chrono::DateTime::parse_from_rfc3339)
.transpose()
.map_err(|e| anyhow::anyhow!("invalid 'since' date (expected RFC 3339): {e}"))?;
let until_dt = until
.map(chrono::DateTime::parse_from_rfc3339)
.transpose()
.map_err(|e| anyhow::anyhow!("invalid 'until' date (expected RFC 3339): {e}"))?;
if let (Some(s), Some(u)) = (&since_dt, &until_dt) {
if s >= u {
anyhow::bail!("'since' must be before 'until'");
}
}
let local_results = self
.local
.recall(query, limit, session_id, since, until)
.await?;
if limit == 0
|| local_results.len() >= limit
|| local_results.len() >= self.local_hit_threshold
@@ -341,7 +360,28 @@ impl Memory for LucidMemory {
match self.recall_from_lucid(query).await {
Ok(lucid_results) if !lucid_results.is_empty() => {
self.clear_failure();
Ok(Self::merge_results(local_results, lucid_results, limit))
let merged = Self::merge_results(local_results, lucid_results, limit);
let filtered: Vec<MemoryEntry> = merged
.into_iter()
.filter(|e| {
if let Some(ref s) = since_dt {
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&e.timestamp) {
if ts < *s {
return false;
}
}
}
if let Some(ref u) = until_dt {
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&e.timestamp) {
if ts > *u {
return false;
}
}
}
true
})
.collect();
Ok(filtered)
}
Ok(_) => {
self.clear_failure();
@@ -541,7 +581,7 @@ exit 1
.await
.unwrap();
let entries = memory.recall("auth", 5, None).await.unwrap();
let entries = memory.recall("auth", 5, None, None, None).await.unwrap();
assert!(entries
.iter()
@@ -565,7 +605,7 @@ exit 1
.await
.unwrap();
let entries = memory.recall("auth", 5, None).await.unwrap();
let entries = memory.recall("auth", 5, None, None, None).await.unwrap();
assert!(entries
.iter()
@@ -603,7 +643,7 @@ exit 1
.await
.unwrap();
let entries = memory.recall("rust", 5, None).await.unwrap();
let entries = memory.recall("rust", 5, None, None, None).await.unwrap();
assert!(entries
.iter()
.any(|e| e.content.contains("Rust should stay local-first")));
@@ -663,8 +703,8 @@ exit 1
Duration::from_secs(5),
);
let first = memory.recall("auth", 5, None).await.unwrap();
let second = memory.recall("auth", 5, None).await.unwrap();
let first = memory.recall("auth", 5, None, None, None).await.unwrap();
let second = memory.recall("auth", 5, None, None, None).await.unwrap();
assert!(first.is_empty());
assert!(second.is_empty());
+47 -6
View File
@@ -158,7 +158,23 @@ impl Memory for MarkdownMemory {
query: &str,
limit: usize,
_session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let since_dt = since
.map(chrono::DateTime::parse_from_rfc3339)
.transpose()
.map_err(|e| anyhow::anyhow!("invalid 'since' date (expected RFC 3339): {e}"))?;
let until_dt = until
.map(chrono::DateTime::parse_from_rfc3339)
.transpose()
.map_err(|e| anyhow::anyhow!("invalid 'until' date (expected RFC 3339): {e}"))?;
if let (Some(s), Some(u)) = (&since_dt, &until_dt) {
if s >= u {
anyhow::bail!("'since' must be before 'until'");
}
}
let all = self.read_all_entries().await?;
let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
@@ -166,6 +182,24 @@ impl Memory for MarkdownMemory {
let mut scored: Vec<MemoryEntry> = all
.into_iter()
.filter_map(|mut entry| {
if let Some(ref s) = since_dt {
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&entry.timestamp) {
if ts < *s {
return None;
}
}
}
if let Some(ref u) = until_dt {
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&entry.timestamp) {
if ts > *u {
return None;
}
}
}
if keywords.is_empty() {
entry.score = Some(1.0);
return Some(entry);
}
let content_lower = entry.content.to_lowercase();
let matched = keywords
.iter()
@@ -183,9 +217,13 @@ impl Memory for MarkdownMemory {
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
if keywords.is_empty() {
b.timestamp.as_str().cmp(a.timestamp.as_str())
} else {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
}
});
scored.truncate(limit);
Ok(scored)
@@ -283,7 +321,7 @@ mod tests {
.await
.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
assert!(results.len() >= 2);
assert!(results
.iter()
@@ -296,7 +334,10 @@ mod tests {
mem.store("a", "Rust is great", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("javascript", 10, None).await.unwrap();
let results = mem
.recall("javascript", 10, None, None, None)
.await
.unwrap();
assert!(results.is_empty());
}
@@ -343,7 +384,7 @@ mod tests {
#[tokio::test]
async fn markdown_empty_recall() {
let (_tmp, mem) = temp_workspace();
let results = mem.recall("anything", 10, None).await.unwrap();
let results = mem.recall("anything", 10, None, None, None).await.unwrap();
assert!(results.is_empty());
}
+5 -1
View File
@@ -364,14 +364,18 @@ impl Memory for Mem0Memory {
query: &str,
limit: usize,
session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
// mem0 handles filtering server-side; since/until are not yet
// supported by the mem0 API, so we pass them through as no-ops.
self.recall_filtered(query, limit, session_id, None, None, None)
.await
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
// mem0 doesn't have a get-by-key API, so we search by key in metadata
let results = self.recall(key, 1, None).await?;
let results = self.recall(key, 1, None, None, None).await?;
Ok(results.into_iter().find(|e| e.key == key))
}
+7 -1
View File
@@ -35,6 +35,8 @@ impl Memory for NoneMemory {
_query: &str,
_limit: usize,
_session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@@ -78,7 +80,11 @@ mod tests {
.unwrap();
assert!(memory.get("k").await.unwrap().is_none());
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
assert!(memory
.recall("k", 10, None, None, None)
.await
.unwrap()
.is_empty());
assert!(memory.list(None, None).await.unwrap().is_empty());
assert!(!memory.forget("k").await.unwrap());
assert_eq!(memory.count().await.unwrap(), 0);
+23 -1
View File
@@ -239,14 +239,30 @@ impl Memory for PostgresMemory {
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> Result<Vec<MemoryEntry>> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
let query = query.trim().to_string();
let sid = session_id.map(str::to_string);
let since_owned = since.map(str::to_string);
let until_owned = until.map(str::to_string);
run_on_os_thread(move || -> Result<Vec<MemoryEntry>> {
let mut client = client.lock();
let since_ref = since_owned.as_deref();
let until_ref = until_owned.as_deref();
let time_filter: String = match (since_ref, until_ref) {
(Some(_), Some(_)) => {
" AND created_at >= $4::TIMESTAMPTZ AND created_at <= $5::TIMESTAMPTZ".into()
}
(Some(_), None) => " AND created_at >= $4::TIMESTAMPTZ".into(),
(None, Some(_)) => " AND created_at <= $4::TIMESTAMPTZ".into(),
(None, None) => String::new(),
};
let stmt = format!(
"
SELECT id, key, content, category, created_at, session_id,
@@ -257,6 +273,7 @@ impl Memory for PostgresMemory {
FROM {qualified_table}
WHERE ($2::TEXT IS NULL OR session_id = $2)
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
{time_filter}
ORDER BY score DESC, updated_at DESC
LIMIT $3
"
@@ -265,7 +282,12 @@ impl Memory for PostgresMemory {
#[allow(clippy::cast_possible_wrap)]
let limit_i64 = limit as i64;
let rows = client.query(&stmt, &[&query, &sid, &limit_i64])?;
let rows = match (since_ref, until_ref) {
(Some(s), Some(u)) => client.query(&stmt, &[&query, &sid, &limit_i64, &s, &u])?,
(Some(s), None) => client.query(&stmt, &[&query, &sid, &limit_i64, &s])?,
(None, Some(u)) => client.query(&stmt, &[&query, &sid, &limit_i64, &u])?,
(None, None) => client.query(&stmt, &[&query, &sid, &limit_i64])?,
};
rows.iter()
.map(Self::row_to_entry)
.collect::<Result<Vec<MemoryEntry>>>()
+20 -2
View File
@@ -291,9 +291,19 @@ impl Memory for QdrantMemory {
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> Result<Vec<MemoryEntry>> {
if query.trim().is_empty() {
return self.list(None, session_id).await;
let mut entries = self.list(None, session_id).await?;
if let Some(s) = since {
entries.retain(|e| e.timestamp.as_str() >= s);
}
if let Some(u) = until {
entries.retain(|e| e.timestamp.as_str() <= u);
}
entries.truncate(limit);
return Ok(entries);
}
self.ensure_initialized().await?;
@@ -344,7 +354,7 @@ impl Memory for QdrantMemory {
let result: QdrantSearchResult = resp.json().await?;
let entries = result
let mut entries: Vec<MemoryEntry> = result
.result
.into_iter()
.filter_map(|point| {
@@ -367,6 +377,14 @@ impl Memory for QdrantMemory {
})
.collect();
// Filter by time range if specified
if let Some(s) = since {
entries.retain(|e| e.timestamp.as_str() >= s);
}
if let Some(u) = until {
entries.retain(|e| e.timestamp.as_str() <= u);
}
Ok(entries)
}
+167 -36
View File
@@ -428,6 +428,74 @@ impl SqliteMemory {
Ok(count)
}
/// List memories by time range (used when query is empty).
async fn recall_by_time_only(
&self,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self.conn.clone();
let sid = session_id.map(String::from);
let since_owned = since.map(String::from);
let until_owned = until.map(String::from);
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
let conn = conn.lock();
let since_ref = since_owned.as_deref();
let until_ref = until_owned.as_deref();
let mut sql =
"SELECT id, key, content, category, created_at, session_id FROM memories \
WHERE 1=1"
.to_string();
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut idx = 1;
if let Some(sid) = sid.as_deref() {
let _ = write!(sql, " AND session_id = ?{idx}");
param_values.push(Box::new(sid.to_string()));
idx += 1;
}
if let Some(s) = since_ref {
let _ = write!(sql, " AND created_at >= ?{idx}");
param_values.push(Box::new(s.to_string()));
idx += 1;
}
if let Some(u) = until_ref {
let _ = write!(sql, " AND created_at <= ?{idx}");
param_values.push(Box::new(u.to_string()));
idx += 1;
}
let _ = write!(sql, " ORDER BY updated_at DESC LIMIT ?{idx}");
#[allow(clippy::cast_possible_wrap)]
param_values.push(Box::new(limit as i64));
let mut stmt = conn.prepare(&sql)?;
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: None,
})
})?;
let mut results = Vec::new();
for row in rows {
results.push(row?);
}
Ok(results)
})
.await?
}
}
#[async_trait]
@@ -481,9 +549,14 @@ impl Memory for SqliteMemory {
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
// Time-only query: list by time range when no keywords
if query.trim().is_empty() {
return Ok(Vec::new());
return self
.recall_by_time_only(limit, session_id, since, until)
.await;
}
// Compute query embedding (async, before blocking work)
@@ -492,12 +565,16 @@ impl Memory for SqliteMemory {
let conn = self.conn.clone();
let query = query.to_string();
let sid = session_id.map(String::from);
let since_owned = since.map(String::from);
let until_owned = until.map(String::from);
let vector_weight = self.vector_weight;
let keyword_weight = self.keyword_weight;
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
let conn = conn.lock();
let session_ref = sid.as_deref();
let since_ref = since_owned.as_deref();
let until_ref = until_owned.as_deref();
// FTS5 BM25 keyword search
let keyword_results = Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default();
@@ -568,6 +645,16 @@ impl Memory for SqliteMemory {
for scored in &merged {
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
if let Some(s) = since_ref {
if ts.as_str() < s {
continue;
}
}
if let Some(u) = until_ref {
if ts.as_str() > u {
continue;
}
}
let entry = MemoryEntry {
id: scored.id.clone(),
key,
@@ -588,8 +675,6 @@ impl Memory for SqliteMemory {
}
// If hybrid returned nothing, fall back to LIKE search.
// Cap keyword count so we don't create too many SQL shapes,
// which helps prepared-statement cache efficiency.
if results.is_empty() {
const MAX_LIKE_KEYWORDS: usize = 8;
let keywords: Vec<String> = query
@@ -606,12 +691,21 @@ impl Memory for SqliteMemory {
})
.collect();
let where_clause = conditions.join(" OR ");
let mut param_idx = keywords.len() * 2 + 1;
let mut time_conditions = String::new();
if since_ref.is_some() {
let _ = write!(time_conditions, " AND created_at >= ?{param_idx}");
param_idx += 1;
}
if until_ref.is_some() {
let _ = write!(time_conditions, " AND created_at <= ?{param_idx}");
param_idx += 1;
}
let sql = format!(
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE {where_clause}
WHERE {where_clause}{time_conditions}
ORDER BY updated_at DESC
LIMIT ?{}",
keywords.len() * 2 + 1
LIMIT ?{param_idx}"
);
let mut stmt = conn.prepare(&sql)?;
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
@@ -619,6 +713,12 @@ impl Memory for SqliteMemory {
param_values.push(Box::new(kw.clone()));
param_values.push(Box::new(kw.clone()));
}
if let Some(s) = since_ref {
param_values.push(Box::new(s.to_string()));
}
if let Some(u) = until_ref {
param_values.push(Box::new(u.to_string()));
}
#[allow(clippy::cast_possible_wrap)]
param_values.push(Box::new(limit as i64));
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
@@ -852,7 +952,7 @@ mod tests {
.await
.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
@@ -869,7 +969,7 @@ mod tests {
.await
.unwrap();
let results = mem.recall("fast safe", 10, None).await.unwrap();
let results = mem.recall("fast safe", 10, None, None, None).await.unwrap();
assert!(!results.is_empty());
// Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
@@ -881,7 +981,10 @@ mod tests {
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("javascript", 10, None).await.unwrap();
let results = mem
.recall("javascript", 10, None, None, None)
.await
.unwrap();
assert!(results.is_empty());
}
@@ -1024,7 +1127,7 @@ mod tests {
.await
.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
assert!(results.len() >= 2);
// All results should contain "Rust"
for r in &results {
@@ -1049,30 +1152,34 @@ mod tests {
.await
.unwrap();
let results = mem.recall("quick dog", 10, None).await.unwrap();
let results = mem.recall("quick dog", 10, None, None, None).await.unwrap();
assert!(!results.is_empty());
// "The quick dog runs fast" matches both terms
assert!(results[0].content.contains("quick"));
}
#[tokio::test]
async fn recall_empty_query_returns_empty() {
async fn recall_empty_query_returns_recent_entries() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("", 10, None).await.unwrap();
assert!(results.is_empty());
// Empty query = time-only mode: returns recent entries
let results = mem.recall("", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "a");
}
#[tokio::test]
async fn recall_whitespace_query_returns_empty() {
async fn recall_whitespace_query_returns_recent_entries() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall(" ", 10, None).await.unwrap();
assert!(results.is_empty());
// Whitespace-only query = time-only mode: returns recent entries
let results = mem.recall(" ", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "a");
}
// ── Embedding cache tests ────────────────────────────────────
@@ -1283,7 +1390,7 @@ mod tests {
assert_eq!(count, 0);
// FTS should still work after rebuild
let results = mem.recall("reindex", 10, None).await.unwrap();
let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 2);
}
@@ -1303,7 +1410,10 @@ mod tests {
.unwrap();
}
let results = mem.recall("common keyword", 5, None).await.unwrap();
let results = mem
.recall("common keyword", 5, None, None, None)
.await
.unwrap();
assert!(results.len() <= 5);
}
@@ -1316,7 +1426,7 @@ mod tests {
.await
.unwrap();
let results = mem.recall("scored", 10, None).await.unwrap();
let results = mem.recall("scored", 10, None, None, None).await.unwrap();
assert!(!results.is_empty());
for r in &results {
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
@@ -1332,7 +1442,7 @@ mod tests {
.await
.unwrap();
// Quotes in query should not crash FTS5
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
let results = mem.recall("\"hello\"", 10, None, None, None).await.unwrap();
// May or may not match depending on FTS5 escaping, but must not error
assert!(results.len() <= 10);
}
@@ -1343,7 +1453,7 @@ mod tests {
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("wild*", 10, None).await.unwrap();
let results = mem.recall("wild*", 10, None, None, None).await.unwrap();
assert!(results.len() <= 10);
}
@@ -1353,7 +1463,10 @@ mod tests {
mem.store("p1", "function call test", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("function()", 10, None).await.unwrap();
let results = mem
.recall("function()", 10, None, None, None)
.await
.unwrap();
assert!(results.len() <= 10);
}
@@ -1365,7 +1478,7 @@ mod tests {
.unwrap();
// Should not crash or leak data
let results = mem
.recall("'; DROP TABLE memories; --", 10, None)
.recall("'; DROP TABLE memories; --", 10, None, None, None)
.await
.unwrap();
assert!(results.len() <= 10);
@@ -1441,7 +1554,7 @@ mod tests {
.await
.unwrap();
// Single char may not match FTS5 but LIKE fallback should work
let results = mem.recall("x", 10, None).await.unwrap();
let results = mem.recall("x", 10, None, None, None).await.unwrap();
// Should not crash; may or may not find results
assert!(results.len() <= 10);
}
@@ -1452,7 +1565,7 @@ mod tests {
mem.store("a", "some content", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("some", 0, None).await.unwrap();
let results = mem.recall("some", 0, None, None, None).await.unwrap();
assert!(results.is_empty());
}
@@ -1465,7 +1578,10 @@ mod tests {
mem.store("b", "matching content beta", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("matching content", 1, None).await.unwrap();
let results = mem
.recall("matching content", 1, None, None, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
}
@@ -1481,7 +1597,7 @@ mod tests {
.await
.unwrap();
// "rust" appears in key but not content — LIKE fallback checks key too
let results = mem.recall("rust", 10, None).await.unwrap();
let results = mem.recall("rust", 10, None, None, None).await.unwrap();
assert!(!results.is_empty(), "Should match by key");
}
@@ -1491,7 +1607,7 @@ mod tests {
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("日本語", 10, None).await.unwrap();
let results = mem.recall("日本語", 10, None, None, None).await.unwrap();
assert!(!results.is_empty());
}
@@ -1541,7 +1657,10 @@ mod tests {
.await
.unwrap();
mem.forget("ghost").await.unwrap();
let results = mem.recall("phantom memory", 10, None).await.unwrap();
let results = mem
.recall("phantom memory", 10, None, None, None)
.await
.unwrap();
assert!(
results.is_empty(),
"Deleted memory should not appear in recall"
@@ -1582,7 +1701,7 @@ mod tests {
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
// Data should still be intact
let results = mem.recall("reindex", 10, None).await.unwrap();
let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 1);
}
@@ -1686,7 +1805,10 @@ mod tests {
.unwrap();
// Recall with session-a filter returns only session-a entry
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
let results = mem
.recall("fact", 10, Some("sess-a"), None, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
@@ -1706,7 +1828,7 @@ mod tests {
.unwrap();
// Recall without session filter returns all matching entries
let results = mem.recall("fact", 10, None).await.unwrap();
let results = mem.recall("fact", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 3);
}
@@ -1723,11 +1845,17 @@ mod tests {
.unwrap();
// Session B cannot see session A data
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
let results = mem
.recall("secret", 10, Some("sess-b"), None, None)
.await
.unwrap();
assert!(results.is_empty());
// Session A can see its own data
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
let results = mem
.recall("secret", 10, Some("sess-a"), None, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
}
@@ -1778,7 +1906,10 @@ mod tests {
// Second open: migration runs again but is idempotent
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
let results = mem
.recall("reopen", 10, Some("sess-x"), None, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
+4
View File
@@ -96,11 +96,15 @@ pub trait Memory: Send + Sync {
) -> anyhow::Result<()>;
/// Recall memories matching a query (keyword search), optionally scoped to a session
/// and time range. Time bounds use RFC 3339 / ISO 8601 format
/// (e.g. "2025-03-01T00:00:00Z"); inclusive (created_at >= since, created_at <= until).
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>>;
/// Get a specific memory by key
+19
View File
@@ -154,6 +154,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
reliability: crate::config::ReliabilityConfig::default(),
scheduler: crate::config::schema::SchedulerConfig::default(),
agent: crate::config::schema::AgentConfig::default(),
pacing: crate::config::PacingConfig::default(),
skills: crate::config::SkillsConfig::default(),
model_routes: Vec::new(),
embedding_routes: Vec::new(),
@@ -199,6 +200,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
plugins: crate::config::PluginsConfig::default(),
locale: None,
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
claude_code: crate::config::ClaudeCodeConfig::default(),
};
println!(
@@ -575,6 +577,7 @@ async fn run_quick_setup_with_home(
reliability: crate::config::ReliabilityConfig::default(),
scheduler: crate::config::schema::SchedulerConfig::default(),
agent: crate::config::schema::AgentConfig::default(),
pacing: crate::config::PacingConfig::default(),
skills: crate::config::SkillsConfig::default(),
model_routes: Vec::new(),
embedding_routes: Vec::new(),
@@ -620,6 +623,7 @@ async fn run_quick_setup_with_home(
plugins: crate::config::PluginsConfig::default(),
locale: None,
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
claude_code: crate::config::ClaudeCodeConfig::default(),
};
config.save().await?;
@@ -3790,6 +3794,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
interrupt_on_new_message: false,
mention_only: false,
ack_reactions: None,
proxy_url: None,
});
}
ChannelMenuChoice::Discord => {
@@ -3890,6 +3895,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
listen_to_bots: false,
interrupt_on_new_message: false,
mention_only: false,
proxy_url: None,
});
}
ChannelMenuChoice::Slack => {
@@ -4020,6 +4026,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
interrupt_on_new_message: false,
thread_replies: None,
mention_only: false,
proxy_url: None,
});
}
ChannelMenuChoice::IMessage => {
@@ -4271,6 +4278,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
allowed_from,
ignore_attachments,
ignore_stories,
proxy_url: None,
});
println!(" {} Signal configured", style("").green().bold());
@@ -4372,6 +4380,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
});
println!(
@@ -4477,6 +4486,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
dm_policy: WhatsAppChatPolicy::default(),
group_policy: WhatsAppChatPolicy::default(),
self_chat_mode: false,
proxy_url: None,
});
}
ChannelMenuChoice::Linq => {
@@ -4810,6 +4820,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
Some(webhook_secret.trim().to_string())
},
allowed_users,
proxy_url: None,
});
println!(" {} Nextcloud Talk configured", style("").green().bold());
@@ -4882,6 +4893,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
client_id,
client_secret,
allowed_users,
proxy_url: None,
});
}
ChannelMenuChoice::QqOfficial => {
@@ -4958,6 +4970,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
app_id,
app_secret,
allowed_users,
proxy_url: None,
});
}
ChannelMenuChoice::Lark | ChannelMenuChoice::Feishu => {
@@ -5147,6 +5160,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
use_feishu: is_feishu,
receive_mode,
port,
proxy_url: None,
});
}
#[cfg(feature = "channel-nostr")]
@@ -7511,6 +7525,7 @@ mod tests {
allowed_from: vec!["*".into()],
ignore_attachments: false,
ignore_stories: true,
proxy_url: None,
});
assert!(has_launchable_channels(&channels));
@@ -7523,6 +7538,7 @@ mod tests {
thread_replies: Some(true),
mention_only: Some(false),
interrupt_on_new_message: false,
proxy_url: None,
});
assert!(has_launchable_channels(&channels));
@@ -7531,6 +7547,7 @@ mod tests {
app_id: "app-id".into(),
app_secret: "app-secret".into(),
allowed_users: vec!["*".into()],
proxy_url: None,
});
assert!(has_launchable_channels(&channels));
@@ -7540,6 +7557,7 @@ mod tests {
app_token: "token".into(),
webhook_secret: Some("secret".into()),
allowed_users: vec!["*".into()],
proxy_url: None,
});
assert!(has_launchable_channels(&channels));
@@ -7552,6 +7570,7 @@ mod tests {
allowed_users: vec!["*".into()],
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
proxy_url: None,
});
assert!(has_launchable_channels(&channels));
}
+1 -1
View File
@@ -146,7 +146,7 @@ fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec<Skill
load_skills_from_directory(&skills_dir, allow_scripts)
}
fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
pub fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
if !skills_dir.exists() {
return Vec::new();
}
+446
View File
@@ -0,0 +1,446 @@
use super::traits::{Tool, ToolResult};
use crate::config::ClaudeCodeConfig;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
use tokio::process::Command;
/// Environment variables safe to pass through to the `claude` subprocess.
const SAFE_ENV_VARS: &[&str] = &[
"PATH", "HOME", "TERM", "LANG", "LC_ALL", "LC_CTYPE", "USER", "SHELL", "TMPDIR",
];
/// Delegates coding tasks to the Claude Code CLI (`claude -p`).
///
/// This creates a two-tier agent architecture: ZeroClaw orchestrates high-level
/// tasks and delegates complex coding work to Claude Code, which has its own
/// agent loop with Read/Edit/Bash tools.
///
/// Authentication uses the `claude` binary's own OAuth session (Max subscription)
/// by default. No API key is needed unless `env_passthrough` includes
/// `ANTHROPIC_API_KEY` for API-key billing.
pub struct ClaudeCodeTool {
security: Arc<SecurityPolicy>,
config: ClaudeCodeConfig,
}
impl ClaudeCodeTool {
pub fn new(security: Arc<SecurityPolicy>, config: ClaudeCodeConfig) -> Self {
Self { security, config }
}
}
#[async_trait]
impl Tool for ClaudeCodeTool {
fn name(&self) -> &str {
"claude_code"
}
fn description(&self) -> &str {
"Delegate a coding task to Claude Code (claude -p). Supports file editing, bash execution, structured output, and multi-turn sessions. Use for complex coding work that benefits from Claude Code's full agent loop."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The coding task to delegate to Claude Code"
},
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "Override the default tool allowlist (e.g. [\"Read\", \"Edit\", \"Bash\", \"Write\"])"
},
"system_prompt": {
"type": "string",
"description": "Override or append a system prompt for this invocation"
},
"session_id": {
"type": "string",
"description": "Resume a previous Claude Code session by its ID"
},
"json_schema": {
"type": "object",
"description": "Request structured output conforming to this JSON Schema"
},
"working_directory": {
"type": "string",
"description": "Working directory within the workspace (must be inside workspace_dir)"
}
},
"required": ["prompt"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
// Rate limit check
if self.security.is_rate_limited() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
});
}
// Enforce act policy
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "claude_code")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
// Extract prompt (required)
let prompt = args
.get("prompt")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
// Extract optional params
let allowed_tools: Vec<String> = args
.get("allowed_tools")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_else(|| self.config.allowed_tools.clone());
let system_prompt = args
.get("system_prompt")
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| self.config.system_prompt.clone());
let session_id = args.get("session_id").and_then(|v| v.as_str());
let json_schema = args.get("json_schema").filter(|v| v.is_object());
// Validate working directory — require both paths to exist (reject
// non-existent paths instead of falling back to the raw value, which
// could bypass the workspace containment check via symlinks or
// specially-crafted path components).
let work_dir = if let Some(wd) = args.get("working_directory").and_then(|v| v.as_str()) {
let wd_path = std::path::PathBuf::from(wd);
let workspace = &self.security.workspace_dir;
let canonical_wd = match wd_path.canonicalize() {
Ok(p) => p,
Err(_) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"working_directory '{}' does not exist or is not accessible",
wd
)),
});
}
};
let canonical_ws = match workspace.canonicalize() {
Ok(p) => p,
Err(_) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"workspace directory '{}' does not exist or is not accessible",
workspace.display()
)),
});
}
};
if !canonical_wd.starts_with(&canonical_ws) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"working_directory '{}' is outside the workspace '{}'",
wd,
workspace.display()
)),
});
}
canonical_wd
} else {
self.security.workspace_dir.clone()
};
// Record action budget
if !self.security.record_action() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Rate limit exceeded: action budget exhausted".into()),
});
}
// Build CLI command
let mut cmd = Command::new("claude");
cmd.arg("-p").arg(prompt);
cmd.arg("--output-format").arg("json");
if !allowed_tools.is_empty() {
for tool in &allowed_tools {
cmd.arg("--allowedTools").arg(tool);
}
}
if let Some(ref sp) = system_prompt {
cmd.arg("--append-system-prompt").arg(sp);
}
if let Some(sid) = session_id {
cmd.arg("--resume").arg(sid);
}
if let Some(schema) = json_schema {
let schema_str = serde_json::to_string(schema).unwrap_or_else(|_| "{}".to_string());
cmd.arg("--json-schema").arg(schema_str);
}
// Environment: clear everything, pass only safe vars + configured passthrough.
// HOME is critical so `claude` finds its OAuth session in ~/.claude/
cmd.env_clear();
for var in SAFE_ENV_VARS {
if let Ok(val) = std::env::var(var) {
cmd.env(var, val);
}
}
for var in &self.config.env_passthrough {
let trimmed = var.trim();
if !trimmed.is_empty() {
if let Ok(val) = std::env::var(trimmed) {
cmd.env(trimmed, val);
}
}
}
cmd.current_dir(&work_dir);
// Execute with timeout — use kill_on_drop(true) so the child process
// is automatically killed when the future is dropped on timeout,
// preventing zombie processes.
let timeout = Duration::from_secs(self.config.timeout_secs);
cmd.kill_on_drop(true);
let result = tokio::time::timeout(timeout, cmd.output()).await;
match result {
Ok(Ok(output)) => {
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
// Truncate to max_output_bytes with char-boundary safety
if stdout.len() > self.config.max_output_bytes {
let mut b = self.config.max_output_bytes.min(stdout.len());
while b > 0 && !stdout.is_char_boundary(b) {
b -= 1;
}
stdout.truncate(b);
stdout.push_str("\n... [output truncated]");
}
// Try to parse JSON response and extract result + session_id
if let Ok(json_resp) = serde_json::from_str::<serde_json::Value>(&stdout) {
let result_text = json_resp
.get("result")
.and_then(|v| v.as_str())
.unwrap_or("");
let resp_session_id = json_resp
.get("session_id")
.and_then(|v| v.as_str())
.unwrap_or("");
let mut formatted = String::new();
if result_text.is_empty() {
// Fall back to full JSON if no "result" key
formatted.push_str(&stdout);
} else {
formatted.push_str(result_text);
}
if !resp_session_id.is_empty() {
use std::fmt::Write;
let _ = write!(formatted, "\n\n[session_id: {}]", resp_session_id);
}
Ok(ToolResult {
success: output.status.success(),
output: formatted,
error: if stderr.is_empty() {
None
} else {
Some(stderr)
},
})
} else {
// JSON parse failed — return raw stdout (defensive)
Ok(ToolResult {
success: output.status.success(),
output: stdout,
error: if stderr.is_empty() {
None
} else {
Some(stderr)
},
})
}
}
Ok(Err(e)) => {
let err_msg = e.to_string();
let msg = if err_msg.contains("No such file or directory")
|| err_msg.contains("not found")
|| err_msg.contains("cannot find")
{
"Claude Code CLI ('claude') not found in PATH. Install with: npm install -g @anthropic-ai/claude-code".into()
} else {
format!("Failed to execute claude: {e}")
};
Ok(ToolResult {
success: false,
output: String::new(),
error: Some(msg),
})
}
Err(_) => {
// Timeout — kill_on_drop(true) ensures the child is killed
// when the future is dropped.
Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Claude Code timed out after {}s and was killed",
self.config.timeout_secs
)),
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ClaudeCodeConfig;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_config() -> ClaudeCodeConfig {
ClaudeCodeConfig::default()
}
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
#[test]
fn claude_code_tool_name() {
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
assert_eq!(tool.name(), "claude_code");
}
#[test]
fn claude_code_tool_schema_has_prompt() {
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
let schema = tool.parameters_schema();
assert!(schema["properties"]["prompt"].is_object());
assert!(schema["required"]
.as_array()
.expect("schema required should be an array")
.contains(&json!("prompt")));
// Optional params exist in properties
assert!(schema["properties"]["allowed_tools"].is_object());
assert!(schema["properties"]["system_prompt"].is_object());
assert!(schema["properties"]["session_id"].is_object());
assert!(schema["properties"]["json_schema"].is_object());
assert!(schema["properties"]["working_directory"].is_object());
}
#[tokio::test]
async fn claude_code_blocks_rate_limited() {
let security = Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
max_actions_per_hour: 0,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
});
let tool = ClaudeCodeTool::new(security, test_config());
let result = tool
.execute(json!({"prompt": "hello"}))
.await
.expect("rate-limited should return a result");
assert!(!result.success);
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
}
#[tokio::test]
async fn claude_code_blocks_readonly() {
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::ReadOnly), test_config());
let result = tool
.execute(json!({"prompt": "hello"}))
.await
.expect("readonly should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("read-only mode"));
}
#[tokio::test]
async fn claude_code_missing_prompt_param() {
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
let result = tool.execute(json!({})).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("prompt"));
}
#[tokio::test]
async fn claude_code_rejects_path_outside_workspace() {
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Full), test_config());
let result = tool
.execute(json!({
"prompt": "hello",
"working_directory": "/etc"
}))
.await
.expect("should return a result for path validation");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("outside the workspace"));
}
#[test]
fn claude_code_env_passthrough_defaults() {
let config = ClaudeCodeConfig::default();
assert!(
config.env_passthrough.is_empty(),
"env_passthrough should default to empty (Max subscription needs no API key)"
);
}
#[test]
fn claude_code_default_config_values() {
let config = ClaudeCodeConfig::default();
assert!(!config.enabled);
assert_eq!(config.timeout_secs, 600);
assert_eq!(config.max_output_bytes, 2_097_152);
assert!(config.system_prompt.is_none());
assert_eq!(config.allowed_tools, vec!["Read", "Edit", "Bash", "Write"]);
}
}
+344 -2
View File
@@ -1,5 +1,6 @@
use super::traits::{Tool, ToolResult};
use crate::agent::loop_::run_tool_call_loop;
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
use crate::config::{DelegateAgentConfig, DelegateToolConfig};
use crate::observability::traits::{Observer, ObserverEvent, ObserverMetric};
use crate::providers::{self, ChatMessage, Provider};
@@ -9,6 +10,7 @@ use async_trait::async_trait;
use parking_lot::RwLock;
use serde_json::json;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
@@ -31,6 +33,8 @@ pub struct DelegateTool {
multimodal_config: crate::config::MultimodalConfig,
/// Global delegate tool config providing default timeout values.
delegate_config: DelegateToolConfig,
/// Workspace directory inherited from the root agent context.
workspace_dir: PathBuf,
}
impl DelegateTool {
@@ -62,6 +66,7 @@ impl DelegateTool {
parent_tools: Arc::new(RwLock::new(Vec::new())),
multimodal_config: crate::config::MultimodalConfig::default(),
delegate_config: DelegateToolConfig::default(),
workspace_dir: PathBuf::new(),
}
}
@@ -99,6 +104,7 @@ impl DelegateTool {
parent_tools: Arc::new(RwLock::new(Vec::new())),
multimodal_config: crate::config::MultimodalConfig::default(),
delegate_config: DelegateToolConfig::default(),
workspace_dir: PathBuf::new(),
}
}
@@ -125,6 +131,12 @@ impl DelegateTool {
pub fn parent_tools_handle(&self) -> Arc<RwLock<Vec<Arc<dyn Tool>>>> {
Arc::clone(&self.parent_tools)
}
/// Attach the workspace directory for system prompt enrichment.
pub fn with_workspace_dir(mut self, workspace_dir: PathBuf) -> Self {
self.workspace_dir = workspace_dir;
self
}
}
#[async_trait]
@@ -300,6 +312,11 @@ impl Tool for DelegateTool {
.await;
}
// Build enriched system prompt for non-agentic sub-agent.
let enriched_system_prompt =
self.build_enriched_system_prompt(agent_config, &[], &self.workspace_dir);
let system_prompt_ref = enriched_system_prompt.as_deref();
// Wrap the provider call in a timeout to prevent indefinite blocking
let timeout_secs = agent_config
.timeout_secs
@@ -307,7 +324,7 @@ impl Tool for DelegateTool {
let result = tokio::time::timeout(
Duration::from_secs(timeout_secs),
provider.chat_with_system(
agent_config.system_prompt.as_deref(),
system_prompt_ref,
&full_prompt,
&agent_config.model,
temperature,
@@ -355,6 +372,80 @@ impl Tool for DelegateTool {
}
impl DelegateTool {
/// Build an enriched system prompt for a sub-agent by composing structured
/// operational sections (tools, skills, workspace, datetime, shell policy)
/// with the operator-configured `system_prompt` string.
fn build_enriched_system_prompt(
&self,
agent_config: &DelegateAgentConfig,
sub_tools: &[Box<dyn Tool>],
workspace_dir: &Path,
) -> Option<String> {
// Resolve skills directory: scoped if configured, otherwise workspace default.
let skills_dir = agent_config
.skills_directory
.as_ref()
.filter(|s| !s.trim().is_empty())
.map(|dir| workspace_dir.join(dir))
.unwrap_or_else(|| crate::skills::skills_dir(workspace_dir));
let skills = crate::skills::load_skills_from_directory(&skills_dir, false);
// Determine shell policy instructions when the `shell` tool is in the
// effective tool list.
let has_shell = sub_tools.iter().any(|t| t.name() == "shell");
let shell_policy = if has_shell {
"## Shell Policy\n\n\
- Prefer non-destructive commands. Use `trash` over `rm` where possible.\n\
- Do not run commands that exfiltrate data or modify system-critical paths.\n\
- Avoid interactive commands that block on stdin.\n\
- Quote paths that may contain spaces."
.to_string()
} else {
String::new()
};
// Build structured operational context using SystemPromptBuilder sections.
let ctx = PromptContext {
workspace_dir,
model_name: &agent_config.model,
tools: sub_tools,
skills: &skills,
skills_prompt_mode: crate::config::SkillsPromptInjectionMode::Full,
identity_config: None,
dispatcher_instructions: "",
tool_descriptions: None,
security_summary: None,
autonomy_level: crate::security::AutonomyLevel::default(),
};
let builder = SystemPromptBuilder::default()
.add_section(Box::new(crate::agent::prompt::ToolsSection))
.add_section(Box::new(crate::agent::prompt::SafetySection))
.add_section(Box::new(crate::agent::prompt::SkillsSection))
.add_section(Box::new(crate::agent::prompt::WorkspaceSection))
.add_section(Box::new(crate::agent::prompt::DateTimeSection));
let mut enriched = builder.build(&ctx).unwrap_or_default();
if !shell_policy.is_empty() {
enriched.push_str(&shell_policy);
enriched.push_str("\n\n");
}
// Append the operator-configured system_prompt as the identity/role block.
if let Some(operator_prompt) = agent_config.system_prompt.as_ref() {
enriched.push_str(operator_prompt);
enriched.push('\n');
}
let trimmed = enriched.trim().to_string();
if trimmed.is_empty() {
None
} else {
Some(trimmed)
}
}
async fn execute_agentic(
&self,
agent_name: &str,
@@ -401,8 +492,12 @@ impl DelegateTool {
});
}
// Build enriched system prompt with tools, skills, workspace, datetime context.
let enriched_system_prompt =
self.build_enriched_system_prompt(agent_config, &sub_tools, &self.workspace_dir);
let mut history = Vec::new();
if let Some(system_prompt) = agent_config.system_prompt.as_ref() {
if let Some(system_prompt) = enriched_system_prompt.as_ref() {
history.push(ChatMessage::system(system_prompt.clone()));
}
history.push(ChatMessage::user(full_prompt.to_string()));
@@ -435,6 +530,7 @@ impl DelegateTool {
&[],
None,
None,
&crate::config::PacingConfig::default(),
),
)
.await;
@@ -548,6 +644,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
agents.insert(
@@ -564,6 +661,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
agents
@@ -719,6 +817,7 @@ mod tests {
max_iterations,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
}
}
@@ -829,6 +928,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
let tool = DelegateTool::new(agents, None, test_security());
@@ -937,6 +1037,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
let tool = DelegateTool::new(agents, None, test_security());
@@ -974,6 +1075,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
let tool = DelegateTool::new(agents, None, test_security());
@@ -1235,6 +1337,113 @@ mod tests {
);
}
#[test]
fn enriched_prompt_includes_tools_workspace_datetime() {
let config = DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "test-model".to_string(),
system_prompt: Some("You are a code reviewer.".to_string()),
api_key: None,
temperature: None,
max_depth: 3,
agentic: true,
allowed_tools: vec!["echo_tool".to_string()],
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
};
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
let workspace = std::env::temp_dir().join(format!(
"zeroclaw_delegate_enrich_test_{}",
uuid::Uuid::new_v4()
));
std::fs::create_dir_all(&workspace).unwrap();
let tool = DelegateTool::new(HashMap::new(), None, test_security())
.with_workspace_dir(workspace.clone());
let prompt = tool
.build_enriched_system_prompt(&config, &tools, &workspace)
.unwrap();
assert!(prompt.contains("## Tools"), "should contain tools section");
assert!(prompt.contains("echo_tool"), "should list allowed tools");
assert!(
prompt.contains("## Workspace"),
"should contain workspace section"
);
assert!(
prompt.contains(&workspace.display().to_string()),
"should contain workspace path"
);
assert!(
prompt.contains("## Current Date & Time"),
"should contain datetime section"
);
assert!(
prompt.contains("You are a code reviewer."),
"should append operator system_prompt"
);
let _ = std::fs::remove_dir_all(workspace);
}
#[test]
fn enriched_prompt_includes_shell_policy_when_shell_present() {
let config = DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "test-model".to_string(),
system_prompt: None,
api_key: None,
temperature: None,
max_depth: 3,
agentic: true,
allowed_tools: vec!["shell".to_string()],
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
};
struct MockShellTool;
#[async_trait]
impl Tool for MockShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> &str {
"Execute shell commands"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({"type": "object"})
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
Ok(ToolResult {
success: true,
output: String::new(),
error: None,
})
}
}
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockShellTool)];
let workspace = std::env::temp_dir();
let tool = DelegateTool::new(HashMap::new(), None, test_security())
.with_workspace_dir(workspace.to_path_buf());
let prompt = tool
.build_enriched_system_prompt(&config, &tools, &workspace)
.unwrap();
assert!(
prompt.contains("## Shell Policy"),
"should contain shell policy when shell tool is present"
);
}
#[test]
fn parent_tools_handle_returns_shared_reference() {
let tool = DelegateTool::new(HashMap::new(), None, test_security()).with_parent_tools(
@@ -1265,6 +1474,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
};
assert_eq!(
config.timeout_secs.unwrap_or(DEFAULT_DELEGATE_TIMEOUT_SECS),
@@ -1278,6 +1488,39 @@ mod tests {
);
}
#[test]
fn enriched_prompt_omits_shell_policy_without_shell_tool() {
let config = DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "test-model".to_string(),
system_prompt: None,
api_key: None,
temperature: None,
max_depth: 3,
agentic: true,
allowed_tools: vec!["echo_tool".to_string()],
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
};
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
let workspace = std::env::temp_dir();
let tool = DelegateTool::new(HashMap::new(), None, test_security())
.with_workspace_dir(workspace.to_path_buf());
let prompt = tool
.build_enriched_system_prompt(&config, &tools, &workspace)
.unwrap();
assert!(
!prompt.contains("## Shell Policy"),
"should not contain shell policy when shell tool is absent"
);
}
#[test]
fn custom_timeout_values_are_respected() {
let config = DelegateAgentConfig {
@@ -1292,6 +1535,7 @@ mod tests {
max_iterations: 10,
timeout_secs: Some(60),
agentic_timeout_secs: Some(600),
skills_directory: None,
};
assert_eq!(
config.timeout_secs.unwrap_or(DEFAULT_DELEGATE_TIMEOUT_SECS),
@@ -1346,6 +1590,7 @@ mod tests {
max_iterations: 10,
timeout_secs: Some(0),
agentic_timeout_secs: None,
skills_directory: None,
},
);
let err = config.validate().unwrap_err();
@@ -1372,6 +1617,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: Some(0),
skills_directory: None,
},
);
let err = config.validate().unwrap_err();
@@ -1398,6 +1644,7 @@ mod tests {
max_iterations: 10,
timeout_secs: Some(7200),
agentic_timeout_secs: None,
skills_directory: None,
},
);
let err = config.validate().unwrap_err();
@@ -1424,6 +1671,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: Some(5000),
skills_directory: None,
},
);
let err = config.validate().unwrap_err();
@@ -1450,6 +1698,7 @@ mod tests {
max_iterations: 10,
timeout_secs: Some(3600),
agentic_timeout_secs: Some(3600),
skills_directory: None,
},
);
assert!(config.validate().is_ok());
@@ -1472,8 +1721,101 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
assert!(config.validate().is_ok());
}
#[test]
fn enriched_prompt_loads_skills_from_scoped_directory() {
let workspace = std::env::temp_dir().join(format!(
"zeroclaw_delegate_skills_test_{}",
uuid::Uuid::new_v4()
));
let scoped_skills_dir = workspace.join("skills/code-review");
std::fs::create_dir_all(scoped_skills_dir.join("lint-check")).unwrap();
std::fs::write(
scoped_skills_dir.join("lint-check/SKILL.toml"),
"[skill]\nname = \"lint-check\"\ndescription = \"Run lint checks\"\nversion = \"1.0.0\"\n",
)
.unwrap();
let config = DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "test-model".to_string(),
system_prompt: None,
api_key: None,
temperature: None,
max_depth: 3,
agentic: true,
allowed_tools: vec!["echo_tool".to_string()],
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: Some("skills/code-review".to_string()),
};
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
let tool = DelegateTool::new(HashMap::new(), None, test_security())
.with_workspace_dir(workspace.clone());
let prompt = tool
.build_enriched_system_prompt(&config, &tools, &workspace)
.unwrap();
assert!(
prompt.contains("lint-check"),
"should contain skills from scoped directory"
);
let _ = std::fs::remove_dir_all(workspace);
}
#[test]
fn enriched_prompt_falls_back_to_default_skills_dir() {
let workspace = std::env::temp_dir().join(format!(
"zeroclaw_delegate_fallback_test_{}",
uuid::Uuid::new_v4()
));
let default_skills_dir = workspace.join("skills");
std::fs::create_dir_all(default_skills_dir.join("deploy")).unwrap();
std::fs::write(
default_skills_dir.join("deploy/SKILL.toml"),
"[skill]\nname = \"deploy\"\ndescription = \"Deploy safely\"\nversion = \"1.0.0\"\n",
)
.unwrap();
let config = DelegateAgentConfig {
provider: "openrouter".to_string(),
model: "test-model".to_string(),
system_prompt: None,
api_key: None,
temperature: None,
max_depth: 3,
agentic: true,
allowed_tools: vec!["echo_tool".to_string()],
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
};
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
let tool = DelegateTool::new(HashMap::new(), None, test_security())
.with_workspace_dir(workspace.clone());
let prompt = tool
.build_enriched_system_prompt(&config, &tools, &workspace)
.unwrap();
assert!(
prompt.contains("deploy"),
"should contain skills from default workspace skills/ directory"
);
let _ = std::fs::remove_dir_all(workspace);
}
}
+85 -13
View File
@@ -23,7 +23,7 @@ impl Tool for MemoryRecallTool {
}
fn description(&self) -> &str {
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance. Supports keyword search, time-only query (since/until), or both."
}
fn parameters_schema(&self) -> serde_json::Value {
@@ -32,22 +32,76 @@ impl Tool for MemoryRecallTool {
"properties": {
"query": {
"type": "string",
"description": "Keywords or phrase to search for in memory"
"description": "Keywords or phrase to search for in memory (optional if since/until provided)"
},
"limit": {
"type": "integer",
"description": "Max results to return (default: 5)"
},
"since": {
"type": "string",
"description": "Filter memories created at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
},
"until": {
"type": "string",
"description": "Filter memories created at or before this time (RFC 3339)"
}
},
"required": ["query"]
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
let since = args.get("since").and_then(|v| v.as_str());
let until = args.get("until").and_then(|v| v.as_str());
if query.trim().is_empty() && since.is_none() && until.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
),
});
}
// Validate date strings
if let Some(s) = since {
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid 'since' date: {s}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
)),
});
}
}
if let Some(u) = until {
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid 'until' date: {u}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
)),
});
}
}
if let (Some(s), Some(u)) = (since, until) {
if let (Ok(s_dt), Ok(u_dt)) = (
chrono::DateTime::parse_from_rfc3339(s),
chrono::DateTime::parse_from_rfc3339(u),
) {
if s_dt >= u_dt {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'since' must be before 'until'".into()),
});
}
}
}
#[allow(clippy::cast_possible_truncation)]
let limit = args
@@ -55,10 +109,10 @@ impl Tool for MemoryRecallTool {
.and_then(serde_json::Value::as_u64)
.map_or(5, |v| v as usize);
match self.memory.recall(query, limit, None).await {
match self.memory.recall(query, limit, None, since, until).await {
Ok(entries) if entries.is_empty() => Ok(ToolResult {
success: true,
output: "No memories found matching that query.".into(),
output: "No memories found.".into(),
error: None,
}),
Ok(entries) => {
@@ -150,11 +204,29 @@ mod tests {
}
#[tokio::test]
async fn recall_missing_query() {
async fn recall_requires_query_or_time() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
let result = tool.execute(json!({})).await;
assert!(result.is_err());
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("at least"));
}
#[tokio::test]
async fn recall_time_only_returns_entries() {
let (_tmp, mem) = seeded_mem();
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
let tool = MemoryRecallTool::new(mem);
// Time-only: since far in past
let result = tool
.execute(json!({"since": "2020-01-01T00:00:00Z", "limit": 5}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Found 1"));
assert!(result.output.contains("Rust"));
}
#[test]
+13 -1
View File
@@ -20,6 +20,7 @@ pub mod browser;
pub mod browser_delegate;
pub mod browser_open;
pub mod calculator;
pub mod claude_code;
pub mod cli_discovery;
pub mod cloud_ops;
pub mod cloud_patterns;
@@ -92,6 +93,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
pub use browser_open::BrowserOpenTool;
pub use calculator::CalculatorTool;
pub use claude_code::ClaudeCodeTool;
pub use cloud_ops::CloudOpsTool;
pub use cloud_patterns::CloudPatternsTool;
pub use composio::ComposioTool;
@@ -528,6 +530,14 @@ pub fn all_tools_with_runtime(
);
}
// Claude Code delegation tool
if root_config.claude_code.enabled {
tool_arcs.push(Arc::new(ClaudeCodeTool::new(
security.clone(),
root_config.claude_code.clone(),
)));
}
// PDF extraction (feature-gated at compile time via rag-pdf)
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
@@ -669,7 +679,8 @@ pub fn all_tools_with_runtime(
)
.with_parent_tools(Arc::clone(&parent_tools))
.with_multimodal_config(root_config.multimodal.clone())
.with_delegate_config(root_config.delegate.clone());
.with_delegate_config(root_config.delegate.clone())
.with_workspace_dir(workspace_dir.to_path_buf());
tool_arcs.push(Arc::new(delegate_tool));
Some(parent_tools)
};
@@ -1000,6 +1011,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
+1
View File
@@ -707,6 +707,7 @@ impl ModelRoutingConfigTool {
max_iterations: DEFAULT_AGENT_MAX_ITERATIONS,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
});
next_agent.provider = provider;
+2
View File
@@ -568,6 +568,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
agents.insert(
@@ -584,6 +585,7 @@ mod tests {
max_iterations: 10,
timeout_secs: None,
agentic_timeout_secs: None,
skills_directory: None,
},
);
agents
+93
View File
@@ -100,6 +100,10 @@ fn gateway_config_defaults_are_secure() {
!gw.trust_forwarded_headers,
"forwarded headers should be untrusted by default"
);
assert!(
gw.path_prefix.is_none(),
"path_prefix should default to None"
);
}
#[test]
@@ -124,6 +128,7 @@ fn gateway_config_toml_roundtrip() {
host: "0.0.0.0".into(),
require_pairing: false,
pair_rate_limit_per_minute: 5,
path_prefix: Some("/zeroclaw".into()),
..Default::default()
};
@@ -134,6 +139,7 @@ fn gateway_config_toml_roundtrip() {
assert_eq!(parsed.host, "0.0.0.0");
assert!(!parsed.require_pairing);
assert_eq!(parsed.pair_rate_limit_per_minute, 5);
assert_eq!(parsed.path_prefix.as_deref(), Some("/zeroclaw"));
}
#[test]
@@ -163,6 +169,93 @@ port = 9090
assert_eq!(parsed.gateway.pair_rate_limit_per_minute, 10);
}
// ─────────────────────────────────────────────────────────────────────────────
// GatewayConfig path_prefix validation
// ─────────────────────────────────────────────────────────────────────────────
#[test]
fn gateway_path_prefix_rejects_missing_leading_slash() {
let mut config = Config::default();
config.gateway.path_prefix = Some("zeroclaw".into());
let err = config.validate().unwrap_err();
assert!(
err.to_string().contains("must start with '/'"),
"expected leading-slash error, got: {err}"
);
}
#[test]
fn gateway_path_prefix_rejects_trailing_slash() {
let mut config = Config::default();
config.gateway.path_prefix = Some("/zeroclaw/".into());
let err = config.validate().unwrap_err();
assert!(
err.to_string().contains("must not end with '/'"),
"expected trailing-slash error, got: {err}"
);
}
#[test]
fn gateway_path_prefix_rejects_bare_slash() {
let mut config = Config::default();
config.gateway.path_prefix = Some("/".into());
let err = config.validate().unwrap_err();
assert!(
err.to_string().contains("must not end with '/'"),
"expected bare-slash error, got: {err}"
);
}
#[test]
fn gateway_path_prefix_accepts_valid_prefixes() {
for prefix in ["/zeroclaw", "/apps/zeroclaw", "/api/hassio_ingress/abc123"] {
let mut config = Config::default();
config.gateway.path_prefix = Some(prefix.into());
config
.validate()
.unwrap_or_else(|e| panic!("prefix {prefix:?} should be valid, got: {e}"));
}
}
#[test]
fn gateway_path_prefix_rejects_unsafe_characters() {
for prefix in [
"/zero claw",
"/zero<claw",
"/zero>claw",
"/zero\"claw",
"/zero?query",
"/zero#frag",
] {
let mut config = Config::default();
config.gateway.path_prefix = Some(prefix.into());
let err = config.validate().unwrap_err();
assert!(
err.to_string().contains("invalid character"),
"prefix {prefix:?} should be rejected, got: {err}"
);
}
// Leading/trailing whitespace is rejected by the starts_with('/') or
// invalid-character check — either way it must not pass validation.
for prefix in [" /zeroclaw ", " /zeroclaw"] {
let mut config = Config::default();
config.gateway.path_prefix = Some(prefix.into());
assert!(
config.validate().is_err(),
"whitespace-padded prefix {prefix:?} should be rejected"
);
}
}
#[test]
fn gateway_path_prefix_accepts_none() {
let config = Config::default();
assert!(config.gateway.path_prefix.is_none());
config
.validate()
.expect("absent path_prefix should be valid");
}
// ─────────────────────────────────────────────────────────────────────────────
// SecurityConfig boundary tests
// ─────────────────────────────────────────────────────────────────────────────
+11 -5
View File
@@ -147,8 +147,8 @@ async fn compare_recall_quality() {
println!("RECALL QUALITY (10 entries seeded):\n");
for (query, desc) in &queries {
let sq_results = sq.recall(query, 10, None).await.unwrap();
let md_results = md.recall(query, 10, None).await.unwrap();
let sq_results = sq.recall(query, 10, None, None, None).await.unwrap();
let md_results = md.recall(query, 10, None, None, None).await.unwrap();
println!(" Query: \"{query}\"{desc}");
println!(" SQLite: {} results", sq_results.len());
@@ -202,11 +202,17 @@ async fn compare_recall_speed() {
// Benchmark recall
let start = Instant::now();
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
let sq_results = sq
.recall("Rust systems", 10, None, None, None)
.await
.unwrap();
let sq_dur = start.elapsed();
let start = Instant::now();
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
let md_results = md
.recall("Rust systems", 10, None, None, None)
.await
.unwrap();
let md_dur = start.elapsed();
println!("\n============================================================");
@@ -312,7 +318,7 @@ async fn compare_upsert() {
let md_count = md.count().await.unwrap();
let sq_entry = sq.get("pref").await.unwrap();
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
let md_results = md.recall("loves Rust", 5, None, None, None).await.unwrap();
println!("\n============================================================");
println!("UPSERT (store same key twice):");
+13 -5
View File
@@ -216,7 +216,10 @@ async fn sqlite_memory_recall_returns_relevant_results() {
.await
.unwrap();
let results = mem.recall("Rust programming", 10, None).await.unwrap();
let results = mem
.recall("Rust programming", 10, None, None, None)
.await
.unwrap();
assert!(!results.is_empty(), "recall should find matching entries");
// The Rust-related entry should be in results
assert!(
@@ -241,7 +244,10 @@ async fn sqlite_memory_recall_respects_limit() {
.unwrap();
}
let results = mem.recall("test content", 3, None).await.unwrap();
let results = mem
.recall("test content", 3, None, None, None)
.await
.unwrap();
assert!(
results.len() <= 3,
"recall should respect limit of 3, got {}",
@@ -250,7 +256,7 @@ async fn sqlite_memory_recall_respects_limit() {
}
#[tokio::test]
async fn sqlite_memory_recall_empty_query_returns_empty() {
async fn sqlite_memory_recall_empty_query_returns_recent_entries() {
let tmp = tempfile::TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
@@ -258,8 +264,10 @@ async fn sqlite_memory_recall_empty_query_returns_empty() {
.await
.unwrap();
let results = mem.recall("", 10, None).await.unwrap();
assert!(results.is_empty(), "empty query should return no results");
// Empty query uses time-only path: returns recent entries by updated_at
let results = mem.recall("", 10, None, None, None).await.unwrap();
assert_eq!(results.len(), 1, "empty query should return recent entries");
assert_eq!(results[0].key, "fact");
}
// ─────────────────────────────────────────────────────────────────────────────
+2 -1
View File
@@ -16,6 +16,7 @@ import Pairing from './pages/Pairing';
import { AuthProvider, useAuth } from './hooks/useAuth';
import { DraftContext, useDraftStore } from './hooks/useDraft';
import { setLocale, type Locale } from './lib/i18n';
import { basePath } from './lib/basePath';
import { getAdminPairCode } from './lib/api';
// Locale context
@@ -131,7 +132,7 @@ function PairingDialog({ onPair }: { onPair: (code: string) => Promise<void> })
<div className="text-center mb-8">
<img
src="/_app/zeroclaw-trans.png"
src={`${basePath}/_app/zeroclaw-trans.png`}
alt="ZeroClaw"
className="h-20 w-20 rounded-2xl object-cover mx-auto mb-4 animate-float"
onError={(e) => { e.currentTarget.style.display = 'none'; }}
+2 -1
View File
@@ -1,4 +1,5 @@
import { NavLink } from 'react-router-dom';
import { basePath } from '../../lib/basePath';
import {
LayoutDashboard,
MessageSquare,
@@ -34,7 +35,7 @@ export default function Sidebar() {
<div className="relative shrink-0">
<div className="absolute -inset-1.5 rounded-xl" style={{ background: 'linear-gradient(135deg, rgba(var(--pc-accent-rgb), 0.15), rgba(var(--pc-accent-rgb), 0.05))' }} />
<img
src="/_app/zeroclaw-trans.png"
src={`${basePath}/_app/zeroclaw-trans.png`}
alt="ZeroClaw"
className="relative h-9 w-9 rounded-xl object-cover"
onError={(e) => {
+4 -3
View File
@@ -11,6 +11,7 @@ import type {
HealthSnapshot,
} from '../types/api';
import { clearToken, getToken, setToken } from './auth';
import { basePath } from './basePath';
// ---------------------------------------------------------------------------
// Base fetch wrapper
@@ -42,7 +43,7 @@ export async function apiFetch<T = unknown>(
headers.set('Content-Type', 'application/json');
}
const response = await fetch(path, { ...options, headers });
const response = await fetch(`${basePath}${path}`, { ...options, headers });
if (response.status === 401) {
clearToken();
@@ -78,7 +79,7 @@ function unwrapField<T>(value: T | Record<string, T>, key: string): T {
// ---------------------------------------------------------------------------
export async function pair(code: string): Promise<{ token: string }> {
const response = await fetch('/pair', {
const response = await fetch(`${basePath}/pair`, {
method: 'POST',
headers: { 'X-Pairing-Code': code },
});
@@ -106,7 +107,7 @@ export async function getAdminPairCode(): Promise<{ pairing_code: string | null;
// ---------------------------------------------------------------------------
export async function getPublicHealth(): Promise<{ require_pairing: boolean; paired: boolean }> {
const response = await fetch('/health');
const response = await fetch(`${basePath}/health`);
if (!response.ok) {
throw new Error(`Health check failed (${response.status})`);
}
+11
View File
@@ -0,0 +1,11 @@
// Runtime base path injected by the Rust gateway into index.html.
// Allows the SPA to work under a reverse-proxy path prefix.
declare global {
interface Window {
__ZEROCLAW_BASE__?: string;
}
}
/** Gateway path prefix (e.g. "/zeroclaw"), or empty string when served at root. */
export const basePath: string = (window.__ZEROCLAW_BASE__ ?? '').replace(/\/+$/, '');
+2 -1
View File
@@ -1,5 +1,6 @@
import type { SSEEvent } from '../types/api';
import { getToken } from './auth';
import { basePath } from './basePath';
export type SSEEventHandler = (event: SSEEvent) => void;
export type SSEErrorHandler = (error: Event | Error) => void;
@@ -41,7 +42,7 @@ export class SSEClient {
private readonly autoReconnect: boolean;
constructor(options: SSEClientOptions = {}) {
this.path = options.path ?? '/api/events';
this.path = options.path ?? `${basePath}/api/events`;
this.reconnectDelay = options.reconnectDelay ?? DEFAULT_RECONNECT_DELAY;
this.maxReconnectDelay = options.maxReconnectDelay ?? MAX_RECONNECT_DELAY;
this.autoReconnect = options.autoReconnect ?? true;
+2 -1
View File
@@ -1,5 +1,6 @@
import type { WsMessage } from '../types/api';
import { getToken } from './auth';
import { basePath } from './basePath';
import { generateUUID } from './uuid';
export type WsMessageHandler = (msg: WsMessage) => void;
@@ -69,7 +70,7 @@ export class WebSocketClient {
const params = new URLSearchParams();
if (token) params.set('token', token);
params.set('session_id', sessionId);
const url = `${this.baseUrl}/ws/chat?${params.toString()}`;
const url = `${this.baseUrl}${basePath}/ws/chat?${params.toString()}`;
const protocols: string[] = ['zeroclaw.v1'];
if (token) protocols.push(`bearer.${token}`);
+3 -2
View File
@@ -2,12 +2,13 @@ import React from 'react';
import ReactDOM from 'react-dom/client';
import { BrowserRouter } from 'react-router-dom';
import App from './App';
import { basePath } from './lib/basePath';
import './index.css';
ReactDOM.createRoot(document.getElementById('root')!).render(
<React.StrictMode>
{/* Vite base '/_app/' scopes static asset URLs only; app routes stay rooted at '/' for SPA fallback. */}
<BrowserRouter basename="/">
{/* basePath is injected by the Rust gateway at serve time for reverse-proxy prefix support. */}
<BrowserRouter basename={basePath || '/'}>
<App />
</BrowserRouter>
</React.StrictMode>