From ddfbf3d9f86e453bd3db31429bb4253c786fd6ef Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 22:17:53 -0500 Subject: [PATCH 01/13] fix(bootstrap): fallback when /dev/stdin is unreadable in guided mode --- scripts/bootstrap.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index cee7251ad..4bd1ac7a5 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -423,11 +423,18 @@ string_to_bool() { } guided_input_stream() { - if [[ -t 0 ]]; then + # Some constrained Linux containers report interactive stdin but deny opening + # /dev/stdin directly. Probe readability before selecting it. + if [[ -t 0 ]] && (: /dev/null; then echo "/dev/stdin" return 0 fi + if [[ -t 0 ]] && (: /dev/null; then + echo "/proc/self/fd/0" + return 0 + fi + if (: /dev/null; then echo "/dev/tty" return 0 From 0129b5da066daec5ae5fca3a045576cc7368bcfe Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 23:12:24 -0500 Subject: [PATCH 02/13] feat(onboard): add hybrid sqlite+qdrant memory option in wizard --- src/memory/backend.rs | 14 ++++---- src/onboard/wizard.rs | 80 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/memory/backend.rs b/src/memory/backend.rs index c6759fbe8..231f6af4b 100644 --- a/src/memory/backend.rs +++ b/src/memory/backend.rs @@ -103,8 +103,9 @@ const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile { optional_dependency: false, }; -const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 5] = [ +const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 6] = [ SQLITE_PROFILE, + SQLITE_QDRANT_HYBRID_PROFILE, LUCID_PROFILE, CORTEX_MEM_PROFILE, MARKDOWN_PROFILE, @@ -194,12 +195,13 @@ mod tests { #[test] fn selectable_backends_are_ordered_for_onboarding() { let backends = selectable_memory_backends(); - assert_eq!(backends.len(), 5); + assert_eq!(backends.len(), 6); assert_eq!(backends[0].key, "sqlite"); - assert_eq!(backends[1].key, "lucid"); - assert_eq!(backends[2].key, "cortex-mem"); - assert_eq!(backends[3].key, "markdown"); - assert_eq!(backends[4].key, "none"); + assert_eq!(backends[1].key, "sqlite_qdrant_hybrid"); + assert_eq!(backends[2].key, "lucid"); + assert_eq!(backends[3].key, "cortex-mem"); + assert_eq!(backends[4].key, "markdown"); + assert_eq!(backends[5].key, "none"); } #[test] diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index f92359b87..5954227fb 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -4091,9 +4091,68 @@ fn setup_memory() -> Result { let mut config = memory_config_defaults_for_backend(backend); config.auto_save = auto_save; + + if classify_memory_backend(backend) == MemoryBackendKind::SqliteQdrantHybrid { + configure_hybrid_qdrant_memory(&mut config)?; + } + Ok(config) } +fn configure_hybrid_qdrant_memory(config: &mut MemoryConfig) -> Result<()> { + print_bullet("Hybrid memory keeps local SQLite metadata and uses Qdrant for semantic ranking."); + print_bullet("SQLite storage path stays at the default workspace database."); + + let qdrant_url_default = config + .qdrant + .url + .clone() + .unwrap_or_else(|| "http://localhost:6333".to_string()); + let qdrant_url: String = Input::new() + .with_prompt(" Qdrant URL") + .default(qdrant_url_default) + .interact_text()?; + let qdrant_url = qdrant_url.trim(); + if qdrant_url.is_empty() { + bail!("Qdrant URL is required for sqlite_qdrant_hybrid backend"); + } + config.qdrant.url = Some(qdrant_url.to_string()); + + let qdrant_collection: String = Input::new() + .with_prompt(" Qdrant collection") + .default(config.qdrant.collection.clone()) + .interact_text()?; + let qdrant_collection = qdrant_collection.trim(); + if !qdrant_collection.is_empty() { + config.qdrant.collection = qdrant_collection.to_string(); + } + + let qdrant_api_key: String = Input::new() + .with_prompt(" Qdrant API key (optional, Enter to skip)") + .allow_empty(true) + .interact_text()?; + let qdrant_api_key = qdrant_api_key.trim(); + config.qdrant.api_key = if qdrant_api_key.is_empty() { + None + } else { + Some(qdrant_api_key.to_string()) + }; + + println!( + " {} Qdrant: {} (collection: {}, api key: {})", + style("✓").green().bold(), + style(config.qdrant.url.as_deref().unwrap_or_default()).green(), + style(&config.qdrant.collection).green(), + if config.qdrant.api_key.is_some() { + style("set").green().to_string() + } else { + style("not set").dim().to_string() + } + ); + + Ok(()) +} + fn setup_identity_backend() -> Result { print_bullet("Choose the identity format ZeroClaw should scaffold for this workspace."); print_bullet("You can switch later in config.toml under [identity]."); @@ -8515,10 +8574,11 @@ mod tests { #[test] fn backend_key_from_choice_maps_supported_backends() { assert_eq!(backend_key_from_choice(0), "sqlite"); - assert_eq!(backend_key_from_choice(1), "lucid"); - assert_eq!(backend_key_from_choice(2), "cortex-mem"); - assert_eq!(backend_key_from_choice(3), "markdown"); - assert_eq!(backend_key_from_choice(4), "none"); + assert_eq!(backend_key_from_choice(1), "sqlite_qdrant_hybrid"); + assert_eq!(backend_key_from_choice(2), "lucid"); + assert_eq!(backend_key_from_choice(3), "cortex-mem"); + assert_eq!(backend_key_from_choice(4), "markdown"); + assert_eq!(backend_key_from_choice(5), "none"); assert_eq!(backend_key_from_choice(999), "sqlite"); } @@ -8560,6 +8620,18 @@ mod tests { assert_eq!(config.embedding_cache_size, 10000); } + #[test] + fn memory_config_defaults_for_hybrid_enable_sqlite_hygiene() { + let config = memory_config_defaults_for_backend("sqlite_qdrant_hybrid"); + assert_eq!(config.backend, "sqlite_qdrant_hybrid"); + assert!(config.auto_save); + assert!(config.hygiene_enabled); + assert_eq!(config.archive_after_days, 7); + assert_eq!(config.purge_after_days, 30); + assert_eq!(config.embedding_cache_size, 10000); + assert_eq!(config.qdrant.collection, "zeroclaw_memories"); + } + #[test] fn memory_config_defaults_for_none_disable_sqlite_hygiene() { let config = memory_config_defaults_for_backend("none"); From 1ecace23a7c56b87be97b5636f01c3d2cca8aa1a Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 15:00:57 -0500 Subject: [PATCH 03/13] feat(update): add install-aware guidance and safer self-update --- docs/commands-reference.md | 13 ++ .../getting-started/macos-update-uninstall.md | 14 ++ src/main.rs | 61 +++++- src/update.rs | 203 ++++++++++++++++-- 4 files changed, 272 insertions(+), 19 deletions(-) diff --git a/docs/commands-reference.md b/docs/commands-reference.md index c15fc8514..e570d468c 100644 --- a/docs/commands-reference.md +++ b/docs/commands-reference.md @@ -15,6 +15,7 @@ Last verified: **February 28, 2026**. | `service` | Manage user-level OS service lifecycle | | `doctor` | Run diagnostics and freshness checks | | `status` | Print current configuration and system summary | +| `update` | Check or install latest ZeroClaw release | | `estop` | Engage/resume emergency stop levels and inspect estop state | | `cron` | Manage scheduled tasks | | `models` | Refresh provider model catalogs | @@ -103,6 +104,18 @@ Notes: - `zeroclaw service status` - `zeroclaw service uninstall` +### `update` + +- `zeroclaw update --check` (check for new release, no install) +- `zeroclaw update` (install latest release binary for current platform) +- `zeroclaw update --force` (reinstall even if current version matches latest) +- `zeroclaw update --instructions` (print install-method-specific guidance) + +Notes: + +- If ZeroClaw is installed via Homebrew, prefer `brew upgrade zeroclaw`. +- `update --instructions` detects common install methods and prints the safest path. + ### `cron` - `zeroclaw cron list` diff --git a/docs/getting-started/macos-update-uninstall.md b/docs/getting-started/macos-update-uninstall.md index 944cd4ce3..f08bc5042 100644 --- a/docs/getting-started/macos-update-uninstall.md +++ b/docs/getting-started/macos-update-uninstall.md @@ -20,6 +20,13 @@ If both exist, your shell `PATH` order decides which one runs. ## 2) Update on macOS +Quick way to get install-method-specific guidance: + +```bash +zeroclaw update --instructions +zeroclaw update --check +``` + ### A) Homebrew install ```bash @@ -54,6 +61,13 @@ Re-run your download/install flow with the latest release asset, then verify: zeroclaw --version ``` +You can also use the built-in updater for manual/local installs: + +```bash +zeroclaw update +zeroclaw --version +``` + ## 3) Uninstall on macOS ### A) Stop and remove background service first diff --git a/src/main.rs b/src/main.rs index 913ed6139..978235848 100644 --- a/src/main.rs +++ b/src/main.rs @@ -333,15 +333,20 @@ the binary location. Examples: zeroclaw update # Update to latest version zeroclaw update --check # Check for updates without installing + zeroclaw update --instructions # Show install-method-specific update instructions zeroclaw update --force # Reinstall even if already up to date")] Update { /// Check for updates without installing - #[arg(long)] + #[arg(long, conflicts_with_all = ["force", "instructions"])] check: bool, /// Force update even if already at latest version - #[arg(long)] + #[arg(long, conflicts_with = "instructions")] force: bool, + + /// Show human-friendly update instructions for your installation method + #[arg(long, conflicts_with_all = ["check", "force"])] + instructions: bool, }, /// Engage, inspect, and resume emergency-stop states. @@ -1107,9 +1112,18 @@ async fn main() -> Result<()> { Ok(()) } - Commands::Update { check, force } => { - update::self_update(force, check).await?; - Ok(()) + Commands::Update { + check, + force, + instructions, + } => { + if instructions { + update::print_update_instructions()?; + Ok(()) + } else { + update::self_update(force, check).await?; + Ok(()) + } } Commands::Estop { @@ -2630,4 +2644,41 @@ mod tests { ); assert_eq!(payload["nested"]["non_secret"], serde_json::json!("ok")); } + + #[test] + fn update_help_mentions_instructions_flag() { + let cmd = Cli::command(); + let update_cmd = cmd + .get_subcommands() + .find(|subcommand| subcommand.get_name() == "update") + .expect("update subcommand must exist"); + + let mut output = Vec::new(); + update_cmd + .clone() + .write_long_help(&mut output) + .expect("help generation should succeed"); + let help = String::from_utf8(output).expect("help output should be utf-8"); + + assert!(help.contains("--instructions")); + } + + #[test] + fn update_cli_parses_instructions_flag() { + let cli = Cli::try_parse_from(["zeroclaw", "update", "--instructions"]) + .expect("update --instructions should parse"); + + match cli.command { + Commands::Update { + check, + force, + instructions, + } => { + assert!(!check); + assert!(!force); + assert!(instructions); + } + other => panic!("expected update command, got {other:?}"), + } + } } diff --git a/src/update.rs b/src/update.rs index b0b328e44..b86b6cbb1 100644 --- a/src/update.rs +++ b/src/update.rs @@ -5,6 +5,7 @@ use anyhow::{bail, Context, Result}; use std::env; use std::fs; +use std::io::ErrorKind; use std::path::{Path, PathBuf}; use std::process::Command; @@ -26,6 +27,13 @@ struct Asset { browser_download_url: String, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InstallMethod { + Homebrew, + CargoOrLocal, + Unknown, +} + /// Get the current version of the binary pub fn current_version() -> &'static str { env!("CARGO_PKG_VERSION") @@ -213,6 +221,79 @@ fn get_current_exe() -> Result { env::current_exe().context("Failed to get current executable path") } +fn detect_install_method_for_path(resolved_path: &Path, home_dir: Option<&Path>) -> InstallMethod { + let lower = resolved_path.to_string_lossy().to_ascii_lowercase(); + if lower.contains("/cellar/zeroclaw/") || lower.contains("/homebrew/cellar/zeroclaw/") { + return InstallMethod::Homebrew; + } + + if let Some(home) = home_dir { + if resolved_path.starts_with(home.join(".cargo").join("bin")) + || resolved_path.starts_with(home.join(".local").join("bin")) + { + return InstallMethod::CargoOrLocal; + } + } + + InstallMethod::Unknown +} + +fn detect_install_method(current_exe: &Path) -> InstallMethod { + let resolved = fs::canonicalize(current_exe).unwrap_or_else(|_| current_exe.to_path_buf()); + let home_dir = env::var_os("HOME").map(PathBuf::from); + detect_install_method_for_path(&resolved, home_dir.as_deref()) +} + +/// Print human-friendly update instructions based on detected install method. +pub fn print_update_instructions() -> Result<()> { + let current_exe = get_current_exe()?; + let install_method = detect_install_method(¤t_exe); + + println!("ZeroClaw update guide"); + println!("Detected binary: {}", current_exe.display()); + println!(); + println!("1) Check if a new release exists:"); + println!(" zeroclaw update --check"); + println!(); + + match install_method { + InstallMethod::Homebrew => { + println!("Detected install method: Homebrew"); + println!("Recommended update commands:"); + println!(" brew update"); + println!(" brew upgrade zeroclaw"); + println!(" zeroclaw --version"); + println!(); + println!( + "Tip: avoid `zeroclaw update` on Homebrew installs unless you intentionally want to override the managed binary." + ); + } + InstallMethod::CargoOrLocal => { + println!("Detected install method: local binary (~/.cargo/bin or ~/.local/bin)"); + println!("Recommended update command:"); + println!(" zeroclaw update"); + println!("Optional force reinstall:"); + println!(" zeroclaw update --force"); + println!("Verify:"); + println!(" zeroclaw --version"); + } + InstallMethod::Unknown => { + println!("Detected install method: unknown"); + println!("Try the built-in updater first:"); + println!(" zeroclaw update"); + println!( + "If your package manager owns the binary, use that manager's upgrade command." + ); + println!("Verify:"); + println!(" zeroclaw --version"); + } + } + + println!(); + println!("Release source: https://github.com/{GITHUB_REPO}/releases/latest"); + Ok(()) +} + /// Replace the current binary with the new one fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> { // On Windows, we can't replace a running executable directly @@ -226,11 +307,43 @@ fn replace_binary(new_binary: &Path, current_exe: &Path) -> Result<()> { let _ = fs::remove_file(&old_path); } - // On Unix, we can overwrite the running executable + // On Unix, stage the binary in the destination directory first. + // This avoids cross-filesystem rename failures (EXDEV) from temp dirs. #[cfg(unix)] { - // Use rename for atomic replacement on Unix - fs::rename(new_binary, current_exe).context("Failed to replace binary")?; + use std::os::unix::fs::PermissionsExt; + + let parent = current_exe + .parent() + .context("Current executable has no parent directory")?; + let binary_name = current_exe + .file_name() + .context("Current executable path is missing a file name")? + .to_string_lossy() + .into_owned(); + let staged_path = parent.join(format!(".{binary_name}.new")); + let backup_path = parent.join(format!(".{binary_name}.bak")); + + fs::copy(new_binary, &staged_path).context("Failed to stage updated binary")?; + fs::set_permissions(&staged_path, fs::Permissions::from_mode(0o755)) + .context("Failed to set permissions on staged binary")?; + + if let Err(err) = fs::remove_file(&backup_path) { + if err.kind() != ErrorKind::NotFound { + return Err(err).context("Failed to remove stale backup binary"); + } + } + + fs::rename(current_exe, &backup_path).context("Failed to backup current binary")?; + + if let Err(err) = fs::rename(&staged_path, current_exe) { + let _ = fs::rename(&backup_path, current_exe); + let _ = fs::remove_file(&staged_path); + return Err(err).context("Failed to activate updated binary"); + } + + // Best-effort cleanup of backup. + let _ = fs::remove_file(&backup_path); } Ok(()) @@ -258,6 +371,7 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> { println!(); let current_exe = get_current_exe()?; + let install_method = detect_install_method(¤t_exe); println!("Current binary: {}", current_exe.display()); println!("Current version: v{}", current_version()); println!(); @@ -268,6 +382,31 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> { println!("Latest version: {}", release.tag_name); + if check_only { + println!(); + if latest_version == current_version() { + println!("✅ Already up to date."); + } else { + println!( + "Update available: {} -> {}", + current_version(), + latest_version + ); + println!("Run `zeroclaw update` to install the update."); + } + return Ok(()); + } + + if install_method == InstallMethod::Homebrew && !force { + println!(); + println!("Detected a Homebrew-managed installation."); + println!("Use `brew upgrade zeroclaw` for the safest update path."); + println!( + "Run `zeroclaw update --force` only if you intentionally want to override Homebrew." + ); + return Ok(()); + } + // Check if update is needed if latest_version == current_version() && !force { println!(); @@ -275,17 +414,6 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> { return Ok(()); } - if check_only { - println!(); - println!( - "Update available: {} -> {}", - current_version(), - latest_version - ); - println!("Run `zeroclaw update` to install the update."); - return Ok(()); - } - println!(); println!( "Updating from v{} to {}...", @@ -315,3 +443,50 @@ pub async fn self_update(force: bool, check_only: bool) -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn archive_name_uses_zip_for_windows_and_targz_elsewhere() { + assert_eq!( + get_archive_name("x86_64-pc-windows-msvc"), + "zeroclaw-x86_64-pc-windows-msvc.zip" + ); + assert_eq!( + get_archive_name("x86_64-unknown-linux-gnu"), + "zeroclaw-x86_64-unknown-linux-gnu.tar.gz" + ); + } + + #[test] + fn detect_install_method_identifies_homebrew_paths() { + let path = Path::new("/opt/homebrew/Cellar/zeroclaw/0.1.7/bin/zeroclaw"); + let method = detect_install_method_for_path(path, None); + assert_eq!(method, InstallMethod::Homebrew); + } + + #[test] + fn detect_install_method_identifies_local_bin_paths() { + let home = Path::new("/Users/example"); + let cargo_path = Path::new("/Users/example/.cargo/bin/zeroclaw"); + let local_path = Path::new("/Users/example/.local/bin/zeroclaw"); + + assert_eq!( + detect_install_method_for_path(cargo_path, Some(home)), + InstallMethod::CargoOrLocal + ); + assert_eq!( + detect_install_method_for_path(local_path, Some(home)), + InstallMethod::CargoOrLocal + ); + } + + #[test] + fn detect_install_method_returns_unknown_for_other_paths() { + let path = Path::new("/usr/bin/zeroclaw"); + let method = detect_install_method_for_path(path, Some(Path::new("/Users/example"))); + assert_eq!(method, InstallMethod::Unknown); + } +} From 28eaef17826867c013c858f247b60b2df46e9863 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 22:52:30 -0500 Subject: [PATCH 04/13] fix(ci): reduce queue saturation via branch supersedence --- .github/workflows/ci-queue-hygiene.yml | 11 ++- .github/workflows/ci-run.yml | 2 +- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/test-e2e.yml | 2 +- scripts/ci/queue_hygiene.py | 21 ++++- scripts/ci/tests/test_ci_scripts.py | 113 +++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci-queue-hygiene.yml b/.github/workflows/ci-queue-hygiene.yml index ada0baf02..b1655435a 100644 --- a/.github/workflows/ci-queue-hygiene.yml +++ b/.github/workflows/ci-queue-hygiene.yml @@ -8,7 +8,7 @@ on: apply: description: "Cancel selected queued runs (false = dry-run report only)" required: true - default: true + default: false type: boolean status: description: "Queued-run status scope" @@ -57,22 +57,27 @@ jobs: status_scope="queued" max_cancel="120" - apply_mode="true" + apply_mode="false" if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then status_scope="${{ github.event.inputs.status || 'queued' }}" max_cancel="${{ github.event.inputs.max_cancel || '120' }}" - apply_mode="${{ github.event.inputs.apply || 'true' }}" + apply_mode="${{ github.event.inputs.apply || 'false' }}" fi cmd=(python3 scripts/ci/queue_hygiene.py --repo "${{ github.repository }}" --status "${status_scope}" --max-cancel "${max_cancel}" + --dedupe-workflow "CI Run" + --dedupe-workflow "Test E2E" + --dedupe-workflow "Docs Deploy" --dedupe-workflow "PR Intake Checks" --dedupe-workflow "PR Labeler" --dedupe-workflow "PR Auto Responder" --dedupe-workflow "Workflow Sanity" --dedupe-workflow "PR Label Policy Check" + --dedupe-include-non-pr + --non-pr-key branch --output-json artifacts/queue-hygiene-report.json --verbose) diff --git a/.github/workflows/ci-run.yml b/.github/workflows/ci-run.yml index 196b15cc6..d28abcf0a 100644 --- a/.github/workflows/ci-run.yml +++ b/.github/workflows/ci-run.yml @@ -9,7 +9,7 @@ on: branches: [dev, main] concurrency: - group: ci-${{ github.event.pull_request.number || github.sha }} + group: ci-run-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 6ac5c220a..c1f55d7db 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -41,7 +41,7 @@ on: default: "" concurrency: - group: docs-deploy-${{ github.event.pull_request.number || github.sha }} + group: docs-deploy-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/test-e2e.yml b/.github/workflows/test-e2e.yml index ce3b00a17..8f9a005fd 100644 --- a/.github/workflows/test-e2e.yml +++ b/.github/workflows/test-e2e.yml @@ -14,7 +14,7 @@ on: workflow_dispatch: concurrency: - group: e2e-${{ github.event.pull_request.number || github.sha }} + group: test-e2e-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref_name || github.sha }} cancel-in-progress: true permissions: diff --git a/scripts/ci/queue_hygiene.py b/scripts/ci/queue_hygiene.py index 9255e9b64..ebeb22699 100755 --- a/scripts/ci/queue_hygiene.py +++ b/scripts/ci/queue_hygiene.py @@ -66,6 +66,15 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Also dedupe non-PR runs (push/manual). Default dedupe scope is PR-originated runs only.", ) + parser.add_argument( + "--non-pr-key", + default="sha", + choices=["sha", "branch"], + help=( + "Identity key mode for non-PR dedupe when --dedupe-include-non-pr is enabled: " + "'sha' keeps one run per commit (default), 'branch' keeps one run per branch." + ), + ) parser.add_argument( "--max-cancel", type=int, @@ -165,7 +174,7 @@ def parse_timestamp(value: str | None) -> datetime: return datetime.fromtimestamp(0, tz=timezone.utc) -def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]: +def run_identity_key(run: dict[str, Any], *, non_pr_key: str) -> tuple[str, str, str, str]: name = str(run.get("name", "")) event = str(run.get("event", "")) head_branch = str(run.get("head_branch", "")) @@ -179,7 +188,10 @@ def run_identity_key(run: dict[str, Any]) -> tuple[str, str, str, str]: if pr_number: # For PR traffic, cancel stale runs across synchronize updates for the same PR. return (name, event, f"pr:{pr_number}", "") - # For push/manual traffic, key by SHA to avoid collapsing distinct commits. + if non_pr_key == "branch": + # Branch-level supersedence for push/manual lanes. + return (name, event, head_branch, "") + # SHA-level supersedence for push/manual lanes. return (name, event, head_branch, head_sha) @@ -189,6 +201,7 @@ def collect_candidates( dedupe_workflows: set[str], *, include_non_pr: bool, + non_pr_key: str, ) -> tuple[list[dict[str, Any]], Counter[str]]: reasons_by_id: dict[int, set[str]] = defaultdict(set) runs_by_id: dict[int, dict[str, Any]] = {} @@ -220,7 +233,7 @@ def collect_candidates( has_pr_context = isinstance(pull_requests, list) and len(pull_requests) > 0 if is_pr_event and not has_pr_context and not include_non_pr: continue - key = run_identity_key(run) + key = run_identity_key(run, non_pr_key=non_pr_key) by_workflow[name][key].append(run) for groups in by_workflow.values(): @@ -324,6 +337,7 @@ def main() -> int: obsolete_workflows, dedupe_workflows, include_non_pr=args.dedupe_include_non_pr, + non_pr_key=args.non_pr_key, ) capped = selected[: max(0, args.max_cancel)] @@ -338,6 +352,7 @@ def main() -> int: "obsolete_workflows": sorted(obsolete_workflows), "dedupe_workflows": sorted(dedupe_workflows), "dedupe_include_non_pr": args.dedupe_include_non_pr, + "non_pr_key": args.non_pr_key, "max_cancel": args.max_cancel, }, "counts": { diff --git a/scripts/ci/tests/test_ci_scripts.py b/scripts/ci/tests/test_ci_scripts.py index 1e5c7921a..f18bec46c 100644 --- a/scripts/ci/tests/test_ci_scripts.py +++ b/scripts/ci/tests/test_ci_scripts.py @@ -3759,6 +3759,119 @@ class CiScriptsBehaviorTest(unittest.TestCase): planned_ids = [item["id"] for item in report["planned_actions"]] self.assertEqual(planned_ids, [101, 102]) + def test_queue_hygiene_non_pr_branch_mode_dedupes_push_runs(self) -> None: + runs_json = self.tmp / "runs-non-pr-branch.json" + output_json = self.tmp / "queue-hygiene-non-pr-branch.json" + runs_json.write_text( + json.dumps( + { + "workflow_runs": [ + { + "id": 201, + "name": "CI Run", + "event": "push", + "head_branch": "main", + "head_sha": "sha-201", + "created_at": "2026-02-27T20:00:00Z", + }, + { + "id": 202, + "name": "CI Run", + "event": "push", + "head_branch": "main", + "head_sha": "sha-202", + "created_at": "2026-02-27T20:01:00Z", + }, + { + "id": 203, + "name": "CI Run", + "event": "push", + "head_branch": "dev", + "head_sha": "sha-203", + "created_at": "2026-02-27T20:02:00Z", + }, + ] + } + ) + + "\n", + encoding="utf-8", + ) + + proc = run_cmd( + [ + "python3", + self._script("queue_hygiene.py"), + "--runs-json", + str(runs_json), + "--dedupe-workflow", + "CI Run", + "--dedupe-include-non-pr", + "--non-pr-key", + "branch", + "--output-json", + str(output_json), + ] + ) + self.assertEqual(proc.returncode, 0, msg=proc.stderr) + + report = json.loads(output_json.read_text(encoding="utf-8")) + self.assertEqual(report["counts"]["candidate_runs_before_cap"], 1) + planned_ids = [item["id"] for item in report["planned_actions"]] + self.assertEqual(planned_ids, [201]) + reasons = report["planned_actions"][0]["reasons"] + self.assertTrue(any(reason.startswith("dedupe-superseded-by:202") for reason in reasons)) + self.assertEqual(report["policies"]["non_pr_key"], "branch") + + def test_queue_hygiene_non_pr_sha_mode_keeps_distinct_push_commits(self) -> None: + runs_json = self.tmp / "runs-non-pr-sha.json" + output_json = self.tmp / "queue-hygiene-non-pr-sha.json" + runs_json.write_text( + json.dumps( + { + "workflow_runs": [ + { + "id": 301, + "name": "CI Run", + "event": "push", + "head_branch": "main", + "head_sha": "sha-301", + "created_at": "2026-02-27T20:00:00Z", + }, + { + "id": 302, + "name": "CI Run", + "event": "push", + "head_branch": "main", + "head_sha": "sha-302", + "created_at": "2026-02-27T20:01:00Z", + }, + ] + } + ) + + "\n", + encoding="utf-8", + ) + + proc = run_cmd( + [ + "python3", + self._script("queue_hygiene.py"), + "--runs-json", + str(runs_json), + "--dedupe-workflow", + "CI Run", + "--dedupe-include-non-pr", + "--output-json", + str(output_json), + ] + ) + self.assertEqual(proc.returncode, 0, msg=proc.stderr) + + report = json.loads(output_json.read_text(encoding="utf-8")) + self.assertEqual(report["counts"]["candidate_runs_before_cap"], 0) + self.assertEqual(report["planned_actions"], []) + self.assertEqual(report["policies"]["non_pr_key"], "sha") + if __name__ == "__main__": # pragma: no cover unittest.main(verbosity=2) From fb124b61d4edbfacd28a912ef4e626dcc9cc2d37 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 22:52:25 -0500 Subject: [PATCH 05/13] fix(docs): correct first-run gateway commands --- README.md | 8 ++++---- src/main.rs | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8b8831366..da4a5c3d7 100644 --- a/README.md +++ b/README.md @@ -108,11 +108,11 @@ cargo install zeroclaw ### First Run ```bash -# Start the gateway daemon -zeroclaw gateway start +# Start the gateway (serves the Web Dashboard API/UI) +zeroclaw gateway -# Open the web UI -zeroclaw dashboard +# Open the dashboard URL shown in startup logs +# (default: http://127.0.0.1:3000/) # Or chat directly zeroclaw chat "Hello!" diff --git a/src/main.rs b/src/main.rs index 978235848..826b33bd9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2425,6 +2425,19 @@ mod tests { } } + #[test] + fn readme_does_not_reference_removed_gateway_or_dashboard_commands() { + let readme = include_str!("../README.md"); + assert!( + !readme.contains("zeroclaw gateway start"), + "README should not suggest obsolete 'zeroclaw gateway start'" + ); + assert!( + !readme.contains("zeroclaw dashboard"), + "README should not suggest nonexistent 'zeroclaw dashboard'" + ); + } + #[test] fn completion_generation_mentions_binary_name() { let mut output = Vec::new(); From f83c9732ca72421952fd10d7e981d9a7f30a5bae Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 23:31:37 -0500 Subject: [PATCH 06/13] chore(ci): keep gateway docs fix docs-only --- src/main.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/main.rs b/src/main.rs index 826b33bd9..978235848 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2425,19 +2425,6 @@ mod tests { } } - #[test] - fn readme_does_not_reference_removed_gateway_or_dashboard_commands() { - let readme = include_str!("../README.md"); - assert!( - !readme.contains("zeroclaw gateway start"), - "README should not suggest obsolete 'zeroclaw gateway start'" - ); - assert!( - !readme.contains("zeroclaw dashboard"), - "README should not suggest nonexistent 'zeroclaw dashboard'" - ); - } - #[test] fn completion_generation_mentions_binary_name() { let mut output = Vec::new(); From f3c82cb13a1c53c01d17090b76568c2206edcb16 Mon Sep 17 00:00:00 2001 From: Argenis Date: Sat, 28 Feb 2026 23:51:34 -0500 Subject: [PATCH 07/13] feat(tools): add xlsx_read tool for spreadsheet extraction (#2338) * feat(tools): add xlsx_read tool for spreadsheet extraction * chore(ci): retrigger intake after PR template update --- src/tools/mod.rs | 5 + src/tools/xlsx_read.rs | 1177 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1182 insertions(+) create mode 100644 src/tools/xlsx_read.rs diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 20d6296fd..f2f18ad27 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -82,6 +82,7 @@ pub mod web_access_config; pub mod web_fetch; pub mod web_search_config; pub mod web_search_tool; +pub mod xlsx_read; pub use apply_patch::ApplyPatchTool; pub use bg_run::{ @@ -147,6 +148,7 @@ pub use web_access_config::WebAccessConfigTool; pub use web_fetch::WebFetchTool; pub use web_search_config::WebSearchConfigTool; pub use web_search_tool::WebSearchTool; +pub use xlsx_read::XlsxReadTool; pub use auth_profile::ManageAuthProfileTool; pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool}; @@ -511,6 +513,9 @@ pub fn all_tools_with_runtime( // PPTX text extraction tool_arcs.push(Arc::new(PptxReadTool::new(security.clone()))); + // XLSX text extraction + tool_arcs.push(Arc::new(XlsxReadTool::new(security.clone()))); + // Vision tools are always available tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone()))); tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone()))); diff --git a/src/tools/xlsx_read.rs b/src/tools/xlsx_read.rs new file mode 100644 index 000000000..655bf112f --- /dev/null +++ b/src/tools/xlsx_read.rs @@ -0,0 +1,1177 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::collections::HashMap; +use std::path::{Component, Path}; +use std::sync::Arc; + +/// Maximum XLSX file size (50 MB). +const MAX_XLSX_BYTES: u64 = 50 * 1024 * 1024; +/// Default character limit returned to the LLM. +const DEFAULT_MAX_CHARS: usize = 50_000; +/// Hard ceiling regardless of what the caller requests. +const MAX_OUTPUT_CHARS: usize = 200_000; +/// Upper bound for total uncompressed XML read from sheet files. +const MAX_TOTAL_SHEET_XML_BYTES: u64 = 16 * 1024 * 1024; + +/// Extract plain text from an XLSX file in the workspace. +pub struct XlsxReadTool { + security: Arc, +} + +impl XlsxReadTool { + pub fn new(security: Arc) -> Self { + Self { security } + } +} + +/// Extract plain text from XLSX bytes. +/// +/// XLSX is a ZIP archive containing `xl/worksheets/sheet*.xml` with cell data, +/// `xl/sharedStrings.xml` with a string pool, and `xl/workbook.xml` with sheet +/// names. Text cells reference the shared string pool by index; inline and +/// numeric values are taken directly from `` elements. +fn extract_xlsx_text(bytes: &[u8]) -> anyhow::Result { + extract_xlsx_text_with_limits(bytes, MAX_TOTAL_SHEET_XML_BYTES) +} + +fn extract_xlsx_text_with_limits( + bytes: &[u8], + max_total_sheet_xml_bytes: u64, +) -> anyhow::Result { + use std::io::Read; + + let cursor = std::io::Cursor::new(bytes); + let mut archive = zip::ZipArchive::new(cursor)?; + + // 1. Parse shared strings table. + let shared_strings = parse_shared_strings(&mut archive)?; + + // 2. Parse workbook.xml to get sheet names and rIds. + let sheet_entries = parse_workbook_sheets(&mut archive)?; + + // 3. Parse workbook.xml.rels to map rId → Target path. + let rel_targets = parse_workbook_rels(&mut archive)?; + + // 4. Build ordered list of (sheet_name, file_path) pairs. + let mut ordered_sheets: Vec<(String, String)> = Vec::new(); + for (sheet_name, r_id) in &sheet_entries { + if let Some(target) = rel_targets.get(r_id) { + if let Some(normalized) = normalize_sheet_target(target) { + ordered_sheets.push((sheet_name.clone(), normalized)); + } + } + } + + // Fallback: if workbook parsing yielded no sheets, scan ZIP entries directly. + if ordered_sheets.is_empty() { + let mut fallback_paths: Vec = (0..archive.len()) + .filter_map(|i| { + let name = archive.by_index(i).ok()?.name().to_string(); + if name.starts_with("xl/worksheets/sheet") && name.ends_with(".xml") { + Some(name) + } else { + None + } + }) + .collect(); + fallback_paths.sort_by(|a, b| { + let a_idx = sheet_numeric_index(a); + let b_idx = sheet_numeric_index(b); + a_idx.cmp(&b_idx).then_with(|| a.cmp(b)) + }); + + if fallback_paths.is_empty() { + anyhow::bail!("Not a valid XLSX (no worksheet XML files found)"); + } + + for (i, path) in fallback_paths.into_iter().enumerate() { + ordered_sheets.push((format!("Sheet{}", i + 1), path)); + } + } + + // 5. Extract cell text from each sheet. + let mut output = String::new(); + let mut total_sheet_xml_bytes = 0u64; + let multi_sheet = ordered_sheets.len() > 1; + + for (sheet_name, sheet_path) in &ordered_sheets { + let mut sheet_file = match archive.by_name(sheet_path) { + Ok(f) => f, + Err(_) => continue, + }; + + let sheet_xml_size = sheet_file.size(); + total_sheet_xml_bytes = total_sheet_xml_bytes + .checked_add(sheet_xml_size) + .ok_or_else(|| anyhow::anyhow!("Sheet XML payload size overflow"))?; + if total_sheet_xml_bytes > max_total_sheet_xml_bytes { + anyhow::bail!( + "Sheet XML payload too large: {} bytes (limit: {} bytes)", + total_sheet_xml_bytes, + max_total_sheet_xml_bytes + ); + } + + let mut xml_content = String::new(); + sheet_file.read_to_string(&mut xml_content)?; + + if multi_sheet { + if !output.is_empty() { + output.push('\n'); + } + use std::fmt::Write as _; + let _ = writeln!(output, "--- Sheet: {} ---", sheet_name); + } + + let sheet_text = extract_sheet_cells(&xml_content, &shared_strings)?; + output.push_str(&sheet_text); + } + + Ok(output) +} + +/// Parse `xl/sharedStrings.xml` into a `Vec` indexed by position. +fn parse_shared_strings( + archive: &mut zip::ZipArchive, +) -> anyhow::Result> { + use quick_xml::events::Event; + use quick_xml::Reader; + use std::io::Read; + + let mut xml = String::new(); + match archive.by_name("xl/sharedStrings.xml") { + Ok(mut f) => { + f.read_to_string(&mut xml)?; + } + Err(zip::result::ZipError::FileNotFound) => return Ok(Vec::new()), + Err(e) => return Err(e.into()), + } + + let mut strings = Vec::new(); + let mut reader = Reader::from_str(&xml); + let mut in_si = false; + let mut in_t = false; + let mut current = String::new(); + + loop { + match reader.read_event() { + Ok(Event::Start(e)) => { + let qname = e.name(); + let name = local_name(qname.as_ref()); + if name == b"si" { + in_si = true; + current.clear(); + } else if in_si && name == b"t" { + in_t = true; + } + } + Ok(Event::End(e)) => { + let qname = e.name(); + let name = local_name(qname.as_ref()); + if name == b"t" { + in_t = false; + } else if name == b"si" { + in_si = false; + strings.push(std::mem::take(&mut current)); + } + } + Ok(Event::Text(e)) => { + if in_t { + current.push_str(&e.unescape()?); + } + } + Ok(Event::Eof) => break, + Err(e) => return Err(e.into()), + _ => {} + } + } + + Ok(strings) +} + +/// Parse `xl/workbook.xml` → Vec<(sheet_name, rId)>. +fn parse_workbook_sheets( + archive: &mut zip::ZipArchive, +) -> anyhow::Result> { + use quick_xml::events::Event; + use quick_xml::Reader; + use std::io::Read; + + let mut xml = String::new(); + match archive.by_name("xl/workbook.xml") { + Ok(mut f) => { + f.read_to_string(&mut xml)?; + } + Err(zip::result::ZipError::FileNotFound) => return Ok(Vec::new()), + Err(e) => return Err(e.into()), + } + + let mut sheets = Vec::new(); + let mut reader = Reader::from_str(&xml); + + loop { + match reader.read_event() { + Ok(Event::Start(ref e) | Event::Empty(ref e)) => { + let qname = e.name(); + if local_name(qname.as_ref()) == b"sheet" { + let mut name = None; + let mut r_id = None; + for attr in e.attributes().flatten() { + let key = attr.key.as_ref(); + let local = local_name(key); + if local == b"name" { + name = Some( + attr.decode_and_unescape_value(reader.decoder())? + .into_owned(), + ); + } else if key == b"r:id" || local == b"id" { + // Accept both r:id and {ns}:id variants. + // Only take the relationship id (starts with "rId"). + let val = attr + .decode_and_unescape_value(reader.decoder())? + .into_owned(); + if val.starts_with("rId") { + r_id = Some(val); + } + } + } + if let (Some(n), Some(r)) = (name, r_id) { + sheets.push((n, r)); + } + } + } + Ok(Event::Eof) => break, + Err(e) => return Err(e.into()), + _ => {} + } + } + + Ok(sheets) +} + +/// Parse `xl/_rels/workbook.xml.rels` → HashMap. +fn parse_workbook_rels( + archive: &mut zip::ZipArchive, +) -> anyhow::Result> { + use quick_xml::events::Event; + use quick_xml::Reader; + use std::io::Read; + + let mut xml = String::new(); + match archive.by_name("xl/_rels/workbook.xml.rels") { + Ok(mut f) => { + f.read_to_string(&mut xml)?; + } + Err(zip::result::ZipError::FileNotFound) => return Ok(HashMap::new()), + Err(e) => return Err(e.into()), + } + + let mut rels = HashMap::new(); + let mut reader = Reader::from_str(&xml); + + loop { + match reader.read_event() { + Ok(Event::Start(ref e) | Event::Empty(ref e)) => { + let qname = e.name(); + if local_name(qname.as_ref()) == b"Relationship" { + let mut rel_id = None; + let mut target = None; + for attr in e.attributes().flatten() { + let key = local_name(attr.key.as_ref()); + if key.eq_ignore_ascii_case(b"id") { + rel_id = Some( + attr.decode_and_unescape_value(reader.decoder())? + .into_owned(), + ); + } else if key.eq_ignore_ascii_case(b"target") { + target = Some( + attr.decode_and_unescape_value(reader.decoder())? + .into_owned(), + ); + } + } + if let (Some(id), Some(t)) = (rel_id, target) { + rels.insert(id, t); + } + } + } + Ok(Event::Eof) => break, + Err(e) => return Err(e.into()), + _ => {} + } + } + + Ok(rels) +} + +/// Extract cell text from a single worksheet XML string. +/// +/// Cells are output as tab-separated values per row, newline-separated per row. +fn extract_sheet_cells(xml: &str, shared_strings: &[String]) -> anyhow::Result { + use quick_xml::events::Event; + use quick_xml::Reader; + + let mut reader = Reader::from_str(xml); + let mut output = String::new(); + + let mut in_row = false; + let mut in_cell = false; + let mut in_value = false; + let mut cell_type = CellType::Number; + let mut cell_value = String::new(); + let mut row_cells: Vec = Vec::new(); + + loop { + match reader.read_event() { + Ok(Event::Start(e)) => { + let qname = e.name(); + let name = local_name(qname.as_ref()); + match name { + b"row" => { + in_row = true; + row_cells.clear(); + } + b"c" if in_row => { + in_cell = true; + cell_type = CellType::Number; + cell_value.clear(); + for attr in e.attributes().flatten() { + if attr.key.as_ref() == b"t" { + let val = attr.decode_and_unescape_value(reader.decoder())?; + cell_type = match val.as_ref() { + "s" => CellType::SharedString, + "inlineStr" => CellType::InlineString, + "b" => CellType::Boolean, + _ => CellType::Number, + }; + } + } + } + b"v" if in_cell => { + in_value = true; + } + b"t" if in_cell && cell_type == CellType::InlineString => { + // Inline string: text is inside ... + in_value = true; + } + _ => {} + } + } + Ok(Event::End(e)) => { + let qname = e.name(); + let name = local_name(qname.as_ref()); + match name { + b"row" => { + in_row = false; + if !row_cells.is_empty() { + if !output.is_empty() { + output.push('\n'); + } + output.push_str(&row_cells.join("\t")); + } + } + b"c" if in_cell => { + in_cell = false; + let resolved = match cell_type { + CellType::SharedString => { + if let Ok(idx) = cell_value.trim().parse::() { + shared_strings.get(idx).cloned().unwrap_or_default() + } else { + cell_value.clone() + } + } + CellType::Boolean => match cell_value.trim() { + "1" => "TRUE".to_string(), + "0" => "FALSE".to_string(), + other => other.to_string(), + }, + _ => cell_value.clone(), + }; + row_cells.push(resolved); + } + b"v" => { + in_value = false; + } + b"t" if in_cell => { + in_value = false; + } + _ => {} + } + } + Ok(Event::Text(e)) => { + if in_value { + cell_value.push_str(&e.unescape()?); + } + } + Ok(Event::Eof) => break, + Err(e) => return Err(e.into()), + _ => {} + } + } + + // Flush last row if not terminated by . + if in_row && !row_cells.is_empty() { + if !output.is_empty() { + output.push('\n'); + } + output.push_str(&row_cells.join("\t")); + } + + if !output.is_empty() { + output.push('\n'); + } + + Ok(output) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CellType { + Number, + SharedString, + InlineString, + Boolean, +} + +fn sheet_numeric_index(sheet_path: &str) -> Option { + let stem = Path::new(sheet_path).file_stem()?.to_string_lossy(); + let digits = stem.strip_prefix("sheet")?; + digits.parse::().ok() +} + +fn local_name(name: &[u8]) -> &[u8] { + name.rsplit(|b| *b == b':').next().unwrap_or(name) +} + +fn normalize_sheet_target(target: &str) -> Option { + if target.contains("://") { + return None; + } + + let mut segments = Vec::new(); + for component in Path::new("xl").join(target).components() { + match component { + Component::Normal(part) => segments.push(part.to_string_lossy().to_string()), + Component::ParentDir => { + segments.pop()?; + } + _ => {} + } + } + + let normalized = segments.join("/"); + if normalized.starts_with("xl/worksheets/") && normalized.ends_with(".xml") { + Some(normalized) + } else { + None + } +} + +fn parse_max_chars(args: &serde_json::Value) -> anyhow::Result { + let Some(value) = args.get("max_chars") else { + return Ok(DEFAULT_MAX_CHARS); + }; + + let serde_json::Value::Number(number) = value else { + anyhow::bail!("Invalid 'max_chars': expected a positive integer"); + }; + let Some(raw) = number.as_u64() else { + anyhow::bail!("Invalid 'max_chars': expected a positive integer"); + }; + if raw == 0 { + anyhow::bail!("Invalid 'max_chars': must be >= 1"); + } + + Ok(usize::try_from(raw) + .unwrap_or(MAX_OUTPUT_CHARS) + .min(MAX_OUTPUT_CHARS)) +} + +#[async_trait] +impl Tool for XlsxReadTool { + fn name(&self) -> &str { + "xlsx_read" + } + + fn description(&self) -> &str { + "Extract plain text and numeric data from an XLSX (Excel) file in the workspace. \ + Returns tab-separated cell values per row for each sheet. \ + No formulas, charts, styles, or merged-cell awareness." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the XLSX file. Relative paths resolve from workspace." + }, + "max_chars": { + "type": "integer", + "description": "Maximum characters to return (default: 50000, max: 200000)", + "minimum": 1, + "maximum": 200_000 + } + }, + "required": ["path"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + + let max_chars = match parse_max_chars(&args) { + Ok(value) => value, + Err(err) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(err.to_string()), + }) + } + }; + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: too many actions in the last hour".into()), + }); + } + + if !self.security.is_path_allowed(path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path not allowed by security policy: {path}")), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + let full_path = self.security.workspace_dir.join(path); + + let resolved_path = match tokio::fs::canonicalize(&full_path).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + self.security + .resolved_path_violation_message(&resolved_path), + ), + }); + } + + tracing::debug!("Reading XLSX: {}", resolved_path.display()); + + match tokio::fs::metadata(&resolved_path).await { + Ok(meta) => { + if meta.len() > MAX_XLSX_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "XLSX too large: {} bytes (limit: {MAX_XLSX_BYTES} bytes)", + meta.len() + )), + }); + } + } + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to read file metadata: {e}")), + }); + } + } + + let bytes = match tokio::fs::read(&resolved_path).await { + Ok(b) => b, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to read XLSX file: {e}")), + }); + } + }; + + let text = match tokio::task::spawn_blocking(move || extract_xlsx_text(&bytes)).await { + Ok(Ok(t)) => t, + Ok(Err(e)) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("XLSX extraction failed: {e}")), + }); + } + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("XLSX extraction task panicked: {e}")), + }); + } + }; + + if text.trim().is_empty() { + return Ok(ToolResult { + success: true, + output: "XLSX contains no extractable text".into(), + error: None, + }); + } + + let output = if text.chars().count() > max_chars { + let mut truncated: String = text.chars().take(max_chars).collect(); + use std::fmt::Write as _; + let _ = write!(truncated, "\n\n... [truncated at {max_chars} chars]"); + truncated + } else { + text + }; + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + use tempfile::TempDir; + + fn test_security(workspace: std::path::PathBuf) -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + workspace_dir: workspace, + ..SecurityPolicy::default() + }) + } + + fn test_security_with_limit( + workspace: std::path::PathBuf, + max_actions: u32, + ) -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + workspace_dir: workspace, + max_actions_per_hour: max_actions, + ..SecurityPolicy::default() + }) + } + + /// Build a minimal valid XLSX (ZIP) in memory with one sheet containing + /// the given rows. Each inner `Vec<&str>` is a row of cell values. + fn minimal_xlsx_bytes(rows: &[Vec<&str>]) -> Vec { + use std::io::Write; + + // Build shared strings from all unique cell values. + let mut all_values: Vec = Vec::new(); + for row in rows { + for cell in row { + if !all_values.contains(&cell.to_string()) { + all_values.push(cell.to_string()); + } + } + } + + let mut ss_entries = String::new(); + for val in &all_values { + ss_entries.push_str(&format!("{val}")); + } + let shared_strings_xml = format!( + r#" +{ss_entries}"#, + all_values.len(), + all_values.len() + ); + + // Build sheet XML. + let mut sheet_rows = String::new(); + for (r_idx, row) in rows.iter().enumerate() { + sheet_rows.push_str(&format!(r#""#, r_idx + 1)); + for (c_idx, cell) in row.iter().enumerate() { + let col_letter = (b'A' + c_idx as u8) as char; + let cell_ref = format!("{}{}", col_letter, r_idx + 1); + let ss_idx = all_values.iter().position(|v| v == cell).unwrap(); + sheet_rows.push_str(&format!(r#"{ss_idx}"#)); + } + sheet_rows.push_str(""); + } + let sheet_xml = format!( + r#" + +{sheet_rows} +"# + ); + + let workbook_xml = r#" + + +"#; + + let rels_xml = r#" + + +"#; + + let buf = std::io::Cursor::new(Vec::new()); + let mut zip = zip::ZipWriter::new(buf); + let options = zip::write::SimpleFileOptions::default() + .compression_method(zip::CompressionMethod::Stored); + + zip.start_file("xl/sharedStrings.xml", options).unwrap(); + zip.write_all(shared_strings_xml.as_bytes()).unwrap(); + + zip.start_file("xl/workbook.xml", options).unwrap(); + zip.write_all(workbook_xml.as_bytes()).unwrap(); + + zip.start_file("xl/_rels/workbook.xml.rels", options) + .unwrap(); + zip.write_all(rels_xml.as_bytes()).unwrap(); + + zip.start_file("xl/worksheets/sheet1.xml", options).unwrap(); + zip.write_all(sheet_xml.as_bytes()).unwrap(); + + zip.finish().unwrap().into_inner() + } + + /// Build an XLSX with two sheets. + fn two_sheet_xlsx_bytes( + sheet1_name: &str, + sheet1_rows: &[Vec<&str>], + sheet2_name: &str, + sheet2_rows: &[Vec<&str>], + ) -> Vec { + use std::io::Write; + + // Collect all unique values across both sheets. + let mut all_values: Vec = Vec::new(); + for rows in [sheet1_rows, sheet2_rows] { + for row in rows { + for cell in row { + if !all_values.contains(&cell.to_string()) { + all_values.push(cell.to_string()); + } + } + } + } + + let mut ss_entries = String::new(); + for val in &all_values { + ss_entries.push_str(&format!("{val}")); + } + let shared_strings_xml = format!( + r#" +{ss_entries}"#, + all_values.len(), + all_values.len() + ); + + let build_sheet = |rows: &[Vec<&str>]| -> String { + let mut sheet_rows = String::new(); + for (r_idx, row) in rows.iter().enumerate() { + sheet_rows.push_str(&format!(r#""#, r_idx + 1)); + for (c_idx, cell) in row.iter().enumerate() { + let col_letter = (b'A' + c_idx as u8) as char; + let cell_ref = format!("{}{}", col_letter, r_idx + 1); + let ss_idx = all_values.iter().position(|v| v == cell).unwrap(); + sheet_rows.push_str(&format!(r#"{ss_idx}"#)); + } + sheet_rows.push_str(""); + } + format!( + r#" + +{sheet_rows} +"# + ) + }; + + let workbook_xml = format!( + r#" + + + + + +"# + ); + + let rels_xml = r#" + + + +"#; + + let buf = std::io::Cursor::new(Vec::new()); + let mut zip = zip::ZipWriter::new(buf); + let options = zip::write::SimpleFileOptions::default() + .compression_method(zip::CompressionMethod::Stored); + + zip.start_file("xl/sharedStrings.xml", options).unwrap(); + zip.write_all(shared_strings_xml.as_bytes()).unwrap(); + + zip.start_file("xl/workbook.xml", options).unwrap(); + zip.write_all(workbook_xml.as_bytes()).unwrap(); + + zip.start_file("xl/_rels/workbook.xml.rels", options) + .unwrap(); + zip.write_all(rels_xml.as_bytes()).unwrap(); + + zip.start_file("xl/worksheets/sheet1.xml", options).unwrap(); + zip.write_all(build_sheet(sheet1_rows).as_bytes()).unwrap(); + + zip.start_file("xl/worksheets/sheet2.xml", options).unwrap(); + zip.write_all(build_sheet(sheet2_rows).as_bytes()).unwrap(); + + zip.finish().unwrap().into_inner() + } + + #[test] + fn name_is_xlsx_read() { + let tool = XlsxReadTool::new(test_security(std::env::temp_dir())); + assert_eq!(tool.name(), "xlsx_read"); + } + + #[test] + fn description_not_empty() { + let tool = XlsxReadTool::new(test_security(std::env::temp_dir())); + assert!(!tool.description().is_empty()); + } + + #[test] + fn schema_has_path_required() { + let tool = XlsxReadTool::new(test_security(std::env::temp_dir())); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["path"].is_object()); + assert!(schema["properties"]["max_chars"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("path"))); + } + + #[test] + fn spec_matches_metadata() { + let tool = XlsxReadTool::new(test_security(std::env::temp_dir())); + let spec = tool.spec(); + assert_eq!(spec.name, "xlsx_read"); + assert!(spec.parameters.is_object()); + } + + #[tokio::test] + async fn missing_path_param_returns_error() { + let tool = XlsxReadTool::new(test_security(std::env::temp_dir())); + let result = tool.execute(json!({})).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("path")); + } + + #[tokio::test] + async fn absolute_path_is_blocked() { + let tool = XlsxReadTool::new(test_security(std::env::temp_dir())); + let result = tool.execute(json!({"path": "/etc/passwd"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("not allowed")); + } + + #[tokio::test] + async fn path_traversal_is_blocked() { + let tmp = TempDir::new().unwrap(); + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool + .execute(json!({"path": "../../../etc/passwd"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("not allowed")); + } + + #[tokio::test] + async fn nonexistent_file_returns_error() { + let tmp = TempDir::new().unwrap(); + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool.execute(json!({"path": "missing.xlsx"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Failed to resolve")); + } + + #[tokio::test] + async fn rate_limit_blocks_request() { + let tmp = TempDir::new().unwrap(); + let tool = XlsxReadTool::new(test_security_with_limit(tmp.path().to_path_buf(), 0)); + let result = tool.execute(json!({"path": "any.xlsx"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("Rate limit")); + } + + #[tokio::test] + async fn extracts_text_from_valid_xlsx() { + let tmp = TempDir::new().unwrap(); + let xlsx_path = tmp.path().join("data.xlsx"); + let rows = vec![vec!["Name", "Age"], vec!["Alice", "30"]]; + tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows)) + .await + .unwrap(); + + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool.execute(json!({"path": "data.xlsx"})).await.unwrap(); + assert!(result.success, "error: {:?}", result.error); + assert!( + result.output.contains("Name"), + "expected 'Name' in output, got: {}", + result.output + ); + assert!(result.output.contains("Age")); + assert!(result.output.contains("Alice")); + assert!(result.output.contains("30")); + } + + #[tokio::test] + async fn extracts_tab_separated_columns() { + let tmp = TempDir::new().unwrap(); + let xlsx_path = tmp.path().join("cols.xlsx"); + let rows = vec![vec!["A", "B", "C"]]; + tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows)) + .await + .unwrap(); + + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool.execute(json!({"path": "cols.xlsx"})).await.unwrap(); + assert!(result.success); + assert!( + result.output.contains("A\tB\tC"), + "expected tab-separated output, got: {:?}", + result.output + ); + } + + #[tokio::test] + async fn extracts_multiple_sheets() { + let tmp = TempDir::new().unwrap(); + let xlsx_path = tmp.path().join("multi.xlsx"); + let bytes = two_sheet_xlsx_bytes( + "Sales", + &[vec!["Product", "Revenue"], vec!["Widget", "1000"]], + "Costs", + &[vec!["Item", "Amount"], vec!["Rent", "500"]], + ); + tokio::fs::write(&xlsx_path, bytes).await.unwrap(); + + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool.execute(json!({"path": "multi.xlsx"})).await.unwrap(); + assert!(result.success, "error: {:?}", result.error); + assert!(result.output.contains("--- Sheet: Sales ---")); + assert!(result.output.contains("--- Sheet: Costs ---")); + assert!(result.output.contains("Widget")); + assert!(result.output.contains("Rent")); + } + + #[tokio::test] + async fn invalid_zip_returns_extraction_error() { + let tmp = TempDir::new().unwrap(); + let xlsx_path = tmp.path().join("bad.xlsx"); + tokio::fs::write(&xlsx_path, b"this is not a zip file") + .await + .unwrap(); + + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool.execute(json!({"path": "bad.xlsx"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("extraction failed")); + } + + #[tokio::test] + async fn max_chars_truncates_output() { + let tmp = TempDir::new().unwrap(); + let long_text = "B".repeat(200); + let rows = vec![vec![long_text.as_str(); 10]]; + let xlsx_path = tmp.path().join("long.xlsx"); + tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows)) + .await + .unwrap(); + + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool + .execute(json!({"path": "long.xlsx", "max_chars": 50})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("truncated")); + } + + #[tokio::test] + async fn invalid_max_chars_returns_tool_error() { + let tmp = TempDir::new().unwrap(); + let xlsx_path = tmp.path().join("data.xlsx"); + let rows = vec![vec!["Hello"]]; + tokio::fs::write(&xlsx_path, minimal_xlsx_bytes(&rows)) + .await + .unwrap(); + + let tool = XlsxReadTool::new(test_security(tmp.path().to_path_buf())); + let result = tool + .execute(json!({"path": "data.xlsx", "max_chars": "100"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("max_chars")); + } + + #[test] + fn shared_string_reference_resolved() { + let rows = vec![vec!["Hello", "World"]]; + let bytes = minimal_xlsx_bytes(&rows); + let text = extract_xlsx_text(&bytes).unwrap(); + assert!(text.contains("Hello")); + assert!(text.contains("World")); + } + + #[test] + fn cumulative_sheet_xml_limit_is_enforced() { + let rows = vec![vec!["Alpha", "Beta"]]; + let bytes = minimal_xlsx_bytes(&rows); + let error = extract_xlsx_text_with_limits(&bytes, 64).unwrap_err(); + assert!(error.to_string().contains("Sheet XML payload too large")); + } + + #[test] + fn numeric_cells_extracted_directly() { + use std::io::Write; + + // Build a sheet with numeric cells (no t="s" attribute). + let sheet_xml = r#" + + +423.14 + +"#; + + let workbook_xml = r#" + + +"#; + + let rels_xml = r#" + + +"#; + + let buf = std::io::Cursor::new(Vec::new()); + let mut zip = zip::ZipWriter::new(buf); + let options = zip::write::SimpleFileOptions::default() + .compression_method(zip::CompressionMethod::Stored); + + zip.start_file("xl/workbook.xml", options).unwrap(); + zip.write_all(workbook_xml.as_bytes()).unwrap(); + zip.start_file("xl/_rels/workbook.xml.rels", options) + .unwrap(); + zip.write_all(rels_xml.as_bytes()).unwrap(); + zip.start_file("xl/worksheets/sheet1.xml", options).unwrap(); + zip.write_all(sheet_xml.as_bytes()).unwrap(); + + let bytes = zip.finish().unwrap().into_inner(); + let text = extract_xlsx_text(&bytes).unwrap(); + assert!(text.contains("42"), "got: {text}"); + assert!(text.contains("3.14"), "got: {text}"); + assert!(text.contains("42\t3.14"), "got: {text}"); + } + + #[test] + fn fallback_when_no_workbook() { + use std::io::Write; + + // ZIP with only sheet files, no workbook.xml. + let sheet_xml = r#" + + +99 + +"#; + + let buf = std::io::Cursor::new(Vec::new()); + let mut zip = zip::ZipWriter::new(buf); + let options = zip::write::SimpleFileOptions::default() + .compression_method(zip::CompressionMethod::Stored); + + zip.start_file("xl/worksheets/sheet1.xml", options).unwrap(); + zip.write_all(sheet_xml.as_bytes()).unwrap(); + + let bytes = zip.finish().unwrap().into_inner(); + let text = extract_xlsx_text(&bytes).unwrap(); + assert!(text.contains("99"), "got: {text}"); + } + + #[cfg(unix)] + #[tokio::test] + async fn symlink_escape_is_blocked() { + use std::os::unix::fs::symlink; + + let root = TempDir::new().unwrap(); + let workspace = root.path().join("workspace"); + let outside = root.path().join("outside"); + tokio::fs::create_dir_all(&workspace).await.unwrap(); + tokio::fs::create_dir_all(&outside).await.unwrap(); + let rows = vec![vec!["secret"]]; + tokio::fs::write(outside.join("secret.xlsx"), minimal_xlsx_bytes(&rows)) + .await + .unwrap(); + symlink(outside.join("secret.xlsx"), workspace.join("link.xlsx")).unwrap(); + + let tool = XlsxReadTool::new(test_security(workspace)); + let result = tool.execute(json!({"path": "link.xlsx"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("escapes workspace")); + } + +} From 0683467bc10fe97cae2a2a119c411911f67eb132 Mon Sep 17 00:00:00 2001 From: Argenis Date: Sat, 28 Feb 2026 23:53:59 -0500 Subject: [PATCH 08/13] fix(channels): prompt non-CLI always_ask approvals (#2337) * fix(channels): prompt non-cli always_ask approvals * chore(ci): retrigger intake after PR template update --- src/agent/loop_.rs | 51 +++++++------ src/channels/mod.rs | 178 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 26 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index bf5793fc1..802dd7453 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -778,36 +778,40 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context( on_delta: Option>, hooks: Option<&crate::hooks::HookRunner>, excluded_tools: &[String], + progress_mode: ProgressMode, safety_heartbeat: Option, ) -> Result { let reply_target = non_cli_approval_context .as_ref() .map(|ctx| ctx.reply_target.clone()); - SAFETY_HEARTBEAT_CONFIG + TOOL_LOOP_PROGRESS_MODE .scope( - safety_heartbeat, - TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope( - non_cli_approval_context, - TOOL_LOOP_REPLY_TARGET.scope( - reply_target, - run_tool_call_loop( - provider, - history, - tools_registry, - observer, - provider_name, - model, - temperature, - silent, - approval, - channel_name, - multimodal_config, - max_tool_iterations, - cancellation_token, - on_delta, - hooks, - excluded_tools, + progress_mode, + SAFETY_HEARTBEAT_CONFIG.scope( + safety_heartbeat, + TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope( + non_cli_approval_context, + TOOL_LOOP_REPLY_TARGET.scope( + reply_target, + run_tool_call_loop( + provider, + history, + tools_registry, + observer, + provider_name, + model, + temperature, + silent, + approval, + channel_name, + multimodal_config, + max_tool_iterations, + cancellation_token, + on_delta, + hooks, + excluded_tools, + ), ), ), ), @@ -3617,6 +3621,7 @@ mod tests { None, None, &[], + ProgressMode::Verbose, None, ) .await diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 51cf345de..1a5251895 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -78,7 +78,8 @@ pub use whatsapp_web::WhatsAppWebChannel; use crate::agent::loop_::{ build_shell_policy_instructions, build_tool_instructions_from_specs, - run_tool_call_loop_with_reply_target, scrub_credentials, SafetyHeartbeatConfig, + run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext, + NonCliApprovalPrompt, SafetyHeartbeatConfig, }; use crate::agent::session::{resolve_session_id, shared_session_manager, Session, SessionManager}; use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError}; @@ -3664,11 +3665,53 @@ or tune thresholds in config.", let timeout_budget_secs = channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations); + + let (approval_prompt_tx, mut approval_prompt_rx) = + tokio::sync::mpsc::unbounded_channel::(); + let non_cli_approval_context = if msg.channel != "cli" && target_channel.is_some() { + Some(NonCliApprovalContext { + sender: msg.sender.clone(), + reply_target: msg.reply_target.clone(), + prompt_tx: approval_prompt_tx, + }) + } else { + drop(approval_prompt_tx); + None + }; + let approval_prompt_dispatcher = if let (Some(channel_ref), true) = + (target_channel.as_ref(), non_cli_approval_context.is_some()) + { + let channel = Arc::clone(channel_ref); + let reply_target = msg.reply_target.clone(); + let thread_ts = msg.thread_ts.clone(); + Some(tokio::spawn(async move { + while let Some(prompt) = approval_prompt_rx.recv().await { + if let Err(err) = channel + .send_approval_prompt( + &reply_target, + &prompt.request_id, + &prompt.tool_name, + &prompt.arguments, + thread_ts.clone(), + ) + .await + { + tracing::warn!( + "Failed to send non-CLI approval prompt for request {}: {err}", + prompt.request_id + ); + } + } + })) + } else { + None + }; + let llm_result = tokio::select! { () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, result = tokio::time::timeout( Duration::from_secs(timeout_budget_secs), - run_tool_call_loop_with_reply_target( + run_tool_call_loop_with_non_cli_approval_context( active_provider.as_ref(), &mut history, ctx.tools_registry.as_ref(), @@ -3679,7 +3722,7 @@ or tune thresholds in config.", true, Some(ctx.approval_manager.as_ref()), msg.channel.as_str(), - Some(msg.reply_target.as_str()), + non_cli_approval_context, &ctx.multimodal, ctx.max_tool_iterations, Some(cancellation_token.clone()), @@ -3687,6 +3730,7 @@ or tune thresholds in config.", ctx.hooks.as_deref(), &excluded_tools_snapshot, progress_mode, + ctx.safety_heartbeat.clone(), ), ) => LlmExecutionResult::Completed(result), }; @@ -3694,6 +3738,9 @@ or tune thresholds in config.", if let Some(handle) = draft_updater { let _ = handle.await; } + if let Some(handle) = approval_prompt_dispatcher { + let _ = handle.await; + } if let Some(token) = typing_cancellation.as_ref() { token.cancel(); @@ -7653,6 +7700,131 @@ BTC is currently around $65,000 based on latest tool output."# assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0); } + #[tokio::test] + async fn process_channel_message_prompts_and_waits_for_non_cli_always_ask_approval() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let autonomy_cfg = crate::config::AutonomyConfig { + always_ask: vec!["mock_price".to_string()], + ..crate::config::AutonomyConfig::default() + }; + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(ToolCallingProvider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + conversation_locks: Default::default(), + session_config: crate::config::AgentSessionConfig::default(), + session_manager: None, + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + hooks: None, + non_cli_excluded_tools: Arc::new(Mutex::new(Vec::new())), + query_classification: crate::config::QueryClassificationConfig::default(), + model_routes: Vec::new(), + approval_manager: Arc::new(ApprovalManager::from_config(&autonomy_cfg)), + safety_heartbeat: None, + startup_perplexity_filter: crate::config::PerplexityFilterConfig::default(), + }); + + let runtime_ctx_for_first_turn = runtime_ctx.clone(); + let first_turn = tokio::spawn(async move { + process_channel_message( + runtime_ctx_for_first_turn, + traits::ChannelMessage { + id: "msg-non-cli-approval-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "What is the BTC price now?".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + }); + + let request_id = tokio::time::timeout(Duration::from_secs(2), async { + loop { + let pending = runtime_ctx.approval_manager.list_non_cli_pending_requests( + Some("alice"), + Some("telegram"), + Some("chat-1"), + ); + if let Some(req) = pending.first() { + break req.request_id.clone(); + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("pending approval request should be created for always_ask tool"); + + process_channel_message( + runtime_ctx.clone(), + traits::ChannelMessage { + id: "msg-non-cli-approval-2".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: format!("/approve-allow {request_id}"), + channel: "telegram".to_string(), + timestamp: 2, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + tokio::time::timeout(Duration::from_secs(5), first_turn) + .await + .expect("first channel turn should finish after approval") + .expect("first channel turn task should not panic"); + + let sent = channel_impl.sent_messages.lock().await; + assert!( + sent.iter() + .any(|entry| entry.contains("Approval required for tool `mock_price`")), + "channel should emit non-cli approval prompt" + ); + assert!( + sent.iter() + .any(|entry| entry.contains("Approved supervised execution for `mock_price`")), + "channel should acknowledge explicit approval command" + ); + assert!( + sent.iter() + .any(|entry| entry.contains("BTC is currently around")), + "tool call should execute after approval and produce final response" + ); + assert!( + sent.iter().all(|entry| !entry.contains("Denied by user.")), + "always_ask tool should not be silently denied once non-cli approval prompt path is wired" + ); + } + #[tokio::test] async fn process_channel_message_denies_approval_management_for_unlisted_sender() { let channel_impl = Arc::new(TelegramRecordingChannel::default()); From 305d9ccb7c10b1bcf48c41eec99f859ff6728f93 Mon Sep 17 00:00:00 2001 From: Argenis Date: Sat, 28 Feb 2026 23:54:26 -0500 Subject: [PATCH 09/13] fix(docs): keep install guidance canonical in README/docs (#2335) --- README.md | 12 +++++++++++- docs/README.md | 2 ++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index da4a5c3d7..d4b66ddaf 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.

Getting Started | - One-Click Setup | + One-Click Setup | Docs Hub | Docs TOC

@@ -120,6 +120,16 @@ zeroclaw chat "Hello!" For detailed setup options, see [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md). +### Installation Docs (Canonical Source) + +Use repository docs as the source of truth for install/setup instructions: + +- [README Quick Start](#quick-start) +- [docs/one-click-bootstrap.md](docs/one-click-bootstrap.md) +- [docs/getting-started/README.md](docs/getting-started/README.md) + +Issue comments can provide context, but they are not canonical installation documentation. + ## Benchmark Snapshot (ZeroClaw vs OpenClaw, Reproducible) Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware. diff --git a/docs/README.md b/docs/README.md index 05d6c6cb1..317ae8422 100644 --- a/docs/README.md +++ b/docs/README.md @@ -29,6 +29,8 @@ Localized hubs: [简体中文](i18n/zh-CN/README.md) · [日本語](i18n/ja/READ | See project PR/issue docs snapshot | [project-triage-snapshot-2026-02-18.md](project-triage-snapshot-2026-02-18.md) | | Perform i18n completion for docs changes | [i18n-guide.md](i18n-guide.md) | +Installation source-of-truth: keep install/run instructions in repository docs and README pages; issue comments are supplemental context only. + ## Quick Decision Tree (10 seconds) - Need first-time setup or install? → [getting-started/README.md](getting-started/README.md) From 20d4e1599a34ee84b3241065e3be2b43b3f8e555 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 23:00:59 -0500 Subject: [PATCH 10/13] feat(skills): add trusted symlink roots for workspace skills --- docs/commands-reference.md | 5 ++ docs/config-reference.md | 4 +- src/config/schema.rs | 5 ++ src/skills/mod.rs | 148 ++++++++++++++++++++++++++++++++++-- src/skills/symlink_tests.rs | 78 ++++++++++++++++--- 5 files changed, 223 insertions(+), 17 deletions(-) diff --git a/docs/commands-reference.md b/docs/commands-reference.md index e570d468c..4b4740997 100644 --- a/docs/commands-reference.md +++ b/docs/commands-reference.md @@ -277,6 +277,11 @@ Registry packages are installed to `~/.zeroclaw/workspace/skills//`. Use `skills audit` to manually validate a candidate skill directory (or an installed skill by name) before sharing it. +Workspace symlink policy: +- Symlinked entries under `~/.zeroclaw/workspace/skills/` are blocked by default. +- To allow shared local skill directories, set `[skills].trusted_skill_roots` in `config.toml`. +- A symlinked skill is accepted only when its resolved canonical target is inside one of the trusted roots. + Skill manifests (`SKILL.toml`) support `prompts` and `[[tools]]`; both are injected into the agent system prompt at runtime, so the model can follow skill instructions without manually reading skill files. ### `migrate` diff --git a/docs/config-reference.md b/docs/config-reference.md index 1211ccbb6..b4145909f 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -536,6 +536,7 @@ Notes: |---|---|---| | `open_skills_enabled` | `false` | Opt-in loading/sync of community `open-skills` repository | | `open_skills_dir` | unset | Optional local path for `open-skills` (defaults to `$HOME/open-skills` when enabled) | +| `trusted_skill_roots` | `[]` | Allowlist of directory roots for symlink targets in `workspace/skills/*` | | `prompt_injection_mode` | `full` | Skill prompt verbosity: `full` (inline instructions/tools) or `compact` (name/description/location only) | | `clawhub_token` | unset | Optional Bearer token for authenticated ClawhHub skill downloads | @@ -548,7 +549,8 @@ Notes: - `ZEROCLAW_SKILLS_PROMPT_MODE` accepts `full` or `compact`. - Precedence for enable flag: `ZEROCLAW_OPEN_SKILLS_ENABLED` → `skills.open_skills_enabled` in `config.toml` → default `false`. - `prompt_injection_mode = "compact"` is recommended on low-context local models to reduce startup prompt size while keeping skill files available on demand. -- Skill loading and `zeroclaw skills install` both apply a static security audit. Skills that contain symlinks, script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected. +- Symlinked workspace skills are blocked by default. Set `trusted_skill_roots` to allow local shared-skill directories after explicit trust review. +- `zeroclaw skills install` and `zeroclaw skills audit` apply a static security audit. Skills that contain script-like files, high-risk shell payload snippets, or unsafe markdown link traversal are rejected. - `clawhub_token` is sent as `Authorization: Bearer ` when downloading from ClawhHub. Obtain a token from [https://clawhub.ai](https://clawhub.ai) after signing in. Required if the API returns 429 (rate-limited) or 401 (unauthorized) for anonymous requests. **ClawhHub token example:** diff --git a/src/config/schema.rs b/src/config/schema.rs index d4d971d84..e48902f13 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1026,6 +1026,11 @@ pub struct SkillsConfig { /// If unset, defaults to `$HOME/open-skills` when enabled. #[serde(default)] pub open_skills_dir: Option, + /// Optional allowlist of canonical directory roots for workspace skill symlink targets. + /// Symlinked workspace skills are rejected unless their resolved targets are under one + /// of these roots. Accepts absolute paths and `~/` home-relative paths. + #[serde(default)] + pub trusted_skill_roots: Vec, /// Allow script-like files in skills (`.sh`, `.bash`, `.ps1`, shebang shell files). /// Default: `false` (secure by default). #[serde(default)] diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 82d467084..1982f7e91 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -80,7 +80,7 @@ fn default_version() -> String { /// Load all skills from the workspace skills directory pub fn load_skills(workspace_dir: &Path) -> Vec { - load_skills_with_open_skills_config(workspace_dir, None, None, None) + load_skills_with_open_skills_config(workspace_dir, None, None, None, None) } /// Load skills using runtime config values (preferred at runtime). @@ -90,6 +90,7 @@ pub fn load_skills_with_config(workspace_dir: &Path, config: &crate::config::Con Some(config.skills.open_skills_enabled), config.skills.open_skills_dir.as_deref(), Some(config.skills.allow_scripts), + Some(&config.skills.trusted_skill_roots), ) } @@ -98,9 +99,12 @@ fn load_skills_with_open_skills_config( config_open_skills_enabled: Option, config_open_skills_dir: Option<&str>, config_allow_scripts: Option, + config_trusted_skill_roots: Option<&[String]>, ) -> Vec { let mut skills = Vec::new(); let allow_scripts = config_allow_scripts.unwrap_or(false); + let trusted_skill_roots = + resolve_trusted_skill_roots(workspace_dir, config_trusted_skill_roots.unwrap_or(&[])); if let Some(open_skills_dir) = ensure_open_skills_repo(config_open_skills_enabled, config_open_skills_dir) @@ -108,16 +112,113 @@ fn load_skills_with_open_skills_config( skills.extend(load_open_skills(&open_skills_dir, allow_scripts)); } - skills.extend(load_workspace_skills(workspace_dir, allow_scripts)); + skills.extend(load_workspace_skills( + workspace_dir, + allow_scripts, + &trusted_skill_roots, + )); skills } -fn load_workspace_skills(workspace_dir: &Path, allow_scripts: bool) -> Vec { +fn load_workspace_skills( + workspace_dir: &Path, + allow_scripts: bool, + trusted_skill_roots: &[PathBuf], +) -> Vec { let skills_dir = workspace_dir.join("skills"); - load_skills_from_directory(&skills_dir, allow_scripts) + load_skills_from_directory(&skills_dir, allow_scripts, trusted_skill_roots) } -fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec { +fn resolve_trusted_skill_roots(workspace_dir: &Path, raw_roots: &[String]) -> Vec { + let home_dir = UserDirs::new().map(|dirs| dirs.home_dir().to_path_buf()); + let mut resolved = Vec::new(); + + for raw in raw_roots { + let trimmed = raw.trim(); + if trimmed.is_empty() { + continue; + } + + let expanded = if trimmed == "~" { + home_dir.clone().unwrap_or_else(|| PathBuf::from(trimmed)) + } else if let Some(rest) = trimmed + .strip_prefix("~/") + .or_else(|| trimmed.strip_prefix("~\\")) + { + home_dir + .as_ref() + .map(|home| home.join(rest)) + .unwrap_or_else(|| PathBuf::from(trimmed)) + } else { + PathBuf::from(trimmed) + }; + + let candidate = if expanded.is_relative() { + workspace_dir.join(expanded) + } else { + expanded + }; + + match candidate.canonicalize() { + Ok(canonical) if canonical.is_dir() => resolved.push(canonical), + Ok(canonical) => tracing::warn!( + "ignoring [skills].trusted_skill_roots entry '{}': canonical path is not a directory ({})", + trimmed, + canonical.display() + ), + Err(err) => tracing::warn!( + "ignoring [skills].trusted_skill_roots entry '{}': failed to canonicalize {} ({err})", + trimmed, + candidate.display() + ), + } + } + + resolved.sort(); + resolved.dedup(); + resolved +} + +fn enforce_workspace_skill_symlink_trust( + path: &Path, + trusted_skill_roots: &[PathBuf], +) -> Result<()> { + let canonical_target = path + .canonicalize() + .with_context(|| format!("failed to resolve skill symlink target {}", path.display()))?; + + if !canonical_target.is_dir() { + anyhow::bail!( + "symlink target is not a directory: {}", + canonical_target.display() + ); + } + + if trusted_skill_roots + .iter() + .any(|root| canonical_target.starts_with(root)) + { + return Ok(()); + } + + if trusted_skill_roots.is_empty() { + anyhow::bail!( + "symlink target {} is not allowed because [skills].trusted_skill_roots is empty", + canonical_target.display() + ); + } + + anyhow::bail!( + "symlink target {} is outside configured [skills].trusted_skill_roots", + canonical_target.display() + ); +} + +fn load_skills_from_directory( + skills_dir: &Path, + allow_scripts: bool, + trusted_skill_roots: &[PathBuf], +) -> Vec { if !skills_dir.exists() { return Vec::new(); } @@ -130,7 +231,26 @@ fn load_skills_from_directory(skills_dir: &Path, allow_scripts: bool) -> Vec meta, + Err(err) => { + tracing::warn!( + "skipping skill entry {}: failed to read metadata ({err})", + path.display() + ); + continue; + } + }; + + if metadata.file_type().is_symlink() { + if let Err(err) = enforce_workspace_skill_symlink_trust(&path, trusted_skill_roots) { + tracing::warn!( + "skipping untrusted symlinked skill entry {}: {err}", + path.display() + ); + continue; + } + } else if !metadata.is_dir() { continue; } @@ -180,7 +300,7 @@ fn load_open_skills(repo_dir: &Path, allow_scripts: bool) -> Vec { // as executable skills. let nested_skills_dir = repo_dir.join("skills"); if nested_skills_dir.is_dir() { - return load_skills_from_directory(&nested_skills_dir, allow_scripts); + return load_skills_from_directory(&nested_skills_dir, allow_scripts, &[]); } let mut skills = Vec::new(); @@ -2137,6 +2257,20 @@ pub fn handle_command(command: crate::SkillCommands, config: &crate::config::Con anyhow::bail!("Skill source or installed skill not found: {source}"); } + let trusted_skill_roots = + resolve_trusted_skill_roots(workspace_dir, &config.skills.trusted_skill_roots); + if let Ok(metadata) = std::fs::symlink_metadata(&target) { + if metadata.file_type().is_symlink() { + enforce_workspace_skill_symlink_trust(&target, &trusted_skill_roots) + .with_context(|| { + format!( + "trusted-symlink policy rejected audit target {}", + target.display() + ) + })?; + } + } + let report = audit::audit_skill_directory_with_options( &target, audit::SkillAuditOptions { diff --git a/src/skills/symlink_tests.rs b/src/skills/symlink_tests.rs index da50891a4..b7bcb726a 100644 --- a/src/skills/symlink_tests.rs +++ b/src/skills/symlink_tests.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod tests { - use crate::skills::skills_dir; + use crate::config::Config; + use crate::skills::{handle_command, load_skills_with_config, skills_dir}; + use crate::SkillCommands; use std::path::Path; use tempfile::TempDir; @@ -83,7 +85,7 @@ mod tests { } #[tokio::test] - async fn test_skills_symlink_permissions_and_safety() { + async fn test_workspace_symlink_loading_requires_trusted_roots() { let tmp = TempDir::new().unwrap(); let workspace_dir = tmp.path().join("workspace"); tokio::fs::create_dir_all(&workspace_dir).await.unwrap(); @@ -93,7 +95,6 @@ mod tests { #[cfg(unix)] { - // Test case: Symlink outside workspace should be allowed (user responsibility) let outside_dir = tmp.path().join("outside_skill"); tokio::fs::create_dir_all(&outside_dir).await.unwrap(); tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent") @@ -102,15 +103,74 @@ mod tests { let dest_link = skills_path.join("outside_skill"); let result = std::os::unix::fs::symlink(&outside_dir, &dest_link); + assert!(result.is_ok(), "symlink creation should succeed on unix"); + + let mut config = Config::default(); + config.workspace_dir = workspace_dir.clone(); + config.config_path = workspace_dir.join("config.toml"); + + let blocked = load_skills_with_config(&workspace_dir, &config); assert!( - result.is_ok(), - "Should allow symlinking to directories outside workspace" + blocked.is_empty(), + "symlinked skill should be rejected when trusted_skill_roots is empty" ); - // Should still be readable - let content = tokio::fs::read_to_string(dest_link.join("SKILL.md")).await; - assert!(content.is_ok()); - assert!(content.unwrap().contains("Outside Skill")); + config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()]; + let allowed = load_skills_with_config(&workspace_dir, &config); + assert_eq!( + allowed.len(), + 1, + "symlinked skill should load when target is inside trusted roots" + ); + assert_eq!(allowed[0].name, "outside_skill"); + } + } + + #[tokio::test] + async fn test_skills_audit_respects_trusted_symlink_roots() { + let tmp = TempDir::new().unwrap(); + let workspace_dir = tmp.path().join("workspace"); + tokio::fs::create_dir_all(&workspace_dir).await.unwrap(); + + let skills_path = skills_dir(&workspace_dir); + tokio::fs::create_dir_all(&skills_path).await.unwrap(); + + #[cfg(unix)] + { + let outside_dir = tmp.path().join("outside_skill"); + tokio::fs::create_dir_all(&outside_dir).await.unwrap(); + tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent") + .await + .unwrap(); + let link_path = skills_path.join("outside_skill"); + std::os::unix::fs::symlink(&outside_dir, &link_path).unwrap(); + + let mut config = Config::default(); + config.workspace_dir = workspace_dir.clone(); + config.config_path = workspace_dir.join("config.toml"); + + let blocked = handle_command( + SkillCommands::Audit { + source: "outside_skill".to_string(), + }, + &config, + ); + assert!( + blocked.is_err(), + "audit should reject symlink target when trusted roots are not configured" + ); + + config.skills.trusted_skill_roots = vec![tmp.path().display().to_string()]; + let allowed = handle_command( + SkillCommands::Audit { + source: "outside_skill".to_string(), + }, + &config, + ); + assert!( + allowed.is_ok(), + "audit should pass when symlink target is inside a trusted root" + ); } } } From 9ef617289fbb4594c28a19a5afa9519be566882c Mon Sep 17 00:00:00 2001 From: Preventnetworkhacking Date: Sat, 28 Feb 2026 20:14:57 -0800 Subject: [PATCH 11/13] fix(mcp): stdio transport reads server notifications as tool responses, registering 0 tools [CDV-2327] Replace single read with deadline-bounded loop that skips JSON-RPC messages where id is None (server notifications like notifications/initialized). Some MCP servers send notifications/initialized after the initialize response but before the tools/list response. The old code would read this notification as the tools/list reply, see result: None, and report 0 tools registered. The fix uses a deadline-bounded loop to skip any JSON-RPC message where id is None while preserving the total timeout across all iterations. Fixes: zeroclaw-labs/zeroclaw#2327 --- src/tools/mcp_transport.rs | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/tools/mcp_transport.rs b/src/tools/mcp_transport.rs index 61052a343..27398451c 100644 --- a/src/tools/mcp_transport.rs +++ b/src/tools/mcp_transport.rs @@ -107,12 +107,27 @@ impl McpTransportConn for StdioTransport { error: None, }); } - let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw()) - .await - .context("timeout waiting for MCP response")??; - let resp: JsonRpcResponse = serde_json::from_str(&resp_line) - .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?; - Ok(resp) + let deadline = std::time::Instant::now() + Duration::from_secs(RECV_TIMEOUT_SECS); + loop { + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + if remaining.is_zero() { + bail!("timeout waiting for MCP response"); + } + let resp_line = timeout(remaining, self.recv_raw()) + .await + .context("timeout waiting for MCP response")??; + let resp: JsonRpcResponse = serde_json::from_str(&resp_line) + .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?; + if resp.id.is_none() { + // Server-sent notification (e.g. `notifications/initialized`) — skip and + // keep waiting for the actual response to our request. + tracing::debug!( + "MCP stdio: skipping server notification while waiting for response" + ); + continue; + } + return Ok(resp); + } } async fn close(&mut self) -> Result<()> { From 404305633264efeb6dcf29cdc290918cd6ab375d Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 23:16:27 -0500 Subject: [PATCH 12/13] feat(cost): enforce preflight budget policy in agent loop --- src/agent/loop_.rs | 408 ++++++++++++++++++++++++++++++++++++++----- src/channels/mod.rs | 52 +++--- src/config/schema.rs | 105 +++++++++++ 3 files changed, 501 insertions(+), 64 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 802dd7453..568facfac 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,5 +1,7 @@ use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse}; +use crate::config::schema::{CostEnforcementMode, ModelPricing}; use crate::config::{Config, ProgressMode}; +use crate::cost::{BudgetCheck, CostTracker, UsagePeriod}; use crate::memory::{self, Memory, MemoryCategory}; use crate::multimodal; use crate::observability::{self, runtime_trace, Observer, ObserverEvent}; @@ -19,9 +21,11 @@ use rustyline::hint::Hinter; use rustyline::validate::Validator; use rustyline::{CompletionType, Config as RlConfig, Context, Editor, Helper}; use std::borrow::Cow; -use std::collections::{BTreeSet, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Write; +use std::future::Future; use std::io::Write as _; +use std::path::Path; use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; @@ -297,6 +301,7 @@ tokio::task_local! { static LOOP_DETECTION_CONFIG: LoopDetectionConfig; static SAFETY_HEARTBEAT_CONFIG: Option; static TOOL_LOOP_PROGRESS_MODE: ProgressMode; + static TOOL_LOOP_COST_ENFORCEMENT_CONTEXT: Option; } /// Configuration for periodic safety-constraint re-injection (heartbeat). @@ -308,6 +313,56 @@ pub(crate) struct SafetyHeartbeatConfig { pub interval: usize, } +#[derive(Clone)] +pub(crate) struct CostEnforcementContext { + tracker: Arc, + prices: HashMap, + mode: CostEnforcementMode, + route_down_model: Option, + reserve_percent: u8, +} + +pub(crate) fn create_cost_enforcement_context( + cost_config: &crate::config::CostConfig, + workspace_dir: &Path, +) -> Option { + if !cost_config.enabled { + return None; + } + let tracker = match CostTracker::new(cost_config.clone(), workspace_dir) { + Ok(tracker) => Arc::new(tracker), + Err(error) => { + tracing::warn!("Cost budget preflight disabled: failed to initialize tracker: {error}"); + return None; + } + }; + let route_down_model = cost_config + .enforcement + .route_down_model + .clone() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()); + Some(CostEnforcementContext { + tracker, + prices: cost_config.prices.clone(), + mode: cost_config.enforcement.mode, + route_down_model, + reserve_percent: cost_config.enforcement.reserve_percent.min(100), + }) +} + +pub(crate) async fn scope_cost_enforcement_context( + context: Option, + future: F, +) -> F::Output +where + F: Future, +{ + TOOL_LOOP_COST_ENFORCEMENT_CONTEXT + .scope(context, future) + .await +} + fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool { interval > 0 && counter > 0 && counter % interval == 0 } @@ -320,6 +375,100 @@ fn should_emit_tool_progress(mode: ProgressMode) -> bool { mode != ProgressMode::Off } +fn estimate_prompt_tokens( + messages: &[ChatMessage], + tools: Option<&[crate::tools::ToolSpec]>, +) -> u64 { + let message_chars: usize = messages + .iter() + .map(|msg| { + msg.role + .len() + .saturating_add(msg.content.chars().count()) + .saturating_add(16) + }) + .sum(); + let tool_chars: usize = tools + .map(|specs| { + specs + .iter() + .map(|spec| serde_json::to_string(spec).map_or(0, |value| value.chars().count())) + .sum() + }) + .unwrap_or(0); + let total_chars = message_chars.saturating_add(tool_chars); + let char_estimate = (total_chars as f64 / 4.0).ceil() as u64; + let framing_overhead = (messages.len() as u64).saturating_mul(6).saturating_add(64); + char_estimate.saturating_add(framing_overhead) +} + +fn lookup_model_pricing( + prices: &HashMap, + provider: &str, + model: &str, +) -> (f64, f64) { + let full_name = format!("{provider}/{model}"); + if let Some(pricing) = prices.get(&full_name) { + return (pricing.input, pricing.output); + } + if let Some(pricing) = prices.get(model) { + return (pricing.input, pricing.output); + } + for (key, pricing) in prices { + let key_model = key.split('/').next_back().unwrap_or(key); + if model.starts_with(key_model) || key_model.starts_with(model) { + return (pricing.input, pricing.output); + } + let normalized_model = model.replace('-', "."); + let normalized_key = key_model.replace('-', "."); + if normalized_model.contains(&normalized_key) || normalized_key.contains(&normalized_model) + { + return (pricing.input, pricing.output); + } + } + (3.0, 15.0) +} + +fn estimate_request_cost_usd( + context: &CostEnforcementContext, + provider: &str, + model: &str, + messages: &[ChatMessage], + tools: Option<&[crate::tools::ToolSpec]>, +) -> f64 { + let reserve_multiplier = 1.0 + (f64::from(context.reserve_percent) / 100.0); + let input_tokens = estimate_prompt_tokens(messages, tools); + let output_tokens = (input_tokens / 4).max(256); + let input_tokens = ((input_tokens as f64) * reserve_multiplier).ceil() as u64; + let output_tokens = ((output_tokens as f64) * reserve_multiplier).ceil() as u64; + + let (input_price, output_price) = lookup_model_pricing(&context.prices, provider, model); + let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price.max(0.0); + let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price.max(0.0); + input_cost + output_cost +} + +fn usage_period_label(period: UsagePeriod) -> &'static str { + match period { + UsagePeriod::Session => "session", + UsagePeriod::Day => "daily", + UsagePeriod::Month => "monthly", + } +} + +fn budget_exceeded_message( + model: &str, + estimated_cost_usd: f64, + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, +) -> String { + format!( + "Budget enforcement blocked request for model '{model}': projected cost (+${estimated_cost_usd:.4}) exceeds {period_label} limit (${limit_usd:.2}, current ${current_usd:.2}).", + period_label = usage_period_label(period) + ) +} + #[derive(Debug, Clone)] struct ProgressEntry { name: String, @@ -894,7 +1043,12 @@ pub(crate) async fn run_tool_call_loop( let progress_mode = TOOL_LOOP_PROGRESS_MODE .try_with(|mode| *mode) .unwrap_or(ProgressMode::Verbose); + let cost_enforcement_context = TOOL_LOOP_COST_ENFORCEMENT_CONTEXT + .try_with(Clone::clone) + .ok() + .flatten(); let mut progress_tracker = ProgressTracker::default(); + let mut active_model = model.to_string(); let bypass_non_cli_approval_for_turn = approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once()); if bypass_non_cli_approval_for_turn { @@ -902,7 +1056,7 @@ pub(crate) async fn run_tool_call_loop( "approval_bypass_one_time_all_tools_consumed", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(true), Some("consumed one-time non-cli allow-all approval token"), @@ -954,6 +1108,13 @@ pub(crate) async fn run_tool_call_loop( request_messages.push(ChatMessage::user(reminder)); } } + // Unified path via Provider::chat so provider-specific native tool logic + // (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored. + let request_tools = if use_native_tools { + Some(tool_specs.as_slice()) + } else { + None + }; // ── Progress: LLM thinking ──────────────────────────── if should_emit_verbose_progress(progress_mode) { @@ -967,16 +1128,175 @@ pub(crate) async fn run_tool_call_loop( } } + if let Some(cost_ctx) = cost_enforcement_context.as_ref() { + let mut estimated_cost_usd = estimate_request_cost_usd( + cost_ctx, + provider_name, + active_model.as_str(), + &request_messages, + request_tools, + ); + + let mut budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) { + Ok(check) => Some(check), + Err(error) => { + tracing::warn!("Cost preflight check failed: {error}"); + None + } + }; + + if matches!(cost_ctx.mode, CostEnforcementMode::RouteDown) + && matches!(budget_check, Some(BudgetCheck::Exceeded { .. })) + { + if let Some(route_down_model) = cost_ctx + .route_down_model + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + { + if route_down_model != active_model { + let previous_model = active_model.clone(); + active_model = route_down_model.to_string(); + estimated_cost_usd = estimate_request_cost_usd( + cost_ctx, + provider_name, + active_model.as_str(), + &request_messages, + request_tools, + ); + budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) { + Ok(check) => Some(check), + Err(error) => { + tracing::warn!( + "Cost preflight check failed after route-down: {error}" + ); + None + } + }; + runtime_trace::record_event( + "cost_budget_route_down", + Some(channel_name), + Some(provider_name), + Some(active_model.as_str()), + Some(&turn_id), + Some(true), + Some("budget exceeded on primary model; route-down candidate applied"), + serde_json::json!({ + "iteration": iteration + 1, + "from_model": previous_model, + "to_model": active_model, + "estimated_cost_usd": estimated_cost_usd, + }), + ); + } + } + } + + if let Some(check) = budget_check { + match check { + BudgetCheck::Allowed => {} + BudgetCheck::Warning { + current_usd, + limit_usd, + period, + } => { + tracing::warn!( + model = active_model.as_str(), + period = usage_period_label(period), + current_usd, + limit_usd, + estimated_cost_usd, + "Cost budget warning threshold reached" + ); + runtime_trace::record_event( + "cost_budget_warning", + Some(channel_name), + Some(provider_name), + Some(active_model.as_str()), + Some(&turn_id), + Some(true), + Some("budget warning threshold reached"), + serde_json::json!({ + "iteration": iteration + 1, + "period": usage_period_label(period), + "current_usd": current_usd, + "limit_usd": limit_usd, + "estimated_cost_usd": estimated_cost_usd, + }), + ); + } + BudgetCheck::Exceeded { + current_usd, + limit_usd, + period, + } => match cost_ctx.mode { + CostEnforcementMode::Warn => { + tracing::warn!( + model = active_model.as_str(), + period = usage_period_label(period), + current_usd, + limit_usd, + estimated_cost_usd, + "Cost budget exceeded (warn mode): continuing request" + ); + runtime_trace::record_event( + "cost_budget_exceeded_warn_mode", + Some(channel_name), + Some(provider_name), + Some(active_model.as_str()), + Some(&turn_id), + Some(true), + Some("budget exceeded but proceeding due to warn mode"), + serde_json::json!({ + "iteration": iteration + 1, + "period": usage_period_label(period), + "current_usd": current_usd, + "limit_usd": limit_usd, + "estimated_cost_usd": estimated_cost_usd, + }), + ); + } + CostEnforcementMode::RouteDown | CostEnforcementMode::Block => { + let message = budget_exceeded_message( + active_model.as_str(), + estimated_cost_usd, + current_usd, + limit_usd, + period, + ); + runtime_trace::record_event( + "cost_budget_blocked", + Some(channel_name), + Some(provider_name), + Some(active_model.as_str()), + Some(&turn_id), + Some(false), + Some(&message), + serde_json::json!({ + "iteration": iteration + 1, + "period": usage_period_label(period), + "current_usd": current_usd, + "limit_usd": limit_usd, + "estimated_cost_usd": estimated_cost_usd, + }), + ); + return Err(anyhow::anyhow!(message)); + } + }, + } + } + } + observer.record_event(&ObserverEvent::LlmRequest { provider: provider_name.to_string(), - model: model.to_string(), + model: active_model.clone(), messages_count: history.len(), }); runtime_trace::record_event( "llm_request", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), None, None, @@ -990,23 +1310,15 @@ pub(crate) async fn run_tool_call_loop( // Fire void hook before LLM call if let Some(hooks) = hooks { - hooks.fire_llm_input(history, model).await; + hooks.fire_llm_input(history, active_model.as_str()).await; } - // Unified path via Provider::chat so provider-specific native tool logic - // (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored. - let request_tools = if use_native_tools { - Some(tool_specs.as_slice()) - } else { - None - }; - let chat_future = provider.chat( ChatRequest { messages: &request_messages, tools: request_tools, }, - model, + active_model.as_str(), temperature, ); @@ -1036,7 +1348,7 @@ pub(crate) async fn run_tool_call_loop( observer.record_event(&ObserverEvent::LlmResponse { provider: provider_name.to_string(), - model: model.to_string(), + model: active_model.clone(), duration: llm_started_at.elapsed(), success: true, error_message: None, @@ -1066,7 +1378,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_parse_issue", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some(parse_issue), @@ -1084,7 +1396,7 @@ pub(crate) async fn run_tool_call_loop( "llm_response", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(true), None, @@ -1135,7 +1447,7 @@ pub(crate) async fn run_tool_call_loop( let safe_error = crate::providers::sanitize_api_error(&e.to_string()); observer.record_event(&ObserverEvent::LlmResponse { provider: provider_name.to_string(), - model: model.to_string(), + model: active_model.clone(), duration: llm_started_at.elapsed(), success: false, error_message: Some(safe_error.clone()), @@ -1146,7 +1458,7 @@ pub(crate) async fn run_tool_call_loop( "llm_response", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some(&safe_error), @@ -1199,7 +1511,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_followthrough_retry", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(true), Some("llm response implied follow-up action but emitted no tool call"), @@ -1227,7 +1539,7 @@ pub(crate) async fn run_tool_call_loop( "turn_final_response", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(true), None, @@ -1303,7 +1615,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_result", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some(&cancelled), @@ -1345,7 +1657,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_result", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some(&blocked), @@ -1385,7 +1697,7 @@ pub(crate) async fn run_tool_call_loop( "approval_bypass_non_cli_session_grant", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(true), Some("using runtime non-cli session approval grant"), @@ -1442,7 +1754,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_result", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some(&denied), @@ -1476,7 +1788,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_result", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some(&duplicate), @@ -1504,7 +1816,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_start", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), None, None, @@ -1564,7 +1876,7 @@ pub(crate) async fn run_tool_call_loop( "tool_call_result", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(outcome.success), outcome.error_reason.as_deref(), @@ -1676,7 +1988,7 @@ pub(crate) async fn run_tool_call_loop( "loop_detected_warning", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some("loop pattern detected, injecting self-correction prompt"), @@ -1698,7 +2010,7 @@ pub(crate) async fn run_tool_call_loop( "loop_detected_hard_stop", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some("loop persisted after warning, stopping early"), @@ -1718,7 +2030,7 @@ pub(crate) async fn run_tool_call_loop( "tool_loop_exhausted", Some(channel_name), Some(provider_name), - Some(model), + Some(active_model.as_str()), Some(&turn_id), Some(false), Some("agent exceeded maximum tool iterations"), @@ -2151,6 +2463,8 @@ pub async fn run( // ── Execute ────────────────────────────────────────────────── let start = Instant::now(); + let cost_enforcement_context = + create_cost_enforcement_context(&config.cost, &config.workspace_dir); let mut final_output = String::new(); @@ -2197,8 +2511,9 @@ pub async fn run( } else { None }; - let response = SAFETY_HEARTBEAT_CONFIG - .scope( + let response = scope_cost_enforcement_context( + cost_enforcement_context.clone(), + SAFETY_HEARTBEAT_CONFIG.scope( hb_cfg, LOOP_DETECTION_CONFIG.scope( ld_cfg, @@ -2221,8 +2536,9 @@ pub async fn run( &[], ), ), - ) - .await?; + ), + ) + .await?; final_output = response.clone(); if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS { let assistant_key = autosave_memory_key("assistant_resp"); @@ -2374,8 +2690,9 @@ pub async fn run( } else { None }; - let response = match SAFETY_HEARTBEAT_CONFIG - .scope( + let response = match scope_cost_enforcement_context( + cost_enforcement_context.clone(), + SAFETY_HEARTBEAT_CONFIG.scope( hb_cfg, LOOP_DETECTION_CONFIG.scope( ld_cfg, @@ -2398,8 +2715,9 @@ pub async fn run( &[], ), ), - ) - .await + ), + ) + .await { Ok(resp) => resp, Err(e) => { @@ -2682,6 +3000,8 @@ pub async fn process_message_with_session( ChatMessage::user(&enriched), ]; + let cost_enforcement_context = + create_cost_enforcement_context(&config.cost, &config.workspace_dir); let hb_cfg = if config.agent.safety_heartbeat_interval > 0 { Some(SafetyHeartbeatConfig { body: security.summary_for_heartbeat(), @@ -2690,8 +3010,9 @@ pub async fn process_message_with_session( } else { None }; - SAFETY_HEARTBEAT_CONFIG - .scope( + scope_cost_enforcement_context( + cost_enforcement_context, + SAFETY_HEARTBEAT_CONFIG.scope( hb_cfg, agent_turn( provider.as_ref(), @@ -2705,8 +3026,9 @@ pub async fn process_message_with_session( &config.multimodal, config.agent.max_tool_iterations, ), - ) - .await + ), + ) + .await } #[cfg(test)] diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 1a5251895..e085f6cd5 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -250,6 +250,7 @@ struct ChannelRuntimeDefaults { api_key: Option, api_url: Option, reliability: crate::config::ReliabilityConfig, + cost: crate::config::CostConfig, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1054,6 +1055,7 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults { api_key: config.api_key.clone(), api_url: config.api_url.clone(), reliability: config.reliability.clone(), + cost: config.cost.clone(), } } @@ -1099,6 +1101,7 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau api_key: ctx.api_key.clone(), api_url: ctx.api_url.clone(), reliability: (*ctx.reliability).clone(), + cost: crate::config::CostConfig::default(), } } @@ -3665,6 +3668,10 @@ or tune thresholds in config.", let timeout_budget_secs = channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations); + let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context( + &runtime_defaults.cost, + ctx.workspace_dir.as_path(), + ); let (approval_prompt_tx, mut approval_prompt_rx) = tokio::sync::mpsc::unbounded_channel::(); @@ -3706,31 +3713,33 @@ or tune thresholds in config.", } else { None }; - let llm_result = tokio::select! { () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, result = tokio::time::timeout( Duration::from_secs(timeout_budget_secs), - run_tool_call_loop_with_non_cli_approval_context( - active_provider.as_ref(), - &mut history, - ctx.tools_registry.as_ref(), - ctx.observer.as_ref(), - route.provider.as_str(), - route.model.as_str(), - runtime_defaults.temperature, - true, - Some(ctx.approval_manager.as_ref()), - msg.channel.as_str(), - non_cli_approval_context, - &ctx.multimodal, - ctx.max_tool_iterations, - Some(cancellation_token.clone()), - delta_tx, - ctx.hooks.as_deref(), - &excluded_tools_snapshot, - progress_mode, - ctx.safety_heartbeat.clone(), + crate::agent::loop_::scope_cost_enforcement_context( + cost_enforcement_context, + run_tool_call_loop_with_non_cli_approval_context( + active_provider.as_ref(), + &mut history, + ctx.tools_registry.as_ref(), + ctx.observer.as_ref(), + route.provider.as_str(), + route.model.as_str(), + runtime_defaults.temperature, + true, + Some(ctx.approval_manager.as_ref()), + msg.channel.as_str(), + non_cli_approval_context, + &ctx.multimodal, + ctx.max_tool_iterations, + Some(cancellation_token.clone()), + delta_tx, + ctx.hooks.as_deref(), + &excluded_tools_snapshot, + progress_mode, + ctx.safety_heartbeat.clone(), + ), ), ) => LlmExecutionResult::Completed(result), }; @@ -9401,6 +9410,7 @@ BTC is currently around $65,000 based on latest tool output."# api_key: None, api_url: None, reliability: crate::config::ReliabilityConfig::default(), + cost: crate::config::CostConfig::default(), }, perplexity_filter: crate::config::PerplexityFilterConfig::default(), outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(), diff --git a/src/config/schema.rs b/src/config/schema.rs index e48902f13..b00538eaa 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1200,6 +1200,58 @@ pub struct CostConfig { /// Per-model pricing (USD per 1M tokens) #[serde(default)] pub prices: std::collections::HashMap, + + /// Runtime budget enforcement policy (`[cost.enforcement]`). + #[serde(default)] + pub enforcement: CostEnforcementConfig, +} + +/// Budget enforcement behavior when projected spend approaches/exceeds limits. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CostEnforcementMode { + /// Log warnings only; never block the request. + Warn, + /// Attempt one downgrade to a cheaper route/model, then block if still over budget. + RouteDown, + /// Block immediately when projected spend exceeds configured limits. + Block, +} + +fn default_cost_enforcement_mode() -> CostEnforcementMode { + CostEnforcementMode::Warn +} + +/// Runtime budget enforcement controls (`[cost.enforcement]`). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct CostEnforcementConfig { + /// Enforcement behavior. Default: `warn`. + #[serde(default = "default_cost_enforcement_mode")] + pub mode: CostEnforcementMode, + /// Optional fallback model (or `hint:*`) when `mode = "route_down"`. + #[serde(default = "default_route_down_model")] + pub route_down_model: Option, + /// Extra reserve added to token/cost estimates (percentage, 0-100). Default: `10`. + #[serde(default = "default_cost_reserve_percent")] + pub reserve_percent: u8, +} + +fn default_route_down_model() -> Option { + Some("hint:fast".to_string()) +} + +fn default_cost_reserve_percent() -> u8 { + 10 +} + +impl Default for CostEnforcementConfig { + fn default() -> Self { + Self { + mode: default_cost_enforcement_mode(), + route_down_model: default_route_down_model(), + reserve_percent: default_cost_reserve_percent(), + } + } } /// Per-model pricing entry (USD per 1M tokens). @@ -1235,6 +1287,7 @@ impl Default for CostConfig { warn_at_percent: default_warn_percent(), allow_override: false, prices: get_default_pricing(), + enforcement: CostEnforcementConfig::default(), } } } @@ -7769,6 +7822,14 @@ impl Config { anyhow::bail!("web_search.timeout_secs must be greater than 0"); } + // Cost + if self.cost.warn_at_percent > 100 { + anyhow::bail!("cost.warn_at_percent must be between 0 and 100"); + } + if self.cost.enforcement.reserve_percent > 100 { + anyhow::bail!("cost.enforcement.reserve_percent must be between 0 and 100"); + } + // Scheduler if self.scheduler.max_concurrent == 0 { anyhow::bail!("scheduler.max_concurrent must be greater than 0"); @@ -13743,4 +13804,48 @@ sensitivity = 0.9 .validate() .expect("disabled coordination should allow empty lead agent"); } + + #[test] + async fn cost_enforcement_defaults_are_stable() { + let cost = CostConfig::default(); + assert_eq!(cost.enforcement.mode, CostEnforcementMode::Warn); + assert_eq!( + cost.enforcement.route_down_model.as_deref(), + Some("hint:fast") + ); + assert_eq!(cost.enforcement.reserve_percent, 10); + } + + #[test] + async fn cost_enforcement_config_parses_route_down_mode() { + let parsed: CostConfig = toml::from_str( + r#" +enabled = true + +[enforcement] +mode = "route_down" +route_down_model = "hint:fast" +reserve_percent = 15 +"#, + ) + .expect("cost enforcement should parse"); + + assert!(parsed.enabled); + assert_eq!(parsed.enforcement.mode, CostEnforcementMode::RouteDown); + assert_eq!( + parsed.enforcement.route_down_model.as_deref(), + Some("hint:fast") + ); + assert_eq!(parsed.enforcement.reserve_percent, 15); + } + + #[test] + async fn validation_rejects_cost_enforcement_reserve_over_100() { + let mut config = Config::default(); + config.cost.enforcement.reserve_percent = 150; + let err = config + .validate() + .expect_err("expected cost.enforcement.reserve_percent validation failure"); + assert!(err.to_string().contains("cost.enforcement.reserve_percent")); + } } From 4e70abf407aa66f635365c3dcb36c41eba107521 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 23:45:42 -0500 Subject: [PATCH 13/13] fix(cost): validate route_down hint against model routes --- src/config/schema.rs | 62 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/src/config/schema.rs b/src/config/schema.rs index b00538eaa..c99c31779 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -7829,6 +7829,36 @@ impl Config { if self.cost.enforcement.reserve_percent > 100 { anyhow::bail!("cost.enforcement.reserve_percent must be between 0 and 100"); } + if matches!(self.cost.enforcement.mode, CostEnforcementMode::RouteDown) { + let route_down_model = self + .cost + .enforcement + .route_down_model + .as_deref() + .map(str::trim) + .filter(|model| !model.is_empty()) + .ok_or_else(|| { + anyhow::anyhow!( + "cost.enforcement.route_down_model must be set when mode is route_down" + ) + })?; + + if let Some(route_hint) = route_down_model + .strip_prefix("hint:") + .map(str::trim) + .filter(|hint| !hint.is_empty()) + { + if !self + .model_routes + .iter() + .any(|route| route.hint.trim() == route_hint) + { + anyhow::bail!( + "cost.enforcement.route_down_model uses hint '{route_hint}', but no matching [[model_routes]] entry exists" + ); + } + } + } // Scheduler if self.scheduler.max_concurrent == 0 { @@ -13848,4 +13878,36 @@ reserve_percent = 15 .expect_err("expected cost.enforcement.reserve_percent validation failure"); assert!(err.to_string().contains("cost.enforcement.reserve_percent")); } + + #[test] + async fn validation_rejects_route_down_hint_without_matching_route() { + let mut config = Config::default(); + config.cost.enforcement.mode = CostEnforcementMode::RouteDown; + config.cost.enforcement.route_down_model = Some("hint:fast".to_string()); + let err = config + .validate() + .expect_err("route_down hint should require a matching model route"); + assert!(err + .to_string() + .contains("cost.enforcement.route_down_model uses hint 'fast'")); + } + + #[test] + async fn validation_accepts_route_down_hint_with_matching_route() { + let mut config = Config::default(); + config.cost.enforcement.mode = CostEnforcementMode::RouteDown; + config.cost.enforcement.route_down_model = Some("hint:fast".to_string()); + config.model_routes = vec![ModelRouteConfig { + hint: "fast".to_string(), + provider: "openrouter".to_string(), + model: "openai/gpt-4.1-mini".to_string(), + api_key: None, + max_tokens: None, + transport: None, + }]; + + config + .validate() + .expect("matching route_down hint route should validate"); + } }