Merge branch 'main' into feat/wasm-plugin-runtime-exec
This commit is contained in:
commit
1da53f154c
11
.github/workflows/ci-queue-hygiene.yml
vendored
11
.github/workflows/ci-queue-hygiene.yml
vendored
@ -8,7 +8,7 @@ on:
|
||||
apply:
|
||||
description: "Cancel selected queued runs (false = dry-run report only)"
|
||||
required: true
|
||||
default: true
|
||||
default: false
|
||||
type: boolean
|
||||
status:
|
||||
description: "Queued-run status scope"
|
||||
@ -57,22 +57,27 @@ jobs:
|
||||
|
||||
status_scope="queued"
|
||||
max_cancel="120"
|
||||
apply_mode="true"
|
||||
apply_mode="false"
|
||||
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
|
||||
status_scope="${{ github.event.inputs.status || 'queued' }}"
|
||||
max_cancel="${{ github.event.inputs.max_cancel || '120' }}"
|
||||
apply_mode="${{ github.event.inputs.apply || 'true' }}"
|
||||
apply_mode="${{ github.event.inputs.apply || 'false' }}"
|
||||
fi
|
||||
|
||||
cmd=(python3 scripts/ci/queue_hygiene.py
|
||||
--repo "${{ github.repository }}"
|
||||
--status "${status_scope}"
|
||||
--max-cancel "${max_cancel}"
|
||||
--dedupe-workflow "CI Run"
|
||||
--dedupe-workflow "Test E2E"
|
||||
--dedupe-workflow "Docs Deploy"
|
||||
--dedupe-workflow "PR Intake Checks"
|
||||
--dedupe-workflow "PR Labeler"
|
||||
--dedupe-workflow "PR Auto Responder"
|
||||
--dedupe-workflow "Workflow Sanity"
|
||||
--dedupe-workflow "PR Label Policy Check"
|
||||
--dedupe-include-non-pr
|
||||
--non-pr-key branch
|
||||
--output-json artifacts/queue-hygiene-report.json
|
||||
--verbose)
|
||||
|
||||
|
||||
2
.github/workflows/ci-run.yml
vendored
2
.github/workflows/ci-run.yml
vendored
@ -9,7 +9,7 @@ on:
|
||||
branches: [dev, main]
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.event.pull_request.number || github.sha }}
|
||||
group: ci-run-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
|
||||
2
.github/workflows/docs-deploy.yml
vendored
2
.github/workflows/docs-deploy.yml
vendored
@ -41,7 +41,7 @@ on:
|
||||
default: ""
|
||||
|
||||
concurrency:
|
||||
group: docs-deploy-${{ github.event.pull_request.number || github.sha }}
|
||||
group: docs-deploy-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
|
||||
2
.github/workflows/test-e2e.yml
vendored
2
.github/workflows/test-e2e.yml
vendored
@ -14,7 +14,7 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: e2e-${{ github.event.pull_request.number || github.sha }}
|
||||
group: test-e2e-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
|
||||
20
README.md
20
README.md
@ -30,7 +30,7 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.
|
||||
|
||||
<p align="center">
|
||||
<a href="#quick-start">Getting Started</a> |
|
||||
<a href="bootstrap.sh">One-Click Setup</a> |
|
||||
<a href="docs/one-click-bootstrap.md">One-Click Setup</a> |
|
||||
<a href="docs/README.md">Docs Hub</a> |
|
||||
<a href="docs/SUMMARY.md">Docs TOC</a>
|
||||
</p>
|
||||
@ -108,11 +108,11 @@ cargo install zeroclaw
|
||||
### First Run
|
||||
|
||||
```bash
|
||||
# Start the gateway daemon
|
||||
zeroclaw gateway start
|
||||
# Start the gateway (serves the Web Dashboard API/UI)
|
||||
zeroclaw gateway
|
||||
|
||||
# Open the web UI
|
||||
zeroclaw dashboard
|
||||
# Open the dashboard URL shown in startup logs
|
||||
# (default: http://127.0.0.1:3000/)
|
||||
|
||||
# Or chat directly
|
||||
zeroclaw chat "Hello!"
|
||||
@ -120,6 +120,16 @@ zeroclaw chat "Hello!"
|
||||
|
||||
For detailed setup options, see [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md).
|
||||
|
||||
### Installation Docs (Canonical Source)
|
||||
|
||||
Use repository docs as the source of truth for install/setup instructions:
|
||||
|
||||
- [README Quick Start](#quick-start)
|
||||
- [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md)
|
||||
- [docs/getting-started/README.md](docs/getting-started/README.md)
|
||||
|
||||
Issue comments can provide context, but they are not canonical installation documentation.
|
||||
|
||||
## Benchmark Snapshot (ZeroClaw vs OpenClaw, Reproducible)
|
||||
|
||||
Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware.
|
||||
|
||||
@ -29,6 +29,8 @@ Localized hubs: [简体中文](i18n/zh-CN/README.md) · [日本語](i18n/ja/READ
|
||||
| See project PR/issue docs snapshot | [project-triage-snapshot-2026-02-18.md](project-triage-snapshot-2026-02-18.md) |
|
||||
| Perform i18n completion for docs changes | [i18n-guide.md](i18n-guide.md) |
|
||||
|
||||
Installation source-of-truth: keep install/run instructions in repository docs and README pages; issue comments are supplemental context only.
|
||||
|
||||
## Quick Decision Tree (10 seconds)
|
||||
|
||||
- Need first-time setup or install? → [getting-started/README.md](getting-started/README.md)
|
||||
|
||||
@ -15,6 +15,7 @@ Last verified: **February 28, 2026**.
|
||||
| `service` | Manage user-level OS service lifecycle |
|
||||
| `doctor` | Run diagnostics and freshness checks |
|
||||
| `status` | Print current configuration and system summary |
|
||||
| `update` | Check or install latest ZeroClaw release |
|
||||
| `estop` | Engage/resume emergency stop levels and inspect estop state |
|
||||
| `cron` | Manage scheduled tasks |
|
||||
| `models` | Refresh provider model catalogs |
|
||||
@ -103,6 +104,18 @@ Notes:
|
||||
- `zeroclaw service status`
|
||||
- `zeroclaw service uninstall`
|
||||
|
||||
### `update`
|
||||
|
||||
- `zeroclaw update --check` (check for new release, no install)
|
||||
- `zeroclaw update` (install latest release binary for current platform)
|
||||
- `zeroclaw update --force` (reinstall even if current version matches latest)
|
||||
- `zeroclaw update --instructions` (print install-method-specific guidance)
|
||||
|
||||
Notes:
|
||||
|
||||
- If ZeroClaw is installed via Homebrew, prefer `brew upgrade zeroclaw`.
|
||||
- `update --instructions` detects common install methods and prints the safest path.
|
||||
|
||||
### `cron`
|
||||
|
||||
- `zeroclaw cron list`
|
||||
@ -264,6 +277,11 @@ Registry packages are installed to `~/.zeroclaw/workspace/skills/<name>/`.
|
||||
|
||||
Use `skills audit` to manually validate a candidate skill directory (or an installed skill by name) before sharing it.
|
||||
|
||||
Workspace symlink policy:
|
||||
- Symlinked entries under `~/.zeroclaw/workspace/skills/` are blocked by default.
|
||||
- To allow shared local skill directories, set `[skills].trusted_skill_roots` in `config.toml`.
|
||||
- A symlinked skill is accepted only when its resolved canonical target is inside one of the trusted roots.
|
||||
|
||||
Skill manifests (`SKILL.toml`) support `prompts` and `[[tools]]`; both are injected into the agent system prompt at runtime, so the model can follow skill instructions without manually reading skill files.
|
||||
|
||||
### `migrate`
|
||||
|
||||
@ -536,6 +536,7 @@ Notes:
|
||||
|---|---|---|
|
||||
| `open_skills_enabled` | `false` | Opt-in loading/sync of community `open-skills` repository |
|
||||
| `open_skills_dir` | unset | Optional local path for `open-skills` (defaults to `$HOME/open-skills` when enabled) |
|
||||
| `trusted_skill_roots` | `[]` | Allowlist of directory roots for symlink targets in `workspace/skills/*` |
|
||||
| `prompt_injection_mode` | `full` | Skill prompt verbosity: `full` (inline instructions/tools) or `compact` (name/description/location only) |
|
||||
| `clawhub_token` | unset | Optional Bearer token for authenticated ClawhHub skill downloads |
|
||||
|
||||
@ -548,7 +549,8 @@ Notes:
|
||||
- `ZEROCLAW_SKILLS_PROMPT_MODE` accepts `full` or `compact`.
|
||||
- Precedence for enable flag: `ZEROCLAW_OPEN_SKILLS_ENABLED` → `skills.open_skills_enabled` in `config.toml` → default `false`.
|
||||
- `prompt_injection_mode = "compact"` is recommended on low-context local models to reduce startup prompt size while keeping skill files available on demand.
|
||||
- Skill loading and `zeroclaw skills install` both apply a static security audit. Skills that contain symlinks, script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected.
|
||||
- Symlinked workspace skills are blocked by default. Set `trusted_skill_roots` to allow local shared-skill directories after explicit trust review.
|
||||
- `zeroclaw skills install` and `zeroclaw skills audit` apply a static security audit. Skills that contain script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected.
|
||||
- `clawhub_token` is sent as `Authorization: Bearer <token>` when downloading from ClawhHub. Obtain a token from [https://clawhub.ai](https://clawhub.ai) after signing in. Required if the API returns 429 (rate-limited) or 401 (unauthorized) for anonymous requests.
|
||||
|
||||
**ClawhHub token example:**
|
||||
|
||||
@ -20,6 +20,13 @@ If both exist, your shell `PATH` order decides which one runs.
|
||||
|
||||
## 2) Update on macOS
|
||||
|
||||
Quick way to get install-method-specific guidance:
|
||||
|
||||
```bash
|
||||
zeroclaw update --instructions
|
||||
zeroclaw update --check
|
||||
```
|
||||
|
||||
### A) Homebrew install
|
||||
|
||||
```bash
|
||||
@ -54,6 +61,13 @@ Re-run your download/install flow with the latest release asset, then verify:
|
||||
zeroclaw --version
|
||||
```
|
||||
|
||||
You can also use the built-in updater for manual/local installs:
|
||||
|
||||
```bash
|
||||
zeroclaw update
|
||||
zeroclaw --version
|
||||
```
|
||||
|
||||
## 3) Uninstall on macOS
|
||||
|
||||
### A) Stop and remove background service first
|
||||
|
||||
@ -423,11 +423,18 @@ string_to_bool() {
|
||||
}
|
||||
|
||||
guided_input_stream() {
|
||||
if [[ -t 0 ]]; then
|
||||
# Some constrained Linux containers report interactive stdin but deny opening
|
||||
# /dev/stdin directly. Probe readability before selecting it.
|
||||
if [[ -t 0 ]] && (: </dev/stdin) 2>/dev/null; then
|
||||
echo "/dev/stdin"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ -t 0 ]] && (: </proc/self/fd/0) 2>/dev/null; then
|
||||
echo "/proc/self/fd/0"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if (: </dev/tty) 2>/dev/null; then
|
||||
echo "/dev/tty"
|
||||
return 0
|
||||
|
||||
@ -66,6 +66,15 @@ def parse_args() -> argparse.Namespace:
|
||||
action="store_true",
|
||||
help="Also dedupe non-PR runs (push/manual). Default dedupe scope is PR-originated runs only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-pr-key",
|
||||
default="sha",
|
||||
choices=["sha", "branch"],
|
||||
help=(
|
||||
"Identity key mode for non-PR dedupe when --dedupe-include-non-pr is enabled: "
|
||||
"'sha' keeps one run per commit (default), 'branch' keeps one run per branch."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-cancel",
|
||||
type=int,
|
||||
@ -165,7 +174,7 @@ def parse_timestamp(value: str | None) -> datetime:
|
||||
return datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
|
||||
def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]:
|
||||
def run_identity_key(run: dict[str, Any], *, non_pr_key: str) -> tuple[str, str, str, str]:
|
||||
name = str(run.get("name", ""))
|
||||
event = str(run.get("event", ""))
|
||||
head_branch = str(run.get("head_branch", ""))
|
||||
@ -179,7 +188,10 @@ def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]:
|
||||
if pr_number:
|
||||
# For PR traffic, cancel stale runs across synchronize updates for the same PR.
|
||||
return (name, event, f"pr:{pr_number}", "")
|
||||
# For push/manual traffic, key by SHA to avoid collapsing distinct commits.
|
||||
if non_pr_key == "branch":
|
||||
# Branch-level supersedence for push/manual lanes.
|
||||
return (name, event, head_branch, "")
|
||||
# SHA-level supersedence for push/manual lanes.
|
||||
return (name, event, head_branch, head_sha)
|
||||
|
||||
|
||||
@ -189,6 +201,7 @@ def collect_candidates(
|
||||
dedupe_workflows: set[str],
|
||||
*,
|
||||
include_non_pr: bool,
|
||||
non_pr_key: str,
|
||||
) -> tuple[list[dict[str, Any]], Counter[str]]:
|
||||
reasons_by_id: dict[int, set[str]] = defaultdict(set)
|
||||
runs_by_id: dict[int, dict[str, Any]] = {}
|
||||
@ -220,7 +233,7 @@ def collect_candidates(
|
||||
has_pr_context = isinstance(pull_requests, list) and len(pull_requests) > 0
|
||||
if is_pr_event and not has_pr_context and not include_non_pr:
|
||||
continue
|
||||
key = run_identity_key(run)
|
||||
key = run_identity_key(run, non_pr_key=non_pr_key)
|
||||
by_workflow[name][key].append(run)
|
||||
|
||||
for groups in by_workflow.values():
|
||||
@ -324,6 +337,7 @@ def main() -> int:
|
||||
obsolete_workflows,
|
||||
dedupe_workflows,
|
||||
include_non_pr=args.dedupe_include_non_pr,
|
||||
non_pr_key=args.non_pr_key,
|
||||
)
|
||||
|
||||
capped = selected[: max(0, args.max_cancel)]
|
||||
@ -338,6 +352,7 @@ def main() -> int:
|
||||
"obsolete_workflows": sorted(obsolete_workflows),
|
||||
"dedupe_workflows": sorted(dedupe_workflows),
|
||||
"dedupe_include_non_pr": args.dedupe_include_non_pr,
|
||||
"non_pr_key": args.non_pr_key,
|
||||
"max_cancel": args.max_cancel,
|
||||
},
|
||||
"counts": {
|
||||
|
||||
@ -3759,6 +3759,119 @@ class CiScriptsBehaviorTest(unittest.TestCase):
|
||||
planned_ids = [item["id"] for item in report["planned_actions"]]
|
||||
self.assertEqual(planned_ids, [101, 102])
|
||||
|
||||
def test_queue_hygiene_non_pr_branch_mode_dedupes_push_runs(self) -> None:
|
||||
runs_json = self.tmp / "runs-non-pr-branch.json"
|
||||
output_json = self.tmp / "queue-hygiene-non-pr-branch.json"
|
||||
runs_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"workflow_runs": [
|
||||
{
|
||||
"id": 201,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "main",
|
||||
"head_sha": "sha-201",
|
||||
"created_at": "2026-02-27T20:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": 202,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "main",
|
||||
"head_sha": "sha-202",
|
||||
"created_at": "2026-02-27T20:01:00Z",
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "dev",
|
||||
"head_sha": "sha-203",
|
||||
"created_at": "2026-02-27T20:02:00Z",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
proc = run_cmd(
|
||||
[
|
||||
"python3",
|
||||
self._script("queue_hygiene.py"),
|
||||
"--runs-json",
|
||||
str(runs_json),
|
||||
"--dedupe-workflow",
|
||||
"CI Run",
|
||||
"--dedupe-include-non-pr",
|
||||
"--non-pr-key",
|
||||
"branch",
|
||||
"--output-json",
|
||||
str(output_json),
|
||||
]
|
||||
)
|
||||
self.assertEqual(proc.returncode, 0, msg=proc.stderr)
|
||||
|
||||
report = json.loads(output_json.read_text(encoding="utf-8"))
|
||||
self.assertEqual(report["counts"]["candidate_runs_before_cap"], 1)
|
||||
planned_ids = [item["id"] for item in report["planned_actions"]]
|
||||
self.assertEqual(planned_ids, [201])
|
||||
reasons = report["planned_actions"][0]["reasons"]
|
||||
self.assertTrue(any(reason.startswith("dedupe-superseded-by:202") for reason in reasons))
|
||||
self.assertEqual(report["policies"]["non_pr_key"], "branch")
|
||||
|
||||
def test_queue_hygiene_non_pr_sha_mode_keeps_distinct_push_commits(self) -> None:
|
||||
runs_json = self.tmp / "runs-non-pr-sha.json"
|
||||
output_json = self.tmp / "queue-hygiene-non-pr-sha.json"
|
||||
runs_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"workflow_runs": [
|
||||
{
|
||||
"id": 301,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "main",
|
||||
"head_sha": "sha-301",
|
||||
"created_at": "2026-02-27T20:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "main",
|
||||
"head_sha": "sha-302",
|
||||
"created_at": "2026-02-27T20:01:00Z",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
proc = run_cmd(
|
||||
[
|
||||
"python3",
|
||||
self._script("queue_hygiene.py"),
|
||||
"--runs-json",
|
||||
str(runs_json),
|
||||
"--dedupe-workflow",
|
||||
"CI Run",
|
||||
"--dedupe-include-non-pr",
|
||||
"--output-json",
|
||||
str(output_json),
|
||||
]
|
||||
)
|
||||
self.assertEqual(proc.returncode, 0, msg=proc.stderr)
|
||||
|
||||
report = json.loads(output_json.read_text(encoding="utf-8"))
|
||||
self.assertEqual(report["counts"]["candidate_runs_before_cap"], 0)
|
||||
self.assertEqual(report["planned_actions"], [])
|
||||
self.assertEqual(report["policies"]["non_pr_key"], "sha")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
|
||||
use crate::config::schema::{CostEnforcementMode, ModelPricing};
|
||||
use crate::config::{Config, ProgressMode};
|
||||
use crate::cost::{BudgetCheck, CostTracker, UsagePeriod};
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
|
||||
@ -19,9 +21,11 @@ use rustyline::hint::Hinter;
|
||||
use rustyline::validate::Validator;
|
||||
use rustyline::{CompletionType, Config as RlConfig, Context, Editor, Helper};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{BTreeSet, HashSet};
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::fmt::Write;
|
||||
use std::future::Future;
|
||||
use std::io::Write as _;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@ -297,6 +301,7 @@ tokio::task_local! {
|
||||
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
|
||||
static SAFETY_HEARTBEAT_CONFIG: Option<SafetyHeartbeatConfig>;
|
||||
static TOOL_LOOP_PROGRESS_MODE: ProgressMode;
|
||||
static TOOL_LOOP_COST_ENFORCEMENT_CONTEXT: Option<CostEnforcementContext>;
|
||||
}
|
||||
|
||||
/// Configuration for periodic safety-constraint re-injection (heartbeat).
|
||||
@ -308,6 +313,56 @@ pub(crate) struct SafetyHeartbeatConfig {
|
||||
pub interval: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CostEnforcementContext {
|
||||
tracker: Arc<CostTracker>,
|
||||
prices: HashMap<String, ModelPricing>,
|
||||
mode: CostEnforcementMode,
|
||||
route_down_model: Option<String>,
|
||||
reserve_percent: u8,
|
||||
}
|
||||
|
||||
pub(crate) fn create_cost_enforcement_context(
|
||||
cost_config: &crate::config::CostConfig,
|
||||
workspace_dir: &Path,
|
||||
) -> Option<CostEnforcementContext> {
|
||||
if !cost_config.enabled {
|
||||
return None;
|
||||
}
|
||||
let tracker = match CostTracker::new(cost_config.clone(), workspace_dir) {
|
||||
Ok(tracker) => Arc::new(tracker),
|
||||
Err(error) => {
|
||||
tracing::warn!("Cost budget preflight disabled: failed to initialize tracker: {error}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let route_down_model = cost_config
|
||||
.enforcement
|
||||
.route_down_model
|
||||
.clone()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty());
|
||||
Some(CostEnforcementContext {
|
||||
tracker,
|
||||
prices: cost_config.prices.clone(),
|
||||
mode: cost_config.enforcement.mode,
|
||||
route_down_model,
|
||||
reserve_percent: cost_config.enforcement.reserve_percent.min(100),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn scope_cost_enforcement_context<F>(
|
||||
context: Option<CostEnforcementContext>,
|
||||
future: F,
|
||||
) -> F::Output
|
||||
where
|
||||
F: Future,
|
||||
{
|
||||
TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
|
||||
.scope(context, future)
|
||||
.await
|
||||
}
|
||||
|
||||
fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool {
|
||||
interval > 0 && counter > 0 && counter % interval == 0
|
||||
}
|
||||
@ -320,6 +375,100 @@ fn should_emit_tool_progress(mode: ProgressMode) -> bool {
|
||||
mode != ProgressMode::Off
|
||||
}
|
||||
|
||||
fn estimate_prompt_tokens(
|
||||
messages: &[ChatMessage],
|
||||
tools: Option<&[crate::tools::ToolSpec]>,
|
||||
) -> u64 {
|
||||
let message_chars: usize = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
msg.role
|
||||
.len()
|
||||
.saturating_add(msg.content.chars().count())
|
||||
.saturating_add(16)
|
||||
})
|
||||
.sum();
|
||||
let tool_chars: usize = tools
|
||||
.map(|specs| {
|
||||
specs
|
||||
.iter()
|
||||
.map(|spec| serde_json::to_string(spec).map_or(0, |value| value.chars().count()))
|
||||
.sum()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
let total_chars = message_chars.saturating_add(tool_chars);
|
||||
let char_estimate = (total_chars as f64 / 4.0).ceil() as u64;
|
||||
let framing_overhead = (messages.len() as u64).saturating_mul(6).saturating_add(64);
|
||||
char_estimate.saturating_add(framing_overhead)
|
||||
}
|
||||
|
||||
fn lookup_model_pricing(
|
||||
prices: &HashMap<String, ModelPricing>,
|
||||
provider: &str,
|
||||
model: &str,
|
||||
) -> (f64, f64) {
|
||||
let full_name = format!("{provider}/{model}");
|
||||
if let Some(pricing) = prices.get(&full_name) {
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
if let Some(pricing) = prices.get(model) {
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
for (key, pricing) in prices {
|
||||
let key_model = key.split('/').next_back().unwrap_or(key);
|
||||
if model.starts_with(key_model) || key_model.starts_with(model) {
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
let normalized_model = model.replace('-', ".");
|
||||
let normalized_key = key_model.replace('-', ".");
|
||||
if normalized_model.contains(&normalized_key) || normalized_key.contains(&normalized_model)
|
||||
{
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
}
|
||||
(3.0, 15.0)
|
||||
}
|
||||
|
||||
fn estimate_request_cost_usd(
|
||||
context: &CostEnforcementContext,
|
||||
provider: &str,
|
||||
model: &str,
|
||||
messages: &[ChatMessage],
|
||||
tools: Option<&[crate::tools::ToolSpec]>,
|
||||
) -> f64 {
|
||||
let reserve_multiplier = 1.0 + (f64::from(context.reserve_percent) / 100.0);
|
||||
let input_tokens = estimate_prompt_tokens(messages, tools);
|
||||
let output_tokens = (input_tokens / 4).max(256);
|
||||
let input_tokens = ((input_tokens as f64) * reserve_multiplier).ceil() as u64;
|
||||
let output_tokens = ((output_tokens as f64) * reserve_multiplier).ceil() as u64;
|
||||
|
||||
let (input_price, output_price) = lookup_model_pricing(&context.prices, provider, model);
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price.max(0.0);
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price.max(0.0);
|
||||
input_cost + output_cost
|
||||
}
|
||||
|
||||
fn usage_period_label(period: UsagePeriod) -> &'static str {
|
||||
match period {
|
||||
UsagePeriod::Session => "session",
|
||||
UsagePeriod::Day => "daily",
|
||||
UsagePeriod::Month => "monthly",
|
||||
}
|
||||
}
|
||||
|
||||
fn budget_exceeded_message(
|
||||
model: &str,
|
||||
estimated_cost_usd: f64,
|
||||
current_usd: f64,
|
||||
limit_usd: f64,
|
||||
period: UsagePeriod,
|
||||
) -> String {
|
||||
format!(
|
||||
"Budget enforcement blocked request for model '{model}': projected cost (+${estimated_cost_usd:.4}) exceeds {period_label} limit (${limit_usd:.2}, current ${current_usd:.2}).",
|
||||
period_label = usage_period_label(period)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ProgressEntry {
|
||||
name: String,
|
||||
@ -778,36 +927,40 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
|
||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
excluded_tools: &[String],
|
||||
progress_mode: ProgressMode,
|
||||
safety_heartbeat: Option<SafetyHeartbeatConfig>,
|
||||
) -> Result<String> {
|
||||
let reply_target = non_cli_approval_context
|
||||
.as_ref()
|
||||
.map(|ctx| ctx.reply_target.clone());
|
||||
|
||||
SAFETY_HEARTBEAT_CONFIG
|
||||
TOOL_LOOP_PROGRESS_MODE
|
||||
.scope(
|
||||
safety_heartbeat,
|
||||
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
|
||||
non_cli_approval_context,
|
||||
TOOL_LOOP_REPLY_TARGET.scope(
|
||||
reply_target,
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
history,
|
||||
tools_registry,
|
||||
observer,
|
||||
provider_name,
|
||||
model,
|
||||
temperature,
|
||||
silent,
|
||||
approval,
|
||||
channel_name,
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
cancellation_token,
|
||||
on_delta,
|
||||
hooks,
|
||||
excluded_tools,
|
||||
progress_mode,
|
||||
SAFETY_HEARTBEAT_CONFIG.scope(
|
||||
safety_heartbeat,
|
||||
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
|
||||
non_cli_approval_context,
|
||||
TOOL_LOOP_REPLY_TARGET.scope(
|
||||
reply_target,
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
history,
|
||||
tools_registry,
|
||||
observer,
|
||||
provider_name,
|
||||
model,
|
||||
temperature,
|
||||
silent,
|
||||
approval,
|
||||
channel_name,
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
cancellation_token,
|
||||
on_delta,
|
||||
hooks,
|
||||
excluded_tools,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -890,7 +1043,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let progress_mode = TOOL_LOOP_PROGRESS_MODE
|
||||
.try_with(|mode| *mode)
|
||||
.unwrap_or(ProgressMode::Verbose);
|
||||
let cost_enforcement_context = TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten();
|
||||
let mut progress_tracker = ProgressTracker::default();
|
||||
let mut active_model = model.to_string();
|
||||
let bypass_non_cli_approval_for_turn =
|
||||
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
|
||||
if bypass_non_cli_approval_for_turn {
|
||||
@ -898,7 +1056,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"approval_bypass_one_time_all_tools_consumed",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("consumed one-time non-cli allow-all approval token"),
|
||||
@ -950,6 +1108,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
request_messages.push(ChatMessage::user(reminder));
|
||||
}
|
||||
}
|
||||
// Unified path via Provider::chat so provider-specific native tool logic
|
||||
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
|
||||
let request_tools = if use_native_tools {
|
||||
Some(tool_specs.as_slice())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// ── Progress: LLM thinking ────────────────────────────
|
||||
if should_emit_verbose_progress(progress_mode) {
|
||||
@ -963,16 +1128,175 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cost_ctx) = cost_enforcement_context.as_ref() {
|
||||
let mut estimated_cost_usd = estimate_request_cost_usd(
|
||||
cost_ctx,
|
||||
provider_name,
|
||||
active_model.as_str(),
|
||||
&request_messages,
|
||||
request_tools,
|
||||
);
|
||||
|
||||
let mut budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
|
||||
Ok(check) => Some(check),
|
||||
Err(error) => {
|
||||
tracing::warn!("Cost preflight check failed: {error}");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if matches!(cost_ctx.mode, CostEnforcementMode::RouteDown)
|
||||
&& matches!(budget_check, Some(BudgetCheck::Exceeded { .. }))
|
||||
{
|
||||
if let Some(route_down_model) = cost_ctx
|
||||
.route_down_model
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
if route_down_model != active_model {
|
||||
let previous_model = active_model.clone();
|
||||
active_model = route_down_model.to_string();
|
||||
estimated_cost_usd = estimate_request_cost_usd(
|
||||
cost_ctx,
|
||||
provider_name,
|
||||
active_model.as_str(),
|
||||
&request_messages,
|
||||
request_tools,
|
||||
);
|
||||
budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
|
||||
Ok(check) => Some(check),
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
"Cost preflight check failed after route-down: {error}"
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
runtime_trace::record_event(
|
||||
"cost_budget_route_down",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("budget exceeded on primary model; route-down candidate applied"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"from_model": previous_model,
|
||||
"to_model": active_model,
|
||||
"estimated_cost_usd": estimated_cost_usd,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(check) = budget_check {
|
||||
match check {
|
||||
BudgetCheck::Allowed => {}
|
||||
BudgetCheck::Warning {
|
||||
current_usd,
|
||||
limit_usd,
|
||||
period,
|
||||
} => {
|
||||
tracing::warn!(
|
||||
model = active_model.as_str(),
|
||||
period = usage_period_label(period),
|
||||
current_usd,
|
||||
limit_usd,
|
||||
estimated_cost_usd,
|
||||
"Cost budget warning threshold reached"
|
||||
);
|
||||
runtime_trace::record_event(
|
||||
"cost_budget_warning",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("budget warning threshold reached"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"period": usage_period_label(period),
|
||||
"current_usd": current_usd,
|
||||
"limit_usd": limit_usd,
|
||||
"estimated_cost_usd": estimated_cost_usd,
|
||||
}),
|
||||
);
|
||||
}
|
||||
BudgetCheck::Exceeded {
|
||||
current_usd,
|
||||
limit_usd,
|
||||
period,
|
||||
} => match cost_ctx.mode {
|
||||
CostEnforcementMode::Warn => {
|
||||
tracing::warn!(
|
||||
model = active_model.as_str(),
|
||||
period = usage_period_label(period),
|
||||
current_usd,
|
||||
limit_usd,
|
||||
estimated_cost_usd,
|
||||
"Cost budget exceeded (warn mode): continuing request"
|
||||
);
|
||||
runtime_trace::record_event(
|
||||
"cost_budget_exceeded_warn_mode",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("budget exceeded but proceeding due to warn mode"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"period": usage_period_label(period),
|
||||
"current_usd": current_usd,
|
||||
"limit_usd": limit_usd,
|
||||
"estimated_cost_usd": estimated_cost_usd,
|
||||
}),
|
||||
);
|
||||
}
|
||||
CostEnforcementMode::RouteDown | CostEnforcementMode::Block => {
|
||||
let message = budget_exceeded_message(
|
||||
active_model.as_str(),
|
||||
estimated_cost_usd,
|
||||
current_usd,
|
||||
limit_usd,
|
||||
period,
|
||||
);
|
||||
runtime_trace::record_event(
|
||||
"cost_budget_blocked",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&message),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"period": usage_period_label(period),
|
||||
"current_usd": current_usd,
|
||||
"limit_usd": limit_usd,
|
||||
"estimated_cost_usd": estimated_cost_usd,
|
||||
}),
|
||||
);
|
||||
return Err(anyhow::anyhow!(message));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmRequest {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
model: active_model.clone(),
|
||||
messages_count: history.len(),
|
||||
});
|
||||
runtime_trace::record_event(
|
||||
"llm_request",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
None,
|
||||
None,
|
||||
@ -986,23 +1310,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
|
||||
// Fire void hook before LLM call
|
||||
if let Some(hooks) = hooks {
|
||||
hooks.fire_llm_input(history, model).await;
|
||||
hooks.fire_llm_input(history, active_model.as_str()).await;
|
||||
}
|
||||
|
||||
// Unified path via Provider::chat so provider-specific native tool logic
|
||||
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
|
||||
let request_tools = if use_native_tools {
|
||||
Some(tool_specs.as_slice())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let chat_future = provider.chat(
|
||||
ChatRequest {
|
||||
messages: &request_messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
active_model.as_str(),
|
||||
temperature,
|
||||
);
|
||||
|
||||
@ -1032,7 +1348,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
model: active_model.clone(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
@ -1062,7 +1378,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_parse_issue",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(parse_issue),
|
||||
@ -1080,7 +1396,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
@ -1131,7 +1447,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
model: active_model.clone(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(safe_error.clone()),
|
||||
@ -1142,7 +1458,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&safe_error),
|
||||
@ -1195,7 +1511,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_followthrough_retry",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("llm response implied follow-up action but emitted no tool call"),
|
||||
@ -1223,7 +1539,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"turn_final_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
@ -1299,7 +1615,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_result",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&cancelled),
|
||||
@ -1341,7 +1657,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_result",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&blocked),
|
||||
@ -1381,7 +1697,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"approval_bypass_non_cli_session_grant",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("using runtime non-cli session approval grant"),
|
||||
@ -1438,7 +1754,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_result",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&denied),
|
||||
@ -1472,7 +1788,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_result",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&duplicate),
|
||||
@ -1500,7 +1816,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_start",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
None,
|
||||
None,
|
||||
@ -1560,7 +1876,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_call_result",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(outcome.success),
|
||||
outcome.error_reason.as_deref(),
|
||||
@ -1672,7 +1988,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"loop_detected_warning",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("loop pattern detected, injecting self-correction prompt"),
|
||||
@ -1694,7 +2010,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"loop_detected_hard_stop",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("loop persisted after warning, stopping early"),
|
||||
@ -1714,7 +2030,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"tool_loop_exhausted",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_model.as_str()),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("agent exceeded maximum tool iterations"),
|
||||
@ -2150,6 +2466,8 @@ pub async fn run(
|
||||
|
||||
// ── Execute ──────────────────────────────────────────────────
|
||||
let start = Instant::now();
|
||||
let cost_enforcement_context =
|
||||
create_cost_enforcement_context(&config.cost, &config.workspace_dir);
|
||||
|
||||
let mut final_output = String::new();
|
||||
|
||||
@ -2196,8 +2514,9 @@ pub async fn run(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let response = SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
let response = scope_cost_enforcement_context(
|
||||
cost_enforcement_context.clone(),
|
||||
SAFETY_HEARTBEAT_CONFIG.scope(
|
||||
hb_cfg,
|
||||
LOOP_DETECTION_CONFIG.scope(
|
||||
ld_cfg,
|
||||
@ -2220,8 +2539,9 @@ pub async fn run(
|
||||
&[],
|
||||
),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let assistant_key = autosave_memory_key("assistant_resp");
|
||||
@ -2373,8 +2693,9 @@ pub async fn run(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let response = match SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
let response = match scope_cost_enforcement_context(
|
||||
cost_enforcement_context.clone(),
|
||||
SAFETY_HEARTBEAT_CONFIG.scope(
|
||||
hb_cfg,
|
||||
LOOP_DETECTION_CONFIG.scope(
|
||||
ld_cfg,
|
||||
@ -2397,8 +2718,9 @@ pub async fn run(
|
||||
&[],
|
||||
),
|
||||
),
|
||||
)
|
||||
.await
|
||||
),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
@ -2684,6 +3006,8 @@ pub async fn process_message_with_session(
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
|
||||
let cost_enforcement_context =
|
||||
create_cost_enforcement_context(&config.cost, &config.workspace_dir);
|
||||
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
|
||||
Some(SafetyHeartbeatConfig {
|
||||
body: security.summary_for_heartbeat(),
|
||||
@ -2692,8 +3016,9 @@ pub async fn process_message_with_session(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
scope_cost_enforcement_context(
|
||||
cost_enforcement_context,
|
||||
SAFETY_HEARTBEAT_CONFIG.scope(
|
||||
hb_cfg,
|
||||
agent_turn(
|
||||
provider.as_ref(),
|
||||
@ -2707,8 +3032,9 @@ pub async fn process_message_with_session(
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
),
|
||||
)
|
||||
.await
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -3623,6 +3949,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
ProgressMode::Verbose,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
||||
@ -78,7 +78,8 @@ pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{
|
||||
build_shell_policy_instructions, build_tool_instructions_from_specs,
|
||||
run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig,
|
||||
run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext,
|
||||
NonCliApprovalPrompt, SafetyHeartbeatConfig,
|
||||
};
|
||||
use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager};
|
||||
use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError};
|
||||
@ -249,6 +250,7 @@ struct ChannelRuntimeDefaults {
|
||||
api_key: Option<String>,
|
||||
api_url: Option<String>,
|
||||
reliability: crate::config::ReliabilityConfig,
|
||||
cost: crate::config::CostConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -1053,6 +1055,7 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
|
||||
api_key: config.api_key.clone(),
|
||||
api_url: config.api_url.clone(),
|
||||
reliability: config.reliability.clone(),
|
||||
cost: config.cost.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1098,6 +1101,7 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau
|
||||
api_key: ctx.api_key.clone(),
|
||||
api_url: ctx.api_url.clone(),
|
||||
reliability: (*ctx.reliability).clone(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -3664,29 +3668,78 @@ or tune thresholds in config.",
|
||||
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context(
|
||||
&runtime_defaults.cost,
|
||||
ctx.workspace_dir.as_path(),
|
||||
);
|
||||
|
||||
let (approval_prompt_tx, mut approval_prompt_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<NonCliApprovalPrompt>();
|
||||
let non_cli_approval_context = if msg.channel != "cli" && target_channel.is_some() {
|
||||
Some(NonCliApprovalContext {
|
||||
sender: msg.sender.clone(),
|
||||
reply_target: msg.reply_target.clone(),
|
||||
prompt_tx: approval_prompt_tx,
|
||||
})
|
||||
} else {
|
||||
drop(approval_prompt_tx);
|
||||
None
|
||||
};
|
||||
let approval_prompt_dispatcher = if let (Some(channel_ref), true) =
|
||||
(target_channel.as_ref(), non_cli_approval_context.is_some())
|
||||
{
|
||||
let channel = Arc::clone(channel_ref);
|
||||
let reply_target = msg.reply_target.clone();
|
||||
let thread_ts = msg.thread_ts.clone();
|
||||
Some(tokio::spawn(async move {
|
||||
while let Some(prompt) = approval_prompt_rx.recv().await {
|
||||
if let Err(err) = channel
|
||||
.send_approval_prompt(
|
||||
&reply_target,
|
||||
&prompt.request_id,
|
||||
&prompt.tool_name,
|
||||
&prompt.arguments,
|
||||
thread_ts.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"Failed to send non-CLI approval prompt for request {}: {err}",
|
||||
prompt.request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let llm_result = tokio::select! {
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_budget_secs),
|
||||
run_tool_call_loop_with_reply_target(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
Some(ctx.approval_manager.as_ref()),
|
||||
msg.channel.as_str(),
|
||||
Some(msg.reply_target.as_str()),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
delta_tx,
|
||||
ctx.hooks.as_deref(),
|
||||
&excluded_tools_snapshot,
|
||||
progress_mode,
|
||||
crate::agent::loop_::scope_cost_enforcement_context(
|
||||
cost_enforcement_context,
|
||||
run_tool_call_loop_with_non_cli_approval_context(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
Some(ctx.approval_manager.as_ref()),
|
||||
msg.channel.as_str(),
|
||||
non_cli_approval_context,
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
delta_tx,
|
||||
ctx.hooks.as_deref(),
|
||||
&excluded_tools_snapshot,
|
||||
progress_mode,
|
||||
ctx.safety_heartbeat.clone(),
|
||||
),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@ -3694,6 +3747,9 @@ or tune thresholds in config.",
|
||||
if let Some(handle) = draft_updater {
|
||||
let _ = handle.await;
|
||||
}
|
||||
if let Some(handle) = approval_prompt_dispatcher {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
if let Some(token) = typing_cancellation.as_ref() {
|
||||
token.cancel();
|
||||
@ -7656,6 +7712,131 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_prompts_and_waits_for_non_cli_always_ask_approval() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let autonomy_cfg = crate::config::AutonomyConfig {
|
||||
always_ask: vec!["mock_price".to_string()],
|
||||
..crate::config::AutonomyConfig::default()
|
||||
};
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
conversation_locks: Default::default(),
|
||||
session_config: crate::config::AgentSessionConfig::default(),
|
||||
session_manager: None,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
approval_manager: Arc::new(ApprovalManager::from_config(&autonomy_cfg)),
|
||||
safety_heartbeat: None,
|
||||
startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(),
|
||||
});
|
||||
|
||||
let runtime_ctx_for_first_turn = runtime_ctx.clone();
|
||||
let first_turn = tokio::spawn(async move {
|
||||
process_channel_message(
|
||||
runtime_ctx_for_first_turn,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-non-cli-approval-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "What is the BTC price now?".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let request_id = tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
let pending = runtime_ctx.approval_manager.list_non_cli_pending_requests(
|
||||
Some("alice"),
|
||||
Some("telegram"),
|
||||
Some("chat-1"),
|
||||
);
|
||||
if let Some(req) = pending.first() {
|
||||
break req.request_id.clone();
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("pending approval request should be created for always_ask tool");
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx.clone(),
|
||||
traits::ChannelMessage {
|
||||
id: "msg-non-cli-approval-2".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: format!("/approve-allow {request_id}"),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(5), first_turn)
|
||||
.await
|
||||
.expect("first channel turn should finish after approval")
|
||||
.expect("first channel turn task should not panic");
|
||||
|
||||
let sent = channel_impl.sent_messages.lock().await;
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("Approval required for tool `mock_price`")),
|
||||
"channel should emit non-cli approval prompt"
|
||||
);
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("Approved supervised execution for `mock_price`")),
|
||||
"channel should acknowledge explicit approval command"
|
||||
);
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("BTC is currently around")),
|
||||
"tool call should execute after approval and produce final response"
|
||||
);
|
||||
assert!(
|
||||
sent.iter().all(|entry| !entry.contains("Denied by user.")),
|
||||
"always_ask tool should not be silently denied once non-cli approval prompt path is wired"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_denies_approval_management_for_unlisted_sender() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
@ -9232,6 +9413,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
},
|
||||
perplexity_filter: crate::config::PerplexityFilterConfig::default(),
|
||||
outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(),
|
||||
|
||||
@ -1026,6 +1026,11 @@ pub struct SkillsConfig {
|
||||
/// If unset, defaults to `$HOME/open-skills` when enabled.
|
||||
#[serde(default)]
|
||||
pub open_skills_dir: Option<String>,
|
||||
/// Optional allowlist of canonical directory roots for workspace skill symlink targets.
|
||||
/// Symlinked workspace skills are rejected unless their resolved targets are under one
|
||||
/// of these roots. Accepts absolute paths and `~/` home-relative paths.
|
||||
#[serde(default)]
|
||||
pub trusted_skill_roots: Vec<String>,
|
||||
/// Allow script-like files in skills (`.sh`, `.bash`, `.ps1`, shebang shell files).
|
||||
/// Default: `false` (secure by default).
|
||||
#[serde(default)]
|
||||
@ -1195,6 +1200,58 @@ pub struct CostConfig {
|
||||
/// Per-model pricing (USD per 1M tokens)
|
||||
#[serde(default)]
|
||||
pub prices: std::collections::HashMap<String, ModelPricing>,
|
||||
|
||||
/// Runtime budget enforcement policy (`[cost.enforcement]`).
|
||||
#[serde(default)]
|
||||
pub enforcement: CostEnforcementConfig,
|
||||
}
|
||||
|
||||
/// Budget enforcement behavior when projected spend approaches/exceeds limits.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CostEnforcementMode {
|
||||
/// Log warnings only; never block the request.
|
||||
Warn,
|
||||
/// Attempt one downgrade to a cheaper route/model, then block if still over budget.
|
||||
RouteDown,
|
||||
/// Block immediately when projected spend exceeds configured limits.
|
||||
Block,
|
||||
}
|
||||
|
||||
fn default_cost_enforcement_mode() -> CostEnforcementMode {
|
||||
CostEnforcementMode::Warn
|
||||
}
|
||||
|
||||
/// Runtime budget enforcement controls (`[cost.enforcement]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CostEnforcementConfig {
|
||||
/// Enforcement behavior. Default: `warn`.
|
||||
#[serde(default = "default_cost_enforcement_mode")]
|
||||
pub mode: CostEnforcementMode,
|
||||
/// Optional fallback model (or `hint:*`) when `mode = "route_down"`.
|
||||
#[serde(default = "default_route_down_model")]
|
||||
pub route_down_model: Option<String>,
|
||||
/// Extra reserve added to token/cost estimates (percentage, 0-100). Default: `10`.
|
||||
#[serde(default = "default_cost_reserve_percent")]
|
||||
pub reserve_percent: u8,
|
||||
}
|
||||
|
||||
fn default_route_down_model() -> Option<String> {
|
||||
Some("hint:fast".to_string())
|
||||
}
|
||||
|
||||
fn default_cost_reserve_percent() -> u8 {
|
||||
10
|
||||
}
|
||||
|
||||
impl Default for CostEnforcementConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: default_cost_enforcement_mode(),
|
||||
route_down_model: default_route_down_model(),
|
||||
reserve_percent: default_cost_reserve_percent(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-model pricing entry (USD per 1M tokens).
|
||||
@ -1230,6 +1287,7 @@ impl Default for CostConfig {
|
||||
warn_at_percent: default_warn_percent(),
|
||||
allow_override: false,
|
||||
prices: get_default_pricing(),
|
||||
enforcement: CostEnforcementConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7764,6 +7822,44 @@ impl Config {
|
||||
anyhow::bail!("web_search.timeout_secs must be greater than 0");
|
||||
}
|
||||
|
||||
// Cost
|
||||
if self.cost.warn_at_percent > 100 {
|
||||
anyhow::bail!("cost.warn_at_percent must be between 0 and 100");
|
||||
}
|
||||
if self.cost.enforcement.reserve_percent > 100 {
|
||||
anyhow::bail!("cost.enforcement.reserve_percent must be between 0 and 100");
|
||||
}
|
||||
if matches!(self.cost.enforcement.mode, CostEnforcementMode::RouteDown) {
|
||||
let route_down_model = self
|
||||
.cost
|
||||
.enforcement
|
||||
.route_down_model
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|model| !model.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"cost.enforcement.route_down_model must be set when mode is route_down"
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(route_hint) = route_down_model
|
||||
.strip_prefix("hint:")
|
||||
.map(str::trim)
|
||||
.filter(|hint| !hint.is_empty())
|
||||
{
|
||||
if !self
|
||||
.model_routes
|
||||
.iter()
|
||||
.any(|route| route.hint.trim() == route_hint)
|
||||
{
|
||||
anyhow::bail!(
|
||||
"cost.enforcement.route_down_model uses hint '{route_hint}', but no matching [[model_routes]] entry exists"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
if self.scheduler.max_concurrent == 0 {
|
||||
anyhow::bail!("scheduler.max_concurrent must be greater than 0");
|
||||
@ -13738,4 +13834,80 @@ sensitivity = 0.9
|
||||
.validate()
|
||||
.expect("disabled coordination should allow empty lead agent");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn cost_enforcement_defaults_are_stable() {
|
||||
let cost = CostConfig::default();
|
||||
assert_eq!(cost.enforcement.mode, CostEnforcementMode::Warn);
|
||||
assert_eq!(
|
||||
cost.enforcement.route_down_model.as_deref(),
|
||||
Some("hint:fast")
|
||||
);
|
||||
assert_eq!(cost.enforcement.reserve_percent, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn cost_enforcement_config_parses_route_down_mode() {
|
||||
let parsed: CostConfig = toml::from_str(
|
||||
r#"
|
||||
enabled = true
|
||||
|
||||
[enforcement]
|
||||
mode = "route_down"
|
||||
route_down_model = "hint:fast"
|
||||
reserve_percent = 15
|
||||
"#,
|
||||
)
|
||||
.expect("cost enforcement should parse");
|
||||
|
||||
assert!(parsed.enabled);
|
||||
assert_eq!(parsed.enforcement.mode, CostEnforcementMode::RouteDown);
|
||||
assert_eq!(
|
||||
parsed.enforcement.route_down_model.as_deref(),
|
||||
Some("hint:fast")
|
||||
);
|
||||
assert_eq!(parsed.enforcement.reserve_percent, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validation_rejects_cost_enforcement_reserve_over_100() {
|
||||
let mut config = Config::default();
|
||||
config.cost.enforcement.reserve_percent = 150;
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("expected cost.enforcement.reserve_percent validation failure");
|
||||
assert!(err.to_string().contains("cost.enforcement.reserve_percent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validation_rejects_route_down_hint_without_matching_route() {
|
||||
let mut config = Config::default();
|
||||
config.cost.enforcement.mode = CostEnforcementMode::RouteDown;
|
||||
config.cost.enforcement.route_down_model = Some("hint:fast".to_string());
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("route_down hint should require a matching model route");
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("cost.enforcement.route_down_model uses hint 'fast'"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn validation_accepts_route_down_hint_with_matching_route() {
|
||||
let mut config = Config::default();
|
||||
config.cost.enforcement.mode = CostEnforcementMode::RouteDown;
|
||||
config.cost.enforcement.route_down_model = Some("hint:fast".to_string());
|
||||
config.model_routes = vec![ModelRouteConfig {
|
||||
hint: "fast".to_string(),
|
||||
provider: "openrouter".to_string(),
|
||||
model: "openai/gpt-4.1-mini".to_string(),
|
||||
api_key: None,
|
||||
max_tokens: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
config
|
||||
.validate()
|
||||
.expect("matching route_down hint route should validate");
|
||||
}
|
||||
}
|
||||
|
||||
61
src/main.rs
61
src/main.rs
@ -333,15 +333,20 @@ the binary location.
|
||||
Examples:
|
||||
zeroclaw update # Update to latest version
|
||||
zeroclaw update --check # Check for updates without installing
|
||||
zeroclaw update --instructions # Show install-method-specific update instructions
|
||||
zeroclaw update --force # Reinstall even if already up to date")]
|
||||
Update {
|
||||
/// Check for updates without installing
|
||||
#[arg(long)]
|
||||
#[arg(long, conflicts_with_all = ["force", "instructions"])]
|
||||
check: bool,
|
||||
|
||||
/// Force update even if already at latest version
|
||||
#[arg(long)]
|
||||
#[arg(long, conflicts_with = "instructions")]
|
||||
force: bool,
|
||||
|
||||
/// Show human-friendly update instructions for your installation method
|
||||
#[arg(long, conflicts_with_all = ["check", "force"])]
|
||||
instructions: bool,
|
||||
},
|
||||
|
||||
/// Engage, inspect, and resume emergency-stop states.
|
||||
@ -1107,9 +1112,18 @@ async fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Commands::Update { check, force } => {
|
||||
update::self_update(force, check).await?;
|
||||
Ok(())
|
||||
Commands::Update {
|
||||
check,
|
||||
force,
|
||||
instructions,
|
||||
} => {
|
||||
if instructions {
|
||||
update::print_update_instructions()?;
|
||||
Ok(())
|
||||
} else {
|
||||
update::self_update(force, check).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Commands::Estop {
|
||||
@ -2630,4 +2644,41 @@ mod tests {
|
||||
);
|
||||
assert_eq!(payload["nested"]["non_secret"], serde_json::json!("ok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_help_mentions_instructions_flag() {
|
||||
let cmd = Cli::command();
|
||||
let update_cmd = cmd
|
||||
.get_subcommands()
|
||||
.find(|subcommand| subcommand.get_name() == "update")
|
||||
.expect("update subcommand must exist");
|
||||
|
||||
let mut output = Vec::new();
|
||||
update_cmd
|
||||
.clone()
|
||||
.write_long_help(&mut output)
|
||||
.expect("help generation should succeed");
|
||||
let help = String::from_utf8(output).expect("help output should be utf-8");
|
||||
|
||||
assert!(help.contains("--instructions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_cli_parses_instructions_flag() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "update", "--instructions"])
|
||||
.expect("update --instructions should parse");
|
||||
|
||||
match cli.command {
|
||||
Commands::Update {
|
||||
check,
|
||||
force,
|
||||
instructions,
|
||||
} => {
|
||||
assert!(!check);
|
||||
assert!(!force);
|
||||
assert!(instructions);
|
||||
}
|
||||
other => panic!("expected update command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,8 +103,9 @@ const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 5] = [
|
||||
const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 6] = [
|
||||
SQLITE_PROFILE,
|
||||
SQLITE_QDRANT_HYBRID_PROFILE,
|
||||
LUCID_PROFILE,
|
||||
CORTEX_MEM_PROFILE,
|
||||
MARKDOWN_PROFILE,
|
||||
@ -194,12 +195,13 @@ mod tests {
|
||||
#[test]
|
||||
fn selectable_backends_are_ordered_for_onboarding() {
|
||||
let backends = selectable_memory_backends();
|
||||
assert_eq!(backends.len(), 5);
|
||||
assert_eq!(backends.len(), 6);
|
||||
assert_eq!(backends[0].key, "sqlite");
|
||||
assert_eq!(backends[1].key, "lucid");
|
||||
assert_eq!(backends[2].key, "cortex-mem");
|
||||
assert_eq!(backends[3].key, "markdown");
|
||||
assert_eq!(backends[4].key, "none");
|
||||
assert_eq!(backends[1].key, "sqlite_qdrant_hybrid");
|
||||
assert_eq!(backends[2].key, "lucid");
|
||||
assert_eq!(backends[3].key, "cortex-mem");
|
||||
assert_eq!(backends[4].key, "markdown");
|
||||
assert_eq!(backends[5].key, "none");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -4091,9 +4091,68 @@ fn setup_memory() -> Result<MemoryConfig> {
|
||||
|
||||
let mut config = memory_config_defaults_for_backend(backend);
|
||||
config.auto_save = auto_save;
|
||||
|
||||
if classify_memory_backend(backend) == MemoryBackendKind::SqliteQdrantHybrid {
|
||||
configure_hybrid_qdrant_memory(&mut config)?;
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn configure_hybrid_qdrant_memory(config: &mut MemoryConfig) -> Result<()> {
|
||||
print_bullet("Hybrid memory keeps local SQLite metadata and uses Qdrant for semantic ranking.");
|
||||
print_bullet("SQLite storage path stays at the default workspace database.");
|
||||
|
||||
let qdrant_url_default = config
|
||||
.qdrant
|
||||
.url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "http://localhost:6333".to_string());
|
||||
let qdrant_url: String = Input::new()
|
||||
.with_prompt(" Qdrant URL")
|
||||
.default(qdrant_url_default)
|
||||
.interact_text()?;
|
||||
let qdrant_url = qdrant_url.trim();
|
||||
if qdrant_url.is_empty() {
|
||||
bail!("Qdrant URL is required for sqlite_qdrant_hybrid backend");
|
||||
}
|
||||
config.qdrant.url = Some(qdrant_url.to_string());
|
||||
|
||||
let qdrant_collection: String = Input::new()
|
||||
.with_prompt(" Qdrant collection")
|
||||
.default(config.qdrant.collection.clone())
|
||||
.interact_text()?;
|
||||
let qdrant_collection = qdrant_collection.trim();
|
||||
if !qdrant_collection.is_empty() {
|
||||
config.qdrant.collection = qdrant_collection.to_string();
|
||||
}
|
||||
|
||||
let qdrant_api_key: String = Input::new()
|
||||
.with_prompt(" Qdrant API key (optional, Enter to skip)")
|
||||
.allow_empty(true)
|
||||
.interact_text()?;
|
||||
let qdrant_api_key = qdrant_api_key.trim();
|
||||
config.qdrant.api_key = if qdrant_api_key.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(qdrant_api_key.to_string())
|
||||
};
|
||||
|
||||
println!(
|
||||
" {} Qdrant: {} (collection: {}, api key: {})",
|
||||
style("✓").green().bold(),
|
||||
style(config.qdrant.url.as_deref().unwrap_or_default()).green(),
|
||||
style(&config.qdrant.collection).green(),
|
||||
if config.qdrant.api_key.is_some() {
|
||||
style("set").green().to_string()
|
||||
} else {
|
||||
style("not set").dim().to_string()
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn setup_identity_backend() -> Result<IdentityConfig> {
|
||||
print_bullet("Choose the identity format ZeroClaw should scaffold for this workspace.");
|
||||
print_bullet("You can switch later in config.toml under [identity].");
|
||||
@ -8515,10 +8574,11 @@ mod tests {
|
||||
#[test]
|
||||
fn backend_key_from_choice_maps_supported_backends() {
|
||||
assert_eq!(backend_key_from_choice(0), "sqlite");
|
||||
assert_eq!(backend_key_from_choice(1), "lucid");
|
||||
assert_eq!(backend_key_from_choice(2), "cortex-mem");
|
||||
assert_eq!(backend_key_from_choice(3), "markdown");
|
||||
assert_eq!(backend_key_from_choice(4), "none");
|
||||
assert_eq!(backend_key_from_choice(1), "sqlite_qdrant_hybrid");
|
||||
assert_eq!(backend_key_from_choice(2), "lucid");
|
||||
assert_eq!(backend_key_from_choice(3), "cortex-mem");
|
||||
assert_eq!(backend_key_from_choice(4), "markdown");
|
||||
assert_eq!(backend_key_from_choice(5), "none");
|
||||
assert_eq!(backend_key_from_choice(999), "sqlite");
|
||||
}
|
||||
|
||||
@ -8560,6 +8620,18 @@ mod tests {
|
||||
assert_eq!(config.embedding_cache_size, 10000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_config_defaults_for_hybrid_enable_sqlite_hygiene() {
|
||||
let config = memory_config_defaults_for_backend("sqlite_qdrant_hybrid");
|
||||
assert_eq!(config.backend, "sqlite_qdrant_hybrid");
|
||||
assert!(config.auto_save);
|
||||
assert!(config.hygiene_enabled);
|
||||
assert_eq!(config.archive_after_days, 7);
|
||||
assert_eq!(config.purge_after_days, 30);
|
||||
assert_eq!(config.embedding_cache_size, 10000);
|
||||
assert_eq!(config.qdrant.collection, "zeroclaw_memories");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_config_defaults_for_none_disable_sqlite_hygiene() {
|
||||
let config = memory_config_defaults_for_backend("none");
|
||||
|
||||
@ -80,7 +80,7 @@ fn default_version() -> String {
|
||||
|
||||
/// Load all skills from the workspace skills directory
|
||||
pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
|
||||
load_skills_with_open_skills_config(workspace_dir, None, None, None)
|
||||
load_skills_with_open_skills_config(workspace_dir, None, None, None, None)
|
||||
}
|
||||
|
||||
/// Load skills using runtime config values (preferred at runtime).
|
||||
@ -90,6 +90,7 @@ pub fn load_skills_with_config(workspace_dir: &Path, config: &crate::config::Con
|
||||
Some(config.skills.open_skills_enabled),
|
||||
config.skills.open_skills_dir.as_deref(),
|
||||
Some(config.skills.allow_scripts),
|
||||
Some(&config.skills.trusted_skill_roots),
|
||||
)
|
||||
}
|
||||
|
||||
@ -98,9 +99,12 @@ fn load_skills_with_open_skills_config(
|
||||
config_open_skills_enabled: Option<bool>,
|
||||
config_open_skills_dir: Option<&str>,
|
||||
config_allow_scripts: Option<bool>,
|
||||
config_trusted_skill_roots: Option<&[String]>,
|
||||
) -> Vec<Skill> {
|
||||
let mut skills = Vec::new();
|
||||
let allow_scripts = config_allow_scripts.unwrap_or(false);
|
||||
let trusted_skill_roots =
|
||||
resolve_trusted_skill_roots(workspace_dir, config_trusted_skill_roots.unwrap_or(&[]));
|
||||
|
||||
if let Some(open_skills_dir) =
|
||||
ensure_open_skills_repo(config_open_skills_enabled, config_open_skills_dir)
|
||||
@ -108,16 +112,113 @@ fn load_skills_with_open_skills_config(
|
||||
skills.extend(load_open_skills(&open_skills_dir, allow_scripts));
|
||||
}
|
||||
|
||||
skills.extend(load_workspace_skills(workspace_dir, allow_scripts));
|
||||
skills.extend(load_workspace_skills(
|
||||
workspace_dir,
|
||||
allow_scripts,
|
||||
&trusted_skill_roots,
|
||||
));
|
||||
skills
|
||||
}
|
||||
|
||||
fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
fn load_workspace_skills(
|
||||
workspace_dir: &Path,
|
||||
allow_scripts: bool,
|
||||
trusted_skill_roots: &[PathBuf],
|
||||
) -> Vec<Skill> {
|
||||
let skills_dir = workspace_dir.join("skills");
|
||||
load_skills_from_directory(&skills_dir, allow_scripts)
|
||||
load_skills_from_directory(&skills_dir, allow_scripts, trusted_skill_roots)
|
||||
}
|
||||
|
||||
fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
fn resolve_trusted_skill_roots(workspace_dir: &Path, raw_roots: &[String]) -> Vec<PathBuf> {
|
||||
let home_dir = UserDirs::new().map(|dirs| dirs.home_dir().to_path_buf());
|
||||
let mut resolved = Vec::new();
|
||||
|
||||
for raw in raw_roots {
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let expanded = if trimmed == "~" {
|
||||
home_dir.clone().unwrap_or_else(|| PathBuf::from(trimmed))
|
||||
} else if let Some(rest) = trimmed
|
||||
.strip_prefix("~/")
|
||||
.or_else(|| trimmed.strip_prefix("~\\"))
|
||||
{
|
||||
home_dir
|
||||
.as_ref()
|
||||
.map(|home| home.join(rest))
|
||||
.unwrap_or_else(|| PathBuf::from(trimmed))
|
||||
} else {
|
||||
PathBuf::from(trimmed)
|
||||
};
|
||||
|
||||
let candidate = if expanded.is_relative() {
|
||||
workspace_dir.join(expanded)
|
||||
} else {
|
||||
expanded
|
||||
};
|
||||
|
||||
match candidate.canonicalize() {
|
||||
Ok(canonical) if canonical.is_dir() => resolved.push(canonical),
|
||||
Ok(canonical) => tracing::warn!(
|
||||
"ignoring [skills].trusted_skill_roots entry '{}': canonical path is not a directory ({})",
|
||||
trimmed,
|
||||
canonical.display()
|
||||
),
|
||||
Err(err) => tracing::warn!(
|
||||
"ignoring [skills].trusted_skill_roots entry '{}': failed to canonicalize {} ({err})",
|
||||
trimmed,
|
||||
candidate.display()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
resolved.sort();
|
||||
resolved.dedup();
|
||||
resolved
|
||||
}
|
||||
|
||||
fn enforce_workspace_skill_symlink_trust(
|
||||
path: &Path,
|
||||
trusted_skill_roots: &[PathBuf],
|
||||
) -> Result<()> {
|
||||
let canonical_target = path
|
||||
.canonicalize()
|
||||
.with_context(|| format!("failed to resolve skill symlink target {}", path.display()))?;
|
||||
|
||||
if !canonical_target.is_dir() {
|
||||
anyhow::bail!(
|
||||
"symlink target is not a directory: {}",
|
||||
canonical_target.display()
|
||||
);
|
||||
}
|
||||
|
||||
if trusted_skill_roots
|
||||
.iter()
|
||||
.any(|root| canonical_target.starts_with(root))
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if trusted_skill_roots.is_empty() {
|
||||
anyhow::bail!(
|
||||
"symlink target {} is not allowed because [skills].trusted_skill_roots is empty",
|
||||
canonical_target.display()
|
||||
);
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"symlink target {} is outside configured [skills].trusted_skill_roots",
|
||||
canonical_target.display()
|
||||
);
|
||||
}
|
||||
|
||||
fn load_skills_from_directory(
|
||||
skills_dir: &Path,
|
||||
allow_scripts: bool,
|
||||
trusted_skill_roots: &[PathBuf],
|
||||
) -> Vec<Skill> {
|
||||
if !skills_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
@ -130,7 +231,26 @@ fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec<Ski
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if !path.is_dir() {
|
||||
let metadata = match std::fs::symlink_metadata(&path) {
|
||||
Ok(meta) => meta,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"skipping skill entry {}: failed to read metadata ({err})",
|
||||
path.display()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.file_type().is_symlink() {
|
||||
if let Err(err) = enforce_workspace_skill_symlink_trust(&path, trusted_skill_roots) {
|
||||
tracing::warn!(
|
||||
"skipping untrusted symlinked skill entry {}: {err}",
|
||||
path.display()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
} else if !metadata.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -180,7 +300,7 @@ fn load_open_skills(repo_dir: &Path, allow_scripts: bool) -> Vec<Skill> {
|
||||
// as executable skills.
|
||||
let nested_skills_dir = repo_dir.join("skills");
|
||||
if nested_skills_dir.is_dir() {
|
||||
return load_skills_from_directory(&nested_skills_dir, allow_scripts);
|
||||
return load_skills_from_directory(&nested_skills_dir, allow_scripts, &[]);
|
||||
}
|
||||
|
||||
let mut skills = Vec::new();
|
||||
@ -2137,6 +2257,20 @@ pub fn handle_command(command: crate::SkillCommands, config: &crate::config::Con
|
||||
anyhow::bail!("Skill source or installed skill not found: {source}");
|
||||
}
|
||||
|
||||
let trusted_skill_roots =
|
||||
resolve_trusted_skill_roots(workspace_dir, &config.skills.trusted_skill_roots);
|
||||
if let Ok(metadata) = std::fs::symlink_metadata(&target) {
|
||||
if metadata.file_type().is_symlink() {
|
||||
enforce_workspace_skill_symlink_trust(&target, &trusted_skill_roots)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"trusted-symlink policy rejected audit target {}",
|
||||
target.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let report = audit::audit_skill_directory_with_options(
|
||||
&target,
|
||||
audit::SkillAuditOptions {
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::skills::skills_dir;
|
||||
use crate::config::Config;
|
||||
use crate::skills::{handle_command, load_skills_with_config, skills_dir};
|
||||
use crate::SkillCommands;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
|
||||
@ -83,7 +85,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_skills_symlink_permissions_and_safety() {
|
||||
async fn test_workspace_symlink_loading_requires_trusted_roots() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let workspace_dir = tmp.path().join("workspace");
|
||||
tokio::fs::create_dir_all(&workspace_dir).await.unwrap();
|
||||
@ -93,7 +95,6 @@ mod tests {
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
// Test case: Symlink outside workspace should be allowed (user responsibility)
|
||||
let outside_dir = tmp.path().join("outside_skill");
|
||||
tokio::fs::create_dir_all(&outside_dir).await.unwrap();
|
||||
tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent")
|
||||
@ -102,15 +103,74 @@ mod tests {
|
||||
|
||||
let dest_link = skills_path.join("outside_skill");
|
||||
let result = std::os::unix::fs::symlink(&outside_dir, &dest_link);
|
||||
assert!(result.is_ok(), "symlink creation should succeed on unix");
|
||||
|
||||
let mut config = Config::default();
|
||||
config.workspace_dir = workspace_dir.clone();
|
||||
config.config_path = workspace_dir.join("config.toml");
|
||||
|
||||
let blocked = load_skills_with_config(&workspace_dir, &config);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should allow symlinking to directories outside workspace"
|
||||
blocked.is_empty(),
|
||||
"symlinked skill should be rejected when trusted_skill_roots is empty"
|
||||
);
|
||||
|
||||
// Should still be readable
|
||||
let content = tokio::fs::read_to_string(dest_link.join("SKILL.md")).await;
|
||||
assert!(content.is_ok());
|
||||
assert!(content.unwrap().contains("Outside Skill"));
|
||||
config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()];
|
||||
let allowed = load_skills_with_config(&workspace_dir, &config);
|
||||
assert_eq!(
|
||||
allowed.len(),
|
||||
1,
|
||||
"symlinked skill should load when target is inside trusted roots"
|
||||
);
|
||||
assert_eq!(allowed[0].name, "outside_skill");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_skills_audit_respects_trusted_symlink_roots() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let workspace_dir = tmp.path().join("workspace");
|
||||
tokio::fs::create_dir_all(&workspace_dir).await.unwrap();
|
||||
|
||||
let skills_path = skills_dir(&workspace_dir);
|
||||
tokio::fs::create_dir_all(&skills_path).await.unwrap();
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let outside_dir = tmp.path().join("outside_skill");
|
||||
tokio::fs::create_dir_all(&outside_dir).await.unwrap();
|
||||
tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent")
|
||||
.await
|
||||
.unwrap();
|
||||
let link_path = skills_path.join("outside_skill");
|
||||
std::os::unix::fs::symlink(&outside_dir, &link_path).unwrap();
|
||||
|
||||
let mut config = Config::default();
|
||||
config.workspace_dir = workspace_dir.clone();
|
||||
config.config_path = workspace_dir.join("config.toml");
|
||||
|
||||
let blocked = handle_command(
|
||||
SkillCommands::Audit {
|
||||
source: "outside_skill".to_string(),
|
||||
},
|
||||
&config,
|
||||
);
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"audit should reject symlink target when trusted roots are not configured"
|
||||
);
|
||||
|
||||
config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()];
|
||||
let allowed = handle_command(
|
||||
SkillCommands::Audit {
|
||||
source: "outside_skill".to_string(),
|
||||
},
|
||||
&config,
|
||||
);
|
||||
assert!(
|
||||
allowed.is_ok(),
|
||||
"audit should pass when symlink target is inside a trusted root"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,12 +107,27 @@ impl McpTransportConn for StdioTransport {
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
|
||||
.await
|
||||
.context("timeout waiting for MCP response")??;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
|
||||
Ok(resp)
|
||||
let deadline = std::time::Instant::now() + Duration::from_secs(RECV_TIMEOUT_SECS);
|
||||
loop {
|
||||
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
|
||||
if remaining.is_zero() {
|
||||
bail!("timeout waiting for MCP response");
|
||||
}
|
||||
let resp_line = timeout(remaining, self.recv_raw())
|
||||
.await
|
||||
.context("timeout waiting for MCP response")??;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
|
||||
if resp.id.is_none() {
|
||||
// Server-sent notification (e.g. `notifications/initialized`) — skip and
|
||||
// keep waiting for the actual response to our request.
|
||||
tracing::debug!(
|
||||
"MCP stdio: skipping server notification while waiting for response"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
|
||||
@ -82,6 +82,7 @@ pub mod web_access_config;
|
||||
pub mod web_fetch;
|
||||
pub mod web_search_config;
|
||||
pub mod web_search_tool;
|
||||
pub mod xlsx_read;
|
||||
|
||||
pub use apply_patch::ApplyPatchTool;
|
||||
pub use bg_run::{
|
||||
@ -147,6 +148,7 @@ pub use web_access_config::WebAccessConfigTool;
|
||||
pub use web_fetch::WebFetchTool;
|
||||
pub use web_search_config::WebSearchConfigTool;
|
||||
pub use web_search_tool::WebSearchTool;
|
||||
pub use xlsx_read::XlsxReadTool;
|
||||
|
||||
pub use auth_profile::ManageAuthProfileTool;
|
||||
pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool};
|
||||
@ -511,6 +513,9 @@ pub fn all_tools_with_runtime(
|
||||
// PPTX text extraction
|
||||
tool_arcs.push(Arc::new(PptxReadTool::new(security.clone())));
|
||||
|
||||
// XLSX text extraction
|
||||
tool_arcs.push(Arc::new(XlsxReadTool::new(security.clone())));
|
||||
|
||||
// Vision tools are always available
|
||||
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
|
||||
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));
|
||||
|
||||
1177
src/tools/xlsx_read.rs
Normal file
1177
src/tools/xlsx_read.rs
Normal file
File diff suppressed because it is too large
Load Diff
203
src/update.rs
203
src/update.rs
@ -5,6 +5,7 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
@ -26,6 +27,13 @@ struct Asset {
|
||||
browser_download_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum InstallMethod {
|
||||
Homebrew,
|
||||
CargoOrLocal,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Get the current version of the binary
|
||||
pub fn current_version() -> &'static str {
|
||||
env!("CARGO_PKG_VERSION")
|
||||
@ -213,6 +221,79 @@ fn get_current_exe() -> Result<PathBuf> {
|
||||
env::current_exe().context("Failed to get current executable path")
|
||||
}
|
||||
|
||||
fn detect_install_method_for_path(resolved_path: &Path, home_dir: Option<&Path>) -> InstallMethod {
|
||||
let lower = resolved_path.to_string_lossy().to_ascii_lowercase();
|
||||
if lower.contains("/cellar/zeroclaw/") || lower.contains("/homebrew/cellar/zeroclaw/") {
|
||||
return InstallMethod::Homebrew;
|
||||
}
|
||||
|
||||
if let Some(home) = home_dir {
|
||||
if resolved_path.starts_with(home.join(".cargo").join("bin"))
|
||||
|| resolved_path.starts_with(home.join(".local").join("bin"))
|
||||
{
|
||||
return InstallMethod::CargoOrLocal;
|
||||
}
|
||||
}
|
||||
|
||||
InstallMethod::Unknown
|
||||
}
|
||||
|
||||
fn detect_install_method(current_exe: &Path) -> InstallMethod {
|
||||
let resolved = fs::canonicalize(current_exe).unwrap_or_else(|_| current_exe.to_path_buf());
|
||||
let home_dir = env::var_os("HOME").map(PathBuf::from);
|
||||
detect_install_method_for_path(&resolved, home_dir.as_deref())
|
||||
}
|
||||
|
||||
/// Print human-friendly update instructions based on detected install method.
|
||||
pub fn print_update_instructions() -> Result<()> {
|
||||
let current_exe = get_current_exe()?;
|
||||
let install_method = detect_install_method(¤t_exe);
|
||||
|
||||
println!("ZeroClaw update guide");
|
||||
println!("Detected binary: {}", current_exe.display());
|
||||
println!();
|
||||
println!("1) Check if a new release exists:");
|
||||
println!(" zeroclaw update --check");
|
||||
println!();
|
||||
|
||||
match install_method {
|
||||
InstallMethod::Homebrew => {
|
||||
println!("Detected install method: Homebrew");
|
||||
println!("Recommended update commands:");
|
||||
println!(" brew update");
|
||||
println!(" brew upgrade zeroclaw");
|
||||
println!(" zeroclaw --version");
|
||||
println!();
|
||||
println!(
|
||||
"Tip: avoid `zeroclaw update` on Homebrew installs unless you intentionally want to override the managed binary."
|
||||
);
|
||||
}
|
||||
InstallMethod::CargoOrLocal => {
|
||||
println!("Detected install method: local binary (~/.cargo/bin or ~/.local/bin)");
|
||||
println!("Recommended update command:");
|
||||
println!(" zeroclaw update");
|
||||
println!("Optional force reinstall:");
|
||||
println!(" zeroclaw update --force");
|
||||
println!("Verify:");
|
||||
println!(" zeroclaw --version");
|
||||
}
|
||||
InstallMethod::Unknown => {
|
||||
println!("Detected install method: unknown");
|
||||
println!("Try the built-in updater first:");
|
||||
println!(" zeroclaw update");
|
||||
println!(
|
||||
"If your package manager owns the binary, use that manager's upgrade command."
|
||||
);
|
||||
println!("Verify:");
|
||||
println!(" zeroclaw --version");
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("Release source: https://github.com/{GITHUB_REPO}/releases/latest");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replace the current binary with the new one
|
||||
fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> {
|
||||
// On Windows, we can't replace a running executable directly
|
||||
@ -226,11 +307,43 @@ fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> {
|
||||
let _ = fs::remove_file(&old_path);
|
||||
}
|
||||
|
||||
// On Unix, we can overwrite the running executable
|
||||
// On Unix, stage the binary in the destination directory first.
|
||||
// This avoids cross-filesystem rename failures (EXDEV) from temp dirs.
|
||||
#[cfg(unix)]
|
||||
{
|
||||
// Use rename for atomic replacement on Unix
|
||||
fs::rename(new_binary, current_exe).context("Failed to replace binary")?;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let parent = current_exe
|
||||
.parent()
|
||||
.context("Current executable has no parent directory")?;
|
||||
let binary_name = current_exe
|
||||
.file_name()
|
||||
.context("Current executable path is missing a file name")?
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let staged_path = parent.join(format!(".{binary_name}.new"));
|
||||
let backup_path = parent.join(format!(".{binary_name}.bak"));
|
||||
|
||||
fs::copy(new_binary, &staged_path).context("Failed to stage updated binary")?;
|
||||
fs::set_permissions(&staged_path, fs::Permissions::from_mode(0o755))
|
||||
.context("Failed to set permissions on staged binary")?;
|
||||
|
||||
if let Err(err) = fs::remove_file(&backup_path) {
|
||||
if err.kind() != ErrorKind::NotFound {
|
||||
return Err(err).context("Failed to remove stale backup binary");
|
||||
}
|
||||
}
|
||||
|
||||
fs::rename(current_exe, &backup_path).context("Failed to backup current binary")?;
|
||||
|
||||
if let Err(err) = fs::rename(&staged_path, current_exe) {
|
||||
let _ = fs::rename(&backup_path, current_exe);
|
||||
let _ = fs::remove_file(&staged_path);
|
||||
return Err(err).context("Failed to activate updated binary");
|
||||
}
|
||||
|
||||
// Best-effort cleanup of backup.
|
||||
let _ = fs::remove_file(&backup_path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -258,6 +371,7 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
|
||||
println!();
|
||||
|
||||
let current_exe = get_current_exe()?;
|
||||
let install_method = detect_install_method(¤t_exe);
|
||||
println!("Current binary: {}", current_exe.display());
|
||||
println!("Current version: v{}", current_version());
|
||||
println!();
|
||||
@ -268,6 +382,31 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
|
||||
|
||||
println!("Latest version: {}", release.tag_name);
|
||||
|
||||
if check_only {
|
||||
println!();
|
||||
if latest_version == current_version() {
|
||||
println!("✅ Already up to date.");
|
||||
} else {
|
||||
println!(
|
||||
"Update available: {} -> {}",
|
||||
current_version(),
|
||||
latest_version
|
||||
);
|
||||
println!("Run `zeroclaw update` to install the update.");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if install_method == InstallMethod::Homebrew && !force {
|
||||
println!();
|
||||
println!("Detected a Homebrew-managed installation.");
|
||||
println!("Use `brew upgrade zeroclaw` for the safest update path.");
|
||||
println!(
|
||||
"Run `zeroclaw update --force` only if you intentionally want to override Homebrew."
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check if update is needed
|
||||
if latest_version == current_version() && !force {
|
||||
println!();
|
||||
@ -275,17 +414,6 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if check_only {
|
||||
println!();
|
||||
println!(
|
||||
"Update available: {} -> {}",
|
||||
current_version(),
|
||||
latest_version
|
||||
);
|
||||
println!("Run `zeroclaw update` to install the update.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"Updating from v{} to {}...",
|
||||
@ -315,3 +443,50 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn archive_name_uses_zip_for_windows_and_targz_elsewhere() {
|
||||
assert_eq!(
|
||||
get_archive_name("x86_64-pc-windows-msvc"),
|
||||
"zeroclaw-x86_64-pc-windows-msvc.zip"
|
||||
);
|
||||
assert_eq!(
|
||||
get_archive_name("x86_64-unknown-linux-gnu"),
|
||||
"zeroclaw-x86_64-unknown-linux-gnu.tar.gz"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_install_method_identifies_homebrew_paths() {
|
||||
let path = Path::new("/opt/homebrew/Cellar/zeroclaw/0.1.7/bin/zeroclaw");
|
||||
let method = detect_install_method_for_path(path, None);
|
||||
assert_eq!(method, InstallMethod::Homebrew);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_install_method_identifies_local_bin_paths() {
|
||||
let home = Path::new("/Users/example");
|
||||
let cargo_path = Path::new("/Users/example/.cargo/bin/zeroclaw");
|
||||
let local_path = Path::new("/Users/example/.local/bin/zeroclaw");
|
||||
|
||||
assert_eq!(
|
||||
detect_install_method_for_path(cargo_path, Some(home)),
|
||||
InstallMethod::CargoOrLocal
|
||||
);
|
||||
assert_eq!(
|
||||
detect_install_method_for_path(local_path, Some(home)),
|
||||
InstallMethod::CargoOrLocal
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_install_method_returns_unknown_for_other_paths() {
|
||||
let path = Path::new("/usr/bin/zeroclaw");
|
||||
let method = detect_install_method_for_path(path, Some(Path::new("/Users/example")));
|
||||
assert_eq!(method, InstallMethod::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user