Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a18b15b4cc | |||
| 6a4ccaeb73 | |||
| 4aead04916 | |||
| 7b23c8934c | |||
| 5d1543100d | |||
| e3a91bc805 | |||
| 833fdefbe5 | |||
| 13f74f0ecc | |||
| 9ff045d2e9 | |||
| 6fe8e3a5bb | |||
| 5dc1750df7 | |||
| b40c9e77af | |||
| 34cac3d9dd | |||
| badf96dcab | |||
| c1e1228fb0 | |||
| d2b923ae07 | |||
| 21fdef95f4 | |||
| d02fbf2d76 | |||
| 05cede29a8 | |||
| d34a2e6d3f | |||
| 576d22fedd | |||
| d5455c694c | |||
| 90275b057e | |||
| d46b4f29d2 | |||
| f25835f98c | |||
| 376579f9fa | |||
| b620fd6bba | |||
| 98d6c5af9e | |||
| c51ca19dc1 | |||
| ea6abc9f42 | |||
| e2f6f20bfb | |||
| 88df3d4b2e | |||
| 0dba55959d | |||
| 893788f04d | |||
| 0fea62d114 | |||
| cca3cf8f84 |
@@ -96,7 +96,7 @@ jobs:
|
||||
|
||||
- name: Build release
|
||||
shell: bash
|
||||
run: cargo build --release --locked --target ${{ matrix.target }}
|
||||
run: cargo build --profile ci --locked --target ${{ matrix.target }}
|
||||
env:
|
||||
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER: clang
|
||||
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS: "-C link-arg=-fuse-ld=mold"
|
||||
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
|
||||
- name: Build release
|
||||
shell: bash
|
||||
run: cargo build --release --locked --target ${{ matrix.target }}
|
||||
run: cargo build --profile ci --locked --target ${{ matrix.target }}
|
||||
env:
|
||||
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER: clang
|
||||
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS: "-C link-arg=-fuse-ld=mold"
|
||||
|
||||
@@ -65,11 +65,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
- os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-22.04
|
||||
- os: ubuntu-latest
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
|
||||
@@ -83,11 +83,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
- os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-22.04
|
||||
- os: ubuntu-latest
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
|
||||
+27
-7
@@ -2,20 +2,41 @@
|
||||
|
||||
Thanks for your interest in contributing to ZeroClaw! This guide will help you get started.
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ Branch Migration Notice (March 2026)
|
||||
|
||||
**`master` is the ONLY default branch. The `main` branch no longer exists.**
|
||||
|
||||
If you have an existing fork or local clone that tracks `main`, you **must** update it:
|
||||
|
||||
```bash
|
||||
# Update your local clone to track master
|
||||
git checkout master
|
||||
git branch -D main 2>/dev/null # delete local main if it exists
|
||||
git remote set-head origin master
|
||||
git fetch origin --prune # remove stale remote refs
|
||||
|
||||
# If your fork still has a main branch, delete it
|
||||
git push origin --delete main 2>/dev/null
|
||||
```
|
||||
|
||||
All PRs must target **`master`**. PRs targeting `main` will be rejected.
|
||||
|
||||
**Background:** ZeroClaw previously used `main` in some documentation and scripts, which caused 404 errors, broken CI refs, and contributor confusion (see [#2929](https://github.com/zeroclaw-labs/zeroclaw/issues/2929), [#3061](https://github.com/zeroclaw-labs/zeroclaw/issues/3061), [#3194](https://github.com/zeroclaw-labs/zeroclaw/pull/3194)). As of March 2026, all references have been corrected, stale branches cleaned up, and the `main` branch permanently deleted.
|
||||
|
||||
---
|
||||
|
||||
## Branching Model
|
||||
|
||||
> **Important — `master` is the default branch.**
|
||||
>
|
||||
> ZeroClaw uses **`master`** as its single source-of-truth branch. The `main` branch has been removed.
|
||||
>
|
||||
> Previously, some documentation and scripts referenced a `main` branch, which caused 404 errors and contributor confusion (see [#2929](https://github.com/zeroclaw-labs/zeroclaw/issues/2929), [#3061](https://github.com/zeroclaw-labs/zeroclaw/issues/3061), [#3194](https://github.com/zeroclaw-labs/zeroclaw/pull/3194)). As of March 2026, all references have been corrected and the `main` branch deleted.
|
||||
> **`master`** is the single source-of-truth branch.
|
||||
>
|
||||
> **How contributors should work:**
|
||||
> 1. Fork the repository
|
||||
> 2. Create a `feat/*` or `fix/*` branch from `master`
|
||||
> 3. Open a PR targeting `master`
|
||||
>
|
||||
> Do **not** create or push to a `main` branch.
|
||||
> Do **not** create or push to a `main` branch. There is no `main` branch — it will not work.
|
||||
|
||||
## First-Time Contributors
|
||||
|
||||
@@ -559,4 +580,3 @@ Recommended scope keys in commit titles:
|
||||
## License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the MIT License.
|
||||
# Contributing Guide Update
|
||||
|
||||
+9
-2
@@ -48,8 +48,8 @@ schemars = "1.2"
|
||||
tracing = { version = "0.1", default-features = false }
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
|
||||
|
||||
# Observability - Prometheus metrics
|
||||
prometheus = { version = "0.14", default-features = false }
|
||||
# Observability - Prometheus metrics (optional; requires AtomicU64, unavailable on 32-bit)
|
||||
prometheus = { version = "0.14", default-features = false, optional = true }
|
||||
|
||||
# Base64 encoding (screenshots, image data)
|
||||
base64 = "0.22"
|
||||
@@ -205,6 +205,8 @@ sandbox-landlock = ["dep:landlock"]
|
||||
sandbox-bubblewrap = []
|
||||
# Backward-compatible alias for older invocations
|
||||
landlock = ["sandbox-landlock"]
|
||||
# Prometheus metrics observer (requires 64-bit atomics; disable on 32-bit targets)
|
||||
metrics = ["dep:prometheus"]
|
||||
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
||||
probe = ["dep:probe-rs"]
|
||||
# rag-pdf = PDF ingestion for datasheet RAG
|
||||
@@ -225,6 +227,11 @@ inherits = "release"
|
||||
codegen-units = 8 # Parallel codegen for faster builds on powerful machines (16GB+ RAM recommended)
|
||||
# Use: cargo build --profile release-fast
|
||||
|
||||
[profile.ci]
|
||||
inherits = "release"
|
||||
lto = "thin" # Much faster than fat LTO; still catches release-mode issues
|
||||
codegen-units = 16 # Full parallelism for CI runners
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
opt-level = "z"
|
||||
|
||||
@@ -90,6 +90,8 @@ COPY dev/config.template.toml /zeroclaw-data/.zeroclaw/config.toml
|
||||
RUN chown 65534:65534 /zeroclaw-data/.zeroclaw/config.toml
|
||||
|
||||
# Environment setup
|
||||
# Ensure UTF-8 locale so CJK / multibyte input is handled correctly
|
||||
ENV LANG=C.UTF-8
|
||||
# Use consistent workspace path
|
||||
ENV ZEROCLAW_WORKSPACE=/zeroclaw-data/workspace
|
||||
ENV HOME=/zeroclaw-data
|
||||
@@ -114,6 +116,8 @@ COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
|
||||
COPY --from=builder /zeroclaw-data /zeroclaw-data
|
||||
|
||||
# Environment setup
|
||||
# Ensure UTF-8 locale so CJK / multibyte input is handled correctly
|
||||
ENV LANG=C.UTF-8
|
||||
ENV ZEROCLAW_WORKSPACE=/zeroclaw-data/workspace
|
||||
ENV HOME=/zeroclaw-data
|
||||
# Default provider and model are set in config.toml, not here,
|
||||
|
||||
@@ -1,6 +1,110 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
let dir = std::path::Path::new("web/dist");
|
||||
if !dir.exists() {
|
||||
std::fs::create_dir_all(dir).expect("failed to create web/dist/");
|
||||
let dist_dir = Path::new("web/dist");
|
||||
let web_dir = Path::new("web");
|
||||
|
||||
// Tell Cargo to re-run this script when web source files change.
|
||||
println!("cargo:rerun-if-changed=web/src");
|
||||
println!("cargo:rerun-if-changed=web/index.html");
|
||||
println!("cargo:rerun-if-changed=web/package.json");
|
||||
println!("cargo:rerun-if-changed=web/vite.config.ts");
|
||||
|
||||
// Attempt to build the web frontend if npm is available and web/dist is
|
||||
// missing or stale. The build is best-effort: when Node.js is not
|
||||
// installed (e.g. CI containers, cross-compilation, minimal dev setups)
|
||||
// we fall back to the existing stub/empty dist directory so the Rust
|
||||
// build still succeeds.
|
||||
let needs_build = !dist_dir.join("index.html").exists();
|
||||
|
||||
if needs_build && web_dir.join("package.json").exists() {
|
||||
if let Ok(npm) = which_npm() {
|
||||
eprintln!("cargo:warning=Building web frontend (web/dist is missing or stale)...");
|
||||
|
||||
// npm ci / npm install
|
||||
let install_status = Command::new(&npm)
|
||||
.args(["ci", "--ignore-scripts"])
|
||||
.current_dir(web_dir)
|
||||
.status();
|
||||
|
||||
match install_status {
|
||||
Ok(s) if s.success() => {}
|
||||
Ok(s) => {
|
||||
// Fall back to `npm install` if `npm ci` fails (no lockfile, etc.)
|
||||
eprintln!("cargo:warning=npm ci exited with {s}, trying npm install...");
|
||||
let fallback = Command::new(&npm)
|
||||
.args(["install"])
|
||||
.current_dir(web_dir)
|
||||
.status();
|
||||
if !matches!(fallback, Ok(s) if s.success()) {
|
||||
eprintln!("cargo:warning=npm install failed — skipping web build");
|
||||
ensure_dist_dir(dist_dir);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("cargo:warning=Could not run npm: {e} — skipping web build");
|
||||
ensure_dist_dir(dist_dir);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// npm run build
|
||||
let build_status = Command::new(&npm)
|
||||
.args(["run", "build"])
|
||||
.current_dir(web_dir)
|
||||
.status();
|
||||
|
||||
match build_status {
|
||||
Ok(s) if s.success() => {
|
||||
eprintln!("cargo:warning=Web frontend built successfully.");
|
||||
}
|
||||
Ok(s) => {
|
||||
eprintln!(
|
||||
"cargo:warning=npm run build exited with {s} — web dashboard may be unavailable"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"cargo:warning=Could not run npm build: {e} — web dashboard may be unavailable"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ensure_dist_dir(dist_dir);
|
||||
}
|
||||
|
||||
/// Ensure the dist directory exists so `rust-embed` does not fail at compile
|
||||
/// time even when the web frontend is not built.
|
||||
fn ensure_dist_dir(dist_dir: &Path) {
|
||||
if !dist_dir.exists() {
|
||||
std::fs::create_dir_all(dist_dir).expect("failed to create web/dist/");
|
||||
}
|
||||
}
|
||||
|
||||
/// Locate the `npm` binary on the system PATH.
|
||||
fn which_npm() -> Result<String, ()> {
|
||||
let cmd = if cfg!(target_os = "windows") {
|
||||
"where"
|
||||
} else {
|
||||
"which"
|
||||
};
|
||||
|
||||
Command::new(cmd)
|
||||
.arg("npm")
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|output| {
|
||||
if output.status.success() {
|
||||
String::from_utf8(output.stdout)
|
||||
.ok()
|
||||
.map(|s| s.lines().next().unwrap_or("npm").trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(())
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ Lưu ý cho người dùng container:
|
||||
| `max_history_messages` | `50` | Số tin nhắn lịch sử tối đa giữ lại mỗi phiên |
|
||||
| `parallel_tools` | `false` | Bật thực thi tool song song trong một lượt |
|
||||
| `tool_dispatcher` | `auto` | Chiến lược dispatch tool |
|
||||
| `tool_call_dedup_exempt` | `[]` | Tên tool được miễn kiểm tra trùng lặp trong cùng một lượt |
|
||||
|
||||
Lưu ý:
|
||||
|
||||
@@ -77,6 +78,7 @@ Lưu ý:
|
||||
- Nếu tin nhắn kênh vượt giá trị này, runtime trả về: `Agent exceeded maximum tool iterations (<value>)`.
|
||||
- Trong vòng lặp tool của CLI, gateway và channel, các lời gọi tool độc lập được thực thi đồng thời mặc định khi không cần phê duyệt; thứ tự kết quả giữ ổn định.
|
||||
- `parallel_tools` áp dụng cho API `Agent::turn()`. Không ảnh hưởng đến vòng lặp runtime của CLI, gateway hay channel.
|
||||
- `tool_call_dedup_exempt` nhận mảng tên tool chính xác. Các tool trong danh sách được phép gọi nhiều lần với cùng tham số trong một lượt. Ví dụ: `tool_call_dedup_exempt = ["browser"]`.
|
||||
|
||||
## `[agents.<name>]`
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ Operational note for container users:
|
||||
| `max_history_messages` | `50` | Maximum conversation history messages retained per session |
|
||||
| `parallel_tools` | `false` | Enable parallel tool execution within a single iteration |
|
||||
| `tool_dispatcher` | `auto` | Tool dispatch strategy |
|
||||
| `tool_call_dedup_exempt` | `[]` | Tool names exempt from within-turn duplicate-call suppression |
|
||||
|
||||
Notes:
|
||||
|
||||
@@ -88,6 +89,7 @@ Notes:
|
||||
- If a channel message exceeds this value, the runtime returns: `Agent exceeded maximum tool iterations (<value>)`.
|
||||
- In CLI, gateway, and channel tool loops, multiple independent tool calls are executed concurrently by default when the pending calls do not require approval gating; result order remains stable.
|
||||
- `parallel_tools` applies to the `Agent::turn()` API surface. It does not gate the runtime loop used by CLI, gateway, or channel handlers.
|
||||
- `tool_call_dedup_exempt` accepts an array of exact tool names. Tools listed here are allowed to be called multiple times with identical arguments in the same turn, bypassing the dedup check. Example: `tool_call_dedup_exempt = ["browser"]`.
|
||||
|
||||
## `[security.otp]`
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ Lưu ý cho người dùng container:
|
||||
| `max_history_messages` | `50` | Số tin nhắn lịch sử tối đa giữ lại mỗi phiên |
|
||||
| `parallel_tools` | `false` | Bật thực thi tool song song trong một lượt |
|
||||
| `tool_dispatcher` | `auto` | Chiến lược dispatch tool |
|
||||
| `tool_call_dedup_exempt` | `[]` | Tên tool được miễn kiểm tra trùng lặp trong cùng một lượt |
|
||||
|
||||
Lưu ý:
|
||||
|
||||
@@ -77,6 +78,7 @@ Lưu ý:
|
||||
- Nếu tin nhắn kênh vượt giá trị này, runtime trả về: `Agent exceeded maximum tool iterations (<value>)`.
|
||||
- Trong vòng lặp tool của CLI, gateway và channel, các lời gọi tool độc lập được thực thi đồng thời mặc định khi không cần phê duyệt; thứ tự kết quả giữ ổn định.
|
||||
- `parallel_tools` áp dụng cho API `Agent::turn()`. Không ảnh hưởng đến vòng lặp runtime của CLI, gateway hay channel.
|
||||
- `tool_call_dedup_exempt` nhận mảng tên tool chính xác. Các tool trong danh sách được phép gọi nhiều lần với cùng tham số trong một lượt. Ví dụ: `tool_call_dedup_exempt = ["browser"]`.
|
||||
|
||||
## `[agents.<name>]`
|
||||
|
||||
|
||||
+38
-3
@@ -211,8 +211,35 @@ should_attempt_prebuilt_for_resources() {
|
||||
return 1
|
||||
}
|
||||
|
||||
resolve_asset_url() {
|
||||
local asset_name="$1"
|
||||
local api_url="https://api.github.com/repos/zeroclaw-labs/zeroclaw/releases"
|
||||
local releases_json download_url
|
||||
|
||||
# Fetch up to 10 recent releases (includes prereleases) and find the first
|
||||
# one that contains the requested asset.
|
||||
releases_json="$(curl -fsSL "${api_url}?per_page=10" 2>/dev/null || true)"
|
||||
if [[ -z "$releases_json" ]]; then
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Parse with simple grep/sed — avoids jq dependency.
|
||||
download_url="$(printf '%s\n' "$releases_json" \
|
||||
| tr ',' '\n' \
|
||||
| grep '"browser_download_url"' \
|
||||
| sed 's/.*"browser_download_url"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/' \
|
||||
| grep "/${asset_name}\$" \
|
||||
| head -n 1)"
|
||||
|
||||
if [[ -z "$download_url" ]]; then
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "$download_url"
|
||||
}
|
||||
|
||||
install_prebuilt_binary() {
|
||||
local target archive_url temp_dir archive_path extracted_bin install_dir
|
||||
local target archive_url temp_dir archive_path extracted_bin install_dir asset_name
|
||||
|
||||
if ! have_cmd curl; then
|
||||
warn "curl is required for pre-built binary installation."
|
||||
@@ -229,9 +256,17 @@ install_prebuilt_binary() {
|
||||
return 1
|
||||
fi
|
||||
|
||||
archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-${target}.tar.gz"
|
||||
asset_name="zeroclaw-${target}.tar.gz"
|
||||
|
||||
# Try the GitHub API first to find the newest release (including prereleases)
|
||||
# that actually contains the asset, then fall back to /releases/latest/.
|
||||
archive_url="$(resolve_asset_url "$asset_name" || true)"
|
||||
if [[ -z "$archive_url" ]]; then
|
||||
archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/${asset_name}"
|
||||
fi
|
||||
|
||||
temp_dir="$(mktemp -d -t zeroclaw-prebuilt-XXXXXX)"
|
||||
archive_path="$temp_dir/zeroclaw-${target}.tar.gz"
|
||||
archive_path="$temp_dir/${asset_name}"
|
||||
|
||||
info "Attempting pre-built binary install for target: $target"
|
||||
if ! curl -fsSL "$archive_url" -o "$archive_path"; then
|
||||
|
||||
@@ -52,7 +52,7 @@ dev = [
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/zeroclaw-labs/zeroclaw"
|
||||
Documentation = "https://github.com/zeroclaw-labs/zeroclaw/tree/main/python"
|
||||
Documentation = "https://github.com/zeroclaw-labs/zeroclaw/tree/master/python"
|
||||
Repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
||||
Issues = "https://github.com/zeroclaw-labs/zeroclaw/issues"
|
||||
|
||||
|
||||
+196
-7
@@ -63,8 +63,17 @@ pub(crate) fn scrub_credentials(input: &str) -> String {
|
||||
.map(|m| m.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Preserve first 4 chars for context, then redact
|
||||
let prefix = if val.len() > 4 { &val[..4] } else { "" };
|
||||
// Preserve first 4 chars for context, then redact.
|
||||
// Use char_indices to find the byte offset of the 4th character
|
||||
// so we never slice in the middle of a multi-byte UTF-8 sequence.
|
||||
let prefix = if val.len() > 4 {
|
||||
val.char_indices()
|
||||
.nth(4)
|
||||
.map(|(byte_idx, _)| &val[..byte_idx])
|
||||
.unwrap_or(val)
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
if full_match.contains(':') {
|
||||
if full_match.contains('"') {
|
||||
@@ -258,6 +267,12 @@ async fn build_context(mem: &dyn Memory, user_msg: &str, min_relevance_score: f6
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
continue;
|
||||
}
|
||||
// Skip entries containing tool_result blocks — they can leak
|
||||
// stale tool output from previous heartbeat ticks into new
|
||||
// sessions, presenting the LLM with orphan tool_result data.
|
||||
if entry.content.contains("<tool_result") {
|
||||
continue;
|
||||
}
|
||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||
}
|
||||
if context == "[Memory context]\n" {
|
||||
@@ -1957,6 +1972,7 @@ pub(crate) async fn agent_turn(
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2156,6 +2172,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
excluded_tools: &[String],
|
||||
dedup_exempt_tools: &[String],
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
DEFAULT_MAX_TOOL_ITERATIONS
|
||||
@@ -2565,7 +2582,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
|
||||
let signature = tool_call_signature(&tool_name, &tool_args);
|
||||
if !seen_tool_signatures.insert(signature) {
|
||||
let dedup_exempt = dedup_exempt_tools.iter().any(|e| e == &tool_name);
|
||||
if !dedup_exempt && !seen_tool_signatures.insert(signature) {
|
||||
let duplicate = format!(
|
||||
"Skipped duplicate tool call '{tool_name}' with identical arguments in this turn."
|
||||
);
|
||||
@@ -2879,6 +2897,7 @@ pub async fn run(
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(config.provider_timeout_secs),
|
||||
};
|
||||
|
||||
let provider: Box<dyn Provider> = providers::create_routed_provider_with_options(
|
||||
@@ -3108,6 +3127,7 @@ pub async fn run(
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
@@ -3125,8 +3145,11 @@ pub async fn run(
|
||||
print!("> ");
|
||||
let _ = std::io::stdout().flush();
|
||||
|
||||
let mut input = String::new();
|
||||
match std::io::stdin().read_line(&mut input) {
|
||||
// Read raw bytes to avoid UTF-8 validation errors when PTY
|
||||
// transport splits multi-byte characters at frame boundaries
|
||||
// (e.g. CJK input with spaces over kubectl exec / SSH).
|
||||
let mut raw = Vec::new();
|
||||
match std::io::BufRead::read_until(&mut std::io::stdin().lock(), b'\n', &mut raw) {
|
||||
Ok(0) => break,
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
@@ -3134,6 +3157,7 @@ pub async fn run(
|
||||
break;
|
||||
}
|
||||
}
|
||||
let input = String::from_utf8_lossy(&raw).into_owned();
|
||||
|
||||
let user_input = input.trim().to_string();
|
||||
if user_input.is_empty() {
|
||||
@@ -3156,10 +3180,17 @@ pub async fn run(
|
||||
print!("Continue? [y/N] ");
|
||||
let _ = std::io::stdout().flush();
|
||||
|
||||
let mut confirm = String::new();
|
||||
if std::io::stdin().read_line(&mut confirm).is_err() {
|
||||
let mut confirm_raw = Vec::new();
|
||||
if std::io::BufRead::read_until(
|
||||
&mut std::io::stdin().lock(),
|
||||
b'\n',
|
||||
&mut confirm_raw,
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let confirm = String::from_utf8_lossy(&confirm_raw);
|
||||
if !matches!(confirm.trim().to_lowercase().as_str(), "y" | "yes") {
|
||||
println!("Cancelled.\n");
|
||||
continue;
|
||||
@@ -3230,6 +3261,7 @@ pub async fn run(
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -3337,6 +3369,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(config.provider_timeout_secs),
|
||||
};
|
||||
let provider: Box<dyn Provider> = providers::create_routed_provider_with_options(
|
||||
provider_name,
|
||||
@@ -3774,6 +3807,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
@@ -3820,6 +3854,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
@@ -3860,6 +3895,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
@@ -3986,6 +4022,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("parallel execution should complete");
|
||||
@@ -4055,6 +4092,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish after deduplicating repeated calls");
|
||||
@@ -4074,6 +4112,142 @@ mod tests {
|
||||
assert!(tool_results.content.contains("Skipped duplicate tool call"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_dedup_exempt_allows_repeated_calls() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"count_tool","arguments":{"value":"A"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name":"count_tool","arguments":{"value":"A"}}
|
||||
</tool_call>"#,
|
||||
"done",
|
||||
]);
|
||||
|
||||
let invocations = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingTool::new(
|
||||
"count_tool",
|
||||
Arc::clone(&invocations),
|
||||
))];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("run tool calls"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
let exempt = vec!["count_tool".to_string()];
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&exempt,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish with exempt tool executing twice");
|
||||
|
||||
assert_eq!(result, "done");
|
||||
assert_eq!(
|
||||
invocations.load(Ordering::SeqCst),
|
||||
2,
|
||||
"exempt tool should execute both duplicate calls"
|
||||
);
|
||||
|
||||
let tool_results = history
|
||||
.iter()
|
||||
.find(|msg| msg.role == "user" && msg.content.starts_with("[Tool results]"))
|
||||
.expect("prompt-mode tool result payload should be present");
|
||||
assert!(
|
||||
!tool_results.content.contains("Skipped duplicate tool call"),
|
||||
"exempt tool calls should not be suppressed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_dedup_exempt_only_affects_listed_tools() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"count_tool","arguments":{"value":"A"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name":"count_tool","arguments":{"value":"A"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name":"other_tool","arguments":{"value":"B"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name":"other_tool","arguments":{"value":"B"}}
|
||||
</tool_call>"#,
|
||||
"done",
|
||||
]);
|
||||
|
||||
let count_invocations = Arc::new(AtomicUsize::new(0));
|
||||
let other_invocations = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![
|
||||
Box::new(CountingTool::new(
|
||||
"count_tool",
|
||||
Arc::clone(&count_invocations),
|
||||
)),
|
||||
Box::new(CountingTool::new(
|
||||
"other_tool",
|
||||
Arc::clone(&other_invocations),
|
||||
)),
|
||||
];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("run tool calls"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
let exempt = vec!["count_tool".to_string()];
|
||||
|
||||
let _result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&exempt,
|
||||
)
|
||||
.await
|
||||
.expect("loop should complete");
|
||||
|
||||
assert_eq!(
|
||||
count_invocations.load(Ordering::SeqCst),
|
||||
2,
|
||||
"exempt tool should execute both calls"
|
||||
);
|
||||
assert_eq!(
|
||||
other_invocations.load(Ordering::SeqCst),
|
||||
1,
|
||||
"non-exempt tool should still be deduped"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_native_mode_preserves_fallback_tool_call_ids() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
@@ -4111,6 +4285,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("native fallback id flow should complete");
|
||||
@@ -5468,6 +5643,20 @@ Let me check the result."#;
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_credentials_multibyte_chars_no_panic() {
|
||||
// Regression test for #3024: byte index 4 is not a char boundary
|
||||
// when the captured value contains multi-byte UTF-8 characters.
|
||||
// The regex only matches quoted values for non-ASCII content, since
|
||||
// capture group 4 is restricted to [a-zA-Z0-9_\-\.].
|
||||
let input = "password=\"\u{4f60}\u{7684}WiFi\u{5bc6}\u{7801}ab\"";
|
||||
let result = scrub_credentials(input);
|
||||
assert!(
|
||||
result.contains("[REDACTED]"),
|
||||
"multi-byte quoted value should be redacted without panic, got: {result}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_credentials_short_values_not_redacted() {
|
||||
// Values shorter than 8 chars should not be redacted
|
||||
|
||||
@@ -622,7 +622,18 @@ impl Channel for DiscordChannel {
|
||||
msg = read.next() => {
|
||||
let msg = match msg {
|
||||
Some(Ok(Message::Text(t))) => t,
|
||||
Some(Ok(Message::Ping(payload))) => {
|
||||
if write.send(Message::Pong(payload)).await.is_err() {
|
||||
tracing::warn!("Discord: pong send failed, reconnecting");
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Err(e)) => {
|
||||
tracing::warn!("Discord: websocket read error: {e}, reconnecting");
|
||||
break;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
|
||||
+8
-1
@@ -1,6 +1,10 @@
|
||||
use crate::channels::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
use std::sync::atomic::AtomicU32;
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
@@ -13,7 +17,10 @@ use tokio_rustls::rustls;
|
||||
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||||
|
||||
/// Monotonic counter to ensure unique message IDs under burst traffic.
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
static MSG_SEQ: AtomicU64 = AtomicU64::new(0);
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
static MSG_SEQ: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
/// IRC over TLS channel.
|
||||
///
|
||||
|
||||
+289
-22
@@ -61,20 +61,33 @@ impl LinqChannel {
|
||||
|
||||
/// Parse an incoming webhook payload from Linq and extract messages.
|
||||
///
|
||||
/// Linq webhook envelope:
|
||||
/// Supports two webhook formats:
|
||||
///
|
||||
/// **New format (webhook_version 2026-02-03):**
|
||||
/// ```json
|
||||
/// {
|
||||
/// "api_version": "v3",
|
||||
/// "webhook_version": "2026-02-03",
|
||||
/// "event_type": "message.received",
|
||||
/// "data": {
|
||||
/// "id": "msg-...",
|
||||
/// "direction": "inbound",
|
||||
/// "sender_handle": { "handle": "+1...", "is_me": false },
|
||||
/// "chat": { "id": "chat-..." },
|
||||
/// "parts": [{ "type": "text", "value": "..." }]
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// **Legacy format (webhook_version 2025-01-01):**
|
||||
/// ```json
|
||||
/// {
|
||||
/// "api_version": "v3",
|
||||
/// "event_type": "message.received",
|
||||
/// "event_id": "...",
|
||||
/// "created_at": "...",
|
||||
/// "trace_id": "...",
|
||||
/// "data": {
|
||||
/// "chat_id": "...",
|
||||
/// "from": "+1...",
|
||||
/// "recipient_phone": "+1...",
|
||||
/// "is_from_me": false,
|
||||
/// "service": "iMessage",
|
||||
/// "message": {
|
||||
/// "id": "...",
|
||||
/// "parts": [{ "type": "text", "value": "..." }]
|
||||
@@ -99,18 +112,44 @@ impl LinqChannel {
|
||||
return messages;
|
||||
};
|
||||
|
||||
// Detect format: new format has `sender_handle`, legacy has `from`.
|
||||
let is_new_format = data.get("sender_handle").is_some();
|
||||
|
||||
// Skip messages sent by the bot itself
|
||||
if data
|
||||
.get("is_from_me")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let is_from_me = if is_new_format {
|
||||
// New format: data.sender_handle.is_me or data.direction == "outbound"
|
||||
data.get("sender_handle")
|
||||
.and_then(|sh| sh.get("is_me"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
|| data
|
||||
.get("direction")
|
||||
.and_then(|d| d.as_str())
|
||||
.is_some_and(|d| d == "outbound")
|
||||
} else {
|
||||
// Legacy format: data.is_from_me
|
||||
data.get("is_from_me")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
if is_from_me {
|
||||
tracing::debug!("Linq: skipping is_from_me message");
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Get sender phone number
|
||||
let Some(from) = data.get("from").and_then(|f| f.as_str()) else {
|
||||
let from = if is_new_format {
|
||||
// New format: data.sender_handle.handle
|
||||
data.get("sender_handle")
|
||||
.and_then(|sh| sh.get("handle"))
|
||||
.and_then(|h| h.as_str())
|
||||
} else {
|
||||
// Legacy format: data.from
|
||||
data.get("from").and_then(|f| f.as_str())
|
||||
};
|
||||
|
||||
let Some(from) = from else {
|
||||
return messages;
|
||||
};
|
||||
|
||||
@@ -132,18 +171,33 @@ impl LinqChannel {
|
||||
}
|
||||
|
||||
// Get chat_id for reply routing
|
||||
let chat_id = data
|
||||
.get("chat_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Extract text from message parts
|
||||
let Some(message) = data.get("message") else {
|
||||
return messages;
|
||||
let chat_id = if is_new_format {
|
||||
// New format: data.chat.id
|
||||
data.get("chat")
|
||||
.and_then(|c| c.get("id"))
|
||||
.and_then(|id| id.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else {
|
||||
// Legacy format: data.chat_id
|
||||
data.get("chat_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
};
|
||||
|
||||
let Some(parts) = message.get("parts").and_then(|p| p.as_array()) else {
|
||||
// Extract message parts
|
||||
let parts = if is_new_format {
|
||||
// New format: data.parts (directly on data)
|
||||
data.get("parts").and_then(|p| p.as_array())
|
||||
} else {
|
||||
// Legacy format: data.message.parts
|
||||
data.get("message")
|
||||
.and_then(|m| m.get("parts"))
|
||||
.and_then(|p| p.as_array())
|
||||
};
|
||||
|
||||
let Some(parts) = parts else {
|
||||
return messages;
|
||||
};
|
||||
|
||||
@@ -790,4 +844,217 @@ mod tests {
|
||||
let ch = make_channel();
|
||||
assert_eq!(ch.phone_number(), "+15551234567");
|
||||
}
|
||||
|
||||
// ---- New format (2026-02-03) tests ----
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_text_message() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"api_version": "v3",
|
||||
"webhook_version": "2026-02-03",
|
||||
"event_type": "message.received",
|
||||
"event_id": "evt-123",
|
||||
"created_at": "2026-03-01T12:00:00Z",
|
||||
"trace_id": "trace-456",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"service": "iMessage",
|
||||
"parts": [{
|
||||
"type": "text",
|
||||
"value": "Hello from new format!"
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "+1234567890");
|
||||
assert_eq!(msgs[0].content, "Hello from new format!");
|
||||
assert_eq!(msgs[0].channel, "linq");
|
||||
assert_eq!(msgs[0].reply_target, "chat-789");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_skip_is_me() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "outbound",
|
||||
"sender_handle": {
|
||||
"handle": "+15551234567",
|
||||
"is_me": true
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "My own message" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert!(
|
||||
msgs.is_empty(),
|
||||
"is_me messages should be skipped in new format"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_skip_outbound_direction() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "outbound",
|
||||
"sender_handle": {
|
||||
"handle": "+15551234567",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "Outbound" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert!(msgs.is_empty(), "outbound direction should be skipped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_unauthorized_sender() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+9999999999",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "Spam" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert!(
|
||||
msgs.is_empty(),
|
||||
"Unauthorized senders should be filtered in new format"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_media_image() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{
|
||||
"type": "media",
|
||||
"url": "https://example.com/photo.png",
|
||||
"mime_type": "image/png"
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "[IMAGE:https://example.com/photo.png]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_multiple_parts() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [
|
||||
{ "type": "text", "value": "Check this out" },
|
||||
{ "type": "media", "url": "https://example.com/img.jpg", "mime_type": "image/jpeg" }
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(
|
||||
msgs[0].content,
|
||||
"Check this out\n[IMAGE:https://example.com/img.jpg]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_fallback_reply_target_when_no_chat() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"parts": [{ "type": "text", "value": "Hi" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].reply_target, "+1234567890");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_normalizes_phone() {
|
||||
let ch = LinqChannel::new(
|
||||
"tok".into(),
|
||||
"+15551234567".into(),
|
||||
vec!["+1234567890".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "Hi" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "+1234567890");
|
||||
}
|
||||
}
|
||||
|
||||
+337
-10
@@ -89,7 +89,11 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
use std::sync::atomic::AtomicU32;
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -290,6 +294,7 @@ struct ChannelRuntimeContext {
|
||||
multimodal: crate::config::MultimodalConfig,
|
||||
hooks: Option<Arc<crate::hooks::HookRunner>>,
|
||||
non_cli_excluded_tools: Arc<Vec<String>>,
|
||||
tool_call_dedup_exempt: Arc<Vec<String>>,
|
||||
model_routes: Arc<Vec<crate::config::ModelRouteConfig>>,
|
||||
}
|
||||
|
||||
@@ -570,6 +575,25 @@ fn normalize_cached_channel_turns(turns: Vec<ChatMessage>) -> Vec<ChatMessage> {
|
||||
normalized
|
||||
}
|
||||
|
||||
/// Remove `<tool_result …>…</tool_result>` blocks (and a leading `[Tool results]`
|
||||
/// header, if present) from a conversation-history entry so that stale tool
|
||||
/// output is never presented to the LLM without the corresponding `<tool_call>`.
|
||||
fn strip_tool_result_content(text: &str) -> String {
|
||||
static TOOL_RESULT_RE: std::sync::LazyLock<regex::Regex> = std::sync::LazyLock::new(|| {
|
||||
regex::Regex::new(r"(?s)<tool_result[^>]*>.*?</tool_result>").unwrap()
|
||||
});
|
||||
|
||||
let cleaned = TOOL_RESULT_RE.replace_all(text, "");
|
||||
let cleaned = cleaned.trim();
|
||||
|
||||
// If the only remaining content is the header, drop it entirely.
|
||||
if cleaned == "[Tool results]" || cleaned.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
cleaned.to_string()
|
||||
}
|
||||
|
||||
fn supports_runtime_model_switch(channel_name: &str) -> bool {
|
||||
matches!(channel_name, "telegram" | "discord" | "matrix")
|
||||
}
|
||||
@@ -943,6 +967,22 @@ fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Skip entries containing image markers to prevent duplication.
|
||||
// When auto_save stores a photo message to memory, a subsequent
|
||||
// memory recall on the same turn would surface the marker again,
|
||||
// causing two identical image blocks in the provider request.
|
||||
if content.contains("[IMAGE:") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Skip entries containing tool_result blocks. After a daemon restart
|
||||
// these can be recalled from SQLite and injected as memory context,
|
||||
// presenting the LLM with a `<tool_result>` without a preceding
|
||||
// `<tool_call>` and triggering hallucinated output.
|
||||
if content.contains("<tool_result") {
|
||||
return true;
|
||||
}
|
||||
|
||||
content.chars().count() > MEMORY_CONTEXT_MAX_CHARS
|
||||
}
|
||||
|
||||
@@ -1749,6 +1789,15 @@ async fn process_channel_message(
|
||||
.unwrap_or_default();
|
||||
let mut prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
||||
|
||||
// Strip stale tool_result blocks from cached turns so the LLM never
|
||||
// sees a `<tool_result>` without a preceding `<tool_call>`, which
|
||||
// causes hallucinated output on subsequent heartbeat ticks or sessions.
|
||||
for turn in &mut prior_turns {
|
||||
if turn.content.contains("<tool_result") {
|
||||
turn.content = strip_tool_result_content(&turn.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Only enrich with memory context when there is no prior conversation
|
||||
// history. Follow-up turns already include context from previous messages.
|
||||
if !had_prior_history {
|
||||
@@ -1919,6 +1968,7 @@ async fn process_channel_message(
|
||||
} else {
|
||||
ctx.non_cli_excluded_tools.as_ref()
|
||||
},
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@@ -2295,7 +2345,10 @@ async fn run_message_dispatch_loop(
|
||||
String,
|
||||
InFlightSenderTaskState,
|
||||
>::new()));
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
let task_sequence = Arc::new(AtomicU64::new(1));
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
let task_sequence = Arc::new(AtomicU32::new(1));
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||||
@@ -2313,7 +2366,7 @@ async fn run_message_dispatch_loop(
|
||||
let sender_scope_key = interruption_scope_key(&msg);
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let completion = Arc::new(InFlightTaskCompletion::new());
|
||||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
|
||||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed) as u64;
|
||||
|
||||
if interrupt_enabled {
|
||||
let previous = {
|
||||
@@ -2831,9 +2884,86 @@ pub(crate) async fn handle_command(command: crate::ChannelCommands, config: &Con
|
||||
crate::ChannelCommands::BindTelegram { identity } => {
|
||||
bind_telegram_identity(config, &identity).await
|
||||
}
|
||||
crate::ChannelCommands::Send {
|
||||
message,
|
||||
channel_id,
|
||||
recipient,
|
||||
} => send_channel_message(config, &channel_id, &recipient, &message).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a single channel instance by config section name (e.g. "telegram").
|
||||
fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Channel>> {
|
||||
match channel_id {
|
||||
"telegram" => {
|
||||
let tg = config
|
||||
.channels_config
|
||||
.telegram
|
||||
.as_ref()
|
||||
.context("Telegram channel is not configured")?;
|
||||
Ok(Arc::new(
|
||||
TelegramChannel::new(
|
||||
tg.bot_token.clone(),
|
||||
tg.allowed_users.clone(),
|
||||
tg.mention_only,
|
||||
)
|
||||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms)
|
||||
.with_transcription(config.transcription.clone())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
))
|
||||
}
|
||||
"discord" => {
|
||||
let dc = config
|
||||
.channels_config
|
||||
.discord
|
||||
.as_ref()
|
||||
.context("Discord channel is not configured")?;
|
||||
Ok(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)))
|
||||
}
|
||||
"slack" => {
|
||||
let sl = config
|
||||
.channels_config
|
||||
.slack
|
||||
.as_ref()
|
||||
.context("Slack channel is not configured")?;
|
||||
Ok(Arc::new(
|
||||
SlackChannel::new(
|
||||
sl.bot_token.clone(),
|
||||
sl.app_token.clone(),
|
||||
sl.channel_id.clone(),
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
))
|
||||
}
|
||||
other => anyhow::bail!("Unknown channel '{other}'. Supported: telegram, discord, slack"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a one-off message to a configured channel.
|
||||
async fn send_channel_message(
|
||||
config: &Config,
|
||||
channel_id: &str,
|
||||
recipient: &str,
|
||||
message: &str,
|
||||
) -> Result<()> {
|
||||
let channel = build_channel_by_id(config, channel_id)?;
|
||||
let msg = SendMessage::new(message, recipient);
|
||||
channel
|
||||
.send(&msg)
|
||||
.await
|
||||
.with_context(|| format!("Failed to send message via {channel_id}"))?;
|
||||
println!("Message sent via {channel_id}.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ChannelHealthState {
|
||||
Healthy,
|
||||
@@ -3217,6 +3347,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(config.provider_timeout_secs),
|
||||
};
|
||||
let provider: Arc<dyn Provider> = Arc::from(
|
||||
create_resilient_provider_nonblocking(
|
||||
@@ -3276,7 +3407,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||
let mut built_tools = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
@@ -3290,7 +3421,44 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
));
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
config.mcp.servers.len()
|
||||
);
|
||||
match crate::tools::mcp_client::McpRegistry::connect_all(&config.mcp.servers).await {
|
||||
Ok(registry) => {
|
||||
let registry = std::sync::Arc::new(registry);
|
||||
let names = registry.tool_names();
|
||||
let mut registered = 0usize;
|
||||
for name in names {
|
||||
if let Some(def) = registry.get_tool_def(&name).await {
|
||||
let wrapper = crate::tools::mcp_tool::McpToolWrapper::new(
|
||||
name,
|
||||
def,
|
||||
std::sync::Arc::clone(®istry),
|
||||
);
|
||||
built_tools.push(Box::new(wrapper));
|
||||
registered += 1;
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
"MCP: {} tool(s) registered from {} server(s)",
|
||||
registered,
|
||||
registry.server_count()
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal — daemon continues with the tools registered above.
|
||||
tracing::error!("MCP registry failed to initialize: {e:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_registry = Arc::new(built_tools);
|
||||
|
||||
let skills = crate::skills::load_skills_with_config(&workspace, &config);
|
||||
|
||||
@@ -3514,6 +3682,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
None
|
||||
},
|
||||
non_cli_excluded_tools: Arc::new(config.autonomy.non_cli_excluded_tools.clone()),
|
||||
tool_call_dedup_exempt: Arc::new(config.agent.tool_call_dedup_exempt.clone()),
|
||||
model_routes: Arc::new(config.model_routes.clone()),
|
||||
});
|
||||
|
||||
@@ -3614,6 +3783,53 @@ mod tests {
|
||||
"fabricated memory"
|
||||
));
|
||||
assert!(!should_skip_memory_context_entry("telegram_123_45", "hi"));
|
||||
|
||||
// Entries containing image markers must be skipped to prevent
|
||||
// auto-saved photo messages from duplicating image blocks (#2403).
|
||||
assert!(should_skip_memory_context_entry(
|
||||
"telegram_user_msg_99",
|
||||
"[IMAGE:/tmp/workspace/photo_1_2.jpg]"
|
||||
));
|
||||
assert!(should_skip_memory_context_entry(
|
||||
"telegram_user_msg_100",
|
||||
"[IMAGE:/tmp/workspace/photo_1_2.jpg]\n\nCheck this screenshot"
|
||||
));
|
||||
// Plain text without image markers should not be skipped.
|
||||
assert!(!should_skip_memory_context_entry(
|
||||
"telegram_user_msg_101",
|
||||
"Please describe the image"
|
||||
));
|
||||
|
||||
// Entries containing tool_result blocks must be skipped (#3402).
|
||||
assert!(should_skip_memory_context_entry(
|
||||
"telegram_user_msg_200",
|
||||
r#"[Tool results]
|
||||
<tool_result name="shell">Mon Feb 20</tool_result>"#
|
||||
));
|
||||
assert!(!should_skip_memory_context_entry(
|
||||
"telegram_user_msg_201",
|
||||
"plain text without tool results"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_tool_result_content_removes_blocks_and_header() {
|
||||
let input = r#"[Tool results]
|
||||
<tool_result name="shell">Mon Feb 20</tool_result>
|
||||
<tool_result name="http_request">{"status":200}</tool_result>"#;
|
||||
assert_eq!(strip_tool_result_content(input), "");
|
||||
|
||||
let mixed = "Some context\n<tool_result name=\"shell\">ok</tool_result>\nMore text";
|
||||
let cleaned = strip_tool_result_content(mixed);
|
||||
assert!(cleaned.contains("Some context"));
|
||||
assert!(cleaned.contains("More text"));
|
||||
assert!(!cleaned.contains("tool_result"));
|
||||
|
||||
assert_eq!(
|
||||
strip_tool_result_content("no tool results here"),
|
||||
"no tool results here"
|
||||
);
|
||||
assert_eq!(strip_tool_result_content(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -3728,16 +3944,17 @@ mod tests {
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
|
||||
let histories = ctx
|
||||
let locked_histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let kept = histories
|
||||
let kept = locked_histories
|
||||
.get(&sender)
|
||||
.expect("sender history should remain");
|
||||
assert_eq!(kept.len(), CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES);
|
||||
@@ -3778,6 +3995,7 @@ mod tests {
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
};
|
||||
|
||||
@@ -3831,16 +4049,17 @@ mod tests {
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
|
||||
let histories = ctx
|
||||
let locked_histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let turns = histories
|
||||
let turns = locked_histories
|
||||
.get(&sender)
|
||||
.expect("sender history should remain");
|
||||
assert_eq!(turns.len(), 2);
|
||||
@@ -4305,6 +4524,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
@@ -4366,6 +4586,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
@@ -4443,6 +4664,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4503,6 +4725,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4573,6 +4796,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4663,6 +4887,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4735,6 +4960,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4822,6 +5048,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4841,10 +5068,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
.await;
|
||||
|
||||
{
|
||||
let mut store = runtime_config_store()
|
||||
let mut cleanup_store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
store.remove(&config_path);
|
||||
cleanup_store.remove(&config_path);
|
||||
}
|
||||
|
||||
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
@@ -4894,6 +5121,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -4956,6 +5184,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -5129,6 +5358,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -5210,6 +5440,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -5303,6 +5534,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -5378,6 +5610,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -5438,6 +5671,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -5919,6 +6153,47 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(context.contains("Age is 45"));
|
||||
}
|
||||
|
||||
/// Auto-saved photo messages must not surface through memory context,
|
||||
/// otherwise the image marker gets duplicated in the provider request (#2403).
|
||||
#[tokio::test]
|
||||
async fn build_memory_context_excludes_image_marker_entries() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
||||
// Simulate auto-save of a photo message containing an [IMAGE:] marker.
|
||||
mem.store(
|
||||
"telegram_user_msg_photo",
|
||||
"[IMAGE:/tmp/workspace/photo_1_2.jpg]\n\nDescribe this screenshot",
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// Also store a plain text entry that shares a word with the query
|
||||
// so the FTS recall returns both entries.
|
||||
mem.store(
|
||||
"screenshot_preference",
|
||||
"User prefers screenshot descriptions to be concise",
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_memory_context(&mem, "screenshot", 0.0).await;
|
||||
|
||||
// The image-marker entry must be excluded to prevent duplication.
|
||||
assert!(
|
||||
!context.contains("[IMAGE:"),
|
||||
"memory context must not contain image markers, got: {context}"
|
||||
);
|
||||
// Plain text entries should still be included.
|
||||
assert!(
|
||||
context.contains("screenshot descriptions"),
|
||||
"plain text entry should remain in context, got: {context}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_restores_per_sender_history_on_follow_ups() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
@@ -5955,6 +6230,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -6041,6 +6317,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -6127,6 +6404,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -6677,6 +6955,7 @@ This is an example JSON object for profile settings."#;
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -6744,6 +7023,7 @@ This is an example JSON object for profile settings."#;
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
});
|
||||
|
||||
@@ -6808,4 +7088,51 @@ This is an example JSON object for profile settings."#;
|
||||
"failed vision turn must not persist image marker content"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_channel_by_id_unknown_channel_returns_error() {
|
||||
let config = Config::default();
|
||||
match build_channel_by_id(&config, "nonexistent") {
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
assert!(
|
||||
err_msg.contains("Unknown channel"),
|
||||
"expected 'Unknown channel' in error, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
Ok(_) => panic!("should fail for unknown channel"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_channel_by_id_unconfigured_telegram_returns_error() {
|
||||
let config = Config::default();
|
||||
match build_channel_by_id(&config, "telegram") {
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
assert!(
|
||||
err_msg.contains("not configured"),
|
||||
"expected 'not configured' in error, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
Ok(_) => panic!("should fail when telegram is not configured"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_channel_by_id_configured_telegram_succeeds() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.telegram = Some(crate::config::schema::TelegramConfig {
|
||||
bot_token: "test-token".to_string(),
|
||||
allowed_users: vec![],
|
||||
stream_mode: crate::config::StreamMode::Off,
|
||||
draft_update_interval_ms: 1000,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
});
|
||||
match build_channel_by_id(&config, "telegram") {
|
||||
Ok(channel) => assert_eq!(channel.name(), "telegram"),
|
||||
Err(e) => panic!("should succeed when telegram is configured: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+7
-6
@@ -10,12 +10,13 @@ pub use schema::{
|
||||
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig,
|
||||
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
|
||||
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
@@ -90,6 +90,13 @@ pub struct Config {
|
||||
)]
|
||||
pub default_temperature: f64,
|
||||
|
||||
/// HTTP request timeout in seconds for LLM provider API calls. Default: `120`.
|
||||
///
|
||||
/// Increase for slower backends (e.g., llama.cpp on constrained hardware)
|
||||
/// that need more time processing large contexts.
|
||||
#[serde(default = "default_provider_timeout_secs")]
|
||||
pub provider_timeout_secs: u64,
|
||||
|
||||
/// Observability backend configuration (`[observability]`).
|
||||
#[serde(default)]
|
||||
pub observability: ObservabilityConfig,
|
||||
@@ -225,6 +232,10 @@ pub struct Config {
|
||||
/// Text-to-Speech configuration (`[tts]`).
|
||||
#[serde(default)]
|
||||
pub tts: TtsConfig,
|
||||
|
||||
/// External MCP server connections (`[mcp]`).
|
||||
#[serde(default, alias = "mcpServers")]
|
||||
pub mcp: McpConfig,
|
||||
}
|
||||
|
||||
/// Named provider profile definition compatible with Codex app-server style config.
|
||||
@@ -295,6 +306,13 @@ fn default_temperature() -> f64 {
|
||||
DEFAULT_TEMPERATURE
|
||||
}
|
||||
|
||||
/// Default provider HTTP request timeout: 120 seconds.
|
||||
const DEFAULT_PROVIDER_TIMEOUT_SECS: u64 = 120;
|
||||
|
||||
fn default_provider_timeout_secs() -> u64 {
|
||||
DEFAULT_PROVIDER_TIMEOUT_SECS
|
||||
}
|
||||
|
||||
/// Validate that a temperature value is within the allowed range.
|
||||
pub fn validate_temperature(value: f64) -> std::result::Result<f64, String> {
|
||||
if TEMPERATURE_RANGE.contains(&value) {
|
||||
@@ -441,6 +459,60 @@ impl Default for TranscriptionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── MCP ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Transport type for MCP server connections.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum McpTransport {
|
||||
/// Spawn a local process and communicate over stdin/stdout.
|
||||
#[default]
|
||||
Stdio,
|
||||
/// Connect via HTTP POST.
|
||||
Http,
|
||||
/// Connect via HTTP + Server-Sent Events.
|
||||
Sse,
|
||||
}
|
||||
|
||||
/// Configuration for a single external MCP server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct McpServerConfig {
|
||||
/// Display name used as a tool prefix (`<server>__<tool>`).
|
||||
pub name: String,
|
||||
/// Transport type (default: stdio).
|
||||
#[serde(default)]
|
||||
pub transport: McpTransport,
|
||||
/// URL for HTTP/SSE transports.
|
||||
#[serde(default)]
|
||||
pub url: Option<String>,
|
||||
/// Executable to spawn for stdio transport.
|
||||
#[serde(default)]
|
||||
pub command: String,
|
||||
/// Command arguments for stdio transport.
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
/// Optional environment variables for stdio transport.
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
/// Optional HTTP headers for HTTP/SSE transports.
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
/// Optional per-call timeout in seconds (hard capped in validation).
|
||||
#[serde(default)]
|
||||
pub tool_timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
/// External MCP client configuration (`[mcp]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct McpConfig {
|
||||
/// Enable MCP tool loading.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Configured MCP servers.
|
||||
#[serde(default, alias = "mcpServers")]
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
// ── TTS (Text-to-Speech) ─────────────────────────────────────────
|
||||
|
||||
fn default_tts_provider() -> String {
|
||||
@@ -604,6 +676,9 @@ pub struct AgentConfig {
|
||||
/// Tool dispatch strategy (e.g. `"auto"`). Default: `"auto"`.
|
||||
#[serde(default = "default_agent_tool_dispatcher")]
|
||||
pub tool_dispatcher: String,
|
||||
/// Tools exempt from the within-turn duplicate-call dedup check. Default: `[]`.
|
||||
#[serde(default)]
|
||||
pub tool_call_dedup_exempt: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@@ -626,6 +701,7 @@ impl Default for AgentConfig {
|
||||
max_history_messages: default_agent_max_history_messages(),
|
||||
parallel_tools: false,
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1616,6 +1692,65 @@ fn service_selector_matches(selector: &str, service_key: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
const MCP_MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
fn validate_mcp_config(config: &McpConfig) -> Result<()> {
|
||||
let mut seen_names = std::collections::HashSet::new();
|
||||
for (i, server) in config.servers.iter().enumerate() {
|
||||
let name = server.name.trim();
|
||||
if name.is_empty() {
|
||||
anyhow::bail!("mcp.servers[{i}].name must not be empty");
|
||||
}
|
||||
if !seen_names.insert(name.to_ascii_lowercase()) {
|
||||
anyhow::bail!("mcp.servers contains duplicate name: {name}");
|
||||
}
|
||||
|
||||
if let Some(timeout) = server.tool_timeout_secs {
|
||||
if timeout == 0 {
|
||||
anyhow::bail!("mcp.servers[{i}].tool_timeout_secs must be greater than 0");
|
||||
}
|
||||
if timeout > MCP_MAX_TOOL_TIMEOUT_SECS {
|
||||
anyhow::bail!(
|
||||
"mcp.servers[{i}].tool_timeout_secs exceeds max {MCP_MAX_TOOL_TIMEOUT_SECS}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match server.transport {
|
||||
McpTransport::Stdio => {
|
||||
if server.command.trim().is_empty() {
|
||||
anyhow::bail!(
|
||||
"mcp.servers[{i}] with transport=stdio requires non-empty command"
|
||||
);
|
||||
}
|
||||
}
|
||||
McpTransport::Http | McpTransport::Sse => {
|
||||
let url = server
|
||||
.url
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"mcp.servers[{i}] with transport={} requires url",
|
||||
match server.transport {
|
||||
McpTransport::Http => "http",
|
||||
McpTransport::Sse => "sse",
|
||||
McpTransport::Stdio => "stdio",
|
||||
}
|
||||
)
|
||||
})?;
|
||||
let parsed = reqwest::Url::parse(url)
|
||||
.with_context(|| format!("mcp.servers[{i}].url is not a valid URL"))?;
|
||||
if !matches!(parsed.scheme(), "http" | "https") {
|
||||
anyhow::bail!("mcp.servers[{i}].url must use http/https");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_proxy_url(field: &str, url: &str) -> Result<()> {
|
||||
let parsed = reqwest::Url::parse(url)
|
||||
.with_context(|| format!("Invalid {field} URL: '{url}' is not a valid URL"))?;
|
||||
@@ -3832,6 +3967,7 @@ impl Default for Config {
|
||||
default_model: Some("anthropic/claude-sonnet-4.6".to_string()),
|
||||
model_providers: HashMap::new(),
|
||||
default_temperature: default_temperature(),
|
||||
provider_timeout_secs: default_provider_timeout_secs(),
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
security: SecurityConfig::default(),
|
||||
@@ -3866,6 +4002,7 @@ impl Default for Config {
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
tts: TtsConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4825,6 +4962,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// MCP
|
||||
if self.mcp.enabled {
|
||||
validate_mcp_config(&self.mcp)?;
|
||||
}
|
||||
|
||||
// Proxy (delegate to existing validation)
|
||||
self.proxy.validate()?;
|
||||
|
||||
@@ -4888,6 +5030,15 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// Provider HTTP timeout: ZEROCLAW_PROVIDER_TIMEOUT_SECS
|
||||
if let Ok(timeout_secs) = std::env::var("ZEROCLAW_PROVIDER_TIMEOUT_SECS") {
|
||||
if let Ok(timeout_secs) = timeout_secs.parse::<u64>() {
|
||||
if timeout_secs > 0 {
|
||||
self.provider_timeout_secs = timeout_secs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply named provider profile remapping (Codex app-server compatibility).
|
||||
self.apply_named_model_provider_profile();
|
||||
|
||||
@@ -5525,6 +5676,7 @@ mod tests {
|
||||
c.skills.prompt_injection_mode,
|
||||
SkillsPromptInjectionMode::Full
|
||||
);
|
||||
assert_eq!(c.provider_timeout_secs, 120);
|
||||
assert!(c.workspace_dir.to_string_lossy().contains("workspace"));
|
||||
assert!(c.config_path.to_string_lossy().contains("config.toml"));
|
||||
}
|
||||
@@ -5729,6 +5881,7 @@ default_temperature = 0.7
|
||||
default_model: Some("gpt-4o".into()),
|
||||
model_providers: HashMap::new(),
|
||||
default_temperature: 0.5,
|
||||
provider_timeout_secs: 120,
|
||||
observability: ObservabilityConfig {
|
||||
backend: "log".into(),
|
||||
..ObservabilityConfig::default()
|
||||
@@ -5820,6 +5973,7 @@ default_temperature = 0.7
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
tts: TtsConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
@@ -5869,6 +6023,18 @@ default_temperature = 0.7
|
||||
assert_eq!(parsed.memory.archive_after_days, 7);
|
||||
assert_eq!(parsed.memory.purge_after_days, 30);
|
||||
assert_eq!(parsed.memory.conversation_retention_days, 30);
|
||||
// provider_timeout_secs defaults to 120 when not specified
|
||||
assert_eq!(parsed.provider_timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_timeout_secs_parses_from_toml() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
provider_timeout_secs = 300
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(raw).unwrap();
|
||||
assert_eq!(parsed.provider_timeout_secs, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -5969,6 +6135,7 @@ tool_dispatcher = "xml"
|
||||
default_model: Some("test-model".into()),
|
||||
model_providers: HashMap::new(),
|
||||
default_temperature: 0.9,
|
||||
provider_timeout_secs: 120,
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
security: SecurityConfig::default(),
|
||||
@@ -6003,6 +6170,7 @@ tool_dispatcher = "xml"
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
tts: TtsConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await.unwrap();
|
||||
|
||||
+21
-9
@@ -351,6 +351,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(config.provider_timeout_secs),
|
||||
},
|
||||
)?);
|
||||
let model = config
|
||||
@@ -747,15 +748,25 @@ const PROMETHEUS_CONTENT_TYPE: &str = "text/plain; version=0.0.4; charset=utf-8"
|
||||
|
||||
/// GET /metrics — Prometheus text exposition format
|
||||
async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let body = if let Some(prom) = state
|
||||
.observer
|
||||
.as_ref()
|
||||
.as_any()
|
||||
.downcast_ref::<crate::observability::PrometheusObserver>()
|
||||
{
|
||||
prom.encode()
|
||||
} else {
|
||||
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
|
||||
let body = {
|
||||
#[cfg(feature = "metrics")]
|
||||
{
|
||||
if let Some(prom) = state
|
||||
.observer
|
||||
.as_ref()
|
||||
.as_any()
|
||||
.downcast_ref::<crate::observability::PrometheusObserver>()
|
||||
{
|
||||
prom.encode()
|
||||
} else {
|
||||
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "metrics"))]
|
||||
{
|
||||
let _ = &state;
|
||||
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
@@ -1738,6 +1749,7 @@ mod tests {
|
||||
assert!(text.contains("Prometheus backend not enabled"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
#[tokio::test]
|
||||
async fn metrics_endpoint_renders_prometheus_output() {
|
||||
let prom = Arc::new(crate::observability::PrometheusObserver::new());
|
||||
|
||||
+136
-4
@@ -15,7 +15,7 @@ use axum::{
|
||||
ws::{Message, WebSocket},
|
||||
Query, State, WebSocketUpgrade,
|
||||
},
|
||||
http::HeaderMap,
|
||||
http::{header, HeaderMap},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
@@ -24,12 +24,62 @@ use serde::Deserialize;
|
||||
/// The sub-protocol we support for the chat WebSocket.
|
||||
const WS_PROTOCOL: &str = "zeroclaw.v1";
|
||||
|
||||
/// Prefix used in `Sec-WebSocket-Protocol` to carry a bearer token.
|
||||
const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WsQuery {
|
||||
pub token: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a bearer token from WebSocket-compatible sources.
|
||||
///
|
||||
/// Precedence (first non-empty wins):
|
||||
/// 1. `Authorization: Bearer <token>` header
|
||||
/// 2. `Sec-WebSocket-Protocol: bearer.<token>` subprotocol
|
||||
/// 3. `?token=<token>` query parameter
|
||||
///
|
||||
/// Browsers cannot set custom headers on `new WebSocket(url)`, so the query
|
||||
/// parameter and subprotocol paths are required for browser-based clients.
|
||||
fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
|
||||
// 1. Authorization header
|
||||
if let Some(t) = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
{
|
||||
if !t.is_empty() {
|
||||
return Some(t);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Sec-WebSocket-Protocol: bearer.<token>
|
||||
if let Some(t) = headers
|
||||
.get("sec-websocket-protocol")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|protos| {
|
||||
protos
|
||||
.split(',')
|
||||
.map(|p| p.trim())
|
||||
.find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
|
||||
})
|
||||
{
|
||||
if !t.is_empty() {
|
||||
return Some(t);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. ?token= query parameter
|
||||
if let Some(t) = query_token {
|
||||
if !t.is_empty() {
|
||||
return Some(t);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// GET /ws/chat — WebSocket upgrade for agent chat
|
||||
pub async fn handle_ws_chat(
|
||||
State(state): State<AppState>,
|
||||
@@ -37,13 +87,13 @@ pub async fn handle_ws_chat(
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> impl IntoResponse {
|
||||
// Auth via query param (browser WebSocket limitation)
|
||||
// Auth: check header, subprotocol, then query param (precedence order)
|
||||
if state.pairing.require_pairing() {
|
||||
let token = params.token.as_deref().unwrap_or("");
|
||||
let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
|
||||
if !state.pairing.is_authenticated(token) {
|
||||
return (
|
||||
axum::http::StatusCode::UNAUTHORIZED,
|
||||
"Unauthorized — provide ?token=<bearer_token>",
|
||||
"Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -183,3 +233,85 @@ async fn handle_socket(socket: WebSocket, state: AppState, _session_id: Option<S
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::HeaderMap;
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_from_authorization_header() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
|
||||
assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_from_subprotocol() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"sec-websocket-protocol",
|
||||
"zeroclaw.v1, bearer.zc_sub456".parse().unwrap(),
|
||||
);
|
||||
assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_from_query_param() {
|
||||
let headers = HeaderMap::new();
|
||||
assert_eq!(
|
||||
extract_ws_token(&headers, Some("zc_query789")),
|
||||
Some("zc_query789")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_precedence_header_over_subprotocol() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("authorization", "Bearer zc_header".parse().unwrap());
|
||||
headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
|
||||
assert_eq!(
|
||||
extract_ws_token(&headers, Some("zc_query")),
|
||||
Some("zc_header")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_precedence_subprotocol_over_query() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
|
||||
assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_returns_none_when_empty() {
|
||||
let headers = HeaderMap::new();
|
||||
assert_eq!(extract_ws_token(&headers, None), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_skips_empty_header_value() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("authorization", "Bearer ".parse().unwrap());
|
||||
assert_eq!(
|
||||
extract_ws_token(&headers, Some("zc_fallback")),
|
||||
Some("zc_fallback")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_skips_empty_query_param() {
|
||||
let headers = HeaderMap::new();
|
||||
assert_eq!(extract_ws_token(&headers, Some("")), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ws_token_subprotocol_with_multiple_entries() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"sec-websocket-protocol",
|
||||
"zeroclaw.v1, bearer.zc_tok, other".parse().unwrap(),
|
||||
);
|
||||
assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
|
||||
}
|
||||
}
|
||||
|
||||
+25
@@ -202,6 +202,31 @@ Examples:
|
||||
/// Telegram identity to allow (username without '@' or numeric user ID)
|
||||
identity: String,
|
||||
},
|
||||
/// Send a message to a configured channel
|
||||
#[command(long_about = "\
|
||||
Send a one-off message to a configured channel.
|
||||
|
||||
Sends a text message through the specified channel without starting \
|
||||
the full agent loop. Useful for scripted notifications, hardware \
|
||||
sensor alerts, and automation pipelines.
|
||||
|
||||
The --channel-id selects the channel by its config section name \
|
||||
(e.g. 'telegram', 'discord', 'slack'). The --recipient is the \
|
||||
platform-specific destination (e.g. a Telegram chat ID).
|
||||
|
||||
Examples:
|
||||
zeroclaw channel send 'Someone is near your device.' --channel-id telegram --recipient 123456789
|
||||
zeroclaw channel send 'Build succeeded!' --channel-id discord --recipient 987654321")]
|
||||
Send {
|
||||
/// Message text to send
|
||||
message: String,
|
||||
/// Channel config name (e.g. telegram, discord, slack)
|
||||
#[arg(long)]
|
||||
channel_id: String,
|
||||
/// Recipient identifier (platform-specific, e.g. Telegram chat ID)
|
||||
#[arg(long)]
|
||||
recipient: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Skills management subcommands
|
||||
|
||||
+3
-2
@@ -324,7 +324,7 @@ Examples:
|
||||
#[command(long_about = "\
|
||||
Manage communication channels.
|
||||
|
||||
Add, remove, list, and health-check channels that connect ZeroClaw \
|
||||
Add, remove, list, send, and health-check channels that connect ZeroClaw \
|
||||
to messaging platforms. Supported channel types: telegram, discord, \
|
||||
slack, whatsapp, matrix, imessage, email.
|
||||
|
||||
@@ -333,7 +333,8 @@ Examples:
|
||||
zeroclaw channel doctor
|
||||
zeroclaw channel add telegram '{\"bot_token\":\"...\",\"name\":\"my-bot\"}'
|
||||
zeroclaw channel remove my-bot
|
||||
zeroclaw channel bind-telegram zeroclaw_user")]
|
||||
zeroclaw channel bind-telegram zeroclaw_user
|
||||
zeroclaw channel send 'Alert!' --channel-id telegram --recipient 123456789")]
|
||||
Channel {
|
||||
#[command(subcommand)]
|
||||
channel_command: ChannelCommands,
|
||||
|
||||
@@ -3,6 +3,7 @@ pub mod multi;
|
||||
pub mod noop;
|
||||
#[cfg(feature = "observability-otel")]
|
||||
pub mod otel;
|
||||
#[cfg(feature = "metrics")]
|
||||
pub mod prometheus;
|
||||
pub mod runtime_trace;
|
||||
pub mod traits;
|
||||
@@ -15,6 +16,7 @@ pub use self::multi::MultiObserver;
|
||||
pub use noop::NoopObserver;
|
||||
#[cfg(feature = "observability-otel")]
|
||||
pub use otel::OtelObserver;
|
||||
#[cfg(feature = "metrics")]
|
||||
pub use prometheus::PrometheusObserver;
|
||||
pub use traits::{Observer, ObserverEvent};
|
||||
#[allow(unused_imports)]
|
||||
@@ -26,7 +28,19 @@ use crate::config::ObservabilityConfig;
|
||||
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
|
||||
match config.backend.as_str() {
|
||||
"log" => Box::new(LogObserver::new()),
|
||||
"prometheus" => Box::new(PrometheusObserver::new()),
|
||||
"prometheus" => {
|
||||
#[cfg(feature = "metrics")]
|
||||
{
|
||||
Box::new(PrometheusObserver::new())
|
||||
}
|
||||
#[cfg(not(feature = "metrics"))]
|
||||
{
|
||||
tracing::warn!(
|
||||
"Prometheus backend requested but this build was compiled without `metrics`; falling back to noop."
|
||||
);
|
||||
Box::new(NoopObserver)
|
||||
}
|
||||
}
|
||||
"otel" | "opentelemetry" | "otlp" => {
|
||||
#[cfg(feature = "observability-otel")]
|
||||
match OtelObserver::new(
|
||||
@@ -104,7 +118,12 @@ mod tests {
|
||||
backend: "prometheus".into(),
|
||||
..ObservabilityConfig::default()
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "prometheus");
|
||||
let expected = if cfg!(feature = "metrics") {
|
||||
"prometheus"
|
||||
} else {
|
||||
"noop"
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -138,6 +138,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
default_model: Some(model),
|
||||
model_providers: std::collections::HashMap::new(),
|
||||
default_temperature: 0.7,
|
||||
provider_timeout_secs: 120,
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
security: crate::config::SecurityConfig::default(),
|
||||
@@ -172,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
transcription: crate::config::TranscriptionConfig::default(),
|
||||
tts: crate::config::TtsConfig::default(),
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@@ -490,6 +492,7 @@ async fn run_quick_setup_with_home(
|
||||
default_model: Some(model.clone()),
|
||||
model_providers: std::collections::HashMap::new(),
|
||||
default_temperature: 0.7,
|
||||
provider_timeout_secs: 120,
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
security: crate::config::SecurityConfig::default(),
|
||||
@@ -524,6 +527,7 @@ async fn run_quick_setup_with_home(
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
transcription: crate::config::TranscriptionConfig::default(),
|
||||
tts: crate::config::TtsConfig::default(),
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
|
||||
@@ -37,6 +37,8 @@ pub struct OpenAiCompatibleProvider {
|
||||
/// Whether this provider supports OpenAI-style native tool calling.
|
||||
/// When false, tools are injected into the system prompt as text.
|
||||
native_tool_calling: bool,
|
||||
/// HTTP request timeout in seconds for LLM API calls. Default: 120.
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
/// How the provider expects the API key to be sent.
|
||||
@@ -170,9 +172,16 @@ impl OpenAiCompatibleProvider {
|
||||
user_agent: user_agent.map(ToString::to_string),
|
||||
merge_system_into_user,
|
||||
native_tool_calling: !merge_system_into_user,
|
||||
timeout_secs: 120,
|
||||
}
|
||||
}
|
||||
|
||||
/// Override the HTTP request timeout for LLM API calls.
|
||||
pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
|
||||
self.timeout_secs = timeout_secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Collect all `system` role messages, concatenate their content,
|
||||
/// and prepend to the first `user` message. Drop all system messages.
|
||||
/// Used for providers (e.g. MiniMax) that reject `role: system`.
|
||||
@@ -205,6 +214,7 @@ impl OpenAiCompatibleProvider {
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
let timeout = self.timeout_secs;
|
||||
if let Some(ua) = self.user_agent.as_deref() {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Ok(value) = HeaderValue::from_str(ua) {
|
||||
@@ -212,7 +222,7 @@ impl OpenAiCompatibleProvider {
|
||||
}
|
||||
|
||||
let builder = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.timeout(std::time::Duration::from_secs(timeout))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.default_headers(headers);
|
||||
let builder =
|
||||
@@ -224,7 +234,7 @@ impl OpenAiCompatibleProvider {
|
||||
});
|
||||
}
|
||||
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("provider.compatible", 120, 10)
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("provider.compatible", timeout, 10)
|
||||
}
|
||||
|
||||
/// Build the full URL for chat completions, detecting if base_url already includes the path.
|
||||
@@ -2899,4 +2909,16 @@ mod tests {
|
||||
);
|
||||
assert!(json.contains("thinking..."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_timeout_is_120s() {
|
||||
let p = make_provider("test", "https://example.com", None);
|
||||
assert_eq!(p.timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn with_timeout_secs_overrides_default() {
|
||||
let p = make_provider("test", "https://example.com", None).with_timeout_secs(300);
|
||||
assert_eq!(p.timeout_secs, 300);
|
||||
}
|
||||
}
|
||||
|
||||
+48
-32
@@ -677,6 +677,9 @@ pub struct ProviderRuntimeOptions {
|
||||
pub zeroclaw_dir: Option<PathBuf>,
|
||||
pub secrets_encrypt: bool,
|
||||
pub reasoning_enabled: Option<bool>,
|
||||
/// HTTP request timeout in seconds for LLM provider API calls.
|
||||
/// `None` uses the provider's built-in default (120s for compatible providers).
|
||||
pub provider_timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl Default for ProviderRuntimeOptions {
|
||||
@@ -687,6 +690,7 @@ impl Default for ProviderRuntimeOptions {
|
||||
zeroclaw_dir: None,
|
||||
secrets_encrypt: true,
|
||||
reasoning_enabled: None,
|
||||
provider_timeout_secs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -993,6 +997,18 @@ fn create_provider_with_url_and_options(
|
||||
api_url: Option<&str>,
|
||||
options: &ProviderRuntimeOptions,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
// Closure to optionally apply the configured provider timeout to
|
||||
// OpenAI-compatible providers before boxing them as trait objects.
|
||||
let compat = {
|
||||
let timeout = options.provider_timeout_secs;
|
||||
move |p: OpenAiCompatibleProvider| -> Box<dyn Provider> {
|
||||
match timeout {
|
||||
Some(t) => Box::new(p.with_timeout_secs(t)),
|
||||
None => Box::new(p),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let qwen_oauth_context = is_qwen_oauth_alias(name).then(|| resolve_qwen_oauth_context(api_key));
|
||||
|
||||
// Resolve credential and break static-analysis taint chain from the
|
||||
@@ -1066,28 +1082,28 @@ fn create_provider_with_url_and_options(
|
||||
"telnyx" => Ok(Box::new(telnyx::TelnyxProvider::new(key))),
|
||||
|
||||
// ── OpenAI-compatible providers ──────────────────────
|
||||
"venice" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"venice" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Venice", "https://api.venice.ai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"vercel" | "vercel-ai" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Vercel AI Gateway",
|
||||
VERCEL_AI_GATEWAY_BASE_URL,
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
"cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"cloudflare" | "cloudflare-ai" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Cloudflare AI Gateway",
|
||||
"https://gateway.ai.cloudflare.com/v1",
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
name if moonshot_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
name if moonshot_base_url(name).is_some() => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Moonshot",
|
||||
moonshot_base_url(name).expect("checked in guard"),
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
"kimi-code" | "kimi_coding" | "kimi_for_coding" => Ok(Box::new(
|
||||
"kimi-code" | "kimi_coding" | "kimi_for_coding" => Ok(compat(
|
||||
OpenAiCompatibleProvider::new_with_user_agent(
|
||||
"Kimi Code",
|
||||
"https://api.kimi.com/coding/v1",
|
||||
@@ -1096,30 +1112,30 @@ fn create_provider_with_url_and_options(
|
||||
"KimiCLI/0.77",
|
||||
),
|
||||
)),
|
||||
"synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"synthetic" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Synthetic", "https://api.synthetic.new/openai/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"opencode" | "opencode-zen" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"OpenCode Zen", "https://opencode.ai/zen/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"opencode-go" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"opencode-go" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"OpenCode Go", "https://opencode.ai/zen/go/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
name if zai_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
name if zai_base_url(name).is_some() => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Z.AI",
|
||||
zai_base_url(name).expect("checked in guard"),
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
name if glm_base_url(name).is_some() => {
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new_no_responses_fallback(
|
||||
Ok(compat(OpenAiCompatibleProvider::new_no_responses_fallback(
|
||||
"GLM",
|
||||
glm_base_url(name).expect("checked in guard"),
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
)))
|
||||
}
|
||||
name if minimax_base_url(name).is_some() => Ok(Box::new(
|
||||
name if minimax_base_url(name).is_some() => Ok(compat(
|
||||
OpenAiCompatibleProvider::new_merge_system_into_user(
|
||||
"MiniMax",
|
||||
minimax_base_url(name).expect("checked in guard"),
|
||||
@@ -1149,7 +1165,7 @@ fn create_provider_with_url_and_options(
|
||||
.or_else(|| qwen_oauth_context.as_ref().and_then(|context| context.base_url.clone()))
|
||||
.unwrap_or_else(|| QWEN_OAUTH_BASE_FALLBACK_URL.to_string());
|
||||
|
||||
Ok(Box::new(
|
||||
Ok(compat(
|
||||
OpenAiCompatibleProvider::new_with_user_agent_and_vision(
|
||||
"Qwen Code",
|
||||
&base_url,
|
||||
@@ -1159,16 +1175,16 @@ fn create_provider_with_url_and_options(
|
||||
true,
|
||||
)))
|
||||
}
|
||||
name if is_qianfan_alias(name) => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
name if is_qianfan_alias(name) => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Qianfan", "https://aip.baidubce.com", key, AuthStyle::Bearer,
|
||||
))),
|
||||
name if is_doubao_alias(name) => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
name if is_doubao_alias(name) => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Doubao",
|
||||
"https://ark.cn-beijing.volces.com/api/v3",
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
name if qwen_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new_with_vision(
|
||||
name if qwen_base_url(name).is_some() => Ok(compat(OpenAiCompatibleProvider::new_with_vision(
|
||||
"Qwen",
|
||||
qwen_base_url(name).expect("checked in guard"),
|
||||
key,
|
||||
@@ -1177,31 +1193,31 @@ fn create_provider_with_url_and_options(
|
||||
))),
|
||||
|
||||
// ── Extended ecosystem (community favorites) ─────────
|
||||
"groq" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"groq" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Groq", "https://api.groq.com/openai/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"mistral" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"xai" | "grok" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"deepseek" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"DeepSeek", "https://api.deepseek.com", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"together" | "together-ai" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Together AI", "https://api.together.xyz", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"fireworks" | "fireworks-ai" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Fireworks AI", "https://api.fireworks.ai/inference/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"novita" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"novita" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Novita AI", "https://api.novita.ai/openai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"perplexity" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Perplexity", "https://api.perplexity.ai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"cohere" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"copilot" | "github-copilot" => Ok(Box::new(copilot::CopilotProvider::new(key))),
|
||||
@@ -1210,7 +1226,7 @@ fn create_provider_with_url_and_options(
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("lm-studio");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"LM Studio",
|
||||
"http://localhost:1234/v1",
|
||||
Some(lm_studio_key),
|
||||
@@ -1226,7 +1242,7 @@ fn create_provider_with_url_and_options(
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("llama.cpp");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"llama.cpp",
|
||||
base_url,
|
||||
Some(llama_cpp_key),
|
||||
@@ -1238,7 +1254,7 @@ fn create_provider_with_url_and_options(
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("http://localhost:30000/v1");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"SGLang",
|
||||
base_url,
|
||||
key,
|
||||
@@ -1250,7 +1266,7 @@ fn create_provider_with_url_and_options(
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("http://localhost:8000/v1");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"vLLM",
|
||||
base_url,
|
||||
key,
|
||||
@@ -1266,14 +1282,14 @@ fn create_provider_with_url_and_options(
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("osaurus");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Osaurus",
|
||||
base_url,
|
||||
Some(osaurus_key),
|
||||
AuthStyle::Bearer,
|
||||
)))
|
||||
}
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(compat(
|
||||
OpenAiCompatibleProvider::new_no_responses_fallback(
|
||||
"NVIDIA NIM",
|
||||
"https://integrate.api.nvidia.com/v1",
|
||||
@@ -1283,7 +1299,7 @@ fn create_provider_with_url_and_options(
|
||||
)),
|
||||
|
||||
// ── AI inference routers ─────────────────────────────
|
||||
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"astrai" => Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
|
||||
@@ -1301,7 +1317,7 @@ fn create_provider_with_url_and_options(
|
||||
"Custom provider",
|
||||
"custom:https://your-api.com",
|
||||
)?;
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new_with_vision(
|
||||
Ok(compat(OpenAiCompatibleProvider::new_with_vision(
|
||||
"Custom",
|
||||
&base_url,
|
||||
key,
|
||||
|
||||
+81
-9
@@ -27,7 +27,7 @@ struct ChatRequest {
|
||||
tools: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -40,14 +40,14 @@ struct Message {
|
||||
tool_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct OutgoingToolCall {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: OutgoingFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct OutgoingFunction {
|
||||
name: String,
|
||||
arguments: serde_json::Value,
|
||||
@@ -258,13 +258,31 @@ impl OllamaProvider {
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> ChatRequest {
|
||||
self.build_chat_request_with_think(
|
||||
messages,
|
||||
model,
|
||||
temperature,
|
||||
tools,
|
||||
self.reasoning_enabled,
|
||||
)
|
||||
}
|
||||
|
||||
/// Build a chat request with an explicit `think` value.
|
||||
fn build_chat_request_with_think(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
think: Option<bool>,
|
||||
) -> ChatRequest {
|
||||
ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
think: self.reasoning_enabled,
|
||||
think,
|
||||
tools: tools.map(|t| t.to_vec()),
|
||||
}
|
||||
}
|
||||
@@ -396,17 +414,18 @@ impl OllamaProvider {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response.
|
||||
/// Pass `tools` to enable native function-calling for models that support it.
|
||||
async fn send_request(
|
||||
/// Send a single HTTP request to Ollama and parse the response.
|
||||
async fn send_request_inner(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
think: Option<bool>,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = self.build_chat_request(messages, model, temperature, tools);
|
||||
let request =
|
||||
self.build_chat_request_with_think(messages.to_vec(), model, temperature, tools, think);
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
@@ -466,6 +485,59 @@ impl OllamaProvider {
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response.
|
||||
/// Pass `tools` to enable native function-calling for models that support it.
|
||||
///
|
||||
/// When `reasoning_enabled` (`think`) is set to `true`, the first request
|
||||
/// includes `think: true`. If that request fails (the model may not support
|
||||
/// the `think` parameter), we automatically retry once with `think` omitted
|
||||
/// so the call succeeds instead of entering an infinite retry loop.
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let result = self
|
||||
.send_request_inner(
|
||||
&messages,
|
||||
model,
|
||||
temperature,
|
||||
should_auth,
|
||||
tools,
|
||||
self.reasoning_enabled,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(first_err) if self.reasoning_enabled == Some(true) => {
|
||||
tracing::warn!(
|
||||
model = model,
|
||||
error = %first_err,
|
||||
"Ollama request failed with think=true; retrying without reasoning \
|
||||
(model may not support it)"
|
||||
);
|
||||
// Retry with think omitted from the request entirely.
|
||||
self.send_request_inner(&messages, model, temperature, should_auth, tools, None)
|
||||
.await
|
||||
.map_err(|retry_err| {
|
||||
// Both attempts failed — return the original error for clarity.
|
||||
tracing::error!(
|
||||
model = model,
|
||||
original_error = %first_err,
|
||||
retry_error = %retry_err,
|
||||
"Ollama request also failed without think; returning original error"
|
||||
);
|
||||
first_err
|
||||
})
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
|
||||
@@ -1017,6 +1017,7 @@ data: [DONE]
|
||||
secrets_encrypt: false,
|
||||
auth_profile_override: None,
|
||||
reasoning_enabled: None,
|
||||
provider_timeout_secs: None,
|
||||
};
|
||||
let provider =
|
||||
OpenAiCodexProvider::new(&options, None).expect("provider should initialize");
|
||||
|
||||
@@ -7,8 +7,12 @@
|
||||
//! Contributed from RustyClaw (MIT licensed).
|
||||
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// Minimum token length considered for high-entropy detection.
|
||||
const ENTROPY_TOKEN_MIN_LEN: usize = 24;
|
||||
|
||||
/// Result of leak detection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LeakResult {
|
||||
@@ -61,6 +65,7 @@ impl LeakDetector {
|
||||
self.check_private_keys(content, &mut patterns, &mut redacted);
|
||||
self.check_jwt_tokens(content, &mut patterns, &mut redacted);
|
||||
self.check_database_urls(content, &mut patterns, &mut redacted);
|
||||
self.check_high_entropy_tokens(content, &mut patterns, &mut redacted);
|
||||
|
||||
if patterns.is_empty() {
|
||||
LeakResult::Clean
|
||||
@@ -288,6 +293,72 @@ impl LeakDetector {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check for high-entropy tokens that may be leaked credentials.
|
||||
///
|
||||
/// Extracts candidate tokens from content (after stripping URLs to avoid
|
||||
/// false-positives on path segments) and flags any that exceed the Shannon
|
||||
/// entropy threshold derived from the detector's sensitivity.
|
||||
fn check_high_entropy_tokens(
|
||||
&self,
|
||||
content: &str,
|
||||
patterns: &mut Vec<String>,
|
||||
redacted: &mut String,
|
||||
) {
|
||||
// Entropy threshold scales with sensitivity: at 0.7 this is ~4.37.
|
||||
let entropy_threshold = 3.5 + self.sensitivity * 1.25;
|
||||
|
||||
// Strip URLs before extracting tokens so that path segments like
|
||||
// "org/documents/2024-report-a1b2c3d4e5f6g7h8i9j0" are not mistaken
|
||||
// for high-entropy credentials.
|
||||
static URL_PATTERN: OnceLock<Regex> = OnceLock::new();
|
||||
let url_re = URL_PATTERN.get_or_init(|| Regex::new(r"https?://\S+").unwrap());
|
||||
let content_without_urls = url_re.replace_all(content, "");
|
||||
|
||||
let tokens = extract_candidate_tokens(&content_without_urls);
|
||||
|
||||
for token in tokens {
|
||||
if token.len() >= ENTROPY_TOKEN_MIN_LEN {
|
||||
let entropy = shannon_entropy(token);
|
||||
if entropy >= entropy_threshold && has_mixed_alpha_digit(token) {
|
||||
patterns.push("High-entropy token".to_string());
|
||||
*redacted = redacted.replace(token, "[REDACTED_HIGH_ENTROPY_TOKEN]");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract candidate tokens by splitting on characters outside the
|
||||
/// alphanumeric + common credential character set.
|
||||
fn extract_candidate_tokens(content: &str) -> Vec<&str> {
|
||||
content
|
||||
.split(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '-' && c != '+' && c != '/')
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute Shannon entropy (bits per character) for the given string.
|
||||
fn shannon_entropy(s: &str) -> f64 {
|
||||
let len = s.len() as f64;
|
||||
if len == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
let mut freq: HashMap<u8, usize> = HashMap::new();
|
||||
for &b in s.as_bytes() {
|
||||
*freq.entry(b).or_insert(0) += 1;
|
||||
}
|
||||
freq.values().fold(0.0, |acc, &count| {
|
||||
let p = count as f64 / len;
|
||||
acc - p * p.log2()
|
||||
})
|
||||
}
|
||||
|
||||
/// Check whether a token contains both alphabetic and digit characters.
|
||||
fn has_mixed_alpha_digit(s: &str) -> bool {
|
||||
let has_alpha = s.bytes().any(|b| b.is_ascii_alphabetic());
|
||||
let has_digit = s.bytes().any(|b| b.is_ascii_digit());
|
||||
has_alpha && has_digit
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -381,4 +452,87 @@ MIIEowIBAAKCAQEA0ZPr5JeyVDonXsKhfq...
|
||||
// Low sensitivity should not flag generic secrets
|
||||
assert!(matches!(result, LeakResult::Clean));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn url_path_segments_not_flagged() {
|
||||
let detector = LeakDetector::new();
|
||||
// URL with a long mixed-alphanumeric path segment that would previously
|
||||
// false-positive as a high-entropy token.
|
||||
let content =
|
||||
"See https://example.org/documents/2024-report-a1b2c3d4e5f6g7h8i9j0.pdf for details";
|
||||
let result = detector.scan(content);
|
||||
assert!(
|
||||
matches!(result, LeakResult::Clean),
|
||||
"URL path segments should not trigger high-entropy detection"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn url_with_long_path_not_redacted() {
|
||||
let detector = LeakDetector::new();
|
||||
let content = "Reference: https://gov.example.com/publications/research/2024-annual-fiscal-policy-review-9a8b7c6d5e4f3g2h1i0j.html";
|
||||
let result = detector.scan(content);
|
||||
assert!(
|
||||
matches!(result, LeakResult::Clean),
|
||||
"Long URL paths should not be redacted"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_high_entropy_token_outside_url() {
|
||||
let detector = LeakDetector::new();
|
||||
// A standalone high-entropy token (not in a URL) should still be detected.
|
||||
let content = "Found credential: aB3xK9mW2pQ7vL4nR8sT1yU6hD0jF5cG";
|
||||
let result = detector.scan(content);
|
||||
match result {
|
||||
LeakResult::Detected { patterns, redacted } => {
|
||||
assert!(patterns.iter().any(|p| p.contains("High-entropy")));
|
||||
assert!(redacted.contains("[REDACTED_HIGH_ENTROPY_TOKEN]"));
|
||||
}
|
||||
LeakResult::Clean => panic!("Should detect high-entropy token"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn low_sensitivity_raises_entropy_threshold() {
|
||||
let detector = LeakDetector::with_sensitivity(0.3);
|
||||
// At low sensitivity the entropy threshold is higher (3.5 + 0.3*1.25 = 3.875).
|
||||
// A repetitive mixed token has low entropy and should not be flagged.
|
||||
let content = "token found: ab12ab12ab12ab12ab12ab12ab12ab12";
|
||||
let result = detector.scan(content);
|
||||
assert!(
|
||||
matches!(result, LeakResult::Clean),
|
||||
"Low-entropy repetitive tokens should not be flagged"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_candidate_tokens_splits_correctly() {
|
||||
let tokens = extract_candidate_tokens("foo.bar:baz qux-quux key=val");
|
||||
assert!(tokens.contains(&"foo"));
|
||||
assert!(tokens.contains(&"bar"));
|
||||
assert!(tokens.contains(&"baz"));
|
||||
assert!(tokens.contains(&"qux-quux"));
|
||||
// '=' is a delimiter, not part of tokens
|
||||
assert!(tokens.contains(&"key"));
|
||||
assert!(tokens.contains(&"val"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shannon_entropy_empty_string() {
|
||||
assert_eq!(shannon_entropy(""), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shannon_entropy_single_char() {
|
||||
// All same characters: entropy = 0
|
||||
assert_eq!(shannon_entropy("aaaa"), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shannon_entropy_two_equal_chars() {
|
||||
// "ab" repeated: entropy = 1.0 bit
|
||||
let e = shannon_entropy("abab");
|
||||
assert!((e - 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
|
||||
+54
-4
@@ -922,9 +922,28 @@ impl SecurityPolicy {
|
||||
// Expand "~" for consistent matching with forbidden paths and allowlists.
|
||||
let expanded_path = expand_user_path(path);
|
||||
|
||||
// Block absolute paths when workspace_only is set
|
||||
if self.workspace_only && expanded_path.is_absolute() {
|
||||
return false;
|
||||
// When workspace_only is set and the path is absolute, only allow it
|
||||
// if it falls within the workspace directory or an explicit allowed
|
||||
// root. The workspace/allowed-root check runs BEFORE the forbidden
|
||||
// prefix list so that workspace paths under broad defaults like
|
||||
// "/home" are not rejected. This mirrors the priority order in
|
||||
// `is_resolved_path_allowed`. See #2880.
|
||||
if expanded_path.is_absolute() {
|
||||
let in_workspace = expanded_path.starts_with(&self.workspace_dir);
|
||||
let in_allowed_root = self
|
||||
.allowed_roots
|
||||
.iter()
|
||||
.any(|root| expanded_path.starts_with(root));
|
||||
|
||||
if in_workspace || in_allowed_root {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Absolute path outside workspace/allowed roots — block when
|
||||
// workspace_only, or fall through to forbidden-prefix check.
|
||||
if self.workspace_only {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Block forbidden paths using path-component-aware matching
|
||||
@@ -1384,6 +1403,37 @@ mod tests {
|
||||
assert!(!p.is_path_allowed("/tmp/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_path_inside_workspace_allowed_when_workspace_only() {
|
||||
let p = SecurityPolicy {
|
||||
workspace_dir: PathBuf::from("/home/user/.zeroclaw/workspace"),
|
||||
workspace_only: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
// Absolute path inside workspace should be allowed
|
||||
assert!(p.is_path_allowed("/home/user/.zeroclaw/workspace/images/example.png"));
|
||||
assert!(p.is_path_allowed("/home/user/.zeroclaw/workspace/file.txt"));
|
||||
// Absolute path outside workspace should still be blocked
|
||||
assert!(!p.is_path_allowed("/home/user/other/file.txt"));
|
||||
assert!(!p.is_path_allowed("/tmp/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_path_in_allowed_root_permitted_when_workspace_only() {
|
||||
let p = SecurityPolicy {
|
||||
workspace_dir: PathBuf::from("/home/user/.zeroclaw/workspace"),
|
||||
workspace_only: true,
|
||||
allowed_roots: vec![PathBuf::from("/home/user/.zeroclaw/shared")],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
// Path in allowed root should be permitted
|
||||
assert!(p.is_path_allowed("/home/user/.zeroclaw/shared/data.txt"));
|
||||
// Path in workspace should still be permitted
|
||||
assert!(p.is_path_allowed("/home/user/.zeroclaw/workspace/file.txt"));
|
||||
// Path outside both should still be blocked
|
||||
assert!(!p.is_path_allowed("/home/user/other/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_paths_allowed_when_not_workspace_only() {
|
||||
let p = SecurityPolicy {
|
||||
@@ -2122,7 +2172,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checklist_workspace_only_blocks_all_absolute() {
|
||||
fn checklist_workspace_only_blocks_absolute_outside_workspace() {
|
||||
let p = SecurityPolicy {
|
||||
workspace_only: true,
|
||||
..SecurityPolicy::default()
|
||||
|
||||
@@ -411,6 +411,7 @@ impl DelegateTool {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
//! MCP (Model Context Protocol) client — connects to external tool servers.
|
||||
//!
|
||||
//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
use crate::config::schema::McpServerConfig;
|
||||
use crate::tools::mcp_protocol::{
|
||||
JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION,
|
||||
};
|
||||
use crate::tools::mcp_transport::{create_transport, McpTransportConn};
|
||||
|
||||
/// Timeout for receiving a response from an MCP server during init/list.
|
||||
/// Prevents a hung server from blocking the daemon indefinitely.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Default timeout for tool calls (seconds) when not configured per-server.
|
||||
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
|
||||
|
||||
/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
|
||||
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
// ── Internal server state ──────────────────────────────────────────────────
|
||||
|
||||
struct McpServerInner {
|
||||
config: McpServerConfig,
|
||||
transport: Box<dyn McpTransportConn>,
|
||||
next_id: AtomicU64,
|
||||
tools: Vec<McpToolDef>,
|
||||
}
|
||||
|
||||
// ── McpServer ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// A live connection to one MCP server (any transport).
|
||||
#[derive(Clone)]
|
||||
pub struct McpServer {
|
||||
inner: Arc<Mutex<McpServerInner>>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
/// Connect to the server, perform the initialize handshake, and fetch the tool list.
|
||||
pub async fn connect(config: McpServerConfig) -> Result<Self> {
|
||||
// Create transport based on config
|
||||
let mut transport = create_transport(&config).with_context(|| {
|
||||
format!(
|
||||
"failed to create transport for MCP server `{}`",
|
||||
config.name
|
||||
)
|
||||
})?;
|
||||
|
||||
// Initialize handshake
|
||||
let id = 1u64;
|
||||
let init_req = JsonRpcRequest::new(
|
||||
id,
|
||||
"initialize",
|
||||
json!({
|
||||
"protocolVersion": MCP_PROTOCOL_VERSION,
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "zeroclaw",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let init_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
transport.send_and_recv(&init_req),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` timed out after {}s waiting for initialize response",
|
||||
config.name, RECV_TIMEOUT_SECS
|
||||
)
|
||||
})??;
|
||||
|
||||
if init_resp.error.is_some() {
|
||||
bail!(
|
||||
"MCP server `{}` rejected initialize: {:?}",
|
||||
config.name,
|
||||
init_resp.error
|
||||
);
|
||||
}
|
||||
|
||||
// Notify server that client is initialized (no response expected for notifications)
|
||||
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
|
||||
// Best effort - ignore errors for notifications
|
||||
let _ = transport.send_and_recv(¬if).await;
|
||||
|
||||
// Fetch available tools
|
||||
let id = 2u64;
|
||||
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
|
||||
|
||||
let list_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
transport.send_and_recv(&list_req),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` timed out after {}s waiting for tools/list response",
|
||||
config.name, RECV_TIMEOUT_SECS
|
||||
)
|
||||
})??;
|
||||
|
||||
let result = list_resp
|
||||
.result
|
||||
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
|
||||
let tool_list: McpToolsListResult = serde_json::from_value(result)
|
||||
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
|
||||
|
||||
let tool_count = tool_list.tools.len();
|
||||
|
||||
let inner = McpServerInner {
|
||||
config,
|
||||
transport,
|
||||
next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
|
||||
tools: tool_list.tools,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"MCP server `{}` connected — {} tool(s) available",
|
||||
inner.config.name,
|
||||
tool_count
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(inner)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Tools advertised by this server.
|
||||
pub async fn tools(&self) -> Vec<McpToolDef> {
|
||||
self.inner.lock().await.tools.clone()
|
||||
}
|
||||
|
||||
/// Server display name.
|
||||
#[allow(dead_code)]
|
||||
pub async fn name(&self) -> String {
|
||||
self.inner.lock().await.config.name.clone()
|
||||
}
|
||||
|
||||
/// Call a tool on this server. Returns the raw JSON result.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let req = JsonRpcRequest::new(
|
||||
id,
|
||||
"tools/call",
|
||||
json!({ "name": tool_name, "arguments": arguments }),
|
||||
);
|
||||
|
||||
// Use per-server tool timeout if configured, otherwise default.
|
||||
// Cap at MAX_TOOL_TIMEOUT_SECS for safety.
|
||||
let tool_timeout = inner
|
||||
.config
|
||||
.tool_timeout_secs
|
||||
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
|
||||
.min(MAX_TOOL_TIMEOUT_SECS);
|
||||
|
||||
let resp = timeout(
|
||||
Duration::from_secs(tool_timeout),
|
||||
inner.transport.send_and_recv(&req),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow!(
|
||||
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
|
||||
inner.config.name,
|
||||
tool_timeout
|
||||
)
|
||||
})?
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` error during tool call `{tool_name}`",
|
||||
inner.config.name
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(err) = resp.error {
|
||||
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
|
||||
}
|
||||
Ok(resp.result.unwrap_or(serde_json::Value::Null))
|
||||
}
|
||||
}
|
||||
|
||||
// ── McpRegistry ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Registry of all connected MCP servers, with a flat tool index.
|
||||
pub struct McpRegistry {
|
||||
servers: Vec<McpServer>,
|
||||
/// prefixed_name -> (server_index, original_tool_name)
|
||||
tool_index: HashMap<String, (usize, String)>,
|
||||
}
|
||||
|
||||
impl McpRegistry {
|
||||
/// Connect to all configured servers. Non-fatal: failures are logged and skipped.
|
||||
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
|
||||
let mut servers = Vec::new();
|
||||
let mut tool_index = HashMap::new();
|
||||
|
||||
for config in configs {
|
||||
match McpServer::connect(config.clone()).await {
|
||||
Ok(server) => {
|
||||
let server_idx = servers.len();
|
||||
// Collect tools while holding the lock once, then release
|
||||
let tools = server.tools().await;
|
||||
for tool in &tools {
|
||||
// Prefix prevents name collisions across servers
|
||||
let prefixed = format!("{}__{}", config.name, tool.name);
|
||||
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
|
||||
}
|
||||
servers.push(server);
|
||||
}
|
||||
// Non-fatal — log and continue with remaining servers
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
servers,
|
||||
tool_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// All prefixed tool names across all connected servers.
|
||||
pub fn tool_names(&self) -> Vec<String> {
|
||||
self.tool_index.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Tool definition for a given prefixed name (cloned).
|
||||
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
|
||||
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
|
||||
let inner = self.servers[*server_idx].inner.lock().await;
|
||||
inner
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| &t.name == original_name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Execute a tool by prefixed name.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
prefixed_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<String> {
|
||||
let (server_idx, original_name) = self
|
||||
.tool_index
|
||||
.get(prefixed_name)
|
||||
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
|
||||
let result = self.servers[*server_idx]
|
||||
.call_tool(original_name, arguments)
|
||||
.await?;
|
||||
serde_json::to_string_pretty(&result)
|
||||
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
|
||||
pub fn server_count(&self) -> usize {
|
||||
self.servers.len()
|
||||
}
|
||||
|
||||
pub fn tool_count(&self) -> usize {
|
||||
self.tool_index.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::McpTransport;
|
||||
|
||||
#[test]
|
||||
fn tool_name_prefix_format() {
|
||||
let prefixed = format!("{}__{}", "filesystem", "read_file");
|
||||
assert_eq!(prefixed, "filesystem__read_file");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_nonexistent_command_fails_cleanly() {
|
||||
// A command that doesn't exist should fail at spawn, not panic.
|
||||
let config = McpServerConfig {
|
||||
name: "nonexistent".to_string(),
|
||||
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
|
||||
args: vec![],
|
||||
env: HashMap::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: HashMap::default(),
|
||||
};
|
||||
let result = McpServer::connect(config).await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.err().unwrap().to_string();
|
||||
assert!(msg.contains("failed to create transport"), "got: {msg}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_all_nonfatal_on_single_failure() {
|
||||
// If one server config is bad, connect_all should succeed (with 0 servers).
|
||||
let configs = vec![McpServerConfig {
|
||||
name: "bad".to_string(),
|
||||
command: "/usr/bin/does_not_exist_zc_test".to_string(),
|
||||
args: vec![],
|
||||
env: HashMap::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: HashMap::default(),
|
||||
}];
|
||||
let registry = McpRegistry::connect_all(&configs)
|
||||
.await
|
||||
.expect("connect_all should not fail");
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.tool_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Http,
|
||||
..Default::default()
|
||||
};
|
||||
let result = create_transport(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sse_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Sse,
|
||||
..Default::default()
|
||||
};
|
||||
let result = create_transport(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
//! MCP (Model Context Protocol) JSON-RPC 2.0 protocol types.
|
||||
//! Protocol version: 2024-11-05
|
||||
//! Adapted from ops-mcp-server/src/protocol.rs for client use.
|
||||
//! Both Serialize and Deserialize are derived — the client both sends (Serialize)
|
||||
//! and receives (Deserialize) JSON-RPC messages.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||
|
||||
// Standard JSON-RPC 2.0 error codes
|
||||
#[allow(dead_code)]
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
#[allow(dead_code)]
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
#[allow(dead_code)]
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
#[allow(dead_code)]
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
/// Outbound JSON-RPC request (client -> MCP server).
|
||||
/// Used for both method calls (with id) and notifications (id = None).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
/// Create a method call request with a numeric id.
|
||||
pub fn new(id: u64, method: impl Into<String>, params: serde_json::Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id: Some(serde_json::Value::Number(id.into())),
|
||||
method: method.into(),
|
||||
params: Some(params),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a notification — no id, no response expected from server.
|
||||
pub fn notification(method: impl Into<String>, params: serde_json::Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
method: method.into(),
|
||||
params: Some(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inbound JSON-RPC response (MCP server -> client).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
/// JSON-RPC error object embedded in a response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A tool advertised by an MCP server (from `tools/list` response).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolDef {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Expected shape of the `tools/list` result payload.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct McpToolsListResult {
|
||||
pub tools: Vec<McpToolDef>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_id() {
|
||||
let req = JsonRpcRequest::new(1, "tools/list", serde_json::json!({}));
|
||||
let s = serde_json::to_string(&req).unwrap();
|
||||
assert!(s.contains("\"id\":1"));
|
||||
assert!(s.contains("\"method\":\"tools/list\""));
|
||||
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notification_omits_id() {
|
||||
let notif =
|
||||
JsonRpcRequest::notification("notifications/initialized", serde_json::json!({}));
|
||||
let s = serde_json::to_string(¬if).unwrap();
|
||||
assert!(!s.contains("\"id\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.result.is_some());
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_def_deserializes_input_schema() {
|
||||
let json = r#"{"name":"read_file","description":"Read a file","inputSchema":{"type":"object","properties":{"path":{"type":"string"}}}}"#;
|
||||
let def: McpToolDef = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(def.name, "read_file");
|
||||
assert!(def.input_schema.is_object());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
//! Wraps a discovered MCP tool as a zeroclaw [`Tool`] so it is dispatched
|
||||
//! through the existing tool registry and agent loop without modification.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::tools::mcp_client::McpRegistry;
|
||||
use crate::tools::mcp_protocol::McpToolDef;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
/// A zeroclaw [`Tool`] backed by an MCP server tool.
|
||||
///
|
||||
/// The `prefixed_name` (e.g. `filesystem__read_file`) is what the agent loop
|
||||
/// sees. The registry knows how to route it to the correct server.
|
||||
pub struct McpToolWrapper {
|
||||
/// Prefixed name: `<server_name>__<tool_name>`.
|
||||
prefixed_name: String,
|
||||
/// Description extracted from the MCP tool definition. Stored as an owned
|
||||
/// String so that `description()` can return `&str` with self's lifetime.
|
||||
description: String,
|
||||
/// JSON schema for the tool's input parameters.
|
||||
input_schema: serde_json::Value,
|
||||
/// Shared registry — used to dispatch actual tool calls.
|
||||
registry: Arc<McpRegistry>,
|
||||
}
|
||||
|
||||
impl McpToolWrapper {
|
||||
pub fn new(prefixed_name: String, def: McpToolDef, registry: Arc<McpRegistry>) -> Self {
|
||||
let description = def.description.unwrap_or_else(|| "MCP tool".to_string());
|
||||
Self {
|
||||
prefixed_name,
|
||||
description,
|
||||
input_schema: def.input_schema,
|
||||
registry,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for McpToolWrapper {
|
||||
fn name(&self) -> &str {
|
||||
&self.prefixed_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.input_schema.clone()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
match self.registry.call_tool(&self.prefixed_name, args).await {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,868 @@
|
||||
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::{oneshot, Mutex, Notify};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use crate::config::schema::{McpServerConfig, McpTransport};
|
||||
use crate::tools::mcp_protocol::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR};
|
||||
|
||||
/// Maximum bytes for a single JSON-RPC response.
|
||||
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
||||
|
||||
/// Timeout for init/list operations.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
// ── Transport Trait ──────────────────────────────────────────────────────
|
||||
|
||||
/// Abstract transport for MCP communication.
|
||||
#[async_trait::async_trait]
|
||||
pub trait McpTransportConn: Send + Sync {
|
||||
/// Send a JSON-RPC request and receive the response.
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
|
||||
|
||||
/// Close the connection.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
// ── Stdio Transport ──────────────────────────────────────────────────────
|
||||
|
||||
/// Stdio-based transport (spawn local process).
|
||||
pub struct StdioTransport {
|
||||
_child: Child,
|
||||
stdin: tokio::process::ChildStdin,
|
||||
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let mut child = Command::new(&config.command)
|
||||
.args(&config.args)
|
||||
.envs(&config.env)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
|
||||
let stdout_lines = BufReader::new(stdout).lines();
|
||||
|
||||
Ok(Self {
|
||||
_child: child,
|
||||
stdin,
|
||||
stdout_lines,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_raw(&mut self, line: &str) -> Result<()> {
|
||||
self.stdin
|
||||
.write_all(line.as_bytes())
|
||||
.await
|
||||
.context("failed to write to MCP server stdin")?;
|
||||
self.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.context("failed to write newline to MCP server stdin")?;
|
||||
self.stdin.flush().await.context("failed to flush stdin")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_raw(&mut self) -> Result<String> {
|
||||
let line = self
|
||||
.stdout_lines
|
||||
.next_line()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("MCP server closed stdout"))?;
|
||||
if line.len() > MAX_LINE_BYTES {
|
||||
bail!("MCP response too large: {} bytes", line.len());
|
||||
}
|
||||
Ok(line)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for StdioTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let line = serde_json::to_string(request)?;
|
||||
self.send_raw(&line).await?;
|
||||
if request.id.is_none() {
|
||||
return Ok(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
|
||||
.await
|
||||
.context("timeout waiting for MCP response")??;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
let _ = self.stdin.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── HTTP Transport ───────────────────────────────────────────────────────
|
||||
|
||||
/// HTTP-based transport (POST requests).
|
||||
pub struct HttpTransport {
|
||||
url: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl HttpTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for HTTP transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
url,
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for HttpTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let mut req = self.client.post(&self.url).body(body);
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.context("HTTP request to MCP server failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("MCP server returned HTTP {}", resp.status());
|
||||
}
|
||||
|
||||
if request.id.is_none() {
|
||||
return Ok(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
|
||||
Ok(mcp_resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── SSE Transport ─────────────────────────────────────────────────────────
|
||||
|
||||
/// SSE-based transport (HTTP POST for requests, SSE for responses).
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
enum SseStreamState {
|
||||
Unknown,
|
||||
Connected,
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
pub struct SseTransport {
|
||||
sse_url: String,
|
||||
server_name: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
stream_state: SseStreamState,
|
||||
shared: std::sync::Arc<Mutex<SseSharedState>>,
|
||||
notify: std::sync::Arc<Notify>,
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
reader_task: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let sse_url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
sse_url,
|
||||
server_name: config.name.clone(),
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
stream_state: SseStreamState::Unknown,
|
||||
shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
|
||||
notify: std::sync::Arc::new(Notify::new()),
|
||||
shutdown_tx: None,
|
||||
reader_task: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn ensure_connected(&mut self) -> Result<()> {
|
||||
if self.stream_state == SseStreamState::Unsupported {
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(task) = &self.reader_task {
|
||||
if !task.is_finished() {
|
||||
self.stream_state = SseStreamState::Connected;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&self.sse_url)
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE GET to MCP server failed")?;
|
||||
if resp.status() == reqwest::StatusCode::NOT_FOUND
|
||||
|| resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
|
||||
{
|
||||
self.stream_state = SseStreamState::Unsupported;
|
||||
return Ok(());
|
||||
}
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
|
||||
}
|
||||
let is_event_stream = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
if !is_event_stream {
|
||||
self.stream_state = SseStreamState::Unsupported;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let shared = self.shared.clone();
|
||||
let notify = self.notify.clone();
|
||||
let sse_url = self.sse_url.clone();
|
||||
let server_name = self.server_name.clone();
|
||||
|
||||
self.reader_task = Some(tokio::spawn(async move {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|item| item.map_err(std::io::Error::other));
|
||||
let reader = tokio_util::io::StreamReader::new(stream);
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
let mut cur_event: Option<String> = None;
|
||||
let mut cur_id: Option<String> = None;
|
||||
let mut cur_data: Vec<String> = Vec::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut shutdown_rx => {
|
||||
break;
|
||||
}
|
||||
line = lines.next_line() => {
|
||||
let Ok(line_opt) = line else { break; };
|
||||
let Some(mut line) = line_opt else { break; };
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
if line.is_empty() {
|
||||
if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let event = cur_event.take();
|
||||
let data = cur_data.join("\n");
|
||||
cur_data.clear();
|
||||
let id = cur_id.take();
|
||||
handle_sse_event(
|
||||
&server_name,
|
||||
&sse_url,
|
||||
&shared,
|
||||
¬ify,
|
||||
event.as_deref(),
|
||||
id.as_deref(),
|
||||
data,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
cur_event = Some(rest.trim().to_string());
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
cur_data.push(rest.to_string());
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("id:") {
|
||||
cur_id = Some(rest.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pending = {
|
||||
let mut guard = shared.lock().await;
|
||||
std::mem::take(&mut guard.pending)
|
||||
};
|
||||
for (_, tx) in pending {
|
||||
let _ = tx.send(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: INTERNAL_ERROR,
|
||||
message: "SSE connection closed".to_string(),
|
||||
data: None,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}));
|
||||
self.stream_state = SseStreamState::Connected;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_message_url(&self) -> Result<(String, bool)> {
|
||||
let guard = self.shared.lock().await;
|
||||
if let Some(url) = &guard.message_url {
|
||||
return Ok((url.clone(), guard.message_url_from_endpoint));
|
||||
}
|
||||
drop(guard);
|
||||
|
||||
let derived = derive_message_url(&self.sse_url, "messages")
|
||||
.or_else(|| derive_message_url(&self.sse_url, "message"))
|
||||
.ok_or_else(|| anyhow!("invalid SSE URL"))?;
|
||||
let mut guard = self.shared.lock().await;
|
||||
if guard.message_url.is_none() {
|
||||
guard.message_url = Some(derived.clone());
|
||||
guard.message_url_from_endpoint = false;
|
||||
}
|
||||
Ok((derived, false))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SseSharedState {
|
||||
message_url: Option<String>,
|
||||
message_url_from_endpoint: bool,
|
||||
pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
|
||||
}
|
||||
|
||||
fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
|
||||
let url = reqwest::Url::parse(sse_url).ok()?;
|
||||
let mut segments: Vec<&str> = url.path_segments()?.collect();
|
||||
if segments.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if segments.last().copied() == Some("sse") {
|
||||
segments.pop();
|
||||
segments.push(message_path);
|
||||
let mut new_url = url.clone();
|
||||
new_url.set_path(&format!("/{}", segments.join("/")));
|
||||
return Some(new_url.to_string());
|
||||
}
|
||||
let mut new_url = url.clone();
|
||||
let mut path = url.path().trim_end_matches('/').to_string();
|
||||
path.push('/');
|
||||
path.push_str(message_path);
|
||||
new_url.set_path(&path);
|
||||
Some(new_url.to_string())
|
||||
}
|
||||
|
||||
async fn handle_sse_event(
|
||||
server_name: &str,
|
||||
sse_url: &str,
|
||||
shared: &std::sync::Arc<Mutex<SseSharedState>>,
|
||||
notify: &std::sync::Arc<Notify>,
|
||||
event: Option<&str>,
|
||||
_id: Option<&str>,
|
||||
data: String,
|
||||
) {
|
||||
let event = event.unwrap_or("message");
|
||||
let trimmed = data.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
|
||||
if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
|
||||
let mut guard = shared.lock().await;
|
||||
guard.message_url = Some(url);
|
||||
guard.message_url_from_endpoint = true;
|
||||
drop(guard);
|
||||
notify.notify_waiters();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !event.eq_ignore_ascii_case("message") {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
|
||||
let _ = serde_json::from_value::<JsonRpcRequest>(value);
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(id_val) = resp.id.clone() else {
|
||||
return;
|
||||
};
|
||||
let id = match id_val.as_u64() {
|
||||
Some(v) => v,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let tx = {
|
||||
let mut guard = shared.lock().await;
|
||||
guard.pending.remove(&id)
|
||||
};
|
||||
if let Some(tx) = tx {
|
||||
let _ = tx.send(resp);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"MCP SSE `{}` received response for unknown id {}",
|
||||
server_name,
|
||||
id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
|
||||
if data.starts_with('{') {
|
||||
let v: serde_json::Value = serde_json::from_str(data).ok()?;
|
||||
let endpoint = v.get("endpoint")?.as_str()?;
|
||||
return parse_endpoint_from_data(sse_url, endpoint);
|
||||
}
|
||||
if data.starts_with("http://") || data.starts_with("https://") {
|
||||
return Some(data.to_string());
|
||||
}
|
||||
let base = reqwest::Url::parse(sse_url).ok()?;
|
||||
base.join(data).ok().map(|u| u.to_string())
|
||||
}
|
||||
|
||||
fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
|
||||
let text = resp_text.trim_start_matches('\u{feff}');
|
||||
let mut current_data_lines: Vec<&str> = Vec::new();
|
||||
let mut last_event_data_lines: Vec<&str> = Vec::new();
|
||||
|
||||
for raw_line in text.lines() {
|
||||
let line = raw_line.trim_end_matches('\r').trim_start();
|
||||
if line.is_empty() {
|
||||
if !current_data_lines.is_empty() {
|
||||
last_event_data_lines = std::mem::take(&mut current_data_lines);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
current_data_lines.push(rest);
|
||||
}
|
||||
}
|
||||
|
||||
if !current_data_lines.is_empty() {
|
||||
last_event_data_lines = current_data_lines;
|
||||
}
|
||||
|
||||
if last_event_data_lines.is_empty() {
|
||||
return Cow::Borrowed(text.trim());
|
||||
}
|
||||
|
||||
if last_event_data_lines.len() == 1 {
|
||||
return Cow::Borrowed(last_event_data_lines[0].trim());
|
||||
}
|
||||
|
||||
let joined = last_event_data_lines.join("\n");
|
||||
Cow::Owned(joined.trim().to_string())
|
||||
}
|
||||
|
||||
async fn read_first_jsonrpc_from_sse_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<Option<JsonRpcResponse>> {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|item| item.map_err(std::io::Error::other));
|
||||
let reader = tokio_util::io::StreamReader::new(stream);
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
let mut cur_event: Option<String> = None;
|
||||
let mut cur_data: Vec<String> = Vec::new();
|
||||
|
||||
while let Ok(line_opt) = lines.next_line().await {
|
||||
let Some(mut line) = line_opt else { break };
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
if line.is_empty() {
|
||||
if cur_event.is_none() && cur_data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let event = cur_event.take();
|
||||
let data = cur_data.join("\n");
|
||||
cur_data.clear();
|
||||
|
||||
let event = event.unwrap_or_else(|| "message".to_string());
|
||||
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if !event.eq_ignore_ascii_case("message") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let trimmed = data.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let json_str = extract_json_from_sse_text(trimmed);
|
||||
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
|
||||
return Ok(Some(resp));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
cur_event = Some(rest.trim().to_string());
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
cur_data.push(rest.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for SseTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
self.ensure_connected().await?;
|
||||
|
||||
let id = request.id.as_ref().and_then(|v| v.as_u64());
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
|
||||
if self.stream_state == SseStreamState::Connected && !from_endpoint {
|
||||
for _ in 0..3 {
|
||||
{
|
||||
let guard = self.shared.lock().await;
|
||||
if guard.message_url_from_endpoint {
|
||||
if let Some(url) = &guard.message_url {
|
||||
message_url = url.clone();
|
||||
from_endpoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
|
||||
}
|
||||
}
|
||||
let primary_url = if from_endpoint {
|
||||
message_url.clone()
|
||||
} else {
|
||||
self.sse_url.clone()
|
||||
};
|
||||
let secondary_url = if message_url == self.sse_url {
|
||||
None
|
||||
} else if primary_url == message_url {
|
||||
Some(self.sse_url.clone())
|
||||
} else {
|
||||
Some(message_url.clone())
|
||||
};
|
||||
let has_secondary = secondary_url.is_some();
|
||||
|
||||
let mut rx = None;
|
||||
if let Some(id) = id {
|
||||
if self.stream_state == SseStreamState::Connected {
|
||||
let (tx, ch) = oneshot::channel();
|
||||
{
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.insert(id, tx);
|
||||
}
|
||||
rx = Some((id, ch));
|
||||
}
|
||||
}
|
||||
|
||||
let mut got_direct = None;
|
||||
let mut last_status = None;
|
||||
|
||||
for (i, url) in std::iter::once(primary_url)
|
||||
.chain(secondary_url.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(120))
|
||||
.body(body.clone())
|
||||
.header("Content-Type", "application/json");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"))
|
||||
{
|
||||
req = req.header("Accept", "application/json, text/event-stream");
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
let status = resp.status();
|
||||
last_status = Some(status);
|
||||
|
||||
if (status == reqwest::StatusCode::NOT_FOUND
|
||||
|| status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
|
||||
&& i == 0
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
break;
|
||||
}
|
||||
|
||||
if request.id.is_none() {
|
||||
got_direct = Some(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
let is_sse = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
|
||||
if is_sse {
|
||||
if i == 0 && has_secondary {
|
||||
match timeout(
|
||||
Duration::from_secs(3),
|
||||
read_first_jsonrpc_from_sse_response(resp),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if let Some(resp) = res? {
|
||||
got_direct = Some(resp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
|
||||
got_direct = Some(resp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let text = if i == 0 && has_secondary {
|
||||
match timeout(Duration::from_secs(3), resp.text()).await {
|
||||
Ok(Ok(t)) => t,
|
||||
Ok(Err(_)) => String::new(),
|
||||
Err(_) => continue,
|
||||
}
|
||||
} else {
|
||||
resp.text().await.unwrap_or_default()
|
||||
};
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
|
||||
extract_json_from_sse_text(trimmed)
|
||||
} else {
|
||||
Cow::Borrowed(trimmed)
|
||||
};
|
||||
if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
|
||||
got_direct = Some(mcp_resp);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some((id, _)) = rx.as_ref() {
|
||||
if got_direct.is_some() {
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.remove(id);
|
||||
} else if let Some(status) = last_status {
|
||||
if !status.is_success() {
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(resp) = got_direct {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
if let Some(status) = last_status {
|
||||
if !status.is_success() {
|
||||
bail!("MCP server returned HTTP {}", status);
|
||||
}
|
||||
} else {
|
||||
bail!("MCP request not sent");
|
||||
}
|
||||
|
||||
let Some((_id, rx)) = rx else {
|
||||
bail!("MCP server returned no response");
|
||||
};
|
||||
|
||||
rx.await.map_err(|_| anyhow!("SSE response channel closed"))
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
if let Some(task) = self.reader_task.take() {
|
||||
task.abort();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Factory ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Create a transport based on config.
|
||||
pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
|
||||
match config.transport {
|
||||
McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
|
||||
McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
|
||||
McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transport_default_is_stdio() {
|
||||
let config = McpServerConfig::default();
|
||||
assert_eq!(config.transport, McpTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Http,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(HttpTransport::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Sse,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(SseTransport::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_data_no_space() {
|
||||
let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_with_event_and_id() {
|
||||
let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_multiline_data() {
|
||||
let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
|
||||
let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_uses_last_event_with_data() {
|
||||
let input =
|
||||
": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
}
|
||||
+8
-1
@@ -40,6 +40,10 @@ pub mod hardware_memory_map;
|
||||
pub mod hardware_memory_read;
|
||||
pub mod http_request;
|
||||
pub mod image_info;
|
||||
pub mod mcp_client;
|
||||
pub mod mcp_protocol;
|
||||
pub mod mcp_tool;
|
||||
pub mod mcp_transport;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
@@ -289,11 +293,13 @@ pub fn all_tools_with_runtime(
|
||||
|
||||
// Web search tool (enabled by default for GLM and other models)
|
||||
if root_config.web_search.enabled {
|
||||
tool_arcs.push(Arc::new(WebSearchTool::new(
|
||||
tool_arcs.push(Arc::new(WebSearchTool::new_with_config(
|
||||
root_config.web_search.provider.clone(),
|
||||
root_config.web_search.brave_api_key.clone(),
|
||||
root_config.web_search.max_results,
|
||||
root_config.web_search.timeout_secs,
|
||||
root_config.config_path.clone(),
|
||||
root_config.secrets.encrypt,
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -338,6 +344,7 @@ pub fn all_tools_with_runtime(
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
},
|
||||
)
|
||||
.with_parent_tools(parent_tools)
|
||||
|
||||
@@ -2,15 +2,26 @@ use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::json;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Web search tool for searching the internet.
|
||||
/// Supports multiple providers: DuckDuckGo (free), Brave (requires API key).
|
||||
///
|
||||
/// The Brave API key is resolved lazily at execution time: if the boot-time key
|
||||
/// is missing or still encrypted, the tool re-reads `config.toml`, decrypts the
|
||||
/// `[web_search] brave_api_key` field, and uses the result. This ensures that
|
||||
/// keys set or rotated after boot, and encrypted keys, are correctly picked up.
|
||||
pub struct WebSearchTool {
|
||||
provider: String,
|
||||
brave_api_key: Option<String>,
|
||||
/// Boot-time key snapshot (may be `None` if not yet configured at startup).
|
||||
boot_brave_api_key: Option<String>,
|
||||
max_results: usize,
|
||||
timeout_secs: u64,
|
||||
/// Path to `config.toml` for lazy re-read of keys at execution time.
|
||||
config_path: PathBuf,
|
||||
/// Whether secret encryption is enabled (needed to create a `SecretStore`).
|
||||
secrets_encrypt: bool,
|
||||
}
|
||||
|
||||
impl WebSearchTool {
|
||||
@@ -22,9 +33,85 @@ impl WebSearchTool {
|
||||
) -> Self {
|
||||
Self {
|
||||
provider: provider.trim().to_lowercase(),
|
||||
brave_api_key,
|
||||
boot_brave_api_key: brave_api_key,
|
||||
max_results: max_results.clamp(1, 10),
|
||||
timeout_secs: timeout_secs.max(1),
|
||||
config_path: PathBuf::new(),
|
||||
secrets_encrypt: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a `WebSearchTool` with config-reload and decryption support.
|
||||
///
|
||||
/// `config_path` is the path to `config.toml` so the tool can re-read the
|
||||
/// Brave API key at execution time. `secrets_encrypt` controls whether the
|
||||
/// key is decrypted via `SecretStore`.
|
||||
pub fn new_with_config(
|
||||
provider: String,
|
||||
brave_api_key: Option<String>,
|
||||
max_results: usize,
|
||||
timeout_secs: u64,
|
||||
config_path: PathBuf,
|
||||
secrets_encrypt: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider: provider.trim().to_lowercase(),
|
||||
boot_brave_api_key: brave_api_key,
|
||||
max_results: max_results.clamp(1, 10),
|
||||
timeout_secs: timeout_secs.max(1),
|
||||
config_path,
|
||||
secrets_encrypt,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the Brave API key, preferring the boot-time value but falling
|
||||
/// back to a fresh config read + decryption when the boot-time value is
|
||||
/// absent.
|
||||
fn resolve_brave_api_key(&self) -> anyhow::Result<String> {
|
||||
// Fast path: boot-time key is present and usable (not an encrypted blob).
|
||||
if let Some(ref key) = self.boot_brave_api_key {
|
||||
if !key.is_empty() && !crate::security::SecretStore::is_encrypted(key) {
|
||||
return Ok(key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: re-read config.toml to pick up keys set/rotated after boot.
|
||||
self.reload_brave_api_key()
|
||||
}
|
||||
|
||||
/// Re-read `config.toml` and decrypt `[web_search] brave_api_key`.
|
||||
fn reload_brave_api_key(&self) -> anyhow::Result<String> {
|
||||
let contents = std::fs::read_to_string(&self.config_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read config file {} for Brave API key: {e}",
|
||||
self.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let config: crate::config::Config = toml::from_str(&contents).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to parse config file {} for Brave API key: {e}",
|
||||
self.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let raw_key = config
|
||||
.web_search
|
||||
.brave_api_key
|
||||
.filter(|k| !k.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("Brave API key not configured"))?;
|
||||
|
||||
// Decrypt if necessary.
|
||||
if crate::security::SecretStore::is_encrypted(&raw_key) {
|
||||
let zeroclaw_dir = self.config_path.parent().unwrap_or_else(|| Path::new("."));
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets_encrypt);
|
||||
let plaintext = store.decrypt(&raw_key)?;
|
||||
if plaintext.is_empty() {
|
||||
anyhow::bail!("Brave API key not configured (decrypted value is empty)");
|
||||
}
|
||||
Ok(plaintext)
|
||||
} else {
|
||||
Ok(raw_key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,10 +186,7 @@ impl WebSearchTool {
|
||||
}
|
||||
|
||||
async fn search_brave(&self, query: &str) -> anyhow::Result<String> {
|
||||
let api_key = self
|
||||
.brave_api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Brave API key not configured"))?;
|
||||
let api_key = self.resolve_brave_api_key()?;
|
||||
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let search_url = format!(
|
||||
@@ -117,7 +201,7 @@ impl WebSearchTool {
|
||||
let response = client
|
||||
.get(&search_url)
|
||||
.header("Accept", "application/json")
|
||||
.header("X-Subscription-Token", api_key)
|
||||
.header("X-Subscription-Token", &api_key)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
@@ -328,4 +412,91 @@ mod tests {
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_brave_api_key_uses_boot_key() {
|
||||
let tool = WebSearchTool::new(
|
||||
"brave".to_string(),
|
||||
Some("sk-plaintext-key".to_string()),
|
||||
5,
|
||||
15,
|
||||
);
|
||||
let key = tool.resolve_brave_api_key().unwrap();
|
||||
assert_eq!(key, "sk-plaintext-key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_brave_api_key_reloads_from_config() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
"[web_search]\nbrave_api_key = \"fresh-key-from-disk\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// No boot key -- forces reload from config
|
||||
let tool =
|
||||
WebSearchTool::new_with_config("brave".to_string(), None, 5, 15, config_path, false);
|
||||
let key = tool.resolve_brave_api_key().unwrap();
|
||||
assert_eq!(key, "fresh-key-from-disk");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_brave_api_key_decrypts_encrypted_key() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let store = crate::security::SecretStore::new(tmp.path(), true);
|
||||
let encrypted = store.encrypt("brave-secret-key").unwrap();
|
||||
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
format!("[web_search]\nbrave_api_key = \"{}\"\n", encrypted),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Boot key is the encrypted blob -- should trigger reload + decrypt
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
Some(encrypted),
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
true,
|
||||
);
|
||||
let key = tool.resolve_brave_api_key().unwrap();
|
||||
assert_eq!(key, "brave-secret-key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_brave_api_key_picks_up_runtime_update() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
|
||||
// Start with no key in config
|
||||
std::fs::write(&config_path, "[web_search]\n").unwrap();
|
||||
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path.clone(),
|
||||
false,
|
||||
);
|
||||
|
||||
// Key not configured yet -- should fail
|
||||
assert!(tool.resolve_brave_api_key().is_err());
|
||||
|
||||
// Simulate runtime config update (e.g. via web_search_config set)
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
"[web_search]\nbrave_api_key = \"runtime-updated-key\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Now should succeed with the updated key
|
||||
let key = tool.resolve_brave_api_key().unwrap();
|
||||
assert_eq!(key, "runtime-updated-key");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,34 @@ use anyhow::{bail, Result};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// Try to extract a real tunnel URL from a cloudflared log line.
|
||||
///
|
||||
/// Returns `Some(url)` when the line contains a genuine tunnel endpoint,
|
||||
/// skipping documentation and warning URLs (quic-go GitHub links,
|
||||
/// Cloudflare docs pages, etc.).
|
||||
fn extract_tunnel_url(line: &str) -> Option<String> {
|
||||
let idx = line.find("https://")?;
|
||||
let url_part = &line[idx..];
|
||||
let end = url_part
|
||||
.find(|c: char| c.is_whitespace())
|
||||
.unwrap_or(url_part.len());
|
||||
let candidate = &url_part[..end];
|
||||
|
||||
let is_tunnel_line = line.contains("Visit it at")
|
||||
|| line.contains("Route at")
|
||||
|| line.contains("Registered tunnel connection");
|
||||
let is_tunnel_domain = candidate.contains(".trycloudflare.com");
|
||||
let is_docs_url = candidate.contains("github.com")
|
||||
|| candidate.contains("cloudflare.com/docs")
|
||||
|| candidate.contains("developers.cloudflare.com");
|
||||
|
||||
if is_tunnel_line || is_tunnel_domain || !is_docs_url {
|
||||
Some(candidate.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Cloudflare Tunnel — wraps the `cloudflared` binary.
|
||||
///
|
||||
/// Requires `cloudflared` installed and a tunnel token from the
|
||||
@@ -62,13 +90,8 @@ impl Tunnel for CloudflareTunnel {
|
||||
match line {
|
||||
Ok(Ok(Some(l))) => {
|
||||
tracing::debug!("cloudflared: {l}");
|
||||
// Look for the URL pattern in cloudflared output
|
||||
if let Some(idx) = l.find("https://") {
|
||||
let url_part = &l[idx..];
|
||||
let end = url_part
|
||||
.find(|c: char| c.is_whitespace())
|
||||
.unwrap_or(url_part.len());
|
||||
public_url = url_part[..end].to_string();
|
||||
if let Some(url) = extract_tunnel_url(&l) {
|
||||
public_url = url;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -138,4 +161,55 @@ mod tests {
|
||||
let tunnel = CloudflareTunnel::new("cf-token".into());
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_skips_quic_go_github_url() {
|
||||
let line = "2024-01-01T00:00:00Z WRN failed to sufficiently increase receive buffer size. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.";
|
||||
assert_eq!(extract_tunnel_url(line), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_skips_cloudflare_docs_url() {
|
||||
let line = "2024-01-01T00:00:00Z INF For more info see https://cloudflare.com/docs/tunnels";
|
||||
assert_eq!(extract_tunnel_url(line), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_skips_developers_cloudflare_url() {
|
||||
let line = "2024-01-01T00:00:00Z INF See https://developers.cloudflare.com/cloudflare-one/connections/connect-apps";
|
||||
assert_eq!(extract_tunnel_url(line), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_captures_trycloudflare_url() {
|
||||
let line = "2024-01-01T00:00:00Z INF Visit it at https://my-tunnel-abc.trycloudflare.com";
|
||||
assert_eq!(
|
||||
extract_tunnel_url(line),
|
||||
Some("https://my-tunnel-abc.trycloudflare.com".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_captures_url_on_visit_it_at_line() {
|
||||
let line = "2024-01-01T00:00:00Z INF Visit it at https://some-custom-domain.example.com";
|
||||
assert_eq!(
|
||||
extract_tunnel_url(line),
|
||||
Some("https://some-custom-domain.example.com".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_captures_url_on_route_at_line() {
|
||||
let line = "2024-01-01T00:00:00Z INF Route at https://tunnel.example.com/path";
|
||||
assert_eq!(
|
||||
extract_tunnel_url(line),
|
||||
Some("https://tunnel.example.com/path".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_returns_none_for_line_without_url() {
|
||||
let line = "2024-01-01T00:00:00Z INF Starting tunnel";
|
||||
assert_eq!(extract_tunnel_url(line), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,6 +151,7 @@ async fn openai_codex_second_vision_support() -> Result<()> {
|
||||
zeroclaw_dir: None,
|
||||
secrets_encrypt: false,
|
||||
reasoning_enabled: None,
|
||||
provider_timeout_secs: None,
|
||||
};
|
||||
|
||||
let provider = zeroclaw::providers::create_provider_with_options("openai-codex", None, &opts)?;
|
||||
|
||||
+2
-2
@@ -80,7 +80,7 @@ function PairingDialog({ onPair }: { onPair: (code: string) => Promise<void> })
|
||||
}
|
||||
|
||||
function AppContent() {
|
||||
const { isAuthenticated, loading, pair, logout } = useAuth();
|
||||
const { isAuthenticated, requiresPairing, loading, pair, logout } = useAuth();
|
||||
const [locale, setLocaleState] = useState('tr');
|
||||
|
||||
const setAppLocale = (newLocale: string) => {
|
||||
@@ -105,7 +105,7 @@ function AppContent() {
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAuthenticated) {
|
||||
if (!isAuthenticated && requiresPairing) {
|
||||
return <PairingDialog onPair={pair} />;
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ export interface AuthState {
|
||||
token: string | null;
|
||||
/** Whether the user is currently authenticated. */
|
||||
isAuthenticated: boolean;
|
||||
/** Whether the server requires pairing. Defaults to true (safe fallback). */
|
||||
requiresPairing: boolean;
|
||||
/** True while the initial auth check is in progress. */
|
||||
loading: boolean;
|
||||
/** Pair with the agent using a pairing code. Stores the token on success. */
|
||||
@@ -45,6 +47,7 @@ export interface AuthProviderProps {
|
||||
export function AuthProvider({ children }: AuthProviderProps) {
|
||||
const [token, setTokenState] = useState<string | null>(readToken);
|
||||
const [authenticated, setAuthenticated] = useState<boolean>(checkAuth);
|
||||
const [requiresPairing, setRequiresPairing] = useState<boolean>(true);
|
||||
const [loading, setLoading] = useState<boolean>(!checkAuth());
|
||||
|
||||
// On mount: check if server requires pairing at all
|
||||
@@ -55,6 +58,7 @@ export function AuthProvider({ children }: AuthProviderProps) {
|
||||
.then((health) => {
|
||||
if (cancelled) return;
|
||||
if (!health.require_pairing) {
|
||||
setRequiresPairing(false);
|
||||
setAuthenticated(true);
|
||||
}
|
||||
})
|
||||
@@ -98,6 +102,7 @@ export function AuthProvider({ children }: AuthProviderProps) {
|
||||
const value: AuthState = {
|
||||
token,
|
||||
isAuthenticated: authenticated,
|
||||
requiresPairing,
|
||||
loading,
|
||||
pair,
|
||||
logout,
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Generate a UUID v4 string.
|
||||
*
|
||||
* Uses `crypto.randomUUID()` when available (modern browsers, secure contexts)
|
||||
* and falls back to a manual implementation backed by `crypto.getRandomValues()`
|
||||
* for older browsers (e.g. Safari < 15.4, some Electron/Raspberry-Pi builds).
|
||||
*
|
||||
* Closes #3303, #3261.
|
||||
*/
|
||||
export function generateUUID(): string {
|
||||
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
|
||||
return crypto.randomUUID();
|
||||
}
|
||||
|
||||
// Fallback: RFC 4122 version 4 UUID via getRandomValues
|
||||
// crypto must exist if we reached here (only randomUUID is missing)
|
||||
const c = globalThis.crypto;
|
||||
const bytes = new Uint8Array(16);
|
||||
c.getRandomValues(bytes);
|
||||
|
||||
// Set version (4) and variant (10xx) bits per RFC 4122
|
||||
bytes[6] = (bytes[6]! & 0x0f) | 0x40;
|
||||
bytes[8] = (bytes[8]! & 0x3f) | 0x80;
|
||||
|
||||
const hex = Array.from(bytes, (b) => b.toString(16).padStart(2, '0')).join('');
|
||||
return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20)}`;
|
||||
}
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
import type { WsMessage } from '../types/api';
|
||||
import { getToken } from './auth';
|
||||
import { generateUUID } from './uuid';
|
||||
|
||||
export type WsMessageHandler = (msg: WsMessage) => void;
|
||||
export type WsOpenHandler = () => void;
|
||||
@@ -26,7 +27,7 @@ const SESSION_STORAGE_KEY = 'zeroclaw_session_id';
|
||||
function getOrCreateSessionId(): string {
|
||||
let id = sessionStorage.getItem(SESSION_STORAGE_KEY);
|
||||
if (!id) {
|
||||
id = crypto.randomUUID();
|
||||
id = generateUUID();
|
||||
sessionStorage.setItem(SESSION_STORAGE_KEY, id);
|
||||
}
|
||||
return id;
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { Send, Bot, User, AlertCircle, Copy, Check } from 'lucide-react';
|
||||
import type { WsMessage } from '@/types/api';
|
||||
import { WebSocketClient } from '@/lib/ws';
|
||||
import { generateUUID } from '@/lib/uuid';
|
||||
|
||||
interface ChatMessage {
|
||||
id: string;
|
||||
@@ -53,7 +54,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content,
|
||||
timestamp: new Date(),
|
||||
@@ -69,7 +70,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content: `[Tool Call] ${msg.name ?? 'unknown'}(${JSON.stringify(msg.args ?? {})})`,
|
||||
timestamp: new Date(),
|
||||
@@ -81,7 +82,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content: `[Tool Result] ${msg.output ?? ''}`,
|
||||
timestamp: new Date(),
|
||||
@@ -93,7 +94,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content: `[Error] ${msg.message ?? 'Unknown error'}`,
|
||||
timestamp: new Date(),
|
||||
@@ -124,7 +125,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'user',
|
||||
content: trimmed,
|
||||
timestamp: new Date(),
|
||||
|
||||
Reference in New Issue
Block a user