Merge branch 'main' into feat/wasm-plugin-runtime-exec

This commit is contained in:
xj 2026-03-01 00:57:15 -08:00 committed by GitHub
commit 1da53f154c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 2711 additions and 153 deletions

View File

@ -8,7 +8,7 @@ on:
apply:
description: "Cancel selected queued runs (false = dry-run report only)"
required: true
default: true
default: false
type: boolean
status:
description: "Queued-run status scope"
@ -57,22 +57,27 @@ jobs:
status_scope="queued"
max_cancel="120"
apply_mode="true"
apply_mode="false"
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
status_scope="${{ github.event.inputs.status || 'queued' }}"
max_cancel="${{ github.event.inputs.max_cancel || '120' }}"
apply_mode="${{ github.event.inputs.apply || 'true' }}"
apply_mode="${{ github.event.inputs.apply || 'false' }}"
fi
cmd=(python3 scripts/ci/queue_hygiene.py
--repo "${{ github.repository }}"
--status "${status_scope}"
--max-cancel "${max_cancel}"
--dedupe-workflow "CI Run"
--dedupe-workflow "Test E2E"
--dedupe-workflow "Docs Deploy"
--dedupe-workflow "PR Intake Checks"
--dedupe-workflow "PR Labeler"
--dedupe-workflow "PR Auto Responder"
--dedupe-workflow "Workflow Sanity"
--dedupe-workflow "PR Label Policy Check"
--dedupe-include-non-pr
--non-pr-key branch
--output-json artifacts/queue-hygiene-report.json
--verbose)

View File

@ -9,7 +9,7 @@ on:
branches: [dev, main]
concurrency:
group: ci-${{ github.event.pull_request.number || github.sha }}
group: ci-run-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
cancel-in-progress: true
permissions:

View File

@ -41,7 +41,7 @@ on:
default: ""
concurrency:
group: docs-deploy-${{ github.event.pull_request.number || github.sha }}
group: docs-deploy-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
cancel-in-progress: true
permissions:

View File

@ -14,7 +14,7 @@ on:
workflow_dispatch:
concurrency:
group: e2e-${{ github.event.pull_request.number || github.sha }}
group: test-e2e-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
cancel-in-progress: true
permissions:

View File

