diff --git a/Cargo.toml b/Cargo.toml index 98dee2f67..7d7906d0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -219,6 +219,8 @@ probe = ["dep:probe-rs"] rag-pdf = ["dep:pdf-extract"] # whatsapp-web = Native WhatsApp Web client with custom rusqlite storage backend whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-proto", "dep:wa-rs-ureq-http", "dep:wa-rs-tokio-transport", "dep:serde-big-array", "dep:prost", "dep:qrcode"] +# Legacy opt-in live integration tests for removed quota tools. +quota-tools-live = [] [profile.release] opt-level = "z" # Optimize for size diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..4383be486 --- /dev/null +++ b/build.rs @@ -0,0 +1,38 @@ +use std::fs; +use std::path::PathBuf; + +const PLACEHOLDER_INDEX_HTML: &str = r#" + + + + + ZeroClaw Dashboard Placeholder + + +
+

ZeroClaw dashboard assets are not built

+

Run the web build to replace this placeholder with the real dashboard.

+
+ + +"#; + +fn main() { + let manifest_dir = + PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR missing")); + let dist_dir = manifest_dir.join("web").join("dist"); + let index_path = dist_dir.join("index.html"); + + println!("cargo:rerun-if-changed=web/dist"); + + if index_path.exists() { + return; + } + + fs::create_dir_all(&dist_dir).expect("failed to create web/dist placeholder directory"); + fs::write(&index_path, PLACEHOLDER_INDEX_HTML) + .expect("failed to write placeholder web/dist/index.html"); + println!( + "cargo:warning=web/dist was missing; generated a placeholder dashboard so the Rust build can continue" + ); +} diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 2ee352490..803c611d2 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,5 +1,7 @@ use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse}; +use crate::config::schema::ModelPricing; use crate::config::Config; +use crate::cost::{CostTracker, TokenUsage as CostTokenUsage}; use crate::memory::{self, Memory, MemoryCategory}; use crate::multimodal; use crate::observability::{self, runtime_trace, Observer, ObserverEvent}; @@ -15,7 +17,7 @@ use anyhow::Result; use futures_util::StreamExt; use regex::{Regex, RegexSet}; use rustyline::error::ReadlineError; -use std::collections::{BTreeSet, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Write; use std::io::Write as _; use std::sync::{Arc, LazyLock}; @@ -194,6 +196,84 @@ tokio::task_local! { static TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT: Option; } +#[derive(Clone)] +pub(crate) struct ToolLoopCostTrackingContext { + tracker: Arc, + prices: Arc>, +} + +impl ToolLoopCostTrackingContext { + pub(crate) fn new( + tracker: Arc, + prices: Arc>, + ) -> Self { + Self { tracker, prices } + } +} + +tokio::task_local! { + static TOOL_LOOP_COST_TRACKING_CONTEXT: Option; +} + +fn lookup_model_pricing<'a>( + prices: &'a HashMap, + provider_name: &str, + model: &str, +) -> Option<&'a ModelPricing> { + prices + .get(model) + .or_else(|| prices.get(&format!("{provider_name}/{model}"))) + .or_else(|| { + model + .rsplit_once('/') + .and_then(|(_, suffix)| prices.get(suffix)) + }) +} + +fn record_tool_loop_cost_usage( + provider_name: &str, + model: &str, + usage: &crate::providers::traits::TokenUsage, +) -> Option<(u64, f64)> { + let input_tokens = usage.input_tokens.unwrap_or(0); + let output_tokens = usage.output_tokens.unwrap_or(0); + let total_tokens = input_tokens.saturating_add(output_tokens); + if total_tokens == 0 { + return None; + } + + let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT + .try_with(Clone::clone) + .ok() + .flatten()?; + let pricing = lookup_model_pricing(&ctx.prices, provider_name, model); + let cost_usage = CostTokenUsage::new( + model, + input_tokens, + output_tokens, + pricing.map_or(0.0, |entry| entry.input), + pricing.map_or(0.0, |entry| entry.output), + ); + + if pricing.is_none() { + tracing::debug!( + provider = provider_name, + model, + "Cost tracking recorded token usage with zero pricing because no model price was configured" + ); + } + + if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) { + tracing::warn!( + provider = provider_name, + model, + "Failed to record cost tracking usage: {error}" + ); + } + + Some((cost_usage.total_tokens, cost_usage.cost_usd)) +} + /// Extract a short hint from tool call arguments for progress display. fn truncate_tool_args_for_progress(name: &str, args: &serde_json::Value, max_len: usize) -> String { let hint = match name { @@ -475,6 +555,31 @@ fn build_tool_unavailable_retry_prompt(tool_specs: &[crate::tools::ToolSpec]) -> ) } +fn display_text_for_turn( + response_text: &str, + parsed_text: &str, + tool_call_count: usize, + native_tool_call_count: usize, +) -> String { + if tool_call_count == 0 { + return if parsed_text.is_empty() { + response_text.to_string() + } else { + parsed_text.to_string() + }; + } + + if !parsed_text.is_empty() { + return parsed_text.to_string(); + } + + if native_tool_call_count > 0 { + return response_text.to_string(); + } + + String::new() +} + #[derive(Debug)] pub(crate) struct ToolLoopCancelled; @@ -734,6 +839,7 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context( approval: Option<&ApprovalManager>, channel_name: &str, non_cli_approval_context: Option, + cost_tracking_context: Option, multimodal_config: &crate::config::MultimodalConfig, max_tool_iterations: usize, cancellation_token: Option, @@ -750,23 +856,26 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context( non_cli_approval_context, TOOL_LOOP_REPLY_TARGET.scope( reply_target, - run_tool_call_loop( - provider, - history, - tools_registry, - observer, - provider_name, - model, - temperature, - silent, - approval, - channel_name, - multimodal_config, - max_tool_iterations, - cancellation_token, - on_delta, - hooks, - excluded_tools, + TOOL_LOOP_COST_TRACKING_CONTEXT.scope( + cost_tracking_context, + run_tool_call_loop( + provider, + history, + tools_registry, + observer, + provider_name, + model, + temperature, + silent, + approval, + channel_name, + multimodal_config, + max_tool_iterations, + cancellation_token, + on_delta, + hooks, + excluded_tools, + ), ), ), ) @@ -1019,6 +1128,11 @@ pub(crate) async fn run_tool_call_loop( output_tokens: resp_output_tokens, }); + let _ = resp + .usage + .as_ref() + .and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage)); + let response_text = resp.text_or_empty().to_string(); // First try native structured tool calls (OpenAI-format). // Fall back to text-based parsing (XML tags, markdown blocks, @@ -1135,11 +1249,12 @@ pub(crate) async fn run_tool_call_loop( } }; - let display_text = if parsed_text.is_empty() { - response_text.clone() - } else { - parsed_text - }; + let display_text = display_text_for_turn( + &response_text, + &parsed_text, + tool_calls.len(), + native_tool_calls.len(), + ); // ── Progress: LLM responded ───────────────────────────── if let Some(ref tx) = on_delta { @@ -1845,7 +1960,7 @@ pub async fn run( } else { (None, None) }; - let mut tools_registry = tools::all_tools_with_runtime( + let (mut tool_arcs, shared_tool_registry) = tools::all_tools_with_runtime_arcs( Arc::new(config.clone()), &security, runtime, @@ -1861,12 +1976,16 @@ pub async fn run( &config, ); - let peripheral_tools: Vec> = + let peripheral_tools: Vec> = crate::peripherals::create_peripheral_tools(&config.peripherals).await?; if !peripheral_tools.is_empty() { tracing::info!(count = peripheral_tools.len(), "Peripheral tools added"); - tools_registry.extend(peripheral_tools); + tool_arcs.extend(peripheral_tools); + if let Some(shared_registry) = shared_tool_registry.as_ref() { + tools::sync_shared_tool_registry(shared_registry, &tool_arcs); + } } + let tools_registry = tools::boxed_registry_from_arcs(tool_arcs); // ── Resolve provider ───────────────────────────────────────── let provider_name = provider_override @@ -2333,7 +2452,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { } else { (None, None) }; - let mut tools_registry = tools::all_tools_with_runtime( + let (mut tool_arcs, shared_tool_registry) = tools::all_tools_with_runtime_arcs( Arc::new(config.clone()), &security, runtime, @@ -2348,9 +2467,13 @@ pub async fn process_message(config: Config, message: &str) -> Result { config.api_key.as_deref(), &config, ); - let peripheral_tools: Vec> = + let peripheral_tools: Vec> = crate::peripherals::create_peripheral_tools(&config.peripherals).await?; - tools_registry.extend(peripheral_tools); + tool_arcs.extend(peripheral_tools); + if let Some(shared_registry) = shared_tool_registry.as_ref() { + tools::sync_shared_tool_registry(shared_registry, &tool_arcs); + } + let tools_registry = tools::boxed_registry_from_arcs(tool_arcs); let provider_name = config.default_provider.as_deref().unwrap_or("openrouter"); let model_name = config @@ -2508,10 +2631,11 @@ mod tests { use super::*; use async_trait::async_trait; use base64::{engine::general_purpose::STANDARD, Engine as _}; - use std::collections::VecDeque; + use std::collections::{HashMap, VecDeque}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; + use tempfile::TempDir; #[test] fn test_scrub_credentials() { @@ -2584,6 +2708,37 @@ mod tests { assert!(args.get("delivery").is_none()); } + #[test] + fn display_text_for_turn_hides_prompt_guided_tool_payloads() { + let display = display_text_for_turn( + "{\"name\":\"shell\",\"arguments\":{\"command\":\"date\"}}", + "", + 1, + 0, + ); + + assert!(display.is_empty()); + } + + #[test] + fn display_text_for_turn_preserves_prompt_guided_preface_text() { + let display = display_text_for_turn( + "Let me check.\n{\"name\":\"shell\",\"arguments\":{\"command\":\"date\"}}", + "Let me check.", + 1, + 0, + ); + + assert_eq!(display, "Let me check."); + } + + #[test] + fn display_text_for_turn_preserves_native_tool_preface_text() { + let display = display_text_for_turn("Let me check.", "", 1, 1); + + assert_eq!(display, "Let me check."); + } + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use crate::observability::NoopObserver; use crate::providers::router::{Route, RouterProvider}; @@ -2591,8 +2746,6 @@ mod tests { use crate::providers::ChatResponse; use crate::runtime::NativeRuntime; use crate::security::{AutonomyLevel, SecurityPolicy, ShellRedirectPolicy}; - use tempfile::TempDir; - struct NonVisionProvider { calls: Arc, } @@ -3472,6 +3625,7 @@ mod tests { reply_target: "chat-approval".to_string(), prompt_tx, }), + None, &crate::config::MultimodalConfig::default(), 4, None, @@ -3491,6 +3645,86 @@ mod tests { ); } + #[tokio::test] + async fn run_tool_call_loop_records_cost_usage_when_tracking_context_is_scoped() { + let provider = ScriptedProvider { + responses: Arc::new(Mutex::new(VecDeque::from([ChatResponse { + text: Some("done".to_string()), + tool_calls: Vec::new(), + usage: Some(crate::providers::traits::TokenUsage { + input_tokens: Some(1_200), + output_tokens: Some(300), + }), + reasoning_content: None, + }]))), + capabilities: ProviderCapabilities::default(), + }; + let observer = NoopObserver; + let workspace = TempDir::new().expect("temp workspace should be created"); + let mut cost_config = crate::config::CostConfig { + enabled: true, + ..crate::config::CostConfig::default() + }; + cost_config.prices = HashMap::from([( + "mock-provider/mock-model".to_string(), + ModelPricing { + input: 2.0, + output: 4.0, + }, + )]); + let tracker = Arc::new( + CostTracker::new(cost_config.clone(), workspace.path()) + .expect("cost tracker should initialize"), + ); + let mut history = vec![ + ChatMessage::system("test-system"), + ChatMessage::user("hello"), + ]; + + let result = run_tool_call_loop_with_non_cli_approval_context( + &provider, + &mut history, + &[], + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "telegram", + None, + Some(ToolLoopCostTrackingContext::new( + Arc::clone(&tracker), + Arc::new(cost_config.prices.clone()), + )), + &crate::config::MultimodalConfig::default(), + 2, + None, + None, + None, + &[], + ) + .await + .expect("tool loop should succeed"); + + assert_eq!(result, "done"); + + let summary = tracker + .get_summary() + .expect("cost summary should be readable"); + assert_eq!(summary.request_count, 1); + assert_eq!(summary.total_tokens, 1_500); + assert!(summary.session_cost_usd > 0.0); + assert_eq!( + summary + .by_model + .get("mock-model") + .expect("model stats should exist") + .total_tokens, + 1_500 + ); + } + #[tokio::test] async fn run_tool_call_loop_consumes_one_time_non_cli_allow_all_token() { let provider = ScriptedProvider::from_text_responses(vec![ diff --git a/src/channels/mod.rs b/src/channels/mod.rs index fc3a00e46..e5922e5f4 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -72,9 +72,12 @@ pub use whatsapp_web::WhatsAppWebChannel; use crate::agent::loop_::{ build_shell_policy_instructions, build_tool_instructions_from_specs, run_tool_call_loop_with_non_cli_approval_context, scrub_credentials, NonCliApprovalContext, + ToolLoopCostTrackingContext, }; use crate::approval::{ApprovalManager, ApprovalResponse, PendingApprovalError}; +use crate::config::schema::ModelPricing; use crate::config::{Config, NonCliNaturalLanguageApprovalMode}; +use crate::cost::CostTracker; use crate::identity; use crate::memory::{self, Memory}; use crate::observability::{self, runtime_trace, Observer}; @@ -207,6 +210,12 @@ struct RuntimeConfigState { last_applied_stamp: Option, } +#[derive(Clone)] +struct ChannelCostTrackingState { + tracker: Arc, + prices: Arc>, +} + #[derive(Debug, Clone)] struct RuntimeAutonomyPolicy { auto_approve: Vec, @@ -223,6 +232,11 @@ fn runtime_config_store() -> &'static Mutex STORE.get_or_init(|| Mutex::new(HashMap::new())) } +fn channel_cost_tracking_store() -> &'static Mutex> { + static STORE: OnceLock>> = OnceLock::new(); + STORE.get_or_init(|| Mutex::new(None)) +} + const SYSTEMD_STATUS_ARGS: [&str; 3] = ["--user", "is-active", "zeroclaw.service"]; const SYSTEMD_RESTART_ARGS: [&str; 3] = ["--user", "restart", "zeroclaw.service"]; const OPENRC_STATUS_ARGS: [&str; 2] = ["zeroclaw", "status"]; @@ -3309,6 +3323,14 @@ semantic_match={:.2} (threshold {:.2}), category={}.", sender: msg.sender.clone(), message_id: msg.id.clone(), }; + let cost_tracking_context = { + let store = channel_cost_tracking_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + store + .clone() + .map(|state| ToolLoopCostTrackingContext::new(state.tracker, state.prices)) + }; let llm_result = tokio::select! { () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, result = tokio::time::timeout( @@ -3327,6 +3349,7 @@ semantic_match={:.2} (threshold {:.2}), category={}.", Some(ctx.approval_manager.as_ref()), msg.channel.as_str(), non_cli_approval_context, + cost_tracking_context, &ctx.multimodal, ctx.max_tool_iterations, Some(cancellation_token.clone()), @@ -5071,6 +5094,26 @@ pub async fn start_channels(config: Config) -> Result<()> { .telegram .as_ref() .is_some_and(|tg| tg.interrupt_on_new_message); + let cost_tracking_state = if config.cost.enabled { + match CostTracker::new(config.cost.clone(), &config.workspace_dir) { + Ok(tracker) => Some(ChannelCostTrackingState { + tracker: Arc::new(tracker), + prices: Arc::new(config.cost.prices.clone()), + }), + Err(error) => { + tracing::warn!("Failed to initialize channel cost tracker: {error}"); + None + } + } + } else { + None + }; + { + let mut store = channel_cost_tracking_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + *store = cost_tracking_state; + } let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name, @@ -5120,6 +5163,13 @@ pub async fn start_channels(config: Config) -> Result<()> { let _ = h.await; } + { + let mut store = channel_cost_tracking_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + *store = None; + } + Ok(()) } @@ -9687,7 +9737,7 @@ BTC is currently around $65,000 based on latest tool output."# .get("test-channel_alice") .expect("history should be stored for sender"); assert_eq!(turns[0].role, "user"); - assert_eq!(turns[0].content, "hello"); + assert!(turns[0].content.ends_with("hello")); assert!(!turns[0].content.contains("[Memory context]")); } @@ -10511,7 +10561,7 @@ BTC is currently around $65,000 based on latest tool output."#; .expect("history should exist for sender"); assert_eq!(turns.len(), 2); assert_eq!(turns[0].role, "user"); - assert_eq!(turns[0].content, "What is WAL?"); + assert!(turns[0].content.ends_with("What is WAL?")); assert_eq!(turns[1].role, "assistant"); assert_eq!(turns[1].content, "ok"); assert!( diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 87596234c..9a42deff2 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -1555,13 +1555,6 @@ Allowlist Telegram username (without '@') or numeric user ID.", chat_id.clone() }; - // Check mention_only for group messages - // Voice messages cannot contain mentions, so skip in group chats when mention_only is set - let is_group = Self::is_group_message(message); - if self.mention_only && is_group { - return None; - } - // Download and transcribe let file_path = match self.get_file_path(&metadata.file_id).await { Ok(p) => p, @@ -3233,6 +3226,8 @@ Ensure only one `zeroclaw` process is using this bot token." mod tests { use super::*; use std::path::Path; + use wiremock::matchers::{header, method, path, query_param}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[cfg(unix)] fn symlink_file(src: &Path, dst: &Path) { @@ -5045,6 +5040,83 @@ mod tests { assert!(ch.voice_transcriptions.lock().is_empty()); } + #[tokio::test] + async fn try_parse_voice_message_allows_group_sender_override_and_transcribes() { + let telegram_api = MockServer::start().await; + let transcription_api = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/bottoken/getFile")) + .and(query_param("file_id", "voice_file")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "ok": true, + "result": { "file_path": "voice/file_without_ext" } + }))) + .mount(&telegram_api) + .await; + Mock::given(method("GET")) + .and(path("/file/bottoken/voice/file_without_ext")) + .respond_with( + ResponseTemplate::new(200) + .set_body_bytes(vec![0x4f, 0x67, 0x67, 0x53, 0x00, 0x02, 0x03, 0x04]), + ) + .mount(&telegram_api) + .await; + Mock::given(method("POST")) + .and(path("/transcribe")) + .and(header("authorization", "Bearer test-groq-key")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "text": "hello from telegram voice" + }))) + .mount(&transcription_api) + .await; + + let previous_api_key = std::env::var("GROQ_API_KEY").ok(); + std::env::set_var("GROQ_API_KEY", "test-groq-key"); + + let mut tc = crate::config::TranscriptionConfig::default(); + tc.enabled = true; + tc.api_url = format!("{}/transcribe", transcription_api.uri()); + + let ch = TelegramChannel::new("token".into(), vec!["555".into()], true) + .with_group_reply_allowed_senders(vec!["555".into()]) + .with_api_base(telegram_api.uri()) + .with_transcription(tc); + let update = serde_json::json!({ + "message": { + "message_id": 42, + "voice": { + "file_id": "voice_file", + "duration": 4, + "mime_type": "audio/ogg" + }, + "from": { "id": 555, "username": "alice" }, + "chat": { "id": -100123, "type": "supergroup" } + } + }); + + let parsed = ch + .try_parse_voice_message(&update) + .await + .expect("voice message should be transcribed for configured sender override"); + + match previous_api_key { + Some(value) => std::env::set_var("GROQ_API_KEY", value), + None => std::env::remove_var("GROQ_API_KEY"), + } + + assert_eq!(parsed.reply_target, "-100123"); + assert_eq!(parsed.sender, "alice"); + assert_eq!(parsed.content, "[Voice] hello from telegram voice"); + assert_eq!( + ch.voice_transcriptions + .lock() + .get("-100123:42") + .cloned() + .as_deref(), + Some("hello from telegram voice") + ); + } + // ───────────────────────────────────────────────────────────────────── // Live e2e: voice transcription via Groq Whisper + reply cache lookup // ───────────────────────────────────────────────────────────────────── diff --git a/src/config/schema.rs b/src/config/schema.rs index dee0e8f28..da9cc2757 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -5858,6 +5858,16 @@ impl Config { &mut config.browser.computer_use.api_key, "config.browser.computer_use.api_key", )?; + decrypt_optional_secret( + &store, + &mut config.web_fetch.api_key, + "config.web_fetch.api_key", + )?; + decrypt_optional_secret( + &store, + &mut config.web_search.api_key, + "config.web_search.api_key", + )?; decrypt_optional_secret( &store, @@ -5889,6 +5899,20 @@ impl Config { for agent in config.agents.values_mut() { decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; } + for route in &mut config.model_routes { + decrypt_optional_secret( + &store, + &mut route.api_key, + "config.model_routes.*.api_key", + )?; + } + for route in &mut config.embedding_routes { + decrypt_optional_secret( + &store, + &mut route.api_key, + "config.embedding_routes.*.api_key", + )?; + } decrypt_channel_secrets(&store, &mut config.channels_config)?; resolve_telegram_allowed_users_env_refs(&mut config.channels_config)?; @@ -6727,6 +6751,16 @@ impl Config { &mut config_to_save.browser.computer_use.api_key, "config.browser.computer_use.api_key", )?; + encrypt_optional_secret( + &store, + &mut config_to_save.web_fetch.api_key, + "config.web_fetch.api_key", + )?; + encrypt_optional_secret( + &store, + &mut config_to_save.web_search.api_key, + "config.web_search.api_key", + )?; encrypt_optional_secret( &store, @@ -6758,6 +6792,16 @@ impl Config { for agent in config_to_save.agents.values_mut() { encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; } + for route in &mut config_to_save.model_routes { + encrypt_optional_secret(&store, &mut route.api_key, "config.model_routes.*.api_key")?; + } + for route in &mut config_to_save.embedding_routes { + encrypt_optional_secret( + &store, + &mut route.api_key, + "config.embedding_routes.*.api_key", + )?; + } encrypt_channel_secrets(&store, &mut config_to_save.channels_config)?; @@ -7747,6 +7791,8 @@ tool_dispatcher = "xml" config.proxy.https_proxy = Some("https://user:pass@proxy.internal:8443".into()); config.proxy.all_proxy = Some("socks5://user:pass@proxy.internal:1080".into()); config.browser.computer_use.api_key = Some("browser-credential".into()); + config.web_fetch.api_key = Some("web-fetch-credential".into()); + config.web_search.api_key = Some("web-search-credential".into()); config.web_search.brave_api_key = Some("brave-credential".into()); config.storage.provider.config.db_url = Some("postgres://user:pw@host/db".into()); config.reliability.api_keys = vec!["backup-credential".into()]; @@ -7754,6 +7800,20 @@ tool_dispatcher = "xml" "custom:https://api-a.example.com/v1".into(), "fallback-a-credential".into(), ); + config.model_routes = vec![ModelRouteConfig { + hint: "reasoning".into(), + provider: "openrouter".into(), + model: "anthropic/claude-sonnet-4".into(), + max_tokens: None, + api_key: Some("route-credential".into()), + }]; + config.embedding_routes = vec![EmbeddingRouteConfig { + hint: "semantic".into(), + provider: "openai".into(), + model: "text-embedding-3-small".into(), + dimensions: Some(1536), + api_key: Some("embedding-credential".into()), + }]; config.gateway.paired_tokens = vec!["zc_0123456789abcdef".into()]; config.channels_config.telegram = Some(TelegramConfig { bot_token: "telegram-credential".into(), @@ -7835,6 +7895,22 @@ tool_dispatcher = "xml" store.decrypt(browser_encrypted).unwrap(), "browser-credential" ); + let web_fetch_encrypted = stored.web_fetch.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted( + web_fetch_encrypted + )); + assert_eq!( + store.decrypt(web_fetch_encrypted).unwrap(), + "web-fetch-credential" + ); + let web_search_api_encrypted = stored.web_search.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted( + web_search_api_encrypted + )); + assert_eq!( + store.decrypt(web_search_api_encrypted).unwrap(), + "web-search-credential" + ); let web_search_encrypted = stored.web_search.brave_api_key.as_deref().unwrap(); assert!(crate::security::SecretStore::is_encrypted( @@ -7870,6 +7946,15 @@ tool_dispatcher = "xml" store.decrypt(fallback_key).unwrap(), "fallback-a-credential" ); + let routed_key = stored.model_routes[0].api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(routed_key)); + assert_eq!(store.decrypt(routed_key).unwrap(), "route-credential"); + let embedding_key = stored.embedding_routes[0].api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(embedding_key)); + assert_eq!( + store.decrypt(embedding_key).unwrap(), + "embedding-credential" + ); let paired_token = &stored.gateway.paired_tokens[0]; assert!(crate::security::SecretStore::is_encrypted(paired_token)); diff --git a/src/gateway/api.rs b/src/gateway/api.rs index 13da1c5e2..bed593e68 100644 --- a/src/gateway/api.rs +++ b/src/gateway/api.rs @@ -1096,6 +1096,12 @@ fn mask_sensitive_fields(config: &crate::config::Config) -> crate::config::Confi for agent in masked.agents.values_mut() { mask_optional_secret(&mut agent.api_key); } + for route in &mut masked.model_routes { + mask_optional_secret(&mut route.api_key); + } + for route in &mut masked.embedding_routes { + mask_optional_secret(&mut route.api_key); + } if let Some(telegram) = masked.channels_config.telegram.as_mut() { mask_required_secret(&mut telegram.bot_token); @@ -1214,6 +1220,20 @@ fn restore_masked_sensitive_fields( restore_optional_secret(&mut agent.api_key, ¤t_agent.api_key); } } + for (incoming_route, current_route) in incoming + .model_routes + .iter_mut() + .zip(current.model_routes.iter()) + { + restore_optional_secret(&mut incoming_route.api_key, ¤t_route.api_key); + } + for (incoming_route, current_route) in incoming + .embedding_routes + .iter_mut() + .zip(current.embedding_routes.iter()) + { + restore_optional_secret(&mut incoming_route.api_key, ¤t_route.api_key); + } if let (Some(incoming_ch), Some(current_ch)) = ( incoming.channels_config.telegram.as_mut(), @@ -1393,11 +1413,27 @@ mod tests { current.workspace_dir = std::path::PathBuf::from("/tmp/current/workspace"); current.api_key = Some("real-key".to_string()); current.reliability.api_keys = vec!["r1".to_string(), "r2".to_string()]; + current.model_routes = vec![crate::config::ModelRouteConfig { + hint: "reasoning".to_string(), + provider: "openrouter".to_string(), + model: "anthropic/claude-sonnet-4".to_string(), + max_tokens: None, + api_key: Some("route-key".to_string()), + }]; + current.embedding_routes = vec![crate::config::EmbeddingRouteConfig { + hint: "semantic".to_string(), + provider: "openai".to_string(), + model: "text-embedding-3-small".to_string(), + dimensions: None, + api_key: Some("embedding-key".to_string()), + }]; let mut incoming = mask_sensitive_fields(¤t); incoming.default_model = Some("gpt-4.1-mini".to_string()); // Simulate UI changing only one key and keeping the first masked. incoming.reliability.api_keys = vec![MASKED_SECRET.to_string(), "r2-new".to_string()]; + incoming.model_routes[0].api_key = Some(MASKED_SECRET.to_string()); + incoming.embedding_routes[0].api_key = Some(MASKED_SECRET.to_string()); let hydrated = hydrate_config_for_save(incoming, ¤t); @@ -1409,6 +1445,14 @@ mod tests { hydrated.reliability.api_keys, vec!["r1".to_string(), "r2-new".to_string()] ); + assert_eq!( + hydrated.model_routes[0].api_key.as_deref(), + Some("route-key") + ); + assert_eq!( + hydrated.embedding_routes[0].api_key.as_deref(), + Some("embedding-key") + ); } #[test] @@ -1518,6 +1562,36 @@ mod tests { ); } + #[test] + fn mask_sensitive_fields_masks_route_api_keys() { + let mut cfg = crate::config::Config::default(); + cfg.model_routes = vec![crate::config::ModelRouteConfig { + hint: "reasoning".to_string(), + provider: "openrouter".to_string(), + model: "anthropic/claude-sonnet-4".to_string(), + max_tokens: None, + api_key: Some("route-real-key".to_string()), + }]; + cfg.embedding_routes = vec![crate::config::EmbeddingRouteConfig { + hint: "semantic".to_string(), + provider: "openai".to_string(), + model: "text-embedding-3-small".to_string(), + dimensions: None, + api_key: Some("embedding-real-key".to_string()), + }]; + + let masked = mask_sensitive_fields(&cfg); + + assert_eq!( + masked.model_routes[0].api_key.as_deref(), + Some(MASKED_SECRET) + ); + assert_eq!( + masked.embedding_routes[0].api_key.as_deref(), + Some(MASKED_SECRET) + ); + } + #[test] fn hydrate_config_for_save_restores_wati_email_and_feishu_secrets() { let mut current = crate::config::Config::default(); diff --git a/src/peripherals/mod.rs b/src/peripherals/mod.rs index 26aebf9a8..29cce3b27 100644 --- a/src/peripherals/mod.rs +++ b/src/peripherals/mod.rs @@ -31,6 +31,7 @@ use crate::peripherals::traits::Peripheral; use crate::tools::HardwareMemoryMapTool; use crate::tools::Tool; use anyhow::Result; +use std::sync::Arc; /// List configured boards from config (no connection yet). pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoardConfig> { @@ -137,20 +138,20 @@ pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> /// Create and connect peripherals from config, returning their tools. /// Returns empty vec if peripherals disabled or hardware feature off. #[cfg(feature = "hardware")] -pub async fn create_peripheral_tools(config: &PeripheralsConfig) -> Result>> { +pub async fn create_peripheral_tools(config: &PeripheralsConfig) -> Result>> { if !config.enabled || config.boards.is_empty() { return Ok(Vec::new()); } - let mut tools: Vec> = Vec::new(); + let mut tools: Vec> = Vec::new(); let mut serial_transports: Vec<(String, std::sync::Arc)> = Vec::new(); for board in &config.boards { // Arduino Uno Q: Bridge transport (socket to local Bridge app) if board.transport == "bridge" && (board.board == "arduino-uno-q" || board.board == "uno-q") { - tools.push(Box::new(uno_q_bridge::UnoQGpioReadTool)); - tools.push(Box::new(uno_q_bridge::UnoQGpioWriteTool)); + tools.push(Arc::new(uno_q_bridge::UnoQGpioReadTool)); + tools.push(Arc::new(uno_q_bridge::UnoQGpioWriteTool)); tracing::info!(board = %board.board, "Uno Q Bridge GPIO tools added"); continue; } @@ -191,7 +192,7 @@ pub async fn create_peripheral_tools(config: &PeripheralsConfig) -> Result Result = config.boards.iter().map(|b| b.board.clone()).collect(); - tools.push(Box::new(HardwareMemoryMapTool::new(board_names.clone()))); - tools.push(Box::new(crate::tools::HardwareBoardInfoTool::new( + tools.push(Arc::new(HardwareMemoryMapTool::new(board_names.clone()))); + tools.push(Arc::new(crate::tools::HardwareBoardInfoTool::new( board_names.clone(), ))); - tools.push(Box::new(crate::tools::HardwareMemoryReadTool::new( + tools.push(Arc::new(crate::tools::HardwareMemoryReadTool::new( board_names, ))); } // Phase C: Add hardware_capabilities tool when any serial boards if !serial_transports.is_empty() { - tools.push(Box::new(capabilities_tool::HardwareCapabilitiesTool::new( + tools.push(Arc::new(capabilities_tool::HardwareCapabilitiesTool::new( serial_transports, ))); } @@ -229,7 +230,7 @@ pub async fn create_peripheral_tools(config: &PeripheralsConfig) -> Result Result>> { +pub async fn create_peripheral_tools(_config: &PeripheralsConfig) -> Result>> { Ok(Vec::new()) } diff --git a/src/providers/openai_codex.rs b/src/providers/openai_codex.rs index 982a7ea4e..36b0d472f 100644 --- a/src/providers/openai_codex.rs +++ b/src/providers/openai_codex.rs @@ -116,7 +116,6 @@ impl OpenAiCodexProvider { .unwrap_or_else(|_| Client::new()), }) } - } fn default_zeroclaw_dir() -> PathBuf { diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 8a368ea70..ad8929d04 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -518,6 +518,12 @@ pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String { instructions.push_str("\n"); instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#); instructions.push_str("\n\n\n"); + instructions.push_str( + "If you would otherwise say an action is only available in a code agent or OpenClaw, do not refuse. ", + ); + instructions.push_str( + "This runtime is the tool environment, so emit the matching payload instead.\n\n", + ); instructions.push_str("You may use multiple tool calls in a single response. "); instructions.push_str("After tool execution, results appear in tags. "); instructions @@ -750,6 +756,7 @@ mod tests { assert!(instructions.contains("Tool Use Protocol")); assert!(instructions.contains("")); assert!(instructions.contains("")); + assert!(instructions.contains("only available in a code agent or OpenClaw")); // Check for tool listings assert!(instructions.contains("**shell**")); diff --git a/src/tools/cron_add.rs b/src/tools/cron_add.rs index bf731b81b..c0aae2a7f 100644 --- a/src/tools/cron_add.rs +++ b/src/tools/cron_add.rs @@ -184,7 +184,7 @@ impl Tool for CronAddTool { return Ok(blocked); } - cron::add_shell_job(&self.config, name, schedule, command) + cron::add_shell_job_with_approval(&self.config, name, schedule, command, approved) } JobType::Agent => { let prompt = match args.get("prompt").and_then(serde_json::Value::as_str) { diff --git a/src/tools/cron_remove.rs b/src/tools/cron_remove.rs index b4dc110c9..0233ea184 100644 --- a/src/tools/cron_remove.rs +++ b/src/tools/cron_remove.rs @@ -166,10 +166,13 @@ mod tests { config_path: tmp.path().join("config.toml"), ..Config::default() }; - config.autonomy.level = AutonomyLevel::ReadOnly; std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let mut writable_config = config.clone(); + writable_config.autonomy.level = AutonomyLevel::Full; + let writable_cfg = Arc::new(writable_config); + let job = cron::add_job(&writable_cfg, "*/5 * * * *", "echo ok").unwrap(); + config.autonomy.level = AutonomyLevel::ReadOnly; let cfg = Arc::new(config); - let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); let tool = CronRemoveTool::new(cfg.clone(), test_security(&cfg)); let result = tool.execute(json!({"job_id": job.id})).await.unwrap(); diff --git a/src/tools/cron_run.rs b/src/tools/cron_run.rs index bb3c9e419..6ac0b3ab6 100644 --- a/src/tools/cron_run.rs +++ b/src/tools/cron_run.rs @@ -211,10 +211,13 @@ mod tests { config_path: tmp.path().join("config.toml"), ..Config::default() }; - config.autonomy.level = AutonomyLevel::ReadOnly; std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let mut writable_config = config.clone(); + writable_config.autonomy.level = AutonomyLevel::Full; + let writable_cfg = Arc::new(writable_config); + let job = cron::add_job(&writable_cfg, "*/5 * * * *", "echo run-now").unwrap(); + config.autonomy.level = AutonomyLevel::ReadOnly; let cfg = Arc::new(config); - let job = cron::add_job(&cfg, "*/5 * * * *", "echo run-now").unwrap(); let tool = CronRunTool::new(cfg.clone(), test_security(&cfg)); let result = tool.execute(json!({ "job_id": job.id })).await.unwrap(); @@ -234,7 +237,8 @@ mod tests { config.autonomy.allowed_commands = vec!["touch".into()]; std::fs::create_dir_all(&config.workspace_dir).unwrap(); let cfg = Arc::new(config); - let job = cron::add_job(&cfg, "*/5 * * * *", "touch cron-run-approval").unwrap(); + let job = + cron::add_job_approved(&cfg, "*/5 * * * *", "touch cron-run-approval", true).unwrap(); let tool = CronRunTool::new(cfg.clone(), test_security(&cfg)); let denied = tool.execute(json!({ "job_id": job.id })).await.unwrap(); diff --git a/src/tools/cron_update.rs b/src/tools/cron_update.rs index f41bacb15..b70e903d7 100644 --- a/src/tools/cron_update.rs +++ b/src/tools/cron_update.rs @@ -133,7 +133,13 @@ impl Tool for CronUpdateTool { return Ok(blocked); } - match cron::update_job(&self.config, job_id, patch) { + let update_result = if patch.command.is_some() { + cron::update_shell_job_with_approval(&self.config, job_id, patch, approved) + } else { + cron::update_job(&self.config, job_id, patch) + }; + + match update_result { Ok(job) => Ok(ToolResult { success: true, output: serde_json::to_string_pretty(&job)?, @@ -228,10 +234,13 @@ mod tests { config_path: tmp.path().join("config.toml"), ..Config::default() }; - config.autonomy.level = AutonomyLevel::ReadOnly; std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let mut writable_config = config.clone(); + writable_config.autonomy.level = AutonomyLevel::Full; + let writable_cfg = Arc::new(writable_config); + let job = cron::add_job(&writable_cfg, "*/5 * * * *", "echo ok").unwrap(); + config.autonomy.level = AutonomyLevel::ReadOnly; let cfg = Arc::new(config); - let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); let tool = CronUpdateTool::new(cfg.clone(), test_security(&cfg)); let result = tool diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index ea26a1f0a..5d38354c7 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -6,6 +6,7 @@ use crate::observability::traits::{Observer, ObserverEvent, ObserverMetric}; use crate::providers::{self, ChatMessage, Provider}; use crate::security::policy::ToolOperation; use crate::security::SecurityPolicy; +use crate::tools::SharedToolRegistry; use async_trait::async_trait; use serde_json::json; use std::collections::HashMap; @@ -36,7 +37,7 @@ pub struct DelegateTool { /// Depth at which this tool instance lives in the delegation chain. depth: u32, /// Parent tool registry for agentic sub-agents. - parent_tools: Arc>>, + parent_tools: SharedToolRegistry, /// Inherited multimodal handling config for sub-agent loops. multimodal_config: crate::config::MultimodalConfig, /// Optional typed coordination bus used to trace delegate lifecycle events. @@ -72,7 +73,7 @@ impl DelegateTool { fallback_credential, provider_runtime_options, depth: 0, - parent_tools: Arc::new(Vec::new()), + parent_tools: crate::tools::new_shared_tool_registry(), multimodal_config: crate::config::MultimodalConfig::default(), coordination_bus, coordination_lead_agent: DEFAULT_COORDINATION_LEAD_AGENT.to_string(), @@ -111,7 +112,7 @@ impl DelegateTool { fallback_credential, provider_runtime_options, depth, - parent_tools: Arc::new(Vec::new()), + parent_tools: crate::tools::new_shared_tool_registry(), multimodal_config: crate::config::MultimodalConfig::default(), coordination_bus, coordination_lead_agent: DEFAULT_COORDINATION_LEAD_AGENT.to_string(), @@ -119,7 +120,7 @@ impl DelegateTool { } /// Attach parent tools used to build sub-agent allowlist registries. - pub fn with_parent_tools(mut self, parent_tools: Arc>>) -> Self { + pub fn with_parent_tools(mut self, parent_tools: SharedToolRegistry) -> Self { self.parent_tools = parent_tools; self } @@ -462,11 +463,20 @@ impl DelegateTool { .filter(|name| !name.is_empty()) .collect::>(); - let sub_tools: Vec> = self + let parent_tools = self .parent_tools + .lock() + .map(|tools| tools.clone()) + .unwrap_or_default(); + + let sub_tools: Vec> = parent_tools .iter() .filter(|tool| allowed.contains(tool.name())) - .filter(|tool| tool.name() != "delegate") + .filter(|tool| { + tool.name() != "delegate" + && tool.name() != "subagent_spawn" + && tool.name() != "subagent_manage" + }) .map(|tool| Box::new(ToolArcRef::new(tool.clone())) as Box) .collect(); @@ -967,6 +977,12 @@ mod tests { } } + fn shared_parent_tools(tools: Vec>) -> SharedToolRegistry { + let shared = crate::tools::new_shared_tool_registry(); + crate::tools::sync_shared_tool_registry(&shared, &tools); + shared + } + #[test] fn name_and_schema() { let tool = DelegateTool::new(sample_agents(), None, test_security()); @@ -1278,7 +1294,7 @@ mod tests { ); let tool = DelegateTool::new(agents, None, test_security()) - .with_parent_tools(Arc::new(vec![Arc::new(EchoTool)])); + .with_parent_tools(shared_parent_tools(vec![Arc::new(EchoTool)])); let result = tool .execute(json!({"agent": "agentic", "prompt": "test"})) .await @@ -1296,7 +1312,7 @@ mod tests { async fn execute_agentic_runs_tool_call_loop_with_filtered_tools() { let config = agentic_config(vec!["echo_tool".to_string()], 10); let tool = DelegateTool::new(HashMap::new(), None, test_security()).with_parent_tools( - Arc::new(vec![ + shared_parent_tools(vec![ Arc::new(EchoTool), Arc::new(DelegateTool::new(HashMap::new(), None, test_security())), ]), @@ -1313,11 +1329,33 @@ mod tests { assert!(result.output.contains("done")); } + #[tokio::test] + async fn execute_agentic_reads_late_bound_parent_tools() { + let config = agentic_config(vec!["echo_tool".to_string()], 10); + let parent_tools = crate::tools::new_shared_tool_registry(); + let tool = DelegateTool::new(HashMap::new(), None, test_security()) + .with_parent_tools(parent_tools.clone()); + + crate::tools::sync_shared_tool_registry( + &parent_tools, + &[Arc::new(EchoTool) as Arc], + ); + + let provider = OneToolThenFinalProvider; + let result = tool + .execute_agentic("agentic", &config, &provider, "run", 0.2) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("done")); + } + #[tokio::test] async fn execute_agentic_excludes_delegate_even_if_allowlisted() { let config = agentic_config(vec!["delegate".to_string()], 10); let tool = DelegateTool::new(HashMap::new(), None, test_security()).with_parent_tools( - Arc::new(vec![Arc::new(DelegateTool::new( + shared_parent_tools(vec![Arc::new(DelegateTool::new( HashMap::new(), None, test_security(), @@ -1342,7 +1380,7 @@ mod tests { async fn execute_agentic_respects_max_iterations() { let config = agentic_config(vec!["echo_tool".to_string()], 2); let tool = DelegateTool::new(HashMap::new(), None, test_security()) - .with_parent_tools(Arc::new(vec![Arc::new(EchoTool)])); + .with_parent_tools(shared_parent_tools(vec![Arc::new(EchoTool)])); let provider = InfiniteToolCallProvider; let result = tool @@ -1362,7 +1400,7 @@ mod tests { async fn execute_agentic_propagates_provider_errors() { let config = agentic_config(vec!["echo_tool".to_string()], 10); let tool = DelegateTool::new(HashMap::new(), None, test_security()) - .with_parent_tools(Arc::new(vec![Arc::new(EchoTool)])); + .with_parent_tools(shared_parent_tools(vec![Arc::new(EchoTool)])); let provider = FailingProvider; let result = tool diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 9d4f320c0..17f3e23a1 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -126,7 +126,7 @@ use crate::runtime::{NativeRuntime, RuntimeAdapter}; use crate::security::SecurityPolicy; use async_trait::async_trait; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; #[derive(Clone)] struct ArcDelegatingTool { @@ -139,6 +139,21 @@ impl ArcDelegatingTool { } } +pub(crate) type SharedToolRegistry = Arc>>>; + +pub(crate) fn new_shared_tool_registry() -> SharedToolRegistry { + Arc::new(Mutex::new(Vec::new())) +} + +pub(crate) fn sync_shared_tool_registry( + shared_registry: &SharedToolRegistry, + tools: &[Arc], +) { + if let Ok(mut guard) = shared_registry.lock() { + *guard = tools.to_vec(); + } +} + #[async_trait] impl Tool for ArcDelegatingTool { fn name(&self) -> &str { @@ -158,7 +173,7 @@ impl Tool for ArcDelegatingTool { } } -fn boxed_registry_from_arcs(tools: Vec>) -> Vec> { +pub(crate) fn boxed_registry_from_arcs(tools: Vec>) -> Vec> { tools.into_iter().map(ArcDelegatingTool::boxed).collect() } @@ -244,6 +259,40 @@ pub fn all_tools_with_runtime( fallback_api_key: Option<&str>, root_config: &crate::config::Config, ) -> Vec> { + let (tool_arcs, _shared_registry) = all_tools_with_runtime_arcs( + config, + security, + runtime, + memory, + composio_key, + composio_entity_id, + browser_config, + http_config, + web_fetch_config, + workspace_dir, + agents, + fallback_api_key, + root_config, + ); + boxed_registry_from_arcs(tool_arcs) +} + +#[allow(clippy::implicit_hasher, clippy::too_many_arguments)] +pub(crate) fn all_tools_with_runtime_arcs( + config: Arc, + security: &Arc, + runtime: Arc, + memory: Arc, + composio_key: Option<&str>, + composio_entity_id: Option<&str>, + browser_config: &crate::config::BrowserConfig, + http_config: &crate::config::HttpRequestConfig, + web_fetch_config: &crate::config::WebFetchConfig, + workspace_dir: &std::path::Path, + agents: &HashMap, + fallback_api_key: Option<&str>, + root_config: &crate::config::Config, +) -> (Vec>, Option) { let has_shell_access = runtime.has_shell_access(); let has_filesystem_access = runtime.has_filesystem_access(); let zeroclaw_dir = root_config @@ -417,6 +466,8 @@ pub fn all_tools_with_runtime( } // Add delegation and sub-agent orchestration tools when agents are configured + let mut shared_parent_tools = None; + if !agents.is_empty() { let delegate_agents: HashMap = agents .iter() @@ -442,7 +493,8 @@ pub fn all_tools_with_runtime( max_tokens_override: None, model_support_vision: root_config.model_support_vision, }; - let parent_tools = Arc::new(tool_arcs.clone()); + let parent_tools = new_shared_tool_registry(); + shared_parent_tools = Some(parent_tools.clone()); let mut delegate_tool = DelegateTool::new_with_options( delegate_agents.clone(), delegate_fallback_credential.clone(), @@ -536,7 +588,11 @@ pub fn all_tools_with_runtime( } } - boxed_registry_from_arcs(tool_arcs) + if let Some(shared_registry) = shared_parent_tools.as_ref() { + sync_shared_tool_registry(shared_registry, &tool_arcs); + } + + (tool_arcs, shared_parent_tools) } #[cfg(test)] @@ -651,6 +707,7 @@ mod tests { allowed_users: vec!["*".into()], listen_to_bots: false, mention_only: false, + group_reply: None, }); let tools = all_tools( diff --git a/src/tools/subagent_spawn.rs b/src/tools/subagent_spawn.rs index 488aa5ffe..4c04d4e82 100644 --- a/src/tools/subagent_spawn.rs +++ b/src/tools/subagent_spawn.rs @@ -11,6 +11,7 @@ use crate::observability::traits::{Observer, ObserverEvent, ObserverMetric}; use crate::providers::{self, ChatMessage, Provider}; use crate::security::policy::ToolOperation; use crate::security::SecurityPolicy; +use crate::tools::SharedToolRegistry; use async_trait::async_trait; use chrono::Utc; use serde_json::json; @@ -32,7 +33,7 @@ pub struct SubAgentSpawnTool { fallback_credential: Option, provider_runtime_options: providers::ProviderRuntimeOptions, registry: Arc, - parent_tools: Arc>>, + parent_tools: SharedToolRegistry, multimodal_config: crate::config::MultimodalConfig, } @@ -44,7 +45,7 @@ impl SubAgentSpawnTool { security: Arc, provider_runtime_options: providers::ProviderRuntimeOptions, registry: Arc, - parent_tools: Arc>>, + parent_tools: SharedToolRegistry, multimodal_config: crate::config::MultimodalConfig, ) -> Self { Self { @@ -395,7 +396,7 @@ async fn run_agentic_background( agent_config: &DelegateAgentConfig, provider: &dyn Provider, full_prompt: &str, - parent_tools: &[Arc], + parent_tools: &SharedToolRegistry, multimodal_config: &crate::config::MultimodalConfig, ) -> anyhow::Result { if agent_config.allowed_tools.is_empty() { @@ -415,6 +416,11 @@ async fn run_agentic_background( .filter(|name| !name.is_empty()) .collect::>(); + let parent_tools = parent_tools + .lock() + .map(|tools| tools.clone()) + .unwrap_or_default(); + let sub_tools: Vec> = parent_tools .iter() .filter(|tool| allowed.contains(tool.name())) @@ -540,7 +546,7 @@ mod tests { security, providers::ProviderRuntimeOptions::default(), Arc::new(SubAgentRegistry::new()), - Arc::new(Vec::new()), + Arc::new(std::sync::Mutex::new(Vec::new())), crate::config::MultimodalConfig::default(), ) } @@ -705,7 +711,7 @@ mod tests { test_security(), providers::ProviderRuntimeOptions::default(), registry, - Arc::new(Vec::new()), + Arc::new(std::sync::Mutex::new(Vec::new())), crate::config::MultimodalConfig::default(), ); diff --git a/tests/quota_tools_live.rs b/tests/quota_tools_live.rs index 6ac48db4a..59bd50773 100644 --- a/tests/quota_tools_live.rs +++ b/tests/quota_tools_live.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "quota-tools-live")] + //! Live E2E tests for quota tools with real auth profiles. //! //! These tests require real auth-profiles.json at ~/.zeroclaw/auth-profiles.json