diff --git a/.github/workflows/ci-queue-hygiene.yml b/.github/workflows/ci-queue-hygiene.yml
index ada0baf02..b1655435a 100644
--- a/.github/workflows/ci-queue-hygiene.yml
+++ b/.github/workflows/ci-queue-hygiene.yml
@@ -8,7 +8,7 @@ on:
apply:
description: "Cancel selected queued runs (false = dry-run report only)"
required: true
- default: true
+ default: false
type: boolean
status:
description: "Queued-run status scope"
@@ -57,22 +57,27 @@ jobs:
status_scope="queued"
max_cancel="120"
- apply_mode="true"
+ apply_mode="false"
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
status_scope="${{ github.event.inputs.status || 'queued' }}"
max_cancel="${{ github.event.inputs.max_cancel || '120' }}"
- apply_mode="${{ github.event.inputs.apply || 'true' }}"
+ apply_mode="${{ github.event.inputs.apply || 'false' }}"
fi
cmd=(python3 scripts/ci/queue_hygiene.py
--repo "${{ github.repository }}"
--status "${status_scope}"
--max-cancel "${max_cancel}"
+ --dedupe-workflow "CI Run"
+ --dedupe-workflow "Test E2E"
+ --dedupe-workflow "Docs Deploy"
--dedupe-workflow "PR Intake Checks"
--dedupe-workflow "PR Labeler"
--dedupe-workflow "PR Auto Responder"
--dedupe-workflow "Workflow Sanity"
--dedupe-workflow "PR Label Policy Check"
+ --dedupe-include-non-pr
+ --non-pr-key branch
--output-json artifacts/queue-hygiene-report.json
--verbose)
diff --git a/.github/workflows/ci-run.yml b/.github/workflows/ci-run.yml
index 196b15cc6..d28abcf0a 100644
--- a/.github/workflows/ci-run.yml
+++ b/.github/workflows/ci-run.yml
@@ -9,7 +9,7 @@ on:
branches: [dev, main]
concurrency:
- group: ci-${{ github.event.pull_request.number || github.sha }}
+ group: ci-run-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
cancel-in-progress: true
permissions:
diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml
index 6ac5c220a..c1f55d7db 100644
--- a/.github/workflows/docs-deploy.yml
+++ b/.github/workflows/docs-deploy.yml
@@ -41,7 +41,7 @@ on:
default: ""
concurrency:
- group: docs-deploy-${{ github.event.pull_request.number || github.sha }}
+ group: docs-deploy-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
cancel-in-progress: true
permissions:
diff --git a/.github/workflows/test-e2e.yml b/.github/workflows/test-e2e.yml
index ce3b00a17..8f9a005fd 100644
--- a/.github/workflows/test-e2e.yml
+++ b/.github/workflows/test-e2e.yml
@@ -14,7 +14,7 @@ on:
workflow_dispatch:
concurrency:
- group: e2e-${{ github.event.pull_request.number || github.sha }}
+ group: test-e2e-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
cancel-in-progress: true
permissions:
diff --git a/README.md b/README.md
index 8b8831366..d4b66ddaf 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.
Getting Started |
- One-Click Setup |
+ One-Click Setup |
Docs Hub |
Docs TOC
@@ -108,11 +108,11 @@ cargo install zeroclaw
### First Run
```bash
-# Start the gateway daemon
-zeroclaw gateway start
+# Start the gateway (serves the Web Dashboard API/UI)
+zeroclaw gateway
-# Open the web UI
-zeroclaw dashboard
+# Open the dashboard URL shown in startup logs
+# (default: http://127.0.0.1:3000/)
# Or chat directly
zeroclaw chat "Hello!"
@@ -120,6 +120,16 @@ zeroclaw chat "Hello!"
For detailed setup options, see [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md).
+### Installation Docs (Canonical Source)
+
+Use repository docs as the source of truth for install/setup instructions:
+
+- [README Quick Start](#quick-start)
+- [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md)
+- [docs/getting-started/README.md](docs/getting-started/README.md)
+
+Issue comments can provide context, but they are not canonical installation documentation.
+
## Benchmark Snapshot (ZeroClaw vs OpenClaw, Reproducible)
Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware.
diff --git a/docs/README.md b/docs/README.md
index 05d6c6cb1..317ae8422 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -29,6 +29,8 @@ Localized hubs: [简体中文](i18n/zh-CN/README.md) · [日本語](i18n/ja/READ
| See project PR/issue docs snapshot | [project-triage-snapshot-2026-02-18.md](project-triage-snapshot-2026-02-18.md) |
| Perform i18n completion for docs changes | [i18n-guide.md](i18n-guide.md) |
+Installation source-of-truth: keep install/run instructions in repository docs and README pages; issue comments are supplemental context only.
+
## Quick Decision Tree (10 seconds)
- Need first-time setup or install? → [getting-started/README.md](getting-started/README.md)
diff --git a/docs/commands-reference.md b/docs/commands-reference.md
index c15fc8514..4b4740997 100644
--- a/docs/commands-reference.md
+++ b/docs/commands-reference.md
@@ -15,6 +15,7 @@ Last verified: **February 28, 2026**.
| `service` | Manage user-level OS service lifecycle |
| `doctor` | Run diagnostics and freshness checks |
| `status` | Print current configuration and system summary |
+| `update` | Check or install latest ZeroClaw release |
| `estop` | Engage/resume emergency stop levels and inspect estop state |
| `cron` | Manage scheduled tasks |
| `models` | Refresh provider model catalogs |
@@ -103,6 +104,18 @@ Notes:
- `zeroclaw service status`
- `zeroclaw service uninstall`
+### `update`
+
+- `zeroclaw update --check` (check for new release, no install)
+- `zeroclaw update` (install latest release binary for current platform)
+- `zeroclaw update --force` (reinstall even if current version matches latest)
+- `zeroclaw update --instructions` (print install-method-specific guidance)
+
+Notes:
+
+- If ZeroClaw is installed via Homebrew, prefer `brew upgrade zeroclaw`.
+- `update --instructions` detects common install methods and prints the safest path.
+
### `cron`
- `zeroclaw cron list`
@@ -264,6 +277,11 @@ Registry packages are installed to `~/.zeroclaw/workspace/skills//`.
Use `skills audit` to manually validate a candidate skill directory (or an installed skill by name) before sharing it.
+Workspace symlink policy:
+- Symlinked entries under `~/.zeroclaw/workspace/skills/` are blocked by default.
+- To allow shared local skill directories, set `[skills].trusted_skill_roots` in `config.toml`.
+- A symlinked skill is accepted only when its resolved canonical target is inside one of the trusted roots.
+
Skill manifests (`SKILL.toml`) support `prompts` and `[[tools]]`; both are injected into the agent system prompt at runtime, so the model can follow skill instructions without manually reading skill files.
### `migrate`
diff --git a/docs/config-reference.md b/docs/config-reference.md
index 1211ccbb6..b4145909f 100644
--- a/docs/config-reference.md
+++ b/docs/config-reference.md
@@ -536,6 +536,7 @@ Notes:
|---|---|---|
| `open_skills_enabled` | `false` | Opt-in loading/sync of community `open-skills` repository |
| `open_skills_dir` | unset | Optional local path for `open-skills` (defaults to `$HOME/open-skills` when enabled) |
+| `trusted_skill_roots` | `[]` | Allowlist of directory roots for symlink targets in `workspace/skills/*` |
| `prompt_injection_mode` | `full` | Skill prompt verbosity: `full` (inline instructions/tools) or `compact` (name/description/location only) |
| `clawhub_token` | unset | Optional Bearer token for authenticated ClawhHub skill downloads |
@@ -548,7 +549,8 @@ Notes:
- `ZEROCLAW_SKILLS_PROMPT_MODE` accepts `full` or `compact`.
- Precedence for enable flag: `ZEROCLAW_OPEN_SKILLS_ENABLED` → `skills.open_skills_enabled` in `config.toml` → default `false`.
- `prompt_injection_mode = "compact"` is recommended on low-context local models to reduce startup prompt size while keeping skill files available on demand.
-- Skill loading and `zeroclaw skills install` both apply a static security audit. Skills that contain symlinks, script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected.
+- Symlinked workspace skills are blocked by default. Set `trusted_skill_roots` to allow local shared-skill directories after explicit trust review.
+- `zeroclaw skills install` and `zeroclaw skills audit` apply a static security audit. Skills that contain script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected.
- `clawhub_token` is sent as `Authorization: Bearer ` when downloading from ClawhHub. Obtain a token from [https://clawhub.ai](https://clawhub.ai) after signing in. Required if the API returns 429 (rate-limited) or 401 (unauthorized) for anonymous requests.
**ClawhHub token example:**
diff --git a/docs/getting-started/macos-update-uninstall.md b/docs/getting-started/macos-update-uninstall.md
index 944cd4ce3..f08bc5042 100644
--- a/docs/getting-started/macos-update-uninstall.md
+++ b/docs/getting-started/macos-update-uninstall.md
@@ -20,6 +20,13 @@ If both exist, your shell `PATH` order decides which one runs.
## 2) Update on macOS
+Quick way to get install-method-specific guidance:
+
+```bash
+zeroclaw update --instructions
+zeroclaw update --check
+```
+
### A) Homebrew install
```bash
@@ -54,6 +61,13 @@ Re-run your download/install flow with the latest release asset, then verify:
zeroclaw --version
```
+You can also use the built-in updater for manual/local installs:
+
+```bash
+zeroclaw update
+zeroclaw --version
+```
+
## 3) Uninstall on macOS
### A) Stop and remove background service first
diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh
index cee7251ad..4bd1ac7a5 100755
--- a/scripts/bootstrap.sh
+++ b/scripts/bootstrap.sh
@@ -423,11 +423,18 @@ string_to_bool() {
}
guided_input_stream() {
- if [[ -t 0 ]]; then
+ # Some constrained Linux containers report interactive stdin but deny opening
+ # /dev/stdin directly. Probe readability before selecting it.
+ if [[ -t 0 ]] && (: /dev/null; then
echo "/dev/stdin"
return 0
fi
+ if [[ -t 0 ]] && (: /dev/null; then
+ echo "/proc/self/fd/0"
+ return 0
+ fi
+
if (: /dev/null; then
echo "/dev/tty"
return 0
diff --git a/scripts/ci/queue_hygiene.py b/scripts/ci/queue_hygiene.py
index 9255e9b64..ebeb22699 100755
--- a/scripts/ci/queue_hygiene.py
+++ b/scripts/ci/queue_hygiene.py
@@ -66,6 +66,15 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Also dedupe non-PR runs (push/manual). Default dedupe scope is PR-originated runs only.",
)
+ parser.add_argument(
+ "--non-pr-key",
+ default="sha",
+ choices=["sha", "branch"],
+ help=(
+ "Identity key mode for non-PR dedupe when --dedupe-include-non-pr is enabled: "
+ "'sha' keeps one run per commit (default), 'branch' keeps one run per branch."
+ ),
+ )
parser.add_argument(
"--max-cancel",
type=int,
@@ -165,7 +174,7 @@ def parse_timestamp(value: str | None) -> datetime:
return datetime.fromtimestamp(0, tz=timezone.utc)
-def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]:
+def run_identity_key(run: dict[str, Any], *, non_pr_key: str) -> tuple[str, str, str, str]:
name = str(run.get("name", ""))
event = str(run.get("event", ""))
head_branch = str(run.get("head_branch", ""))
@@ -179,7 +188,10 @@ def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]:
if pr_number:
# For PR traffic, cancel stale runs across synchronize updates for the same PR.
return (name, event, f"pr:{pr_number}", "")
- # For push/manual traffic, key by SHA to avoid collapsing distinct commits.
+ if non_pr_key == "branch":
+ # Branch-level supersedence for push/manual lanes.
+ return (name, event, head_branch, "")
+ # SHA-level supersedence for push/manual lanes.
return (name, event, head_branch, head_sha)
@@ -189,6 +201,7 @@ def collect_candidates(
dedupe_workflows: set[str],
*,
include_non_pr: bool,
+ non_pr_key: str,
) -> tuple[list[dict[str, Any]], Counter[str]]:
reasons_by_id: dict[int, set[str]] = defaultdict(set)
runs_by_id: dict[int, dict[str, Any]] = {}
@@ -220,7 +233,7 @@ def collect_candidates(
has_pr_context = isinstance(pull_requests, list) and len(pull_requests) > 0
if is_pr_event and not has_pr_context and not include_non_pr:
continue
- key = run_identity_key(run)
+ key = run_identity_key(run, non_pr_key=non_pr_key)
by_workflow[name][key].append(run)
for groups in by_workflow.values():
@@ -324,6 +337,7 @@ def main() -> int:
obsolete_workflows,
dedupe_workflows,
include_non_pr=args.dedupe_include_non_pr,
+ non_pr_key=args.non_pr_key,
)
capped = selected[: max(0, args.max_cancel)]
@@ -338,6 +352,7 @@ def main() -> int:
"obsolete_workflows": sorted(obsolete_workflows),
"dedupe_workflows": sorted(dedupe_workflows),
"dedupe_include_non_pr": args.dedupe_include_non_pr,
+ "non_pr_key": args.non_pr_key,
"max_cancel": args.max_cancel,
},
"counts": {
diff --git a/scripts/ci/tests/test_ci_scripts.py b/scripts/ci/tests/test_ci_scripts.py
index 1e5c7921a..f18bec46c 100644
--- a/scripts/ci/tests/test_ci_scripts.py
+++ b/scripts/ci/tests/test_ci_scripts.py
@@ -3759,6 +3759,119 @@ class CiScriptsBehaviorTest(unittest.TestCase):
planned_ids = [item["id"] for item in report["planned_actions"]]
self.assertEqual(planned_ids, [101, 102])
+ def test_queue_hygiene_non_pr_branch_mode_dedupes_push_runs(self) -> None:
+ runs_json = self.tmp / "runs-non-pr-branch.json"
+ output_json = self.tmp / "queue-hygiene-non-pr-branch.json"
+ runs_json.write_text(
+ json.dumps(
+ {
+ "workflow_runs": [
+ {
+ "id": 201,
+ "name": "CI Run",
+ "event": "push",
+ "head_branch": "main",
+ "head_sha": "sha-201",
+ "created_at": "2026-02-27T20:00:00Z",
+ },
+ {
+ "id": 202,
+ "name": "CI Run",
+ "event": "push",
+ "head_branch": "main",
+ "head_sha": "sha-202",
+ "created_at": "2026-02-27T20:01:00Z",
+ },
+ {
+ "id": 203,
+ "name": "CI Run",
+ "event": "push",
+ "head_branch": "dev",
+ "head_sha": "sha-203",
+ "created_at": "2026-02-27T20:02:00Z",
+ },
+ ]
+ }
+ )
+ + "\n",
+ encoding="utf-8",
+ )
+
+ proc = run_cmd(
+ [
+ "python3",
+ self._script("queue_hygiene.py"),
+ "--runs-json",
+ str(runs_json),
+ "--dedupe-workflow",
+ "CI Run",
+ "--dedupe-include-non-pr",
+ "--non-pr-key",
+ "branch",
+ "--output-json",
+ str(output_json),
+ ]
+ )
+ self.assertEqual(proc.returncode, 0, msg=proc.stderr)
+
+ report = json.loads(output_json.read_text(encoding="utf-8"))
+ self.assertEqual(report["counts"]["candidate_runs_before_cap"], 1)
+ planned_ids = [item["id"] for item in report["planned_actions"]]
+ self.assertEqual(planned_ids, [201])
+ reasons = report["planned_actions"][0]["reasons"]
+ self.assertTrue(any(reason.startswith("dedupe-superseded-by:202") for reason in reasons))
+ self.assertEqual(report["policies"]["non_pr_key"], "branch")
+
+ def test_queue_hygiene_non_pr_sha_mode_keeps_distinct_push_commits(self) -> None:
+ runs_json = self.tmp / "runs-non-pr-sha.json"
+ output_json = self.tmp / "queue-hygiene-non-pr-sha.json"
+ runs_json.write_text(
+ json.dumps(
+ {
+ "workflow_runs": [
+ {
+ "id": 301,
+ "name": "CI Run",
+ "event": "push",
+ "head_branch": "main",
+ "head_sha": "sha-301",
+ "created_at": "2026-02-27T20:00:00Z",
+ },
+ {
+ "id": 302,
+ "name": "CI Run",
+ "event": "push",
+ "head_branch": "main",
+ "head_sha": "sha-302",
+ "created_at": "2026-02-27T20:01:00Z",
+ },
+ ]
+ }
+ )
+ + "\n",
+ encoding="utf-8",
+ )
+
+ proc = run_cmd(
+ [
+ "python3",
+ self._script("queue_hygiene.py"),
+ "--runs-json",
+ str(runs_json),
+ "--dedupe-workflow",
+ "CI Run",
+ "--dedupe-include-non-pr",
+ "--output-json",
+ str(output_json),
+ ]
+ )
+ self.assertEqual(proc.returncode, 0, msg=proc.stderr)
+
+ report = json.loads(output_json.read_text(encoding="utf-8"))
+ self.assertEqual(report["counts"]["candidate_runs_before_cap"], 0)
+ self.assertEqual(report["planned_actions"], [])
+ self.assertEqual(report["policies"]["non_pr_key"], "sha")
+
if __name__ == "__main__": # pragma: no cover
unittest.main(verbosity=2)
diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs
index faa31c671..f6f03d932 100644
--- a/src/agent/loop_.rs
+++ b/src/agent/loop_.rs
@@ -1,5 +1,7 @@
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
+use crate::config::schema::{CostEnforcementMode, ModelPricing};
use crate::config::{Config, ProgressMode};
+use crate::cost::{BudgetCheck, CostTracker, UsagePeriod};
use crate::memory::{self, Memory, MemoryCategory};
use crate::multimodal;
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
@@ -19,9 +21,11 @@ use rustyline::hint::Hinter;
use rustyline::validate::Validator;
use rustyline::{CompletionType, Config as RlConfig, Context, Editor, Helper};
use std::borrow::Cow;
-use std::collections::{BTreeSet, HashSet};
+use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Write;
+use std::future::Future;
use std::io::Write as _;
+use std::path::Path;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
@@ -297,6 +301,7 @@ tokio::task_local! {
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
static SAFETY_HEARTBEAT_CONFIG: Option;
static TOOL_LOOP_PROGRESS_MODE: ProgressMode;
+ static TOOL_LOOP_COST_ENFORCEMENT_CONTEXT: Option;
}
/// Configuration for periodic safety-constraint re-injection (heartbeat).
@@ -308,6 +313,56 @@ pub(crate) struct SafetyHeartbeatConfig {
pub interval: usize,
}
+#[derive(Clone)]
+pub(crate) struct CostEnforcementContext {
+ tracker: Arc,
+ prices: HashMap,
+ mode: CostEnforcementMode,
+ route_down_model: Option,
+ reserve_percent: u8,
+}
+
+pub(crate) fn create_cost_enforcement_context(
+ cost_config: &crate::config::CostConfig,
+ workspace_dir: &Path,
+) -> Option {
+ if !cost_config.enabled {
+ return None;
+ }
+ let tracker = match CostTracker::new(cost_config.clone(), workspace_dir) {
+ Ok(tracker) => Arc::new(tracker),
+ Err(error) => {
+ tracing::warn!("Cost budget preflight disabled: failed to initialize tracker: {error}");
+ return None;
+ }
+ };
+ let route_down_model = cost_config
+ .enforcement
+ .route_down_model
+ .clone()
+ .map(|value| value.trim().to_string())
+ .filter(|value| !value.is_empty());
+ Some(CostEnforcementContext {
+ tracker,
+ prices: cost_config.prices.clone(),
+ mode: cost_config.enforcement.mode,
+ route_down_model,
+ reserve_percent: cost_config.enforcement.reserve_percent.min(100),
+ })
+}
+
+pub(crate) async fn scope_cost_enforcement_context(
+ context: Option,
+ future: F,
+) -> F::Output
+where
+ F: Future,
+{
+ TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
+ .scope(context, future)
+ .await
+}
+
fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool {
interval > 0 && counter > 0 && counter % interval == 0
}
@@ -320,6 +375,100 @@ fn should_emit_tool_progress(mode: ProgressMode) -> bool {
mode != ProgressMode::Off
}
+fn estimate_prompt_tokens(
+ messages: &[ChatMessage],
+ tools: Option<&[crate::tools::ToolSpec]>,
+) -> u64 {
+ let message_chars: usize = messages
+ .iter()
+ .map(|msg| {
+ msg.role
+ .len()
+ .saturating_add(msg.content.chars().count())
+ .saturating_add(16)
+ })
+ .sum();
+ let tool_chars: usize = tools
+ .map(|specs| {
+ specs
+ .iter()
+ .map(|spec| serde_json::to_string(spec).map_or(0, |value| value.chars().count()))
+ .sum()
+ })
+ .unwrap_or(0);
+ let total_chars = message_chars.saturating_add(tool_chars);
+ let char_estimate = (total_chars as f64 / 4.0).ceil() as u64;
+ let framing_overhead = (messages.len() as u64).saturating_mul(6).saturating_add(64);
+ char_estimate.saturating_add(framing_overhead)
+}
+
+fn lookup_model_pricing(
+ prices: &HashMap,
+ provider: &str,
+ model: &str,
+) -> (f64, f64) {
+ let full_name = format!("{provider}/{model}");
+ if let Some(pricing) = prices.get(&full_name) {
+ return (pricing.input, pricing.output);
+ }
+ if let Some(pricing) = prices.get(model) {
+ return (pricing.input, pricing.output);
+ }
+ for (key, pricing) in prices {
+ let key_model = key.split('/').next_back().unwrap_or(key);
+ if model.starts_with(key_model) || key_model.starts_with(model) {
+ return (pricing.input, pricing.output);
+ }
+ let normalized_model = model.replace('-', ".");
+ let normalized_key = key_model.replace('-', ".");
+ if normalized_model.contains(&normalized_key) || normalized_key.contains(&normalized_model)
+ {
+ return (pricing.input, pricing.output);
+ }
+ }
+ (3.0, 15.0)
+}
+
+fn estimate_request_cost_usd(
+ context: &CostEnforcementContext,
+ provider: &str,
+ model: &str,
+ messages: &[ChatMessage],
+ tools: Option<&[crate::tools::ToolSpec]>,
+) -> f64 {
+ let reserve_multiplier = 1.0 + (f64::from(context.reserve_percent) / 100.0);
+ let input_tokens = estimate_prompt_tokens(messages, tools);
+ let output_tokens = (input_tokens / 4).max(256);
+ let input_tokens = ((input_tokens as f64) * reserve_multiplier).ceil() as u64;
+ let output_tokens = ((output_tokens as f64) * reserve_multiplier).ceil() as u64;
+
+ let (input_price, output_price) = lookup_model_pricing(&context.prices, provider, model);
+ let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price.max(0.0);
+ let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price.max(0.0);
+ input_cost + output_cost
+}
+
+fn usage_period_label(period: UsagePeriod) -> &'static str {
+ match period {
+ UsagePeriod::Session => "session",
+ UsagePeriod::Day => "daily",
+ UsagePeriod::Month => "monthly",
+ }
+}
+
+fn budget_exceeded_message(
+ model: &str,
+ estimated_cost_usd: f64,
+ current_usd: f64,
+ limit_usd: f64,
+ period: UsagePeriod,
+) -> String {
+ format!(
+ "Budget enforcement blocked request for model '{model}': projected cost (+${estimated_cost_usd:.4}) exceeds {period_label} limit (${limit_usd:.2}, current ${current_usd:.2}).",
+ period_label = usage_period_label(period)
+ )
+}
+
#[derive(Debug, Clone)]
struct ProgressEntry {
name: String,
@@ -778,36 +927,40 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
on_delta: Option>,
hooks: Option<&crate::hooks::HookRunner>,
excluded_tools: &[String],
+ progress_mode: ProgressMode,
safety_heartbeat: Option,
) -> Result {
let reply_target = non_cli_approval_context
.as_ref()
.map(|ctx| ctx.reply_target.clone());
- SAFETY_HEARTBEAT_CONFIG
+ TOOL_LOOP_PROGRESS_MODE
.scope(
- safety_heartbeat,
- TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
- non_cli_approval_context,
- TOOL_LOOP_REPLY_TARGET.scope(
- reply_target,
- run_tool_call_loop(
- provider,
- history,
- tools_registry,
- observer,
- provider_name,
- model,
- temperature,
- silent,
- approval,
- channel_name,
- multimodal_config,
- max_tool_iterations,
- cancellation_token,
- on_delta,
- hooks,
- excluded_tools,
+ progress_mode,
+ SAFETY_HEARTBEAT_CONFIG.scope(
+ safety_heartbeat,
+ TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
+ non_cli_approval_context,
+ TOOL_LOOP_REPLY_TARGET.scope(
+ reply_target,
+ run_tool_call_loop(
+ provider,
+ history,
+ tools_registry,
+ observer,
+ provider_name,
+ model,
+ temperature,
+ silent,
+ approval,
+ channel_name,
+ multimodal_config,
+ max_tool_iterations,
+ cancellation_token,
+ on_delta,
+ hooks,
+ excluded_tools,
+ ),
),
),
),
@@ -890,7 +1043,12 @@ pub(crate) async fn run_tool_call_loop(
let progress_mode = TOOL_LOOP_PROGRESS_MODE
.try_with(|mode| *mode)
.unwrap_or(ProgressMode::Verbose);
+ let cost_enforcement_context = TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
+ .try_with(Clone::clone)
+ .ok()
+ .flatten();
let mut progress_tracker = ProgressTracker::default();
+ let mut active_model = model.to_string();
let bypass_non_cli_approval_for_turn =
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
if bypass_non_cli_approval_for_turn {
@@ -898,7 +1056,7 @@ pub(crate) async fn run_tool_call_loop(
"approval_bypass_one_time_all_tools_consumed",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("consumed one-time non-cli allow-all approval token"),
@@ -950,6 +1108,13 @@ pub(crate) async fn run_tool_call_loop(
request_messages.push(ChatMessage::user(reminder));
}
}
+ // Unified path via Provider::chat so provider-specific native tool logic
+ // (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
+ let request_tools = if use_native_tools {
+ Some(tool_specs.as_slice())
+ } else {
+ None
+ };
// ── Progress: LLM thinking ────────────────────────────
if should_emit_verbose_progress(progress_mode) {
@@ -963,16 +1128,175 @@ pub(crate) async fn run_tool_call_loop(
}
}
+ if let Some(cost_ctx) = cost_enforcement_context.as_ref() {
+ let mut estimated_cost_usd = estimate_request_cost_usd(
+ cost_ctx,
+ provider_name,
+ active_model.as_str(),
+ &request_messages,
+ request_tools,
+ );
+
+ let mut budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
+ Ok(check) => Some(check),
+ Err(error) => {
+ tracing::warn!("Cost preflight check failed: {error}");
+ None
+ }
+ };
+
+ if matches!(cost_ctx.mode, CostEnforcementMode::RouteDown)
+ && matches!(budget_check, Some(BudgetCheck::Exceeded { .. }))
+ {
+ if let Some(route_down_model) = cost_ctx
+ .route_down_model
+ .as_deref()
+ .map(str::trim)
+ .filter(|value| !value.is_empty())
+ {
+ if route_down_model != active_model {
+ let previous_model = active_model.clone();
+ active_model = route_down_model.to_string();
+ estimated_cost_usd = estimate_request_cost_usd(
+ cost_ctx,
+ provider_name,
+ active_model.as_str(),
+ &request_messages,
+ request_tools,
+ );
+ budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
+ Ok(check) => Some(check),
+ Err(error) => {
+ tracing::warn!(
+ "Cost preflight check failed after route-down: {error}"
+ );
+ None
+ }
+ };
+ runtime_trace::record_event(
+ "cost_budget_route_down",
+ Some(channel_name),
+ Some(provider_name),
+ Some(active_model.as_str()),
+ Some(&turn_id),
+ Some(true),
+ Some("budget exceeded on primary model; route-down candidate applied"),
+ serde_json::json!({
+ "iteration": iteration + 1,
+ "from_model": previous_model,
+ "to_model": active_model,
+ "estimated_cost_usd": estimated_cost_usd,
+ }),
+ );
+ }
+ }
+ }
+
+ if let Some(check) = budget_check {
+ match check {
+ BudgetCheck::Allowed => {}
+ BudgetCheck::Warning {
+ current_usd,
+ limit_usd,
+ period,
+ } => {
+ tracing::warn!(
+ model = active_model.as_str(),
+ period = usage_period_label(period),
+ current_usd,
+ limit_usd,
+ estimated_cost_usd,
+ "Cost budget warning threshold reached"
+ );
+ runtime_trace::record_event(
+ "cost_budget_warning",
+ Some(channel_name),
+ Some(provider_name),
+ Some(active_model.as_str()),
+ Some(&turn_id),
+ Some(true),
+ Some("budget warning threshold reached"),
+ serde_json::json!({
+ "iteration": iteration + 1,
+ "period": usage_period_label(period),
+ "current_usd": current_usd,
+ "limit_usd": limit_usd,
+ "estimated_cost_usd": estimated_cost_usd,
+ }),
+ );
+ }
+ BudgetCheck::Exceeded {
+ current_usd,
+ limit_usd,
+ period,
+ } => match cost_ctx.mode {
+ CostEnforcementMode::Warn => {
+ tracing::warn!(
+ model = active_model.as_str(),
+ period = usage_period_label(period),
+ current_usd,
+ limit_usd,
+ estimated_cost_usd,
+ "Cost budget exceeded (warn mode): continuing request"
+ );
+ runtime_trace::record_event(
+ "cost_budget_exceeded_warn_mode",
+ Some(channel_name),
+ Some(provider_name),
+ Some(active_model.as_str()),
+ Some(&turn_id),
+ Some(true),
+ Some("budget exceeded but proceeding due to warn mode"),
+ serde_json::json!({
+ "iteration": iteration + 1,
+ "period": usage_period_label(period),
+ "current_usd": current_usd,
+ "limit_usd": limit_usd,
+ "estimated_cost_usd": estimated_cost_usd,
+ }),
+ );
+ }
+ CostEnforcementMode::RouteDown | CostEnforcementMode::Block => {
+ let message = budget_exceeded_message(
+ active_model.as_str(),
+ estimated_cost_usd,
+ current_usd,
+ limit_usd,
+ period,
+ );
+ runtime_trace::record_event(
+ "cost_budget_blocked",
+ Some(channel_name),
+ Some(provider_name),
+ Some(active_model.as_str()),
+ Some(&turn_id),
+ Some(false),
+ Some(&message),
+ serde_json::json!({
+ "iteration": iteration + 1,
+ "period": usage_period_label(period),
+ "current_usd": current_usd,
+ "limit_usd": limit_usd,
+ "estimated_cost_usd": estimated_cost_usd,
+ }),
+ );
+ return Err(anyhow::anyhow!(message));
+ }
+ },
+ }
+ }
+ }
+
observer.record_event(&ObserverEvent::LlmRequest {
provider: provider_name.to_string(),
- model: model.to_string(),
+ model: active_model.clone(),
messages_count: history.len(),
});
runtime_trace::record_event(
"llm_request",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
None,
None,
@@ -986,23 +1310,15 @@ pub(crate) async fn run_tool_call_loop(
// Fire void hook before LLM call
if let Some(hooks) = hooks {
- hooks.fire_llm_input(history, model).await;
+ hooks.fire_llm_input(history, active_model.as_str()).await;
}
- // Unified path via Provider::chat so provider-specific native tool logic
- // (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
- let request_tools = if use_native_tools {
- Some(tool_specs.as_slice())
- } else {
- None
- };
-
let chat_future = provider.chat(
ChatRequest {
messages: &request_messages,
tools: request_tools,
},
- model,
+ active_model.as_str(),
temperature,
);
@@ -1032,7 +1348,7 @@ pub(crate) async fn run_tool_call_loop(
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
- model: model.to_string(),
+ model: active_model.clone(),
duration: llm_started_at.elapsed(),
success: true,
error_message: None,
@@ -1062,7 +1378,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_parse_issue",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(parse_issue),
@@ -1080,7 +1396,7 @@ pub(crate) async fn run_tool_call_loop(
"llm_response",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(true),
None,
@@ -1131,7 +1447,7 @@ pub(crate) async fn run_tool_call_loop(
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
- model: model.to_string(),
+ model: active_model.clone(),
duration: llm_started_at.elapsed(),
success: false,
error_message: Some(safe_error.clone()),
@@ -1142,7 +1458,7 @@ pub(crate) async fn run_tool_call_loop(
"llm_response",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&safe_error),
@@ -1195,7 +1511,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_followthrough_retry",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("llm response implied follow-up action but emitted no tool call"),
@@ -1223,7 +1539,7 @@ pub(crate) async fn run_tool_call_loop(
"turn_final_response",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(true),
None,
@@ -1299,7 +1615,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&cancelled),
@@ -1341,7 +1657,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&blocked),
@@ -1381,7 +1697,7 @@ pub(crate) async fn run_tool_call_loop(
"approval_bypass_non_cli_session_grant",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("using runtime non-cli session approval grant"),
@@ -1438,7 +1754,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&denied),
@@ -1472,7 +1788,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&duplicate),
@@ -1500,7 +1816,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_start",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
None,
None,
@@ -1560,7 +1876,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(outcome.success),
outcome.error_reason.as_deref(),
@@ -1672,7 +1988,7 @@ pub(crate) async fn run_tool_call_loop(
"loop_detected_warning",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("loop pattern detected, injecting self-correction prompt"),
@@ -1694,7 +2010,7 @@ pub(crate) async fn run_tool_call_loop(
"loop_detected_hard_stop",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("loop persisted after warning, stopping early"),
@@ -1714,7 +2030,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_loop_exhausted",
Some(channel_name),
Some(provider_name),
- Some(model),
+ Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("agent exceeded maximum tool iterations"),
@@ -2150,6 +2466,8 @@ pub async fn run(
// ── Execute ──────────────────────────────────────────────────
let start = Instant::now();
+ let cost_enforcement_context =
+ create_cost_enforcement_context(&config.cost, &config.workspace_dir);
let mut final_output = String::new();
@@ -2196,8 +2514,9 @@ pub async fn run(
} else {
None
};
- let response = SAFETY_HEARTBEAT_CONFIG
- .scope(
+ let response = scope_cost_enforcement_context(
+ cost_enforcement_context.clone(),
+ SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
@@ -2220,8 +2539,9 @@ pub async fn run(
&[],
),
),
- )
- .await?;
+ ),
+ )
+ .await?;
final_output = response.clone();
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
let assistant_key = autosave_memory_key("assistant_resp");
@@ -2373,8 +2693,9 @@ pub async fn run(
} else {
None
};
- let response = match SAFETY_HEARTBEAT_CONFIG
- .scope(
+ let response = match scope_cost_enforcement_context(
+ cost_enforcement_context.clone(),
+ SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
@@ -2397,8 +2718,9 @@ pub async fn run(
&[],
),
),
- )
- .await
+ ),
+ )
+ .await
{
Ok(resp) => resp,
Err(e) => {
@@ -2684,6 +3006,8 @@ pub async fn process_message_with_session(
ChatMessage::user(&enriched),
];
+ let cost_enforcement_context =
+ create_cost_enforcement_context(&config.cost, &config.workspace_dir);
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
Some(SafetyHeartbeatConfig {
body: security.summary_for_heartbeat(),
@@ -2692,8 +3016,9 @@ pub async fn process_message_with_session(
} else {
None
};
- SAFETY_HEARTBEAT_CONFIG
- .scope(
+ scope_cost_enforcement_context(
+ cost_enforcement_context,
+ SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
agent_turn(
provider.as_ref(),
@@ -2707,8 +3032,9 @@ pub async fn process_message_with_session(
&config.multimodal,
config.agent.max_tool_iterations,
),
- )
- .await
+ ),
+ )
+ .await
}
#[cfg(test)]
@@ -3623,6 +3949,7 @@ mod tests {
None,
None,
&[],
+ ProgressMode::Verbose,
None,
)
.await
diff --git a/src/channels/mod.rs b/src/channels/mod.rs
index 64253e209..2ee2e03ad 100644
--- a/src/channels/mod.rs
+++ b/src/channels/mod.rs
@@ -78,7 +78,8 @@ pub use whatsapp_web::WhatsAppWebChannel;
use crate::agent::loop_::{
build_shell_policy_instructions, build_tool_instructions_from_specs,
- run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig,
+ run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext,
+ NonCliApprovalPrompt, SafetyHeartbeatConfig,
};
use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager};
use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError};
@@ -249,6 +250,7 @@ struct ChannelRuntimeDefaults {
api_key: Option,
api_url: Option,
reliability: crate::config::ReliabilityConfig,
+ cost: crate::config::CostConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -1053,6 +1055,7 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
api_key: config.api_key.clone(),
api_url: config.api_url.clone(),
reliability: config.reliability.clone(),
+ cost: config.cost.clone(),
}
}
@@ -1098,6 +1101,7 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau
api_key: ctx.api_key.clone(),
api_url: ctx.api_url.clone(),
reliability: (*ctx.reliability).clone(),
+ cost: crate::config::CostConfig::default(),
}
}
@@ -3664,29 +3668,78 @@ or tune thresholds in config.",
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
+ let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context(
+ &runtime_defaults.cost,
+ ctx.workspace_dir.as_path(),
+ );
+
+ let (approval_prompt_tx, mut approval_prompt_rx) =
+ tokio::sync::mpsc::unbounded_channel::();
+ let non_cli_approval_context = if msg.channel != "cli" && target_channel.is_some() {
+ Some(NonCliApprovalContext {
+ sender: msg.sender.clone(),
+ reply_target: msg.reply_target.clone(),
+ prompt_tx: approval_prompt_tx,
+ })
+ } else {
+ drop(approval_prompt_tx);
+ None
+ };
+ let approval_prompt_dispatcher = if let (Some(channel_ref), true) =
+ (target_channel.as_ref(), non_cli_approval_context.is_some())
+ {
+ let channel = Arc::clone(channel_ref);
+ let reply_target = msg.reply_target.clone();
+ let thread_ts = msg.thread_ts.clone();
+ Some(tokio::spawn(async move {
+ while let Some(prompt) = approval_prompt_rx.recv().await {
+ if let Err(err) = channel
+ .send_approval_prompt(
+ &reply_target,
+ &prompt.request_id,
+ &prompt.tool_name,
+ &prompt.arguments,
+ thread_ts.clone(),
+ )
+ .await
+ {
+ tracing::warn!(
+ "Failed to send non-CLI approval prompt for request {}: {err}",
+ prompt.request_id
+ );
+ }
+ }
+ }))
+ } else {
+ None
+ };
let llm_result = tokio::select! {
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
result = tokio::time::timeout(
Duration::from_secs(timeout_budget_secs),
- run_tool_call_loop_with_reply_target(
- active_provider.as_ref(),
- &mut history,
- ctx.tools_registry.as_ref(),
- ctx.observer.as_ref(),
- route.provider.as_str(),
- route.model.as_str(),
- runtime_defaults.temperature,
- true,
- Some(ctx.approval_manager.as_ref()),
- msg.channel.as_str(),
- Some(msg.reply_target.as_str()),
- &ctx.multimodal,
- ctx.max_tool_iterations,
- Some(cancellation_token.clone()),
- delta_tx,
- ctx.hooks.as_deref(),
- &excluded_tools_snapshot,
- progress_mode,
+ crate::agent::loop_::scope_cost_enforcement_context(
+ cost_enforcement_context,
+ run_tool_call_loop_with_non_cli_approval_context(
+ active_provider.as_ref(),
+ &mut history,
+ ctx.tools_registry.as_ref(),
+ ctx.observer.as_ref(),
+ route.provider.as_str(),
+ route.model.as_str(),
+ runtime_defaults.temperature,
+ true,
+ Some(ctx.approval_manager.as_ref()),
+ msg.channel.as_str(),
+ non_cli_approval_context,
+ &ctx.multimodal,
+ ctx.max_tool_iterations,
+ Some(cancellation_token.clone()),
+ delta_tx,
+ ctx.hooks.as_deref(),
+ &excluded_tools_snapshot,
+ progress_mode,
+ ctx.safety_heartbeat.clone(),
+ ),
),
) => LlmExecutionResult::Completed(result),
};
@@ -3694,6 +3747,9 @@ or tune thresholds in config.",
if let Some(handle) = draft_updater {
let _ = handle.await;
}
+ if let Some(handle) = approval_prompt_dispatcher {
+ let _ = handle.await;
+ }
if let Some(token) = typing_cancellation.as_ref() {
token.cancel();
@@ -7656,6 +7712,131 @@ BTC is currently around $65,000 based on latest tool output."#
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0);
}
+ #[tokio::test]
+ async fn process_channel_message_prompts_and_waits_for_non_cli_always_ask_approval() {
+ let channel_impl = Arc::new(TelegramRecordingChannel::default());
+ let channel: Arc = channel_impl.clone();
+
+ let mut channels_by_name = HashMap::new();
+ channels_by_name.insert(channel.name().to_string(), channel);
+
+ let autonomy_cfg = crate::config::AutonomyConfig {
+ always_ask: vec!["mock_price".to_string()],
+ ..crate::config::AutonomyConfig::default()
+ };
+
+ let runtime_ctx = Arc::new(ChannelRuntimeContext {
+ channels_by_name: Arc::new(channels_by_name),
+ provider: Arc::new(ToolCallingProvider),
+ default_provider: Arc::new("test-provider".to_string()),
+ memory: Arc::new(NoopMemory),
+ tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
+ observer: Arc::new(NoopObserver),
+ system_prompt: Arc::new("test-system-prompt".to_string()),
+ model: Arc::new("test-model".to_string()),
+ temperature: 0.0,
+ auto_save_memory: false,
+ max_tool_iterations: 10,
+ min_relevance_score: 0.0,
+ conversation_histories: Arc::new(Mutex::new(HashMap::new())),
+ conversation_locks: Default::default(),
+ session_config: crate::config::AgentSessionConfig::default(),
+ session_manager: None,
+ provider_cache: Arc::new(Mutex::new(HashMap::new())),
+ route_overrides: Arc::new(Mutex::new(HashMap::new())),
+ api_key: None,
+ api_url: None,
+ reliability: Arc::new(crate::config::ReliabilityConfig::default()),
+ provider_runtime_options: providers::ProviderRuntimeOptions::default(),
+ workspace_dir: Arc::new(std::env::temp_dir()),
+ message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
+ interrupt_on_new_message: false,
+ multimodal: crate::config::MultimodalConfig::default(),
+ hooks: None,
+ non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())),
+ query_classification: crate::config::QueryClassificationConfig::default(),
+ model_routes: Vec::new(),
+ approval_manager: Arc::new(ApprovalManager::from_config(&autonomy_cfg)),
+ safety_heartbeat: None,
+ startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(),
+ });
+
+ let runtime_ctx_for_first_turn = runtime_ctx.clone();
+ let first_turn = tokio::spawn(async move {
+ process_channel_message(
+ runtime_ctx_for_first_turn,
+ traits::ChannelMessage {
+ id: "msg-non-cli-approval-1".to_string(),
+ sender: "alice".to_string(),
+ reply_target: "chat-1".to_string(),
+ content: "What is the BTC price now?".to_string(),
+ channel: "telegram".to_string(),
+ timestamp: 1,
+ thread_ts: None,
+ },
+ CancellationToken::new(),
+ )
+ .await;
+ });
+
+ let request_id = tokio::time::timeout(Duration::from_secs(2), async {
+ loop {
+ let pending = runtime_ctx.approval_manager.list_non_cli_pending_requests(
+ Some("alice"),
+ Some("telegram"),
+ Some("chat-1"),
+ );
+ if let Some(req) = pending.first() {
+ break req.request_id.clone();
+ }
+ tokio::time::sleep(Duration::from_millis(20)).await;
+ }
+ })
+ .await
+ .expect("pending approval request should be created for always_ask tool");
+
+ process_channel_message(
+ runtime_ctx.clone(),
+ traits::ChannelMessage {
+ id: "msg-non-cli-approval-2".to_string(),
+ sender: "alice".to_string(),
+ reply_target: "chat-1".to_string(),
+ content: format!("/approve-allow {request_id}"),
+ channel: "telegram".to_string(),
+ timestamp: 2,
+ thread_ts: None,
+ },
+ CancellationToken::new(),
+ )
+ .await;
+
+ tokio::time::timeout(Duration::from_secs(5), first_turn)
+ .await
+ .expect("first channel turn should finish after approval")
+ .expect("first channel turn task should not panic");
+
+ let sent = channel_impl.sent_messages.lock().await;
+ assert!(
+ sent.iter()
+ .any(|entry| entry.contains("Approval required for tool `mock_price`")),
+ "channel should emit non-cli approval prompt"
+ );
+ assert!(
+ sent.iter()
+ .any(|entry| entry.contains("Approved supervised execution for `mock_price`")),
+ "channel should acknowledge explicit approval command"
+ );
+ assert!(
+ sent.iter()
+ .any(|entry| entry.contains("BTC is currently around")),
+ "tool call should execute after approval and produce final response"
+ );
+ assert!(
+ sent.iter().all(|entry| !entry.contains("Denied by user.")),
+ "always_ask tool should not be silently denied once non-cli approval prompt path is wired"
+ );
+ }
+
#[tokio::test]
async fn process_channel_message_denies_approval_management_for_unlisted_sender() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
@@ -9232,6 +9413,7 @@ BTC is currently around $65,000 based on latest tool output."#
api_key: None,
api_url: None,
reliability: crate::config::ReliabilityConfig::default(),
+ cost: crate::config::CostConfig::default(),
},
perplexity_filter: crate::config::PerplexityFilterConfig::default(),
outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(),
diff --git a/src/config/schema.rs b/src/config/schema.rs
index d4d971d84..c99c31779 100644
--- a/src/config/schema.rs
+++ b/src/config/schema.rs
@@ -1026,6 +1026,11 @@ pub struct SkillsConfig {
/// If unset, defaults to `$HOME/open-skills` when enabled.
#[serde(default)]
pub open_skills_dir: Option,
+ /// Optional allowlist of canonical directory roots for workspace skill symlink targets.
+ /// Symlinked workspace skills are rejected unless their resolved targets are under one
+ /// of these roots. Accepts absolute paths and `~/` home-relative paths.
+ #[serde(default)]
+ pub trusted_skill_roots: Vec,
/// Allow script-like files in skills (`.sh`, `.bash`, `.ps1`, shebang shell files).
/// Default: `false` (secure by default).
#[serde(default)]
@@ -1195,6 +1200,58 @@ pub struct CostConfig {
/// Per-model pricing (USD per 1M tokens)
#[serde(default)]
pub prices: std::collections::HashMap,
+
+ /// Runtime budget enforcement policy (`[cost.enforcement]`).
+ #[serde(default)]
+ pub enforcement: CostEnforcementConfig,
+}
+
+/// Budget enforcement behavior when projected spend approaches/exceeds limits.
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+pub enum CostEnforcementMode {
+ /// Log warnings only; never block the request.
+ Warn,
+ /// Attempt one downgrade to a cheaper route/model, then block if still over budget.
+ RouteDown,
+ /// Block immediately when projected spend exceeds configured limits.
+ Block,
+}
+
+fn default_cost_enforcement_mode() -> CostEnforcementMode {
+ CostEnforcementMode::Warn
+}
+
+/// Runtime budget enforcement controls (`[cost.enforcement]`).
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct CostEnforcementConfig {
+ /// Enforcement behavior. Default: `warn`.
+ #[serde(default = "default_cost_enforcement_mode")]
+ pub mode: CostEnforcementMode,
+ /// Optional fallback model (or `hint:*`) when `mode = "route_down"`.
+ #[serde(default = "default_route_down_model")]
+ pub route_down_model: Option,
+ /// Extra reserve added to token/cost estimates (percentage, 0-100). Default: `10`.
+ #[serde(default = "default_cost_reserve_percent")]
+ pub reserve_percent: u8,
+}
+
+fn default_route_down_model() -> Option {
+ Some("hint:fast".to_string())
+}
+
+fn default_cost_reserve_percent() -> u8 {
+ 10
+}
+
+impl Default for CostEnforcementConfig {
+ fn default() -> Self {
+ Self {
+ mode: default_cost_enforcement_mode(),
+ route_down_model: default_route_down_model(),
+ reserve_percent: default_cost_reserve_percent(),
+ }
+ }
}
/// Per-model pricing entry (USD per 1M tokens).
@@ -1230,6 +1287,7 @@ impl Default for CostConfig {
warn_at_percent: default_warn_percent(),
allow_override: false,
prices: get_default_pricing(),
+ enforcement: CostEnforcementConfig::default(),
}
}
}
@@ -7764,6 +7822,44 @@ impl Config {
anyhow::bail!("web_search.timeout_secs must be greater than 0");
}
+ // Cost
+ if self.cost.warn_at_percent > 100 {
+ anyhow::bail!("cost.warn_at_percent must be between 0 and 100");
+ }
+ if self.cost.enforcement.reserve_percent > 100 {
+ anyhow::bail!("cost.enforcement.reserve_percent must be between 0 and 100");
+ }
+ if matches!(self.cost.enforcement.mode, CostEnforcementMode::RouteDown) {
+ let route_down_model = self
+ .cost
+ .enforcement
+ .route_down_model
+ .as_deref()
+ .map(str::trim)
+ .filter(|model| !model.is_empty())
+ .ok_or_else(|| {
+ anyhow::anyhow!(
+ "cost.enforcement.route_down_model must be set when mode is route_down"
+ )
+ })?;
+
+ if let Some(route_hint) = route_down_model
+ .strip_prefix("hint:")
+ .map(str::trim)
+ .filter(|hint| !hint.is_empty())
+ {
+ if !self
+ .model_routes
+ .iter()
+ .any(|route| route.hint.trim() == route_hint)
+ {
+ anyhow::bail!(
+ "cost.enforcement.route_down_model uses hint '{route_hint}', but no matching [[model_routes]] entry exists"
+ );
+ }
+ }
+ }
+
// Scheduler
if self.scheduler.max_concurrent == 0 {
anyhow::bail!("scheduler.max_concurrent must be greater than 0");
@@ -13738,4 +13834,80 @@ sensitivity = 0.9
.validate()
.expect("disabled coordination should allow empty lead agent");
}
+
+ #[test]
+ async fn cost_enforcement_defaults_are_stable() {
+ let cost = CostConfig::default();
+ assert_eq!(cost.enforcement.mode, CostEnforcementMode::Warn);
+ assert_eq!(
+ cost.enforcement.route_down_model.as_deref(),
+ Some("hint:fast")
+ );
+ assert_eq!(cost.enforcement.reserve_percent, 10);
+ }
+
+ #[test]
+ async fn cost_enforcement_config_parses_route_down_mode() {
+ let parsed: CostConfig = toml::from_str(
+ r#"
+enabled = true
+
+[enforcement]
+mode = "route_down"
+route_down_model = "hint:fast"
+reserve_percent = 15
+"#,
+ )
+ .expect("cost enforcement should parse");
+
+ assert!(parsed.enabled);
+ assert_eq!(parsed.enforcement.mode, CostEnforcementMode::RouteDown);
+ assert_eq!(
+ parsed.enforcement.route_down_model.as_deref(),
+ Some("hint:fast")
+ );
+ assert_eq!(parsed.enforcement.reserve_percent, 15);
+ }
+
+ #[test]
+ async fn validation_rejects_cost_enforcement_reserve_over_100() {
+ let mut config = Config::default();
+ config.cost.enforcement.reserve_percent = 150;
+ let err = config
+ .validate()
+ .expect_err("expected cost.enforcement.reserve_percent validation failure");
+ assert!(err.to_string().contains("cost.enforcement.reserve_percent"));
+ }
+
+ #[test]
+ async fn validation_rejects_route_down_hint_without_matching_route() {
+ let mut config = Config::default();
+ config.cost.enforcement.mode = CostEnforcementMode::RouteDown;
+ config.cost.enforcement.route_down_model = Some("hint:fast".to_string());
+ let err = config
+ .validate()
+ .expect_err("route_down hint should require a matching model route");
+ assert!(err
+ .to_string()
+ .contains("cost.enforcement.route_down_model uses hint 'fast'"));
+ }
+
+ #[test]
+ async fn validation_accepts_route_down_hint_with_matching_route() {
+ let mut config = Config::default();
+ config.cost.enforcement.mode = CostEnforcementMode::RouteDown;
+ config.cost.enforcement.route_down_model = Some("hint:fast".to_string());
+ config.model_routes = vec![ModelRouteConfig {
+ hint: "fast".to_string(),
+ provider: "openrouter".to_string(),
+ model: "openai/gpt-4.1-mini".to_string(),
+ api_key: None,
+ max_tokens: None,
+ transport: None,
+ }];
+
+ config
+ .validate()
+ .expect("matching route_down hint route should validate");
+ }
}
diff --git a/src/main.rs b/src/main.rs
index 913ed6139..978235848 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -333,15 +333,20 @@ the binary location.
Examples:
zeroclaw update # Update to latest version
zeroclaw update --check # Check for updates without installing
+ zeroclaw update --instructions # Show install-method-specific update instructions
zeroclaw update --force # Reinstall even if already up to date")]
Update {
/// Check for updates without installing
- #[arg(long)]
+ #[arg(long, conflicts_with_all = ["force", "instructions"])]
check: bool,
/// Force update even if already at latest version
- #[arg(long)]
+ #[arg(long, conflicts_with = "instructions")]
force: bool,
+
+ /// Show human-friendly update instructions for your installation method
+ #[arg(long, conflicts_with_all = ["check", "force"])]
+ instructions: bool,
},
/// Engage, inspect, and resume emergency-stop states.
@@ -1107,9 +1112,18 @@ async fn main() -> Result<()> {
Ok(())
}
- Commands::Update { check, force } => {
- update::self_update(force, check).await?;
- Ok(())
+ Commands::Update {
+ check,
+ force,
+ instructions,
+ } => {
+ if instructions {
+ update::print_update_instructions()?;
+ Ok(())
+ } else {
+ update::self_update(force, check).await?;
+ Ok(())
+ }
}
Commands::Estop {
@@ -2630,4 +2644,41 @@ mod tests {
);
assert_eq!(payload["nested"]["non_secret"], serde_json::json!("ok"));
}
+
+ #[test]
+ fn update_help_mentions_instructions_flag() {
+ let cmd = Cli::command();
+ let update_cmd = cmd
+ .get_subcommands()
+ .find(|subcommand| subcommand.get_name() == "update")
+ .expect("update subcommand must exist");
+
+ let mut output = Vec::new();
+ update_cmd
+ .clone()
+ .write_long_help(&mut output)
+ .expect("help generation should succeed");
+ let help = String::from_utf8(output).expect("help output should be utf-8");
+
+ assert!(help.contains("--instructions"));
+ }
+
+ #[test]
+ fn update_cli_parses_instructions_flag() {
+ let cli = Cli::try_parse_from(["zeroclaw", "update", "--instructions"])
+ .expect("update --instructions should parse");
+
+ match cli.command {
+ Commands::Update {
+ check,
+ force,
+ instructions,
+ } => {
+ assert!(!check);
+ assert!(!force);
+ assert!(instructions);
+ }
+ other => panic!("expected update command, got {other:?}"),
+ }
+ }
}
diff --git a/src/memory/backend.rs b/src/memory/backend.rs
index c6759fbe8..231f6af4b 100644
--- a/src/memory/backend.rs
+++ b/src/memory/backend.rs
@@ -103,8 +103,9 @@ const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: false,
};
-const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 5] = [
+const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 6] = [
SQLITE_PROFILE,
+ SQLITE_QDRANT_HYBRID_PROFILE,
LUCID_PROFILE,
CORTEX_MEM_PROFILE,
MARKDOWN_PROFILE,
@@ -194,12 +195,13 @@ mod tests {
#[test]
fn selectable_backends_are_ordered_for_onboarding() {
let backends = selectable_memory_backends();
- assert_eq!(backends.len(), 5);
+ assert_eq!(backends.len(), 6);
assert_eq!(backends[0].key, "sqlite");
- assert_eq!(backends[1].key, "lucid");
- assert_eq!(backends[2].key, "cortex-mem");
- assert_eq!(backends[3].key, "markdown");
- assert_eq!(backends[4].key, "none");
+ assert_eq!(backends[1].key, "sqlite_qdrant_hybrid");
+ assert_eq!(backends[2].key, "lucid");
+ assert_eq!(backends[3].key, "cortex-mem");
+ assert_eq!(backends[4].key, "markdown");
+ assert_eq!(backends[5].key, "none");
}
#[test]
diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs
index f92359b87..5954227fb 100644
--- a/src/onboard/wizard.rs
+++ b/src/onboard/wizard.rs
@@ -4091,9 +4091,68 @@ fn setup_memory() -> Result {
let mut config = memory_config_defaults_for_backend(backend);
config.auto_save = auto_save;
+
+ if classify_memory_backend(backend) == MemoryBackendKind::SqliteQdrantHybrid {
+ configure_hybrid_qdrant_memory(&mut config)?;
+ }
+
Ok(config)
}
+fn configure_hybrid_qdrant_memory(config: &mut MemoryConfig) -> Result<()> {
+ print_bullet("Hybrid memory keeps local SQLite metadata and uses Qdrant for semantic ranking.");
+ print_bullet("SQLite storage path stays at the default workspace database.");
+
+ let qdrant_url_default = config
+ .qdrant
+ .url
+ .clone()
+ .unwrap_or_else(|| "http://localhost:6333".to_string());
+ let qdrant_url: String = Input::new()
+ .with_prompt(" Qdrant URL")
+ .default(qdrant_url_default)
+ .interact_text()?;
+ let qdrant_url = qdrant_url.trim();
+ if qdrant_url.is_empty() {
+ bail!("Qdrant URL is required for sqlite_qdrant_hybrid backend");
+ }
+ config.qdrant.url = Some(qdrant_url.to_string());
+
+ let qdrant_collection: String = Input::new()
+ .with_prompt(" Qdrant collection")
+ .default(config.qdrant.collection.clone())
+ .interact_text()?;
+ let qdrant_collection = qdrant_collection.trim();
+ if !qdrant_collection.is_empty() {
+ config.qdrant.collection = qdrant_collection.to_string();
+ }
+
+ let qdrant_api_key: String = Input::new()
+ .with_prompt(" Qdrant API key (optional, Enter to skip)")
+ .allow_empty(true)
+ .interact_text()?;
+ let qdrant_api_key = qdrant_api_key.trim();
+ config.qdrant.api_key = if qdrant_api_key.is_empty() {
+ None
+ } else {
+ Some(qdrant_api_key.to_string())
+ };
+
+ println!(
+ " {} Qdrant: {} (collection: {}, api key: {})",
+ style("✓").green().bold(),
+ style(config.qdrant.url.as_deref().unwrap_or_default()).green(),
+ style(&config.qdrant.collection).green(),
+ if config.qdrant.api_key.is_some() {
+ style("set").green().to_string()
+ } else {
+ style("not set").dim().to_string()
+ }
+ );
+
+ Ok(())
+}
+
fn setup_identity_backend() -> Result {
print_bullet("Choose the identity format ZeroClaw should scaffold for this workspace.");
print_bullet("You can switch later in config.toml under [identity].");
@@ -8515,10 +8574,11 @@ mod tests {
#[test]
fn backend_key_from_choice_maps_supported_backends() {
assert_eq!(backend_key_from_choice(0), "sqlite");
- assert_eq!(backend_key_from_choice(1), "lucid");
- assert_eq!(backend_key_from_choice(2), "cortex-mem");
- assert_eq!(backend_key_from_choice(3), "markdown");
- assert_eq!(backend_key_from_choice(4), "none");
+ assert_eq!(backend_key_from_choice(1), "sqlite_qdrant_hybrid");
+ assert_eq!(backend_key_from_choice(2), "lucid");
+ assert_eq!(backend_key_from_choice(3), "cortex-mem");
+ assert_eq!(backend_key_from_choice(4), "markdown");
+ assert_eq!(backend_key_from_choice(5), "none");
assert_eq!(backend_key_from_choice(999), "sqlite");
}
@@ -8560,6 +8620,18 @@ mod tests {
assert_eq!(config.embedding_cache_size, 10000);
}
+ #[test]
+ fn memory_config_defaults_for_hybrid_enable_sqlite_hygiene() {
+ let config = memory_config_defaults_for_backend("sqlite_qdrant_hybrid");
+ assert_eq!(config.backend, "sqlite_qdrant_hybrid");
+ assert!(config.auto_save);
+ assert!(config.hygiene_enabled);
+ assert_eq!(config.archive_after_days, 7);
+ assert_eq!(config.purge_after_days, 30);
+ assert_eq!(config.embedding_cache_size, 10000);
+ assert_eq!(config.qdrant.collection, "zeroclaw_memories");
+ }
+
#[test]
fn memory_config_defaults_for_none_disable_sqlite_hygiene() {
let config = memory_config_defaults_for_backend("none");
diff --git a/src/skills/mod.rs b/src/skills/mod.rs
index 82d467084..1982f7e91 100644
--- a/src/skills/mod.rs
+++ b/src/skills/mod.rs
@@ -80,7 +80,7 @@ fn default_version() -> String {
/// Load all skills from the workspace skills directory
pub fn load_skills(workspace_dir: &Path) -> Vec {
- load_skills_with_open_skills_config(workspace_dir, None, None, None)
+ load_skills_with_open_skills_config(workspace_dir, None, None, None, None)
}
/// Load skills using runtime config values (preferred at runtime).
@@ -90,6 +90,7 @@ pub fn load_skills_with_config(workspace_dir: &Path, config: &crate::config::Con
Some(config.skills.open_skills_enabled),
config.skills.open_skills_dir.as_deref(),
Some(config.skills.allow_scripts),
+ Some(&config.skills.trusted_skill_roots),
)
}
@@ -98,9 +99,12 @@ fn load_skills_with_open_skills_config(
config_open_skills_enabled: Option,
config_open_skills_dir: Option<&str>,
config_allow_scripts: Option,
+ config_trusted_skill_roots: Option<&[String]>,
) -> Vec {
let mut skills = Vec::new();
let allow_scripts = config_allow_scripts.unwrap_or(false);
+ let trusted_skill_roots =
+ resolve_trusted_skill_roots(workspace_dir, config_trusted_skill_roots.unwrap_or(&[]));
if let Some(open_skills_dir) =
ensure_open_skills_repo(config_open_skills_enabled, config_open_skills_dir)
@@ -108,16 +112,113 @@ fn load_skills_with_open_skills_config(
skills.extend(load_open_skills(&open_skills_dir, allow_scripts));
}
- skills.extend(load_workspace_skills(workspace_dir, allow_scripts));
+ skills.extend(load_workspace_skills(
+ workspace_dir,
+ allow_scripts,
+ &trusted_skill_roots,
+ ));
skills
}
-fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec {
+fn load_workspace_skills(
+ workspace_dir: &Path,
+ allow_scripts: bool,
+ trusted_skill_roots: &[PathBuf],
+) -> Vec {
let skills_dir = workspace_dir.join("skills");
- load_skills_from_directory(&skills_dir, allow_scripts)
+ load_skills_from_directory(&skills_dir, allow_scripts, trusted_skill_roots)
}
-fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec {
+fn resolve_trusted_skill_roots(workspace_dir: &Path, raw_roots: &[String]) -> Vec {
+ let home_dir = UserDirs::new().map(|dirs| dirs.home_dir().to_path_buf());
+ let mut resolved = Vec::new();
+
+ for raw in raw_roots {
+ let trimmed = raw.trim();
+ if trimmed.is_empty() {
+ continue;
+ }
+
+ let expanded = if trimmed == "~" {
+ home_dir.clone().unwrap_or_else(|| PathBuf::from(trimmed))
+ } else if let Some(rest) = trimmed
+ .strip_prefix("~/")
+ .or_else(|| trimmed.strip_prefix("~\\"))
+ {
+ home_dir
+ .as_ref()
+ .map(|home| home.join(rest))
+ .unwrap_or_else(|| PathBuf::from(trimmed))
+ } else {
+ PathBuf::from(trimmed)
+ };
+
+ let candidate = if expanded.is_relative() {
+ workspace_dir.join(expanded)
+ } else {
+ expanded
+ };
+
+ match candidate.canonicalize() {
+ Ok(canonical) if canonical.is_dir() => resolved.push(canonical),
+ Ok(canonical) => tracing::warn!(
+ "ignoring [skills].trusted_skill_roots entry '{}': canonical path is not a directory ({})",
+ trimmed,
+ canonical.display()
+ ),
+ Err(err) => tracing::warn!(
+ "ignoring [skills].trusted_skill_roots entry '{}': failed to canonicalize {} ({err})",
+ trimmed,
+ candidate.display()
+ ),
+ }
+ }
+
+ resolved.sort();
+ resolved.dedup();
+ resolved
+}
+
+fn enforce_workspace_skill_symlink_trust(
+ path: &Path,
+ trusted_skill_roots: &[PathBuf],
+) -> Result<()> {
+ let canonical_target = path
+ .canonicalize()
+ .with_context(|| format!("failed to resolve skill symlink target {}", path.display()))?;
+
+ if !canonical_target.is_dir() {
+ anyhow::bail!(
+ "symlink target is not a directory: {}",
+ canonical_target.display()
+ );
+ }
+
+ if trusted_skill_roots
+ .iter()
+ .any(|root| canonical_target.starts_with(root))
+ {
+ return Ok(());
+ }
+
+ if trusted_skill_roots.is_empty() {
+ anyhow::bail!(
+ "symlink target {} is not allowed because [skills].trusted_skill_roots is empty",
+ canonical_target.display()
+ );
+ }
+
+ anyhow::bail!(
+ "symlink target {} is outside configured [skills].trusted_skill_roots",
+ canonical_target.display()
+ );
+}
+
+fn load_skills_from_directory(
+ skills_dir: &Path,
+ allow_scripts: bool,
+ trusted_skill_roots: &[PathBuf],
+) -> Vec {
if !skills_dir.exists() {
return Vec::new();
}
@@ -130,7 +231,26 @@ fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec meta,
+ Err(err) => {
+ tracing::warn!(
+ "skipping skill entry {}: failed to read metadata ({err})",
+ path.display()
+ );
+ continue;
+ }
+ };
+
+ if metadata.file_type().is_symlink() {
+ if let Err(err) = enforce_workspace_skill_symlink_trust(&path, trusted_skill_roots) {
+ tracing::warn!(
+ "skipping untrusted symlinked skill entry {}: {err}",
+ path.display()
+ );
+ continue;
+ }
+ } else if !metadata.is_dir() {
continue;
}
@@ -180,7 +300,7 @@ fn load_open_skills(repo_dir: &Path, allow_scripts: bool) -> Vec {
// as executable skills.
let nested_skills_dir = repo_dir.join("skills");
if nested_skills_dir.is_dir() {
- return load_skills_from_directory(&nested_skills_dir, allow_scripts);
+ return load_skills_from_directory(&nested_skills_dir, allow_scripts, &[]);
}
let mut skills = Vec::new();
@@ -2137,6 +2257,20 @@ pub fn handle_command(command: crate::SkillCommands, config: &crate::config::Con
anyhow::bail!("Skill source or installed skill not found: {source}");
}
+ let trusted_skill_roots =
+ resolve_trusted_skill_roots(workspace_dir, &config.skills.trusted_skill_roots);
+ if let Ok(metadata) = std::fs::symlink_metadata(&target) {
+ if metadata.file_type().is_symlink() {
+ enforce_workspace_skill_symlink_trust(&target, &trusted_skill_roots)
+ .with_context(|| {
+ format!(
+ "trusted-symlink policy rejected audit target {}",
+ target.display()
+ )
+ })?;
+ }
+ }
+
let report = audit::audit_skill_directory_with_options(
&target,
audit::SkillAuditOptions {
diff --git a/src/skills/symlink_tests.rs b/src/skills/symlink_tests.rs
index da50891a4..b7bcb726a 100644
--- a/src/skills/symlink_tests.rs
+++ b/src/skills/symlink_tests.rs
@@ -1,6 +1,8 @@
#[cfg(test)]
mod tests {
- use crate::skills::skills_dir;
+ use crate::config::Config;
+ use crate::skills::{handle_command, load_skills_with_config, skills_dir};
+ use crate::SkillCommands;
use std::path::Path;
use tempfile::TempDir;
@@ -83,7 +85,7 @@ mod tests {
}
#[tokio::test]
- async fn test_skills_symlink_permissions_and_safety() {
+ async fn test_workspace_symlink_loading_requires_trusted_roots() {
let tmp = TempDir::new().unwrap();
let workspace_dir = tmp.path().join("workspace");
tokio::fs::create_dir_all(&workspace_dir).await.unwrap();
@@ -93,7 +95,6 @@ mod tests {
#[cfg(unix)]
{
- // Test case: Symlink outside workspace should be allowed (user responsibility)
let outside_dir = tmp.path().join("outside_skill");
tokio::fs::create_dir_all(&outside_dir).await.unwrap();
tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent")
@@ -102,15 +103,74 @@ mod tests {
let dest_link = skills_path.join("outside_skill");
let result = std::os::unix::fs::symlink(&outside_dir, &dest_link);
+ assert!(result.is_ok(), "symlink creation should succeed on unix");
+
+ let mut config = Config::default();
+ config.workspace_dir = workspace_dir.clone();
+ config.config_path = workspace_dir.join("config.toml");
+
+ let blocked = load_skills_with_config(&workspace_dir, &config);
assert!(
- result.is_ok(),
- "Should allow symlinking to directories outside workspace"
+ blocked.is_empty(),
+ "symlinked skill should be rejected when trusted_skill_roots is empty"
);
- // Should still be readable
- let content = tokio::fs::read_to_string(dest_link.join("SKILL.md")).await;
- assert!(content.is_ok());
- assert!(content.unwrap().contains("Outside Skill"));
+ config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()];
+ let allowed = load_skills_with_config(&workspace_dir, &config);
+ assert_eq!(
+ allowed.len(),
+ 1,
+ "symlinked skill should load when target is inside trusted roots"
+ );
+ assert_eq!(allowed[0].name, "outside_skill");
+ }
+ }
+
+ #[tokio::test]
+ async fn test_skills_audit_respects_trusted_symlink_roots() {
+ let tmp = TempDir::new().unwrap();
+ let workspace_dir = tmp.path().join("workspace");
+ tokio::fs::create_dir_all(&workspace_dir).await.unwrap();
+
+ let skills_path = skills_dir(&workspace_dir);
+ tokio::fs::create_dir_all(&skills_path).await.unwrap();
+
+ #[cfg(unix)]
+ {
+ let outside_dir = tmp.path().join("outside_skill");
+ tokio::fs::create_dir_all(&outside_dir).await.unwrap();
+ tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent")
+ .await
+ .unwrap();
+ let link_path = skills_path.join("outside_skill");
+ std::os::unix::fs::symlink(&outside_dir, &link_path).unwrap();
+
+ let mut config = Config::default();
+ config.workspace_dir = workspace_dir.clone();
+ config.config_path = workspace_dir.join("config.toml");
+
+ let blocked = handle_command(
+ SkillCommands::Audit {
+ source: "outside_skill".to_string(),
+ },
+ &config,
+ );
+ assert!(
+ blocked.is_err(),
+ "audit should reject symlink target when trusted roots are not configured"
+ );
+
+ config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()];
+ let allowed = handle_command(
+ SkillCommands::Audit {
+ source: "outside_skill".to_string(),
+ },
+ &config,
+ );
+ assert!(
+ allowed.is_ok(),
+ "audit should pass when symlink target is inside a trusted root"
+ );
}
}
}
diff --git a/src/tools/mcp_transport.rs b/src/tools/mcp_transport.rs
index 61052a343..27398451c 100644
--- a/src/tools/mcp_transport.rs
+++ b/src/tools/mcp_transport.rs
@@ -107,12 +107,27 @@ impl McpTransportConn for StdioTransport {
error: None,
});
}
- let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
- .await
- .context("timeout waiting for MCP response")??;
- let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
- .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
- Ok(resp)
+ let deadline = std::time::Instant::now() + Duration::from_secs(RECV_TIMEOUT_SECS);
+ loop {
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ bail!("timeout waiting for MCP response");
+ }
+ let resp_line = timeout(remaining, self.recv_raw())
+ .await
+ .context("timeout waiting for MCP response")??;
+ let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
+ .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
+ if resp.id.is_none() {
+ // Server-sent notification (e.g. `notifications/initialized`) — skip and
+ // keep waiting for the actual response to our request.
+ tracing::debug!(
+ "MCP stdio: skipping server notification while waiting for response"
+ );
+ continue;
+ }
+ return Ok(resp);
+ }
}
async fn close(&mut self) -> Result<()> {
diff --git a/src/tools/mod.rs b/src/tools/mod.rs
index a95244de7..7a4a0d070 100644
--- a/src/tools/mod.rs
+++ b/src/tools/mod.rs
@@ -82,6 +82,7 @@ pub mod web_access_config;
pub mod web_fetch;
pub mod web_search_config;
pub mod web_search_tool;
+pub mod xlsx_read;
pub use apply_patch::ApplyPatchTool;
pub use bg_run::{
@@ -147,6 +148,7 @@ pub use web_access_config::WebAccessConfigTool;
pub use web_fetch::WebFetchTool;
pub use web_search_config::WebSearchConfigTool;
pub use web_search_tool::WebSearchTool;
+pub use xlsx_read::XlsxReadTool;
pub use auth_profile::ManageAuthProfileTool;
pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool};
@@ -511,6 +513,9 @@ pub fn all_tools_with_runtime(
// PPTX text extraction
tool_arcs.push(Arc::new(PptxReadTool::new(security.clone())));
+ // XLSX text extraction
+ tool_arcs.push(Arc::new(XlsxReadTool::new(security.clone())));
+
// Vision tools are always available
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));
diff --git a/src/tools/xlsx_read.rs b/src/tools/xlsx_read.rs
new file mode 100644
index 000000000..655bf112f
--- /dev/null
+++ b/src/tools/xlsx_read.rs
@@ -0,0 +1,1177 @@
+use super::traits::{Tool, ToolResult};
+use crate::security::SecurityPolicy;
+use async_trait::async_trait;
+use serde_json::json;
+use std::collections::HashMap;
+use std::path::{Component, Path};
+use std::sync::Arc;
+
+/// Maximum XLSX file size (50 MB).
+const MAX_XLSX_BYTES: u64 = 50 * 1024 * 1024;
+/// Default character limit returned to the LLM.
+const DEFAULT_MAX_CHARS: usize = 50_000;
+/// Hard ceiling regardless of what the caller requests.
+const MAX_OUTPUT_CHARS: usize = 200_000;
+/// Upper bound for total uncompressed XML read from sheet files.
+const MAX_TOTAL_SHEET_XML_BYTES: u64 = 16 * 1024 * 1024;
+
+/// Extract plain text from an XLSX file in the workspace.
+pub struct XlsxReadTool {
+ security: Arc,
+}
+
+impl XlsxReadTool {
+ pub fn new(security: Arc) -> Self {
+ Self { security }
+ }
+}
+
+/// Extract plain text from XLSX bytes.
+///
+/// XLSX is a ZIP archive containing `xl/worksheets/sheet*.xml` with cell data,
+/// `xl/sharedStrings.xml` with a string pool, and `xl/workbook.xml` with sheet
+/// names. Text cells reference the shared string pool by index; inline and
+/// numeric values are taken directly from `` elements.
+fn extract_xlsx_text(bytes: &[u8]) -> anyhow::Result {
+ extract_xlsx_text_with_limits(bytes, MAX_TOTAL_SHEET_XML_BYTES)
+}
+
+fn extract_xlsx_text_with_limits(
+ bytes: &[u8],
+ max_total_sheet_xml_bytes: u64,
+) -> anyhow::Result {
+ use std::io::Read;
+
+ let cursor = std::io::Cursor::new(bytes);
+ let mut archive = zip::ZipArchive::new(cursor)?;
+
+ // 1. Parse shared strings table.
+ let shared_strings = parse_shared_strings(&mut archive)?;
+
+ // 2. Parse workbook.xml to get sheet names and rIds.
+ let sheet_entries = parse_workbook_sheets(&mut archive)?;
+
+ // 3. Parse workbook.xml.rels to map rId → Target path.
+ let rel_targets = parse_workbook_rels(&mut archive)?;
+
+ // 4. Build ordered list of (sheet_name, file_path) pairs.
+ let mut ordered_sheets: Vec<(String, String)> = Vec::new();
+ for (sheet_name, r_id) in &sheet_entries {
+ if let Some(target) = rel_targets.get(r_id) {
+ if let Some(normalized) = normalize_sheet_target(target) {
+ ordered_sheets.push((sheet_name.clone(), normalized));
+ }
+ }
+ }
+
+ // Fallback: if workbook parsing yielded no sheets, scan ZIP entries directly.
+ if ordered_sheets.is_empty() {
+ let mut fallback_paths: Vec = (0..archive.len())
+ .filter_map(|i| {
+ let name = archive.by_index(i).ok()?.name().to_string();
+ if name.starts_with("xl/worksheets/sheet") && name.ends_with(".xml") {
+ Some(name)
+ } else {
+ None
+ }
+ })
+ .collect();
+ fallback_paths.sort_by(|a, b| {
+ let a_idx = sheet_numeric_index(a);
+ let b_idx = sheet_numeric_index(b);
+ a_idx.cmp(&b_idx).then_with(|| a.cmp(b))
+ });
+
+ if fallback_paths.is_empty() {
+ anyhow::bail!("Not a valid XLSX (no worksheet XML files found)");
+ }
+
+ for (i, path) in fallback_paths.into_iter().enumerate() {
+ ordered_sheets.push((format!("Sheet{}", i + 1), path));
+ }
+ }
+
+ // 5. Extract cell text from each sheet.
+ let mut output = String::new();
+ let mut total_sheet_xml_bytes = 0u64;
+ let multi_sheet = ordered_sheets.len() > 1;
+
+ for (sheet_name, sheet_path) in &ordered_sheets {
+ let mut sheet_file = match archive.by_name(sheet_path) {
+ Ok(f) => f,
+ Err(_) => continue,
+ };
+
+ let sheet_xml_size = sheet_file.size();
+ total_sheet_xml_bytes = total_sheet_xml_bytes
+ .checked_add(sheet_xml_size)
+ .ok_or_else(|| anyhow::anyhow!("Sheet XML payload size overflow"))?;
+ if total_sheet_xml_bytes > max_total_sheet_xml_bytes {
+ anyhow::bail!(
+ "Sheet XML payload too large: {} bytes (limit: {} bytes)",
+ total_sheet_xml_bytes,
+ max_total_sheet_xml_bytes
+ );
+ }
+
+ let mut xml_content = String::new();
+ sheet_file.read_to_string(&mut xml_content)?;
+
+ if multi_sheet {
+ if !output.is_empty() {
+ output.push('\n');
+ }
+ use std::fmt::Write as _;
+ let _ = writeln!(output, "--- Sheet: {} ---", sheet_name);
+ }
+
+ let sheet_text = extract_sheet_cells(&xml_content, &shared_strings)?;
+ output.push_str(&sheet_text);
+ }
+
+ Ok(output)
+}
+
+/// Parse `xl/sharedStrings.xml` into a `Vec` indexed by position.
+fn parse_shared_strings(
+ archive: &mut zip::ZipArchive,
+) -> anyhow::Result> {
+ use quick_xml::events::Event;
+ use quick_xml::Reader;
+ use std::io::Read;
+
+ let mut xml = String::new();
+ match archive.by_name("xl/sharedStrings.xml") {
+ Ok(mut f) => {
+ f.read_to_string(&mut xml)?;
+ }
+ Err(zip::result::ZipError::FileNotFound) => return Ok(Vec::new()),
+ Err(e) => return Err(e.into()),
+ }
+
+ let mut strings = Vec::new();
+ let mut reader = Reader::from_str(&xml);
+ let mut in_si = false;
+ let mut in_t = false;
+ let mut current = String::new();
+
+ loop {
+ match reader.read_event() {
+ Ok(Event::Start(e)) => {
+ let qname = e.name();
+ let name = local_name(qname.as_ref());
+ if name == b"si" {
+ in_si = true;
+ current.clear();
+ } else if in_si && name == b"t" {
+ in_t = true;
+ }
+ }
+ Ok(Event::End(e)) => {
+ let qname = e.name();
+ let name = local_name(qname.as_ref());
+ if name == b"t" {
+ in_t = false;
+ } else if name == b"si" {
+ in_si = false;
+ strings.push(std::mem::take(&mut current));
+ }
+ }
+ Ok(Event::Text(e)) => {
+ if in_t {
+ current.push_str(&e.unescape()?);
+ }
+ }
+ Ok(Event::Eof) => break,
+ Err(e) => return Err(e.into()),
+ _ => {}
+ }
+ }
+
+ Ok(strings)
+}
+
+/// Parse `xl/workbook.xml` → Vec<(sheet_name, rId)>.
+fn parse_workbook_sheets(
+ archive: &mut zip::ZipArchive,
+) -> anyhow::Result> {
+ use quick_xml::events::Event;
+ use quick_xml::Reader;
+ use std::io::Read;
+
+ let mut xml = String::new();
+ match archive.by_name("xl/workbook.xml") {
+ Ok(mut f) => {
+ f.read_to_string(&mut xml)?;
+ }
+ Err(zip::result::ZipError::FileNotFound) => return Ok(Vec::new()),
+ Err(e) => return Err(e.into()),
+ }
+
+ let mut sheets = Vec::new();
+ let mut reader = Reader::from_str(&xml);
+
+ loop {
+ match reader.read_event() {
+ Ok(Event::Start(ref e) | Event::Empty(ref e)) => {
+ let qname = e.name();
+ if local_name(qname.as_ref()) == b"sheet" {
+ let mut name = None;
+ let mut r_id = None;
+ for attr in e.attributes().flatten() {
+ let key = attr.key.as_ref();
+ let local = local_name(key);
+ if local == b"name" {
+ name = Some(
+ attr.decode_and_unescape_value(reader.decoder())?
+ .into_owned(),
+ );
+ } else if key == b"r:id" || local == b"id" {
+ // Accept both r:id and {ns}:id variants.
+ // Only take the relationship id (starts with "rId").
+ let val = attr
+ .decode_and_unescape_value(reader.decoder())?
+ .into_owned();
+ if val.starts_with("rId") {
+ r_id = Some(val);
+ }
+ }
+ }
+ if let (Some(n), Some(r)) = (name, r_id) {
+ sheets.push((n, r));
+ }
+ }
+ }
+ Ok(Event::Eof) => break,
+ Err(e) => return Err(e.into()),
+ _ => {}
+ }
+ }
+
+ Ok(sheets)
+}
+
+/// Parse `xl/_rels/workbook.xml.rels` → HashMap.
+fn parse_workbook_rels(
+ archive: &mut zip::ZipArchive,
+) -> anyhow::Result> {
+ use quick_xml::events::Event;
+ use quick_xml::Reader;
+ use std::io::Read;
+
+ let mut xml = String::new();
+ match archive.by_name("xl/_rels/workbook.xml.rels") {
+ Ok(mut f) => {
+ f.read_to_string(&mut xml)?;
+ }
+ Err(zip::result::ZipError::FileNotFound) => return Ok(HashMap::new()),
+ Err(e) => return Err(e.into()),
+ }
+
+ let mut rels = HashMap::new();
+ let mut reader = Reader::from_str(&xml);
+
+ loop {
+ match reader.read_event() {
+ Ok(Event::Start(ref e) | Event::Empty(ref e)) => {
+ let qname = e.name();
+ if local_name(qname.as_ref()) == b"Relationship" {
+ let mut rel_id = None;
+ let mut target = None;
+ for attr in e.attributes().flatten() {
+ let key = local_name(attr.key.as_ref());
+ if key.eq_ignore_ascii_case(b"id") {
+ rel_id = Some(
+ attr.decode_and_unescape_value(reader.decoder())?
+ .into_owned(),
+ );
+ } else if key.eq_ignore_ascii_case(b"target") {
+ target = Some(
+ attr.decode_and_unescape_value(reader.decoder())?
+ .into_owned(),
+ );
+ }
+ }
+ if let (Some(id), Some(t)) = (rel_id, target) {
+ rels.insert(id, t);
+ }
+ }
+ }
+ Ok(Event::Eof) => break,
+ Err(e) => return Err(e.into()),
+ _ => {}
+ }
+ }
+
+ Ok(rels)
+}
+
+/// Extract cell text from a single worksheet XML string.
+///
+/// Cells are output as tab-separated values per row, newline-separated per row.
+fn extract_sheet_cells(xml: &str, shared_strings: &[String]) -> anyhow::Result {
+ use quick_xml::events::Event;
+ use quick_xml::Reader;
+
+ let mut reader = Reader::from_str(xml);
+ let mut output = String::new();
+
+ let mut in_row = false;
+ let mut in_cell = false;
+ let mut in_value = false;
+ let mut cell_type = CellType::Number;
+ let mut cell_value = String::new();
+ let mut row_cells: Vec = Vec::new();
+
+ loop {
+ match reader.read_event() {
+ Ok(Event::Start(e)) => {
+ let qname = e.name();
+ let name = local_name(qname.as_ref());
+ match name {
+ b"row" => {
+ in_row = true;
+ row_cells.clear();
+ }
+ b"c" if in_row => {
+ in_cell = true;
+ cell_type = CellType::Number;
+ cell_value.clear();
+ for attr in e.attributes().flatten() {
+ if attr.key.as_ref() == b"t" {
+ let val = attr.decode_and_unescape_value(reader.decoder())?;
+ cell_type = match val.as_ref() {
+ "s" => CellType::SharedString,
+ "inlineStr" => CellType::InlineString,
+ "b" => CellType::Boolean,
+ _ => CellType::Number,
+ };
+ }
+ }
+ }
+ b"v" if in_cell => {
+ in_value = true;
+ }
+ b"t" if in_cell && cell_type == CellType::InlineString => {
+ // Inline string: text is inside ...
+ in_value = true;
+ }
+ _ => {}
+ }
+ }
+ Ok(Event::End(e)) => {
+ let qname = e.name();
+ let name = local_name(qname.as_ref());
+ match name {
+ b"row" => {
+ in_row = false;
+ if !row_cells.is_empty() {
+ if !output.is_empty() {
+ output.push('\n');
+ }
+ output.push_str(&row_cells.join("\t"));
+ }
+ }
+ b"c" if in_cell => {
+ in_cell = false;
+ let resolved = match cell_type {
+ CellType::SharedString => {
+ if let Ok(idx) = cell_value.trim().parse::() {
+ shared_strings.get(idx).cloned().unwrap_or_default()
+ } else {
+ cell_value.clone()
+ }
+ }
+ CellType::Boolean => match cell_value.trim() {
+ "1" => "TRUE".to_string(),
+ "0" => "FALSE".to_string(),
+ other => other.to_string(),
+ },
+ _ => cell_value.clone(),
+ };
+ row_cells.push(resolved);
+ }
+ b"v" => {
+ in_value = false;
+ }
+ b"t" if in_cell => {
+ in_value = false;
+ }
+ _ => {}
+ }
+ }
+ Ok(Event::Text(e)) => {
+ if in_value {
+ cell_value.push_str(&e.unescape()?);
+ }
+ }
+ Ok(Event::Eof) => break,
+ Err(e) => return Err(e.into()),
+ _ => {}
+ }
+ }
+
+ // Flush last row if not terminated by .
+ if in_row && !row_cells.is_empty() {
+ if !output.is_empty() {
+ output.push('\n');
+ }
+ output.push_str(&row_cells.join("\t"));
+ }
+
+ if !output.is_empty() {
+ output.push('\n');
+ }
+
+ Ok(output)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum CellType {
+ Number,
+ SharedString,
+ InlineString,
+ Boolean,
+}
+
+fn sheet_numeric_index(sheet_path: &str) -> Option {
+ let stem = Path::new(sheet_path).file_stem()?.to_string_lossy();
+ let digits = stem.strip_prefix("sheet")?;
+ digits.parse::().ok()
+}
+
+fn local_name(name: &[u8]) -> &[u8] {
+ name.rsplit(|b| *b == b':').next().unwrap_or(name)
+}
+
+fn normalize_sheet_target(target: &str) -> Option {
+ if target.contains("://") {
+ return None;
+ }
+
+ let mut segments = Vec::new();
+ for component in Path::new("xl").join(target).components() {
+ match component {
+ Component::Normal(part) => segments.push(part.to_string_lossy().to_string()),
+ Component::ParentDir => {
+ segments.pop()?;
+ }
+ _ => {}
+ }
+ }
+
+ let normalized = segments.join("/");
+ if normalized.starts_with("xl/worksheets/") && normalized.ends_with(".xml") {
+ Some(normalized)
+ } else {
+ None
+ }
+}
+
+fn parse_max_chars(args: &serde_json::Value) -> anyhow::Result {
+ let Some(value) = args.get("max_chars") else {
+ return Ok(DEFAULT_MAX_CHARS);
+ };
+
+ let serde_json::Value::Number(number) = value else {
+ anyhow::bail!("Invalid 'max_chars': expected a positive integer");
+ };
+ let Some(raw) = number.as_u64() else {
+ anyhow::bail!("Invalid 'max_chars': expected a positive integer");
+ };
+ if raw == 0 {
+ anyhow::bail!("Invalid 'max_chars': must be >= 1");
+ }
+
+ Ok(usize::try_from(raw)
+ .unwrap_or(MAX_OUTPUT_CHARS)
+ .min(MAX_OUTPUT_CHARS))
+}
+
+#[async_trait]
+impl Tool for XlsxReadTool {
+ fn name(&self) -> &str {
+ "xlsx_read"
+ }
+
+ fn description(&self) -> &str {
+ "Extract plain text and numeric data from an XLSX (Excel) file in the workspace. \
+ Returns tab-separated cell values per row for each sheet. \
+ No formulas, charts, styles, or merged-cell awareness."
+ }
+
+ fn parameters_schema(&self) -> serde_json::Value {
+ json!({
+ "type": "object",
+ "properties": {
+ "path": {
+ "type": "string",
+ "description": "Path to the XLSX file. Relative paths resolve from workspace."
+ },
+ "max_chars": {
+ "type": "integer",
+ "description": "Maximum characters to return (default: 50000, max: 200000)",
+ "minimum": 1,
+ "maximum": 200_000
+ }
+ },
+ "required": ["path"]
+ })
+ }
+
+ async fn execute(&self, args: serde_json::Value) -> anyhow::Result {
+ let path = args
+ .get("path")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
+
+ let max_chars = match parse_max_chars(&args) {
+ Ok(value) => value,
+ Err(err) => {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(err.to_string()),
+ })
+ }
+ };
+
+ if self.security.is_rate_limited() {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some("Rate limit exceeded: too many actions in the last hour".into()),
+ });
+ }
+
+ if !self.security.is_path_allowed(path) {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!("Path not allowed by security policy: {path}")),
+ });
+ }
+
+ if !self.security.record_action() {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some("Rate limit exceeded: action budget exhausted".into()),
+ });
+ }
+
+ let full_path = self.security.workspace_dir.join(path);
+
+ let resolved_path = match tokio::fs::canonicalize(&full_path).await {
+ Ok(p) => p,
+ Err(e) => {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!("Failed to resolve file path: {e}")),
+ });
+ }
+ };
+
+ if !self.security.is_resolved_path_allowed(&resolved_path) {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(
+ self.security
+ .resolved_path_violation_message(&resolved_path),
+ ),
+ });
+ }
+
+ tracing::debug!("Reading XLSX: {}", resolved_path.display());
+
+ match tokio::fs::metadata(&resolved_path).await {
+ Ok(meta) => {
+ if meta.len() > MAX_XLSX_BYTES {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!(
+ "XLSX too large: {} bytes (limit: {MAX_XLSX_BYTES} bytes)",
+ meta.len()
+ )),
+ });
+ }
+ }
+ Err(e) => {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!("Failed to read file metadata: {e}")),
+ });
+ }
+ }
+
+ let bytes = match tokio::fs::read(&resolved_path).await {
+ Ok(b) => b,
+ Err(e) => {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!("Failed to read XLSX file: {e}")),
+ });
+ }
+ };
+
+ let text = match tokio::task::spawn_blocking(move || extract_xlsx_text(&bytes)).await {
+ Ok(Ok(t)) => t,
+ Ok(Err(e)) => {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!("XLSX extraction failed: {e}")),
+ });
+ }
+ Err(e) => {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(format!("XLSX extraction task panicked: {e}")),
+ });
+ }
+ };
+
+ if text.trim().is_empty() {
+ return Ok(ToolResult {
+ success: true,
+ output: "XLSX contains no extractable text".into(),
+ error: None,
+ });
+ }
+
+ let output = if text.chars().count() > max_chars {
+ let mut truncated: String = text.chars().take(max_chars).collect();
+ use std::fmt::Write as _;
+ let _ = write!(truncated, "\n\n... [truncated at {max_chars} chars]");
+ truncated
+ } else {
+ text
+ };
+
+ Ok(ToolResult {
+ success: true,
+ output,
+ error: None,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::security::{AutonomyLevel, SecurityPolicy};
+ use tempfile::TempDir;
+
+ fn test_security(workspace: std::path::PathBuf) -> Arc {
+ Arc::new(SecurityPolicy {
+ autonomy: AutonomyLevel::Supervised,
+ workspace_dir: workspace,
+ ..SecurityPolicy::default()
+ })
+ }
+
+ fn test_security_with_limit(
+ workspace: std::path::PathBuf,
+ max_actions: u32,
+ ) -> Arc {
+ Arc::new(SecurityPolicy {
+ autonomy: AutonomyLevel::Supervised,
+ workspace_dir: workspace,
+ max_actions_per_hour: max_actions,
+ ..SecurityPolicy::default()
+ })
+ }
+
+ /// Build a minimal valid XLSX (ZIP) in memory with one sheet containing
+ /// the given rows. Each inner `Vec<&str>` is a row of cell values.
+ fn minimal_xlsx_bytes(rows: &[Vec<&str>]) -> Vec {
+ use std::io::Write;
+
+ // Build shared strings from all unique cell values.
+ let mut all_values: Vec = Vec::new();
+ for row in rows {
+ for cell in row {
+ if !all_values.contains(&cell.to_string()) {
+ all_values.push(cell.to_string());
+ }
+ }
+ }
+
+ let mut ss_entries = String::new();
+ for val in &all_values {
+ ss_entries.push_str(&format!("{val}"));
+ }
+ let shared_strings_xml = format!(
+ r#"
+{ss_entries}"#,
+ all_values.len(),
+ all_values.len()
+ );
+
+ // Build sheet XML.
+ let mut sheet_rows = String::new();
+ for (r_idx, row) in rows.iter().enumerate() {
+ sheet_rows.push_str(&format!(r#""#, r_idx + 1));
+ for (c_idx, cell) in row.iter().enumerate() {
+ let col_letter = (b'A' + c_idx as u8) as char;
+ let cell_ref = format!("{}{}", col_letter, r_idx + 1);
+ let ss_idx = all_values.iter().position(|v| v == cell).unwrap();
+ sheet_rows.push_str(&format!(r#"{ss_idx}"#));
+ }
+ sheet_rows.push_str("
");
+ }
+ let sheet_xml = format!(
+ r#"
+
+{sheet_rows}
+"#
+ );
+
+ let workbook_xml = r#"
+
+
+"#;
+
+ let rels_xml = r#"
+
+
+"#;
+
+ let buf = std::io::Cursor::new(Vec::new());
+ let mut zip = zip::ZipWriter::new(buf);
+ let options = zip::write::SimpleFileOptions::default()
+ .compression_method(zip::CompressionMethod::Stored);
+
+ zip.start_file("xl/sharedStrings.xml", options).unwrap();
+ zip.write_all(shared_strings_xml.as_bytes()).unwrap();
+
+ zip.start_file("xl/workbook.xml", options).unwrap();
+ zip.write_all(workbook_xml.as_bytes()).unwrap();
+
+ zip.start_file("xl/_rels/workbook.xml.rels", options)
+ .unwrap();
+ zip.write_all(rels_xml.as_bytes()).unwrap();
+
+ zip.start_file("xl/worksheets/sheet1.xml", options).unwrap();
+ zip.write_all(sheet_xml.as_bytes()).unwrap();
+
+ zip.finish().unwrap().into_inner()
+ }
+
+ /// Build an XLSX with two sheets.
+ fn two_sheet_xlsx_bytes(
+ sheet1_name: &str,
+ sheet1_rows: &[Vec<&str>],
+ sheet2_name: &str,
+ sheet2_rows: &[Vec<&str>],
+ ) -> Vec {
+ use std::io::Write;
+
+ // Collect all unique values across both sheets.
+ let mut all_values: Vec = Vec::new();
+ for rows in [sheet1_rows, sheet2_rows] {
+ for row in rows {
+ for cell in row {
+ if !all_values.contains(&cell.to_string()) {
+ all_values.push(cell.to_string());
+ }
+ }
+ }
+ }
+
+ let mut ss_entries = String::new();
+ for val in &all_values {
+ ss_entries.push_str(&format!("{val}"));
+ }
+ let shared_strings_xml = format!(
+ r#"
+{ss_entries}"#,
+ all_values.len(),
+ all_values.len()
+ );
+
+ let build_sheet = |rows: &[Vec<&str>]| -> String {
+ let mut sheet_rows = String::new();
+ for (r_idx, row) in rows.iter().enumerate() {
+ sheet_rows.push_str(&format!(r#""#, r_idx + 1));
+ for (c_idx, cell) in row.iter().enumerate() {
+ let col_letter = (b'A' + c_idx as u8) as char;
+ let cell_ref = format!("{}{}", col_letter, r_idx + 1);
+ let ss_idx = all_values.iter().position(|v| v == cell).unwrap();
+ sheet_rows.push_str(&format!(r#"{ss_idx}"#));
+ }
+ sheet_rows.push_str("
");
+ }
+ format!(
+ r#"
+
+{sheet_rows}
+"#
+ )
+ };
+
+ let workbook_xml = format!(
+ r#"
+
+
+
+
+
+"#
+ );
+
+ let rels_xml = r#"
+
+
+
+"#;
+
+ let buf = std::io::Cursor::new(Vec::new());
+ let mut zip = zip::ZipWriter::new(buf);
+ let options = zip::write::SimpleFileOptions::default()
+ .compression_method(zip::CompressionMethod::Stored);
+
+ zip.start_file("xl/sharedStrings.xml", options).unwrap();
+ zip.write_all(shared_strings_xml.as_bytes()).unwrap();
+
+ zip.start_file("xl/workbook.xml", options).unwrap();
+ zip.write_all(workbook_xml.as_bytes()).unwrap();
+
+ zip.start_file("xl/_rels/workbook.xml.rels", options)
+ .unwrap();
+ zip.write_all(rels_xml.as_bytes()).unwrap();
+
+ zip.start_file("xl/worksheets/sheet1.xml", options).unwrap();
+ zip.write_all(build_sheet(sheet1_rows).as_bytes()).unwrap();
+
+ zip.start_file("xl/worksheets/sheet2.xml", options).unwrap();
+ zip.write_all(build_sheet(sheet2_rows).as_bytes()).unwrap();
+
+ zip.finish().unwrap().into_inner()
+ }
+
+ #[test]
+ fn name_is_xlsx_read() {
+ let tool = XlsxReadTool::new(test_security(std::env::temp_dir()));
+ assert_eq!(tool.name(), "xlsx_read");
+ }
+
+ #[test]
+ fn description_not_empty() {
+ let tool = XlsxReadTool::new(test_security(std::env::temp_dir()));
+ assert!(!tool.description().is_empty());
+ }
+
+ #[test]
+ fn schema_has_path_required() {
+ let tool = XlsxReadTool::new(test_security(std::env::temp_dir()));
+ let schema = tool.parameters_schema();
+ assert!(schema["properties"]["path"].is_object());
+ assert!(schema["properties"]["max_chars"].is_object());
+ let required = schema["required"].as_array().unwrap();
+ assert!(required.contains(&json!("path")));
+ }
+
+ #[test]
+ fn spec_matches_metadata() {
+ let tool = XlsxReadTool::new(test_security(std::env::temp_dir()));
+ let spec = tool.spec();
+ assert_eq!(spec.name, "xlsx_read");
+ assert!(spec.parameters.is_object());
+ }
+
+ #[tokio::test]
+ async fn missing_path_param_returns_error() {
+ let tool = XlsxReadTool::new(test_security(std::env::temp_dir()));
+ let result = tool.execute(json!({})).await;
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("path"));
+ }
+
+ #[tokio::test]
+ async fn absolute_path_is_blocked() {
+ let tool = XlsxReadTool::new(test_security(std::env::temp_dir()));
+ let result = tool.execute(json!({"path": "/etc/passwd"})).await.unwrap();
+ assert!(!result.success);
+ assert!(result
+ .error
+ .as_deref()
+ .unwrap_or("")
+ .contains("not allowed"));
+ }
+
+ #[tokio::test]
+ async fn path_traversal_is_blocked() {
+ let tmp = TempDir::new().unwrap();
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool
+ .execute(json!({"path": "../../../etc/passwd"}))
+ .await
+ .unwrap();
+ assert!(!result.success);
+ assert!(result
+ .error
+ .as_deref()
+ .unwrap_or("")
+ .contains("not allowed"));
+ }
+
+ #[tokio::test]
+ async fn nonexistent_file_returns_error() {
+ let tmp = TempDir::new().unwrap();
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool.execute(json!({"path": "missing.xlsx"})).await.unwrap();
+ assert!(!result.success);
+ assert!(result
+ .error
+ .as_deref()
+ .unwrap_or("")
+ .contains("Failed to resolve"));
+ }
+
+ #[tokio::test]
+ async fn rate_limit_blocks_request() {
+ let tmp = TempDir::new().unwrap();
+ let tool = XlsxReadTool::new(test_security_with_limit(tmp.path().to_path_buf(), 0));
+ let result = tool.execute(json!({"path": "any.xlsx"})).await.unwrap();
+ assert!(!result.success);
+ assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
+ }
+
+ #[tokio::test]
+ async fn extracts_text_from_valid_xlsx() {
+ let tmp = TempDir::new().unwrap();
+ let xlsx_path = tmp.path().join("data.xlsx");
+ let rows = vec![vec!["Name", "Age"], vec!["Alice", "30"]];
+ tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows))
+ .await
+ .unwrap();
+
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool.execute(json!({"path": "data.xlsx"})).await.unwrap();
+ assert!(result.success, "error: {:?}", result.error);
+ assert!(
+ result.output.contains("Name"),
+ "expected 'Name' in output, got: {}",
+ result.output
+ );
+ assert!(result.output.contains("Age"));
+ assert!(result.output.contains("Alice"));
+ assert!(result.output.contains("30"));
+ }
+
+ #[tokio::test]
+ async fn extracts_tab_separated_columns() {
+ let tmp = TempDir::new().unwrap();
+ let xlsx_path = tmp.path().join("cols.xlsx");
+ let rows = vec![vec!["A", "B", "C"]];
+ tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows))
+ .await
+ .unwrap();
+
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool.execute(json!({"path": "cols.xlsx"})).await.unwrap();
+ assert!(result.success);
+ assert!(
+ result.output.contains("A\tB\tC"),
+ "expected tab-separated output, got: {:?}",
+ result.output
+ );
+ }
+
+ #[tokio::test]
+ async fn extracts_multiple_sheets() {
+ let tmp = TempDir::new().unwrap();
+ let xlsx_path = tmp.path().join("multi.xlsx");
+ let bytes = two_sheet_xlsx_bytes(
+ "Sales",
+ &[vec!["Product", "Revenue"], vec!["Widget", "1000"]],
+ "Costs",
+ &[vec!["Item", "Amount"], vec!["Rent", "500"]],
+ );
+ tokio::fs::write(&xlsx_path, bytes).await.unwrap();
+
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool.execute(json!({"path": "multi.xlsx"})).await.unwrap();
+ assert!(result.success, "error: {:?}", result.error);
+ assert!(result.output.contains("--- Sheet: Sales ---"));
+ assert!(result.output.contains("--- Sheet: Costs ---"));
+ assert!(result.output.contains("Widget"));
+ assert!(result.output.contains("Rent"));
+ }
+
+ #[tokio::test]
+ async fn invalid_zip_returns_extraction_error() {
+ let tmp = TempDir::new().unwrap();
+ let xlsx_path = tmp.path().join("bad.xlsx");
+ tokio::fs::write(&xlsx_path, b"this is not a zip file")
+ .await
+ .unwrap();
+
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool.execute(json!({"path": "bad.xlsx"})).await.unwrap();
+ assert!(!result.success);
+ assert!(result
+ .error
+ .as_deref()
+ .unwrap_or("")
+ .contains("extraction failed"));
+ }
+
+ #[tokio::test]
+ async fn max_chars_truncates_output() {
+ let tmp = TempDir::new().unwrap();
+ let long_text = "B".repeat(200);
+ let rows = vec![vec![long_text.as_str(); 10]];
+ let xlsx_path = tmp.path().join("long.xlsx");
+ tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows))
+ .await
+ .unwrap();
+
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool
+ .execute(json!({"path": "long.xlsx", "max_chars": 50}))
+ .await
+ .unwrap();
+ assert!(result.success);
+ assert!(result.output.contains("truncated"));
+ }
+
+ #[tokio::test]
+ async fn invalid_max_chars_returns_tool_error() {
+ let tmp = TempDir::new().unwrap();
+ let xlsx_path = tmp.path().join("data.xlsx");
+ let rows = vec![vec!["Hello"]];
+ tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows))
+ .await
+ .unwrap();
+
+ let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf()));
+ let result = tool
+ .execute(json!({"path": "data.xlsx", "max_chars": "100"}))
+ .await
+ .unwrap();
+ assert!(!result.success);
+ assert!(result.error.as_deref().unwrap_or("").contains("max_chars"));
+ }
+
+ #[test]
+ fn shared_string_reference_resolved() {
+ let rows = vec![vec!["Hello", "World"]];
+ let bytes = minimal_xlsx_bytes(&rows);
+ let text = extract_xlsx_text(&bytes).unwrap();
+ assert!(text.contains("Hello"));
+ assert!(text.contains("World"));
+ }
+
+ #[test]
+ fn cumulative_sheet_xml_limit_is_enforced() {
+ let rows = vec![vec!["Alpha", "Beta"]];
+ let bytes = minimal_xlsx_bytes(&rows);
+ let error = extract_xlsx_text_with_limits(&bytes, 64).unwrap_err();
+ assert!(error.to_string().contains("Sheet XML payload too large"));
+ }
+
+ #[test]
+ fn numeric_cells_extracted_directly() {
+ use std::io::Write;
+
+ // Build a sheet with numeric cells (no t="s" attribute).
+ let sheet_xml = r#"
+
+
+423.14
+
+"#;
+
+ let workbook_xml = r#"
+
+
+"#;
+
+ let rels_xml = r#"
+
+
+"#;
+
+ let buf = std::io::Cursor::new(Vec::new());
+ let mut zip = zip::ZipWriter::new(buf);
+ let options = zip::write::SimpleFileOptions::default()
+ .compression_method(zip::CompressionMethod::Stored);
+
+ zip.start_file("xl/workbook.xml", options).unwrap();
+ zip.write_all(workbook_xml.as_bytes()).unwrap();
+ zip.start_file("xl/_rels/workbook.xml.rels", options)
+ .unwrap();
+ zip.write_all(rels_xml.as_bytes()).unwrap();
+ zip.start_file("xl/worksheets/sheet1.xml", options).unwrap();
+ zip.write_all(sheet_xml.as_bytes()).unwrap();
+
+ let bytes = zip.finish().unwrap().into_inner();
+ let text = extract_xlsx_text(&bytes).unwrap();
+ assert!(text.contains("42"), "got: {text}");
+ assert!(text.contains("3.14"), "got: {text}");
+ assert!(text.contains("42\t3.14"), "got: {text}");
+ }
+
+ #[test]
+ fn fallback_when_no_workbook() {
+ use std::io::Write;
+
+ // ZIP with only sheet files, no workbook.xml.
+ let sheet_xml = r#"
+
+
+99
+
+"#;
+
+ let buf = std::io::Cursor::new(Vec::new());
+ let mut zip = zip::ZipWriter::new(buf);
+ let options = zip::write::SimpleFileOptions::default()
+ .compression_method(zip::CompressionMethod::Stored);
+
+ zip.start_file("xl/worksheets/sheet1.xml", options).unwrap();
+ zip.write_all(sheet_xml.as_bytes()).unwrap();
+
+ let bytes = zip.finish().unwrap().into_inner();
+ let text = extract_xlsx_text(&bytes).unwrap();
+ assert!(text.contains("99"), "got: {text}");
+ }
+
+ #[cfg(unix)]
+ #[tokio::test]
+ async fn symlink_escape_is_blocked() {
+ use std::os::unix::fs::symlink;
+
+ let root = TempDir::new().unwrap();
+ let workspace = root.path().join("workspace");
+ let outside = root.path().join("outside");
+ tokio::fs::create_dir_all(&workspace).await.unwrap();
+ tokio::fs::create_dir_all(&outside).await.unwrap();
+ let rows = vec![vec!["secret"]];
+ tokio::fs::write(outside.join("secret.xlsx"), minimal_xlsx_bytes(&rows))
+ .await
+ .unwrap();
+ symlink(outside.join("secret.xlsx"), workspace.join("link.xlsx")).unwrap();
+
+ let tool = XlsxReadTool::new(test_security(workspace));
+ let result = tool.execute(json!({"path": "link.xlsx"})).await.unwrap();
+ assert!(!result.success);
+ assert!(result
+ .error
+ .as_deref()
+ .unwrap_or("")
+ .contains("escapes workspace"));
+ }
+
+}
diff --git a/src/update.rs b/src/update.rs
index b0b328e44..b86b6cbb1 100644
--- a/src/update.rs
+++ b/src/update.rs
@@ -5,6 +5,7 @@
use anyhow::{bail, Context, Result};
use std::env;
use std::fs;
+use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use std::process::Command;
@@ -26,6 +27,13 @@ struct Asset {
browser_download_url: String,
}
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum InstallMethod {
+ Homebrew,
+ CargoOrLocal,
+ Unknown,
+}
+
/// Get the current version of the binary
pub fn current_version() -> &'static str {
env!("CARGO_PKG_VERSION")
@@ -213,6 +221,79 @@ fn get_current_exe() -> Result {
env::current_exe().context("Failed to get current executable path")
}
+fn detect_install_method_for_path(resolved_path: &Path, home_dir: Option<&Path>) -> InstallMethod {
+ let lower = resolved_path.to_string_lossy().to_ascii_lowercase();
+ if lower.contains("/cellar/zeroclaw/") || lower.contains("/homebrew/cellar/zeroclaw/") {
+ return InstallMethod::Homebrew;
+ }
+
+ if let Some(home) = home_dir {
+ if resolved_path.starts_with(home.join(".cargo").join("bin"))
+ || resolved_path.starts_with(home.join(".local").join("bin"))
+ {
+ return InstallMethod::CargoOrLocal;
+ }
+ }
+
+ InstallMethod::Unknown
+}
+
+fn detect_install_method(current_exe: &Path) -> InstallMethod {
+ let resolved = fs::canonicalize(current_exe).unwrap_or_else(|_| current_exe.to_path_buf());
+ let home_dir = env::var_os("HOME").map(PathBuf::from);
+ detect_install_method_for_path(&resolved, home_dir.as_deref())
+}
+
+/// Print human-friendly update instructions based on detected install method.
+pub fn print_update_instructions() -> Result<()> {
+ let current_exe = get_current_exe()?;
+ let install_method = detect_install_method(¤t_exe);
+
+ println!("ZeroClaw update guide");
+ println!("Detected binary: {}", current_exe.display());
+ println!();
+ println!("1) Check if a new release exists:");
+ println!(" zeroclaw update --check");
+ println!();
+
+ match install_method {
+ InstallMethod::Homebrew => {
+ println!("Detected install method: Homebrew");
+ println!("Recommended update commands:");
+ println!(" brew update");
+ println!(" brew upgrade zeroclaw");
+ println!(" zeroclaw --version");
+ println!();
+ println!(
+ "Tip: avoid `zeroclaw update` on Homebrew installs unless you intentionally want to override the managed binary."
+ );
+ }
+ InstallMethod::CargoOrLocal => {
+ println!("Detected install method: local binary (~/.cargo/bin or ~/.local/bin)");
+ println!("Recommended update command:");
+ println!(" zeroclaw update");
+ println!("Optional force reinstall:");
+ println!(" zeroclaw update --force");
+ println!("Verify:");
+ println!(" zeroclaw --version");
+ }
+ InstallMethod::Unknown => {
+ println!("Detected install method: unknown");
+ println!("Try the built-in updater first:");
+ println!(" zeroclaw update");
+ println!(
+ "If your package manager owns the binary, use that manager's upgrade command."
+ );
+ println!("Verify:");
+ println!(" zeroclaw --version");
+ }
+ }
+
+ println!();
+ println!("Release source: https://github.com/{GITHUB_REPO}/releases/latest");
+ Ok(())
+}
+
/// Replace the current binary with the new one
fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> {
// On Windows, we can't replace a running executable directly
@@ -226,11 +307,43 @@ fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> {
let _ = fs::remove_file(&old_path);
}
- // On Unix, we can overwrite the running executable
+ // On Unix, stage the binary in the destination directory first.
+ // This avoids cross-filesystem rename failures (EXDEV) from temp dirs.
#[cfg(unix)]
{
- // Use rename for atomic replacement on Unix
- fs::rename(new_binary, current_exe).context("Failed to replace binary")?;
+ use std::os::unix::fs::PermissionsExt;
+
+ let parent = current_exe
+ .parent()
+ .context("Current executable has no parent directory")?;
+ let binary_name = current_exe
+ .file_name()
+ .context("Current executable path is missing a file name")?
+ .to_string_lossy()
+ .into_owned();
+ let staged_path = parent.join(format!(".{binary_name}.new"));
+ let backup_path = parent.join(format!(".{binary_name}.bak"));
+
+ fs::copy(new_binary, &staged_path).context("Failed to stage updated binary")?;
+ fs::set_permissions(&staged_path, fs::Permissions::from_mode(0o755))
+ .context("Failed to set permissions on staged binary")?;
+
+ if let Err(err) = fs::remove_file(&backup_path) {
+ if err.kind() != ErrorKind::NotFound {
+ return Err(err).context("Failed to remove stale backup binary");
+ }
+ }
+
+ fs::rename(current_exe, &backup_path).context("Failed to backup current binary")?;
+
+ if let Err(err) = fs::rename(&staged_path, current_exe) {
+ let _ = fs::rename(&backup_path, current_exe);
+ let _ = fs::remove_file(&staged_path);
+ return Err(err).context("Failed to activate updated binary");
+ }
+
+ // Best-effort cleanup of backup.
+ let _ = fs::remove_file(&backup_path);
}
Ok(())
@@ -258,6 +371,7 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
println!();
let current_exe = get_current_exe()?;
+ let install_method = detect_install_method(¤t_exe);
println!("Current binary: {}", current_exe.display());
println!("Current version: v{}", current_version());
println!();
@@ -268,6 +382,31 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
println!("Latest version: {}", release.tag_name);
+ if check_only {
+ println!();
+ if latest_version == current_version() {
+ println!("✅ Already up to date.");
+ } else {
+ println!(
+ "Update available: {} -> {}",
+ current_version(),
+ latest_version
+ );
+ println!("Run `zeroclaw update` to install the update.");
+ }
+ return Ok(());
+ }
+
+ if install_method == InstallMethod::Homebrew && !force {
+ println!();
+ println!("Detected a Homebrew-managed installation.");
+ println!("Use `brew upgrade zeroclaw` for the safest update path.");
+ println!(
+ "Run `zeroclaw update --force` only if you intentionally want to override Homebrew."
+ );
+ return Ok(());
+ }
+
// Check if update is needed
if latest_version == current_version() && !force {
println!();
@@ -275,17 +414,6 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
return Ok(());
}
- if check_only {
- println!();
- println!(
- "Update available: {} -> {}",
- current_version(),
- latest_version
- );
- println!("Run `zeroclaw update` to install the update.");
- return Ok(());
- }
-
println!();
println!(
"Updating from v{} to {}...",
@@ -315,3 +443,50 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
Ok(())
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn archive_name_uses_zip_for_windows_and_targz_elsewhere() {
+ assert_eq!(
+ get_archive_name("x86_64-pc-windows-msvc"),
+ "zeroclaw-x86_64-pc-windows-msvc.zip"
+ );
+ assert_eq!(
+ get_archive_name("x86_64-unknown-linux-gnu"),
+ "zeroclaw-x86_64-unknown-linux-gnu.tar.gz"
+ );
+ }
+
+ #[test]
+ fn detect_install_method_identifies_homebrew_paths() {
+ let path = Path::new("/opt/homebrew/Cellar/zeroclaw/0.1.7/bin/zeroclaw");
+ let method = detect_install_method_for_path(path, None);
+ assert_eq!(method, InstallMethod::Homebrew);
+ }
+
+ #[test]
+ fn detect_install_method_identifies_local_bin_paths() {
+ let home = Path::new("/Users/example");
+ let cargo_path = Path::new("/Users/example/.cargo/bin/zeroclaw");
+ let local_path = Path::new("/Users/example/.local/bin/zeroclaw");
+
+ assert_eq!(
+ detect_install_method_for_path(cargo_path, Some(home)),
+ InstallMethod::CargoOrLocal
+ );
+ assert_eq!(
+ detect_install_method_for_path(local_path, Some(home)),
+ InstallMethod::CargoOrLocal
+ );
+ }
+
+ #[test]
+ fn detect_install_method_returns_unknown_for_other_paths() {
+ let path = Path::new("/usr/bin/zeroclaw");
+ let method = detect_install_method_for_path(path, Some(Path::new("/Users/example")));
+ assert_eq!(method, InstallMethod::Unknown);
+ }
+}