@ -30,7 +30,7 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.
<p align="center">
<a href="#quick-start">Getting Started</a> |
<a href="bootstrap.sh">One-Click Setup</a> |
<a href="docs/one-click-bootstrap.md">One-Click Setup</a> |
<a href="docs/README.md">Docs Hub</a> |
<a href="docs/SUMMARY.md">Docs TOC</a>
</p>
@ -108,11 +108,11 @@ cargo install zeroclaw
### First Run
```bash
# Start the gateway daemon
zeroclaw gateway start
# Start the gateway (serves the Web Dashboard API/UI)
zeroclaw gateway
# Open the web UI
zeroclaw dashboard
# Open the dashboard URL shown in startup logs
# (default: http://127.0.0.1:3000/)
# Or chat directly
zeroclaw chat "Hello!"
@ -120,6 +120,16 @@ zeroclaw chat "Hello!"
For detailed setup options, see [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md).
### Installation Docs (Canonical Source)
Use repository docs as the source of truth for install/setup instructions:
- [README Quick Start](#quick-start)
- [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md)
- [docs/getting-started/README.md](docs/getting-started/README.md)
Issue comments can provide context, but they are not canonical installation documentation.
## Benchmark Snapshot (ZeroClaw vs OpenClaw, Reproducible)
Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware.

View File

@ -29,6 +29,8 @@ Localized hubs: [简体中文](i18n/zh-CN/README.md) · [日本語](i18n/ja/READ
| See project PR/issue docs snapshot | [project-triage-snapshot-2026-02-18.md](project-triage-snapshot-2026-02-18.md) |
| Perform i18n completion for docs changes | [i18n-guide.md](i18n-guide.md) |
Installation source-of-truth: keep install/run instructions in repository docs and README pages; issue comments are supplemental context only.
## Quick Decision Tree (10 seconds)
- Need first-time setup or install? → [getting-started/README.md](getting-started/README.md)

View File

@ -15,6 +15,7 @@ Last verified: **February 28, 2026**.
| `service` | Manage user-level OS service lifecycle |
| `doctor` | Run diagnostics and freshness checks |
| `status` | Print current configuration and system summary |
| `update` | Check or install latest ZeroClaw release |
| `estop` | Engage/resume emergency stop levels and inspect estop state |
| `cron` | Manage scheduled tasks |
| `models` | Refresh provider model catalogs |
@ -103,6 +104,18 @@ Notes:
- `zeroclaw service status`
- `zeroclaw service uninstall`
### `update`
- `zeroclaw update --check` (check for new release, no install)
- `zeroclaw update` (install latest release binary for current platform)
- `zeroclaw update --force` (reinstall even if current version matches latest)
- `zeroclaw update --instructions` (print install-method-specific guidance)
Notes:
- If ZeroClaw is installed via Homebrew, prefer `brew upgrade zeroclaw`.
- `update --instructions` detects common install methods and prints the safest path.
### `cron`
- `zeroclaw cron list`
@ -264,6 +277,11 @@ Registry packages are installed to `~/.zeroclaw/workspace/skills/<name>/`.
Use `skills audit` to manually validate a candidate skill directory (or an installed skill by name) before sharing it.
Workspace symlink policy:
- Symlinked entries under `~/.zeroclaw/workspace/skills/` are blocked by default.
- To allow shared local skill directories, set `[skills].trusted_skill_roots` in `config.toml`.
- A symlinked skill is accepted only when its resolved canonical target is inside one of the trusted roots.
Skill manifests (`SKILL.toml`) support `prompts` and `[[tools]]`; both are injected into the agent system prompt at runtime, so the model can follow skill instructions without manually reading skill files.
### `migrate`

View File

@ -536,6 +536,7 @@ Notes:
|---|---|---|
| `open_skills_enabled` | `false` | Opt-in loading/sync of community `open-skills` repository |
| `open_skills_dir` | unset | Optional local path for `open-skills` (defaults to `$HOME/open-skills` when enabled) |
| `trusted_skill_roots` | `[]` | Allowlist of directory roots for symlink targets in `workspace/skills/*` |
| `prompt_injection_mode` | `full` | Skill prompt verbosity: `full` (inline instructions/tools) or `compact` (name/description/location only) |
| `clawhub_token` | unset | Optional Bearer token for authenticated ClawhHub skill downloads |
@ -548,7 +549,8 @@ Notes:
- `ZEROCLAW_SKILLS_PROMPT_MODE` accepts `full` or `compact`.
- Precedence for enable flag: `ZEROCLAW_OPEN_SKILLS_ENABLED``skills.open_skills_enabled` in `config.toml` → default `false`.
- `prompt_injection_mode = "compact"` is recommended on low-context local models to reduce startup prompt size while keeping skill files available on demand.
- Skill loading and `zeroclaw skills install` both apply a static security audit. Skills that contain symlinks, script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected.
- Symlinked workspace skills are blocked by default. Set `trusted_skill_roots` to allow local shared-skill directories after explicit trust review.
- `zeroclaw skills install` and `zeroclaw skills audit` apply a static security audit. Skills that contain script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected.
- `clawhub_token` is sent as `Authorization: Bearer <token>` when downloading from ClawhHub. Obtain a token from [https://clawhub.ai](https://clawhub.ai) after signing in. Required if the API returns 429 (rate-limited) or 401 (unauthorized) for anonymous requests.
**ClawhHub token example:**

View File

@ -20,6 +20,13 @@ If both exist, your shell `PATH` order decides which one runs.
## 2) Update on macOS
Quick way to get install-method-specific guidance:
```bash
zeroclaw update --instructions
zeroclaw update --check
```
### A) Homebrew install
```bash
@ -54,6 +61,13 @@ Re-run your download/install flow with the latest release asset, then verify:
zeroclaw --version
```
You can also use the built-in updater for manual/local installs:
```bash
zeroclaw update
zeroclaw --version
```
## 3) Uninstall on macOS
### A) Stop and remove background service first

View File

@ -423,11 +423,18 @@ string_to_bool() {
}
guided_input_stream() {
if [[ -t 0 ]]; then
# Some constrained Linux containers report interactive stdin but deny opening
# /dev/stdin directly. Probe readability before selecting it.
if [[ -t 0 ]] && (: </dev/stdin) 2>/dev/null; then
echo "/dev/stdin"
return 0
fi
if [[ -t 0 ]] && (: </proc/self/fd/0) 2>/dev/null; then
echo "/proc/self/fd/0"
return 0
fi
if (: </dev/tty) 2>/dev/null; then
echo "/dev/tty"
return 0

View File

@ -66,6 +66,15 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Also dedupe non-PR runs (push/manual). Default dedupe scope is PR-originated runs only.",
)
parser.add_argument(
"--non-pr-key",
default="sha",
choices=["sha", "branch"],
help=(
"Identity key mode for non-PR dedupe when --dedupe-include-non-pr is enabled: "
"'sha' keeps one run per commit (default), 'branch' keeps one run per branch."
),
)
parser.add_argument(
"--max-cancel",
type=int,
@ -165,7 +174,7 @@ def parse_timestamp(value: str | None) -> datetime:
return datetime.fromtimestamp(0, tz=timezone.utc)
def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]:
def run_identity_key(run: dict[str, Any], *, non_pr_key: str) -> tuple[str, str, str, str]:
name = str(run.get("name", ""))
event = str(run.get("event", ""))
head_branch = str(run.get("head_branch", ""))
@ -179,7 +188,10 @@ def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]:
if pr_number:
# For PR traffic, cancel stale runs across synchronize updates for the same PR.
return (name, event, f"pr:{pr_number}", "")
# For push/manual traffic, key by SHA to avoid collapsing distinct commits.
if non_pr_key == "branch":
# Branch-level supersedence for push/manual lanes.
return (name, event, head_branch, "")
# SHA-level supersedence for push/manual lanes.
return (name, event, head_branch, head_sha)
@ -189,6 +201,7 @@ def collect_candidates(
dedupe_workflows: set[str],
*,
include_non_pr: bool,
non_pr_key: str,
) -> tuple[list[dict[str, Any]], Counter[str]]:
reasons_by_id: dict[int, set[str]] = defaultdict(set)
runs_by_id: dict[int, dict[str, Any]] = {}
@ -220,7 +233,7 @@ def collect_candidates(
has_pr_context = isinstance(pull_requests, list) and len(pull_requests) > 0
if is_pr_event and not has_pr_context and not include_non_pr:
continue
key = run_identity_key(run)
key = run_identity_key(run, non_pr_key=non_pr_key)
by_workflow[name][key].append(run)
for groups in by_workflow.values():
@ -324,6 +337,7 @@ def main() -> int:
obsolete_workflows,
dedupe_workflows,
include_non_pr=args.dedupe_include_non_pr,
non_pr_key=args.non_pr_key,
)
capped = selected[: max(0, args.max_cancel)]
@ -338,6 +352,7 @@ def main() -> int:
"obsolete_workflows": sorted(obsolete_workflows),
"dedupe_workflows": sorted(dedupe_workflows),
"dedupe_include_non_pr": args.dedupe_include_non_pr,
"non_pr_key": args.non_pr_key,
"max_cancel": args.max_cancel,
},
"counts": {

View File

@ -3759,6 +3759,119 @@ class CiScriptsBehaviorTest(unittest.TestCase):
planned_ids = [item["id"] for item in report["planned_actions"]]
self.assertEqual(planned_ids, [101, 102])
def test_queue_hygiene_non_pr_branch_mode_dedupes_push_runs(self) -> None:
runs_json = self.tmp / "runs-non-pr-branch.json"
output_json = self.tmp / "queue-hygiene-non-pr-branch.json"
runs_json.write_text(
json.dumps(
{
"workflow_runs": [
{
"id": 201,
"name": "CI Run",
"event": "push",
"head_branch": "main",
"head_sha": "sha-201",
"created_at": "2026-02-27T20:00:00Z",
},
{
"id": 202,
"name": "CI Run",
"event": "push",
"head_branch": "main",
"head_sha": "sha-202",
"created_at": "2026-02-27T20:01:00Z",
},
{
"id": 203,
"name": "CI Run",
"event": "push",
"head_branch": "dev",
"head_sha": "sha-203",
"created_at": "2026-02-27T20:02:00Z",
},
]
}
)
+ "\n",
encoding="utf-8",
)
proc = run_cmd(
[
"python3",
self._script("queue_hygiene.py"),
"--runs-json",
str(runs_json),
"--dedupe-workflow",
"CI Run",
"--dedupe-include-non-pr",
"--non-pr-key",
"branch",
"--output-json",
str(output_json),
]
)
self.assertEqual(proc.returncode, 0, msg=proc.stderr)
report = json.loads(output_json.read_text(encoding="utf-8"))
self.assertEqual(report["counts"]["candidate_runs_before_cap"], 1)
planned_ids = [item["id"] for item in report["planned_actions"]]
self.assertEqual(planned_ids, [201])
reasons = report["planned_actions"][0]["reasons"]
self.assertTrue(any(reason.startswith("dedupe-superseded-by:202") for reason in reasons))
self.assertEqual(report["policies"]["non_pr_key"], "branch")
def test_queue_hygiene_non_pr_sha_mode_keeps_distinct_push_commits(self) -> None:
runs_json = self.tmp / "runs-non-pr-sha.json"
output_json = self.tmp / "queue-hygiene-non-pr-sha.json"
runs_json.write_text(
json.dumps(
{
"workflow_runs": [
{
"id": 301,
"name": "CI Run",
"event": "push",
"head_branch": "main",
"head_sha": "sha-301",
"created_at": "2026-02-27T20:00:00Z",
},
{
"id": 302,
"name": "CI Run",
"event": "push",
"head_branch": "main",
"head_sha": "sha-302",
"created_at": "2026-02-27T20:01:00Z",
},
]
}
)
+ "\n",
encoding="utf-8",
)
proc = run_cmd(
[
"python3",
self._script("queue_hygiene.py"),
"--runs-json",
str(runs_json),
"--dedupe-workflow",
"CI Run",
"--dedupe-include-non-pr",
"--output-json",
str(output_json),
]
)
self.assertEqual(proc.returncode, 0, msg=proc.stderr)
report = json.loads(output_json.read_text(encoding="utf-8"))
self.assertEqual(report["counts"]["candidate_runs_before_cap"], 0)
self.assertEqual(report["planned_actions"], [])
self.assertEqual(report["policies"]["non_pr_key"], "sha")
if __name__ == "__main__": # pragma: no cover
unittest.main(verbosity=2)

View File

@ -1,5 +1,7 @@
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
use crate::config::schema::{CostEnforcementMode, ModelPricing};
use crate::config::{Config, ProgressMode};
use crate::cost::{BudgetCheck, CostTracker, UsagePeriod};
use crate::memory::{self, Memory, MemoryCategory};
use crate::multimodal;
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
@ -19,9 +21,11 @@ use rustyline::hint::Hinter;
use rustyline::validate::Validator;
use rustyline::{CompletionType, Config as RlConfig, Context, Editor, Helper};
use std::borrow::Cow;
use std::collections::{BTreeSet, HashSet};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Write;
use std::future::Future;
use std::io::Write as _;
use std::path::Path;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
@ -297,6 +301,7 @@ tokio::task_local! {
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
static SAFETY_HEARTBEAT_CONFIG: Option<SafetyHeartbeatConfig>;
static TOOL_LOOP_PROGRESS_MODE: ProgressMode;
static TOOL_LOOP_COST_ENFORCEMENT_CONTEXT: Option<CostEnforcementContext>;
}
/// Configuration for periodic safety-constraint re-injection (heartbeat).
@ -308,6 +313,56 @@ pub(crate) struct SafetyHeartbeatConfig {
pub interval: usize,
}
#[derive(Clone)]
pub(crate) struct CostEnforcementContext {
tracker: Arc<CostTracker>,
prices: HashMap<String, ModelPricing>,
mode: CostEnforcementMode,
route_down_model: Option<String>,
reserve_percent: u8,
}
pub(crate) fn create_cost_enforcement_context(
cost_config: &crate::config::CostConfig,
workspace_dir: &Path,
) -> Option<CostEnforcementContext> {
if !cost_config.enabled {
return None;
}
let tracker = match CostTracker::new(cost_config.clone(), workspace_dir) {
Ok(tracker) => Arc::new(tracker),
Err(error) => {
tracing::warn!("Cost budget preflight disabled: failed to initialize tracker: {error}");
return None;
}
};
let route_down_model = cost_config
.enforcement
.route_down_model
.clone()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
Some(CostEnforcementContext {
tracker,
prices: cost_config.prices.clone(),
mode: cost_config.enforcement.mode,
route_down_model,
reserve_percent: cost_config.enforcement.reserve_percent.min(100),
})
}
pub(crate) async fn scope_cost_enforcement_context<F>(
context: Option<CostEnforcementContext>,
future: F,
) -> F::Output
where
F: Future,
{
TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
.scope(context, future)
.await
}
fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool {
interval > 0 && counter > 0 && counter % interval == 0
}
@ -320,6 +375,100 @@ fn should_emit_tool_progress(mode: ProgressMode) -> bool {
mode != ProgressMode::Off
}
fn estimate_prompt_tokens(
messages: &[ChatMessage],
tools: Option<&[crate::tools::ToolSpec]>,
) -> u64 {
let message_chars: usize = messages
.iter()
.map(|msg| {
msg.role
.len()
.saturating_add(msg.content.chars().count())
.saturating_add(16)
})
.sum();
let tool_chars: usize = tools
.map(|specs| {
specs
.iter()
.map(|spec| serde_json::to_string(spec).map_or(0, |value| value.chars().count()))
.sum()
})
.unwrap_or(0);
let total_chars = message_chars.saturating_add(tool_chars);
let char_estimate = (total_chars as f64 / 4.0).ceil() as u64;
let framing_overhead = (messages.len() as u64).saturating_mul(6).saturating_add(64);
char_estimate.saturating_add(framing_overhead)
}
fn lookup_model_pricing(
prices: &HashMap<String, ModelPricing>,
provider: &str,
model: &str,
) -> (f64, f64) {
let full_name = format!("{provider}/{model}");
if let Some(pricing) = prices.get(&full_name) {
return (pricing.input, pricing.output);
}
if let Some(pricing) = prices.get(model) {
return (pricing.input, pricing.output);
}
for (key, pricing) in prices {
let key_model = key.split('/').next_back().unwrap_or(key);
if model.starts_with(key_model) || key_model.starts_with(model) {
return (pricing.input, pricing.output);
}
let normalized_model = model.replace('-', ".");
let normalized_key = key_model.replace('-', ".");
if normalized_model.contains(&normalized_key) || normalized_key.contains(&normalized_model)
{
return (pricing.input, pricing.output);
}
}
(3.0, 15.0)
}
fn estimate_request_cost_usd(
context: &CostEnforcementContext,
provider: &str,
model: &str,
messages: &[ChatMessage],
tools: Option<&[crate::tools::ToolSpec]>,
) -> f64 {
let reserve_multiplier = 1.0 + (f64::from(context.reserve_percent) / 100.0);
let input_tokens = estimate_prompt_tokens(messages, tools);
let output_tokens = (input_tokens / 4).max(256);
let input_tokens = ((input_tokens as f64) * reserve_multiplier).ceil() as u64;
let output_tokens = ((output_tokens as f64) * reserve_multiplier).ceil() as u64;
let (input_price, output_price) = lookup_model_pricing(&context.prices, provider, model);
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price.max(0.0);
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price.max(0.0);
input_cost + output_cost
}
fn usage_period_label(period: UsagePeriod) -> &'static str {
match period {
UsagePeriod::Session => "session",
UsagePeriod::Day => "daily",
UsagePeriod::Month => "monthly",
}
}
fn budget_exceeded_message(
model: &str,
estimated_cost_usd: f64,
current_usd: f64,
limit_usd: f64,
period: UsagePeriod,
) -> String {
format!(
"Budget enforcement blocked request for model '{model}': projected cost (+${estimated_cost_usd:.4}) exceeds {period_label} limit (${limit_usd:.2}, current ${current_usd:.2}).",
period_label = usage_period_label(period)
)
}
#[derive(Debug, Clone)]
struct ProgressEntry {
name: String,
@ -778,36 +927,40 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
hooks: Option<&crate::hooks::HookRunner>,
excluded_tools: &[String],
progress_mode: ProgressMode,
safety_heartbeat: Option<SafetyHeartbeatConfig>,
) -> Result<String> {
let reply_target = non_cli_approval_context
.as_ref()
.map(|ctx| ctx.reply_target.clone());
SAFETY_HEARTBEAT_CONFIG
TOOL_LOOP_PROGRESS_MODE
.scope(
safety_heartbeat,
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
non_cli_approval_context,
TOOL_LOOP_REPLY_TARGET.scope(
reply_target,
run_tool_call_loop(
provider,
history,
tools_registry,
observer,
provider_name,
model,
temperature,
silent,
approval,
channel_name,
multimodal_config,
max_tool_iterations,
cancellation_token,
on_delta,
hooks,
excluded_tools,
progress_mode,
SAFETY_HEARTBEAT_CONFIG.scope(
safety_heartbeat,
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
non_cli_approval_context,
TOOL_LOOP_REPLY_TARGET.scope(
reply_target,
run_tool_call_loop(
provider,
history,
tools_registry,
observer,
provider_name,
model,
temperature,
silent,
approval,
channel_name,
multimodal_config,
max_tool_iterations,
cancellation_token,
on_delta,
hooks,
excluded_tools,
),
),
),
),
@ -890,7 +1043,12 @@ pub(crate) async fn run_tool_call_loop(
let progress_mode = TOOL_LOOP_PROGRESS_MODE
.try_with(|mode| *mode)
.unwrap_or(ProgressMode::Verbose);
let cost_enforcement_context = TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten();
let mut progress_tracker = ProgressTracker::default();
let mut active_model = model.to_string();
let bypass_non_cli_approval_for_turn =
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
if bypass_non_cli_approval_for_turn {
@ -898,7 +1056,7 @@ pub(crate) async fn run_tool_call_loop(
"approval_bypass_one_time_all_tools_consumed",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("consumed one-time non-cli allow-all approval token"),
@ -950,6 +1108,13 @@ pub(crate) async fn run_tool_call_loop(
request_messages.push(ChatMessage::user(reminder));
}
}
// Unified path via Provider::chat so provider-specific native tool logic
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
let request_tools = if use_native_tools {
Some(tool_specs.as_slice())
} else {
None
};
// ── Progress: LLM thinking ────────────────────────────
if should_emit_verbose_progress(progress_mode) {
@ -963,16 +1128,175 @@ pub(crate) async fn run_tool_call_loop(
}
}
if let Some(cost_ctx) = cost_enforcement_context.as_ref() {
let mut estimated_cost_usd = estimate_request_cost_usd(
cost_ctx,
provider_name,
active_model.as_str(),
&request_messages,
request_tools,
);
let mut budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
Ok(check) => Some(check),
Err(error) => {
tracing::warn!("Cost preflight check failed: {error}");
None
}
};
if matches!(cost_ctx.mode, CostEnforcementMode::RouteDown)
&& matches!(budget_check, Some(BudgetCheck::Exceeded { .. }))
{
if let Some(route_down_model) = cost_ctx
.route_down_model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
{
if route_down_model != active_model {
let previous_model = active_model.clone();
active_model = route_down_model.to_string();
estimated_cost_usd = estimate_request_cost_usd(
cost_ctx,
provider_name,
active_model.as_str(),
&request_messages,
request_tools,
);
budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
Ok(check) => Some(check),
Err(error) => {
tracing::warn!(
"Cost preflight check failed after route-down: {error}"
);
None
}
};
runtime_trace::record_event(
"cost_budget_route_down",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("budget exceeded on primary model; route-down candidate applied"),
serde_json::json!({
"iteration": iteration + 1,
"from_model": previous_model,
"to_model": active_model,
"estimated_cost_usd": estimated_cost_usd,
}),
);
}
}
}
if let Some(check) = budget_check {
match check {
BudgetCheck::Allowed => {}
BudgetCheck::Warning {
current_usd,
limit_usd,
period,
} => {
tracing::warn!(
model = active_model.as_str(),
period = usage_period_label(period),
current_usd,
limit_usd,
estimated_cost_usd,
"Cost budget warning threshold reached"
);
runtime_trace::record_event(
"cost_budget_warning",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("budget warning threshold reached"),
serde_json::json!({
"iteration": iteration + 1,
"period": usage_period_label(period),
"current_usd": current_usd,
"limit_usd": limit_usd,
"estimated_cost_usd": estimated_cost_usd,
}),
);
}
BudgetCheck::Exceeded {
current_usd,
limit_usd,
period,
} => match cost_ctx.mode {
CostEnforcementMode::Warn => {
tracing::warn!(
model = active_model.as_str(),
period = usage_period_label(period),
current_usd,
limit_usd,
estimated_cost_usd,
"Cost budget exceeded (warn mode): continuing request"
);
runtime_trace::record_event(
"cost_budget_exceeded_warn_mode",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("budget exceeded but proceeding due to warn mode"),
serde_json::json!({
"iteration": iteration + 1,
"period": usage_period_label(period),
"current_usd": current_usd,
"limit_usd": limit_usd,
"estimated_cost_usd": estimated_cost_usd,
}),
);
}
CostEnforcementMode::RouteDown | CostEnforcementMode::Block => {
let message = budget_exceeded_message(
active_model.as_str(),
estimated_cost_usd,
current_usd,
limit_usd,
period,
);
runtime_trace::record_event(
"cost_budget_blocked",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&message),
serde_json::json!({
"iteration": iteration + 1,
"period": usage_period_label(period),
"current_usd": current_usd,
"limit_usd": limit_usd,
"estimated_cost_usd": estimated_cost_usd,
}),
);
return Err(anyhow::anyhow!(message));
}
},
}
}
}
observer.record_event(&ObserverEvent::LlmRequest {
provider: provider_name.to_string(),
model: model.to_string(),
model: active_model.clone(),
messages_count: history.len(),
});
runtime_trace::record_event(
"llm_request",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
None,
None,
@ -986,23 +1310,15 @@ pub(crate) async fn run_tool_call_loop(
// Fire void hook before LLM call
if let Some(hooks) = hooks {
hooks.fire_llm_input(history, model).await;
hooks.fire_llm_input(history, active_model.as_str()).await;
}
// Unified path via Provider::chat so provider-specific native tool logic
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
let request_tools = if use_native_tools {
Some(tool_specs.as_slice())
} else {
None
};
let chat_future = provider.chat(
ChatRequest {
messages: &request_messages,
tools: request_tools,
},
model,
active_model.as_str(),
temperature,
);
@ -1032,7 +1348,7 @@ pub(crate) async fn run_tool_call_loop(
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
model: active_model.clone(),
duration: llm_started_at.elapsed(),
success: true,
error_message: None,
@ -1062,7 +1378,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_parse_issue",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(parse_issue),
@ -1080,7 +1396,7 @@ pub(crate) async fn run_tool_call_loop(
"llm_response",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
None,
@ -1131,7 +1447,7 @@ pub(crate) async fn run_tool_call_loop(
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
model: active_model.clone(),
duration: llm_started_at.elapsed(),
success: false,
error_message: Some(safe_error.clone()),
@ -1142,7 +1458,7 @@ pub(crate) async fn run_tool_call_loop(
"llm_response",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&safe_error),
@ -1195,7 +1511,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_followthrough_retry",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("llm response implied follow-up action but emitted no tool call"),
@ -1223,7 +1539,7 @@ pub(crate) async fn run_tool_call_loop(
"turn_final_response",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
None,
@ -1299,7 +1615,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&cancelled),
@ -1341,7 +1657,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&blocked),
@ -1381,7 +1697,7 @@ pub(crate) async fn run_tool_call_loop(
"approval_bypass_non_cli_session_grant",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("using runtime non-cli session approval grant"),
@ -1438,7 +1754,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&denied),
@ -1472,7 +1788,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&duplicate),
@ -1500,7 +1816,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_start",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
None,
None,
@ -1560,7 +1876,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(outcome.success),
outcome.error_reason.as_deref(),
@ -1672,7 +1988,7 @@ pub(crate) async fn run_tool_call_loop(
"loop_detected_warning",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("loop pattern detected, injecting self-correction prompt"),
@ -1694,7 +2010,7 @@ pub(crate) async fn run_tool_call_loop(
"loop_detected_hard_stop",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("loop persisted after warning, stopping early"),
@ -1714,7 +2030,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_loop_exhausted",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("agent exceeded maximum tool iterations"),
@ -2150,6 +2466,8 @@ pub async fn run(
// ── Execute ──────────────────────────────────────────────────
let start = Instant::now();
let cost_enforcement_context =
create_cost_enforcement_context(&config.cost, &config.workspace_dir);
let mut final_output = String::new();
@ -2196,8 +2514,9 @@ pub async fn run(
} else {
None
};
let response = SAFETY_HEARTBEAT_CONFIG
.scope(
let response = scope_cost_enforcement_context(
cost_enforcement_context.clone(),
SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
@ -2220,8 +2539,9 @@ pub async fn run(
&[],
),
),
)
.await?;
),
)
.await?;
final_output = response.clone();
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
let assistant_key = autosave_memory_key("assistant_resp");
@ -2373,8 +2693,9 @@ pub async fn run(
} else {
None
};
let response = match SAFETY_HEARTBEAT_CONFIG
.scope(
let response = match scope_cost_enforcement_context(
cost_enforcement_context.clone(),
SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
@ -2397,8 +2718,9 @@ pub async fn run(
&[],
),
),
)
.await
),
)
.await
{
Ok(resp) => resp,
Err(e) => {
@ -2684,6 +3006,8 @@ pub async fn process_message_with_session(
ChatMessage::user(&enriched),
];
let cost_enforcement_context =
create_cost_enforcement_context(&config.cost, &config.workspace_dir);
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
Some(SafetyHeartbeatConfig {
body: security.summary_for_heartbeat(),
@ -2692,8 +3016,9 @@ pub async fn process_message_with_session(
} else {
None
};
SAFETY_HEARTBEAT_CONFIG
.scope(
scope_cost_enforcement_context(
cost_enforcement_context,
SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
agent_turn(
provider.as_ref(),
@ -2707,8 +3032,9 @@ pub async fn process_message_with_session(
&config.multimodal,
config.agent.max_tool_iterations,
),
)
.await
),
)
.await
}
#[cfg(test)]
@ -3623,6 +3949,7 @@ mod tests {
None,
None,
&[],
ProgressMode::Verbose,
None,
)
.await

View File

@ -78,7 +78,8 @@ pub use whatsapp_web::WhatsAppWebChannel;
use crate::agent::loop_::{
build_shell_policy_instructions, build_tool_instructions_from_specs,
run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig,
run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext,
NonCliApprovalPrompt, SafetyHeartbeatConfig,
};
use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager};
use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError};
@ -249,6 +250,7 @@ struct ChannelRuntimeDefaults {
api_key: Option<String>,
api_url: Option<String>,
reliability: crate::config::ReliabilityConfig,
cost: crate::config::CostConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -1053,6 +1055,7 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
api_key: config.api_key.clone(),
api_url: config.api_url.clone(),
reliability: config.reliability.clone(),
cost: config.cost.clone(),
}
}
@ -1098,6 +1101,7 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau
api_key: ctx.api_key.clone(),
api_url: ctx.api_url.clone(),
reliability: (*ctx.reliability).clone(),
cost: crate::config::CostConfig::default(),
}
}
@ -3664,29 +3668,78 @@ or tune thresholds in config.",
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context(
&runtime_defaults.cost,
ctx.workspace_dir.as_path(),
);
let (approval_prompt_tx, mut approval_prompt_rx) =
tokio::sync::mpsc::unbounded_channel::<NonCliApprovalPrompt>();
let non_cli_approval_context = if msg.channel != "cli" && target_channel.is_some() {
Some(NonCliApprovalContext {
sender: msg.sender.clone(),
reply_target: msg.reply_target.clone(),
prompt_tx: approval_prompt_tx,
})
} else {
drop(approval_prompt_tx);
None
};
let approval_prompt_dispatcher = if let (Some(channel_ref), true) =
(target_channel.as_ref(), non_cli_approval_context.is_some())
{
let channel = Arc::clone(channel_ref);
let reply_target = msg.reply_target.clone();
let thread_ts = msg.thread_ts.clone();
Some(tokio::spawn(async move {
while let Some(prompt) = approval_prompt_rx.recv().await {
if let Err(err) = channel
.send_approval_prompt(
&reply_target,
&prompt.request_id,
&prompt.tool_name,
&prompt.arguments,
thread_ts.clone(),
)
.await
{
tracing::warn!(
"Failed to send non-CLI approval prompt for request {}: {err}",
prompt.request_id
);
}
}
}))
} else {
None
};
let llm_result = tokio::select! {
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
result = tokio::time::timeout(
Duration::from_secs(timeout_budget_secs),
run_tool_call_loop_with_reply_target(
active_provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
ctx.observer.as_ref(),
route.provider.as_str(),
route.model.as_str(),
runtime_defaults.temperature,
true,
Some(ctx.approval_manager.as_ref()),
msg.channel.as_str(),
Some(msg.reply_target.as_str()),
&ctx.multimodal,
ctx.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx,
ctx.hooks.as_deref(),
&excluded_tools_snapshot,
progress_mode,
crate::agent::loop_::scope_cost_enforcement_context(
cost_enforcement_context,
run_tool_call_loop_with_non_cli_approval_context(
active_provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
ctx.observer.as_ref(),
route.provider.as_str(),
route.model.as_str(),
runtime_defaults.temperature,
true,
Some(ctx.approval_manager.as_ref()),
msg.channel.as_str(),
non_cli_approval_context,
&ctx.multimodal,
ctx.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx,
ctx.hooks.as_deref(),
&excluded_tools_snapshot,
progress_mode,
ctx.safety_heartbeat.clone(),
),
),
) => LlmExecutionResult::Completed(result),
};
@ -3694,6 +3747,9 @@ or tune thresholds in config.",
if let Some(handle) = draft_updater {
let _ = handle.await;
}
if let Some(handle) = approval_prompt_dispatcher {
let _ = handle.await;
}
if let Some(token) = typing_cancellation.as_ref() {
token.cancel();
@ -7656,6 +7712,131 @@ BTC is currently around $65,000 based on latest tool output."#
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn process_channel_message_prompts_and_waits_for_non_cli_always_ask_approval() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let autonomy_cfg = crate::config::AutonomyConfig {
always_ask: vec!["mock_price".to_string()],
..crate::config::AutonomyConfig::default()
};
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 10,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
conversation_locks: Default::default(),
session_config: crate::config::AgentSessionConfig::default(),
session_manager: None,
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: false,
multimodal: crate::config::MultimodalConfig::default(),
hooks: None,
non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())),
query_classification: crate::config::QueryClassificationConfig::default(),
model_routes: Vec::new(),
approval_manager: Arc::new(ApprovalManager::from_config(&autonomy_cfg)),
safety_heartbeat: None,
startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(),
});
let runtime_ctx_for_first_turn = runtime_ctx.clone();
let first_turn = tokio::spawn(async move {
process_channel_message(
runtime_ctx_for_first_turn,
traits::ChannelMessage {
id: "msg-non-cli-approval-1".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "What is the BTC price now?".to_string(),
channel: "telegram".to_string(),
timestamp: 1,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
});
let request_id = tokio::time::timeout(Duration::from_secs(2), async {
loop {
let pending = runtime_ctx.approval_manager.list_non_cli_pending_requests(
Some("alice"),
Some("telegram"),
Some("chat-1"),
);
if let Some(req) = pending.first() {
break req.request_id.clone();
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
})
.await
.expect("pending approval request should be created for always_ask tool");
process_channel_message(
runtime_ctx.clone(),
traits::ChannelMessage {
id: "msg-non-cli-approval-2".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: format!("/approve-allow {request_id}"),
channel: "telegram".to_string(),
timestamp: 2,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
tokio::time::timeout(Duration::from_secs(5), first_turn)
.await
.expect("first channel turn should finish after approval")
.expect("first channel turn task should not panic");
let sent = channel_impl.sent_messages.lock().await;
assert!(
sent.iter()
.any(|entry| entry.contains("Approval required for tool `mock_price`")),
"channel should emit non-cli approval prompt"
);
assert!(
sent.iter()
.any(|entry| entry.contains("Approved supervised execution for `mock_price`")),
"channel should acknowledge explicit approval command"
);
assert!(
sent.iter()
.any(|entry| entry.contains("BTC is currently around")),
"tool call should execute after approval and produce final response"
);
assert!(
sent.iter().all(|entry| !entry.contains("Denied by user.")),
"always_ask tool should not be silently denied once non-cli approval prompt path is wired"
);
}
#[tokio::test]
async fn process_channel_message_denies_approval_management_for_unlisted_sender() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
@ -9232,6 +9413,7 @@ BTC is currently around $65,000 based on latest tool output."#
api_key: None,
api_url: None,
reliability: crate::config::ReliabilityConfig::default(),
cost: crate::config::CostConfig::default(),
},
perplexity_filter: crate::config::PerplexityFilterConfig::default(),
outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(),

View File

@ -1026,6 +1026,11 @@ pub struct SkillsConfig {
/// If unset, defaults to `$HOME/open-skills` when enabled.
#[serde(default)]
pub open_skills_dir: Option<String>,
/// Optional allowlist of canonical directory roots for workspace skill symlink targets.
/// Symlinked workspace skills are rejected unless their resolved targets are under one
/// of these roots. Accepts absolute paths and `~/` home-relative paths.
#[serde(default)]
pub trusted_skill_roots: Vec<String>,
/// Allow script-like files in skills (`.sh`, `.bash`, `.ps1`, shebang shell files).
/// Default: `false` (secure by default).
#[serde(default)]
@ -1195,6 +1200,58 @@ pub struct CostConfig {
/// Per-model pricing (USD per 1M tokens)
#[serde(default)]
pub prices: std::collections::HashMap<String, ModelPricing>,
/// Runtime budget enforcement policy (`[cost.enforcement]`).
#[serde(default)]
pub enforcement: CostEnforcementConfig,
}
/// Budget enforcement behavior when projected spend approaches/exceeds limits.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CostEnforcementMode {
/// Log warnings only; never block the request.
Warn,
/// Attempt one downgrade to a cheaper route/model, then block if still over budget.
RouteDown,
/// Block immediately when projected spend exceeds configured limits.
Block,
}
fn default_cost_enforcement_mode() -> CostEnforcementMode {
CostEnforcementMode::Warn
}
/// Runtime budget enforcement controls (`[cost.enforcement]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct CostEnforcementConfig {
/// Enforcement behavior. Default: `warn`.
#[serde(default = "default_cost_enforcement_mode")]
pub mode: CostEnforcementMode,
/// Optional fallback model (or `hint:*`) when `mode = "route_down"`.
#[serde(default = "default_route_down_model")]
pub route_down_model: Option<String>,
/// Extra reserve added to token/cost estimates (percentage, 0-100). Default: `10`.
#[serde(default = "default_cost_reserve_percent")]
pub reserve_percent: u8,
}
fn default_route_down_model() -> Option<String> {
Some("hint:fast".to_string())
}
fn default_cost_reserve_percent() -> u8 {
10
}
impl Default for CostEnforcementConfig {
fn default() -> Self {
Self {
mode: default_cost_enforcement_mode(),
route_down_model: default_route_down_model(),
reserve_percent: default_cost_reserve_percent(),
}
}
}
/// Per-model pricing entry (USD per 1M tokens).
@ -1230,6 +1287,7 @@ impl Default for CostConfig {
warn_at_percent: default_warn_percent(),
allow_override: false,
prices: get_default_pricing(),
enforcement: CostEnforcementConfig::default(),
}
}
}
@ -7764,6 +7822,44 @@ impl Config {
anyhow::bail!("web_search.timeout_secs must be greater than 0");
}
// Cost
if self.cost.warn_at_percent > 100 {
anyhow::bail!("cost.warn_at_percent must be between 0 and 100");
}
if self.cost.enforcement.reserve_percent > 100 {
anyhow::bail!("cost.enforcement.reserve_percent must be between 0 and 100");
}
if matches!(self.cost.enforcement.mode, CostEnforcementMode::RouteDown) {
let route_down_model = self
.cost
.enforcement
.route_down_model
.as_deref()
.map(str::trim)
.filter(|model| !model.is_empty())
.ok_or_else(|| {
anyhow::anyhow!(
"cost.enforcement.route_down_model must be set when mode is route_down"
)
})?;
if let Some(route_hint) = route_down_model
.strip_prefix("hint:")
.map(str::trim)
.filter(|hint| !hint.is_empty())
{
if !self
.model_routes
.iter()
.any(|route| route.hint.trim() == route_hint)
{
anyhow::bail!(
"cost.enforcement.route_down_model uses hint '{route_hint}', but no matching [[model_routes]] entry exists"
);
}
}
}
// Scheduler
if self.scheduler.max_concurrent == 0 {
anyhow::bail!("scheduler.max_concurrent must be greater than 0");
@ -13738,4 +13834,80 @@ sensitivity = 0.9
.validate()
.expect("disabled coordination should allow empty lead agent");
}
#[test]
async fn cost_enforcement_defaults_are_stable() {
let cost = CostConfig::default();
assert_eq!(cost.enforcement.mode, CostEnforcementMode::Warn);
assert_eq!(
cost.enforcement.route_down_model.as_deref(),
Some("hint:fast")
);
assert_eq!(cost.enforcement.reserve_percent, 10);
}
#[test]
async fn cost_enforcement_config_parses_route_down_mode() {
let parsed: CostConfig = toml::from_str(
r#"
enabled = true
[enforcement]
mode = "route_down"
route_down_model = "hint:fast"
reserve_percent = 15
"#,
)
.expect("cost enforcement should parse");
assert!(parsed.enabled);
assert_eq!(parsed.enforcement.mode, CostEnforcementMode::RouteDown);
assert_eq!(
parsed.enforcement.route_down_model.as_deref(),
Some("hint:fast")
);
assert_eq!(parsed.enforcement.reserve_percent, 15);
}
#[test]
async fn validation_rejects_cost_enforcement_reserve_over_100() {
let mut config = Config::default();
config.cost.enforcement.reserve_percent = 150;
let err = config
.validate()
.expect_err("expected cost.enforcement.reserve_percent validation failure");
assert!(err.to_string().contains("cost.enforcement.reserve_percent"));
}
#[test]
async fn validation_rejects_route_down_hint_without_matching_route() {
let mut config = Config::default();
config.cost.enforcement.mode = CostEnforcementMode::RouteDown;
config.cost.enforcement.route_down_model = Some("hint:fast".to_string());
let err = config
.validate()
.expect_err("route_down hint should require a matching model route");
assert!(err
.to_string()
.contains("cost.enforcement.route_down_model uses hint 'fast'"));
}
#[test]
async fn validation_accepts_route_down_hint_with_matching_route() {
let mut config = Config::default();
config.cost.enforcement.mode = CostEnforcementMode::RouteDown;
config.cost.enforcement.route_down_model = Some("hint:fast".to_string());
config.model_routes = vec![ModelRouteConfig {
hint: "fast".to_string(),
provider: "openrouter".to_string(),
model: "openai/gpt-4.1-mini".to_string(),
api_key: None,
max_tokens: None,
transport: None,
}];
config
.validate()
.expect("matching route_down hint route should validate");
}
}

View File

@ -333,15 +333,20 @@ the binary location.
Examples:
zeroclaw update # Update to latest version
zeroclaw update --check # Check for updates without installing
zeroclaw update --instructions # Show install-method-specific update instructions
zeroclaw update --force # Reinstall even if already up to date")]
Update {
/// Check for updates without installing
#[arg(long)]
#[arg(long, conflicts_with_all = ["force", "instructions"])]
check: bool,
/// Force update even if already at latest version
#[arg(long)]
#[arg(long, conflicts_with = "instructions")]
force: bool,
/// Show human-friendly update instructions for your installation method
#[arg(long, conflicts_with_all = ["check", "force"])]
instructions: bool,
},
/// Engage, inspect, and resume emergency-stop states.
@ -1107,9 +1112,18 @@ async fn main() -> Result<()> {
Ok(())
}
Commands::Update { check, force } => {
update::self_update(force, check).await?;
Ok(())
Commands::Update {
check,
force,
instructions,
} => {
if instructions {
update::print_update_instructions()?;
Ok(())
} else {
update::self_update(force, check).await?;
Ok(())
}
}
Commands::Estop {
@ -2630,4 +2644,41 @@ mod tests {
);
assert_eq!(payload["nested"]["non_secret"], serde_json::json!("ok"));
}
#[test]
fn update_help_mentions_instructions_flag() {
let cmd = Cli::command();
let update_cmd = cmd
.get_subcommands()
.find(|subcommand| subcommand.get_name() == "update")
.expect("update subcommand must exist");
let mut output = Vec::new();
update_cmd
.clone()
.write_long_help(&mut output)
.expect("help generation should succeed");
let help = String::from_utf8(output).expect("help output should be utf-8");
assert!(help.contains("--instructions"));
}
#[test]
fn update_cli_parses_instructions_flag() {
let cli = Cli::try_parse_from(["zeroclaw", "update", "--instructions"])
.expect("update --instructions should parse");
match cli.command {
Commands::Update {
check,
force,
instructions,
} => {
assert!(!check);
assert!(!force);
assert!(instructions);
}
other => panic!("expected update command, got {other:?}"),
}
}
}

View File

@ -103,8 +103,9 @@ const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: false,
};
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 5] = [
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 6] = [
SQLITE_PROFILE,
SQLITE_QDRANT_HYBRID_PROFILE,
LUCID_PROFILE,
CORTEX_MEM_PROFILE,
MARKDOWN_PROFILE,
@ -194,12 +195,13 @@ mod tests {
#[test]
fn selectable_backends_are_ordered_for_onboarding() {
let backends = selectable_memory_backends();
assert_eq!(backends.len(), 5);
assert_eq!(backends.len(), 6);
assert_eq!(backends[0].key, "sqlite");
assert_eq!(backends[1].key, "lucid");
assert_eq!(backends[2].key, "cortex-mem");
assert_eq!(backends[3].key, "markdown");
assert_eq!(backends[4].key, "none");
assert_eq!(backends[1].key, "sqlite_qdrant_hybrid");
assert_eq!(backends[2].key, "lucid");
assert_eq!(backends[3].key, "cortex-mem");
assert_eq!(backends[4].key, "markdown");
assert_eq!(backends[5].key, "none");
}
#[test]

View File

@ -4091,9 +4091,68 @@ fn setup_memory() -> Result<MemoryConfig> {
let mut config = memory_config_defaults_for_backend(backend);
config.auto_save = auto_save;
if classify_memory_backend(backend) == MemoryBackendKind::SqliteQdrantHybrid {
configure_hybrid_qdrant_memory(&mut config)?;
}
Ok(config)
}
fn configure_hybrid_qdrant_memory(config: &mut MemoryConfig) -> Result<()> {
print_bullet("Hybrid memory keeps local SQLite metadata and uses Qdrant for semantic ranking.");
print_bullet("SQLite storage path stays at the default workspace database.");
let qdrant_url_default = config
.qdrant
.url
.clone()
.unwrap_or_else(|| "http://localhost:6333".to_string());
let qdrant_url: String = Input::new()
.with_prompt(" Qdrant URL")
.default(qdrant_url_default)
.interact_text()?;
let qdrant_url = qdrant_url.trim();
if qdrant_url.is_empty() {
bail!("Qdrant URL is required for sqlite_qdrant_hybrid backend");
}
config.qdrant.url = Some(qdrant_url.to_string());
let qdrant_collection: String = Input::new()
.with_prompt(" Qdrant collection")
.default(config.qdrant.collection.clone())
.interact_text()?;
let qdrant_collection = qdrant_collection.trim();
if !qdrant_collection.is_empty() {
config.qdrant.collection = qdrant_collection.to_string();
}
let qdrant_api_key: String = Input::new()
.with_prompt(" Qdrant API key (optional, Enter to skip)")
.allow_empty(true)
.interact_text()?;
let qdrant_api_key = qdrant_api_key.trim();
config.qdrant.api_key = if qdrant_api_key.is_empty() {
None
} else {
Some(qdrant_api_key.to_string())
};
println!(
" {} Qdrant: {} (collection: {}, api key: {})",
style("").green().bold(),
style(config.qdrant.url.as_deref().unwrap_or_default()).green(),
style(&config.qdrant.collection).green(),
if config.qdrant.api_key.is_some() {
style("set").green().to_string()
} else {
style("not set").dim().to_string()
}
);
Ok(())
}
fn setup_identity_backend() -> Result<IdentityConfig> {
print_bullet("Choose the identity format ZeroClaw should scaffold for this workspace.");
print_bullet("You can switch later in config.toml under [identity].");
@ -8515,10 +8574,11 @@ mod tests {
#[test]
fn backend_key_from_choice_maps_supported_backends() {
assert_eq!(backend_key_from_choice(0), "sqlite");
assert_eq!(backend_key_from_choice(1), "lucid");
assert_eq!(backend_key_from_choice(2), "cortex-mem");
assert_eq!(backend_key_from_choice(3), "markdown");
assert_eq!(backend_key_from_choice(4), "none");
assert_eq!(backend_key_from_choice(1), "sqlite_qdrant_hybrid");
assert_eq!(backend_key_from_choice(2), "lucid");
assert_eq!(backend_key_from_choice(3), "cortex-mem");
assert_eq!(backend_key_from_choice(4), "markdown");
assert_eq!(backend_key_from_choice(5), "none");
assert_eq!(backend_key_from_choice(999), "sqlite");
}
@ -8560,6 +8620,18 @@ mod tests {
assert_eq!(config.embedding_cache_size, 10000);
}
#[test]
fn memory_config_defaults_for_hybrid_enable_sqlite_hygiene() {
let config = memory_config_defaults_for_backend("sqlite_qdrant_hybrid");
assert_eq!(config.backend, "sqlite_qdrant_hybrid");
assert!(config.auto_save);
assert!(config.hygiene_enabled);
assert_eq!(config.archive_after_days, 7);
assert_eq!(config.purge_after_days, 30);
assert_eq!(config.embedding_cache_size, 10000);
assert_eq!(config.qdrant.collection, "zeroclaw_memories");
}
#[test]
fn memory_config_defaults_for_none_disable_sqlite_hygiene() {
let config = memory_config_defaults_for_backend("none");

View File

@ -80,7 +80,7 @@ fn default_version() -> String {
/// Load all skills from the workspace skills directory
pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
load_skills_with_open_skills_config(workspace_dir, None, None, None)
load_skills_with_open_skills_config(workspace_dir, None, None, None, None)
}
/// Load skills using runtime config values (preferred at runtime).
@ -90,6 +90,7 @@ pub fn load_skills_with_config(workspace_dir: &Path, config: &crate::config::Con
Some(config.skills.open_skills_enabled),
config.skills.open_skills_dir.as_deref(),
Some(config.skills.allow_scripts),
Some(&config.skills.trusted_skill_roots),
)
}
@ -98,9 +99,12 @@ fn load_skills_with_open_skills_config(
config_open_skills_enabled: Option<bool>,
config_open_skills_dir: Option<&str>,
config_allow_scripts: Option<bool>,
config_trusted_skill_roots: Option<&[String]>,
) -> Vec<Skill> {
let mut skills = Vec::new();
let allow_scripts = config_allow_scripts.unwrap_or(false);
let trusted_skill_roots =
resolve_trusted_skill_roots(workspace_dir, config_trusted_skill_roots.unwrap_or(&[]));
if let Some(open_skills_dir) =
ensure_open_skills_repo(config_open_skills_enabled, config_open_skills_dir)
@ -108,16 +112,113 @@ fn load_skills_with_open_skills_config(
skills.extend(load_open_skills(&open_skills_dir, allow_scripts));
}
skills.extend(load_workspace_skills(workspace_dir, allow_scripts));
skills.extend(load_workspace_skills(
workspace_dir,
allow_scripts,
&trusted_skill_roots,
));
skills
}
fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
fn load_workspace_skills(
workspace_dir: &Path,
allow_scripts: bool,
trusted_skill_roots: &[PathBuf],
) -> Vec<Skill> {
let skills_dir = workspace_dir.join("skills");
load_skills_from_directory(&skills_dir, allow_scripts)
load_skills_from_directory(&skills_dir, allow_scripts, trusted_skill_roots)
}
fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
fn resolve_trusted_skill_roots(workspace_dir: &Path, raw_roots: &[String]) -> Vec<PathBuf> {
let home_dir = UserDirs::new().map(|dirs| dirs.home_dir().to_path_buf());
let mut resolved = Vec::new();
for raw in raw_roots {
let trimmed = raw.trim();
if trimmed.is_empty() {
continue;
}
let expanded = if trimmed == "~" {
home_dir.clone().unwrap_or_else(|| PathBuf::from(trimmed))
} else if let Some(rest) = trimmed
.strip_prefix("~/")
.or_else(|| trimmed.strip_prefix("~\\"))
{
home_dir
.as_ref()
.map(|home| home.join(rest))
.unwrap_or_else(|| PathBuf::from(trimmed))
} else {
PathBuf::from(trimmed)
};
let candidate = if expanded.is_relative() {
workspace_dir.join(expanded)
} else {
expanded
};
match candidate.canonicalize() {
Ok(canonical) if canonical.is_dir() => resolved.push(canonical),
Ok(canonical) => tracing::warn!(
"ignoring [skills].trusted_skill_roots entry '{}': canonical path is not a directory ({})",
trimmed,
canonical.display()
),
Err(err) => tracing::warn!(
"ignoring [skills].trusted_skill_roots entry '{}': failed to canonicalize {} ({err})",
trimmed,
candidate.display()
),
}
}
resolved.sort();
resolved.dedup();
resolved
}
fn enforce_workspace_skill_symlink_trust(
path: &Path,
trusted_skill_roots: &[PathBuf],
) -> Result<()> {
let canonical_target = path
.canonicalize()
.with_context(|| format!("failed to resolve skill symlink target {}", path.display()))?;
if !canonical_target.is_dir() {
anyhow::bail!(
"symlink target is not a directory: {}",
canonical_target.display()
);
}
if trusted_skill_roots
.iter()
.any(|root| canonical_target.starts_with(root))
{
return Ok(());
}
if trusted_skill_roots.is_empty() {
anyhow::bail!(
"symlink target {} is not allowed because [skills].trusted_skill_roots is empty",
canonical_target.display()
);
}
anyhow::bail!(
"symlink target {} is outside configured [skills].trusted_skill_roots",
canonical_target.display()
);
}
fn load_skills_from_directory(
skills_dir: &Path,
allow_scripts: bool,
trusted_skill_roots: &[PathBuf],
) -> Vec<Skill> {
if !skills_dir.exists() {
return Vec::new();
}
@ -130,7 +231,26 @@ fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Ski
for entry in entries.flatten() {
let path = entry.path();
if !path.is_dir() {
let metadata = match std::fs::symlink_metadata(&path) {
Ok(meta) => meta,
Err(err) => {
tracing::warn!(
"skipping skill entry {}: failed to read metadata ({err})",
path.display()
);
continue;
}
};
if metadata.file_type().is_symlink() {
if let Err(err) = enforce_workspace_skill_symlink_trust(&path, trusted_skill_roots) {
tracing::warn!(
"skipping untrusted symlinked skill entry {}: {err}",
path.display()
);
continue;
}
} else if !metadata.is_dir() {
continue;
}
@ -180,7 +300,7 @@ fn load_open_skills(repo_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
// as executable skills.
let nested_skills_dir = repo_dir.join("skills");
if nested_skills_dir.is_dir() {
return load_skills_from_directory(&nested_skills_dir, allow_scripts);
return load_skills_from_directory(&nested_skills_dir, allow_scripts, &[]);
}
let mut skills = Vec::new();
@ -2137,6 +2257,20 @@ pub fn handle_command(command: crate::SkillCommands, config: &crate::config::Con
anyhow::bail!("Skill source or installed skill not found: {source}");
}
let trusted_skill_roots =
resolve_trusted_skill_roots(workspace_dir, &config.skills.trusted_skill_roots);
if let Ok(metadata) = std::fs::symlink_metadata(&target) {
if metadata.file_type().is_symlink() {
enforce_workspace_skill_symlink_trust(&target, &trusted_skill_roots)
.with_context(|| {
format!(
"trusted-symlink policy rejected audit target {}",
target.display()
)
})?;
}
}
let report = audit::audit_skill_directory_with_options(
&target,
audit::SkillAuditOptions {

View File

@ -1,6 +1,8 @@
#[cfg(test)]
mod tests {
use crate::skills::skills_dir;
use crate::config::Config;
use crate::skills::{handle_command, load_skills_with_config, skills_dir};
use crate::SkillCommands;
use std::path::Path;
use tempfile::TempDir;
@ -83,7 +85,7 @@ mod tests {
}
#[tokio::test]
async fn test_skills_symlink_permissions_and_safety() {
async fn test_workspace_symlink_loading_requires_trusted_roots() {
let tmp = TempDir::new().unwrap();
let workspace_dir = tmp.path().join("workspace");
tokio::fs::create_dir_all(&workspace_dir).await.unwrap();
@ -93,7 +95,6 @@ mod tests {
#[cfg(unix)]
{
// Test case: Symlink outside workspace should be allowed (user responsibility)
let outside_dir = tmp.path().join("outside_skill");
tokio::fs::create_dir_all(&outside_dir).await.unwrap();
tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent")
@ -102,15 +103,74 @@ mod tests {
let dest_link = skills_path.join("outside_skill");
let result = std::os::unix::fs::symlink(&outside_dir, &dest_link);
assert!(result.is_ok(), "symlink creation should succeed on unix");
let mut config = Config::default();
config.workspace_dir = workspace_dir.clone();
config.config_path = workspace_dir.join("config.toml");
let blocked = load_skills_with_config(&workspace_dir, &config);
assert!(
result.is_ok(),
"Should allow symlinking to directories outside workspace"
blocked.is_empty(),
"symlinked skill should be rejected when trusted_skill_roots is empty"
);
// Should still be readable
let content = tokio::fs::read_to_string(dest_link.join("SKILL.md")).await;
assert!(content.is_ok());
assert!(content.unwrap().contains("Outside Skill"));
config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()];
let allowed = load_skills_with_config(&workspace_dir, &config);
assert_eq!(
allowed.len(),
1,
"symlinked skill should load when target is inside trusted roots"
);
assert_eq!(allowed[0].name, "outside_skill");
}
}
#[tokio::test]
async fn test_skills_audit_respects_trusted_symlink_roots() {
let tmp = TempDir::new().unwrap();
let workspace_dir = tmp.path().join("workspace");
tokio::fs::create_dir_all(&workspace_dir).await.unwrap();
let skills_path = skills_dir(&workspace_dir);
tokio::fs::create_dir_all(&skills_path).await.unwrap();
#[cfg(unix)]
{
let outside_dir = tmp.path().join("outside_skill");
tokio::fs::create_dir_all(&outside_dir).await.unwrap();
tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent")
.await
.unwrap();
let link_path = skills_path.join("outside_skill");
std::os::unix::fs::symlink(&outside_dir, &link_path).unwrap();
let mut config = Config::default();
config.workspace_dir = workspace_dir.clone();
config.config_path = workspace_dir.join("config.toml");
let blocked = handle_command(
SkillCommands::Audit {
source: "outside_skill".to_string(),
},
&config,
);
assert!(
blocked.is_err(),
"audit should reject symlink target when trusted roots are not configured"
);
config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()];
let allowed = handle_command(
SkillCommands::Audit {
source: "outside_skill".to_string(),
},
&config,
);
assert!(
allowed.is_ok(),
"audit should pass when symlink target is inside a trusted root"
);
}
}
}

View File

@ -107,12 +107,27 @@ impl McpTransportConn for StdioTransport {
error: None,
});
}
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
.await
.context("timeout waiting for MCP response")??;
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
Ok(resp)
let deadline = std::time::Instant::now() + Duration::from_secs(RECV_TIMEOUT_SECS);
loop {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
bail!("timeout waiting for MCP response");
}
let resp_line = timeout(remaining, self.recv_raw())
.await
.context("timeout waiting for MCP response")??;
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
if resp.id.is_none() {
// Server-sent notification (e.g. `notifications/initialized`) — skip and
// keep waiting for the actual response to our request.
tracing::debug!(
"MCP stdio: skipping server notification while waiting for response"
);
continue;
}
return Ok(resp);
}
}
async fn close(&mut self) -> Result<()> {

View File

@ -82,6 +82,7 @@ pub mod web_access_config;
pub mod web_fetch;
pub mod web_search_config;
pub mod web_search_tool;
pub mod xlsx_read;
pub use apply_patch::ApplyPatchTool;
pub use bg_run::{
@ -147,6 +148,7 @@ pub use web_access_config::WebAccessConfigTool;
pub use web_fetch::WebFetchTool;
pub use web_search_config::WebSearchConfigTool;
pub use web_search_tool::WebSearchTool;
pub use xlsx_read::XlsxReadTool;
pub use auth_profile::ManageAuthProfileTool;
pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool};
@ -511,6 +513,9 @@ pub fn all_tools_with_runtime(
// PPTX text extraction
tool_arcs.push(Arc::new(PptxReadTool::new(security.clone())));
// XLSX text extraction
tool_arcs.push(Arc::new(XlsxReadTool::new(security.clone())));
// Vision tools are always available
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));

1177
src/tools/xlsx_read.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@ -5,6 +5,7 @@
use anyhow::{bail, Context, Result};
use std::env;
use std::fs;
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use std::process::Command;
@ -26,6 +27,13 @@ struct Asset {
browser_download_url: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InstallMethod {
Homebrew,
CargoOrLocal,
Unknown,
}
/// Get the current version of the binary
pub fn current_version() -> &'static str {
env!("CARGO_PKG_VERSION")
@ -213,6 +221,79 @@ fn get_current_exe() -> Result<PathBuf> {
env::current_exe().context("Failed to get current executable path")
}
fn detect_install_method_for_path(resolved_path: &Path, home_dir: Option<&Path>) -> InstallMethod {
let lower = resolved_path.to_string_lossy().to_ascii_lowercase();
if lower.contains("/cellar/zeroclaw/") || lower.contains("/homebrew/cellar/zeroclaw/") {
return InstallMethod::Homebrew;
}
if let Some(home) = home_dir {
if resolved_path.starts_with(home.join(".cargo").join("bin"))
|| resolved_path.starts_with(home.join(".local").join("bin"))
{
return InstallMethod::CargoOrLocal;
}
}
InstallMethod::Unknown
}
fn detect_install_method(current_exe: &Path) -> InstallMethod {
let resolved = fs::canonicalize(current_exe).unwrap_or_else(|_| current_exe.to_path_buf());
let home_dir = env::var_os("HOME").map(PathBuf::from);
detect_install_method_for_path(&resolved, home_dir.as_deref())
}
/// Print human-friendly update instructions based on detected install method.
pub fn print_update_instructions() -> Result<()> {
let current_exe = get_current_exe()?;
let install_method = detect_install_method(&current_exe);
println!("ZeroClaw update guide");
println!("Detected binary: {}", current_exe.display());
println!();
println!("1) Check if a new release exists:");
println!(" zeroclaw update --check");
println!();
match install_method {
InstallMethod::Homebrew => {
println!("Detected install method: Homebrew");
println!("Recommended update commands:");
println!(" brew update");
println!(" brew upgrade zeroclaw");
println!(" zeroclaw --version");
println!();
println!(
"Tip: avoid `zeroclaw update` on Homebrew installs unless you intentionally want to override the managed binary."
);
}
InstallMethod::CargoOrLocal => {
println!("Detected install method: local binary (~/.cargo/bin or ~/.local/bin)");
println!("Recommended update command:");
println!(" zeroclaw update");
println!("Optional force reinstall:");
println!(" zeroclaw update --force");
println!("Verify:");
println!(" zeroclaw --version");
}
InstallMethod::Unknown => {
println!("Detected install method: unknown");
println!("Try the built-in updater first:");
println!(" zeroclaw update");
println!(
"If your package manager owns the binary, use that manager's upgrade command."
);
println!("Verify:");
println!(" zeroclaw --version");
}
}
println!();
println!("Release source: https://github.com/{GITHUB_REPO}/releases/latest");
Ok(())
}
/// Replace the current binary with the new one
fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> {
// On Windows, we can't replace a running executable directly
@ -226,11 +307,43 @@ fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> {
let _ = fs::remove_file(&old_path);
}
// On Unix, we can overwrite the running executable
// On Unix, stage the binary in the destination directory first.
// This avoids cross-filesystem rename failures (EXDEV) from temp dirs.
#[cfg(unix)]
{
// Use rename for atomic replacement on Unix
fs::rename(new_binary, current_exe).context("Failed to replace binary")?;
use std::os::unix::fs::PermissionsExt;
let parent = current_exe
.parent()
.context("Current executable has no parent directory")?;
let binary_name = current_exe
.file_name()
.context("Current executable path is missing a file name")?
.to_string_lossy()
.into_owned();
let staged_path = parent.join(format!(".{binary_name}.new"));
let backup_path = parent.join(format!(".{binary_name}.bak"));
fs::copy(new_binary, &staged_path).context("Failed to stage updated binary")?;
fs::set_permissions(&staged_path, fs::Permissions::from_mode(0o755))
.context("Failed to set permissions on staged binary")?;
if let Err(err) = fs::remove_file(&backup_path) {
if err.kind() != ErrorKind::NotFound {
return Err(err).context("Failed to remove stale backup binary");
}
}
fs::rename(current_exe, &backup_path).context("Failed to backup current binary")?;
if let Err(err) = fs::rename(&staged_path, current_exe) {
let _ = fs::rename(&backup_path, current_exe);
let _ = fs::remove_file(&staged_path);
return Err(err).context("Failed to activate updated binary");
}
// Best-effort cleanup of backup.
let _ = fs::remove_file(&backup_path);
}
Ok(())
@ -258,6 +371,7 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
println!();
let current_exe = get_current_exe()?;
let install_method = detect_install_method(&current_exe);
println!("Current binary: {}", current_exe.display());
println!("Current version: v{}", current_version());
println!();
@ -268,6 +382,31 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
println!("Latest version: {}", release.tag_name);
if check_only {
println!();
if latest_version == current_version() {
println!("✅ Already up to date.");
} else {
println!(
"Update available: {} -> {}",
current_version(),
latest_version
);
println!("Run `zeroclaw update` to install the update.");
}
return Ok(());
}
if install_method == InstallMethod::Homebrew && !force {
println!();
println!("Detected a Homebrew-managed installation.");
println!("Use `brew upgrade zeroclaw` for the safest update path.");
println!(
"Run `zeroclaw update --force` only if you intentionally want to override Homebrew."
);
return Ok(());
}
// Check if update is needed
if latest_version == current_version() && !force {
println!();
@ -275,17 +414,6 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
return Ok(());
}
if check_only {
println!();
println!(
"Update available: {} -> {}",
current_version(),
latest_version
);
println!("Run `zeroclaw update` to install the update.");
return Ok(());
}
println!();
println!(
"Updating from v{} to {}...",
@ -315,3 +443,50 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn archive_name_uses_zip_for_windows_and_targz_elsewhere() {
assert_eq!(
get_archive_name("x86_64-pc-windows-msvc"),
"zeroclaw-x86_64-pc-windows-msvc.zip"
);
assert_eq!(
get_archive_name("x86_64-unknown-linux-gnu"),
"zeroclaw-x86_64-unknown-linux-gnu.tar.gz"
);
}
#[test]
fn detect_install_method_identifies_homebrew_paths() {
let path = Path::new("/opt/homebrew/Cellar/zeroclaw/0.1.7/bin/zeroclaw");
let method = detect_install_method_for_path(path, None);
assert_eq!(method, InstallMethod::Homebrew);
}
#[test]
fn detect_install_method_identifies_local_bin_paths() {
let home = Path::new("/Users/example");
let cargo_path = Path::new("/Users/example/.cargo/bin/zeroclaw");
let local_path = Path::new("/Users/example/.local/bin/zeroclaw");
assert_eq!(
detect_install_method_for_path(cargo_path, Some(home)),
InstallMethod::CargoOrLocal
);
assert_eq!(
detect_install_method_for_path(local_path, Some(home)),
InstallMethod::CargoOrLocal
);
}
#[test]
fn detect_install_method_returns_unknown_for_other_paths() {
let path = Path::new("/usr/bin/zeroclaw");
let method = detect_install_method_for_path(path, Some(Path::new("/Users/example")));
assert_eq!(method, InstallMethod::Unknown);
}
}