Compare commits

..

1 Commits

Author SHA1 Message Date
argenis de la rosa aa9c6ded42 fix(cron): prevent one-shot jobs from re-executing indefinitely
Handle Schedule::At jobs in reschedule_after_run by disabling them
instead of rescheduling to a past timestamp. Also add a fallback in
persist_job_result to disable one-shot jobs if removal fails.

Closes #3868
2026-03-18 09:52:39 -04:00
17 changed files with 143 additions and 1426 deletions
+7 -6
View File
@@ -23,14 +23,13 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
# 1. Copy manifests to cache dependencies
COPY Cargo.toml Cargo.lock ./
# Remove robot-kit from workspace members — it is excluded by .dockerignore
# and is not needed for the Docker build (hardware-only crate).
RUN sed -i 's/members = \[".", "crates\/robot-kit"\]/members = ["."]/' Cargo.toml
COPY crates/robot-kit/Cargo.toml crates/robot-kit/Cargo.toml
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
RUN mkdir -p src benches \
RUN mkdir -p src benches crates/robot-kit/src \
&& echo "fn main() {}" > src/main.rs \
&& echo "" > src/lib.rs \
&& echo "fn main() {}" > benches/agent_benchmarks.rs
&& echo "fn main() {}" > benches/agent_benchmarks.rs \
&& echo "pub fn placeholder() {}" > crates/robot-kit/src/lib.rs
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
@@ -39,11 +38,13 @@ RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/regist
else \
cargo build --release --locked; \
fi
RUN rm -rf src benches
RUN rm -rf src benches crates/robot-kit/src
# 2. Copy only build-relevant source paths (avoid cache-busting on docs/tests/scripts)
COPY src/ src/
COPY benches/ benches/
COPY crates/ crates/
COPY firmware/ firmware/
COPY --from=web-builder /web/dist web/dist
COPY *.rs .
RUN touch src/main.rs
+7 -6
View File
@@ -38,14 +38,13 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
# 1. Copy manifests to cache dependencies
COPY Cargo.toml Cargo.lock ./
# Remove robot-kit from workspace members — it is excluded by .dockerignore
# and is not needed for the Docker build (hardware-only crate).
RUN sed -i 's/members = \[".", "crates\/robot-kit"\]/members = ["."]/' Cargo.toml
COPY crates/robot-kit/Cargo.toml crates/robot-kit/Cargo.toml
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
RUN mkdir -p src benches \
RUN mkdir -p src benches crates/robot-kit/src \
&& echo "fn main() {}" > src/main.rs \
&& echo "" > src/lib.rs \
&& echo "fn main() {}" > benches/agent_benchmarks.rs
&& echo "fn main() {}" > benches/agent_benchmarks.rs \
&& echo "pub fn placeholder() {}" > crates/robot-kit/src/lib.rs
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
@@ -54,11 +53,13 @@ RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/regist
else \
cargo build --release --locked; \
fi
RUN rm -rf src benches
RUN rm -rf src benches crates/robot-kit/src
# 2. Copy only build-relevant source paths (avoid cache-busting on docs/tests/scripts)
COPY src/ src/
COPY benches/ benches/
COPY crates/ crates/
COPY firmware/ firmware/
COPY --from=web-builder /web/dist web/dist
RUN touch src/main.rs
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
+1 -149
View File
@@ -767,140 +767,6 @@ run_guided_installer() {
fi
}
ensure_default_config_and_workspace() {
# Creates a minimal config.toml and workspace scaffold files when the
# onboard wizard was skipped (e.g. --skip-build --prefer-prebuilt, or
# Docker mode without an API key).
#
# $1 — config directory (e.g. ~/.zeroclaw or $docker_data_dir/.zeroclaw)
# $2 — workspace directory (e.g. ~/.zeroclaw/workspace or $docker_data_dir/workspace)
# $3 — provider name (default: openrouter)
local config_dir="$1"
local workspace_dir="$2"
local provider="${3:-openrouter}"
mkdir -p "$config_dir" "$workspace_dir"
# --- config.toml ---
local config_path="$config_dir/config.toml"
if [[ ! -f "$config_path" ]]; then
step_dot "Creating default config.toml"
cat > "$config_path" <<TOML
# ZeroClaw configuration — generated by install.sh
# Edit this file or run 'zeroclaw onboard' to reconfigure.
default_provider = "${provider}"
workspace_dir = "${workspace_dir}"
TOML
if [[ -n "${API_KEY:-}" ]]; then
printf 'api_key = "%s"\n' "$API_KEY" >> "$config_path"
fi
if [[ -n "${MODEL:-}" ]]; then
printf 'default_model = "%s"\n' "$MODEL" >> "$config_path"
fi
chmod 600 "$config_path" 2>/dev/null || true
step_ok "Default config.toml created at $config_path"
else
step_dot "config.toml already exists, skipping"
fi
# --- Workspace scaffold ---
local subdirs=(sessions memory state cron skills)
for dir in "${subdirs[@]}"; do
mkdir -p "$workspace_dir/$dir"
done
# Seed workspace markdown files only if they don't already exist.
local user_name="${USER:-User}"
local agent_name="ZeroClaw"
_write_if_missing() {
local filepath="$1"
local content="$2"
if [[ ! -f "$filepath" ]]; then
printf '%s\n' "$content" > "$filepath"
fi
}
_write_if_missing "$workspace_dir/IDENTITY.md" \
"# IDENTITY.md — Who Am I?
- **Name:** ${agent_name}
- **Creature:** A Rust-forged AI — fast, lean, and relentless
- **Vibe:** Sharp, direct, resourceful. Not corporate. Not a chatbot.
---
Update this file as you evolve. Your identity is yours to shape."
_write_if_missing "$workspace_dir/USER.md" \
"# USER.md — Who You're Helping
## About You
- **Name:** ${user_name}
- **Timezone:** UTC
- **Languages:** English
## Preferences
- (Add your preferences here)
## Work Context
- (Add your work context here)
---
*Update this anytime. The more ${agent_name} knows, the better it helps.*"
_write_if_missing "$workspace_dir/MEMORY.md" \
"# MEMORY.md — Long-Term Memory
## Key Facts
(Add important facts here)
## Decisions & Preferences
(Record decisions and preferences here)
## Lessons Learned
(Document mistakes and insights here)
## Open Loops
(Track unfinished tasks and follow-ups here)"
_write_if_missing "$workspace_dir/AGENTS.md" \
"# AGENTS.md — ${agent_name} Personal Assistant
## Every Session (required)
Before doing anything else:
1. Read SOUL.md — this is who you are
2. Read USER.md — this is who you're helping
3. Use memory_recall for recent context
---
*Add your own conventions, style, and rules.*"
_write_if_missing "$workspace_dir/SOUL.md" \
"# SOUL.md — Who You Are
## Core Truths
**Be genuinely helpful, not performatively helpful.**
**Have opinions.** You're allowed to disagree.
**Be resourceful before asking.** Try to figure it out first.
**Earn trust through competence.**
## Identity
You are **${agent_name}**. Built in Rust. 3MB binary. Zero bloat.
---
*This file is yours to evolve.*"
step_ok "Workspace scaffold ready at $workspace_dir"
unset -f _write_if_missing
}
resolve_container_cli() {
local requested_cli
requested_cli="${ZEROCLAW_CONTAINER_CLI:-docker}"
@@ -1018,17 +884,10 @@ run_docker_bootstrap() {
-v "$config_mount" \
-v "$workspace_mount" \
"$docker_image" \
"${onboard_cmd[@]}" || true
"${onboard_cmd[@]}"
else
info "Docker image ready. Run zeroclaw onboard inside the container to configure."
fi
# Ensure config.toml and workspace scaffold exist on the host even when
# onboard was skipped, failed, or ran non-interactively inside the container.
ensure_default_config_and_workspace \
"$docker_data_dir/.zeroclaw" \
"$docker_data_dir/workspace" \
"$PROVIDER"
}
SCRIPT_PATH="${BASH_SOURCE[0]:-$0}"
@@ -1449,13 +1308,6 @@ elif [[ -z "$ZEROCLAW_BIN" ]]; then
warn "ZeroClaw binary not found — cannot configure provider"
fi
# Ensure config.toml and workspace scaffold exist even when onboard was
# skipped, unavailable, or failed (e.g. --skip-build --prefer-prebuilt
# without an API key, or when the binary could not run onboard).
_native_config_dir="${ZEROCLAW_CONFIG_DIR:-$HOME/.zeroclaw}"
_native_workspace_dir="${ZEROCLAW_WORKSPACE:-$_native_config_dir/workspace}"
ensure_default_config_and_workspace "$_native_config_dir" "$_native_workspace_dir" "$PROVIDER"
# --- Gateway service management ---
if [[ -n "$ZEROCLAW_BIN" ]]; then
# Try to install and start the gateway service
+57 -225
View File
@@ -17,7 +17,7 @@ use std::collections::HashSet;
use std::fmt::Write;
use std::io::Write as _;
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock, Mutex};
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
@@ -33,29 +33,6 @@ const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10;
/// Matches the channel-side constant in `channels/mod.rs`.
const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20;
/// Callback type for checking if model has been switched during tool execution.
/// Returns Some((provider, model)) if a switch was requested, None otherwise.
pub type ModelSwitchCallback = Arc<Mutex<Option<(String, String)>>>;
/// Global model switch request state - used for runtime model switching via model_switch tool.
/// This is set by the model_switch tool and checked by the agent loop.
#[allow(clippy::type_complexity)]
static MODEL_SWITCH_REQUEST: LazyLock<Arc<Mutex<Option<(String, String)>>>> =
LazyLock::new(|| Arc::new(Mutex::new(None)));
/// Get the global model switch request state
pub fn get_model_switch_state() -> ModelSwitchCallback {
Arc::clone(&MODEL_SWITCH_REQUEST)
}
/// Clear any pending model switch request
pub fn clear_model_switch_request() {
if let Ok(guard) = MODEL_SWITCH_REQUEST.lock() {
let mut guard = guard;
*guard = None;
}
}
fn glob_match(pattern: &str, name: &str) -> bool {
match pattern.find('*') {
None => pattern == name,
@@ -2141,31 +2118,6 @@ pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool {
err.chain().any(|source| source.is::<ToolLoopCancelled>())
}
#[derive(Debug)]
pub(crate) struct ModelSwitchRequested {
pub provider: String,
pub model: String,
}
impl std::fmt::Display for ModelSwitchRequested {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"model switch requested to {} {}",
self.provider, self.model
)
}
}
impl std::error::Error for ModelSwitchRequested {}
pub(crate) fn is_model_switch_requested(err: &anyhow::Error) -> Option<(String, String)> {
err.chain()
.filter_map(|source| source.downcast_ref::<ModelSwitchRequested>())
.map(|e| (e.provider.clone(), e.model.clone()))
.next()
}
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
/// When `silent` is true, suppresses stdout (for channel use).
@@ -2185,7 +2137,6 @@ pub(crate) async fn agent_turn(
excluded_tools: &[String],
dedup_exempt_tools: &[String],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
model_switch_callback: Option<ModelSwitchCallback>,
) -> Result<String> {
run_tool_call_loop(
provider,
@@ -2206,7 +2157,6 @@ pub(crate) async fn agent_turn(
excluded_tools,
dedup_exempt_tools,
activated_tools,
model_switch_callback,
)
.await
}
@@ -2412,7 +2362,6 @@ pub(crate) async fn run_tool_call_loop(
excluded_tools: &[String],
dedup_exempt_tools: &[String],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
model_switch_callback: Option<ModelSwitchCallback>,
) -> Result<String> {
let max_iterations = if max_tool_iterations == 0 {
DEFAULT_MAX_TOOL_ITERATIONS
@@ -2431,28 +2380,6 @@ pub(crate) async fn run_tool_call_loop(
return Err(ToolLoopCancelled.into());
}
// Check if model switch was requested via model_switch tool
if let Some(ref callback) = model_switch_callback {
if let Ok(guard) = callback.lock() {
if let Some((new_provider, new_model)) = guard.as_ref() {
if new_provider != provider_name || new_model != model {
tracing::info!(
"Model switch detected: {} {} -> {} {}",
provider_name,
model,
new_provider,
new_model
);
return Err(ModelSwitchRequested {
provider: new_provider.clone(),
model: new_model.clone(),
}
.into());
}
}
}
}
// Rebuild tool_specs each iteration so newly activated deferred tools appear.
let mut tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
.iter()
@@ -3272,32 +3199,28 @@ pub async fn run(
}
// ── Resolve provider ─────────────────────────────────────────
let mut provider_name = provider_override
let provider_name = provider_override
.as_deref()
.or(config.default_provider.as_deref())
.unwrap_or("openrouter")
.to_string();
.unwrap_or("openrouter");
let mut model_name = model_override
let model_name = model_override
.as_deref()
.or(config.default_model.as_deref())
.unwrap_or("anthropic/claude-sonnet-4")
.to_string();
.unwrap_or("anthropic/claude-sonnet-4");
let provider_runtime_options = providers::provider_runtime_options_from_config(&config);
let mut provider: Box<dyn Provider> = providers::create_routed_provider_with_options(
&provider_name,
let provider: Box<dyn Provider> = providers::create_routed_provider_with_options(
provider_name,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&model_name,
model_name,
&provider_runtime_options,
)?;
let model_switch_callback = get_model_switch_state();
observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.to_string(),
model: model_name.to_string(),
@@ -3441,7 +3364,7 @@ pub async fn run(
let native_tools = provider.supports_native_tools();
let mut system_prompt = crate::channels::build_system_prompt_with_mode(
&config.workspace_dir,
&model_name,
model_name,
&tool_descs,
&skills,
Some(&config.identity),
@@ -3524,72 +3447,27 @@ pub async fn run(
let excluded_tools =
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, &msg);
#[allow(unused_assignments)]
let mut response = String::new();
loop {
match run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
&provider_name,
&model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
Some(model_switch_callback.clone()),
)
.await
{
Ok(resp) => {
response = resp;
break;
}
Err(e) => {
if let Some((new_provider, new_model)) = is_model_switch_requested(&e) {
tracing::info!(
"Model switch requested, switching from {} {} to {} {}",
provider_name,
model_name,
new_provider,
new_model
);
provider = providers::create_routed_provider_with_options(
&new_provider,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&new_model,
&provider_runtime_options,
)?;
provider_name = new_provider;
model_name = new_model;
clear_model_switch_request();
observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.to_string(),
model: model_name.to_string(),
});
continue;
}
return Err(e);
}
}
}
let response = run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
)
.await?;
final_output = response.clone();
println!("{response}");
observer.record_event(&ObserverEvent::TurnComplete);
@@ -3731,66 +3609,32 @@ pub async fn run(
&user_input,
);
let response = loop {
match run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
&provider_name,
&model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
Some(model_switch_callback.clone()),
)
.await
{
Ok(resp) => break resp,
Err(e) => {
if let Some((new_provider, new_model)) = is_model_switch_requested(&e) {
tracing::info!(
"Model switch requested, switching from {} {} to {} {}",
provider_name,
model_name,
new_provider,
new_model
);
provider = providers::create_routed_provider_with_options(
&new_provider,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&new_model,
&provider_runtime_options,
)?;
provider_name = new_provider;
model_name = new_model;
clear_model_switch_request();
observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.to_string(),
model: model_name.to_string(),
});
continue;
}
eprintln!("\nError: {e}\n");
break String::new();
}
let response = match run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
)
.await
{
Ok(resp) => resp,
Err(e) => {
eprintln!("\nError: {e}\n");
continue;
}
};
final_output = response.clone();
@@ -3808,7 +3652,7 @@ pub async fn run(
if let Ok(compacted) = auto_compact_history(
&mut history,
provider.as_ref(),
&model_name,
model_name,
config.agent.max_history_messages,
config.agent.max_context_tokens,
)
@@ -4102,7 +3946,6 @@ pub async fn process_message(
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle_pm.as_ref(),
None,
)
.await
}
@@ -4560,7 +4403,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect_err("provider without vision support should fail");
@@ -4609,7 +4451,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect_err("oversized payload must fail");
@@ -4652,7 +4493,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect("valid multimodal payload should pass");
@@ -4781,7 +4621,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect("parallel execution should complete");
@@ -4853,7 +4692,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect("loop should finish after deduplicating repeated calls");
@@ -4921,7 +4759,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect("non-interactive shell should succeed for low-risk command");
@@ -4980,7 +4817,6 @@ mod tests {
&[],
&exempt,
None,
None,
)
.await
.expect("loop should finish with exempt tool executing twice");
@@ -5059,7 +4895,6 @@ mod tests {
&[],
&exempt,
None,
None,
)
.await
.expect("loop should complete");
@@ -5115,7 +4950,6 @@ mod tests {
&[],
&[],
None,
None,
)
.await
.expect("native fallback id flow should complete");
@@ -5184,7 +5018,6 @@ mod tests {
&[],
&[],
Some(&activated),
None,
)
.await
.expect("wrapper path should execute activated tools");
@@ -7076,7 +6909,6 @@ Let me check the result."#;
&[],
&[],
None,
None,
)
.await
.expect("tool loop should complete");
+17 -61
View File
@@ -227,10 +227,6 @@ fn channel_message_timeout_budget_secs(
struct ChannelRouteSelection {
provider: String,
model: String,
/// Route-specific API key override. When set, this takes precedence over
/// the global `api_key` in [`ChannelRuntimeContext`] when creating the
/// provider for this route.
api_key: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -908,7 +904,6 @@ fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection
ChannelRouteSelection {
provider: defaults.default_provider,
model: defaults.model,
api_key: None,
}
}
@@ -1127,43 +1122,21 @@ fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec<S
.unwrap_or_default()
}
/// Build a cache key that includes the provider name and, when a
/// route-specific API key is supplied, a hash of that key. This prevents
/// cache poisoning when multiple routes target the same provider with
/// different credentials.
fn provider_cache_key(provider_name: &str, route_api_key: Option<&str>) -> String {
match route_api_key {
Some(key) => {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
format!("{provider_name}@{:x}", hasher.finish())
}
None => provider_name.to_string(),
}
}
async fn get_or_create_provider(
ctx: &ChannelRuntimeContext,
provider_name: &str,
route_api_key: Option<&str>,
) -> anyhow::Result<Arc<dyn Provider>> {
let cache_key = provider_cache_key(provider_name, route_api_key);
if let Some(existing) = ctx
.provider_cache
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(&cache_key)
.get(provider_name)
.cloned()
{
return Ok(existing);
}
// Only return the pre-built default provider when there is no
// route-specific credential override — otherwise the default was
// created with the global key and would be wrong.
if route_api_key.is_none() && provider_name == ctx.default_provider.as_str() {
if provider_name == ctx.default_provider.as_str() {
return Ok(Arc::clone(&ctx.provider));
}
@@ -1174,14 +1147,9 @@ async fn get_or_create_provider(
None
};
// Prefer route-specific credential; fall back to the global key.
let effective_api_key = route_api_key
.map(ToString::to_string)
.or_else(|| ctx.api_key.clone());
let provider = create_resilient_provider_nonblocking(
provider_name,
effective_api_key,
ctx.api_key.clone(),
api_url.map(ToString::to_string),
ctx.reliability.as_ref().clone(),
ctx.provider_runtime_options.clone(),
@@ -1195,7 +1163,7 @@ async fn get_or_create_provider(
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
let cached = cache
.entry(cache_key)
.entry(provider_name.to_string())
.or_insert_with(|| Arc::clone(&provider));
Ok(Arc::clone(cached))
}
@@ -1311,27 +1279,25 @@ async fn handle_runtime_command_if_needed(
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(&current),
ChannelRuntimeCommand::SetProvider(raw_provider) => {
match resolve_provider_alias(&raw_provider) {
Some(provider_name) => {
match get_or_create_provider(ctx, &provider_name, None).await {
Ok(_) => {
if provider_name != current.provider {
current.provider = provider_name.clone();
set_route_selection(ctx, &sender_key, current.clone());
}
Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await {
Ok(_) => {
if provider_name != current.provider {
current.provider = provider_name.clone();
set_route_selection(ctx, &sender_key, current.clone());
}
format!(
format!(
"Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model <model-id>` to set a provider-compatible model.",
current.model
)
}
Err(err) => {
let safe_err = providers::sanitize_api_error(&err.to_string());
format!(
}
Err(err) => {
let safe_err = providers::sanitize_api_error(&err.to_string());
format!(
"Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}"
)
}
}
}
},
None => format!(
"Unknown provider `{raw_provider}`. Use `/models` to list valid providers."
),
@@ -1351,7 +1317,6 @@ async fn handle_runtime_command_if_needed(
}) {
current.provider = route.provider.clone();
current.model = route.model.clone();
current.api_key = route.api_key.clone();
} else {
current.model = model.clone();
}
@@ -1957,19 +1922,12 @@ async fn process_channel_message(
route = ChannelRouteSelection {
provider: matched_route.provider.clone(),
model: matched_route.model.clone(),
api_key: matched_route.api_key.clone(),
};
}
}
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
let active_provider = match get_or_create_provider(
ctx.as_ref(),
&route.provider,
route.api_key.as_deref(),
)
.await
{
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
Ok(provider) => provider,
Err(err) => {
let safe_err = providers::sanitize_api_error(&err.to_string());
@@ -2251,7 +2209,6 @@ async fn process_channel_message(
},
ctx.tool_call_dedup_exempt.as_ref(),
ctx.activated_tools.as_ref(),
None,
),
) => LlmExecutionResult::Completed(result),
};
@@ -5645,7 +5602,6 @@ BTC is currently around $65,000 based on latest tool output."#
ChannelRouteSelection {
provider: "openrouter".to_string(),
model: "route-model".to_string(),
api_key: None,
},
);
+1 -4
View File
@@ -17,10 +17,7 @@ pub use store::{
add_agent_job, due_jobs, get_job, list_jobs, list_runs, record_last_run, record_run,
remove_job, reschedule_after_run, update_job,
};
pub use types::{
deserialize_maybe_stringified, CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType,
Schedule, SessionTarget,
};
pub use types::{CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType, Schedule, SessionTarget};
/// Validate a shell command against the full security policy (allowlist + risk gate).
///
+1 -66
View File
@@ -1,32 +1,6 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Try to deserialize a `serde_json::Value` as `T`. If the value is a JSON
/// string that looks like an object (i.e. the LLM double-serialized it), parse
/// the inner string first and then deserialize the resulting object. This
/// provides backward-compatible handling for both `Value::Object` and
/// `Value::String` representations.
pub fn deserialize_maybe_stringified<T: serde::de::DeserializeOwned>(
v: &serde_json::Value,
) -> Result<T, serde_json::Error> {
// Fast path: value is already the right shape (object, array, etc.)
match serde_json::from_value::<T>(v.clone()) {
Ok(parsed) => Ok(parsed),
Err(first_err) => {
// If it's a string, try parsing the string as JSON first.
if let Some(s) = v.as_str() {
let s = s.trim();
if s.starts_with('{') || s.starts_with('[') {
if let Ok(inner) = serde_json::from_str::<serde_json::Value>(s) {
return serde_json::from_value::<T>(inner);
}
}
}
Err(first_err)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum JobType {
@@ -180,46 +154,7 @@ pub struct CronJobPatch {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_schedule_from_object() {
let val = serde_json::json!({"kind": "cron", "expr": "*/5 * * * *"});
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
assert!(matches!(sched, Schedule::Cron { ref expr, .. } if expr == "*/5 * * * *"));
}
#[test]
fn deserialize_schedule_from_string() {
let val = serde_json::Value::String(r#"{"kind":"cron","expr":"*/5 * * * *"}"#.to_string());
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
assert!(matches!(sched, Schedule::Cron { ref expr, .. } if expr == "*/5 * * * *"));
}
#[test]
fn deserialize_schedule_string_with_tz() {
let val = serde_json::Value::String(
r#"{"kind":"cron","expr":"*/30 9-15 * * 1-5","tz":"Asia/Shanghai"}"#.to_string(),
);
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
match sched {
Schedule::Cron { tz, .. } => assert_eq!(tz.as_deref(), Some("Asia/Shanghai")),
_ => panic!("expected Cron variant"),
}
}
#[test]
fn deserialize_every_from_string() {
let val = serde_json::Value::String(r#"{"kind":"every","every_ms":60000}"#.to_string());
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
assert!(matches!(sched, Schedule::Every { every_ms: 60000 }));
}
#[test]
fn deserialize_invalid_string_returns_error() {
let val = serde_json::Value::String("not json at all".to_string());
assert!(deserialize_maybe_stringified::<Schedule>(&val).is_err());
}
use super::JobType;
#[test]
fn job_type_try_from_accepts_known_values_case_insensitive() {
+4 -150
View File
@@ -17,6 +17,9 @@
//!
//! # Limitations
//!
//! - **Conversation history**: Only the system prompt (if present) and the last
//! user message are forwarded. Full multi-turn history is not preserved because
//! the CLI accepts a single prompt per invocation.
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
@@ -31,7 +34,7 @@
//!
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
use crate::providers::traits::{ChatMessage, ChatRequest, ChatResponse, Provider, TokenUsage};
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
@@ -209,54 +212,6 @@ impl Provider for ClaudeCodeProvider {
self.invoke_cli(&full_message, model).await
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
// Separate system prompt from conversation messages.
let system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.as_str());
// Build conversation turns (skip system messages).
let turns: Vec<&ChatMessage> = messages.iter().filter(|m| m.role != "system").collect();
// If there's only one user message, use the simple path.
if turns.len() <= 1 {
let last_user = turns.first().map(|m| m.content.as_str()).unwrap_or("");
let full_message = match system {
Some(s) if !s.is_empty() => format!("{s}\n\n{last_user}"),
_ => last_user.to_string(),
};
return self.invoke_cli(&full_message, model).await;
}
// Format multi-turn conversation into a single prompt.
let mut parts = Vec::new();
if let Some(s) = system {
if !s.is_empty() {
parts.push(format!("[system]\n{s}"));
}
}
for msg in &turns {
let label = match msg.role.as_str() {
"user" => "[user]",
"assistant" => "[assistant]",
other => other,
};
parts.push(format!("{label}\n{}", msg.content));
}
parts.push("[assistant]".to_string());
let full_message = parts.join("\n\n");
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
@@ -372,105 +327,4 @@ mod tests {
"unexpected error message: {msg}"
);
}
/// Helper: create a provider that uses a shell script echoing stdin back.
/// The script ignores CLI flags (`--print`, `--model`, `-`) and just cats stdin.
///
/// Uses `OnceLock` to write the script file exactly once, avoiding
/// "Text file busy" (ETXTBSY) races when parallel tests try to
/// overwrite a script that another test is currently executing.
fn echo_provider() -> ClaudeCodeProvider {
use std::sync::OnceLock;
static SCRIPT_PATH: OnceLock<PathBuf> = OnceLock::new();
let script = SCRIPT_PATH.get_or_init(|| {
use std::io::Write;
let dir = std::env::temp_dir().join("zeroclaw_test_claude_code");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("fake_claude.sh");
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "#!/bin/sh\ncat /dev/stdin").unwrap();
drop(f);
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
path
});
ClaudeCodeProvider {
binary_path: script.clone(),
}
}
#[tokio::test]
async fn chat_with_history_single_user_message() {
let provider = echo_provider();
let messages = vec![ChatMessage::user("hello")];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert_eq!(result, "hello");
}
#[tokio::test]
async fn chat_with_history_single_user_with_system() {
let provider = echo_provider();
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("hello"),
];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert_eq!(result, "You are helpful.\n\nhello");
}
#[tokio::test]
async fn chat_with_history_multi_turn_includes_all_messages() {
let provider = echo_provider();
let messages = vec![
ChatMessage::system("Be concise."),
ChatMessage::user("What is 2+2?"),
ChatMessage::assistant("4"),
ChatMessage::user("And 3+3?"),
];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert!(result.contains("[system]\nBe concise."));
assert!(result.contains("[user]\nWhat is 2+2?"));
assert!(result.contains("[assistant]\n4"));
assert!(result.contains("[user]\nAnd 3+3?"));
assert!(result.ends_with("[assistant]"));
}
#[tokio::test]
async fn chat_with_history_multi_turn_without_system() {
let provider = echo_provider();
let messages = vec![
ChatMessage::user("hi"),
ChatMessage::assistant("hello"),
ChatMessage::user("bye"),
];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert!(!result.contains("[system]"));
assert!(result.contains("[user]\nhi"));
assert!(result.contains("[assistant]\nhello"));
assert!(result.contains("[user]\nbye"));
}
#[tokio::test]
async fn chat_with_history_rejects_bad_temperature() {
let provider = echo_provider();
let messages = vec![ChatMessage::user("test")];
let result = provider.chat_with_history(&messages, "default", 0.5).await;
assert!(result.is_err());
}
}
+1 -2
View File
@@ -1320,12 +1320,11 @@ fn create_provider_with_url_and_options(
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("llama.cpp");
Ok(compat(OpenAiCompatibleProvider::new_with_vision(
Ok(compat(OpenAiCompatibleProvider::new(
"llama.cpp",
base_url,
Some(llama_cpp_key),
AuthStyle::Bearer,
true,
)))
}
"sglang" => {
+41 -250
View File
@@ -16,10 +16,8 @@ use std::time::Duration;
/// Check if an error is non-retryable (client errors that won't resolve with retries).
pub fn is_non_retryable(err: &anyhow::Error) -> bool {
// Context window errors are NOT non-retryable — they can be recovered
// by truncating conversation history, so let the retry loop handle them.
if is_context_window_exceeded(err) {
return false;
return true;
}
// 4xx errors are generally non-retryable (bad request, auth failure, etc.),
@@ -77,7 +75,6 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
let lower = err.to_string().to_lowercase();
let hints = [
"exceeds the context window",
"exceeds the available context size",
"context window of this model",
"maximum context length",
"context length exceeded",
@@ -200,35 +197,6 @@ fn compact_error_detail(err: &anyhow::Error) -> String {
.join(" ")
}
/// Truncate conversation history by dropping the oldest non-system messages.
/// Returns the number of messages dropped. Keeps at least the system message
/// (if any) and the most recent user message.
fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
// Find all non-system message indices
let non_system: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| m.role != "system")
.map(|(i, _)| i)
.collect();
// Keep at least the last non-system message (most recent user turn)
if non_system.len() <= 1 {
return 0;
}
// Drop the oldest half of non-system messages
let drop_count = non_system.len() / 2;
let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
// Remove in reverse order to preserve indices
for &idx in indices_to_remove.iter().rev() {
messages.remove(idx);
}
drop_count
}
fn push_failure(
failures: &mut Vec<String>,
provider_name: &str,
@@ -370,25 +338,6 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
// Context window exceeded: no history to truncate
// in chat_with_system, bail immediately.
if is_context_window_exceeded(&e) {
let error_detail = compact_error_detail(&e);
push_failure(
&mut failures,
provider_name,
current_model,
attempt + 1,
self.max_retries + 1,
"non_retryable",
&error_detail,
);
anyhow::bail!(
"Request exceeds model context window. Attempts:\n{}",
failures.join("\n")
);
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@@ -427,6 +376,14 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@@ -478,8 +435,6 @@ impl Provider for ReliableProvider {
) -> anyhow::Result<String> {
let models = self.model_chain(model);
let mut failures = Vec::new();
let mut effective_messages = messages.to_vec();
let mut context_truncated = false;
for current_model in &models {
for (provider_name, provider) in &self.providers {
@@ -487,39 +442,22 @@ impl Provider for ReliableProvider {
for attempt in 0..=self.max_retries {
match provider
.chat_with_history(&effective_messages, current_model, temperature)
.chat_with_history(messages, current_model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 || *current_model != model || context_truncated {
if attempt > 0 || *current_model != model {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
context_truncated,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
// Context window exceeded: truncate history and retry
if is_context_window_exceeded(&e) && !context_truncated {
let dropped = truncate_for_context(&mut effective_messages);
if dropped > 0 {
context_truncated = true;
tracing::warn!(
provider = provider_name,
model = *current_model,
dropped,
remaining = effective_messages.len(),
"Context window exceeded; truncated history and retrying"
);
continue; // Retry with truncated messages (counts as an attempt)
}
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@@ -556,6 +494,14 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@@ -613,8 +559,6 @@ impl Provider for ReliableProvider {
) -> anyhow::Result<ChatResponse> {
let models = self.model_chain(model);
let mut failures = Vec::new();
let mut effective_messages = messages.to_vec();
let mut context_truncated = false;
for current_model in &models {
for (provider_name, provider) in &self.providers {
@@ -622,39 +566,22 @@ impl Provider for ReliableProvider {
for attempt in 0..=self.max_retries {
match provider
.chat_with_tools(&effective_messages, tools, current_model, temperature)
.chat_with_tools(messages, tools, current_model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 || *current_model != model || context_truncated {
if attempt > 0 || *current_model != model {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
context_truncated,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
// Context window exceeded: truncate history and retry
if is_context_window_exceeded(&e) && !context_truncated {
let dropped = truncate_for_context(&mut effective_messages);
if dropped > 0 {
context_truncated = true;
tracing::warn!(
provider = provider_name,
model = *current_model,
dropped,
remaining = effective_messages.len(),
"Context window exceeded; truncated history and retrying"
);
continue; // Retry with truncated messages (counts as an attempt)
}
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@@ -691,6 +618,14 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@@ -734,8 +669,6 @@ impl Provider for ReliableProvider {
) -> anyhow::Result<ChatResponse> {
let models = self.model_chain(model);
let mut failures = Vec::new();
let mut effective_messages = request.messages.to_vec();
let mut context_truncated = false;
for current_model in &models {
for (provider_name, provider) in &self.providers {
@@ -743,40 +676,23 @@ impl Provider for ReliableProvider {
for attempt in 0..=self.max_retries {
let req = ChatRequest {
messages: &effective_messages,
messages: request.messages,
tools: request.tools,
};
match provider.chat(req, current_model, temperature).await {
Ok(resp) => {
if attempt > 0 || *current_model != model || context_truncated {
if attempt > 0 || *current_model != model {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
context_truncated,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
// Context window exceeded: truncate history and retry
if is_context_window_exceeded(&e) && !context_truncated {
let dropped = truncate_for_context(&mut effective_messages);
if dropped > 0 {
context_truncated = true;
tracing::warn!(
provider = provider_name,
model = *current_model,
dropped,
remaining = effective_messages.len(),
"Context window exceeded; truncated history and retrying"
);
continue; // Retry with truncated messages (counts as an attempt)
}
}
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
@@ -813,6 +729,14 @@ impl Provider for ReliableProvider {
error = %error_detail,
"Non-retryable error, moving on"
);
if is_context_window_exceeded(&e) {
anyhow::bail!(
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
failures.join("\n")
);
}
break;
}
@@ -1147,8 +1071,7 @@ mod tests {
assert!(!is_non_retryable(&anyhow::anyhow!(
"model overloaded, try again later"
)));
// Context window errors are now recoverable (not non-retryable)
assert!(!is_non_retryable(&anyhow::anyhow!(
assert!(is_non_retryable(&anyhow::anyhow!(
"OpenAI Codex stream error: Your input exceeds the context window of this model."
)));
}
@@ -1184,7 +1107,7 @@ mod tests {
let msg = err.to_string();
assert!(msg.contains("context window"));
// chat_with_system has no history to truncate, so it bails immediately
assert!(msg.contains("skipped"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@@ -2057,136 +1980,4 @@ mod tests {
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
// ── Context window truncation tests ─────────────────────────
#[test]
fn context_window_error_is_not_non_retryable() {
// Context window errors should be recoverable via truncation
assert!(!is_non_retryable(&anyhow::anyhow!(
"exceeds the context window"
)));
assert!(!is_non_retryable(&anyhow::anyhow!(
"maximum context length exceeded"
)));
assert!(!is_non_retryable(&anyhow::anyhow!(
"too many tokens in the request"
)));
assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded")));
}
#[test]
fn is_context_window_exceeded_detects_llamacpp() {
assert!(is_context_window_exceeded(&anyhow::anyhow!(
"request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
)));
}
#[test]
fn truncate_for_context_drops_oldest_non_system() {
let mut messages = vec![
ChatMessage::system("sys"),
ChatMessage::user("msg1"),
ChatMessage::assistant("resp1"),
ChatMessage::user("msg2"),
ChatMessage::assistant("resp2"),
ChatMessage::user("msg3"),
];
let dropped = truncate_for_context(&mut messages);
// 5 non-system messages, drop oldest half = 2
assert_eq!(dropped, 2);
// System message preserved
assert_eq!(messages[0].role, "system");
// Remaining messages should be the newer ones
assert_eq!(messages.len(), 4); // system + 3 remaining non-system
// The last message should still be the most recent user message
assert_eq!(messages.last().unwrap().content, "msg3");
}
#[test]
fn truncate_for_context_preserves_system_and_last_message() {
// Only one non-system message: nothing to drop
let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
let dropped = truncate_for_context(&mut messages);
assert_eq!(dropped, 0);
assert_eq!(messages.len(), 2);
// No system message, only one user message
let mut messages = vec![ChatMessage::user("only")];
let dropped = truncate_for_context(&mut messages);
assert_eq!(dropped, 0);
assert_eq!(messages.len(), 1);
}
/// Mock that fails with context error on first N calls, then succeeds.
/// Tracks the number of messages received on each call.
struct ContextOverflowMock {
calls: Arc<AtomicUsize>,
fail_until_attempt: usize,
message_counts: parking_lot::Mutex<Vec<usize>>,
}
#[async_trait]
impl Provider for ContextOverflowMock {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".to_string())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
self.message_counts.lock().push(messages.len());
if attempt <= self.fail_until_attempt {
anyhow::bail!(
"request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
);
}
Ok("recovered after truncation".to_string())
}
}
#[tokio::test]
async fn chat_with_history_truncates_on_context_overflow() {
let calls = Arc::new(AtomicUsize::new(0));
let mock = ContextOverflowMock {
calls: Arc::clone(&calls),
fail_until_attempt: 1, // fail first call, succeed after truncation
message_counts: parking_lot::Mutex::new(Vec::new()),
};
let provider = ReliableProvider::new(
vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
3,
1,
);
let messages = vec![
ChatMessage::system("system prompt"),
ChatMessage::user("old message 1"),
ChatMessage::assistant("old response 1"),
ChatMessage::user("old message 2"),
ChatMessage::assistant("old response 2"),
ChatMessage::user("current question"),
];
let result = provider
.chat_with_history(&messages, "local-model", 0.0)
.await
.unwrap();
assert_eq!(result, "recovered after truncation");
// Should have been called twice: once with full messages, once with truncated
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
}
+2 -61
View File
@@ -1,8 +1,6 @@
use super::traits::{Tool, ToolResult};
use crate::config::Config;
use crate::cron::{
self, deserialize_maybe_stringified, DeliveryConfig, JobType, Schedule, SessionTarget,
};
use crate::cron::{self, DeliveryConfig, JobType, Schedule, SessionTarget};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
@@ -178,7 +176,7 @@ impl Tool for CronAddTool {
}
let schedule = match args.get("schedule") {
Some(v) => match deserialize_maybe_stringified::<Schedule>(v) {
Some(v) => match serde_json::from_value::<Schedule>(v.clone()) {
Ok(schedule) => schedule,
Err(e) => {
return Ok(ToolResult {
@@ -513,63 +511,6 @@ mod tests {
assert!(approved.success, "{:?}", approved.error);
}
#[tokio::test]
async fn accepts_schedule_passed_as_json_string() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
// Simulate the LLM double-serializing the schedule: the value arrives
// as a JSON string containing a JSON object, rather than an object.
let result = tool
.execute(json!({
"schedule": r#"{"kind":"cron","expr":"*/5 * * * *"}"#,
"job_type": "shell",
"command": "echo string-schedule"
}))
.await
.unwrap();
assert!(result.success, "{:?}", result.error);
assert!(result.output.contains("next_run"));
}
#[tokio::test]
async fn accepts_stringified_interval_schedule() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
let result = tool
.execute(json!({
"schedule": r#"{"kind":"every","every_ms":60000}"#,
"job_type": "shell",
"command": "echo interval"
}))
.await
.unwrap();
assert!(result.success, "{:?}", result.error);
}
#[tokio::test]
async fn accepts_stringified_schedule_with_timezone() {
let tmp = TempDir::new().unwrap();
let cfg = test_config(&tmp).await;
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
let result = tool
.execute(json!({
"schedule": r#"{"kind":"cron","expr":"*/30 9-15 * * 1-5","tz":"Asia/Shanghai"}"#,
"job_type": "shell",
"command": "echo tz-test"
}))
.await
.unwrap();
assert!(result.success, "{:?}", result.error);
}
#[tokio::test]
async fn rejects_invalid_schedule() {
let tmp = TempDir::new().unwrap();
+2 -2
View File
@@ -1,6 +1,6 @@
use super::traits::{Tool, ToolResult};
use crate::config::Config;
use crate::cron::{self, deserialize_maybe_stringified, CronJobPatch};
use crate::cron::{self, CronJobPatch};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
@@ -202,7 +202,7 @@ impl Tool for CronUpdateTool {
}
};
let patch = match deserialize_maybe_stringified::<CronJobPatch>(&patch_val) {
let patch = match serde_json::from_value::<CronJobPatch>(patch_val) {
Ok(patch) => patch,
Err(e) => {
return Ok(ToolResult {
-1
View File
@@ -422,7 +422,6 @@ impl DelegateTool {
&[],
&[],
None,
None,
),
)
.await;
+2 -92
View File
@@ -233,22 +233,12 @@ impl Default for ActivatedToolSet {
/// Build the `<available-deferred-tools>` section for the system prompt.
/// Lists only tool names so the LLM knows what is available without
/// consuming context window on full schemas. Includes an instruction
/// block that tells the LLM to call `tool_search` to activate them.
/// consuming context window on full schemas.
pub fn build_deferred_tools_section(deferred: &DeferredMcpToolSet) -> String {
if deferred.is_empty() {
return String::new();
}
let mut out = String::new();
out.push_str("## Deferred Tools\n\n");
out.push_str(
"The tools listed below are available but NOT yet loaded. \
To use any of them you MUST first call the `tool_search` tool \
to fetch their full schemas. Use `\"select:name1,name2\"` for \
exact tools or keywords to search. Once activated, the tools \
become callable for the rest of the conversation.\n\n",
);
out.push_str("<available-deferred-tools>\n");
let mut out = String::from("<available-deferred-tools>\n");
for stub in &deferred.stubs {
out.push_str(&stub.prefixed_name);
out.push('\n');
@@ -426,55 +416,6 @@ mod tests {
assert!(section.contains("</available-deferred-tools>"));
}
#[test]
fn build_deferred_section_includes_tool_search_instruction() {
let stubs = vec![make_stub("fs__read_file", "Read a file")];
let set = DeferredMcpToolSet {
stubs,
registry: std::sync::Arc::new(
tokio::runtime::Runtime::new()
.unwrap()
.block_on(McpRegistry::connect_all(&[]))
.unwrap(),
),
};
let section = build_deferred_tools_section(&set);
assert!(
section.contains("tool_search"),
"deferred section must instruct the LLM to use tool_search"
);
assert!(
section.contains("## Deferred Tools"),
"deferred section must include a heading"
);
}
#[test]
fn build_deferred_section_multiple_servers() {
let stubs = vec![
make_stub("server_a__list", "List items"),
make_stub("server_a__create", "Create item"),
make_stub("server_b__query", "Query records"),
];
let set = DeferredMcpToolSet {
stubs,
registry: std::sync::Arc::new(
tokio::runtime::Runtime::new()
.unwrap()
.block_on(McpRegistry::connect_all(&[]))
.unwrap(),
),
};
let section = build_deferred_tools_section(&set);
assert!(section.contains("server_a__list"));
assert!(section.contains("server_a__create"));
assert!(section.contains("server_b__query"));
assert!(
section.contains("tool_search"),
"section must mention tool_search for multi-server setups"
);
}
#[test]
fn keyword_search_ranks_by_hits() {
let stubs = vec![
@@ -516,35 +457,4 @@ mod tests {
assert!(set.get_by_name("a__one").is_some());
assert!(set.get_by_name("nonexistent").is_none());
}
#[test]
fn search_across_multiple_servers() {
let stubs = vec![
make_stub("server_a__read_file", "Read a file from disk"),
make_stub("server_b__read_config", "Read configuration from database"),
];
let set = DeferredMcpToolSet {
stubs,
registry: std::sync::Arc::new(
tokio::runtime::Runtime::new()
.unwrap()
.block_on(McpRegistry::connect_all(&[]))
.unwrap(),
),
};
// "read" should match stubs from both servers
let results = set.search("read", 10);
assert_eq!(results.len(), 2);
// "file" should match only server_a
let results = set.search("file", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].prefixed_name, "server_a__read_file");
// "config database" should rank server_b highest (2 hits)
let results = set.search("config database", 10);
assert!(!results.is_empty());
assert_eq!(results[0].prefixed_name, "server_b__read_config");
}
}
-3
View File
@@ -59,7 +59,6 @@ pub mod memory_recall;
pub mod memory_store;
pub mod microsoft365;
pub mod model_routing_config;
pub mod model_switch;
pub mod node_tool;
pub mod notion_tool;
pub mod pdf_read;
@@ -120,7 +119,6 @@ pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool;
pub use microsoft365::Microsoft365Tool;
pub use model_routing_config::ModelRoutingConfigTool;
pub use model_switch::ModelSwitchTool;
#[allow(unused_imports)]
pub use node_tool::NodeTool;
pub use notion_tool::NotionTool;
@@ -304,7 +302,6 @@ pub fn all_tools_with_runtime(
config.clone(),
security.clone(),
)),
Arc::new(ModelSwitchTool::new(security.clone())),
Arc::new(ProxyConfigTool::new(config.clone(), security.clone())),
Arc::new(GitOperationsTool::new(
security.clone(),
-264
View File
@@ -1,264 +0,0 @@
use super::traits::{Tool, ToolResult};
use crate::agent::loop_::get_model_switch_state;
use crate::providers;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
pub struct ModelSwitchTool {
security: Arc<SecurityPolicy>,
}
impl ModelSwitchTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self { security }
}
}
#[async_trait]
impl Tool for ModelSwitchTool {
fn name(&self) -> &str {
"model_switch"
}
fn description(&self) -> &str {
"Switch the AI model at runtime. Use 'get' to see current model, 'list_providers' to see available providers, 'list_models' to see models for a provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["get", "set", "list_providers", "list_models"],
"description": "Action to perform: get current model, set a new model, list available providers, or list models for a provider"
},
"provider": {
"type": "string",
"description": "Provider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions."
},
"model": {
"type": "string",
"description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "model_switch")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
match action {
"get" => self.handle_get(),
"set" => self.handle_set(&args),
"list_providers" => self.handle_list_providers(),
"list_models" => self.handle_list_models(&args),
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: {}. Valid actions: get, set, list_providers, list_models",
action
)),
}),
}
}
}
impl ModelSwitchTool {
fn handle_get(&self) -> anyhow::Result<ToolResult> {
let switch_state = get_model_switch_state();
let pending = switch_state.lock().unwrap().clone();
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&json!({
"pending_switch": pending,
"note": "To switch models, use action 'set' with provider and model parameters"
}))?,
error: None,
})
}
fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let provider = args.get("provider").and_then(|v| v.as_str());
let provider = match provider {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing 'provider' parameter for 'set' action".to_string()),
});
}
};
let model = args.get("model").and_then(|v| v.as_str());
let model = match model {
Some(m) => m,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing 'model' parameter for 'set' action".to_string()),
});
}
};
// Validate the provider exists
let known_providers = providers::list_providers();
let provider_valid = known_providers.iter().any(|p| {
p.name.eq_ignore_ascii_case(provider)
|| p.aliases.iter().any(|a| a.eq_ignore_ascii_case(provider))
});
if !provider_valid {
return Ok(ToolResult {
success: false,
output: serde_json::to_string_pretty(&json!({
"available_providers": known_providers.iter().map(|p| p.name).collect::<Vec<_>>()
}))?,
error: Some(format!(
"Unknown provider: {}. Use 'list_providers' to see available options.",
provider
)),
});
}
// Set the global model switch request
let switch_state = get_model_switch_state();
*switch_state.lock().unwrap() = Some((provider.to_string(), model.to_string()));
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&json!({
"message": "Model switch requested",
"provider": provider,
"model": model,
"note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch."
}))?,
error: None,
})
}
fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
let providers_list = providers::list_providers();
let providers: Vec<serde_json::Value> = providers_list
.iter()
.map(|p| {
json!({
"name": p.name,
"display_name": p.display_name,
"aliases": p.aliases,
"local": p.local
})
})
.collect();
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&json!({
"providers": providers,
"count": providers.len(),
"example": "Use action 'set' with provider and model to switch"
}))?,
error: None,
})
}
fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let provider = args.get("provider").and_then(|v| v.as_str());
let provider = match provider {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Missing 'provider' parameter for 'list_models' action".to_string(),
),
});
}
};
// Return common models for known providers
let models = match provider.to_lowercase().as_str() {
"openai" => vec![
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
],
"anthropic" => vec![
"claude-sonnet-4-6",
"claude-sonnet-4-5",
"claude-3-5-sonnet",
"claude-3-opus",
"claude-3-haiku",
],
"openrouter" => vec![
"anthropic/claude-sonnet-4-6",
"openai/gpt-4o",
"google/gemini-pro",
"meta-llama/llama-3-70b-instruct",
],
"groq" => vec![
"llama-3.3-70b-versatile",
"mixtral-8x7b-32768",
"llama-3.1-70b-speculative",
],
"ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
"deepseek" => vec!["deepseek-chat", "deepseek-coder"],
"mistral" => vec![
"mistral-large-latest",
"mistral-small-latest",
"mistral-nemo",
],
"google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
"xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"],
_ => vec![],
};
if models.is_empty() {
return Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&json!({
"provider": provider,
"models": [],
"note": "No common models listed for this provider. Check provider documentation for available models."
}))?,
error: None,
});
}
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&json!({
"provider": provider,
"models": models,
"example": "Use action 'set' with this provider and a model ID to switch"
}))?,
error: None,
})
}
}
-84
View File
@@ -281,88 +281,4 @@ mod tests {
// Tool should now be activated
assert!(activated.lock().unwrap().is_activated("fs__read"));
}
/// Verify tool_search works with stubs from multiple MCP servers,
/// simulating a daemon-mode setup where several servers are deferred.
#[tokio::test]
async fn multiple_servers_stubs_all_searchable() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![
make_stub("server_a__list_files", "List files on server A"),
make_stub("server_a__read_file", "Read file on server A"),
make_stub("server_b__query_db", "Query database on server B"),
make_stub("server_b__insert_row", "Insert row on server B"),
];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
// Search should find tools across both servers
let result = tool
.execute(serde_json::json!({"query": "file"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("server_a__list_files"));
assert!(result.output.contains("server_a__read_file"));
// Server B tools should also be searchable
let result = tool
.execute(serde_json::json!({"query": "database query"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("server_b__query_db"));
}
/// Verify select mode activates tools and they stay activated across calls,
/// matching the daemon-mode pattern where a single ActivatedToolSet persists.
#[tokio::test]
async fn select_activates_and_persists_across_calls() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![
make_stub("srv__tool_a", "Tool A"),
make_stub("srv__tool_b", "Tool B"),
];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
// Activate tool_a
let result = tool
.execute(serde_json::json!({"query": "select:srv__tool_a"}))
.await
.unwrap();
assert!(result.success);
assert!(activated.lock().unwrap().is_activated("srv__tool_a"));
assert!(!activated.lock().unwrap().is_activated("srv__tool_b"));
// Activate tool_b in a separate call
let result = tool
.execute(serde_json::json!({"query": "select:srv__tool_b"}))
.await
.unwrap();
assert!(result.success);
// Both should remain activated
let guard = activated.lock().unwrap();
assert!(guard.is_activated("srv__tool_a"));
assert!(guard.is_activated("srv__tool_b"));
assert_eq!(guard.tool_specs().len(), 2);
}
/// Verify re-activating an already-activated tool does not duplicate it.
#[tokio::test]
async fn reactivation_is_idempotent() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let tool = ToolSearchTool::new(
make_deferred_set(vec![make_stub("srv__tool", "A tool")]).await,
Arc::clone(&activated),
);
tool.execute(serde_json::json!({"query": "select:srv__tool"}))
.await
.unwrap();
tool.execute(serde_json::json!({"query": "select:srv__tool"}))
.await
.unwrap();
assert_eq!(activated.lock().unwrap().tool_specs().len(), 1);
}
}