Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 18cfb4e2fe | |||
| ccabfd7167 | |||
| b91844966d | |||
| e46a19c61c | |||
| e8c61522ea | |||
| a12716a065 | |||
| 87b5bca449 | |||
| be40c0c5a5 | |||
| 6527871928 | |||
| 0bda80de9c | |||
| 02f57f4d98 | |||
| ef83dd44d7 | |||
| a986b6b912 | |||
| b6b1186e3b | |||
| 00dc0c8670 | |||
| 43f2a0a815 | |||
| 50b5bd4d73 | |||
| 8c074870a1 | |||
| 61d1841ce3 | |||
| eb396cf38f | |||
| 9f1657b9be | |||
| 8fecd4286c | |||
| df21d92da3 | |||
| 8d65924704 | |||
| 756c3cadff | |||
| ee870028ff |
Generated
+1
-1
@@ -9203,7 +9203,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.4"
|
||||
version = "0.5.5"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.4"
|
||||
version = "0.5.5"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
@@ -263,7 +263,7 @@ fn bench_memory_operations(c: &mut Criterion) {
|
||||
c.bench_function("memory_recall_top10", |b| {
|
||||
b.iter(|| {
|
||||
rt.block_on(async {
|
||||
mem.recall(black_box("zeroclaw agent"), 10, None)
|
||||
mem.recall(black_box("zeroclaw agent"), 10, None, None, None)
|
||||
.await
|
||||
.unwrap()
|
||||
})
|
||||
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.4
|
||||
pkgver = 0.5.5
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.4.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.4.tar.gz
|
||||
source = zeroclaw-0.5.5.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.5.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.4
|
||||
pkgver=0.5.5
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.4",
|
||||
"version": "0.5.5",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.4/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.5/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ SOP 审计条目通过 `SopAuditLogger` 持久化到配置的内存后端的 `so
|
||||
- `sop_step_{run_id}_{step_number}`:单步结果
|
||||
- `sop_approval_{run_id}_{step_number}`:操作员审批记录
|
||||
- `sop_timeout_approve_{run_id}_{step_number}`:超时自动审批记录
|
||||
- `sop_gate_decision_{gate_id}_{timestamp_ms}`:门评估器决策记录(启用 `ampersona-gates` 时)
|
||||
- `sop_phase_state`:持久化的信任阶段状态快照(启用 `ampersona-gates` 时)
|
||||
|
||||
## 2. 检查路径
|
||||
|
||||
|
||||
@@ -122,6 +122,34 @@ tools = ["mcp_browser_*"]
|
||||
keywords = ["browse", "navigate", "open url", "screenshot"]
|
||||
```
|
||||
|
||||
## `[pacing]`
|
||||
|
||||
Pacing controls for slow/local LLM workloads (Ollama, llama.cpp, vLLM). All keys are optional; when absent, existing behavior is preserved.
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `step_timeout_secs` | _none_ | Per-step timeout: maximum seconds for a single LLM inference turn. Catches a truly hung model without terminating the overall task loop |
|
||||
| `loop_detection_min_elapsed_secs` | _none_ | Minimum elapsed seconds before loop detection activates. Tasks completing under this threshold get aggressive loop protection; longer-running tasks receive a grace period |
|
||||
| `loop_ignore_tools` | `[]` | Tool names excluded from identical-output loop detection. Useful for browser workflows where `browser_screenshot` structurally resembles a loop |
|
||||
| `message_timeout_scale_max` | `4` | Override for the hardcoded timeout scaling cap. The channel message timeout budget is `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)` |
|
||||
|
||||
Notes:
|
||||
|
||||
- These settings are intended for local/slow LLM deployments. Cloud-provider users typically do not need them.
|
||||
- `step_timeout_secs` operates independently of the total channel message timeout budget. A step timeout abort does not consume the overall budget; the loop simply stops.
|
||||
- `loop_detection_min_elapsed_secs` delays loop-detection counting, not the task itself. Loop protection remains fully active for short tasks (the default).
|
||||
- `loop_ignore_tools` only suppresses tool-output-based loop detection for the listed tools. Other safety features (max iterations, overall timeout) remain active.
|
||||
- `message_timeout_scale_max` must be >= 1. Setting it higher than `max_tool_iterations` has no additional effect (the formula uses `min()`).
|
||||
- Example configuration for a slow local Ollama deployment:
|
||||
|
||||
```toml
|
||||
[pacing]
|
||||
step_timeout_secs = 120
|
||||
loop_detection_min_elapsed_secs = 60
|
||||
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
|
||||
message_timeout_scale_max = 8
|
||||
```
|
||||
|
||||
## `[security.otp]`
|
||||
|
||||
| Key | Default | Purpose |
|
||||
@@ -185,12 +213,15 @@ Delegate sub-agent configurations. Each key under `[agents]` defines a named sub
|
||||
| `max_iterations` | `10` | Max tool-call iterations for agentic mode |
|
||||
| `timeout_secs` | `120` | Timeout in seconds for non-agentic provider calls (1–3600) |
|
||||
| `agentic_timeout_secs` | `300` | Timeout in seconds for agentic sub-agent loops (1–3600) |
|
||||
| `skills_directory` | unset | Optional skills directory path (workspace-relative) for scoped skill loading |
|
||||
|
||||
Notes:
|
||||
|
||||
- `agentic = false` preserves existing single prompt→response delegate behavior.
|
||||
- `agentic = true` requires at least one matching entry in `allowed_tools`.
|
||||
- The `delegate` tool is excluded from sub-agent allowlists to prevent re-entrant delegation loops.
|
||||
- Sub-agents receive an enriched system prompt containing: tools section (allowed tools with parameters), skills section (from scoped or default directory), workspace path, current date/time, safety constraints, and shell policy when `shell` is in the effective tool list.
|
||||
- When `skills_directory` is unset or empty, the sub-agent loads skills from the default workspace `skills/` directory. When set, skills are loaded exclusively from that directory (relative to workspace root), enabling per-agent scoped skill sets.
|
||||
|
||||
```toml
|
||||
[agents.researcher]
|
||||
@@ -208,6 +239,14 @@ provider = "ollama"
|
||||
model = "qwen2.5-coder:32b"
|
||||
temperature = 0.2
|
||||
timeout_secs = 60
|
||||
|
||||
[agents.code_reviewer]
|
||||
provider = "anthropic"
|
||||
model = "claude-opus-4-5"
|
||||
system_prompt = "You are an expert code reviewer focused on security and performance."
|
||||
agentic = true
|
||||
allowed_tools = ["file_read", "shell"]
|
||||
skills_directory = "skills/code-review"
|
||||
```
|
||||
|
||||
## `[runtime]`
|
||||
@@ -414,6 +453,12 @@ Notes:
|
||||
| `port` | `42617` | gateway listen port |
|
||||
| `require_pairing` | `true` | require pairing before bearer auth |
|
||||
| `allow_public_bind` | `false` | block accidental public exposure |
|
||||
| `path_prefix` | _(none)_ | URL path prefix for reverse-proxy deployments (e.g. `"/zeroclaw"`) |
|
||||
|
||||
When deploying behind a reverse proxy that maps ZeroClaw to a sub-path,
|
||||
set `path_prefix` to that sub-path (e.g. `"/zeroclaw"`). All gateway
|
||||
routes will be served under this prefix. The value must start with `/`
|
||||
and must not end with `/`.
|
||||
|
||||
## `[autonomy]`
|
||||
|
||||
@@ -586,7 +631,7 @@ Top-level channel options are configured under `channels_config`.
|
||||
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x) |
|
||||
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x, overridable via `[pacing].message_timeout_scale_max`) |
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -601,7 +646,7 @@ Examples:
|
||||
Notes:
|
||||
|
||||
- Default `300s` is optimized for on-device LLMs (Ollama) which are slower than cloud APIs.
|
||||
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, 4)` and a minimum of `1`.
|
||||
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, cap)` and a minimum of `1`. The default cap is `4`; override with `[pacing].message_timeout_scale_max`.
|
||||
- This scaling avoids false timeouts when the first LLM turn is slow/retried but later tool-loop turns still need to complete.
|
||||
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
|
||||
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
|
||||
|
||||
@@ -12,8 +12,6 @@ Common key patterns:
|
||||
- `sop_step_{run_id}_{step_number}`: per-step result
|
||||
- `sop_approval_{run_id}_{step_number}`: operator approval record
|
||||
- `sop_timeout_approve_{run_id}_{step_number}`: timeout auto-approval record
|
||||
- `sop_gate_decision_{gate_id}_{timestamp_ms}`: gate evaluator decision record (when `ampersona-gates` is enabled)
|
||||
- `sop_phase_state`: persisted trust-phase state snapshot (when `ampersona-gates` is enabled)
|
||||
|
||||
## 2. Inspection Paths
|
||||
|
||||
|
||||
@@ -568,6 +568,13 @@ then re-run bootstrap.
|
||||
MSG
|
||||
exit 0
|
||||
fi
|
||||
# Detect un-accepted Xcode/CLT license (causes `cc` to exit 69).
|
||||
if ! /usr/bin/xcrun --show-sdk-path >/dev/null 2>&1; then
|
||||
warn "Xcode license has not been accepted. Run:"
|
||||
warn " sudo xcodebuild -license accept"
|
||||
warn "then re-run this installer."
|
||||
exit 1
|
||||
fi
|
||||
if ! have_cmd git; then
|
||||
warn "git is not available. Install git (e.g., Homebrew) and re-run bootstrap."
|
||||
fi
|
||||
|
||||
+451
-8
@@ -1,5 +1,8 @@
|
||||
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::config::Config;
|
||||
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
|
||||
use crate::cost::CostTracker;
|
||||
use crate::i18n::ToolDescriptions;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
@@ -23,6 +26,108 @@ use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ── Cost tracking via task-local ──
|
||||
|
||||
/// Context for cost tracking within the tool call loop.
|
||||
/// Scoped via `tokio::task_local!` at call sites (channels, gateway).
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ToolLoopCostTrackingContext {
|
||||
pub tracker: Arc<CostTracker>,
|
||||
pub prices: Arc<std::collections::HashMap<String, ModelPricing>>,
|
||||
}
|
||||
|
||||
impl ToolLoopCostTrackingContext {
|
||||
pub(crate) fn new(
|
||||
tracker: Arc<CostTracker>,
|
||||
prices: Arc<std::collections::HashMap<String, ModelPricing>>,
|
||||
) -> Self {
|
||||
Self { tracker, prices }
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task_local! {
|
||||
pub(crate) static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
|
||||
}
|
||||
|
||||
/// 3-tier model pricing lookup:
|
||||
/// 1. Direct model name
|
||||
/// 2. Qualified `provider/model`
|
||||
/// 3. Suffix after last `/`
|
||||
fn lookup_model_pricing<'a>(
|
||||
prices: &'a std::collections::HashMap<String, ModelPricing>,
|
||||
provider_name: &str,
|
||||
model: &str,
|
||||
) -> Option<&'a ModelPricing> {
|
||||
prices
|
||||
.get(model)
|
||||
.or_else(|| prices.get(&format!("{provider_name}/{model}")))
|
||||
.or_else(|| {
|
||||
model
|
||||
.rsplit_once('/')
|
||||
.and_then(|(_, suffix)| prices.get(suffix))
|
||||
})
|
||||
}
|
||||
|
||||
/// Record token usage from an LLM response via the task-local cost tracker.
|
||||
/// Returns `(total_tokens, cost_usd)` on success, `None` when not scoped or no usage.
|
||||
fn record_tool_loop_cost_usage(
|
||||
provider_name: &str,
|
||||
model: &str,
|
||||
usage: &crate::providers::traits::TokenUsage,
|
||||
) -> Option<(u64, f64)> {
|
||||
let input_tokens = usage.input_tokens.unwrap_or(0);
|
||||
let output_tokens = usage.output_tokens.unwrap_or(0);
|
||||
let total_tokens = input_tokens.saturating_add(output_tokens);
|
||||
if total_tokens == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten()?;
|
||||
let pricing = lookup_model_pricing(&ctx.prices, provider_name, model);
|
||||
let cost_usage = CostTokenUsage::new(
|
||||
model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
pricing.map_or(0.0, |entry| entry.input),
|
||||
pricing.map_or(0.0, |entry| entry.output),
|
||||
);
|
||||
|
||||
if pricing.is_none() {
|
||||
tracing::debug!(
|
||||
provider = provider_name,
|
||||
model,
|
||||
"Cost tracking recorded token usage with zero pricing (no pricing entry found)"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model,
|
||||
"Failed to record cost tracking usage: {error}"
|
||||
);
|
||||
}
|
||||
|
||||
Some((cost_usage.total_tokens, cost_usage.cost_usd))
|
||||
}
|
||||
|
||||
/// Check budget before an LLM call. Returns `None` when no cost tracking
|
||||
/// context is scoped (tests, delegate, CLI without cost config).
|
||||
pub(crate) fn check_tool_loop_budget() -> Option<BudgetCheck> {
|
||||
TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|ctx| {
|
||||
ctx.tracker
|
||||
.check_budget(0.0)
|
||||
.unwrap_or(BudgetCheck::Allowed)
|
||||
})
|
||||
}
|
||||
|
||||
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
|
||||
const STREAM_CHUNK_MIN_CHARS: usize = 80;
|
||||
|
||||
@@ -465,7 +570,7 @@ async fn build_context(
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
@@ -2226,6 +2331,7 @@ pub(crate) async fn agent_turn(
|
||||
dedup_exempt_tools,
|
||||
activated_tools,
|
||||
model_switch_callback,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2535,6 +2641,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
dedup_exempt_tools: &[String],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
model_switch_callback: Option<ModelSwitchCallback>,
|
||||
pacing: &crate::config::PacingConfig,
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
DEFAULT_MAX_TOOL_ITERATIONS
|
||||
@@ -2543,6 +2650,14 @@ pub(crate) async fn run_tool_call_loop(
|
||||
};
|
||||
|
||||
let turn_id = Uuid::new_v4().to_string();
|
||||
let loop_started_at = Instant::now();
|
||||
let loop_ignore_tools: HashSet<&str> = pacing
|
||||
.loop_ignore_tools
|
||||
.iter()
|
||||
.map(String::as_str)
|
||||
.collect();
|
||||
let mut consecutive_identical_outputs: usize = 0;
|
||||
let mut last_tool_output_hash: Option<u64> = None;
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
@@ -2642,6 +2757,19 @@ pub(crate) async fn run_tool_call_loop(
|
||||
hooks.fire_llm_input(history, model).await;
|
||||
}
|
||||
|
||||
// Budget enforcement — block if limit exceeded (no-op when not scoped)
|
||||
if let Some(BudgetCheck::Exceeded {
|
||||
current_usd,
|
||||
limit_usd,
|
||||
period,
|
||||
}) = check_tool_loop_budget()
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"Budget exceeded: ${:.4} of ${:.2} {:?} limit. Cannot make further API calls until the budget resets.",
|
||||
current_usd, limit_usd, period
|
||||
));
|
||||
}
|
||||
|
||||
// Unified path via Provider::chat so provider-specific native tool logic
|
||||
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
|
||||
let request_tools = if use_native_tools {
|
||||
@@ -2659,13 +2787,43 @@ pub(crate) async fn run_tool_call_loop(
|
||||
temperature,
|
||||
);
|
||||
|
||||
let chat_result = if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
// Wrap the LLM call with an optional per-step timeout from pacing config.
|
||||
// This catches a truly hung model response without terminating the overall
|
||||
// task loop (the per-message budget handles that separately).
|
||||
let chat_result = match pacing.step_timeout_secs {
|
||||
Some(step_secs) if step_secs > 0 => {
|
||||
let step_timeout = Duration::from_secs(step_secs);
|
||||
if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = tokio::time::timeout(step_timeout, chat_future) => {
|
||||
match result {
|
||||
Ok(inner) => inner,
|
||||
Err(_) => anyhow::bail!(
|
||||
"LLM inference step timed out after {step_secs}s (step_timeout_secs)"
|
||||
),
|
||||
}
|
||||
},
|
||||
}
|
||||
} else {
|
||||
match tokio::time::timeout(step_timeout, chat_future).await {
|
||||
Ok(inner) => inner,
|
||||
Err(_) => anyhow::bail!(
|
||||
"LLM inference step timed out after {step_secs}s (step_timeout_secs)"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
};
|
||||
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
||||
@@ -2687,6 +2845,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
output_tokens: resp_output_tokens,
|
||||
});
|
||||
|
||||
// Record cost via task-local tracker (no-op when not scoped)
|
||||
let _ = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage));
|
||||
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
// Fall back to text-based parsing (XML tags, markdown blocks,
|
||||
@@ -3158,7 +3322,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
|
||||
}
|
||||
|
||||
// Collect tool results and build per-tool output for loop detection.
|
||||
// Only non-ignored tool outputs contribute to the identical-output hash.
|
||||
let mut detection_relevant_output = String::new();
|
||||
for (tool_name, tool_call_id, outcome) in ordered_results.into_iter().flatten() {
|
||||
if !loop_ignore_tools.contains(tool_name.as_str()) {
|
||||
detection_relevant_output.push_str(&outcome.output);
|
||||
}
|
||||
individual_results.push((tool_call_id, outcome.output.clone()));
|
||||
let _ = writeln!(
|
||||
tool_results,
|
||||
@@ -3167,6 +3337,53 @@ pub(crate) async fn run_tool_call_loop(
|
||||
);
|
||||
}
|
||||
|
||||
// ── Time-gated loop detection ──────────────────────────
|
||||
// When pacing.loop_detection_min_elapsed_secs is set, identical-output
|
||||
// loop detection activates after the task has been running that long.
|
||||
// This avoids false-positive aborts on long-running browser/research
|
||||
// workflows while keeping aggressive protection for quick tasks.
|
||||
// When not configured, identical-output detection is disabled (preserving
|
||||
// existing behavior where only max_iterations prevents runaway loops).
|
||||
let loop_detection_active = match pacing.loop_detection_min_elapsed_secs {
|
||||
Some(min_secs) => loop_started_at.elapsed() >= Duration::from_secs(min_secs),
|
||||
None => false, // disabled when not configured (backwards compatible)
|
||||
};
|
||||
|
||||
if loop_detection_active && !detection_relevant_output.is_empty() {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
detection_relevant_output.hash(&mut hasher);
|
||||
let current_hash = hasher.finish();
|
||||
|
||||
if last_tool_output_hash == Some(current_hash) {
|
||||
consecutive_identical_outputs += 1;
|
||||
} else {
|
||||
consecutive_identical_outputs = 0;
|
||||
last_tool_output_hash = Some(current_hash);
|
||||
}
|
||||
|
||||
// Bail if we see 3+ consecutive identical tool outputs (clear runaway).
|
||||
if consecutive_identical_outputs >= 3 {
|
||||
runtime_trace::record_event(
|
||||
"tool_loop_identical_output_abort",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("identical tool output detected 3 consecutive times"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"consecutive_identical": consecutive_identical_outputs,
|
||||
}),
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Agent loop aborted: identical tool output detected {} consecutive times",
|
||||
consecutive_identical_outputs
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls + tool results to history.
|
||||
// Native mode: use JSON-structured messages so convert_messages() can
|
||||
// reconstruct proper OpenAI-format tool_calls and tool result messages.
|
||||
@@ -3716,6 +3933,7 @@ pub async fn run(
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&config.pacing,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -3943,6 +4161,7 @@ pub async fn run(
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&config.pacing,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -4840,6 +5059,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
@@ -4890,6 +5110,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
@@ -4934,6 +5155,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
@@ -5064,6 +5286,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("parallel execution should complete");
|
||||
@@ -5134,6 +5357,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("cron_add delivery defaults should be injected");
|
||||
@@ -5196,6 +5420,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("explicit delivery mode should be preserved");
|
||||
@@ -5253,6 +5478,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish after deduplicating repeated calls");
|
||||
@@ -5322,6 +5548,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("non-interactive shell should succeed for low-risk command");
|
||||
@@ -5382,6 +5609,7 @@ mod tests {
|
||||
&exempt,
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish with exempt tool executing twice");
|
||||
@@ -5462,6 +5690,7 @@ mod tests {
|
||||
&exempt,
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("loop should complete");
|
||||
@@ -5519,6 +5748,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("native fallback id flow should complete");
|
||||
@@ -5600,6 +5830,7 @@ mod tests {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("native tool-call text should be relayed through on_delta");
|
||||
@@ -6395,7 +6626,7 @@ Tail"#;
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None, None, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
@@ -7585,6 +7816,7 @@ Let me check the result."#;
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
@@ -7662,4 +7894,215 @@ Let me check the result."#;
|
||||
let result = filter_by_allowed_tools(specs, Some(&allowed));
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
// ── Cost tracking tests ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_tracking_records_usage_when_scoped() {
|
||||
use super::{
|
||||
run_tool_call_loop, ToolLoopCostTrackingContext, TOOL_LOOP_COST_TRACKING_CONTEXT,
|
||||
};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::cost::CostTracker;
|
||||
use crate::observability::noop::NoopObserver;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let provider = ScriptedProvider {
|
||||
responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse {
|
||||
text: Some("done".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(crate::providers::traits::TokenUsage {
|
||||
input_tokens: Some(1_000),
|
||||
output_tokens: Some(200),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}]))),
|
||||
capabilities: ProviderCapabilities::default(),
|
||||
};
|
||||
let observer = NoopObserver;
|
||||
let workspace = tempfile::TempDir::new().unwrap();
|
||||
let mut cost_config = crate::config::CostConfig {
|
||||
enabled: true,
|
||||
..crate::config::CostConfig::default()
|
||||
};
|
||||
cost_config.prices = HashMap::from([(
|
||||
"mock-model".to_string(),
|
||||
ModelPricing {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
},
|
||||
)]);
|
||||
let tracker = Arc::new(CostTracker::new(cost_config.clone(), workspace.path()).unwrap());
|
||||
let ctx = ToolLoopCostTrackingContext::new(
|
||||
Arc::clone(&tracker),
|
||||
Arc::new(cost_config.prices.clone()),
|
||||
);
|
||||
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
|
||||
|
||||
let result = TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.scope(
|
||||
Some(ctx),
|
||||
run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&[],
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"test",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should succeed");
|
||||
|
||||
assert_eq!(result, "done");
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
assert_eq!(summary.request_count, 1);
|
||||
assert_eq!(summary.total_tokens, 1_200);
|
||||
assert!(summary.session_cost_usd > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_tracking_enforces_budget() {
|
||||
use super::{
|
||||
run_tool_call_loop, ToolLoopCostTrackingContext, TOOL_LOOP_COST_TRACKING_CONTEXT,
|
||||
};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::cost::CostTracker;
|
||||
use crate::observability::noop::NoopObserver;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["should not reach this"]);
|
||||
let observer = NoopObserver;
|
||||
let workspace = tempfile::TempDir::new().unwrap();
|
||||
let cost_config = crate::config::CostConfig {
|
||||
enabled: true,
|
||||
daily_limit_usd: 0.001, // very low limit
|
||||
..crate::config::CostConfig::default()
|
||||
};
|
||||
let tracker = Arc::new(CostTracker::new(cost_config.clone(), workspace.path()).unwrap());
|
||||
// Record a usage that already exceeds the limit
|
||||
tracker
|
||||
.record_usage(crate::cost::types::TokenUsage::new(
|
||||
"mock-model",
|
||||
100_000,
|
||||
50_000,
|
||||
1.0,
|
||||
1.0,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let ctx = ToolLoopCostTrackingContext::new(
|
||||
Arc::clone(&tracker),
|
||||
Arc::new(HashMap::from([(
|
||||
"mock-model".to_string(),
|
||||
ModelPricing {
|
||||
input: 1.0,
|
||||
output: 1.0,
|
||||
},
|
||||
)])),
|
||||
);
|
||||
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
|
||||
|
||||
let err = TOOL_LOOP_COST_TRACKING_CONTEXT
|
||||
.scope(
|
||||
Some(ctx),
|
||||
run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&[],
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"test",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail with budget exceeded");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("Budget exceeded"),
|
||||
"error should mention budget: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cost_tracking_is_noop_without_scope() {
|
||||
use super::run_tool_call_loop;
|
||||
use crate::observability::noop::NoopObserver;
|
||||
|
||||
// No TOOL_LOOP_COST_TRACKING_CONTEXT scoped — should run fine
|
||||
let provider = ScriptedProvider {
|
||||
responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse {
|
||||
text: Some("ok".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(crate::providers::traits::TokenUsage {
|
||||
input_tokens: Some(500),
|
||||
output_tokens: Some(100),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}]))),
|
||||
capabilities: ProviderCapabilities::default(),
|
||||
};
|
||||
let observer = NoopObserver;
|
||||
let mut history = vec![ChatMessage::system("test"), ChatMessage::user("hello")];
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&[],
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"test",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("should succeed without cost scope");
|
||||
|
||||
assert_eq!(result, "ok");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,9 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit, session_id).await?;
|
||||
let entries = memory
|
||||
.recall(user_message, self.limit, session_id, None, None)
|
||||
.await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
@@ -102,6 +104,8 @@ mod tests {
|
||||
_query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if limit == 0 {
|
||||
return Ok(vec![]);
|
||||
@@ -163,6 +167,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(self.entries.as_ref().clone())
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ pub struct DingTalkChannel {
|
||||
/// Per-chat session webhooks for sending replies (chatID -> webhook URL).
|
||||
/// DingTalk provides a unique webhook URL with each incoming message.
|
||||
session_webhooks: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Response from DingTalk gateway connection registration.
|
||||
@@ -34,11 +36,18 @@ impl DingTalkChannel {
|
||||
client_secret,
|
||||
allowed_users,
|
||||
session_webhooks: Arc::new(RwLock::new(HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.dingtalk")
|
||||
crate::config::build_channel_proxy_client("channel.dingtalk", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
|
||||
+10
-1
@@ -18,6 +18,8 @@ pub struct DiscordChannel {
|
||||
listen_to_bots: bool,
|
||||
mention_only: bool,
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@@ -35,11 +37,18 @@ impl DiscordChannel {
|
||||
listen_to_bots,
|
||||
mention_only,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.discord")
|
||||
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
/// Check if a Discord user ID is in the allowlist.
|
||||
|
||||
+16
-1
@@ -380,6 +380,8 @@ pub struct LarkChannel {
|
||||
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
|
||||
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
|
||||
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
@@ -423,6 +425,7 @@ impl LarkChannel {
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::default(),
|
||||
tenant_token: Arc::new(RwLock::new(None)),
|
||||
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -444,6 +447,7 @@ impl LarkChannel {
|
||||
platform,
|
||||
);
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch.proxy_url = config.proxy_url.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
@@ -461,6 +465,7 @@ impl LarkChannel {
|
||||
LarkPlatform::Lark,
|
||||
);
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch.proxy_url = config.proxy_url.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
@@ -476,11 +481,15 @@ impl LarkChannel {
|
||||
LarkPlatform::Feishu,
|
||||
);
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch.proxy_url = config.proxy_url.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client(self.platform.proxy_service_key())
|
||||
crate::config::build_channel_proxy_client(
|
||||
self.platform.proxy_service_key(),
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn channel_name(&self) -> &'static str {
|
||||
@@ -2113,6 +2122,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::default(),
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -2135,6 +2145,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -2169,6 +2180,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_config(&cfg);
|
||||
@@ -2193,6 +2205,7 @@ mod tests {
|
||||
use_feishu: true,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_lark_config(&cfg);
|
||||
@@ -2214,6 +2227,7 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_feishu_config(&cfg);
|
||||
@@ -2386,6 +2400,7 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let ch_feishu = LarkChannel::from_feishu_config(&feishu_cfg);
|
||||
assert_eq!(
|
||||
|
||||
@@ -17,6 +17,8 @@ pub struct MattermostChannel {
|
||||
mention_only: bool,
|
||||
/// Handle for the background typing-indicator loop (aborted on stop_typing).
|
||||
typing_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl MattermostChannel {
|
||||
@@ -38,11 +40,18 @@ impl MattermostChannel {
|
||||
thread_replies,
|
||||
mention_only,
|
||||
typing_handle: Mutex::new(None),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.mattermost")
|
||||
crate::config::build_channel_proxy_client("channel.mattermost", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
/// Check if a user ID is in the allowlist.
|
||||
|
||||
+218
-44
@@ -222,9 +222,21 @@ fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
|
||||
fn channel_message_timeout_budget_secs(
|
||||
message_timeout_secs: u64,
|
||||
max_tool_iterations: usize,
|
||||
) -> u64 {
|
||||
channel_message_timeout_budget_secs_with_cap(
|
||||
message_timeout_secs,
|
||||
max_tool_iterations,
|
||||
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP,
|
||||
)
|
||||
}
|
||||
|
||||
fn channel_message_timeout_budget_secs_with_cap(
|
||||
message_timeout_secs: u64,
|
||||
max_tool_iterations: usize,
|
||||
scale_cap: u64,
|
||||
) -> u64 {
|
||||
let iterations = max_tool_iterations.max(1) as u64;
|
||||
let scale = iterations.min(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
|
||||
let scale = iterations.min(scale_cap);
|
||||
message_timeout_secs.saturating_mul(scale)
|
||||
}
|
||||
|
||||
@@ -313,6 +325,12 @@ impl InterruptOnNewMessageConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelCostTrackingState {
|
||||
tracker: Arc<crate::cost::CostTracker>,
|
||||
prices: Arc<HashMap<String, crate::config::schema::ModelPricing>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelRuntimeContext {
|
||||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
@@ -355,6 +373,8 @@ struct ChannelRuntimeContext {
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
cost_tracking: Option<ChannelCostTrackingState>,
|
||||
pacing: crate::config::PacingConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1511,7 +1531,7 @@ async fn build_memory_context(
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
let mut included = 0usize;
|
||||
let mut used_chars = 0usize;
|
||||
|
||||
@@ -2395,8 +2415,18 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let model_switch_callback = get_model_switch_state();
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
let scale_cap = ctx
|
||||
.pacing
|
||||
.message_timeout_scale_max
|
||||
.unwrap_or(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
|
||||
let timeout_budget_secs = channel_message_timeout_budget_secs_with_cap(
|
||||
ctx.message_timeout_secs,
|
||||
ctx.max_tool_iterations,
|
||||
scale_cap,
|
||||
);
|
||||
let cost_tracking_context = ctx.cost_tracking.clone().map(|state| {
|
||||
crate::agent::loop_::ToolLoopCostTrackingContext::new(state.tracker, state.prices)
|
||||
});
|
||||
let llm_call_start = Instant::now();
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let elapsed_before_llm_ms = started_at.elapsed().as_millis() as u64;
|
||||
@@ -2406,6 +2436,8 @@ async fn process_channel_message(
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_budget_secs),
|
||||
crate::agent::loop_::TOOL_LOOP_COST_TRACKING_CONTEXT.scope(
|
||||
cost_tracking_context.clone(),
|
||||
run_tool_call_loop(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
@@ -2433,6 +2465,8 @@ async fn process_channel_message(
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
&ctx.pacing,
|
||||
),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@@ -3691,7 +3725,8 @@ fn collect_configured_channels(
|
||||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms)
|
||||
.with_transcription(config.transcription.clone())
|
||||
.with_tts(config.tts.clone())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(tg.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3699,13 +3734,16 @@ fn collect_configured_channels(
|
||||
if let Some(ref dc) = config.channels_config.discord {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Discord",
|
||||
channel: Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)),
|
||||
channel: Arc::new(
|
||||
DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_proxy_url(dc.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3722,7 +3760,8 @@ fn collect_configured_channels(
|
||||
)
|
||||
.with_thread_replies(sl.thread_replies.unwrap_or(true))
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(sl.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3730,14 +3769,17 @@ fn collect_configured_channels(
|
||||
if let Some(ref mm) = config.channels_config.mattermost {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Mattermost",
|
||||
channel: Arc::new(MattermostChannel::new(
|
||||
mm.url.clone(),
|
||||
mm.bot_token.clone(),
|
||||
mm.channel_id.clone(),
|
||||
mm.allowed_users.clone(),
|
||||
mm.thread_replies.unwrap_or(true),
|
||||
mm.mention_only.unwrap_or(false),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
MattermostChannel::new(
|
||||
mm.url.clone(),
|
||||
mm.bot_token.clone(),
|
||||
mm.channel_id.clone(),
|
||||
mm.allowed_users.clone(),
|
||||
mm.thread_replies.unwrap_or(true),
|
||||
mm.mention_only.unwrap_or(false),
|
||||
)
|
||||
.with_proxy_url(mm.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3775,14 +3817,17 @@ fn collect_configured_channels(
|
||||
if let Some(ref sig) = config.channels_config.signal {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Signal",
|
||||
channel: Arc::new(SignalChannel::new(
|
||||
sig.http_url.clone(),
|
||||
sig.account.clone(),
|
||||
sig.group_id.clone(),
|
||||
sig.allowed_from.clone(),
|
||||
sig.ignore_attachments,
|
||||
sig.ignore_stories,
|
||||
)),
|
||||
channel: Arc::new(
|
||||
SignalChannel::new(
|
||||
sig.http_url.clone(),
|
||||
sig.account.clone(),
|
||||
sig.group_id.clone(),
|
||||
sig.allowed_from.clone(),
|
||||
sig.ignore_attachments,
|
||||
sig.ignore_stories,
|
||||
)
|
||||
.with_proxy_url(sig.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3799,12 +3844,15 @@ fn collect_configured_channels(
|
||||
if wa.is_cloud_config() {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WhatsApp",
|
||||
channel: Arc::new(WhatsAppChannel::new(
|
||||
wa.access_token.clone().unwrap_or_default(),
|
||||
wa.phone_number_id.clone().unwrap_or_default(),
|
||||
wa.verify_token.clone().unwrap_or_default(),
|
||||
wa.allowed_numbers.clone(),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
WhatsAppChannel::new(
|
||||
wa.access_token.clone().unwrap_or_default(),
|
||||
wa.phone_number_id.clone().unwrap_or_default(),
|
||||
wa.verify_token.clone().unwrap_or_default(),
|
||||
wa.allowed_numbers.clone(),
|
||||
)
|
||||
.with_proxy_url(wa.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
} else {
|
||||
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
|
||||
@@ -3861,11 +3909,12 @@ fn collect_configured_channels(
|
||||
if let Some(ref wati_cfg) = config.channels_config.wati {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "WATI",
|
||||
channel: Arc::new(WatiChannel::new(
|
||||
channel: Arc::new(WatiChannel::new_with_proxy(
|
||||
wati_cfg.api_token.clone(),
|
||||
wati_cfg.api_url.clone(),
|
||||
wati_cfg.tenant_id.clone(),
|
||||
wati_cfg.allowed_numbers.clone(),
|
||||
wati_cfg.proxy_url.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
@@ -3873,10 +3922,11 @@ fn collect_configured_channels(
|
||||
if let Some(ref nc) = config.channels_config.nextcloud_talk {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Nextcloud Talk",
|
||||
channel: Arc::new(NextcloudTalkChannel::new(
|
||||
channel: Arc::new(NextcloudTalkChannel::new_with_proxy(
|
||||
nc.base_url.clone(),
|
||||
nc.app_token.clone(),
|
||||
nc.allowed_users.clone(),
|
||||
nc.proxy_url.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
@@ -3948,11 +3998,14 @@ fn collect_configured_channels(
|
||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "DingTalk",
|
||||
channel: Arc::new(DingTalkChannel::new(
|
||||
dt.client_id.clone(),
|
||||
dt.client_secret.clone(),
|
||||
dt.allowed_users.clone(),
|
||||
)),
|
||||
channel: Arc::new(
|
||||
DingTalkChannel::new(
|
||||
dt.client_id.clone(),
|
||||
dt.client_secret.clone(),
|
||||
dt.allowed_users.clone(),
|
||||
)
|
||||
.with_proxy_url(dt.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3965,7 +4018,8 @@ fn collect_configured_channels(
|
||||
qq.app_secret.clone(),
|
||||
qq.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(qq.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -4600,6 +4654,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
activated_tools: ch_activated_handle,
|
||||
cost_tracking: crate::cost::CostTracker::get_or_init_global(
|
||||
config.cost.clone(),
|
||||
&config.workspace_dir,
|
||||
)
|
||||
.map(|tracker| ChannelCostTrackingState {
|
||||
tracker,
|
||||
prices: Arc::new(config.cost.prices.clone()),
|
||||
}),
|
||||
pacing: config.pacing.clone(),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
@@ -4696,6 +4759,49 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_message_timeout_budget_with_custom_scale_cap() {
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 8, 8),
|
||||
300 * 8
|
||||
);
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 20, 8),
|
||||
300 * 8
|
||||
);
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 10, 1),
|
||||
300
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pacing_config_defaults_preserve_existing_behavior() {
|
||||
let pacing = crate::config::PacingConfig::default();
|
||||
assert!(pacing.step_timeout_secs.is_none());
|
||||
assert!(pacing.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(pacing.loop_ignore_tools.is_empty());
|
||||
assert!(pacing.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pacing_message_timeout_scale_max_overrides_default_cap() {
|
||||
// Custom cap of 8 scales budget proportionally
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(300, 10, 8),
|
||||
300 * 8
|
||||
);
|
||||
// Default cap produces the standard behavior
|
||||
assert_eq!(
|
||||
channel_message_timeout_budget_secs_with_cap(
|
||||
300,
|
||||
10,
|
||||
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
|
||||
),
|
||||
300 * CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_window_overflow_error_detector_matches_known_messages() {
|
||||
let overflow_err = anyhow::anyhow!(
|
||||
@@ -4899,6 +5005,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -5014,6 +5122,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -5085,6 +5195,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -5175,6 +5287,8 @@ mod tests {
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(
|
||||
@@ -5715,6 +5829,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5795,6 +5911,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5889,6 +6007,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5968,6 +6088,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6057,6 +6179,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6167,6 +6291,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6258,6 +6384,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6364,6 +6492,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6455,6 +6585,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6536,6 +6668,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6584,6 +6718,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -6636,6 +6772,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(vec![crate::memory::MemoryEntry {
|
||||
id: "entry-1".to_string(),
|
||||
@@ -6728,6 +6866,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -6829,6 +6969,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6944,7 +7086,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7058,6 +7202,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -7153,6 +7299,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7232,6 +7380,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7884,7 +8034,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None, None, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
@@ -7997,6 +8147,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8127,6 +8279,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8297,6 +8451,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8404,6 +8560,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8721,6 +8879,7 @@ This is an example JSON object for profile settings."#;
|
||||
thread_replies: Some(true),
|
||||
mention_only: Some(false),
|
||||
interrupt_on_new_message: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
let channels = collect_configured_channels(&config, "test");
|
||||
@@ -8974,6 +9133,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -9060,6 +9221,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9221,6 +9384,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9331,6 +9496,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9433,6 +9600,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9555,6 +9724,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -9613,6 +9784,7 @@ This is an example JSON object for profile settings."#;
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
match build_channel_by_id(&config, "telegram") {
|
||||
Ok(channel) => assert_eq!(channel.name(), "telegram"),
|
||||
@@ -9814,6 +9986,8 @@ This is an example JSON object for profile settings."#;
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
cost_tracking: None,
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
|
||||
@@ -17,11 +17,23 @@ pub struct NextcloudTalkChannel {
|
||||
|
||||
impl NextcloudTalkChannel {
|
||||
pub fn new(base_url: String, app_token: String, allowed_users: Vec<String>) -> Self {
|
||||
Self::new_with_proxy(base_url, app_token, allowed_users, None)
|
||||
}
|
||||
|
||||
pub fn new_with_proxy(
|
||||
base_url: String,
|
||||
app_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
proxy_url: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
app_token,
|
||||
allowed_users,
|
||||
client: reqwest::Client::new(),
|
||||
client: crate::config::build_channel_proxy_client(
|
||||
"channel.nextcloud_talk",
|
||||
proxy_url.as_deref(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+10
-1
@@ -285,6 +285,8 @@ pub struct QQChannel {
|
||||
upload_cache: Arc<RwLock<HashMap<String, UploadCacheEntry>>>,
|
||||
/// Passive reply tracker for QQ API rate limiting.
|
||||
reply_tracker: Arc<RwLock<HashMap<String, ReplyRecord>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl QQChannel {
|
||||
@@ -298,6 +300,7 @@ impl QQChannel {
|
||||
workspace_dir: None,
|
||||
upload_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
reply_tracker: Arc::new(RwLock::new(HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,8 +310,14 @@ impl QQChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.qq")
|
||||
crate::config::build_channel_proxy_client("channel.qq", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
|
||||
+14
-1
@@ -28,6 +28,8 @@ pub struct SignalChannel {
|
||||
allowed_from: Vec<String>,
|
||||
ignore_attachments: bool,
|
||||
ignore_stories: bool,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
// ── signal-cli SSE event JSON shapes ────────────────────────────
|
||||
@@ -87,12 +89,23 @@ impl SignalChannel {
|
||||
allowed_from,
|
||||
ignore_attachments,
|
||||
ignore_stories,
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
let builder = Client::builder().connect_timeout(Duration::from_secs(10));
|
||||
let builder = crate::config::apply_runtime_proxy_to_builder(builder, "channel.signal");
|
||||
let builder = crate::config::apply_channel_proxy_to_builder(
|
||||
builder,
|
||||
"channel.signal",
|
||||
self.proxy_url.as_deref(),
|
||||
);
|
||||
builder.build().expect("Signal HTTP client should build")
|
||||
}
|
||||
|
||||
|
||||
+78
-2
@@ -32,6 +32,8 @@ pub struct SlackChannel {
|
||||
workspace_dir: Option<PathBuf>,
|
||||
/// Maps channel_id -> thread_ts for active assistant threads (used for status indicators).
|
||||
active_assistant_thread: Mutex<HashMap<String, String>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@@ -46,6 +48,7 @@ const SLACK_ATTACHMENT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
|
||||
const SLACK_ATTACHMENT_IMAGE_INLINE_FALLBACK_MAX_BYTES: usize = 512 * 1024;
|
||||
const SLACK_ATTACHMENT_TEXT_DOWNLOAD_MAX_BYTES: usize = 256 * 1024;
|
||||
const SLACK_ATTACHMENT_TEXT_INLINE_MAX_CHARS: usize = 12_000;
|
||||
const SLACK_MARKDOWN_BLOCK_MAX_CHARS: usize = 12_000;
|
||||
const SLACK_ATTACHMENT_FILENAME_MAX_CHARS: usize = 128;
|
||||
const SLACK_USER_CACHE_MAX_ENTRIES: usize = 1000;
|
||||
const SLACK_ATTACHMENT_SAVE_SUBDIR: &str = "slack_files";
|
||||
@@ -121,6 +124,7 @@ impl SlackChannel {
|
||||
user_display_name_cache: Mutex::new(HashMap::new()),
|
||||
workspace_dir: None,
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,8 +152,19 @@ impl SlackChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("channel.slack", 30, 10)
|
||||
crate::config::build_channel_proxy_client_with_timeouts(
|
||||
"channel.slack",
|
||||
self.proxy_url.as_deref(),
|
||||
30,
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if a Slack user ID is in the allowlist.
|
||||
@@ -804,12 +819,13 @@ impl SlackChannel {
|
||||
}
|
||||
|
||||
fn slack_media_http_client_no_redirect(&self) -> anyhow::Result<reqwest::Client> {
|
||||
let builder = crate::config::apply_runtime_proxy_to_builder(
|
||||
let builder = crate::config::apply_channel_proxy_to_builder(
|
||||
reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect_timeout(Duration::from_secs(10)),
|
||||
"channel.slack",
|
||||
self.proxy_url.as_deref(),
|
||||
);
|
||||
builder
|
||||
.build()
|
||||
@@ -2272,6 +2288,14 @@ impl Channel for SlackChannel {
|
||||
"text": message.content
|
||||
});
|
||||
|
||||
// Use Slack's native markdown block for rich formatting when content fits.
|
||||
if message.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": message.content
|
||||
}]);
|
||||
}
|
||||
|
||||
if let Some(ts) = self.outbound_thread_ts(message) {
|
||||
body["thread_ts"] = serde_json::json!(ts);
|
||||
}
|
||||
@@ -3630,6 +3654,58 @@ mod tests {
|
||||
assert_ne!(key1, key2, "session key should differ per thread");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_send_uses_markdown_blocks() {
|
||||
let msg = SendMessage::new("**bold** and _italic_", "C123");
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
|
||||
// Build the same JSON body that send() would construct.
|
||||
let mut body = serde_json::json!({
|
||||
"channel": msg.recipient,
|
||||
"text": msg.content
|
||||
});
|
||||
if msg.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": msg.content
|
||||
}]);
|
||||
}
|
||||
|
||||
// Verify blocks are present with correct structure.
|
||||
let blocks = body["blocks"]
|
||||
.as_array()
|
||||
.expect("blocks should be an array");
|
||||
assert_eq!(blocks.len(), 1);
|
||||
assert_eq!(blocks[0]["type"], "markdown");
|
||||
assert_eq!(blocks[0]["text"], msg.content);
|
||||
// text field kept as plaintext fallback.
|
||||
assert_eq!(body["text"], msg.content);
|
||||
// Suppress unused variable warning.
|
||||
let _ = ch.name();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_send_skips_markdown_blocks_for_long_content() {
|
||||
let long_content = "x".repeat(SLACK_MARKDOWN_BLOCK_MAX_CHARS + 1);
|
||||
let msg = SendMessage::new(long_content.clone(), "C123");
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"channel": msg.recipient,
|
||||
"text": msg.content
|
||||
});
|
||||
if msg.content.len() <= SLACK_MARKDOWN_BLOCK_MAX_CHARS {
|
||||
body["blocks"] = serde_json::json!([{
|
||||
"type": "markdown",
|
||||
"text": msg.content
|
||||
}]);
|
||||
}
|
||||
|
||||
assert!(
|
||||
body.get("blocks").is_none(),
|
||||
"blocks should not be set for oversized content"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_typing_requires_thread_context() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, None, vec![], vec![]);
|
||||
|
||||
@@ -337,6 +337,8 @@ pub struct TelegramChannel {
|
||||
voice_chats: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
||||
pending_voice:
|
||||
Arc<std::sync::Mutex<std::collections::HashMap<String, (String, std::time::Instant)>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -379,6 +381,7 @@ impl TelegramChannel {
|
||||
tts_config: None,
|
||||
voice_chats: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
pending_voice: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,6 +391,12 @@ impl TelegramChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure workspace directory for saving downloaded attachments.
|
||||
pub fn with_workspace_dir(mut self, dir: std::path::PathBuf) -> Self {
|
||||
self.workspace_dir = Some(dir);
|
||||
@@ -478,7 +487,7 @@ impl TelegramChannel {
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.telegram")
|
||||
crate::config::build_channel_proxy_client("channel.telegram", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
fn normalize_identity(value: &str) -> String {
|
||||
|
||||
+389
-18
@@ -80,17 +80,10 @@ fn resolve_transcription_api_key(config: &TranscriptionConfig) -> Result<String>
|
||||
);
|
||||
}
|
||||
|
||||
/// Validate audio data and resolve MIME type from file name.
|
||||
/// Resolve MIME type and normalize filename from extension.
|
||||
///
|
||||
/// Returns `(normalized_filename, mime_type)` on success.
|
||||
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
|
||||
if audio_data.len() > MAX_AUDIO_BYTES {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
|
||||
audio_data.len()
|
||||
);
|
||||
}
|
||||
|
||||
/// No size check — callers enforce their own limits.
|
||||
fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> {
|
||||
let normalized_name = normalize_audio_filename(file_name);
|
||||
let extension = normalized_name
|
||||
.rsplit_once('.')
|
||||
@@ -98,13 +91,26 @@ fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'stati
|
||||
.unwrap_or("");
|
||||
let mime = mime_for_audio(extension).ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Unsupported audio format '.{extension}' — accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
|
||||
"Unsupported audio format '.{extension}' — \
|
||||
accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok((normalized_name, mime))
|
||||
}
|
||||
|
||||
/// Validate audio data and resolve MIME type from file name.
|
||||
///
|
||||
/// Enforces the 25 MB cloud API cap. Returns `(normalized_filename, mime_type)` on success.
|
||||
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
|
||||
if audio_data.len() > MAX_AUDIO_BYTES {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
|
||||
audio_data.len()
|
||||
);
|
||||
}
|
||||
resolve_audio_format(file_name)
|
||||
}
|
||||
|
||||
// ── TranscriptionProvider trait ─────────────────────────────────
|
||||
|
||||
/// Trait for speech-to-text provider implementations.
|
||||
@@ -586,21 +592,120 @@ impl TranscriptionProvider for GoogleSttProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider ────────────────────────────────────────
|
||||
|
||||
/// Self-hosted faster-whisper-compatible STT provider.
|
||||
///
|
||||
/// POSTs audio as `multipart/form-data` (field name `file`) to a configurable
|
||||
/// HTTP endpoint (e.g. faster-whisper on GEX44 over WireGuard). The endpoint
|
||||
/// must return `{"text": "..."}`. No cloud API key required. Size limit is
|
||||
/// configurable — not constrained by the 25 MB cloud API cap.
|
||||
pub struct LocalWhisperProvider {
|
||||
url: String,
|
||||
bearer_token: String,
|
||||
max_audio_bytes: usize,
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl LocalWhisperProvider {
|
||||
/// Build from config. Fails if `url` or `bearer_token` is empty, if `url`
|
||||
/// is not a valid HTTP/HTTPS URL (scheme must be `http` or `https`), if
|
||||
/// `max_audio_bytes` is zero, or if `timeout_secs` is zero.
|
||||
pub fn from_config(config: &crate::config::LocalWhisperConfig) -> Result<Self> {
|
||||
let url = config.url.trim().to_string();
|
||||
anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty");
|
||||
let parsed = url
|
||||
.parse::<reqwest::Url>()
|
||||
.with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?;
|
||||
anyhow::ensure!(
|
||||
matches!(parsed.scheme(), "http" | "https"),
|
||||
"local_whisper: `url` must use http or https scheme, got {:?}",
|
||||
parsed.scheme()
|
||||
);
|
||||
|
||||
let bearer_token = config.bearer_token.trim().to_string();
|
||||
anyhow::ensure!(
|
||||
!bearer_token.is_empty(),
|
||||
"local_whisper: `bearer_token` must not be empty"
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
config.max_audio_bytes > 0,
|
||||
"local_whisper: `max_audio_bytes` must be greater than zero"
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
config.timeout_secs > 0,
|
||||
"local_whisper: `timeout_secs` must be greater than zero"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
url,
|
||||
bearer_token,
|
||||
max_audio_bytes: config.max_audio_bytes,
|
||||
timeout_secs: config.timeout_secs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for LocalWhisperProvider {
|
||||
fn name(&self) -> &str {
|
||||
"local_whisper"
|
||||
}
|
||||
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
if audio_data.len() > self.max_audio_bytes {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, local_whisper max {})",
|
||||
audio_data.len(),
|
||||
self.max_audio_bytes
|
||||
);
|
||||
}
|
||||
|
||||
let (normalized_name, mime) = resolve_audio_format(file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.local_whisper");
|
||||
|
||||
// to_vec() clones the buffer for the multipart payload; peak memory per
|
||||
// call is ~2× max_audio_bytes. TODO: replace with streaming upload once
|
||||
// reqwest supports body streaming in multipart parts.
|
||||
let file_part = Part::bytes(audio_data.to_vec())
|
||||
.file_name(normalized_name)
|
||||
.mime_str(mime)?;
|
||||
|
||||
let resp = client
|
||||
.post(&self.url)
|
||||
.bearer_auth(&self.bearer_token)
|
||||
.multipart(Form::new().part("file", file_part))
|
||||
.timeout(std::time::Duration::from_secs(self.timeout_secs))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send audio to local Whisper endpoint")?;
|
||||
|
||||
parse_whisper_response(resp).await
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared response parsing ─────────────────────────────────────
|
||||
|
||||
/// Parse a standard Whisper-compatible JSON response (`{ "text": "..." }`).
|
||||
/// Parse a faster-whisper-compatible JSON response (`{ "text": "..." }`).
|
||||
///
|
||||
/// Checks HTTP status before attempting JSON parsing so that non-JSON error
|
||||
/// bodies (plain text, HTML, empty 5xx) produce a readable status error
|
||||
/// rather than a confusing "Failed to parse transcription response".
|
||||
async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
bail!("Transcription API error ({}): {}", status, body.trim());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse transcription response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
|
||||
bail!("Transcription API error ({}): {}", status, error_msg);
|
||||
}
|
||||
|
||||
let text = body["text"]
|
||||
.as_str()
|
||||
.context("Transcription response missing 'text' field")?
|
||||
@@ -657,6 +762,17 @@ impl TranscriptionManager {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref local_cfg) = config.local_whisper {
|
||||
match LocalWhisperProvider::from_config(local_cfg) {
|
||||
Ok(p) => {
|
||||
providers.insert("local_whisper".to_string(), Box::new(p));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("local_whisper config invalid, provider skipped: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let default_provider = config.default_provider.clone();
|
||||
|
||||
if config.enabled && !providers.contains_key(&default_provider) {
|
||||
@@ -1036,5 +1152,260 @@ mod tests {
|
||||
assert!(config.deepgram.is_none());
|
||||
assert!(config.assemblyai.is_none());
|
||||
assert!(config.google.is_none());
|
||||
assert!(config.local_whisper.is_none());
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider tests (TDD — added below as red/green cycles) ──
|
||||
|
||||
fn local_whisper_config(url: &str) -> crate::config::LocalWhisperConfig {
|
||||
crate::config::LocalWhisperConfig {
|
||||
url: url.to_string(),
|
||||
bearer_token: "test-token".to_string(),
|
||||
max_audio_bytes: 10 * 1024 * 1024,
|
||||
timeout_secs: 30,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_empty_url() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.url = String::new();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string().contains("`url` must not be empty"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_invalid_url() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.url = "not-a-url".to_string();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(err.to_string().contains("invalid `url`"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_non_http_url() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(err.to_string().contains("http or https"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_empty_bearer_token() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.bearer_token = String::new();
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string().contains("`bearer_token` must not be empty"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_zero_max_audio_bytes() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.max_audio_bytes = 0;
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("`max_audio_bytes` must be greater than zero"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_rejects_zero_timeout() {
|
||||
let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
cfg.timeout_secs = 0;
|
||||
let err = LocalWhisperProvider::from_config(&cfg).err().unwrap();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("`timeout_secs` must be greater than zero"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_registered_when_config_present() {
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.local_whisper = Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe"));
|
||||
config.default_provider = "local_whisper".to_string();
|
||||
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
assert!(
|
||||
manager.available_providers().contains(&"local_whisper"),
|
||||
"expected local_whisper in {:?}",
|
||||
manager.available_providers()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_whisper_misconfigured_section_fails_manager_construction() {
|
||||
// A misconfigured local_whisper section logs a warning and skips
|
||||
// registration. When local_whisper is also the default_provider and
|
||||
// transcription is enabled, the safety net in TranscriptionManager
|
||||
// surfaces the error: "not configured".
|
||||
let mut config = TranscriptionConfig::default();
|
||||
let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
bad_cfg.bearer_token = String::new();
|
||||
config.local_whisper = Some(bad_cfg);
|
||||
config.enabled = true;
|
||||
config.default_provider = "local_whisper".to_string();
|
||||
|
||||
let err = TranscriptionManager::new(&config).err().unwrap();
|
||||
assert!(
|
||||
err.to_string().contains("not configured"),
|
||||
"expected 'not configured' from manager safety net, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_audio_still_enforces_25mb_cap() {
|
||||
// Regression: extracting resolve_audio_format() must not weaken validate_audio().
|
||||
let at_limit = vec![0u8; MAX_AUDIO_BYTES];
|
||||
assert!(validate_audio(&at_limit, "test.ogg").is_ok());
|
||||
let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1];
|
||||
let err = validate_audio(&over_limit, "test.ogg").unwrap_err();
|
||||
assert!(err.to_string().contains("too large"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_rejects_oversized_audio() {
|
||||
let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
let big = vec![0u8; cfg.max_audio_bytes + 1];
|
||||
let err = provider.transcribe(&big, "voice.ogg").await.unwrap_err();
|
||||
assert!(err.to_string().contains("too large"), "got: {err}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_rejects_unsupported_format() {
|
||||
let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
let data = vec![0u8; 100];
|
||||
let err = provider.transcribe(&data, "voice.aiff").await.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("Unsupported audio format"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── LocalWhisperProvider HTTP mock tests ────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_returns_text_from_response() {
|
||||
use wiremock::matchers::{header_exists, method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.and(header_exists("authorization"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.set_body_json(serde_json::json!({"text": "hello world"})),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let result = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_sends_bearer_auth_header() {
|
||||
use wiremock::matchers::{header, method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.and(header("authorization", "Bearer test-token"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let result = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "auth ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_propagates_http_error() {
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(503).set_body_json(
|
||||
serde_json::json!({"error": {"message": "service unavailable"}}),
|
||||
),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let err = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("503") || err.to_string().contains("service unavailable"),
|
||||
"expected HTTP error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_whisper_propagates_non_json_http_error() {
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/transcribe"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(502)
|
||||
.set_body_string("Bad Gateway")
|
||||
.insert_header("content-type", "text/plain"),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
|
||||
let provider = LocalWhisperProvider::from_config(&cfg).unwrap();
|
||||
|
||||
let err = provider
|
||||
.transcribe(b"fake-audio", "voice.ogg")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("502"), "got: {err}");
|
||||
assert!(
|
||||
err.to_string().contains("Bad Gateway"),
|
||||
"expected plain-text body in error, got: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+11
-1
@@ -22,13 +22,23 @@ impl WatiChannel {
|
||||
api_url: String,
|
||||
tenant_id: Option<String>,
|
||||
allowed_numbers: Vec<String>,
|
||||
) -> Self {
|
||||
Self::new_with_proxy(api_token, api_url, tenant_id, allowed_numbers, None)
|
||||
}
|
||||
|
||||
pub fn new_with_proxy(
|
||||
api_token: String,
|
||||
api_url: String,
|
||||
tenant_id: Option<String>,
|
||||
allowed_numbers: Vec<String>,
|
||||
proxy_url: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_token,
|
||||
api_url,
|
||||
tenant_id,
|
||||
allowed_numbers,
|
||||
client: crate::config::build_runtime_proxy_client("channel.wati"),
|
||||
client: crate::config::build_channel_proxy_client("channel.wati", proxy_url.as_deref()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ pub struct WhatsAppChannel {
|
||||
endpoint_id: String,
|
||||
verify_token: String,
|
||||
allowed_numbers: Vec<String>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl WhatsAppChannel {
|
||||
@@ -41,11 +43,18 @@ impl WhatsAppChannel {
|
||||
endpoint_id,
|
||||
verify_token,
|
||||
allowed_numbers,
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a per-channel proxy URL that overrides the global proxy config.
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_runtime_proxy_client("channel.whatsapp")
|
||||
crate::config::build_channel_proxy_client("channel.whatsapp", self.proxy_url.as_deref())
|
||||
}
|
||||
|
||||
/// Check if a phone number is allowed (E.164 format: +1234567890)
|
||||
|
||||
@@ -249,7 +249,7 @@ async fn check_memory_roundtrip(config: &crate::config::Config) -> CheckResult {
|
||||
return CheckResult::fail("memory", format!("write failed: {e}"));
|
||||
}
|
||||
|
||||
match mem.recall(test_key, 1, None).await {
|
||||
match mem.recall(test_key, 1, None, None, None).await {
|
||||
Ok(entries) if !entries.is_empty() => {
|
||||
let _ = mem.forget(test_key).await;
|
||||
CheckResult::pass("memory", "write/read/delete round-trip OK")
|
||||
|
||||
+28
-22
@@ -4,31 +4,32 @@ pub mod workspace;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use schema::{
|
||||
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
|
||||
apply_channel_proxy_to_builder, apply_runtime_proxy_to_builder, build_channel_proxy_client,
|
||||
build_channel_proxy_client_with_timeouts, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, CloudOpsConfig, ComposioConfig, Config, ConversationalAiConfig, CostConfig,
|
||||
CronConfig, DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig,
|
||||
DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig,
|
||||
EstopConfig, FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig,
|
||||
GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig, HardwareConfig, HardwareTransport,
|
||||
HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig,
|
||||
ImageProviderDalleConfig, ImageProviderFluxConfig, ImageProviderImagenConfig,
|
||||
ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig, LarkConfig, LinkedInConfig,
|
||||
LinkedInContentConfig, LinkedInImageConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
|
||||
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
|
||||
DEFAULT_GWS_SERVICES,
|
||||
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
|
||||
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
|
||||
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
|
||||
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
|
||||
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
|
||||
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
|
||||
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, SopConfig, StorageConfig,
|
||||
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
|
||||
TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig,
|
||||
TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig,
|
||||
WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
@@ -58,6 +59,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let discord = DiscordConfig {
|
||||
@@ -67,6 +69,7 @@ mod tests {
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let lark = LarkConfig {
|
||||
@@ -79,6 +82,7 @@ mod tests {
|
||||
use_feishu: false,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let feishu = FeishuConfig {
|
||||
app_id: "app-id".into(),
|
||||
@@ -88,6 +92,7 @@ mod tests {
|
||||
allowed_users: vec![],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let nextcloud_talk = NextcloudTalkConfig {
|
||||
@@ -95,6 +100,7 @@ mod tests {
|
||||
app_token: "app-token".into(),
|
||||
webhook_secret: None,
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
assert_eq!(telegram.allowed_users.len(), 1);
|
||||
|
||||
+508
-6
@@ -165,6 +165,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub agent: AgentConfig,
|
||||
|
||||
/// Pacing controls for slow/local LLM workloads (`[pacing]`).
|
||||
#[serde(default)]
|
||||
pub pacing: PacingConfig,
|
||||
|
||||
/// Skills loading and community repository behavior (`[skills]`).
|
||||
#[serde(default)]
|
||||
pub skills: SkillsConfig,
|
||||
@@ -371,6 +375,14 @@ pub struct Config {
|
||||
/// Verifiable Intent (VI) credential verification and issuance (`[verifiable_intent]`).
|
||||
#[serde(default)]
|
||||
pub verifiable_intent: VerifiableIntentConfig,
|
||||
|
||||
/// Claude Code tool configuration (`[claude_code]`).
|
||||
#[serde(default)]
|
||||
pub claude_code: ClaudeCodeConfig,
|
||||
|
||||
/// Standard Operating Procedures engine configuration (`[sop]`).
|
||||
#[serde(default)]
|
||||
pub sop: SopConfig,
|
||||
}
|
||||
|
||||
/// Multi-client workspace isolation configuration.
|
||||
@@ -515,6 +527,10 @@ pub struct DelegateAgentConfig {
|
||||
/// When `None`, falls back to `[delegate].agentic_timeout_secs` (default: 300).
|
||||
#[serde(default)]
|
||||
pub agentic_timeout_secs: Option<u64>,
|
||||
/// Optional skills directory path (relative to workspace root) for scoped skill loading.
|
||||
/// When unset or empty, the sub-agent falls back to the default workspace `skills/` directory.
|
||||
#[serde(default)]
|
||||
pub skills_directory: Option<String>,
|
||||
}
|
||||
|
||||
fn default_delegate_timeout_secs() -> u64 {
|
||||
@@ -784,6 +800,9 @@ pub struct TranscriptionConfig {
|
||||
/// Google Cloud Speech-to-Text provider configuration.
|
||||
#[serde(default)]
|
||||
pub google: Option<GoogleSttConfig>,
|
||||
/// Local/self-hosted Whisper-compatible STT provider.
|
||||
#[serde(default)]
|
||||
pub local_whisper: Option<LocalWhisperConfig>,
|
||||
}
|
||||
|
||||
impl Default for TranscriptionConfig {
|
||||
@@ -801,6 +820,7 @@ impl Default for TranscriptionConfig {
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
local_whisper: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1169,6 +1189,35 @@ pub struct GoogleSttConfig {
|
||||
pub language_code: String,
|
||||
}
|
||||
|
||||
/// Local/self-hosted Whisper-compatible STT endpoint (`[transcription.local_whisper]`).
|
||||
///
|
||||
/// Audio is sent over WireGuard; never leaves the platform perimeter.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct LocalWhisperConfig {
|
||||
/// HTTP or HTTPS endpoint URL, e.g. `"http://10.10.0.1:8001/v1/transcribe"`.
|
||||
pub url: String,
|
||||
/// Bearer token for endpoint authentication.
|
||||
pub bearer_token: String,
|
||||
/// Maximum audio file size in bytes accepted by this endpoint.
|
||||
/// Defaults to 25 MB — matching the cloud API cap for a safe out-of-the-box
|
||||
/// experience. Self-hosted endpoints can accept much larger files; raise this
|
||||
/// as needed, but note that each transcription call clones the audio buffer
|
||||
/// into a multipart payload, so peak memory per request is ~2× this value.
|
||||
#[serde(default = "default_local_whisper_max_audio_bytes")]
|
||||
pub max_audio_bytes: usize,
|
||||
/// Request timeout in seconds. Defaults to 300 (large files on local GPU).
|
||||
#[serde(default = "default_local_whisper_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_local_whisper_max_audio_bytes() -> usize {
|
||||
25 * 1024 * 1024
|
||||
}
|
||||
|
||||
fn default_local_whisper_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
/// Agent orchestration configuration (`[agent]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AgentConfig {
|
||||
@@ -1203,6 +1252,12 @@ pub struct AgentConfig {
|
||||
/// Default: `[]` (no filtering — all tools included).
|
||||
#[serde(default)]
|
||||
pub tool_filter_groups: Vec<ToolFilterGroup>,
|
||||
/// Maximum characters for the assembled system prompt. When `> 0`, the prompt
|
||||
/// is truncated to this limit after assembly (keeping the top portion which
|
||||
/// contains identity and safety instructions). `0` means unlimited.
|
||||
/// Useful for small-context models (e.g. glm-4.5-air ~8K tokens → set to 8000).
|
||||
#[serde(default = "default_max_system_prompt_chars")]
|
||||
pub max_system_prompt_chars: usize,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@@ -1221,6 +1276,10 @@ fn default_agent_tool_dispatcher() -> String {
|
||||
"auto".into()
|
||||
}
|
||||
|
||||
fn default_max_system_prompt_chars() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -1232,10 +1291,48 @@ impl Default for AgentConfig {
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
tool_filter_groups: Vec::new(),
|
||||
max_system_prompt_chars: default_max_system_prompt_chars(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pacing ────────────────────────────────────────────────────────
|
||||
|
||||
/// Pacing controls for slow/local LLM workloads (`[pacing]` section).
|
||||
///
|
||||
/// All fields are optional and default to values that preserve existing
|
||||
/// behavior. When set, they extend — not replace — the existing timeout
|
||||
/// and loop-detection subsystems.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PacingConfig {
|
||||
/// Per-step timeout in seconds: the maximum time allowed for a single
|
||||
/// LLM inference turn, independent of the total message budget.
|
||||
/// `None` means no per-step timeout (existing behavior).
|
||||
#[serde(default)]
|
||||
pub step_timeout_secs: Option<u64>,
|
||||
|
||||
/// Minimum elapsed seconds before loop detection activates.
|
||||
/// Tasks completing under this threshold get aggressive loop protection;
|
||||
/// longer-running tasks receive a grace period before the detector starts
|
||||
/// counting. `None` means loop detection is always active (existing behavior).
|
||||
#[serde(default)]
|
||||
pub loop_detection_min_elapsed_secs: Option<u64>,
|
||||
|
||||
/// Tool names excluded from identical-output / alternating-pattern loop
|
||||
/// detection. Useful for browser workflows where `browser_screenshot`
|
||||
/// structurally resembles a loop even when making progress.
|
||||
#[serde(default)]
|
||||
pub loop_ignore_tools: Vec<String>,
|
||||
|
||||
/// Override for the hardcoded timeout scaling cap (default: 4).
|
||||
/// The channel message timeout budget is computed as:
|
||||
/// `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)`
|
||||
/// Raising this value lets long multi-step tasks with slow local models
|
||||
/// receive a proportionally larger budget without inflating the base timeout.
|
||||
#[serde(default)]
|
||||
pub message_timeout_scale_max: Option<u64>,
|
||||
}
|
||||
|
||||
/// Skills loading configuration (`[skills]` section).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -1611,6 +1708,12 @@ pub struct GatewayConfig {
|
||||
#[serde(default)]
|
||||
pub trust_forwarded_headers: bool,
|
||||
|
||||
/// Optional URL path prefix for reverse-proxy deployments.
|
||||
/// When set, all gateway routes are served under this prefix.
|
||||
/// Must start with `/` and must not end with `/`.
|
||||
#[serde(default)]
|
||||
pub path_prefix: Option<String>,
|
||||
|
||||
/// Maximum distinct client keys tracked by gateway rate limiter maps.
|
||||
#[serde(default = "default_gateway_rate_limit_max_keys")]
|
||||
pub rate_limit_max_keys: usize,
|
||||
@@ -1683,6 +1786,7 @@ impl Default for GatewayConfig {
|
||||
pair_rate_limit_per_minute: default_pair_rate_limit(),
|
||||
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
|
||||
trust_forwarded_headers: false,
|
||||
path_prefix: None,
|
||||
rate_limit_max_keys: default_gateway_rate_limit_max_keys(),
|
||||
idempotency_ttl_secs: default_idempotency_ttl_secs(),
|
||||
idempotency_max_keys: default_gateway_idempotency_max_keys(),
|
||||
@@ -2874,6 +2978,60 @@ impl Default for ImageProviderFluxConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude Code ─────────────────────────────────────────────────
|
||||
|
||||
/// Claude Code CLI tool configuration (`[claude_code]` section).
|
||||
///
|
||||
/// Delegates coding tasks to the `claude -p` CLI. Authentication uses the
|
||||
/// binary's own OAuth session (Max subscription) by default — no API key
|
||||
/// needed unless `env_passthrough` includes `ANTHROPIC_API_KEY`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ClaudeCodeConfig {
|
||||
/// Enable the `claude_code` tool
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Maximum execution time in seconds (coding tasks can be long)
|
||||
#[serde(default = "default_claude_code_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
/// Claude Code tools the subprocess is allowed to use
|
||||
#[serde(default = "default_claude_code_allowed_tools")]
|
||||
pub allowed_tools: Vec<String>,
|
||||
/// Optional system prompt appended to Claude Code invocations
|
||||
#[serde(default)]
|
||||
pub system_prompt: Option<String>,
|
||||
/// Maximum output size in bytes (2MB default)
|
||||
#[serde(default = "default_claude_code_max_output_bytes")]
|
||||
pub max_output_bytes: usize,
|
||||
/// Extra env vars passed to the claude subprocess (e.g. ANTHROPIC_API_KEY for API-key billing)
|
||||
#[serde(default)]
|
||||
pub env_passthrough: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_claude_code_timeout_secs() -> u64 {
|
||||
600
|
||||
}
|
||||
|
||||
fn default_claude_code_allowed_tools() -> Vec<String> {
|
||||
vec!["Read".into(), "Edit".into(), "Bash".into(), "Write".into()]
|
||||
}
|
||||
|
||||
fn default_claude_code_max_output_bytes() -> usize {
|
||||
2_097_152
|
||||
}
|
||||
|
||||
impl Default for ClaudeCodeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
timeout_secs: default_claude_code_timeout_secs(),
|
||||
allowed_tools: default_claude_code_allowed_tools(),
|
||||
system_prompt: None,
|
||||
max_output_bytes: default_claude_code_max_output_bytes(),
|
||||
env_passthrough: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Proxy ───────────────────────────────────────────────────────
|
||||
|
||||
/// Proxy application scope — determines which outbound traffic uses the proxy.
|
||||
@@ -3381,6 +3539,116 @@ pub fn build_runtime_proxy_client_with_timeouts(
|
||||
client
|
||||
}
|
||||
|
||||
/// Build an HTTP client for a channel, using an explicit per-channel proxy URL
|
||||
/// when configured. Falls back to the global runtime proxy when `proxy_url` is
|
||||
/// `None` or empty.
|
||||
pub fn build_channel_proxy_client(service_key: &str, proxy_url: Option<&str>) -> reqwest::Client {
|
||||
match normalize_proxy_url_option(proxy_url) {
|
||||
Some(url) => build_explicit_proxy_client(service_key, &url, None, None),
|
||||
None => build_runtime_proxy_client(service_key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an HTTP client for a channel with custom timeouts, using an explicit
|
||||
/// per-channel proxy URL when configured. Falls back to the global runtime
|
||||
/// proxy when `proxy_url` is `None` or empty.
|
||||
pub fn build_channel_proxy_client_with_timeouts(
|
||||
service_key: &str,
|
||||
proxy_url: Option<&str>,
|
||||
timeout_secs: u64,
|
||||
connect_timeout_secs: u64,
|
||||
) -> reqwest::Client {
|
||||
match normalize_proxy_url_option(proxy_url) {
|
||||
Some(url) => build_explicit_proxy_client(
|
||||
service_key,
|
||||
&url,
|
||||
Some(timeout_secs),
|
||||
Some(connect_timeout_secs),
|
||||
),
|
||||
None => build_runtime_proxy_client_with_timeouts(
|
||||
service_key,
|
||||
timeout_secs,
|
||||
connect_timeout_secs,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply an explicit proxy URL to a `reqwest::ClientBuilder`, returning the
|
||||
/// modified builder. Used by channels that specify a per-channel `proxy_url`.
|
||||
pub fn apply_channel_proxy_to_builder(
|
||||
builder: reqwest::ClientBuilder,
|
||||
service_key: &str,
|
||||
proxy_url: Option<&str>,
|
||||
) -> reqwest::ClientBuilder {
|
||||
match normalize_proxy_url_option(proxy_url) {
|
||||
Some(url) => apply_explicit_proxy_to_builder(builder, service_key, &url),
|
||||
None => apply_runtime_proxy_to_builder(builder, service_key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a client with a single explicit proxy URL (http+https via `Proxy::all`).
|
||||
fn build_explicit_proxy_client(
|
||||
service_key: &str,
|
||||
proxy_url: &str,
|
||||
timeout_secs: Option<u64>,
|
||||
connect_timeout_secs: Option<u64>,
|
||||
) -> reqwest::Client {
|
||||
let cache_key = format!(
|
||||
"explicit|{}|{}|timeout={}|connect_timeout={}",
|
||||
service_key.trim().to_ascii_lowercase(),
|
||||
proxy_url,
|
||||
timeout_secs
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "none".to_string()),
|
||||
connect_timeout_secs
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "none".to_string()),
|
||||
);
|
||||
if let Some(client) = runtime_proxy_cached_client(&cache_key) {
|
||||
return client;
|
||||
}
|
||||
|
||||
let mut builder = reqwest::Client::builder();
|
||||
if let Some(t) = timeout_secs {
|
||||
builder = builder.timeout(std::time::Duration::from_secs(t));
|
||||
}
|
||||
if let Some(ct) = connect_timeout_secs {
|
||||
builder = builder.connect_timeout(std::time::Duration::from_secs(ct));
|
||||
}
|
||||
builder = apply_explicit_proxy_to_builder(builder, service_key, proxy_url);
|
||||
let client = builder.build().unwrap_or_else(|error| {
|
||||
tracing::warn!(
|
||||
service_key,
|
||||
proxy_url,
|
||||
"Failed to build channel proxy client: {error}"
|
||||
);
|
||||
reqwest::Client::new()
|
||||
});
|
||||
set_runtime_proxy_cached_client(cache_key, client.clone());
|
||||
client
|
||||
}
|
||||
|
||||
/// Apply a single explicit proxy URL to a builder via `Proxy::all`.
|
||||
fn apply_explicit_proxy_to_builder(
|
||||
mut builder: reqwest::ClientBuilder,
|
||||
service_key: &str,
|
||||
proxy_url: &str,
|
||||
) -> reqwest::ClientBuilder {
|
||||
match reqwest::Proxy::all(proxy_url) {
|
||||
Ok(proxy) => {
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
proxy_url,
|
||||
service_key,
|
||||
"Ignoring invalid channel proxy_url: {error}"
|
||||
);
|
||||
}
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
fn parse_proxy_scope(raw: &str) -> Option<ProxyScope> {
|
||||
match raw.trim().to_ascii_lowercase().as_str() {
|
||||
"environment" | "env" => Some(ProxyScope::Environment),
|
||||
@@ -4885,6 +5153,10 @@ pub struct TelegramConfig {
|
||||
/// explicitly, it takes precedence.
|
||||
#[serde(default)]
|
||||
pub ack_reactions: Option<bool>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for TelegramConfig {
|
||||
@@ -4918,6 +5190,10 @@ pub struct DiscordConfig {
|
||||
/// Other messages in the guild are silently ignored.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for DiscordConfig {
|
||||
@@ -4954,6 +5230,10 @@ pub struct SlackConfig {
|
||||
/// Direct messages remain allowed.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for SlackConfig {
|
||||
@@ -4989,6 +5269,10 @@ pub struct MattermostConfig {
|
||||
/// cancels the in-flight request and starts a fresh response with preserved history.
|
||||
#[serde(default)]
|
||||
pub interrupt_on_new_message: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for MattermostConfig {
|
||||
@@ -5101,6 +5385,10 @@ pub struct SignalConfig {
|
||||
/// Skip incoming story messages.
|
||||
#[serde(default)]
|
||||
pub ignore_stories: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for SignalConfig {
|
||||
@@ -5195,6 +5483,10 @@ pub struct WhatsAppConfig {
|
||||
/// user's own self-chat (Notes to Self). Defaults to false.
|
||||
#[serde(default)]
|
||||
pub self_chat_mode: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for WhatsAppConfig {
|
||||
@@ -5243,6 +5535,10 @@ pub struct WatiConfig {
|
||||
/// Allowed phone numbers (E.164 format) or "*" for all.
|
||||
#[serde(default)]
|
||||
pub allowed_numbers: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
fn default_wati_api_url() -> String {
|
||||
@@ -5273,6 +5569,10 @@ pub struct NextcloudTalkConfig {
|
||||
/// Allowed Nextcloud actor IDs (`[]` = deny all, `"*"` = allow all).
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for NextcloudTalkConfig {
|
||||
@@ -5400,6 +5700,10 @@ pub struct LarkConfig {
|
||||
/// Not required (and ignored) for websocket mode.
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for LarkConfig {
|
||||
@@ -5434,6 +5738,10 @@ pub struct FeishuConfig {
|
||||
/// Not required (and ignored) for websocket mode.
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for FeishuConfig {
|
||||
@@ -5895,6 +6203,10 @@ pub struct DingTalkConfig {
|
||||
/// Allowed user IDs (staff IDs). Empty = deny all, "*" = allow all
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for DingTalkConfig {
|
||||
@@ -5935,6 +6247,10 @@ pub struct QQConfig {
|
||||
/// Allowed user IDs. Empty = deny all, "*" = allow all
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
/// Overrides the global `[proxy]` setting for this channel only.
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for QQConfig {
|
||||
@@ -6467,6 +6783,7 @@ impl Default for Config {
|
||||
reliability: ReliabilityConfig::default(),
|
||||
scheduler: SchedulerConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
pacing: PacingConfig::default(),
|
||||
skills: SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -6512,6 +6829,8 @@ impl Default for Config {
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
sop: SopConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7571,6 +7890,31 @@ impl Config {
|
||||
if self.gateway.host.trim().is_empty() {
|
||||
anyhow::bail!("gateway.host must not be empty");
|
||||
}
|
||||
if let Some(ref prefix) = self.gateway.path_prefix {
|
||||
// Validate the raw value — no silent trimming so the stored
|
||||
// value is exactly what was validated.
|
||||
if !prefix.is_empty() {
|
||||
if !prefix.starts_with('/') {
|
||||
anyhow::bail!("gateway.path_prefix must start with '/'");
|
||||
}
|
||||
if prefix.ends_with('/') {
|
||||
anyhow::bail!("gateway.path_prefix must not end with '/' (including bare '/')");
|
||||
}
|
||||
// Reject characters unsafe for URL paths or HTML/JS injection.
|
||||
// Whitespace is intentionally excluded from the allowed set.
|
||||
if let Some(bad) = prefix.chars().find(|c| {
|
||||
!matches!(c, '/' | '-' | '_' | '.' | '~'
|
||||
| 'a'..='z' | 'A'..='Z' | '0'..='9'
|
||||
| '!' | '$' | '&' | '\'' | '(' | ')' | '*' | '+' | ',' | ';' | '='
|
||||
| ':' | '@')
|
||||
}) {
|
||||
anyhow::bail!(
|
||||
"gateway.path_prefix contains invalid character '{bad}'; \
|
||||
only unreserved and sub-delim URI characters are allowed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Autonomy
|
||||
if self.autonomy.max_actions_per_hour == 0 {
|
||||
@@ -8081,10 +8425,10 @@ impl Config {
|
||||
{
|
||||
let dp = self.transcription.default_provider.trim();
|
||||
match dp {
|
||||
"groq" | "openai" | "deepgram" | "assemblyai" | "google" => {}
|
||||
"groq" | "openai" | "deepgram" | "assemblyai" | "google" | "local_whisper" => {}
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google (got '{other}')"
|
||||
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google, local_whisper (got '{other}')"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -8939,6 +9283,70 @@ async fn sync_directory(path: &Path) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// ── SOP engine configuration ───────────────────────────────────
|
||||
|
||||
/// Standard Operating Procedures engine configuration (`[sop]`).
|
||||
///
|
||||
/// The `default_execution_mode` field uses the `SopExecutionMode` type from
|
||||
/// `sop::types` (re-exported via `sop::SopExecutionMode`). To avoid circular
|
||||
/// module references, config stores it using the same enum definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SopConfig {
|
||||
/// Directory containing SOP definitions (subdirs with SOP.toml + SOP.md).
|
||||
/// Falls back to `<workspace>/sops` when omitted.
|
||||
#[serde(default)]
|
||||
pub sops_dir: Option<String>,
|
||||
|
||||
/// Default execution mode for SOPs that omit `execution_mode`.
|
||||
/// Values: `auto`, `supervised` (default), `step_by_step`,
|
||||
/// `priority_based`, `deterministic`.
|
||||
#[serde(default = "default_sop_execution_mode")]
|
||||
pub default_execution_mode: String,
|
||||
|
||||
/// Maximum total concurrent SOP runs across all SOPs.
|
||||
#[serde(default = "default_sop_max_concurrent_total")]
|
||||
pub max_concurrent_total: usize,
|
||||
|
||||
/// Approval timeout in seconds. When a run waits for approval longer than
|
||||
/// this, Critical/High-priority SOPs auto-approve; others stay waiting.
|
||||
/// Set to 0 to disable timeout.
|
||||
#[serde(default = "default_sop_approval_timeout_secs")]
|
||||
pub approval_timeout_secs: u64,
|
||||
|
||||
/// Maximum number of finished runs kept in memory for status queries.
|
||||
/// Oldest runs are evicted when over capacity. 0 = unlimited.
|
||||
#[serde(default = "default_sop_max_finished_runs")]
|
||||
pub max_finished_runs: usize,
|
||||
}
|
||||
|
||||
fn default_sop_execution_mode() -> String {
|
||||
"supervised".to_string()
|
||||
}
|
||||
|
||||
fn default_sop_max_concurrent_total() -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
fn default_sop_approval_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_sop_max_finished_runs() -> usize {
|
||||
100
|
||||
}
|
||||
|
||||
impl Default for SopConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sops_dir: None,
|
||||
default_execution_mode: default_sop_execution_mode(),
|
||||
max_concurrent_total: default_sop_max_concurrent_total(),
|
||||
approval_timeout_secs: default_sop_approval_timeout_secs(),
|
||||
max_finished_runs: default_sop_max_finished_runs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -9335,6 +9743,7 @@ default_temperature = 0.7
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
}),
|
||||
discord: None,
|
||||
slack: None,
|
||||
@@ -9386,6 +9795,7 @@ default_temperature = 0.7
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
pacing: PacingConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
@@ -9407,6 +9817,8 @@ default_temperature = 0.7
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
sop: SopConfig::default(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
@@ -9656,6 +10068,47 @@ tool_dispatcher = "xml"
|
||||
assert_eq!(parsed.agent.tool_dispatcher, "xml");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn pacing_config_defaults_are_all_none_or_empty() {
|
||||
let cfg = PacingConfig::default();
|
||||
assert!(cfg.step_timeout_secs.is_none());
|
||||
assert!(cfg.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(cfg.loop_ignore_tools.is_empty());
|
||||
assert!(cfg.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn pacing_config_deserializes_from_toml() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
[pacing]
|
||||
step_timeout_secs = 120
|
||||
loop_detection_min_elapsed_secs = 60
|
||||
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
|
||||
message_timeout_scale_max = 8
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(raw).unwrap();
|
||||
assert_eq!(parsed.pacing.step_timeout_secs, Some(120));
|
||||
assert_eq!(parsed.pacing.loop_detection_min_elapsed_secs, Some(60));
|
||||
assert_eq!(
|
||||
parsed.pacing.loop_ignore_tools,
|
||||
vec!["browser_screenshot", "browser_navigate"]
|
||||
);
|
||||
assert_eq!(parsed.pacing.message_timeout_scale_max, Some(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn pacing_config_absent_preserves_defaults() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(raw).unwrap();
|
||||
assert!(parsed.pacing.step_timeout_secs.is_none());
|
||||
assert!(parsed.pacing.loop_detection_min_elapsed_secs.is_none());
|
||||
assert!(parsed.pacing.loop_ignore_tools.is_empty());
|
||||
assert!(parsed.pacing.message_timeout_scale_max.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sync_directory_handles_existing_directory() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
@@ -9724,6 +10177,7 @@ tool_dispatcher = "xml"
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
pacing: PacingConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
@@ -9745,6 +10199,8 @@ tool_dispatcher = "xml"
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
claude_code: ClaudeCodeConfig::default(),
|
||||
sop: SopConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await.unwrap();
|
||||
@@ -9789,6 +10245,7 @@ tool_dispatcher = "xml"
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
config.agents.insert(
|
||||
@@ -9805,6 +10262,7 @@ tool_dispatcher = "xml"
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -9930,6 +10388,7 @@ tool_dispatcher = "xml"
|
||||
interrupt_on_new_message: true,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&tc).unwrap();
|
||||
let parsed: TelegramConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -9958,6 +10417,7 @@ tool_dispatcher = "xml"
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -9974,6 +10434,7 @@ tool_dispatcher = "xml"
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -10075,6 +10536,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
allowed_from: vec!["+1111111111".into()],
|
||||
ignore_attachments: true,
|
||||
ignore_stories: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&sc).unwrap();
|
||||
let parsed: SignalConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -10095,6 +10557,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
allowed_from: vec!["*".into()],
|
||||
ignore_attachments: false,
|
||||
ignore_stories: true,
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&sc).unwrap();
|
||||
let parsed: SignalConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -10325,6 +10788,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&wc).unwrap();
|
||||
let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -10349,6 +10813,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&wc).unwrap();
|
||||
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -10378,6 +10843,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&wc).unwrap();
|
||||
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -10399,6 +10865,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
assert!(wc.is_ambiguous_config());
|
||||
assert_eq!(wc.backend_type(), "cloud");
|
||||
@@ -10419,6 +10886,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
};
|
||||
assert!(!wc.is_ambiguous_config());
|
||||
assert_eq!(wc.backend_type(), "web");
|
||||
@@ -10449,6 +10917,7 @@ channel_id = "C123"
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
}),
|
||||
linq: None,
|
||||
wati: None,
|
||||
@@ -10553,6 +11022,7 @@ channel_id = "C123"
|
||||
pair_rate_limit_per_minute: 12,
|
||||
webhook_rate_limit_per_minute: 80,
|
||||
trust_forwarded_headers: true,
|
||||
path_prefix: Some("/zeroclaw".into()),
|
||||
rate_limit_max_keys: 2048,
|
||||
idempotency_ttl_secs: 600,
|
||||
idempotency_max_keys: 4096,
|
||||
@@ -10570,6 +11040,7 @@ channel_id = "C123"
|
||||
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
|
||||
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
|
||||
assert!(parsed.trust_forwarded_headers);
|
||||
assert_eq!(parsed.path_prefix.as_deref(), Some("/zeroclaw"));
|
||||
assert_eq!(parsed.rate_limit_max_keys, 2048);
|
||||
assert_eq!(parsed.idempotency_ttl_secs, 600);
|
||||
assert_eq!(parsed.idempotency_max_keys, 4096);
|
||||
@@ -11453,6 +11924,7 @@ default_model = "legacy-model"
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
config.save().await.unwrap();
|
||||
|
||||
@@ -12164,6 +12636,7 @@ default_model = "persisted-profile"
|
||||
use_feishu: true,
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -12187,6 +12660,7 @@ default_model = "persisted-profile"
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -12233,6 +12707,7 @@ default_model = "persisted-profile"
|
||||
allowed_users: vec!["user_123".into(), "user_456".into()],
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
let json = serde_json::to_string(&fc).unwrap();
|
||||
let parsed: FeishuConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -12253,6 +12728,7 @@ default_model = "persisted-profile"
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
proxy_url: None,
|
||||
};
|
||||
let toml_str = toml::to_string(&fc).unwrap();
|
||||
let parsed: FeishuConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -12280,6 +12756,7 @@ default_model = "persisted-profile"
|
||||
app_token: "app-token".into(),
|
||||
webhook_secret: Some("webhook-secret".into()),
|
||||
allowed_users: vec!["user_a".into(), "*".into()],
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&nc).unwrap();
|
||||
@@ -12467,6 +12944,30 @@ require_otp_to_resume = true
|
||||
assert!(err.to_string().contains("gated_domains"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validate_accepts_local_whisper_as_transcription_default_provider() {
|
||||
let mut config = Config::default();
|
||||
config.transcription.default_provider = "local_whisper".to_string();
|
||||
|
||||
config.validate().expect(
|
||||
"local_whisper must be accepted by the transcription.default_provider allowlist",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validate_rejects_unknown_transcription_default_provider() {
|
||||
let mut config = Config::default();
|
||||
config.transcription.default_provider = "unknown_stt".to_string();
|
||||
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("expected validation to reject unknown transcription provider");
|
||||
assert!(
|
||||
err.to_string().contains("transcription.default_provider"),
|
||||
"got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_secret_telegram_bot_token_roundtrip() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
@@ -12488,6 +12989,7 @@ require_otp_to_resume = true
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
// Save (triggers encryption)
|
||||
@@ -13247,9 +13749,9 @@ require_otp_to_resume = true
|
||||
|
||||
// ── Bootstrap files ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
#[tokio::test]
|
||||
async fn ensure_bootstrap_files_creates_missing_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let ws = tmp.path().join("workspace");
|
||||
tokio::fs::create_dir_all(&ws).await.unwrap();
|
||||
|
||||
@@ -13263,9 +13765,9 @@ require_otp_to_resume = true
|
||||
assert!(identity.contains("IDENTITY.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tokio::test]
|
||||
async fn ensure_bootstrap_files_does_not_overwrite_existing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let ws = tmp.path().join("workspace");
|
||||
tokio::fs::create_dir_all(&ws).await.unwrap();
|
||||
|
||||
|
||||
+30
-1
@@ -7,7 +7,7 @@ use std::collections::HashMap;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
/// Cost tracker for API usage monitoring and budget enforcement.
|
||||
pub struct CostTracker {
|
||||
@@ -175,6 +175,35 @@ impl CostTracker {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Process-global singleton ────────────────────────────────────────
|
||||
// Both the gateway and the channels supervisor share a single CostTracker
|
||||
// so that budget enforcement is consistent across all paths.
|
||||
|
||||
static GLOBAL_COST_TRACKER: OnceLock<Option<Arc<CostTracker>>> = OnceLock::new();
|
||||
|
||||
impl CostTracker {
|
||||
/// Return the process-global `CostTracker`, creating it on first call.
|
||||
/// Subsequent calls (from gateway or channels, whichever starts second)
|
||||
/// receive the same `Arc`. Returns `None` when cost tracking is disabled
|
||||
/// or initialisation fails.
|
||||
pub fn get_or_init_global(config: CostConfig, workspace_dir: &Path) -> Option<Arc<Self>> {
|
||||
GLOBAL_COST_TRACKER
|
||||
.get_or_init(|| {
|
||||
if !config.enabled {
|
||||
return None;
|
||||
}
|
||||
match Self::new(config, workspace_dir) {
|
||||
Ok(ct) => Some(Arc::new(ct)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to initialize global cost tracker: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_storage_path(workspace_dir: &Path) -> Result<PathBuf> {
|
||||
let storage_path = workspace_dir.join("state").join("costs.jsonl");
|
||||
let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db");
|
||||
|
||||
@@ -646,6 +646,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -657,6 +658,7 @@ mod tests {
|
||||
client_id: "client_id".into(),
|
||||
client_secret: "client_secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -672,6 +674,7 @@ mod tests {
|
||||
thread_replies: Some(true),
|
||||
mention_only: Some(false),
|
||||
interrupt_on_new_message: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -683,6 +686,7 @@ mod tests {
|
||||
app_id: "app-id".into(),
|
||||
app_secret: "app-secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -695,6 +699,7 @@ mod tests {
|
||||
app_token: "app-token".into(),
|
||||
webhook_secret: None,
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_supervised_channels(&config));
|
||||
}
|
||||
@@ -761,6 +766,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
@@ -778,6 +784,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
|
||||
@@ -1283,6 +1283,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
config.agents.insert(
|
||||
@@ -1299,6 +1300,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
+17
-3
@@ -50,6 +50,10 @@ fn require_auth(
|
||||
pub struct MemoryQuery {
|
||||
pub query: Option<String>,
|
||||
pub category: Option<String>,
|
||||
/// Filter memories created at or after (RFC 3339 / ISO 8601)
|
||||
pub since: Option<String>,
|
||||
/// Filter memories created at or before (RFC 3339 / ISO 8601)
|
||||
pub until: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -633,9 +637,12 @@ pub async fn handle_api_memory_list(
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
if let Some(ref query) = params.query {
|
||||
// Search mode
|
||||
match state.mem.recall(query, 50, None).await {
|
||||
// Use recall when query or time range is provided
|
||||
if params.query.is_some() || params.since.is_some() || params.until.is_some() {
|
||||
let query = params.query.as_deref().unwrap_or("");
|
||||
let since = params.since.as_deref();
|
||||
let until = params.until.as_deref();
|
||||
match state.mem.recall(query, 50, None, since, until).await {
|
||||
Ok(entries) => Json(serde_json::json!({"entries": entries})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@@ -1356,6 +1363,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -1429,6 +1438,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
path_prefix: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1457,6 +1467,7 @@ mod tests {
|
||||
api_url: "https://live-mt-server.wati.io".to_string(),
|
||||
tenant_id: None,
|
||||
allowed_numbers: vec![],
|
||||
proxy_url: None,
|
||||
});
|
||||
cfg.channels_config.feishu = Some(crate::config::schema::FeishuConfig {
|
||||
app_id: "cli_aabbcc".to_string(),
|
||||
@@ -1466,6 +1477,7 @@ mod tests {
|
||||
allowed_users: vec!["*".to_string()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
cfg.channels_config.email = Some(crate::channels::email_channel::EmailConfig {
|
||||
imap_host: "imap.example.com".to_string(),
|
||||
@@ -1591,6 +1603,7 @@ mod tests {
|
||||
api_url: "https://live-mt-server.wati.io".to_string(),
|
||||
tenant_id: None,
|
||||
allowed_numbers: vec![],
|
||||
proxy_url: None,
|
||||
});
|
||||
current.channels_config.feishu = Some(crate::config::schema::FeishuConfig {
|
||||
app_id: "cli_current".to_string(),
|
||||
@@ -1600,6 +1613,7 @@ mod tests {
|
||||
allowed_users: vec!["*".to_string()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
current.channels_config.email = Some(crate::channels::email_channel::EmailConfig {
|
||||
imap_host: "imap.example.com".to_string(),
|
||||
|
||||
+61
-34
@@ -348,6 +348,8 @@ pub struct AppState {
|
||||
pub shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||
/// Registry of dynamically connected nodes
|
||||
pub node_registry: Arc<nodes::NodeRegistry>,
|
||||
/// Path prefix for reverse-proxy deployments (empty string = no prefix)
|
||||
pub path_prefix: String,
|
||||
/// Session backend for persisting gateway WS chat sessions
|
||||
pub session_backend: Option<Arc<dyn SessionBackend>>,
|
||||
/// Device registry for paired device management
|
||||
@@ -505,18 +507,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
let tools_registry: Arc<Vec<ToolSpec>> =
|
||||
Arc::new(tools_registry_raw.iter().map(|t| t.spec()).collect());
|
||||
|
||||
// Cost tracker (optional)
|
||||
let cost_tracker = if config.cost.enabled {
|
||||
match CostTracker::new(config.cost.clone(), &config.workspace_dir) {
|
||||
Ok(ct) => Some(Arc::new(ct)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to initialize cost tracker: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Cost tracker — process-global singleton so channels share the same instance
|
||||
let cost_tracker = CostTracker::get_or_init_global(config.cost.clone(), &config.workspace_dir);
|
||||
|
||||
// SSE broadcast channel for real-time events
|
||||
let (event_tx, _event_rx) = tokio::sync::broadcast::channel::<serde_json::Value>(256);
|
||||
@@ -683,6 +675,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
idempotency_max_keys,
|
||||
));
|
||||
|
||||
// Resolve optional path prefix for reverse-proxy deployments.
|
||||
let path_prefix: Option<&str> = config
|
||||
.gateway
|
||||
.path_prefix
|
||||
.as_deref()
|
||||
.filter(|p| !p.is_empty());
|
||||
|
||||
// ── Tunnel ────────────────────────────────────────────────
|
||||
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
||||
let mut tunnel_url: Option<String> = None;
|
||||
@@ -701,18 +700,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
|
||||
let pfx = path_prefix.unwrap_or("");
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}{pfx}");
|
||||
if let Some(ref url) = tunnel_url {
|
||||
println!(" 🌐 Public URL: {url}");
|
||||
}
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}/");
|
||||
println!(" 🌐 Web Dashboard: http://{display_addr}{pfx}/");
|
||||
if let Some(code) = pairing.pairing_code() {
|
||||
println!();
|
||||
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
|
||||
println!(" ┌──────────────┐");
|
||||
println!(" │ {code} │");
|
||||
println!(" └──────────────┘");
|
||||
println!();
|
||||
println!(" Send: POST {pfx}/pair with header X-Pairing-Code: {code}");
|
||||
} else if pairing.require_pairing() {
|
||||
println!(" 🔒 Pairing: ACTIVE (bearer token required)");
|
||||
println!(" To pair a new device: zeroclaw gateway get-paircode --new");
|
||||
@@ -721,29 +721,29 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||
println!();
|
||||
}
|
||||
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||
println!(" POST {pfx}/pair — pair a new client (X-Pairing-Code header)");
|
||||
println!(" POST {pfx}/webhook — {{\"message\": \"your prompt\"}}");
|
||||
if whatsapp_channel.is_some() {
|
||||
println!(" GET /whatsapp — Meta webhook verification");
|
||||
println!(" POST /whatsapp — WhatsApp message webhook");
|
||||
println!(" GET {pfx}/whatsapp — Meta webhook verification");
|
||||
println!(" POST {pfx}/whatsapp — WhatsApp message webhook");
|
||||
}
|
||||
if linq_channel.is_some() {
|
||||
println!(" POST /linq — Linq message webhook (iMessage/RCS/SMS)");
|
||||
println!(" POST {pfx}/linq — Linq message webhook (iMessage/RCS/SMS)");
|
||||
}
|
||||
if wati_channel.is_some() {
|
||||
println!(" GET /wati — WATI webhook verification");
|
||||
println!(" POST /wati — WATI message webhook");
|
||||
println!(" GET {pfx}/wati — WATI webhook verification");
|
||||
println!(" POST {pfx}/wati — WATI message webhook");
|
||||
}
|
||||
if nextcloud_talk_channel.is_some() {
|
||||
println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook");
|
||||
println!(" POST {pfx}/nextcloud-talk — Nextcloud Talk bot webhook");
|
||||
}
|
||||
println!(" GET /api/* — REST API (bearer token required)");
|
||||
println!(" GET /ws/chat — WebSocket agent chat");
|
||||
println!(" GET {pfx}/api/* — REST API (bearer token required)");
|
||||
println!(" GET {pfx}/ws/chat — WebSocket agent chat");
|
||||
if config.nodes.enabled {
|
||||
println!(" GET /ws/nodes — WebSocket node discovery");
|
||||
println!(" GET {pfx}/ws/nodes — WebSocket node discovery");
|
||||
}
|
||||
println!(" GET /health — health check");
|
||||
println!(" GET /metrics — Prometheus metrics");
|
||||
println!(" GET {pfx}/health — health check");
|
||||
println!(" GET {pfx}/metrics — Prometheus metrics");
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
crate::health::mark_component_ok("gateway");
|
||||
@@ -809,6 +809,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
session_backend,
|
||||
device_registry,
|
||||
pending_pairings,
|
||||
path_prefix: path_prefix.unwrap_or("").to_string(),
|
||||
};
|
||||
|
||||
// Config PUT needs larger body limit (1MB)
|
||||
@@ -817,7 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.layer(RequestBodyLimitLayer::new(1_048_576));
|
||||
|
||||
// Build router with middleware
|
||||
let app = Router::new()
|
||||
let inner = Router::new()
|
||||
// ── Admin routes (for CLI management) ──
|
||||
.route("/admin/shutdown", post(handle_admin_shutdown))
|
||||
.route("/admin/paircode", get(handle_admin_paircode))
|
||||
@@ -877,12 +878,12 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
|
||||
// ── Plugin management API (requires plugins-wasm feature) ──
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
let app = app.route(
|
||||
let inner = inner.route(
|
||||
"/api/plugins",
|
||||
get(api_plugins::plugin_routes::list_plugins),
|
||||
);
|
||||
|
||||
let app = app
|
||||
let inner = inner
|
||||
// ── SSE event stream ──
|
||||
.route("/api/events", get(sse::handle_sse_events))
|
||||
// ── WebSocket agent chat ──
|
||||
@@ -893,14 +894,27 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/_app/{*path}", get(static_files::handle_static))
|
||||
// ── Config PUT with larger body limit ──
|
||||
.merge(config_put_router)
|
||||
// ── SPA fallback: non-API GET requests serve index.html ──
|
||||
.fallback(get(static_files::handle_spa_fallback))
|
||||
.with_state(state)
|
||||
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
|
||||
.layer(TimeoutLayer::with_status_code(
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
Duration::from_secs(gateway_request_timeout_secs()),
|
||||
))
|
||||
// ── SPA fallback: non-API GET requests serve index.html ──
|
||||
.fallback(get(static_files::handle_spa_fallback));
|
||||
));
|
||||
|
||||
// Nest under path prefix when configured (axum strips prefix before routing).
|
||||
// nest() at "/prefix" handles both "/prefix" and "/prefix/*" but not "/prefix/"
|
||||
// with a trailing slash, so we add a fallback redirect for that case.
|
||||
let app = if let Some(prefix) = path_prefix {
|
||||
let redirect_target = prefix.to_string();
|
||||
Router::new().nest(prefix, inner).route(
|
||||
&format!("{prefix}/"),
|
||||
get(|| async move { axum::response::Redirect::permanent(&redirect_target) }),
|
||||
)
|
||||
} else {
|
||||
inner
|
||||
};
|
||||
|
||||
// Run the server with graceful shutdown
|
||||
axum::serve(
|
||||
@@ -1992,6 +2006,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2047,6 +2062,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2287,6 +2303,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -2362,6 +2380,8 @@ mod tests {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -2427,6 +2447,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2496,6 +2517,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2577,6 +2599,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2630,6 +2653,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2688,6 +2712,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2751,6 +2776,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
@@ -2810,6 +2836,7 @@ mod tests {
|
||||
event_tx: tokio::sync::broadcast::channel(16).0,
|
||||
shutdown_tx: tokio::sync::watch::channel(false).0,
|
||||
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
|
||||
path_prefix: String::new(),
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
//! Uses `rust-embed` to bundle the `web/dist/` directory into the binary at compile time.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{header, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use rust_embed::Embed;
|
||||
|
||||
use super::AppState;
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "web/dist/"]
|
||||
struct WebAssets;
|
||||
@@ -23,16 +26,41 @@ pub async fn handle_static(uri: Uri) -> Response {
|
||||
serve_embedded_file(path)
|
||||
}
|
||||
|
||||
/// SPA fallback: serve index.html for any non-API, non-static GET request
|
||||
pub async fn handle_spa_fallback() -> Response {
|
||||
if WebAssets::get("index.html").is_none() {
|
||||
/// SPA fallback: serve index.html for any non-API, non-static GET request.
|
||||
/// Injects `window.__ZEROCLAW_BASE__` so the frontend knows the path prefix.
|
||||
pub async fn handle_spa_fallback(State(state): State<AppState>) -> Response {
|
||||
let Some(content) = WebAssets::get("index.html") else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"Web dashboard not available. Build it with: cd web && npm ci && npm run build",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
serve_embedded_file("index.html")
|
||||
};
|
||||
|
||||
let html = String::from_utf8_lossy(&content.data);
|
||||
|
||||
// Inject path prefix for the SPA and rewrite asset paths in the HTML
|
||||
let html = if state.path_prefix.is_empty() {
|
||||
html.into_owned()
|
||||
} else {
|
||||
let pfx = &state.path_prefix;
|
||||
// JSON-encode the prefix to safely embed in a <script> block
|
||||
let json_pfx = serde_json::to_string(pfx).unwrap_or_else(|_| "\"\"".to_string());
|
||||
let script = format!("<script>window.__ZEROCLAW_BASE__={json_pfx};</script>");
|
||||
// Rewrite absolute /_app/ references so the browser requests {prefix}/_app/...
|
||||
html.replace("/_app/", &format!("{pfx}/_app/"))
|
||||
.replace("<head>", &format!("<head>{script}"))
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
[
|
||||
(header::CONTENT_TYPE, "text/html; charset=utf-8".to_string()),
|
||||
(header::CACHE_CONTROL, "no-cache".to_string()),
|
||||
],
|
||||
html,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn serve_embedded_file(path: &str) -> Response {
|
||||
|
||||
@@ -841,6 +841,7 @@ mod tests {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
let entries = all_integrations();
|
||||
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
|
||||
|
||||
+18
@@ -71,6 +71,7 @@ pub mod runtime;
|
||||
pub(crate) mod security;
|
||||
pub(crate) mod service;
|
||||
pub(crate) mod skills;
|
||||
pub mod sop;
|
||||
pub mod tools;
|
||||
pub(crate) mod tunnel;
|
||||
pub(crate) mod util;
|
||||
@@ -561,3 +562,20 @@ Examples:
|
||||
/// Flash ZeroClaw firmware to Nucleo-F401RE (builds + probe-rs run)
|
||||
FlashNucleo,
|
||||
}
|
||||
|
||||
/// SOP management subcommands
|
||||
#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum SopCommands {
|
||||
/// List loaded SOPs
|
||||
List,
|
||||
/// Validate SOP definitions
|
||||
Validate {
|
||||
/// SOP name to validate (all if omitted)
|
||||
name: Option<String>,
|
||||
},
|
||||
/// Show details of an SOP
|
||||
Show {
|
||||
/// Name of the SOP to show
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
+2
-1
@@ -107,6 +107,7 @@ mod security;
|
||||
mod service;
|
||||
mod skillforge;
|
||||
mod skills;
|
||||
mod sop;
|
||||
mod tools;
|
||||
mod tunnel;
|
||||
mod util;
|
||||
@@ -117,7 +118,7 @@ use config::Config;
|
||||
// Re-export so binary modules can use crate::<CommandEnum> while keeping a single source of truth.
|
||||
pub use zeroclaw::{
|
||||
ChannelCommands, CronCommands, GatewayCommands, HardwareCommands, IntegrationCommands,
|
||||
MigrateCommands, PeripheralCommands, ServiceCommands, SkillCommands,
|
||||
MigrateCommands, PeripheralCommands, ServiceCommands, SkillCommands, SopCommands,
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
|
||||
|
||||
+47
-7
@@ -325,8 +325,27 @@ impl Memory for LucidMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||
let since_dt = since
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'since' date (expected RFC 3339): {e}"))?;
|
||||
let until_dt = until
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'until' date (expected RFC 3339): {e}"))?;
|
||||
if let (Some(s), Some(u)) = (&since_dt, &until_dt) {
|
||||
if s >= u {
|
||||
anyhow::bail!("'since' must be before 'until'");
|
||||
}
|
||||
}
|
||||
|
||||
let local_results = self
|
||||
.local
|
||||
.recall(query, limit, session_id, since, until)
|
||||
.await?;
|
||||
if limit == 0
|
||||
|| local_results.len() >= limit
|
||||
|| local_results.len() >= self.local_hit_threshold
|
||||
@@ -341,7 +360,28 @@ impl Memory for LucidMemory {
|
||||
match self.recall_from_lucid(query).await {
|
||||
Ok(lucid_results) if !lucid_results.is_empty() => {
|
||||
self.clear_failure();
|
||||
Ok(Self::merge_results(local_results, lucid_results, limit))
|
||||
let merged = Self::merge_results(local_results, lucid_results, limit);
|
||||
let filtered: Vec<MemoryEntry> = merged
|
||||
.into_iter()
|
||||
.filter(|e| {
|
||||
if let Some(ref s) = since_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&e.timestamp) {
|
||||
if ts < *s {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(ref u) = until_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&e.timestamp) {
|
||||
if ts > *u {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
Ok(_) => {
|
||||
self.clear_failure();
|
||||
@@ -541,7 +581,7 @@ exit 1
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
let entries = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
@@ -565,7 +605,7 @@ exit 1
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
let entries = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
@@ -603,7 +643,7 @@ exit 1
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||
let entries = memory.recall("rust", 5, None, None, None).await.unwrap();
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||
@@ -663,8 +703,8 @@ exit 1
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||
let first = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None, None, None).await.unwrap();
|
||||
|
||||
assert!(first.is_empty());
|
||||
assert!(second.is_empty());
|
||||
|
||||
+47
-6
@@ -158,7 +158,23 @@ impl Memory for MarkdownMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let since_dt = since
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'since' date (expected RFC 3339): {e}"))?;
|
||||
let until_dt = until
|
||||
.map(chrono::DateTime::parse_from_rfc3339)
|
||||
.transpose()
|
||||
.map_err(|e| anyhow::anyhow!("invalid 'until' date (expected RFC 3339): {e}"))?;
|
||||
if let (Some(s), Some(u)) = (&since_dt, &until_dt) {
|
||||
if s >= u {
|
||||
anyhow::bail!("'since' must be before 'until'");
|
||||
}
|
||||
}
|
||||
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
@@ -166,6 +182,24 @@ impl Memory for MarkdownMemory {
|
||||
let mut scored: Vec<MemoryEntry> = all
|
||||
.into_iter()
|
||||
.filter_map(|mut entry| {
|
||||
if let Some(ref s) = since_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
if ts < *s {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(ref u) = until_dt {
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
if ts > *u {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
if keywords.is_empty() {
|
||||
entry.score = Some(1.0);
|
||||
return Some(entry);
|
||||
}
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let matched = keywords
|
||||
.iter()
|
||||
@@ -183,9 +217,13 @@ impl Memory for MarkdownMemory {
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
if keywords.is_empty() {
|
||||
b.timestamp.as_str().cmp(a.timestamp.as_str())
|
||||
} else {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
});
|
||||
scored.truncate(limit);
|
||||
Ok(scored)
|
||||
@@ -283,7 +321,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
@@ -296,7 +334,10 @@ mod tests {
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("javascript", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
@@ -343,7 +384,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||
let results = mem.recall("anything", 10, None, None, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
||||
+5
-1
@@ -364,14 +364,18 @@ impl Memory for Mem0Memory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// mem0 handles filtering server-side; since/until are not yet
|
||||
// supported by the mem0 API, so we pass them through as no-ops.
|
||||
self.recall_filtered(query, limit, session_id, None, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
// mem0 doesn't have a get-by-key API, so we search by key in metadata
|
||||
let results = self.recall(key, 1, None).await?;
|
||||
let results = self.recall(key, 1, None, None, None).await?;
|
||||
Ok(results.into_iter().find(|e| e.key == key))
|
||||
}
|
||||
|
||||
|
||||
+7
-1
@@ -35,6 +35,8 @@ impl Memory for NoneMemory {
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
_since: Option<&str>,
|
||||
_until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
@@ -78,7 +80,11 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
assert!(memory.get("k").await.unwrap().is_none());
|
||||
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||
assert!(memory
|
||||
.recall("k", 10, None, None, None)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||
assert!(!memory.forget("k").await.unwrap());
|
||||
assert_eq!(memory.count().await.unwrap(), 0);
|
||||
|
||||
+23
-1
@@ -239,14 +239,30 @@ impl Memory for PostgresMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
let client = self.client.clone();
|
||||
let qualified_table = self.qualified_table.clone();
|
||||
let query = query.trim().to_string();
|
||||
let sid = session_id.map(str::to_string);
|
||||
let since_owned = since.map(str::to_string);
|
||||
let until_owned = until.map(str::to_string);
|
||||
|
||||
run_on_os_thread(move || -> Result<Vec<MemoryEntry>> {
|
||||
let mut client = client.lock();
|
||||
let since_ref = since_owned.as_deref();
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let time_filter: String = match (since_ref, until_ref) {
|
||||
(Some(_), Some(_)) => {
|
||||
" AND created_at >= $4::TIMESTAMPTZ AND created_at <= $5::TIMESTAMPTZ".into()
|
||||
}
|
||||
(Some(_), None) => " AND created_at >= $4::TIMESTAMPTZ".into(),
|
||||
(None, Some(_)) => " AND created_at <= $4::TIMESTAMPTZ".into(),
|
||||
(None, None) => String::new(),
|
||||
};
|
||||
|
||||
let stmt = format!(
|
||||
"
|
||||
SELECT id, key, content, category, created_at, session_id,
|
||||
@@ -257,6 +273,7 @@ impl Memory for PostgresMemory {
|
||||
FROM {qualified_table}
|
||||
WHERE ($2::TEXT IS NULL OR session_id = $2)
|
||||
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
|
||||
{time_filter}
|
||||
ORDER BY score DESC, updated_at DESC
|
||||
LIMIT $3
|
||||
"
|
||||
@@ -265,7 +282,12 @@ impl Memory for PostgresMemory {
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let limit_i64 = limit as i64;
|
||||
|
||||
let rows = client.query(&stmt, &[&query, &sid, &limit_i64])?;
|
||||
let rows = match (since_ref, until_ref) {
|
||||
(Some(s), Some(u)) => client.query(&stmt, &[&query, &sid, &limit_i64, &s, &u])?,
|
||||
(Some(s), None) => client.query(&stmt, &[&query, &sid, &limit_i64, &s])?,
|
||||
(None, Some(u)) => client.query(&stmt, &[&query, &sid, &limit_i64, &u])?,
|
||||
(None, None) => client.query(&stmt, &[&query, &sid, &limit_i64])?,
|
||||
};
|
||||
rows.iter()
|
||||
.map(Self::row_to_entry)
|
||||
.collect::<Result<Vec<MemoryEntry>>>()
|
||||
|
||||
+20
-2
@@ -291,9 +291,19 @@ impl Memory for QdrantMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
if query.trim().is_empty() {
|
||||
return self.list(None, session_id).await;
|
||||
let mut entries = self.list(None, session_id).await?;
|
||||
if let Some(s) = since {
|
||||
entries.retain(|e| e.timestamp.as_str() >= s);
|
||||
}
|
||||
if let Some(u) = until {
|
||||
entries.retain(|e| e.timestamp.as_str() <= u);
|
||||
}
|
||||
entries.truncate(limit);
|
||||
return Ok(entries);
|
||||
}
|
||||
|
||||
self.ensure_initialized().await?;
|
||||
@@ -344,7 +354,7 @@ impl Memory for QdrantMemory {
|
||||
|
||||
let result: QdrantSearchResult = resp.json().await?;
|
||||
|
||||
let entries = result
|
||||
let mut entries: Vec<MemoryEntry> = result
|
||||
.result
|
||||
.into_iter()
|
||||
.filter_map(|point| {
|
||||
@@ -367,6 +377,14 @@ impl Memory for QdrantMemory {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Filter by time range if specified
|
||||
if let Some(s) = since {
|
||||
entries.retain(|e| e.timestamp.as_str() >= s);
|
||||
}
|
||||
if let Some(u) = until {
|
||||
entries.retain(|e| e.timestamp.as_str() <= u);
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
|
||||
+167
-36
@@ -428,6 +428,74 @@ impl SqliteMemory {
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// List memories by time range (used when query is empty).
|
||||
async fn recall_by_time_only(
|
||||
&self,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self.conn.clone();
|
||||
let sid = session_id.map(String::from);
|
||||
let since_owned = since.map(String::from);
|
||||
let until_owned = until.map(String::from);
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let since_ref = since_owned.as_deref();
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
let mut sql =
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories \
|
||||
WHERE 1=1"
|
||||
.to_string();
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
let mut idx = 1;
|
||||
|
||||
if let Some(sid) = sid.as_deref() {
|
||||
let _ = write!(sql, " AND session_id = ?{idx}");
|
||||
param_values.push(Box::new(sid.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
if let Some(s) = since_ref {
|
||||
let _ = write!(sql, " AND created_at >= ?{idx}");
|
||||
param_values.push(Box::new(s.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
if let Some(u) = until_ref {
|
||||
let _ = write!(sql, " AND created_at <= ?{idx}");
|
||||
param_values.push(Box::new(u.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
let _ = write!(sql, " ORDER BY updated_at DESC LIMIT ?{idx}");
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
param_values.push(Box::new(limit as i64));
|
||||
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
param_values.iter().map(AsRef::as_ref).collect();
|
||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
Ok(results)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -481,9 +549,14 @@ impl Memory for SqliteMemory {
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
// Time-only query: list by time range when no keywords
|
||||
if query.trim().is_empty() {
|
||||
return Ok(Vec::new());
|
||||
return self
|
||||
.recall_by_time_only(limit, session_id, since, until)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Compute query embedding (async, before blocking work)
|
||||
@@ -492,12 +565,16 @@ impl Memory for SqliteMemory {
|
||||
let conn = self.conn.clone();
|
||||
let query = query.to_string();
|
||||
let sid = session_id.map(String::from);
|
||||
let since_owned = since.map(String::from);
|
||||
let until_owned = until.map(String::from);
|
||||
let vector_weight = self.vector_weight;
|
||||
let keyword_weight = self.keyword_weight;
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = conn.lock();
|
||||
let session_ref = sid.as_deref();
|
||||
let since_ref = since_owned.as_deref();
|
||||
let until_ref = until_owned.as_deref();
|
||||
|
||||
// FTS5 BM25 keyword search
|
||||
let keyword_results = Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default();
|
||||
@@ -568,6 +645,16 @@ impl Memory for SqliteMemory {
|
||||
|
||||
for scored in &merged {
|
||||
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||
if let Some(s) = since_ref {
|
||||
if ts.as_str() < s {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Some(u) = until_ref {
|
||||
if ts.as_str() > u {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let entry = MemoryEntry {
|
||||
id: scored.id.clone(),
|
||||
key,
|
||||
@@ -588,8 +675,6 @@ impl Memory for SqliteMemory {
|
||||
}
|
||||
|
||||
// If hybrid returned nothing, fall back to LIKE search.
|
||||
// Cap keyword count so we don't create too many SQL shapes,
|
||||
// which helps prepared-statement cache efficiency.
|
||||
if results.is_empty() {
|
||||
const MAX_LIKE_KEYWORDS: usize = 8;
|
||||
let keywords: Vec<String> = query
|
||||
@@ -606,12 +691,21 @@ impl Memory for SqliteMemory {
|
||||
})
|
||||
.collect();
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let mut param_idx = keywords.len() * 2 + 1;
|
||||
let mut time_conditions = String::new();
|
||||
if since_ref.is_some() {
|
||||
let _ = write!(time_conditions, " AND created_at >= ?{param_idx}");
|
||||
param_idx += 1;
|
||||
}
|
||||
if until_ref.is_some() {
|
||||
let _ = write!(time_conditions, " AND created_at <= ?{param_idx}");
|
||||
param_idx += 1;
|
||||
}
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}
|
||||
WHERE {where_clause}{time_conditions}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
keywords.len() * 2 + 1
|
||||
LIMIT ?{param_idx}"
|
||||
);
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
@@ -619,6 +713,12 @@ impl Memory for SqliteMemory {
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
}
|
||||
if let Some(s) = since_ref {
|
||||
param_values.push(Box::new(s.to_string()));
|
||||
}
|
||||
if let Some(u) = until_ref {
|
||||
param_values.push(Box::new(u.to_string()));
|
||||
}
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
param_values.push(Box::new(limit as i64));
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
@@ -852,7 +952,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
@@ -869,7 +969,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||
let results = mem.recall("fast safe", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
@@ -881,7 +981,10 @@ mod tests {
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("javascript", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
@@ -1024,7 +1127,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
// All results should contain "Rust"
|
||||
for r in &results {
|
||||
@@ -1049,30 +1152,34 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||
let results = mem.recall("quick dog", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// "The quick dog runs fast" matches both terms
|
||||
assert!(results[0].content.contains("quick"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_empty_query_returns_empty() {
|
||||
async fn recall_empty_query_returns_recent_entries() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
// Empty query = time-only mode: returns recent entries
|
||||
let results = mem.recall("", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "a");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_whitespace_query_returns_empty() {
|
||||
async fn recall_whitespace_query_returns_recent_entries() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
// Whitespace-only query = time-only mode: returns recent entries
|
||||
let results = mem.recall(" ", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "a");
|
||||
}
|
||||
|
||||
// ── Embedding cache tests ────────────────────────────────────
|
||||
@@ -1283,7 +1390,7 @@ mod tests {
|
||||
assert_eq!(count, 0);
|
||||
|
||||
// FTS should still work after rebuild
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
@@ -1303,7 +1410,10 @@ mod tests {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("common keyword", 5, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
@@ -1316,7 +1426,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||
let results = mem.recall("scored", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
for r in &results {
|
||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||
@@ -1332,7 +1442,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
// Quotes in query should not crash FTS5
|
||||
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||
let results = mem.recall("\"hello\"", 10, None, None, None).await.unwrap();
|
||||
// May or may not match depending on FTS5 escaping, but must not error
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
@@ -1343,7 +1453,7 @@ mod tests {
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||
let results = mem.recall("wild*", 10, None, None, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
@@ -1353,7 +1463,10 @@ mod tests {
|
||||
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("function()", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
@@ -1365,7 +1478,7 @@ mod tests {
|
||||
.unwrap();
|
||||
// Should not crash or leak data
|
||||
let results = mem
|
||||
.recall("'; DROP TABLE memories; --", 10, None)
|
||||
.recall("'; DROP TABLE memories; --", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
@@ -1441,7 +1554,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
// Single char may not match FTS5 but LIKE fallback should work
|
||||
let results = mem.recall("x", 10, None).await.unwrap();
|
||||
let results = mem.recall("x", 10, None, None, None).await.unwrap();
|
||||
// Should not crash; may or may not find results
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
@@ -1452,7 +1565,7 @@ mod tests {
|
||||
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("some", 0, None).await.unwrap();
|
||||
let results = mem.recall("some", 0, None, None, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
@@ -1465,7 +1578,10 @@ mod tests {
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("matching content", 1, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
@@ -1481,7 +1597,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||
let results = mem.recall("rust", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty(), "Should match by key");
|
||||
}
|
||||
|
||||
@@ -1491,7 +1607,7 @@ mod tests {
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||
let results = mem.recall("日本語", 10, None, None, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
@@ -1541,7 +1657,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("ghost").await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("phantom memory", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Deleted memory should not appear in recall"
|
||||
@@ -1582,7 +1701,7 @@ mod tests {
|
||||
let count = mem.reindex().await.unwrap();
|
||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||
// Data should still be intact
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
@@ -1686,7 +1805,10 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Recall with session-a filter returns only session-a entry
|
||||
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("fact", 10, Some("sess-a"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||
@@ -1706,7 +1828,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Recall without session filter returns all matching entries
|
||||
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||
let results = mem.recall("fact", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
|
||||
@@ -1723,11 +1845,17 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Session B cannot see session A data
|
||||
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("secret", 10, Some("sess-b"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Session A can see its own data
|
||||
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("secret", 10, Some("sess-a"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
@@ -1778,7 +1906,10 @@ mod tests {
|
||||
// Second open: migration runs again but is idempotent
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||
let results = mem
|
||||
.recall("reopen", 10, Some("sess-x"), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||
|
||||
@@ -96,11 +96,15 @@ pub trait Memory: Send + Sync {
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||
/// and time range. Time bounds use RFC 3339 / ISO 8601 format
|
||||
/// (e.g. "2025-03-01T00:00:00Z"); inclusive (created_at >= since, created_at <= until).
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
since: Option<&str>,
|
||||
until: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
|
||||
@@ -154,6 +154,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
agent: crate::config::schema::AgentConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -199,6 +200,8 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
claude_code: crate::config::ClaudeCodeConfig::default(),
|
||||
sop: crate::config::SopConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@@ -575,6 +578,7 @@ async fn run_quick_setup_with_home(
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||
agent: crate::config::schema::AgentConfig::default(),
|
||||
pacing: crate::config::PacingConfig::default(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
embedding_routes: Vec::new(),
|
||||
@@ -620,6 +624,8 @@ async fn run_quick_setup_with_home(
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
claude_code: crate::config::ClaudeCodeConfig::default(),
|
||||
sop: crate::config::SopConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
@@ -3790,6 +3796,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
ack_reactions: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Discord => {
|
||||
@@ -3890,6 +3897,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
listen_to_bots: false,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Slack => {
|
||||
@@ -4020,6 +4028,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
interrupt_on_new_message: false,
|
||||
thread_replies: None,
|
||||
mention_only: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::IMessage => {
|
||||
@@ -4271,6 +4280,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
allowed_from,
|
||||
ignore_attachments,
|
||||
ignore_stories,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
println!(" {} Signal configured", style("✅").green().bold());
|
||||
@@ -4372,6 +4382,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
println!(
|
||||
@@ -4477,6 +4488,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
dm_policy: WhatsAppChatPolicy::default(),
|
||||
group_policy: WhatsAppChatPolicy::default(),
|
||||
self_chat_mode: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Linq => {
|
||||
@@ -4810,6 +4822,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
Some(webhook_secret.trim().to_string())
|
||||
},
|
||||
allowed_users,
|
||||
proxy_url: None,
|
||||
});
|
||||
|
||||
println!(" {} Nextcloud Talk configured", style("✅").green().bold());
|
||||
@@ -4882,6 +4895,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
client_id,
|
||||
client_secret,
|
||||
allowed_users,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::QqOfficial => {
|
||||
@@ -4958,6 +4972,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
app_id,
|
||||
app_secret,
|
||||
allowed_users,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::Lark | ChannelMenuChoice::Feishu => {
|
||||
@@ -5147,6 +5162,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
use_feishu: is_feishu,
|
||||
receive_mode,
|
||||
port,
|
||||
proxy_url: None,
|
||||
});
|
||||
}
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
@@ -7511,6 +7527,7 @@ mod tests {
|
||||
allowed_from: vec!["*".into()],
|
||||
ignore_attachments: false,
|
||||
ignore_stories: true,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7523,6 +7540,7 @@ mod tests {
|
||||
thread_replies: Some(true),
|
||||
mention_only: Some(false),
|
||||
interrupt_on_new_message: false,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7531,6 +7549,7 @@ mod tests {
|
||||
app_id: "app-id".into(),
|
||||
app_secret: "app-secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7540,6 +7559,7 @@ mod tests {
|
||||
app_token: "token".into(),
|
||||
webhook_secret: Some("secret".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
@@ -7552,6 +7572,7 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
proxy_url: None,
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
}
|
||||
|
||||
+1
-1
@@ -146,7 +146,7 @@ fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec<Skill
|
||||
load_skills_from_directory(&skills_dir, allow_scripts)
|
||||
}
|
||||
|
||||
fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
pub fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
if !skills_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
+1
-27
@@ -78,33 +78,6 @@ impl SopAuditLogger {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Log a gate evaluation decision record.
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
pub async fn log_gate_decision(
|
||||
&self,
|
||||
record: &ersona_engine::gates::decision::GateDecisionRecord,
|
||||
) -> Result<()> {
|
||||
let timestamp_ms = chrono::Utc::now().timestamp_millis();
|
||||
let key = format!("sop_gate_decision_{}_{timestamp_ms}", record.gate_id);
|
||||
let content = serde_json::to_string_pretty(record)?;
|
||||
self.memory.store(&key, &content, category(), None).await?;
|
||||
info!(
|
||||
gate_id = %record.gate_id,
|
||||
decision = %record.decision,
|
||||
"SOP audit: gate decision logged"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist (upsert) the current gate phase state.
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
pub async fn log_phase_state(&self, state: &ersona_core::state::PhaseState) -> Result<()> {
|
||||
let key = "sop_phase_state";
|
||||
let content = serde_json::to_string_pretty(state)?;
|
||||
self.memory.store(key, &content, category(), None).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve a stored run by ID (if it exists in memory).
|
||||
pub async fn get_run(&self, run_id: &str) -> Result<Option<SopRun>> {
|
||||
let key = run_key(run_id);
|
||||
@@ -166,6 +139,7 @@ mod tests {
|
||||
completed_at: None,
|
||||
step_results: Vec::new(),
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+33
-9
@@ -24,7 +24,7 @@ pub enum DispatchResult {
|
||||
Started {
|
||||
run_id: String,
|
||||
sop_name: String,
|
||||
action: SopRunAction,
|
||||
action: Box<SopRunAction>,
|
||||
},
|
||||
/// A matching SOP was found but could not start (cooldown / concurrency).
|
||||
Skipped { sop_name: String, reason: String },
|
||||
@@ -39,6 +39,8 @@ fn extract_run_id_from_action(action: &SopRunAction) -> &str {
|
||||
match action {
|
||||
SopRunAction::ExecuteStep { run_id, .. }
|
||||
| SopRunAction::WaitApproval { run_id, .. }
|
||||
| SopRunAction::DeterministicStep { run_id, .. }
|
||||
| SopRunAction::CheckpointWait { run_id, .. }
|
||||
| SopRunAction::Completed { run_id, .. }
|
||||
| SopRunAction::Failed { run_id, .. } => run_id,
|
||||
}
|
||||
@@ -49,6 +51,8 @@ fn action_label(action: &SopRunAction) -> &'static str {
|
||||
match action {
|
||||
SopRunAction::ExecuteStep { .. } => "ExecuteStep",
|
||||
SopRunAction::WaitApproval { .. } => "WaitApproval",
|
||||
SopRunAction::DeterministicStep { .. } => "DeterministicStep",
|
||||
SopRunAction::CheckpointWait { .. } => "CheckpointWait",
|
||||
SopRunAction::Completed { .. } => "Completed",
|
||||
SopRunAction::Failed { .. } => "Failed",
|
||||
}
|
||||
@@ -62,7 +66,6 @@ fn action_label(action: &SopRunAction) -> &'static str {
|
||||
/// 1. Lock → `match_trigger` → collect SOP names → drop lock
|
||||
/// 2. Lock → for each name: `start_run` → collect results → drop lock
|
||||
/// 3. Async (no lock): audit each started run
|
||||
#[tracing::instrument(skip(engine, audit), fields(source = %event.source, topic = ?event.topic))]
|
||||
pub async fn dispatch_sop_event(
|
||||
engine: &Arc<Mutex<SopEngine>>,
|
||||
audit: &SopAuditLogger,
|
||||
@@ -124,7 +127,7 @@ pub async fn dispatch_sop_event(
|
||||
results.push(DispatchResult::Started {
|
||||
run_id,
|
||||
sop_name: sop_name.clone(),
|
||||
action,
|
||||
action: Box::new(action),
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -158,14 +161,14 @@ pub async fn dispatch_sop_event(
|
||||
/// approval timeout polling in the scheduler handles progression.
|
||||
/// For `ExecuteStep` actions, the run is started in the engine but steps
|
||||
/// cannot be executed without an agent loop — this is logged as a warning.
|
||||
pub async fn process_headless_results(results: &[DispatchResult]) {
|
||||
pub fn process_headless_results(results: &[DispatchResult]) {
|
||||
for result in results {
|
||||
match result {
|
||||
DispatchResult::Started {
|
||||
run_id,
|
||||
sop_name,
|
||||
action,
|
||||
} => match action {
|
||||
} => match action.as_ref() {
|
||||
SopRunAction::ExecuteStep { step, .. } => {
|
||||
warn!(
|
||||
"SOP headless dispatch: run {run_id} ('{sop_name}') ready for step {} \
|
||||
@@ -180,6 +183,24 @@ pub async fn process_headless_results(results: &[DispatchResult]) {
|
||||
step.number, step.title,
|
||||
);
|
||||
}
|
||||
SopRunAction::DeterministicStep { step, .. } => {
|
||||
info!(
|
||||
"SOP headless dispatch: run {run_id} ('{sop_name}') deterministic step {} \
|
||||
'{}'",
|
||||
step.number, step.title,
|
||||
);
|
||||
}
|
||||
SopRunAction::CheckpointWait {
|
||||
step, state_file, ..
|
||||
} => {
|
||||
info!(
|
||||
"SOP headless dispatch: run {run_id} ('{sop_name}') checkpoint at step {} \
|
||||
'{}', state persisted to {}",
|
||||
step.number,
|
||||
step.title,
|
||||
state_file.display(),
|
||||
);
|
||||
}
|
||||
SopRunAction::Completed { .. } => {
|
||||
info!(
|
||||
"SOP headless dispatch: run {run_id} ('{sop_name}') completed immediately"
|
||||
@@ -250,7 +271,7 @@ impl SopCronCache {
|
||||
for trigger in &sop.triggers {
|
||||
if let super::types::SopTrigger::Cron { expression } = trigger {
|
||||
// Normalize 5-field crontab to 6-field (prepend seconds)
|
||||
let normalized = match crate::cron::schedule::normalize_expression(expression) {
|
||||
let normalized = match crate::cron::normalize_expression(expression) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
@@ -349,10 +370,13 @@ mod tests {
|
||||
body: "Do step one".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: crate::sop::SopStepKind::default(),
|
||||
schema: None,
|
||||
}],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 2,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,7 +420,7 @@ mod tests {
|
||||
let results = dispatch_sop_event(&engine, &audit, event).await;
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(
|
||||
matches!(&results[0], DispatchResult::Started { sop_name, action, .. } if sop_name == "mqtt-sop" && matches!(action, SopRunAction::ExecuteStep { .. }))
|
||||
matches!(&results[0], DispatchResult::Started { sop_name, action, .. } if sop_name == "mqtt-sop" && matches!(action.as_ref(), SopRunAction::ExecuteStep { .. }))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -534,7 +558,7 @@ mod tests {
|
||||
assert_eq!(sop_name, "supervised-sop");
|
||||
assert!(!run_id.is_empty());
|
||||
assert!(
|
||||
matches!(action, SopRunAction::WaitApproval { .. }),
|
||||
matches!(action.as_ref(), SopRunAction::WaitApproval { .. }),
|
||||
"Supervised SOP must return WaitApproval, got {:?}",
|
||||
action
|
||||
);
|
||||
@@ -561,7 +585,7 @@ mod tests {
|
||||
match &results[0] {
|
||||
DispatchResult::Started { action, .. } => {
|
||||
assert!(
|
||||
matches!(action, SopRunAction::ExecuteStep { .. }),
|
||||
matches!(action.as_ref(), SopRunAction::ExecuteStep { .. }),
|
||||
"Auto SOP must return ExecuteStep, got {:?}",
|
||||
action
|
||||
);
|
||||
|
||||
+303
-15
@@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use tracing::{info, warn};
|
||||
@@ -8,8 +8,9 @@ use tracing::{info, warn};
|
||||
use super::condition::evaluate_condition;
|
||||
use super::load_sops;
|
||||
use super::types::{
|
||||
Sop, SopEvent, SopPriority, SopRun, SopRunAction, SopRunStatus, SopStep, SopStepResult,
|
||||
SopStepStatus, SopTrigger, SopTriggerSource,
|
||||
DeterministicRunState, DeterministicSavings, Sop, SopEvent, SopExecutionMode, SopPriority,
|
||||
SopRun, SopRunAction, SopRunStatus, SopStep, SopStepKind, SopStepResult, SopStepStatus,
|
||||
SopTrigger, SopTriggerSource,
|
||||
};
|
||||
use crate::config::SopConfig;
|
||||
|
||||
@@ -21,6 +22,8 @@ pub struct SopEngine {
|
||||
finished_runs: Vec<SopRun>,
|
||||
config: SopConfig,
|
||||
run_counter: u64,
|
||||
/// Cumulative savings from deterministic execution.
|
||||
deterministic_savings: DeterministicSavings,
|
||||
}
|
||||
|
||||
impl SopEngine {
|
||||
@@ -32,6 +35,7 @@ impl SopEngine {
|
||||
finished_runs: Vec::new(),
|
||||
config,
|
||||
run_counter: 0,
|
||||
deterministic_savings: DeterministicSavings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +44,7 @@ impl SopEngine {
|
||||
self.sops = load_sops(
|
||||
workspace_dir,
|
||||
self.config.sops_dir.as_deref(),
|
||||
self.config.default_execution_mode,
|
||||
super::parse_execution_mode(&self.config.default_execution_mode),
|
||||
);
|
||||
info!("SOP engine loaded {} SOPs", self.sops.len());
|
||||
}
|
||||
@@ -118,7 +122,15 @@ impl SopEngine {
|
||||
}
|
||||
|
||||
/// Start a new SOP run. Returns the first action to take.
|
||||
/// Deterministic SOPs are automatically routed to `start_deterministic_run`.
|
||||
pub fn start_run(&mut self, sop_name: &str, event: SopEvent) -> Result<SopRunAction> {
|
||||
// Route deterministic SOPs to dedicated path
|
||||
if self.get_sop(sop_name).map_or(false, |s| {
|
||||
s.execution_mode == SopExecutionMode::Deterministic
|
||||
}) {
|
||||
return self.start_deterministic_run(sop_name, event);
|
||||
}
|
||||
|
||||
let sop = self
|
||||
.get_sop(sop_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("SOP not found: {sop_name}"))?
|
||||
@@ -154,6 +166,7 @@ impl SopEngine {
|
||||
completed_at: None,
|
||||
step_results: Vec::new(),
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
|
||||
self.active_runs.insert(run_id.clone(), run);
|
||||
@@ -283,6 +296,273 @@ impl SopEngine {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return cumulative deterministic execution savings.
|
||||
pub fn deterministic_savings(&self) -> &DeterministicSavings {
|
||||
&self.deterministic_savings
|
||||
}
|
||||
|
||||
// ── Deterministic execution ─────────────────────────────────
|
||||
|
||||
/// Start a deterministic SOP run. Steps execute sequentially without LLM
|
||||
/// round-trips. Returns the first action (DeterministicStep or CheckpointWait).
|
||||
pub fn start_deterministic_run(
|
||||
&mut self,
|
||||
sop_name: &str,
|
||||
event: SopEvent,
|
||||
) -> Result<SopRunAction> {
|
||||
let sop = self
|
||||
.get_sop(sop_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("SOP not found: {sop_name}"))?
|
||||
.clone();
|
||||
|
||||
if sop.execution_mode != SopExecutionMode::Deterministic {
|
||||
bail!(
|
||||
"SOP '{}' is not in deterministic mode (mode: {})",
|
||||
sop_name,
|
||||
sop.execution_mode
|
||||
);
|
||||
}
|
||||
|
||||
if !self.can_start(sop_name) {
|
||||
bail!(
|
||||
"Cannot start SOP '{}': cooldown or concurrency limit reached",
|
||||
sop_name
|
||||
);
|
||||
}
|
||||
|
||||
if sop.steps.is_empty() {
|
||||
bail!("SOP '{}' has no steps defined", sop_name);
|
||||
}
|
||||
|
||||
self.run_counter += 1;
|
||||
let dur = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
let epoch_ms = dur.as_secs() * 1000 + u64::from(dur.subsec_millis());
|
||||
let run_id = format!("det-{epoch_ms}-{:04}", self.run_counter);
|
||||
let now = now_iso8601();
|
||||
|
||||
let total_steps = u32::try_from(sop.steps.len()).unwrap_or(u32::MAX);
|
||||
let run = SopRun {
|
||||
run_id: run_id.clone(),
|
||||
sop_name: sop_name.to_string(),
|
||||
trigger_event: event,
|
||||
status: SopRunStatus::Running,
|
||||
current_step: 1,
|
||||
total_steps,
|
||||
started_at: now,
|
||||
completed_at: None,
|
||||
step_results: Vec::new(),
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
|
||||
self.active_runs.insert(run_id.clone(), run);
|
||||
info!(
|
||||
"Deterministic SOP run {} started for '{}'",
|
||||
run_id, sop_name
|
||||
);
|
||||
|
||||
// Produce first step action
|
||||
let step = sop.steps[0].clone();
|
||||
let input = serde_json::Value::Null;
|
||||
self.resolve_deterministic_action(&sop, &run_id, &step, input)
|
||||
}
|
||||
|
||||
/// Advance a deterministic run with the output of the current step.
|
||||
/// The output is piped as input to the next step.
|
||||
pub fn advance_deterministic_step(
|
||||
&mut self,
|
||||
run_id: &str,
|
||||
step_output: serde_json::Value,
|
||||
) -> Result<SopRunAction> {
|
||||
let run = self
|
||||
.active_runs
|
||||
.get_mut(run_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Active run not found: {run_id}"))?;
|
||||
|
||||
let sop = self
|
||||
.sops
|
||||
.iter()
|
||||
.find(|s| s.name == run.sop_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("SOP '{}' no longer loaded", run.sop_name))?
|
||||
.clone();
|
||||
|
||||
// Record step result
|
||||
let now = now_iso8601();
|
||||
let step_result = SopStepResult {
|
||||
step_number: run.current_step,
|
||||
status: SopStepStatus::Completed,
|
||||
output: step_output.to_string(),
|
||||
started_at: run.started_at.clone(),
|
||||
completed_at: Some(now),
|
||||
};
|
||||
run.step_results.push(step_result);
|
||||
|
||||
// Each deterministic step saves one LLM call
|
||||
run.llm_calls_saved += 1;
|
||||
|
||||
// Advance to next step
|
||||
let next_step_num = run.current_step + 1;
|
||||
if next_step_num > run.total_steps {
|
||||
info!(
|
||||
"Deterministic SOP run {run_id} completed ({} LLM calls saved)",
|
||||
run.llm_calls_saved
|
||||
);
|
||||
let saved = run.llm_calls_saved;
|
||||
self.deterministic_savings.total_llm_calls_saved += saved;
|
||||
self.deterministic_savings.total_runs += 1;
|
||||
return Ok(self.finish_run(run_id, SopRunStatus::Completed, None));
|
||||
}
|
||||
|
||||
let run = self.active_runs.get_mut(run_id).unwrap();
|
||||
run.current_step = next_step_num;
|
||||
|
||||
let step_idx = (next_step_num - 1) as usize;
|
||||
let step = sop.steps[step_idx].clone();
|
||||
let run_id_owned = run_id.to_string();
|
||||
|
||||
self.resolve_deterministic_action(&sop, &run_id_owned, &step, step_output)
|
||||
}
|
||||
|
||||
/// Resume a deterministic run from persisted state.
|
||||
pub fn resume_deterministic_run(
|
||||
&mut self,
|
||||
state: DeterministicRunState,
|
||||
) -> Result<SopRunAction> {
|
||||
let run = self
|
||||
.active_runs
|
||||
.get_mut(&state.run_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Active run not found: {}", state.run_id))?;
|
||||
|
||||
if run.status != SopRunStatus::PausedCheckpoint {
|
||||
bail!(
|
||||
"Run {} is not paused at checkpoint (status: {})",
|
||||
state.run_id,
|
||||
run.status
|
||||
);
|
||||
}
|
||||
|
||||
let sop = self
|
||||
.sops
|
||||
.iter()
|
||||
.find(|s| s.name == run.sop_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("SOP '{}' no longer loaded", run.sop_name))?
|
||||
.clone();
|
||||
|
||||
run.status = SopRunStatus::Running;
|
||||
run.waiting_since = None;
|
||||
run.llm_calls_saved = state.llm_calls_saved;
|
||||
|
||||
// Resume from the step after the last completed one
|
||||
let next_step_num = state.last_completed_step + 1;
|
||||
if next_step_num > state.total_steps {
|
||||
info!(
|
||||
"Deterministic SOP run {} completed on resume ({} LLM calls saved)",
|
||||
state.run_id, state.llm_calls_saved
|
||||
);
|
||||
self.deterministic_savings.total_llm_calls_saved += state.llm_calls_saved;
|
||||
self.deterministic_savings.total_runs += 1;
|
||||
return Ok(self.finish_run(&state.run_id, SopRunStatus::Completed, None));
|
||||
}
|
||||
|
||||
let run = self.active_runs.get_mut(&state.run_id).unwrap();
|
||||
run.current_step = next_step_num;
|
||||
|
||||
let step_idx = (next_step_num - 1) as usize;
|
||||
let step = sop.steps[step_idx].clone();
|
||||
|
||||
// Use last step's output as input, or Null
|
||||
let last_output = state
|
||||
.step_outputs
|
||||
.get(&state.last_completed_step)
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
let run_id = state.run_id.clone();
|
||||
self.resolve_deterministic_action(&sop, &run_id, &step, last_output)
|
||||
}
|
||||
|
||||
/// Resolve the action for a deterministic step (execute or checkpoint).
|
||||
fn resolve_deterministic_action(
|
||||
&mut self,
|
||||
sop: &Sop,
|
||||
run_id: &str,
|
||||
step: &SopStep,
|
||||
input: serde_json::Value,
|
||||
) -> Result<SopRunAction> {
|
||||
if step.kind == SopStepKind::Checkpoint {
|
||||
// Pause at checkpoint — persist state and wait for approval
|
||||
if let Some(run) = self.active_runs.get_mut(run_id) {
|
||||
run.status = SopRunStatus::PausedCheckpoint;
|
||||
run.waiting_since = Some(now_iso8601());
|
||||
}
|
||||
|
||||
let state_file = self.persist_deterministic_state(run_id, sop)?;
|
||||
|
||||
info!(
|
||||
"Deterministic SOP run {run_id}: checkpoint at step {} '{}', state persisted to {}",
|
||||
step.number,
|
||||
step.title,
|
||||
state_file.display()
|
||||
);
|
||||
|
||||
Ok(SopRunAction::CheckpointWait {
|
||||
run_id: run_id.to_string(),
|
||||
step: step.clone(),
|
||||
state_file,
|
||||
})
|
||||
} else {
|
||||
Ok(SopRunAction::DeterministicStep {
|
||||
run_id: run_id.to_string(),
|
||||
step: step.clone(),
|
||||
input,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Persist the current deterministic run state to a JSON file.
|
||||
fn persist_deterministic_state(&self, run_id: &str, sop: &Sop) -> Result<PathBuf> {
|
||||
let run = self
|
||||
.active_runs
|
||||
.get(run_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Run not found: {run_id}"))?;
|
||||
|
||||
let mut step_outputs = HashMap::new();
|
||||
for result in &run.step_results {
|
||||
// Try to parse output as JSON, fall back to string value
|
||||
let value = serde_json::from_str(&result.output)
|
||||
.unwrap_or_else(|_| serde_json::Value::String(result.output.clone()));
|
||||
step_outputs.insert(result.step_number, value);
|
||||
}
|
||||
|
||||
let state = DeterministicRunState {
|
||||
run_id: run_id.to_string(),
|
||||
sop_name: run.sop_name.clone(),
|
||||
last_completed_step: run.current_step.saturating_sub(1),
|
||||
total_steps: run.total_steps,
|
||||
step_outputs,
|
||||
persisted_at: now_iso8601(),
|
||||
llm_calls_saved: run.llm_calls_saved,
|
||||
paused_at_checkpoint: run.status == SopRunStatus::PausedCheckpoint,
|
||||
};
|
||||
|
||||
// Write to SOP location directory, or temp dir
|
||||
let dir = sop.location.as_deref().unwrap_or_else(|| Path::new("."));
|
||||
let state_file = dir.join(format!("{run_id}.state.json"));
|
||||
let json = serde_json::to_string_pretty(&state)?;
|
||||
std::fs::write(&state_file, json)?;
|
||||
|
||||
Ok(state_file)
|
||||
}
|
||||
|
||||
/// Load a persisted deterministic run state from a JSON file.
|
||||
pub fn load_deterministic_state(path: &Path) -> Result<DeterministicRunState> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let state: DeterministicRunState = serde_json::from_str(&content)?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
// ── Approval timeout ──────────────────────────────────────────
|
||||
|
||||
/// Check all WaitingApproval runs for timeout. For Critical/High-priority SOPs,
|
||||
@@ -487,21 +767,21 @@ fn resolve_step_action(sop: &Sop, step: &SopStep, run_id: String, context: Strin
|
||||
}
|
||||
|
||||
let needs_approval = match sop.execution_mode {
|
||||
crate::sop::SopExecutionMode::Auto => false,
|
||||
crate::sop::SopExecutionMode::Supervised => {
|
||||
// Deterministic mode is handled via start_deterministic_run;
|
||||
// if we reach here via the standard path, treat as Auto.
|
||||
SopExecutionMode::Auto | SopExecutionMode::Deterministic => false,
|
||||
SopExecutionMode::Supervised => {
|
||||
// Supervised: approval only before the first step
|
||||
step.number == 1
|
||||
}
|
||||
crate::sop::SopExecutionMode::StepByStep => true,
|
||||
crate::sop::SopExecutionMode::PriorityBased => {
|
||||
match sop.priority {
|
||||
SopPriority::Critical | SopPriority::High => false,
|
||||
SopPriority::Normal | SopPriority::Low => {
|
||||
// Supervised behavior for normal/low
|
||||
step.number == 1
|
||||
}
|
||||
SopExecutionMode::StepByStep => true,
|
||||
SopExecutionMode::PriorityBased => match sop.priority {
|
||||
SopPriority::Critical | SopPriority::High => false,
|
||||
SopPriority::Normal | SopPriority::Low => {
|
||||
// Supervised behavior for normal/low
|
||||
step.number == 1
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if needs_approval {
|
||||
@@ -680,6 +960,8 @@ mod tests {
|
||||
body: "Do step one".into(),
|
||||
suggested_tools: vec!["shell".into()],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
},
|
||||
SopStep {
|
||||
number: 2,
|
||||
@@ -687,11 +969,14 @@ mod tests {
|
||||
body: "Do step two".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
},
|
||||
],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -706,6 +991,8 @@ mod tests {
|
||||
match action {
|
||||
SopRunAction::ExecuteStep { run_id, .. }
|
||||
| SopRunAction::WaitApproval { run_id, .. }
|
||||
| SopRunAction::DeterministicStep { run_id, .. }
|
||||
| SopRunAction::CheckpointWait { run_id, .. }
|
||||
| SopRunAction::Completed { run_id, .. }
|
||||
| SopRunAction::Failed { run_id, .. } => run_id,
|
||||
}
|
||||
@@ -1359,6 +1646,7 @@ mod tests {
|
||||
completed_at: None,
|
||||
step_results: Vec::new(),
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
let ctx = format_step_context(&sop, &run, &sop.steps[0]);
|
||||
assert!(ctx.contains("pump-shutdown"));
|
||||
|
||||
@@ -1,746 +0,0 @@
|
||||
//! Gate evaluation state for ampersona trust-phase transitions.
|
||||
//!
|
||||
//! This module is only compiled when the `ampersona-gates` feature is active
|
||||
//! (module declaration in `mod.rs` is behind `#[cfg]`).
|
||||
//!
|
||||
//! Gate decisions do NOT change SOP execution behavior — this is purely
|
||||
//! observation + phase state tracking + audit logging.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use ampersona_core::spec::gates::Gate;
|
||||
use ampersona_core::state::{PendingTransition, PhaseState, TransitionRecord};
|
||||
use ampersona_core::traits::MetricsProvider;
|
||||
use ampersona_engine::gates::decision::GateDecisionRecord;
|
||||
use ampersona_engine::gates::evaluator::DefaultGateEvaluator;
|
||||
use anyhow::Result;
|
||||
use chrono::Utc;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
|
||||
const PHASE_STATE_KEY: &str = "sop_phase_state";
|
||||
|
||||
fn sop_category() -> MemoryCategory {
|
||||
MemoryCategory::Custom("sop".into())
|
||||
}
|
||||
|
||||
// ── Inner state ────────────────────────────────────────────────
|
||||
|
||||
struct GateEvalInner {
|
||||
phase_state: PhaseState,
|
||||
last_tick: Instant,
|
||||
}
|
||||
|
||||
// ── GateEvalState ──────────────────────────────────────────────
|
||||
|
||||
/// Manages trust-phase gate evaluation state.
|
||||
///
|
||||
/// Single `Mutex<GateEvalInner>` ensures atomic interval-check + evaluate + apply.
|
||||
/// `DefaultGateEvaluator` is a unit struct — called inline, not stored.
|
||||
pub struct GateEvalState {
|
||||
inner: Mutex<GateEvalInner>,
|
||||
memory: Arc<dyn Memory>,
|
||||
gates: Vec<Gate>,
|
||||
tick_interval: Duration,
|
||||
}
|
||||
|
||||
impl GateEvalState {
|
||||
/// Create with fresh (default) phase state.
|
||||
pub fn new(
|
||||
agent_name: &str,
|
||||
gates: Vec<Gate>,
|
||||
interval_secs: u64,
|
||||
memory: Arc<dyn Memory>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Mutex::new(GateEvalInner {
|
||||
phase_state: PhaseState::new(agent_name.to_string()),
|
||||
last_tick: Instant::now(),
|
||||
}),
|
||||
memory,
|
||||
gates,
|
||||
tick_interval: Duration::from_secs(interval_secs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a known phase state (warm-start).
|
||||
pub fn with_state(
|
||||
state: PhaseState,
|
||||
gates: Vec<Gate>,
|
||||
interval_secs: u64,
|
||||
memory: Arc<dyn Memory>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Mutex::new(GateEvalInner {
|
||||
phase_state: state,
|
||||
last_tick: Instant::now(),
|
||||
}),
|
||||
memory,
|
||||
gates,
|
||||
tick_interval: Duration::from_secs(interval_secs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load gate definitions from a persona JSON file.
|
||||
///
|
||||
/// Expects `{"gates": [...]}` at the top level. Missing file → empty Vec.
|
||||
/// Parse error → warn log + empty Vec.
|
||||
pub fn load_gates_from_file(path: &Path) -> Vec<Gate> {
|
||||
let content = match std::fs::read_to_string(path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PersonaGates {
|
||||
#[serde(default)]
|
||||
gates: Vec<Gate>,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<PersonaGates>(&content) {
|
||||
Ok(parsed) => parsed.gates,
|
||||
Err(e) => {
|
||||
warn!(path = %path.display(), error = %e, "failed to parse gates from persona file");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rebuild from Memory backend (warm-start).
|
||||
///
|
||||
/// Loads `PhaseState` from Memory key `sop_phase_state`, loads gates from
|
||||
/// file, falls back to fresh state on parse error.
|
||||
pub async fn rebuild_from_memory(
|
||||
memory: Arc<dyn Memory>,
|
||||
agent_name: &str,
|
||||
gates_file: Option<&Path>,
|
||||
interval_secs: u64,
|
||||
) -> Result<Self> {
|
||||
let gates = gates_file
|
||||
.map(Self::load_gates_from_file)
|
||||
.unwrap_or_default();
|
||||
|
||||
let phase_state = match memory.get(PHASE_STATE_KEY).await? {
|
||||
Some(entry) => match serde_json::from_str::<PhaseState>(&entry.content) {
|
||||
Ok(state) => {
|
||||
info!(
|
||||
phase = ?state.current_phase,
|
||||
rev = state.state_rev,
|
||||
"gate eval warm-started from memory"
|
||||
);
|
||||
state
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "failed to parse phase state from memory, using fresh state");
|
||||
PhaseState::new(agent_name.to_string())
|
||||
}
|
||||
},
|
||||
None => PhaseState::new(agent_name.to_string()),
|
||||
};
|
||||
|
||||
Ok(Self::with_state(phase_state, gates, interval_secs, memory))
|
||||
}
|
||||
|
||||
/// Atomic tick: interval check + evaluate + apply under single lock.
|
||||
///
|
||||
/// Returns `Some(record)` if a gate fired, `None` otherwise.
|
||||
pub fn tick(&self, metrics: &dyn MetricsProvider) -> Option<GateDecisionRecord> {
|
||||
let _span = tracing::info_span!("gate_eval_tick", gates = self.gates.len()).entered();
|
||||
|
||||
// interval_secs=0 means disabled
|
||||
if self.tick_interval.is_zero() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self.inner.is_poisoned() {
|
||||
error!("gate eval mutex poisoned — loss of gate evaluation until restart");
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut inner = self.inner.lock().ok()?;
|
||||
|
||||
// Check interval
|
||||
if inner.last_tick.elapsed() < self.tick_interval {
|
||||
return None;
|
||||
}
|
||||
inner.last_tick = Instant::now();
|
||||
|
||||
// Evaluate
|
||||
let record = DefaultGateEvaluator.evaluate(&self.gates, &inner.phase_state, metrics);
|
||||
|
||||
match record {
|
||||
Some(ref record) => {
|
||||
// Apply decision in-place under the same lock
|
||||
apply_decision(&mut inner.phase_state, record);
|
||||
info!(
|
||||
gate_id = %record.gate_id,
|
||||
decision = %record.decision,
|
||||
from = ?record.from_phase,
|
||||
to = %record.to_phase,
|
||||
"gate decision"
|
||||
);
|
||||
}
|
||||
None => {
|
||||
debug!("no gate fired");
|
||||
}
|
||||
}
|
||||
|
||||
record
|
||||
}
|
||||
|
||||
/// Persist current phase state to Memory.
|
||||
pub async fn persist(&self) -> Result<()> {
|
||||
let content = {
|
||||
let inner = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("gate eval lock poisoned: {e}"))?;
|
||||
serde_json::to_string_pretty(&inner.phase_state)?
|
||||
};
|
||||
self.memory
|
||||
.store(PHASE_STATE_KEY, &content, sop_category(), None)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Snapshot of current phase state (for diagnostics / sop_status).
|
||||
pub fn phase_state_snapshot(&self) -> Option<PhaseState> {
|
||||
self.inner.lock().ok().map(|g| g.phase_state.clone())
|
||||
}
|
||||
|
||||
/// Number of loaded gate definitions.
|
||||
pub fn gate_count(&self) -> usize {
|
||||
self.gates.len()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Decision application ───────────────────────────────────────
|
||||
|
||||
fn apply_decision(state: &mut PhaseState, record: &GateDecisionRecord) {
|
||||
match record.decision.as_str() {
|
||||
"transition" => {
|
||||
state.current_phase = Some(record.to_phase.clone());
|
||||
state.state_rev += 1;
|
||||
state.last_transition = Some(TransitionRecord {
|
||||
gate_id: record.gate_id.clone(),
|
||||
from_phase: record.from_phase.clone(),
|
||||
to_phase: record.to_phase.clone(),
|
||||
at: Utc::now(),
|
||||
decision_id: format!(
|
||||
"{}-{}-{}",
|
||||
record.gate_id, record.state_rev, record.metrics_hash
|
||||
),
|
||||
metrics_hash: Some(record.metrics_hash.clone()),
|
||||
state_rev: state.state_rev,
|
||||
});
|
||||
state.pending_transition = None;
|
||||
state.updated_at = Utc::now();
|
||||
}
|
||||
"observed" => {
|
||||
debug!(
|
||||
gate_id = %record.gate_id,
|
||||
"observed gate — no state change"
|
||||
);
|
||||
}
|
||||
"pending_human" => {
|
||||
state.pending_transition = Some(PendingTransition {
|
||||
gate_id: record.gate_id.clone(),
|
||||
from_phase: record.from_phase.clone(),
|
||||
to_phase: record.to_phase.clone(),
|
||||
decision: record.decision.clone(),
|
||||
metrics_hash: record.metrics_hash.clone(),
|
||||
state_rev: record.state_rev,
|
||||
created_at: Utc::now(),
|
||||
});
|
||||
state.updated_at = Utc::now();
|
||||
}
|
||||
other => {
|
||||
warn!(decision = %other, gate_id = %record.gate_id, "unknown gate decision — skipping");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ampersona_core::errors::MetricError;
|
||||
use ampersona_core::spec::gates::Gate;
|
||||
use ampersona_core::traits::{MetricQuery, MetricSample};
|
||||
use ampersona_core::types::{CriterionOp, GateApproval, GateDirection, GateEnforcement};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ── Mock MetricsProvider ──────────────────────────────────
|
||||
|
||||
struct MockMetrics {
|
||||
values: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl MockMetrics {
|
||||
fn new(values: Vec<(&str, serde_json::Value)>) -> Self {
|
||||
Self {
|
||||
values: values
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricsProvider for MockMetrics {
|
||||
fn get_metric(&self, query: &MetricQuery) -> Result<MetricSample, MetricError> {
|
||||
self.values
|
||||
.get(&query.name)
|
||||
.cloned()
|
||||
.map(|value| MetricSample {
|
||||
name: query.name.clone(),
|
||||
value,
|
||||
sampled_at: Utc::now(),
|
||||
})
|
||||
.ok_or_else(|| MetricError::NotFound(query.name.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────
|
||||
|
||||
fn make_promote_gate(
|
||||
id: &str,
|
||||
metric: &str,
|
||||
op: CriterionOp,
|
||||
value: serde_json::Value,
|
||||
to_phase: &str,
|
||||
) -> Gate {
|
||||
Gate {
|
||||
id: id.into(),
|
||||
direction: GateDirection::Promote,
|
||||
enforcement: GateEnforcement::Enforce,
|
||||
priority: 0,
|
||||
cooldown_seconds: 0,
|
||||
from_phase: None,
|
||||
to_phase: to_phase.into(),
|
||||
criteria: vec![ampersona_core::spec::gates::Criterion {
|
||||
metric: metric.into(),
|
||||
op,
|
||||
value,
|
||||
window_seconds: None,
|
||||
}],
|
||||
metrics_schema: None,
|
||||
approval: GateApproval::Auto,
|
||||
on_pass: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_memory() -> Arc<dyn Memory> {
|
||||
let mem_cfg = crate::config::MemoryConfig {
|
||||
backend: "sqlite".into(),
|
||||
..crate::config::MemoryConfig::default()
|
||||
};
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap())
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tick_no_gates_returns_none() {
|
||||
let mem = test_memory();
|
||||
let ge = GateEvalState::new("test-agent", vec![], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![]);
|
||||
// Force past interval
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
assert!(ge.tick(&metrics).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tick_with_passing_gate_returns_decision() {
|
||||
let mem = test_memory();
|
||||
let gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
let ge = GateEvalState::new("test-agent", vec![gate], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.9))]);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let record = ge.tick(&metrics);
|
||||
assert!(record.is_some());
|
||||
let record = record.unwrap();
|
||||
assert_eq!(record.gate_id, "g1");
|
||||
assert_eq!(record.to_phase, "active");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tick_transition_advances_phase() {
|
||||
let mem = test_memory();
|
||||
let gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
let ge = GateEvalState::new("test-agent", vec![gate], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.95))]);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
ge.tick(&metrics);
|
||||
|
||||
let snap = ge.phase_state_snapshot().unwrap();
|
||||
assert_eq!(snap.current_phase, Some("active".into()));
|
||||
assert!(snap.state_rev > 0);
|
||||
assert!(snap.last_transition.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tick_observed_no_state_change() {
|
||||
let mem = test_memory();
|
||||
let mut gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
gate.enforcement = GateEnforcement::Observe;
|
||||
let ge = GateEvalState::new("test-agent", vec![gate], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.95))]);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let record = ge.tick(&metrics);
|
||||
assert!(record.is_some());
|
||||
assert_eq!(record.unwrap().decision, "observed");
|
||||
|
||||
let snap = ge.phase_state_snapshot().unwrap();
|
||||
assert!(snap.current_phase.is_none()); // no change
|
||||
assert_eq!(snap.state_rev, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tick_pending_human_sets_pending() {
|
||||
let mem = test_memory();
|
||||
let mut gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
gate.approval = GateApproval::Human;
|
||||
let ge = GateEvalState::new("test-agent", vec![gate], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.95))]);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let record = ge.tick(&metrics);
|
||||
assert!(record.is_some());
|
||||
assert_eq!(record.unwrap().decision, "pending_human");
|
||||
|
||||
let snap = ge.phase_state_snapshot().unwrap();
|
||||
assert!(snap.pending_transition.is_some());
|
||||
assert_eq!(snap.pending_transition.unwrap().to_phase, "active");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_gates_missing_file_returns_empty() {
|
||||
let gates = GateEvalState::load_gates_from_file(Path::new("/nonexistent/persona.json"));
|
||||
assert!(gates.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_gates_valid_persona() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("persona.json");
|
||||
std::fs::write(
|
||||
&path,
|
||||
r#"{
|
||||
"gates": [{
|
||||
"id": "g1",
|
||||
"direction": "promote",
|
||||
"to_phase": "active",
|
||||
"criteria": [{"metric": "sop.completion_rate", "op": "gte", "value": 0.8}]
|
||||
}]
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let gates = GateEvalState::load_gates_from_file(&path);
|
||||
assert_eq!(gates.len(), 1);
|
||||
assert_eq!(gates[0].id, "g1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_gates_no_gates_key_returns_empty() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("persona.json");
|
||||
std::fs::write(&path, r#"{"name": "test"}"#).unwrap();
|
||||
let gates = GateEvalState::load_gates_from_file(&path);
|
||||
assert!(gates.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_gates_invalid_json_returns_empty() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("persona.json");
|
||||
std::fs::write(&path, "not json at all {{{").unwrap();
|
||||
let gates = GateEvalState::load_gates_from_file(&path);
|
||||
assert!(gates.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warm_start_roundtrip() {
|
||||
let mem = test_memory();
|
||||
let gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
|
||||
// Create, tick to advance state, persist
|
||||
let ge = GateEvalState::new("test-agent", vec![gate.clone()], 1, Arc::clone(&mem));
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.95))]);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
ge.tick(&metrics);
|
||||
ge.persist().await.unwrap();
|
||||
|
||||
// Write gates file for rebuild
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let gates_path = dir.path().join("persona.json");
|
||||
std::fs::write(
|
||||
&gates_path,
|
||||
serde_json::to_string(&serde_json::json!({"gates": [gate]})).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Rebuild
|
||||
let ge2 = GateEvalState::rebuild_from_memory(
|
||||
Arc::clone(&mem),
|
||||
"test-agent",
|
||||
Some(gates_path.as_path()),
|
||||
1,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let snap = ge2.phase_state_snapshot().unwrap();
|
||||
assert_eq!(snap.current_phase, Some("active".into()));
|
||||
assert!(snap.state_rev > 0);
|
||||
assert_eq!(ge2.gate_count(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warm_start_empty_memory() {
|
||||
let mem = test_memory();
|
||||
let ge = GateEvalState::rebuild_from_memory(Arc::clone(&mem), "test-agent", None, 60)
|
||||
.await
|
||||
.unwrap();
|
||||
let snap = ge.phase_state_snapshot().unwrap();
|
||||
assert!(snap.current_phase.is_none());
|
||||
assert_eq!(snap.state_rev, 0);
|
||||
assert_eq!(ge.gate_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn demote_priority_over_promote() {
|
||||
let mem = test_memory();
|
||||
let promote = make_promote_gate(
|
||||
"promote-g",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
let mut demote = make_promote_gate(
|
||||
"demote-g",
|
||||
"sop.deviation_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.3),
|
||||
"restricted",
|
||||
);
|
||||
demote.direction = GateDirection::Demote;
|
||||
demote.from_phase = Some("active".into());
|
||||
|
||||
let state = PhaseState {
|
||||
current_phase: Some("active".into()),
|
||||
..PhaseState::new("test-agent".into())
|
||||
};
|
||||
let ge = GateEvalState::with_state(state, vec![promote, demote], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![
|
||||
("sop.completion_rate", json!(0.95)),
|
||||
("sop.deviation_rate", json!(0.5)),
|
||||
]);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let record = ge.tick(&metrics).unwrap();
|
||||
// Demote should fire first (evaluator sorts demote before promote)
|
||||
assert_eq!(record.gate_id, "demote-g");
|
||||
assert_eq!(record.to_phase, "restricted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn idempotent_tick_after_apply() {
|
||||
let mem = test_memory();
|
||||
let gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
let ge = GateEvalState::new("test-agent", vec![gate], 1, mem);
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.95))]);
|
||||
|
||||
// First tick — fires
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let first = ge.tick(&metrics);
|
||||
assert!(first.is_some());
|
||||
|
||||
// Second tick with same metrics + updated state_rev — should not fire again
|
||||
// (evaluator idempotency via metrics_hash + state_rev)
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let second = ge.tick(&metrics);
|
||||
assert!(second.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gate_tick_with_real_collector() {
|
||||
use crate::sop::metrics::SopMetricsCollector;
|
||||
use crate::sop::types::{
|
||||
SopEvent, SopRun, SopRunStatus, SopStepResult, SopStepStatus, SopTriggerSource,
|
||||
};
|
||||
|
||||
let mem = test_memory();
|
||||
let collector = SopMetricsCollector::new();
|
||||
|
||||
// Record a completed run
|
||||
let run = SopRun {
|
||||
run_id: "r1".into(),
|
||||
sop_name: "test-sop".into(),
|
||||
trigger_event: SopEvent {
|
||||
source: SopTriggerSource::Manual,
|
||||
topic: None,
|
||||
payload: None,
|
||||
timestamp: "2026-02-19T12:00:00Z".into(),
|
||||
},
|
||||
status: SopRunStatus::Completed,
|
||||
current_step: 1,
|
||||
total_steps: 1,
|
||||
started_at: "2026-02-19T12:00:00Z".into(),
|
||||
completed_at: Some("2026-02-19T12:05:00Z".into()),
|
||||
step_results: vec![SopStepResult {
|
||||
step_number: 1,
|
||||
status: SopStepStatus::Completed,
|
||||
output: "done".into(),
|
||||
started_at: "2026-02-19T12:00:00Z".into(),
|
||||
completed_at: Some("2026-02-19T12:01:00Z".into()),
|
||||
}],
|
||||
waiting_since: None,
|
||||
};
|
||||
collector.record_run_complete(&run);
|
||||
|
||||
let gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
let ge = GateEvalState::new("test-agent", vec![gate], 1, mem);
|
||||
{
|
||||
let mut inner = ge.inner.lock().unwrap();
|
||||
inner.last_tick = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
|
||||
}
|
||||
let record = ge.tick(&collector);
|
||||
assert!(record.is_some());
|
||||
assert_eq!(record.unwrap().to_phase, "active");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tick_respects_interval() {
|
||||
let mem = test_memory();
|
||||
let gate = make_promote_gate(
|
||||
"g1",
|
||||
"sop.completion_rate",
|
||||
CriterionOp::Gte,
|
||||
json!(0.8),
|
||||
"active",
|
||||
);
|
||||
|
||||
// Long interval
|
||||
let ge = GateEvalState::new("test-agent", vec![gate.clone()], 3600, mem.clone());
|
||||
let metrics = MockMetrics::new(vec![("sop.completion_rate", json!(0.95))]);
|
||||
// last_tick is Instant::now() — not enough elapsed
|
||||
assert!(ge.tick(&metrics).is_none());
|
||||
|
||||
// Zero interval = disabled
|
||||
let ge_disabled = GateEvalState::new("test-agent", vec![gate], 0, mem);
|
||||
assert!(ge_disabled.tick(&metrics).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ampersona_decision_strings_stable() {
|
||||
// Canary test: verifies that DefaultGateEvaluator produces the decision
|
||||
// strings we expect. If ampersona changes them, this test fails.
|
||||
let state = PhaseState::new("test".into());
|
||||
|
||||
// Enforce promote → "transition"
|
||||
let enforce_gate =
|
||||
make_promote_gate("g-enforce", "m", CriterionOp::Gte, json!(1), "phase-b");
|
||||
let metrics = MockMetrics::new(vec![("m", json!(1))]);
|
||||
let record = DefaultGateEvaluator.evaluate(&[enforce_gate], &state, &metrics);
|
||||
assert_eq!(
|
||||
record.as_ref().map(|r| r.decision.as_str()),
|
||||
Some("transition")
|
||||
);
|
||||
|
||||
// Observe promote → "observed"
|
||||
let mut observe_gate =
|
||||
make_promote_gate("g-observe", "m", CriterionOp::Gte, json!(1), "phase-b");
|
||||
observe_gate.enforcement = GateEnforcement::Observe;
|
||||
let record = DefaultGateEvaluator.evaluate(&[observe_gate], &state, &metrics);
|
||||
assert_eq!(
|
||||
record.as_ref().map(|r| r.decision.as_str()),
|
||||
Some("observed")
|
||||
);
|
||||
|
||||
// RequireApproval promote → "pending_human"
|
||||
let mut approval_gate =
|
||||
make_promote_gate("g-approval", "m", CriterionOp::Gte, json!(1), "phase-b");
|
||||
approval_gate.approval = GateApproval::Human;
|
||||
let record = DefaultGateEvaluator.evaluate(&[approval_gate], &state, &metrics);
|
||||
assert_eq!(
|
||||
record.as_ref().map(|r| r.decision.as_str()),
|
||||
Some("pending_human")
|
||||
);
|
||||
}
|
||||
}
|
||||
+9
-96
@@ -413,34 +413,6 @@ impl Default for SopMetricsCollector {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Conditional MetricsProvider impl ───────────────────────────
|
||||
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
impl ampersona_core::traits::MetricsProvider for SopMetricsCollector {
|
||||
fn get_metric(
|
||||
&self,
|
||||
query: &ersona_core::traits::MetricQuery,
|
||||
) -> Result<ampersona_core::traits::MetricSample, ampersona_core::errors::MetricError> {
|
||||
if self.inner.is_poisoned() {
|
||||
return Err(ampersona_core::errors::MetricError::ProviderUnavailable);
|
||||
}
|
||||
let value = if let Some(ref window) = query.window {
|
||||
// Window specified by evaluator (from Criterion.window_seconds)
|
||||
self.get_metric_value_windowed(&query.name, window)
|
||||
} else {
|
||||
// No window — use name as-is (may include _7d/_30d suffix or be all-time)
|
||||
self.get_metric_value(&query.name)
|
||||
};
|
||||
value
|
||||
.map(|v| ampersona_core::traits::MetricSample {
|
||||
name: query.name.clone(),
|
||||
value: v,
|
||||
sampled_at: Utc::now(),
|
||||
})
|
||||
.ok_or_else(|| ampersona_core::errors::MetricError::NotFound(query.name.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────
|
||||
|
||||
fn build_snapshot(run: &SopRun, human_count: u64, timeout_count: u64) -> RunSnapshot {
|
||||
@@ -637,6 +609,9 @@ mod tests {
|
||||
total_steps: u32,
|
||||
step_results: Vec<SopStepResult>,
|
||||
) -> SopRun {
|
||||
let now = Utc::now();
|
||||
let started = (now - chrono::Duration::minutes(5)).to_rfc3339();
|
||||
let completed = now.to_rfc3339();
|
||||
SopRun {
|
||||
run_id: run_id.into(),
|
||||
sop_name: sop_name.into(),
|
||||
@@ -644,10 +619,11 @@ mod tests {
|
||||
status,
|
||||
current_step: total_steps,
|
||||
total_steps,
|
||||
started_at: "2026-02-19T12:00:00Z".into(),
|
||||
completed_at: Some("2026-02-19T12:05:00Z".into()),
|
||||
started_at: started,
|
||||
completed_at: Some(completed),
|
||||
step_results,
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1141,43 +1117,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
// ── MetricsProvider impl (ampersona-gates feature) ───────
|
||||
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
#[test]
|
||||
fn metrics_provider_get_metric() {
|
||||
use ampersona_core::traits::{MetricQuery, MetricsProvider};
|
||||
|
||||
let c = SopMetricsCollector::new();
|
||||
let run = make_run(
|
||||
"r1",
|
||||
"test-sop",
|
||||
SopRunStatus::Completed,
|
||||
1,
|
||||
vec![make_step(1, SopStepStatus::Completed)],
|
||||
);
|
||||
c.record_run_complete(&run);
|
||||
|
||||
let query = MetricQuery {
|
||||
name: "sop.runs_completed".into(),
|
||||
window: None,
|
||||
};
|
||||
let sample = c.get_metric(&query).unwrap();
|
||||
assert_eq!(sample.value, json!(1u64));
|
||||
assert_eq!(sample.name, "sop.runs_completed");
|
||||
|
||||
// NotFound for unknown metric
|
||||
let bad_query = MetricQuery {
|
||||
name: "sop.nonexistent".into(),
|
||||
window: None,
|
||||
};
|
||||
let err = c.get_metric(&bad_query).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
ampersona_core::errors::MetricError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
// ── Warm-start tests ─────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1245,6 +1184,7 @@ mod tests {
|
||||
completed_at: None,
|
||||
step_results: vec![],
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
audit.log_run_start(&run).await.unwrap();
|
||||
|
||||
@@ -1342,6 +1282,7 @@ mod tests {
|
||||
completed_at: None,
|
||||
step_results: vec![],
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
audit.log_run_start(&running_run).await.unwrap();
|
||||
audit.log_approval(&running_run, 1).await.unwrap();
|
||||
@@ -1385,7 +1326,7 @@ mod tests {
|
||||
assert_eq!(hic_7d, 1);
|
||||
}
|
||||
|
||||
// ── Windowed MetricsProvider tests (ampersona-gates feature) ──
|
||||
// ── Windowed MetricsProvider tests ──
|
||||
|
||||
#[test]
|
||||
fn get_metric_windowed_7d_matches_suffix() {
|
||||
@@ -1461,32 +1402,4 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(val, 2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
#[test]
|
||||
fn get_metric_provider_window_propagation() {
|
||||
use ampersona_core::traits::{MetricQuery, MetricsProvider};
|
||||
|
||||
let c = SopMetricsCollector::new();
|
||||
let run = make_run(
|
||||
"r1",
|
||||
"test-sop",
|
||||
SopRunStatus::Completed,
|
||||
1,
|
||||
vec![make_step(1, SopStepStatus::Completed)],
|
||||
);
|
||||
c.record_run_complete(&run);
|
||||
|
||||
// Query with window via MetricsProvider trait
|
||||
let query = MetricQuery {
|
||||
name: "sop.runs_completed".into(),
|
||||
window: Some(std::time::Duration::from_secs(7 * 86400)),
|
||||
};
|
||||
let sample = c.get_metric(&query).unwrap();
|
||||
assert_eq!(sample.value, json!(1u64));
|
||||
|
||||
// Same result as suffix-based query
|
||||
let suffix_val = c.get_metric_value("sop.runs_completed_7d");
|
||||
assert_eq!(Some(sample.value), suffix_val);
|
||||
}
|
||||
}
|
||||
|
||||
+61
-14
@@ -2,20 +2,17 @@ pub mod audit;
|
||||
pub mod condition;
|
||||
pub mod dispatch;
|
||||
pub mod engine;
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
pub mod gates;
|
||||
pub mod metrics;
|
||||
pub mod types;
|
||||
|
||||
pub use audit::SopAuditLogger;
|
||||
pub use engine::SopEngine;
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
pub use gates::GateEvalState;
|
||||
pub use metrics::SopMetricsCollector;
|
||||
#[allow(unused_imports)]
|
||||
pub use types::{
|
||||
Sop, SopEvent, SopExecutionMode, SopPriority, SopRun, SopRunAction, SopRunStatus, SopStep,
|
||||
SopStepResult, SopStepStatus, SopTrigger, SopTriggerSource,
|
||||
DeterministicRunState, DeterministicSavings, Sop, SopEvent, SopExecutionMode, SopPriority,
|
||||
SopRun, SopRunAction, SopRunStatus, SopStep, SopStepKind, SopStepResult, SopStepStatus,
|
||||
SopTrigger, SopTriggerSource, StepSchema,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@@ -24,6 +21,19 @@ use tracing::warn;
|
||||
|
||||
use types::{SopManifest, SopMeta};
|
||||
|
||||
/// Parse an execution mode string into `SopExecutionMode`, falling back to
|
||||
/// `Supervised` for unknown values.
|
||||
pub fn parse_execution_mode(s: &str) -> SopExecutionMode {
|
||||
match s.trim().to_lowercase().as_str() {
|
||||
"auto" => SopExecutionMode::Auto,
|
||||
"step_by_step" => SopExecutionMode::StepByStep,
|
||||
"priority_based" => SopExecutionMode::PriorityBased,
|
||||
"deterministic" => SopExecutionMode::Deterministic,
|
||||
// "supervised" and any unknown value
|
||||
_ => SopExecutionMode::Supervised,
|
||||
}
|
||||
}
|
||||
|
||||
// ── SOP directory helpers ───────────────────────────────────────
|
||||
|
||||
/// Return the default SOPs directory: `<workspace>/sops`.
|
||||
@@ -112,19 +122,28 @@ fn load_sop(sop_dir: &Path, default_execution_mode: SopExecutionMode) -> Result<
|
||||
execution_mode,
|
||||
cooldown_secs,
|
||||
max_concurrent,
|
||||
deterministic,
|
||||
} = manifest.sop;
|
||||
|
||||
// When deterministic=true, override execution_mode to Deterministic
|
||||
let effective_mode = if deterministic {
|
||||
SopExecutionMode::Deterministic
|
||||
} else {
|
||||
execution_mode.unwrap_or(default_execution_mode)
|
||||
};
|
||||
|
||||
Ok(Sop {
|
||||
name,
|
||||
description,
|
||||
version,
|
||||
priority,
|
||||
execution_mode: execution_mode.unwrap_or(default_execution_mode),
|
||||
execution_mode: effective_mode,
|
||||
triggers: manifest.triggers,
|
||||
steps,
|
||||
cooldown_secs,
|
||||
max_concurrent,
|
||||
location: Some(sop_dir.to_path_buf()),
|
||||
deterministic,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -143,6 +162,7 @@ pub fn parse_steps(md: &str) -> Vec<SopStep> {
|
||||
let mut current_body = String::new();
|
||||
let mut current_tools: Vec<String> = Vec::new();
|
||||
let mut current_requires_confirmation = false;
|
||||
let mut current_kind = SopStepKind::Execute;
|
||||
|
||||
for line in md.lines() {
|
||||
let trimmed = line.trim();
|
||||
@@ -164,6 +184,7 @@ pub fn parse_steps(md: &str) -> Vec<SopStep> {
|
||||
&mut current_body,
|
||||
&mut current_tools,
|
||||
&mut current_requires_confirmation,
|
||||
&mut current_kind,
|
||||
);
|
||||
in_steps_section = false;
|
||||
}
|
||||
@@ -184,6 +205,7 @@ pub fn parse_steps(md: &str) -> Vec<SopStep> {
|
||||
&mut current_body,
|
||||
&mut current_tools,
|
||||
&mut current_requires_confirmation,
|
||||
&mut current_kind,
|
||||
);
|
||||
|
||||
let step_num = u32::try_from(steps.len())
|
||||
@@ -217,6 +239,15 @@ pub fn parse_steps(md: &str) -> Vec<SopStep> {
|
||||
if let Some(val) = bullet.strip_prefix("requires_confirmation:") {
|
||||
current_requires_confirmation = val.trim().eq_ignore_ascii_case("true");
|
||||
}
|
||||
} else if bullet.starts_with("kind:") {
|
||||
if let Some(val) = bullet.strip_prefix("kind:") {
|
||||
let val = val.trim();
|
||||
if val.eq_ignore_ascii_case("checkpoint") {
|
||||
current_kind = SopStepKind::Checkpoint;
|
||||
} else {
|
||||
current_kind = SopStepKind::Execute;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Continuation body line
|
||||
if !current_body.is_empty() {
|
||||
@@ -244,6 +275,7 @@ pub fn parse_steps(md: &str) -> Vec<SopStep> {
|
||||
&mut current_body,
|
||||
&mut current_tools,
|
||||
&mut current_requires_confirmation,
|
||||
&mut current_kind,
|
||||
);
|
||||
|
||||
steps
|
||||
@@ -257,6 +289,7 @@ fn flush_step(
|
||||
body: &mut String,
|
||||
tools: &mut Vec<String>,
|
||||
requires_confirmation: &mut bool,
|
||||
kind: &mut SopStepKind,
|
||||
) {
|
||||
if let Some(n) = number.take() {
|
||||
steps.push(SopStep {
|
||||
@@ -265,9 +298,12 @@ fn flush_step(
|
||||
body: body.trim().to_string(),
|
||||
suggested_tools: std::mem::take(tools),
|
||||
requires_confirmation: *requires_confirmation,
|
||||
kind: *kind,
|
||||
schema: None,
|
||||
});
|
||||
*body = String::new();
|
||||
*requires_confirmation = false;
|
||||
*kind = SopStepKind::Execute;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,7 +385,7 @@ pub fn handle_command(command: crate::SopCommands, config: &crate::config::Confi
|
||||
let sops = load_sops(
|
||||
&config.workspace_dir,
|
||||
sops_dir_override,
|
||||
config.sop.default_execution_mode,
|
||||
parse_execution_mode(&config.sop.default_execution_mode),
|
||||
);
|
||||
if sops.is_empty() {
|
||||
println!("No SOPs found.");
|
||||
@@ -393,7 +429,7 @@ pub fn handle_command(command: crate::SopCommands, config: &crate::config::Confi
|
||||
let sops = load_sops(
|
||||
&config.workspace_dir,
|
||||
sops_dir_override,
|
||||
config.sop.default_execution_mode,
|
||||
parse_execution_mode(&config.sop.default_execution_mode),
|
||||
);
|
||||
let matching: Vec<&Sop> = if let Some(ref name) = name {
|
||||
sops.iter().filter(|s| s.name == *name).collect()
|
||||
@@ -443,7 +479,7 @@ pub fn handle_command(command: crate::SopCommands, config: &crate::config::Confi
|
||||
let sops = load_sops(
|
||||
&config.workspace_dir,
|
||||
sops_dir_override,
|
||||
config.sop.default_execution_mode,
|
||||
parse_execution_mode(&config.sop.default_execution_mode),
|
||||
);
|
||||
let sop = sops
|
||||
.iter()
|
||||
@@ -474,16 +510,23 @@ pub fn handle_command(command: crate::SopCommands, config: &crate::config::Confi
|
||||
if !sop.steps.is_empty() {
|
||||
println!("Steps:");
|
||||
for step in &sop.steps {
|
||||
let confirm_tag = if step.requires_confirmation {
|
||||
" [requires confirmation]"
|
||||
let mut tags = Vec::new();
|
||||
if step.requires_confirmation {
|
||||
tags.push("requires confirmation");
|
||||
}
|
||||
if step.kind == SopStepKind::Checkpoint {
|
||||
tags.push("checkpoint");
|
||||
}
|
||||
let tag_str = if tags.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
""
|
||||
format!(" [{}]", tags.join(", "))
|
||||
};
|
||||
println!(
|
||||
" {}. {}{}",
|
||||
step.number,
|
||||
console::style(&step.title).bold(),
|
||||
confirm_tag
|
||||
tag_str
|
||||
);
|
||||
if !step.body.is_empty() {
|
||||
for line in step.body.lines() {
|
||||
@@ -705,6 +748,7 @@ type = "manual"
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
};
|
||||
|
||||
let warnings = validate_sop(&sop);
|
||||
@@ -729,10 +773,13 @@ type = "manual"
|
||||
body: "Do the thing".into(),
|
||||
suggested_tools: vec!["shell".into()],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
}],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
};
|
||||
|
||||
let warnings = validate_sop(&sop);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -42,6 +43,10 @@ pub enum SopExecutionMode {
|
||||
StepByStep,
|
||||
/// Critical/High → Auto, Normal/Low → Supervised.
|
||||
PriorityBased,
|
||||
/// Execute steps sequentially without LLM round-trips.
|
||||
/// Step outputs are piped as inputs to the next step.
|
||||
/// Checkpoint steps pause for human approval.
|
||||
Deterministic,
|
||||
}
|
||||
|
||||
impl fmt::Display for SopExecutionMode {
|
||||
@@ -51,6 +56,7 @@ impl fmt::Display for SopExecutionMode {
|
||||
Self::Supervised => write!(f, "supervised"),
|
||||
Self::StepByStep => write!(f, "step_by_step"),
|
||||
Self::PriorityBased => write!(f, "priority_based"),
|
||||
Self::Deterministic => write!(f, "deterministic"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,6 +99,44 @@ impl fmt::Display for SopTrigger {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Step kind ────────────────────────────────────────────────────
|
||||
|
||||
/// The kind of a workflow step.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SopStepKind {
|
||||
/// Normal step — executed by the agent (or deterministic handler).
|
||||
#[default]
|
||||
Execute,
|
||||
/// Checkpoint step — pauses execution and waits for human approval.
|
||||
Checkpoint,
|
||||
}
|
||||
|
||||
impl fmt::Display for SopStepKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Execute => write!(f, "execute"),
|
||||
Self::Checkpoint => write!(f, "checkpoint"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Typed step parameters ────────────────────────────────────────
|
||||
|
||||
/// JSON Schema fragment for validating step input/output data.
|
||||
///
|
||||
/// Stored as a raw `serde_json::Value` so callers can validate without
|
||||
/// pulling in a full JSON Schema library.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct StepSchema {
|
||||
/// JSON Schema object describing expected input shape.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub input: Option<serde_json::Value>,
|
||||
/// JSON Schema object describing expected output shape.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ── Step ────────────────────────────────────────────────────────
|
||||
|
||||
/// A single step in an SOP procedure, parsed from SOP.md.
|
||||
@@ -105,6 +149,12 @@ pub struct SopStep {
|
||||
pub suggested_tools: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub requires_confirmation: bool,
|
||||
/// Step kind: `execute` (default) or `checkpoint`.
|
||||
#[serde(default)]
|
||||
pub kind: SopStepKind,
|
||||
/// Typed input/output schemas for deterministic data flow validation.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub schema: Option<StepSchema>,
|
||||
}
|
||||
|
||||
// ── SOP ─────────────────────────────────────────────────────────
|
||||
@@ -125,6 +175,10 @@ pub struct Sop {
|
||||
pub max_concurrent: u32,
|
||||
#[serde(skip)]
|
||||
pub location: Option<PathBuf>,
|
||||
/// When true, sets execution_mode to Deterministic.
|
||||
/// Steps execute sequentially without LLM round-trips.
|
||||
#[serde(default)]
|
||||
pub deterministic: bool,
|
||||
}
|
||||
|
||||
fn default_cooldown_secs() -> u64 {
|
||||
@@ -160,6 +214,9 @@ pub(crate) struct SopMeta {
|
||||
pub cooldown_secs: u64,
|
||||
#[serde(default = "default_max_concurrent")]
|
||||
pub max_concurrent: u32,
|
||||
/// Opt-in deterministic execution (no LLM round-trips between steps).
|
||||
#[serde(default)]
|
||||
pub deterministic: bool,
|
||||
}
|
||||
|
||||
fn default_sop_version() -> String {
|
||||
@@ -214,6 +271,8 @@ pub enum SopRunStatus {
|
||||
Pending,
|
||||
Running,
|
||||
WaitingApproval,
|
||||
/// Paused at a checkpoint in a deterministic workflow.
|
||||
PausedCheckpoint,
|
||||
Completed,
|
||||
Failed,
|
||||
Cancelled,
|
||||
@@ -225,6 +284,7 @@ impl fmt::Display for SopRunStatus {
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::WaitingApproval => write!(f, "waiting_approval"),
|
||||
Self::PausedCheckpoint => write!(f, "paused_checkpoint"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
Self::Cancelled => write!(f, "cancelled"),
|
||||
@@ -276,6 +336,44 @@ pub struct SopRun {
|
||||
/// ISO-8601 timestamp when the run entered WaitingApproval (for timeout tracking).
|
||||
#[serde(default)]
|
||||
pub waiting_since: Option<String>,
|
||||
/// Number of LLM calls saved by deterministic execution in this run.
|
||||
#[serde(default)]
|
||||
pub llm_calls_saved: u64,
|
||||
}
|
||||
|
||||
// ── Deterministic workflow state (persistence + resume) ──────────
|
||||
|
||||
/// Persisted state for a deterministic workflow run, enabling resume
|
||||
/// after interruption. Serialized to a JSON file alongside the SOP.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeterministicRunState {
|
||||
/// Identifier of this run.
|
||||
pub run_id: String,
|
||||
/// SOP name this state belongs to.
|
||||
pub sop_name: String,
|
||||
/// Last successfully completed step number (0 = none completed).
|
||||
pub last_completed_step: u32,
|
||||
/// Total steps in the workflow.
|
||||
pub total_steps: u32,
|
||||
/// Output of each completed step, keyed by step number.
|
||||
pub step_outputs: HashMap<u32, serde_json::Value>,
|
||||
/// ISO-8601 timestamp when this state was last persisted.
|
||||
pub persisted_at: String,
|
||||
/// Number of LLM calls that were saved by deterministic execution.
|
||||
pub llm_calls_saved: u64,
|
||||
/// Whether the run is paused at a checkpoint awaiting approval.
|
||||
pub paused_at_checkpoint: bool,
|
||||
}
|
||||
|
||||
// ── Cost savings metric ──────────────────────────────────────────
|
||||
|
||||
/// Tracks how many LLM round-trips were saved by deterministic execution.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct DeterministicSavings {
|
||||
/// Total LLM calls saved across all deterministic runs.
|
||||
pub total_llm_calls_saved: u64,
|
||||
/// Total deterministic runs completed.
|
||||
pub total_runs: u64,
|
||||
}
|
||||
|
||||
/// What the engine instructs the caller to do next after a state transition.
|
||||
@@ -293,6 +391,20 @@ pub enum SopRunAction {
|
||||
step: SopStep,
|
||||
context: String,
|
||||
},
|
||||
/// Execute a step deterministically (no LLM). The `input` is the piped
|
||||
/// output from the previous step (or trigger payload for step 1).
|
||||
DeterministicStep {
|
||||
run_id: String,
|
||||
step: SopStep,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
/// Deterministic workflow hit a checkpoint — pause for human approval.
|
||||
/// Workflow state has been persisted so it can resume after approval.
|
||||
CheckpointWait {
|
||||
run_id: String,
|
||||
step: SopStep,
|
||||
state_file: PathBuf,
|
||||
},
|
||||
/// The SOP run completed successfully.
|
||||
Completed { run_id: String, sop_name: String },
|
||||
/// The SOP run failed.
|
||||
@@ -459,6 +571,7 @@ path = "/sop/test"
|
||||
completed_at: Some("2026-02-19T12:00:05Z".into()),
|
||||
}],
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
let json = serde_json::to_string(&run).unwrap();
|
||||
let parsed: SopRun = serde_json::from_str(&json).unwrap();
|
||||
|
||||
@@ -0,0 +1,446 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::ClaudeCodeConfig;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// Environment variables safe to pass through to the `claude` subprocess.
|
||||
const SAFE_ENV_VARS: &[&str] = &[
|
||||
"PATH", "HOME", "TERM", "LANG", "LC_ALL", "LC_CTYPE", "USER", "SHELL", "TMPDIR",
|
||||
];
|
||||
|
||||
/// Delegates coding tasks to the Claude Code CLI (`claude -p`).
|
||||
///
|
||||
/// This creates a two-tier agent architecture: ZeroClaw orchestrates high-level
|
||||
/// tasks and delegates complex coding work to Claude Code, which has its own
|
||||
/// agent loop with Read/Edit/Bash tools.
|
||||
///
|
||||
/// Authentication uses the `claude` binary's own OAuth session (Max subscription)
|
||||
/// by default. No API key is needed unless `env_passthrough` includes
|
||||
/// `ANTHROPIC_API_KEY` for API-key billing.
|
||||
pub struct ClaudeCodeTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
config: ClaudeCodeConfig,
|
||||
}
|
||||
|
||||
impl ClaudeCodeTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>, config: ClaudeCodeConfig) -> Self {
|
||||
Self { security, config }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ClaudeCodeTool {
|
||||
fn name(&self) -> &str {
|
||||
"claude_code"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Delegate a coding task to Claude Code (claude -p). Supports file editing, bash execution, structured output, and multi-turn sessions. Use for complex coding work that benefits from Claude Code's full agent loop."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The coding task to delegate to Claude Code"
|
||||
},
|
||||
"allowed_tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Override the default tool allowlist (e.g. [\"Read\", \"Edit\", \"Bash\", \"Write\"])"
|
||||
},
|
||||
"system_prompt": {
|
||||
"type": "string",
|
||||
"description": "Override or append a system prompt for this invocation"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Resume a previous Claude Code session by its ID"
|
||||
},
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"description": "Request structured output conforming to this JSON Schema"
|
||||
},
|
||||
"working_directory": {
|
||||
"type": "string",
|
||||
"description": "Working directory within the workspace (must be inside workspace_dir)"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Rate limit check
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Enforce act policy
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "claude_code")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
// Extract prompt (required)
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
// Extract optional params
|
||||
let allowed_tools: Vec<String> = args
|
||||
.get("allowed_tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_else(|| self.config.allowed_tools.clone());
|
||||
|
||||
let system_prompt = args
|
||||
.get("system_prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.or_else(|| self.config.system_prompt.clone());
|
||||
|
||||
let session_id = args.get("session_id").and_then(|v| v.as_str());
|
||||
|
||||
let json_schema = args.get("json_schema").filter(|v| v.is_object());
|
||||
|
||||
// Validate working directory — require both paths to exist (reject
|
||||
// non-existent paths instead of falling back to the raw value, which
|
||||
// could bypass the workspace containment check via symlinks or
|
||||
// specially-crafted path components).
|
||||
let work_dir = if let Some(wd) = args.get("working_directory").and_then(|v| v.as_str()) {
|
||||
let wd_path = std::path::PathBuf::from(wd);
|
||||
let workspace = &self.security.workspace_dir;
|
||||
let canonical_wd = match wd_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"working_directory '{}' does not exist or is not accessible",
|
||||
wd
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
let canonical_ws = match workspace.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"workspace directory '{}' does not exist or is not accessible",
|
||||
workspace.display()
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
if !canonical_wd.starts_with(&canonical_ws) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"working_directory '{}' is outside the workspace '{}'",
|
||||
wd,
|
||||
workspace.display()
|
||||
)),
|
||||
});
|
||||
}
|
||||
canonical_wd
|
||||
} else {
|
||||
self.security.workspace_dir.clone()
|
||||
};
|
||||
|
||||
// Record action budget
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build CLI command
|
||||
let mut cmd = Command::new("claude");
|
||||
cmd.arg("-p").arg(prompt);
|
||||
cmd.arg("--output-format").arg("json");
|
||||
|
||||
if !allowed_tools.is_empty() {
|
||||
for tool in &allowed_tools {
|
||||
cmd.arg("--allowedTools").arg(tool);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sp) = system_prompt {
|
||||
cmd.arg("--append-system-prompt").arg(sp);
|
||||
}
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
cmd.arg("--resume").arg(sid);
|
||||
}
|
||||
|
||||
if let Some(schema) = json_schema {
|
||||
let schema_str = serde_json::to_string(schema).unwrap_or_else(|_| "{}".to_string());
|
||||
cmd.arg("--json-schema").arg(schema_str);
|
||||
}
|
||||
|
||||
// Environment: clear everything, pass only safe vars + configured passthrough.
|
||||
// HOME is critical so `claude` finds its OAuth session in ~/.claude/
|
||||
cmd.env_clear();
|
||||
for var in SAFE_ENV_VARS {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
cmd.env(var, val);
|
||||
}
|
||||
}
|
||||
for var in &self.config.env_passthrough {
|
||||
let trimmed = var.trim();
|
||||
if !trimmed.is_empty() {
|
||||
if let Ok(val) = std::env::var(trimmed) {
|
||||
cmd.env(trimmed, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cmd.current_dir(&work_dir);
|
||||
// Execute with timeout — use kill_on_drop(true) so the child process
|
||||
// is automatically killed when the future is dropped on timeout,
|
||||
// preventing zombie processes.
|
||||
let timeout = Duration::from_secs(self.config.timeout_secs);
|
||||
cmd.kill_on_drop(true);
|
||||
|
||||
let result = tokio::time::timeout(timeout, cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
// Truncate to max_output_bytes with char-boundary safety
|
||||
if stdout.len() > self.config.max_output_bytes {
|
||||
let mut b = self.config.max_output_bytes.min(stdout.len());
|
||||
while b > 0 && !stdout.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
stdout.truncate(b);
|
||||
stdout.push_str("\n... [output truncated]");
|
||||
}
|
||||
|
||||
// Try to parse JSON response and extract result + session_id
|
||||
if let Ok(json_resp) = serde_json::from_str::<serde_json::Value>(&stdout) {
|
||||
let result_text = json_resp
|
||||
.get("result")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let resp_session_id = json_resp
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let mut formatted = String::new();
|
||||
if result_text.is_empty() {
|
||||
// Fall back to full JSON if no "result" key
|
||||
formatted.push_str(&stdout);
|
||||
} else {
|
||||
formatted.push_str(result_text);
|
||||
}
|
||||
if !resp_session_id.is_empty() {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(formatted, "\n\n[session_id: {}]", resp_session_id);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: formatted,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// JSON parse failed — return raw stdout (defensive)
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
let err_msg = e.to_string();
|
||||
let msg = if err_msg.contains("No such file or directory")
|
||||
|| err_msg.contains("not found")
|
||||
|| err_msg.contains("cannot find")
|
||||
{
|
||||
"Claude Code CLI ('claude') not found in PATH. Install with: npm install -g @anthropic-ai/claude-code".into()
|
||||
} else {
|
||||
format!("Failed to execute claude: {e}")
|
||||
};
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(msg),
|
||||
})
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout — kill_on_drop(true) ensures the child is killed
|
||||
// when the future is dropped.
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Claude Code timed out after {}s and was killed",
|
||||
self.config.timeout_secs
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::ClaudeCodeConfig;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_config() -> ClaudeCodeConfig {
|
||||
ClaudeCodeConfig::default()
|
||||
}
|
||||
|
||||
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_tool_name() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
|
||||
assert_eq!(tool.name(), "claude_code");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_tool_schema_has_prompt() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.expect("schema required should be an array")
|
||||
.contains(&json!("prompt")));
|
||||
// Optional params exist in properties
|
||||
assert!(schema["properties"]["allowed_tools"].is_object());
|
||||
assert!(schema["properties"]["system_prompt"].is_object());
|
||||
assert!(schema["properties"]["session_id"].is_object());
|
||||
assert!(schema["properties"]["json_schema"].is_object());
|
||||
assert!(schema["properties"]["working_directory"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_blocks_rate_limited() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
max_actions_per_hour: 0,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ClaudeCodeTool::new(security, test_config());
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "hello"}))
|
||||
.await
|
||||
.expect("rate-limited should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_blocks_readonly() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::ReadOnly), test_config());
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "hello"}))
|
||||
.await
|
||||
.expect("readonly should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_missing_prompt_param() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Supervised), test_config());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claude_code_rejects_path_outside_workspace() {
|
||||
let tool = ClaudeCodeTool::new(test_security(AutonomyLevel::Full), test_config());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"prompt": "hello",
|
||||
"working_directory": "/etc"
|
||||
}))
|
||||
.await
|
||||
.expect("should return a result for path validation");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("outside the workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_env_passthrough_defaults() {
|
||||
let config = ClaudeCodeConfig::default();
|
||||
assert!(
|
||||
config.env_passthrough.is_empty(),
|
||||
"env_passthrough should default to empty (Max subscription needs no API key)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_default_config_values() {
|
||||
let config = ClaudeCodeConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.timeout_secs, 600);
|
||||
assert_eq!(config.max_output_bytes, 2_097_152);
|
||||
assert!(config.system_prompt.is_none());
|
||||
assert_eq!(config.allowed_tools, vec!["Read", "Edit", "Bash", "Write"]);
|
||||
}
|
||||
}
|
||||
+344
-2
@@ -1,5 +1,6 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::agent::loop_::run_tool_call_loop;
|
||||
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
|
||||
use crate::config::{DelegateAgentConfig, DelegateToolConfig};
|
||||
use crate::observability::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
use crate::providers::{self, ChatMessage, Provider};
|
||||
@@ -9,6 +10,7 @@ use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -31,6 +33,8 @@ pub struct DelegateTool {
|
||||
multimodal_config: crate::config::MultimodalConfig,
|
||||
/// Global delegate tool config providing default timeout values.
|
||||
delegate_config: DelegateToolConfig,
|
||||
/// Workspace directory inherited from the root agent context.
|
||||
workspace_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
@@ -62,6 +66,7 @@ impl DelegateTool {
|
||||
parent_tools: Arc::new(RwLock::new(Vec::new())),
|
||||
multimodal_config: crate::config::MultimodalConfig::default(),
|
||||
delegate_config: DelegateToolConfig::default(),
|
||||
workspace_dir: PathBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +104,7 @@ impl DelegateTool {
|
||||
parent_tools: Arc::new(RwLock::new(Vec::new())),
|
||||
multimodal_config: crate::config::MultimodalConfig::default(),
|
||||
delegate_config: DelegateToolConfig::default(),
|
||||
workspace_dir: PathBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +131,12 @@ impl DelegateTool {
|
||||
pub fn parent_tools_handle(&self) -> Arc<RwLock<Vec<Arc<dyn Tool>>>> {
|
||||
Arc::clone(&self.parent_tools)
|
||||
}
|
||||
|
||||
/// Attach the workspace directory for system prompt enrichment.
|
||||
pub fn with_workspace_dir(mut self, workspace_dir: PathBuf) -> Self {
|
||||
self.workspace_dir = workspace_dir;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -300,6 +312,11 @@ impl Tool for DelegateTool {
|
||||
.await;
|
||||
}
|
||||
|
||||
// Build enriched system prompt for non-agentic sub-agent.
|
||||
let enriched_system_prompt =
|
||||
self.build_enriched_system_prompt(agent_config, &[], &self.workspace_dir);
|
||||
let system_prompt_ref = enriched_system_prompt.as_deref();
|
||||
|
||||
// Wrap the provider call in a timeout to prevent indefinite blocking
|
||||
let timeout_secs = agent_config
|
||||
.timeout_secs
|
||||
@@ -307,7 +324,7 @@ impl Tool for DelegateTool {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
provider.chat_with_system(
|
||||
agent_config.system_prompt.as_deref(),
|
||||
system_prompt_ref,
|
||||
&full_prompt,
|
||||
&agent_config.model,
|
||||
temperature,
|
||||
@@ -355,6 +372,80 @@ impl Tool for DelegateTool {
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
/// Build an enriched system prompt for a sub-agent by composing structured
|
||||
/// operational sections (tools, skills, workspace, datetime, shell policy)
|
||||
/// with the operator-configured `system_prompt` string.
|
||||
fn build_enriched_system_prompt(
|
||||
&self,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
sub_tools: &[Box<dyn Tool>],
|
||||
workspace_dir: &Path,
|
||||
) -> Option<String> {
|
||||
// Resolve skills directory: scoped if configured, otherwise workspace default.
|
||||
let skills_dir = agent_config
|
||||
.skills_directory
|
||||
.as_ref()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.map(|dir| workspace_dir.join(dir))
|
||||
.unwrap_or_else(|| crate::skills::skills_dir(workspace_dir));
|
||||
let skills = crate::skills::load_skills_from_directory(&skills_dir, false);
|
||||
|
||||
// Determine shell policy instructions when the `shell` tool is in the
|
||||
// effective tool list.
|
||||
let has_shell = sub_tools.iter().any(|t| t.name() == "shell");
|
||||
let shell_policy = if has_shell {
|
||||
"## Shell Policy\n\n\
|
||||
- Prefer non-destructive commands. Use `trash` over `rm` where possible.\n\
|
||||
- Do not run commands that exfiltrate data or modify system-critical paths.\n\
|
||||
- Avoid interactive commands that block on stdin.\n\
|
||||
- Quote paths that may contain spaces."
|
||||
.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Build structured operational context using SystemPromptBuilder sections.
|
||||
let ctx = PromptContext {
|
||||
workspace_dir,
|
||||
model_name: &agent_config.model,
|
||||
tools: sub_tools,
|
||||
skills: &skills,
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode::Full,
|
||||
identity_config: None,
|
||||
dispatcher_instructions: "",
|
||||
tool_descriptions: None,
|
||||
security_summary: None,
|
||||
autonomy_level: crate::security::AutonomyLevel::default(),
|
||||
};
|
||||
|
||||
let builder = SystemPromptBuilder::default()
|
||||
.add_section(Box::new(crate::agent::prompt::ToolsSection))
|
||||
.add_section(Box::new(crate::agent::prompt::SafetySection))
|
||||
.add_section(Box::new(crate::agent::prompt::SkillsSection))
|
||||
.add_section(Box::new(crate::agent::prompt::WorkspaceSection))
|
||||
.add_section(Box::new(crate::agent::prompt::DateTimeSection));
|
||||
|
||||
let mut enriched = builder.build(&ctx).unwrap_or_default();
|
||||
|
||||
if !shell_policy.is_empty() {
|
||||
enriched.push_str(&shell_policy);
|
||||
enriched.push_str("\n\n");
|
||||
}
|
||||
|
||||
// Append the operator-configured system_prompt as the identity/role block.
|
||||
if let Some(operator_prompt) = agent_config.system_prompt.as_ref() {
|
||||
enriched.push_str(operator_prompt);
|
||||
enriched.push('\n');
|
||||
}
|
||||
|
||||
let trimmed = enriched.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_agentic(
|
||||
&self,
|
||||
agent_name: &str,
|
||||
@@ -401,8 +492,12 @@ impl DelegateTool {
|
||||
});
|
||||
}
|
||||
|
||||
// Build enriched system prompt with tools, skills, workspace, datetime context.
|
||||
let enriched_system_prompt =
|
||||
self.build_enriched_system_prompt(agent_config, &sub_tools, &self.workspace_dir);
|
||||
|
||||
let mut history = Vec::new();
|
||||
if let Some(system_prompt) = agent_config.system_prompt.as_ref() {
|
||||
if let Some(system_prompt) = enriched_system_prompt.as_ref() {
|
||||
history.push(ChatMessage::system(system_prompt.clone()));
|
||||
}
|
||||
history.push(ChatMessage::user(full_prompt.to_string()));
|
||||
@@ -435,6 +530,7 @@ impl DelegateTool {
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
@@ -548,6 +644,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
@@ -564,6 +661,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents
|
||||
@@ -719,6 +817,7 @@ mod tests {
|
||||
max_iterations,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -829,6 +928,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
@@ -937,6 +1037,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
@@ -974,6 +1075,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
@@ -1235,6 +1337,113 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_includes_tools_workspace_datetime() {
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: Some("You are a code reviewer.".to_string()),
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_delegate_enrich_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
std::fs::create_dir_all(&workspace).unwrap();
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.clone());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(prompt.contains("## Tools"), "should contain tools section");
|
||||
assert!(prompt.contains("echo_tool"), "should list allowed tools");
|
||||
assert!(
|
||||
prompt.contains("## Workspace"),
|
||||
"should contain workspace section"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains(&workspace.display().to_string()),
|
||||
"should contain workspace path"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("## Current Date & Time"),
|
||||
"should contain datetime section"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("You are a code reviewer."),
|
||||
"should append operator system_prompt"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_includes_shell_policy_when_shell_present() {
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["shell".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
struct MockShellTool;
|
||||
#[async_trait]
|
||||
impl Tool for MockShellTool {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"Execute shell commands"
|
||||
}
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({"type": "object"})
|
||||
}
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: String::new(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockShellTool)];
|
||||
let workspace = std::env::temp_dir();
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.to_path_buf());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
prompt.contains("## Shell Policy"),
|
||||
"should contain shell policy when shell tool is present"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parent_tools_handle_returns_shared_reference() {
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security()).with_parent_tools(
|
||||
@@ -1265,6 +1474,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
assert_eq!(
|
||||
config.timeout_secs.unwrap_or(DEFAULT_DELEGATE_TIMEOUT_SECS),
|
||||
@@ -1278,6 +1488,39 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_omits_shell_policy_without_shell_tool() {
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
let workspace = std::env::temp_dir();
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.to_path_buf());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!prompt.contains("## Shell Policy"),
|
||||
"should not contain shell policy when shell tool is absent"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_timeout_values_are_respected() {
|
||||
let config = DelegateAgentConfig {
|
||||
@@ -1292,6 +1535,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(60),
|
||||
agentic_timeout_secs: Some(600),
|
||||
skills_directory: None,
|
||||
};
|
||||
assert_eq!(
|
||||
config.timeout_secs.unwrap_or(DEFAULT_DELEGATE_TIMEOUT_SECS),
|
||||
@@ -1346,6 +1590,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(0),
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1372,6 +1617,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: Some(0),
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1398,6 +1644,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(7200),
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1424,6 +1671,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: Some(5000),
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
let err = config.validate().unwrap_err();
|
||||
@@ -1450,6 +1698,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: Some(3600),
|
||||
agentic_timeout_secs: Some(3600),
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
assert!(config.validate().is_ok());
|
||||
@@ -1472,8 +1721,101 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_loads_skills_from_scoped_directory() {
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_delegate_skills_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
let scoped_skills_dir = workspace.join("skills/code-review");
|
||||
std::fs::create_dir_all(scoped_skills_dir.join("lint-check")).unwrap();
|
||||
std::fs::write(
|
||||
scoped_skills_dir.join("lint-check/SKILL.toml"),
|
||||
"[skill]\nname = \"lint-check\"\ndescription = \"Run lint checks\"\nversion = \"1.0.0\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: Some("skills/code-review".to_string()),
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.clone());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
prompt.contains("lint-check"),
|
||||
"should contain skills from scoped directory"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enriched_prompt_falls_back_to_default_skills_dir() {
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_delegate_fallback_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
let default_skills_dir = workspace.join("skills");
|
||||
std::fs::create_dir_all(default_skills_dir.join("deploy")).unwrap();
|
||||
std::fs::write(
|
||||
default_skills_dir.join("deploy/SKILL.toml"),
|
||||
"[skill]\nname = \"deploy\"\ndescription = \"Deploy safely\"\nversion = \"1.0.0\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: None,
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
agentic: true,
|
||||
allowed_tools: vec!["echo_tool".to_string()],
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(EchoTool)];
|
||||
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security())
|
||||
.with_workspace_dir(workspace.clone());
|
||||
|
||||
let prompt = tool
|
||||
.build_enriched_system_prompt(&config, &tools, &workspace)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
prompt.contains("deploy"),
|
||||
"should contain skills from default workspace skills/ directory"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
}
|
||||
|
||||
+85
-13
@@ -23,7 +23,7 @@ impl Tool for MemoryRecallTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
|
||||
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance. Supports keyword search, time-only query (since/until), or both."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@@ -32,22 +32,76 @@ impl Tool for MemoryRecallTool {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Keywords or phrase to search for in memory"
|
||||
"description": "Keywords or phrase to search for in memory (optional if since/until provided)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 5)"
|
||||
},
|
||||
"since": {
|
||||
"type": "string",
|
||||
"description": "Filter memories created at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
|
||||
},
|
||||
"until": {
|
||||
"type": "string",
|
||||
"description": "Filter memories created at or before this time (RFC 3339)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let since = args.get("since").and_then(|v| v.as_str());
|
||||
let until = args.get("until").and_then(|v| v.as_str());
|
||||
|
||||
if query.trim().is_empty() && since.is_none() && until.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate date strings
|
||||
if let Some(s) = since {
|
||||
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'since' date: {s}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(u) = until {
|
||||
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'until' date: {u}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let (Some(s), Some(u)) = (since, until) {
|
||||
if let (Ok(s_dt), Ok(u_dt)) = (
|
||||
chrono::DateTime::parse_from_rfc3339(s),
|
||||
chrono::DateTime::parse_from_rfc3339(u),
|
||||
) {
|
||||
if s_dt >= u_dt {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'since' must be before 'until'".into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
@@ -55,10 +109,10 @@ impl Tool for MemoryRecallTool {
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(5, |v| v as usize);
|
||||
|
||||
match self.memory.recall(query, limit, None).await {
|
||||
match self.memory.recall(query, limit, None, since, until).await {
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No memories found matching that query.".into(),
|
||||
output: "No memories found.".into(),
|
||||
error: None,
|
||||
}),
|
||||
Ok(entries) => {
|
||||
@@ -150,11 +204,29 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_missing_query() {
|
||||
async fn recall_requires_query_or_time() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("at least"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_time_only_returns_entries() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
// Time-only: since far in past
|
||||
let result = tool
|
||||
.execute(json!({"since": "2020-01-01T00:00:00Z", "limit": 5}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Found 1"));
|
||||
assert!(result.output.contains("Rust"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
+35
-1
@@ -20,6 +20,7 @@ pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
pub mod browser_open;
|
||||
pub mod calculator;
|
||||
pub mod claude_code;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
pub mod cloud_patterns;
|
||||
@@ -75,6 +76,11 @@ pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod shell;
|
||||
pub mod sop_advance;
|
||||
pub mod sop_approve;
|
||||
pub mod sop_execute;
|
||||
pub mod sop_list;
|
||||
pub mod sop_status;
|
||||
pub mod swarm;
|
||||
pub mod text_browser;
|
||||
pub mod tool_search;
|
||||
@@ -92,6 +98,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use claude_code::ClaudeCodeTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
pub use composio::ComposioTool;
|
||||
@@ -144,6 +151,11 @@ pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use sop_advance::SopAdvanceTool;
|
||||
pub use sop_approve::SopApproveTool;
|
||||
pub use sop_execute::SopExecuteTool;
|
||||
pub use sop_list::SopListTool;
|
||||
pub use sop_status::SopStatusTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use text_browser::TextBrowserTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
@@ -528,6 +540,14 @@ pub fn all_tools_with_runtime(
|
||||
);
|
||||
}
|
||||
|
||||
// Claude Code delegation tool
|
||||
if root_config.claude_code.enabled {
|
||||
tool_arcs.push(Arc::new(ClaudeCodeTool::new(
|
||||
security.clone(),
|
||||
root_config.claude_code.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// PDF extraction (feature-gated at compile time via rag-pdf)
|
||||
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
|
||||
|
||||
@@ -546,6 +566,18 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// SOP tools (registered when sops_dir is configured)
|
||||
if root_config.sop.sops_dir.is_some() {
|
||||
let sop_engine = Arc::new(std::sync::Mutex::new(crate::sop::SopEngine::new(
|
||||
root_config.sop.clone(),
|
||||
)));
|
||||
tool_arcs.push(Arc::new(SopListTool::new(Arc::clone(&sop_engine))));
|
||||
tool_arcs.push(Arc::new(SopExecuteTool::new(Arc::clone(&sop_engine))));
|
||||
tool_arcs.push(Arc::new(SopAdvanceTool::new(Arc::clone(&sop_engine))));
|
||||
tool_arcs.push(Arc::new(SopApproveTool::new(Arc::clone(&sop_engine))));
|
||||
tool_arcs.push(Arc::new(SopStatusTool::new(Arc::clone(&sop_engine))));
|
||||
}
|
||||
|
||||
if let Some(key) = composio_key {
|
||||
if !key.is_empty() {
|
||||
tool_arcs.push(Arc::new(ComposioTool::new(
|
||||
@@ -669,7 +701,8 @@ pub fn all_tools_with_runtime(
|
||||
)
|
||||
.with_parent_tools(Arc::clone(&parent_tools))
|
||||
.with_multimodal_config(root_config.multimodal.clone())
|
||||
.with_delegate_config(root_config.delegate.clone());
|
||||
.with_delegate_config(root_config.delegate.clone())
|
||||
.with_workspace_dir(workspace_dir.to_path_buf());
|
||||
tool_arcs.push(Arc::new(delegate_tool));
|
||||
Some(parent_tools)
|
||||
};
|
||||
@@ -1000,6 +1033,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -707,6 +707,7 @@ impl ModelRoutingConfigTool {
|
||||
max_iterations: DEFAULT_AGENT_MAX_ITERATIONS,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
});
|
||||
|
||||
next_agent.provider = provider;
|
||||
|
||||
@@ -181,6 +181,18 @@ impl Tool for SopAdvanceTool {
|
||||
} => {
|
||||
format!("SOP '{sop_name}' run {run_id} failed: {reason}")
|
||||
}
|
||||
SopRunAction::DeterministicStep { run_id, step, .. } => {
|
||||
format!(
|
||||
"Step recorded. Next deterministic step for run {run_id}: {}",
|
||||
step.title
|
||||
)
|
||||
}
|
||||
SopRunAction::CheckpointWait { run_id, step, .. } => {
|
||||
format!(
|
||||
"Step recorded. Run {run_id} paused at checkpoint: {}",
|
||||
step.title
|
||||
)
|
||||
}
|
||||
};
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
@@ -222,6 +234,8 @@ mod tests {
|
||||
body: "Do step one".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
},
|
||||
SopStep {
|
||||
number: 2,
|
||||
@@ -229,11 +243,14 @@ mod tests {
|
||||
body: "Do step two".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
},
|
||||
],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -143,10 +143,13 @@ mod tests {
|
||||
body: "Do it".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
}],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -118,6 +118,18 @@ impl Tool for SopExecuteTool {
|
||||
SopRunAction::Failed { run_id, reason, .. } => {
|
||||
format!("SOP run {run_id} failed: {reason}")
|
||||
}
|
||||
SopRunAction::DeterministicStep { run_id, step, .. } => {
|
||||
format!(
|
||||
"SOP run started (deterministic): {run_id}\nFirst step: {}",
|
||||
step.title
|
||||
)
|
||||
}
|
||||
SopRunAction::CheckpointWait { run_id, step, .. } => {
|
||||
format!(
|
||||
"SOP run started: {run_id} (paused at checkpoint: {})",
|
||||
step.title
|
||||
)
|
||||
}
|
||||
};
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
@@ -140,7 +152,9 @@ fn action_run_id(action: &SopRunAction) -> Option<&str> {
|
||||
SopRunAction::ExecuteStep { run_id, .. }
|
||||
| SopRunAction::WaitApproval { run_id, .. }
|
||||
| SopRunAction::Completed { run_id, .. }
|
||||
| SopRunAction::Failed { run_id, .. } => Some(run_id),
|
||||
| SopRunAction::Failed { run_id, .. }
|
||||
| SopRunAction::DeterministicStep { run_id, .. }
|
||||
| SopRunAction::CheckpointWait { run_id, .. } => Some(run_id),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,6 +182,8 @@ mod tests {
|
||||
body: "Do step one".into(),
|
||||
suggested_tools: vec!["shell".into()],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
},
|
||||
SopStep {
|
||||
number: 2,
|
||||
@@ -175,11 +191,14 @@ mod tests {
|
||||
body: "Do step two".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
},
|
||||
],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -137,10 +137,13 @@ mod tests {
|
||||
body: "Do it".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
}],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 1,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+6
-55
@@ -11,8 +11,6 @@ use crate::sop::{SopEngine, SopMetricsCollector};
|
||||
pub struct SopStatusTool {
|
||||
engine: Arc<Mutex<SopEngine>>,
|
||||
collector: Option<Arc<SopMetricsCollector>>,
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
gate_eval: Option<Arc<crate::sop::GateEvalState>>,
|
||||
}
|
||||
|
||||
impl SopStatusTool {
|
||||
@@ -20,8 +18,6 @@ impl SopStatusTool {
|
||||
Self {
|
||||
engine,
|
||||
collector: None,
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
gate_eval: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,61 +26,11 @@ impl SopStatusTool {
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
pub fn with_gate_eval(mut self, gate_eval: Arc<crate::sop::GateEvalState>) -> Self {
|
||||
self.gate_eval = Some(gate_eval);
|
||||
self
|
||||
}
|
||||
|
||||
fn append_gate_status(&self, output: &mut String, include_gate_status: bool) {
|
||||
#[cfg(feature = "ampersona-gates")]
|
||||
if include_gate_status {
|
||||
if let Some(ref ge) = self.gate_eval {
|
||||
if let Some(snap) = ge.phase_state_snapshot() {
|
||||
let _ = writeln!(output, "\nGate Status:");
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" current_phase: {}",
|
||||
snap.current_phase.as_deref().unwrap_or("(none)")
|
||||
);
|
||||
let _ = writeln!(output, " state_rev: {}", snap.state_rev);
|
||||
let _ = writeln!(output, " gates_loaded: {}", ge.gate_count());
|
||||
if let Some(ref tr) = snap.last_transition {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" last_transition: {} ({} → {})",
|
||||
tr.at.to_rfc3339(),
|
||||
tr.from_phase.as_deref().unwrap_or("(none)"),
|
||||
tr.to_phase,
|
||||
);
|
||||
} else {
|
||||
let _ = writeln!(output, " last_transition: none");
|
||||
}
|
||||
if let Some(ref pt) = snap.pending_transition {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" pending_transition: {} → {} ({})",
|
||||
pt.from_phase.as_deref().unwrap_or("(none)"),
|
||||
pt.to_phase,
|
||||
pt.decision,
|
||||
);
|
||||
} else {
|
||||
let _ = writeln!(output, " pending_transition: none");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"\nGate Status: not available (gate eval not configured)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ampersona-gates"))]
|
||||
if include_gate_status {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"\nGate Status: not available (ampersona-gates feature not enabled)"
|
||||
"\nGate Status: not available (gate evaluation not supported)"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -309,10 +255,13 @@ mod tests {
|
||||
body: "Do it".into(),
|
||||
suggested_tools: vec![],
|
||||
requires_confirmation: false,
|
||||
kind: SopStepKind::default(),
|
||||
schema: None,
|
||||
}],
|
||||
cooldown_secs: 0,
|
||||
max_concurrent: 2,
|
||||
location: None,
|
||||
deterministic: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -431,6 +380,7 @@ mod tests {
|
||||
completed_at: Some("2026-02-19T12:01:00Z".into()),
|
||||
}],
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
collector.record_run_complete(&run);
|
||||
|
||||
@@ -466,6 +416,7 @@ mod tests {
|
||||
completed_at: Some("2026-02-19T12:01:00Z".into()),
|
||||
}],
|
||||
waiting_since: None,
|
||||
llm_calls_saved: 0,
|
||||
};
|
||||
collector.record_run_complete(&run);
|
||||
|
||||
|
||||
@@ -568,6 +568,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
@@ -584,6 +585,7 @@ mod tests {
|
||||
max_iterations: 10,
|
||||
timeout_secs: None,
|
||||
agentic_timeout_secs: None,
|
||||
skills_directory: None,
|
||||
},
|
||||
);
|
||||
agents
|
||||
|
||||
@@ -100,6 +100,10 @@ fn gateway_config_defaults_are_secure() {
|
||||
!gw.trust_forwarded_headers,
|
||||
"forwarded headers should be untrusted by default"
|
||||
);
|
||||
assert!(
|
||||
gw.path_prefix.is_none(),
|
||||
"path_prefix should default to None"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -124,6 +128,7 @@ fn gateway_config_toml_roundtrip() {
|
||||
host: "0.0.0.0".into(),
|
||||
require_pairing: false,
|
||||
pair_rate_limit_per_minute: 5,
|
||||
path_prefix: Some("/zeroclaw".into()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -134,6 +139,7 @@ fn gateway_config_toml_roundtrip() {
|
||||
assert_eq!(parsed.host, "0.0.0.0");
|
||||
assert!(!parsed.require_pairing);
|
||||
assert_eq!(parsed.pair_rate_limit_per_minute, 5);
|
||||
assert_eq!(parsed.path_prefix.as_deref(), Some("/zeroclaw"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -163,6 +169,93 @@ port = 9090
|
||||
assert_eq!(parsed.gateway.pair_rate_limit_per_minute, 10);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// GatewayConfig path_prefix validation
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_missing_leading_slash() {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some("zeroclaw".into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must start with '/'"),
|
||||
"expected leading-slash error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_trailing_slash() {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some("/zeroclaw/".into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not end with '/'"),
|
||||
"expected trailing-slash error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_bare_slash() {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some("/".into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not end with '/'"),
|
||||
"expected bare-slash error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_accepts_valid_prefixes() {
|
||||
for prefix in ["/zeroclaw", "/apps/zeroclaw", "/api/hassio_ingress/abc123"] {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some(prefix.into());
|
||||
config
|
||||
.validate()
|
||||
.unwrap_or_else(|e| panic!("prefix {prefix:?} should be valid, got: {e}"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_rejects_unsafe_characters() {
|
||||
for prefix in [
|
||||
"/zero claw",
|
||||
"/zero<claw",
|
||||
"/zero>claw",
|
||||
"/zero\"claw",
|
||||
"/zero?query",
|
||||
"/zero#frag",
|
||||
] {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some(prefix.into());
|
||||
let err = config.validate().unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("invalid character"),
|
||||
"prefix {prefix:?} should be rejected, got: {err}"
|
||||
);
|
||||
}
|
||||
// Leading/trailing whitespace is rejected by the starts_with('/') or
|
||||
// invalid-character check — either way it must not pass validation.
|
||||
for prefix in [" /zeroclaw ", " /zeroclaw"] {
|
||||
let mut config = Config::default();
|
||||
config.gateway.path_prefix = Some(prefix.into());
|
||||
assert!(
|
||||
config.validate().is_err(),
|
||||
"whitespace-padded prefix {prefix:?} should be rejected"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_path_prefix_accepts_none() {
|
||||
let config = Config::default();
|
||||
assert!(config.gateway.path_prefix.is_none());
|
||||
config
|
||||
.validate()
|
||||
.expect("absent path_prefix should be valid");
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// SecurityConfig boundary tests
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -147,8 +147,8 @@ async fn compare_recall_quality() {
|
||||
println!("RECALL QUALITY (10 entries seeded):\n");
|
||||
|
||||
for (query, desc) in &queries {
|
||||
let sq_results = sq.recall(query, 10, None).await.unwrap();
|
||||
let md_results = md.recall(query, 10, None).await.unwrap();
|
||||
let sq_results = sq.recall(query, 10, None, None, None).await.unwrap();
|
||||
let md_results = md.recall(query, 10, None, None, None).await.unwrap();
|
||||
|
||||
println!(" Query: \"{query}\" — {desc}");
|
||||
println!(" SQLite: {} results", sq_results.len());
|
||||
@@ -202,11 +202,17 @@ async fn compare_recall_speed() {
|
||||
|
||||
// Benchmark recall
|
||||
let start = Instant::now();
|
||||
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
|
||||
let sq_results = sq
|
||||
.recall("Rust systems", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let sq_dur = start.elapsed();
|
||||
|
||||
let start = Instant::now();
|
||||
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
|
||||
let md_results = md
|
||||
.recall("Rust systems", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let md_dur = start.elapsed();
|
||||
|
||||
println!("\n============================================================");
|
||||
@@ -312,7 +318,7 @@ async fn compare_upsert() {
|
||||
let md_count = md.count().await.unwrap();
|
||||
|
||||
let sq_entry = sq.get("pref").await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5, None, None, None).await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("UPSERT (store same key twice):");
|
||||
|
||||
@@ -216,7 +216,10 @@ async fn sqlite_memory_recall_returns_relevant_results() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust programming", 10, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("Rust programming", 10, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!results.is_empty(), "recall should find matching entries");
|
||||
// The Rust-related entry should be in results
|
||||
assert!(
|
||||
@@ -241,7 +244,10 @@ async fn sqlite_memory_recall_respects_limit() {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("test content", 3, None).await.unwrap();
|
||||
let results = mem
|
||||
.recall("test content", 3, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
results.len() <= 3,
|
||||
"recall should respect limit of 3, got {}",
|
||||
@@ -250,7 +256,7 @@ async fn sqlite_memory_recall_respects_limit() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_memory_recall_empty_query_returns_empty() {
|
||||
async fn sqlite_memory_recall_empty_query_returns_recent_entries() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
||||
@@ -258,8 +264,10 @@ async fn sqlite_memory_recall_empty_query_returns_empty() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty(), "empty query should return no results");
|
||||
// Empty query uses time-only path: returns recent entries by updated_at
|
||||
let results = mem.recall("", 10, None, None, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1, "empty query should return recent entries");
|
||||
assert_eq!(results[0].key, "fact");
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
+2
-1
@@ -16,6 +16,7 @@ import Pairing from './pages/Pairing';
|
||||
import { AuthProvider, useAuth } from './hooks/useAuth';
|
||||
import { DraftContext, useDraftStore } from './hooks/useDraft';
|
||||
import { setLocale, type Locale } from './lib/i18n';
|
||||
import { basePath } from './lib/basePath';
|
||||
import { getAdminPairCode } from './lib/api';
|
||||
|
||||
// Locale context
|
||||
@@ -131,7 +132,7 @@ function PairingDialog({ onPair }: { onPair: (code: string) => Promise<void> })
|
||||
|
||||
<div className="text-center mb-8">
|
||||
<img
|
||||
src="/_app/zeroclaw-trans.png"
|
||||
src={`${basePath}/_app/zeroclaw-trans.png`}
|
||||
alt="ZeroClaw"
|
||||
className="h-20 w-20 rounded-2xl object-cover mx-auto mb-4 animate-float"
|
||||
onError={(e) => { e.currentTarget.style.display = 'none'; }}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { NavLink } from 'react-router-dom';
|
||||
import { basePath } from '../../lib/basePath';
|
||||
import {
|
||||
LayoutDashboard,
|
||||
MessageSquare,
|
||||
@@ -34,7 +35,7 @@ export default function Sidebar() {
|
||||
<div className="relative shrink-0">
|
||||
<div className="absolute -inset-1.5 rounded-xl" style={{ background: 'linear-gradient(135deg, rgba(var(--pc-accent-rgb), 0.15), rgba(var(--pc-accent-rgb), 0.05))' }} />
|
||||
<img
|
||||
src="/_app/zeroclaw-trans.png"
|
||||
src={`${basePath}/_app/zeroclaw-trans.png`}
|
||||
alt="ZeroClaw"
|
||||
className="relative h-9 w-9 rounded-xl object-cover"
|
||||
onError={(e) => {
|
||||
|
||||
+4
-3
@@ -11,6 +11,7 @@ import type {
|
||||
HealthSnapshot,
|
||||
} from '../types/api';
|
||||
import { clearToken, getToken, setToken } from './auth';
|
||||
import { basePath } from './basePath';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Base fetch wrapper
|
||||
@@ -42,7 +43,7 @@ export async function apiFetch<T = unknown>(
|
||||
headers.set('Content-Type', 'application/json');
|
||||
}
|
||||
|
||||
const response = await fetch(path, { ...options, headers });
|
||||
const response = await fetch(`${basePath}${path}`, { ...options, headers });
|
||||
|
||||
if (response.status === 401) {
|
||||
clearToken();
|
||||
@@ -78,7 +79,7 @@ function unwrapField<T>(value: T | Record<string, T>, key: string): T {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function pair(code: string): Promise<{ token: string }> {
|
||||
const response = await fetch('/pair', {
|
||||
const response = await fetch(`${basePath}/pair`, {
|
||||
method: 'POST',
|
||||
headers: { 'X-Pairing-Code': code },
|
||||
});
|
||||
@@ -106,7 +107,7 @@ export async function getAdminPairCode(): Promise<{ pairing_code: string | null;
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function getPublicHealth(): Promise<{ require_pairing: boolean; paired: boolean }> {
|
||||
const response = await fetch('/health');
|
||||
const response = await fetch(`${basePath}/health`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Health check failed (${response.status})`);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
// Runtime base path injected by the Rust gateway into index.html.
|
||||
// Allows the SPA to work under a reverse-proxy path prefix.
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
__ZEROCLAW_BASE__?: string;
|
||||
}
|
||||
}
|
||||
|
||||
/** Gateway path prefix (e.g. "/zeroclaw"), or empty string when served at root. */
|
||||
export const basePath: string = (window.__ZEROCLAW_BASE__ ?? '').replace(/\/+$/, '');
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
import type { SSEEvent } from '../types/api';
|
||||
import { getToken } from './auth';
|
||||
import { basePath } from './basePath';
|
||||
|
||||
export type SSEEventHandler = (event: SSEEvent) => void;
|
||||
export type SSEErrorHandler = (error: Event | Error) => void;
|
||||
@@ -41,7 +42,7 @@ export class SSEClient {
|
||||
private readonly autoReconnect: boolean;
|
||||
|
||||
constructor(options: SSEClientOptions = {}) {
|
||||
this.path = options.path ?? '/api/events';
|
||||
this.path = options.path ?? `${basePath}/api/events`;
|
||||
this.reconnectDelay = options.reconnectDelay ?? DEFAULT_RECONNECT_DELAY;
|
||||
this.maxReconnectDelay = options.maxReconnectDelay ?? MAX_RECONNECT_DELAY;
|
||||
this.autoReconnect = options.autoReconnect ?? true;
|
||||
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
import type { WsMessage } from '../types/api';
|
||||
import { getToken } from './auth';
|
||||
import { basePath } from './basePath';
|
||||
import { generateUUID } from './uuid';
|
||||
|
||||
export type WsMessageHandler = (msg: WsMessage) => void;
|
||||
@@ -69,7 +70,7 @@ export class WebSocketClient {
|
||||
const params = new URLSearchParams();
|
||||
if (token) params.set('token', token);
|
||||
params.set('session_id', sessionId);
|
||||
const url = `${this.baseUrl}/ws/chat?${params.toString()}`;
|
||||
const url = `${this.baseUrl}${basePath}/ws/chat?${params.toString()}`;
|
||||
|
||||
const protocols: string[] = ['zeroclaw.v1'];
|
||||
if (token) protocols.push(`bearer.${token}`);
|
||||
|
||||
+3
-2
@@ -2,12 +2,13 @@ import React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import { BrowserRouter } from 'react-router-dom';
|
||||
import App from './App';
|
||||
import { basePath } from './lib/basePath';
|
||||
import './index.css';
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<React.StrictMode>
|
||||
{/* Vite base '/_app/' scopes static asset URLs only; app routes stay rooted at '/' for SPA fallback. */}
|
||||
<BrowserRouter basename="/">
|
||||
{/* basePath is injected by the Rust gateway at serve time for reverse-proxy prefix support. */}
|
||||
<BrowserRouter basename={basePath || '/'}>
|
||||
<App />
|
||||
</BrowserRouter>
|
||||
</React.StrictMode>
|
||||
|
||||
Reference in New Issue
Block a user