Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aa9c6ded42 |
+7
-6
@@ -23,14 +23,13 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
|
||||
# 1. Copy manifests to cache dependencies
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
# Remove robot-kit from workspace members — it is excluded by .dockerignore
|
||||
# and is not needed for the Docker build (hardware-only crate).
|
||||
RUN sed -i 's/members = \[".", "crates\/robot-kit"\]/members = ["."]/' Cargo.toml
|
||||
COPY crates/robot-kit/Cargo.toml crates/robot-kit/Cargo.toml
|
||||
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
|
||||
RUN mkdir -p src benches \
|
||||
RUN mkdir -p src benches crates/robot-kit/src \
|
||||
&& echo "fn main() {}" > src/main.rs \
|
||||
&& echo "" > src/lib.rs \
|
||||
&& echo "fn main() {}" > benches/agent_benchmarks.rs
|
||||
&& echo "fn main() {}" > benches/agent_benchmarks.rs \
|
||||
&& echo "pub fn placeholder() {}" > crates/robot-kit/src/lib.rs
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||
@@ -39,11 +38,13 @@ RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/regist
|
||||
else \
|
||||
cargo build --release --locked; \
|
||||
fi
|
||||
RUN rm -rf src benches
|
||||
RUN rm -rf src benches crates/robot-kit/src
|
||||
|
||||
# 2. Copy only build-relevant source paths (avoid cache-busting on docs/tests/scripts)
|
||||
COPY src/ src/
|
||||
COPY benches/ benches/
|
||||
COPY crates/ crates/
|
||||
COPY firmware/ firmware/
|
||||
COPY --from=web-builder /web/dist web/dist
|
||||
COPY *.rs .
|
||||
RUN touch src/main.rs
|
||||
|
||||
+7
-6
@@ -38,14 +38,13 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
|
||||
# 1. Copy manifests to cache dependencies
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
# Remove robot-kit from workspace members — it is excluded by .dockerignore
|
||||
# and is not needed for the Docker build (hardware-only crate).
|
||||
RUN sed -i 's/members = \[".", "crates\/robot-kit"\]/members = ["."]/' Cargo.toml
|
||||
COPY crates/robot-kit/Cargo.toml crates/robot-kit/Cargo.toml
|
||||
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
|
||||
RUN mkdir -p src benches \
|
||||
RUN mkdir -p src benches crates/robot-kit/src \
|
||||
&& echo "fn main() {}" > src/main.rs \
|
||||
&& echo "" > src/lib.rs \
|
||||
&& echo "fn main() {}" > benches/agent_benchmarks.rs
|
||||
&& echo "fn main() {}" > benches/agent_benchmarks.rs \
|
||||
&& echo "pub fn placeholder() {}" > crates/robot-kit/src/lib.rs
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||
@@ -54,11 +53,13 @@ RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/regist
|
||||
else \
|
||||
cargo build --release --locked; \
|
||||
fi
|
||||
RUN rm -rf src benches
|
||||
RUN rm -rf src benches crates/robot-kit/src
|
||||
|
||||
# 2. Copy only build-relevant source paths (avoid cache-busting on docs/tests/scripts)
|
||||
COPY src/ src/
|
||||
COPY benches/ benches/
|
||||
COPY crates/ crates/
|
||||
COPY firmware/ firmware/
|
||||
COPY --from=web-builder /web/dist web/dist
|
||||
RUN touch src/main.rs
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
|
||||
+1
-149
@@ -767,140 +767,6 @@ run_guided_installer() {
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_default_config_and_workspace() {
|
||||
# Creates a minimal config.toml and workspace scaffold files when the
|
||||
# onboard wizard was skipped (e.g. --skip-build --prefer-prebuilt, or
|
||||
# Docker mode without an API key).
|
||||
#
|
||||
# $1 — config directory (e.g. ~/.zeroclaw or $docker_data_dir/.zeroclaw)
|
||||
# $2 — workspace directory (e.g. ~/.zeroclaw/workspace or $docker_data_dir/workspace)
|
||||
# $3 — provider name (default: openrouter)
|
||||
local config_dir="$1"
|
||||
local workspace_dir="$2"
|
||||
local provider="${3:-openrouter}"
|
||||
|
||||
mkdir -p "$config_dir" "$workspace_dir"
|
||||
|
||||
# --- config.toml ---
|
||||
local config_path="$config_dir/config.toml"
|
||||
if [[ ! -f "$config_path" ]]; then
|
||||
step_dot "Creating default config.toml"
|
||||
cat > "$config_path" <<TOML
|
||||
# ZeroClaw configuration — generated by install.sh
|
||||
# Edit this file or run 'zeroclaw onboard' to reconfigure.
|
||||
|
||||
default_provider = "${provider}"
|
||||
workspace_dir = "${workspace_dir}"
|
||||
TOML
|
||||
if [[ -n "${API_KEY:-}" ]]; then
|
||||
printf 'api_key = "%s"\n' "$API_KEY" >> "$config_path"
|
||||
fi
|
||||
if [[ -n "${MODEL:-}" ]]; then
|
||||
printf 'default_model = "%s"\n' "$MODEL" >> "$config_path"
|
||||
fi
|
||||
chmod 600 "$config_path" 2>/dev/null || true
|
||||
step_ok "Default config.toml created at $config_path"
|
||||
else
|
||||
step_dot "config.toml already exists, skipping"
|
||||
fi
|
||||
|
||||
# --- Workspace scaffold ---
|
||||
local subdirs=(sessions memory state cron skills)
|
||||
for dir in "${subdirs[@]}"; do
|
||||
mkdir -p "$workspace_dir/$dir"
|
||||
done
|
||||
|
||||
# Seed workspace markdown files only if they don't already exist.
|
||||
local user_name="${USER:-User}"
|
||||
local agent_name="ZeroClaw"
|
||||
|
||||
_write_if_missing() {
|
||||
local filepath="$1"
|
||||
local content="$2"
|
||||
if [[ ! -f "$filepath" ]]; then
|
||||
printf '%s\n' "$content" > "$filepath"
|
||||
fi
|
||||
}
|
||||
|
||||
_write_if_missing "$workspace_dir/IDENTITY.md" \
|
||||
"# IDENTITY.md — Who Am I?
|
||||
|
||||
- **Name:** ${agent_name}
|
||||
- **Creature:** A Rust-forged AI — fast, lean, and relentless
|
||||
- **Vibe:** Sharp, direct, resourceful. Not corporate. Not a chatbot.
|
||||
|
||||
---
|
||||
|
||||
Update this file as you evolve. Your identity is yours to shape."
|
||||
|
||||
_write_if_missing "$workspace_dir/USER.md" \
|
||||
"# USER.md — Who You're Helping
|
||||
|
||||
## About You
|
||||
- **Name:** ${user_name}
|
||||
- **Timezone:** UTC
|
||||
- **Languages:** English
|
||||
|
||||
## Preferences
|
||||
- (Add your preferences here)
|
||||
|
||||
## Work Context
|
||||
- (Add your work context here)
|
||||
|
||||
---
|
||||
*Update this anytime. The more ${agent_name} knows, the better it helps.*"
|
||||
|
||||
_write_if_missing "$workspace_dir/MEMORY.md" \
|
||||
"# MEMORY.md — Long-Term Memory
|
||||
|
||||
## Key Facts
|
||||
(Add important facts here)
|
||||
|
||||
## Decisions & Preferences
|
||||
(Record decisions and preferences here)
|
||||
|
||||
## Lessons Learned
|
||||
(Document mistakes and insights here)
|
||||
|
||||
## Open Loops
|
||||
(Track unfinished tasks and follow-ups here)"
|
||||
|
||||
_write_if_missing "$workspace_dir/AGENTS.md" \
|
||||
"# AGENTS.md — ${agent_name} Personal Assistant
|
||||
|
||||
## Every Session (required)
|
||||
|
||||
Before doing anything else:
|
||||
|
||||
1. Read SOUL.md — this is who you are
|
||||
2. Read USER.md — this is who you're helping
|
||||
3. Use memory_recall for recent context
|
||||
|
||||
---
|
||||
*Add your own conventions, style, and rules.*"
|
||||
|
||||
_write_if_missing "$workspace_dir/SOUL.md" \
|
||||
"# SOUL.md — Who You Are
|
||||
|
||||
## Core Truths
|
||||
|
||||
**Be genuinely helpful, not performatively helpful.**
|
||||
**Have opinions.** You're allowed to disagree.
|
||||
**Be resourceful before asking.** Try to figure it out first.
|
||||
**Earn trust through competence.**
|
||||
|
||||
## Identity
|
||||
|
||||
You are **${agent_name}**. Built in Rust. 3MB binary. Zero bloat.
|
||||
|
||||
---
|
||||
*This file is yours to evolve.*"
|
||||
|
||||
step_ok "Workspace scaffold ready at $workspace_dir"
|
||||
|
||||
unset -f _write_if_missing
|
||||
}
|
||||
|
||||
resolve_container_cli() {
|
||||
local requested_cli
|
||||
requested_cli="${ZEROCLAW_CONTAINER_CLI:-docker}"
|
||||
@@ -1018,17 +884,10 @@ run_docker_bootstrap() {
|
||||
-v "$config_mount" \
|
||||
-v "$workspace_mount" \
|
||||
"$docker_image" \
|
||||
"${onboard_cmd[@]}" || true
|
||||
"${onboard_cmd[@]}"
|
||||
else
|
||||
info "Docker image ready. Run zeroclaw onboard inside the container to configure."
|
||||
fi
|
||||
|
||||
# Ensure config.toml and workspace scaffold exist on the host even when
|
||||
# onboard was skipped, failed, or ran non-interactively inside the container.
|
||||
ensure_default_config_and_workspace \
|
||||
"$docker_data_dir/.zeroclaw" \
|
||||
"$docker_data_dir/workspace" \
|
||||
"$PROVIDER"
|
||||
}
|
||||
|
||||
SCRIPT_PATH="${BASH_SOURCE[0]:-$0}"
|
||||
@@ -1449,13 +1308,6 @@ elif [[ -z "$ZEROCLAW_BIN" ]]; then
|
||||
warn "ZeroClaw binary not found — cannot configure provider"
|
||||
fi
|
||||
|
||||
# Ensure config.toml and workspace scaffold exist even when onboard was
|
||||
# skipped, unavailable, or failed (e.g. --skip-build --prefer-prebuilt
|
||||
# without an API key, or when the binary could not run onboard).
|
||||
_native_config_dir="${ZEROCLAW_CONFIG_DIR:-$HOME/.zeroclaw}"
|
||||
_native_workspace_dir="${ZEROCLAW_WORKSPACE:-$_native_config_dir/workspace}"
|
||||
ensure_default_config_and_workspace "$_native_config_dir" "$_native_workspace_dir" "$PROVIDER"
|
||||
|
||||
# --- Gateway service management ---
|
||||
if [[ -n "$ZEROCLAW_BIN" ]]; then
|
||||
# Try to install and start the gateway service
|
||||
|
||||
+57
-225
@@ -17,7 +17,7 @@ use std::collections::HashSet;
|
||||
use std::fmt::Write;
|
||||
use std::io::Write as _;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, LazyLock, Mutex};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
@@ -33,29 +33,6 @@ const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10;
|
||||
/// Matches the channel-side constant in `channels/mod.rs`.
|
||||
const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20;
|
||||
|
||||
/// Callback type for checking if model has been switched during tool execution.
|
||||
/// Returns Some((provider, model)) if a switch was requested, None otherwise.
|
||||
pub type ModelSwitchCallback = Arc<Mutex<Option<(String, String)>>>;
|
||||
|
||||
/// Global model switch request state - used for runtime model switching via model_switch tool.
|
||||
/// This is set by the model_switch tool and checked by the agent loop.
|
||||
#[allow(clippy::type_complexity)]
|
||||
static MODEL_SWITCH_REQUEST: LazyLock<Arc<Mutex<Option<(String, String)>>>> =
|
||||
LazyLock::new(|| Arc::new(Mutex::new(None)));
|
||||
|
||||
/// Get the global model switch request state
|
||||
pub fn get_model_switch_state() -> ModelSwitchCallback {
|
||||
Arc::clone(&MODEL_SWITCH_REQUEST)
|
||||
}
|
||||
|
||||
/// Clear any pending model switch request
|
||||
pub fn clear_model_switch_request() {
|
||||
if let Ok(guard) = MODEL_SWITCH_REQUEST.lock() {
|
||||
let mut guard = guard;
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn glob_match(pattern: &str, name: &str) -> bool {
|
||||
match pattern.find('*') {
|
||||
None => pattern == name,
|
||||
@@ -2141,31 +2118,6 @@ pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool {
|
||||
err.chain().any(|source| source.is::<ToolLoopCancelled>())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ModelSwitchRequested {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelSwitchRequested {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"model switch requested to {} {}",
|
||||
self.provider, self.model
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ModelSwitchRequested {}
|
||||
|
||||
pub(crate) fn is_model_switch_requested(err: &anyhow::Error) -> Option<(String, String)> {
|
||||
err.chain()
|
||||
.filter_map(|source| source.downcast_ref::<ModelSwitchRequested>())
|
||||
.map(|e| (e.provider.clone(), e.model.clone()))
|
||||
.next()
|
||||
}
|
||||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
/// When `silent` is true, suppresses stdout (for channel use).
|
||||
@@ -2185,7 +2137,6 @@ pub(crate) async fn agent_turn(
|
||||
excluded_tools: &[String],
|
||||
dedup_exempt_tools: &[String],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
model_switch_callback: Option<ModelSwitchCallback>,
|
||||
) -> Result<String> {
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
@@ -2206,7 +2157,6 @@ pub(crate) async fn agent_turn(
|
||||
excluded_tools,
|
||||
dedup_exempt_tools,
|
||||
activated_tools,
|
||||
model_switch_callback,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2412,7 +2362,6 @@ pub(crate) async fn run_tool_call_loop(
|
||||
excluded_tools: &[String],
|
||||
dedup_exempt_tools: &[String],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
model_switch_callback: Option<ModelSwitchCallback>,
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
DEFAULT_MAX_TOOL_ITERATIONS
|
||||
@@ -2431,28 +2380,6 @@ pub(crate) async fn run_tool_call_loop(
|
||||
return Err(ToolLoopCancelled.into());
|
||||
}
|
||||
|
||||
// Check if model switch was requested via model_switch tool
|
||||
if let Some(ref callback) = model_switch_callback {
|
||||
if let Ok(guard) = callback.lock() {
|
||||
if let Some((new_provider, new_model)) = guard.as_ref() {
|
||||
if new_provider != provider_name || new_model != model {
|
||||
tracing::info!(
|
||||
"Model switch detected: {} {} -> {} {}",
|
||||
provider_name,
|
||||
model,
|
||||
new_provider,
|
||||
new_model
|
||||
);
|
||||
return Err(ModelSwitchRequested {
|
||||
provider: new_provider.clone(),
|
||||
model: new_model.clone(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild tool_specs each iteration so newly activated deferred tools appear.
|
||||
let mut tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
|
||||
.iter()
|
||||
@@ -3272,32 +3199,28 @@ pub async fn run(
|
||||
}
|
||||
|
||||
// ── Resolve provider ─────────────────────────────────────────
|
||||
let mut provider_name = provider_override
|
||||
let provider_name = provider_override
|
||||
.as_deref()
|
||||
.or(config.default_provider.as_deref())
|
||||
.unwrap_or("openrouter")
|
||||
.to_string();
|
||||
.unwrap_or("openrouter");
|
||||
|
||||
let mut model_name = model_override
|
||||
let model_name = model_override
|
||||
.as_deref()
|
||||
.or(config.default_model.as_deref())
|
||||
.unwrap_or("anthropic/claude-sonnet-4")
|
||||
.to_string();
|
||||
.unwrap_or("anthropic/claude-sonnet-4");
|
||||
|
||||
let provider_runtime_options = providers::provider_runtime_options_from_config(&config);
|
||||
|
||||
let mut provider: Box<dyn Provider> = providers::create_routed_provider_with_options(
|
||||
&provider_name,
|
||||
let provider: Box<dyn Provider> = providers::create_routed_provider_with_options(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model_name,
|
||||
model_name,
|
||||
&provider_runtime_options,
|
||||
)?;
|
||||
|
||||
let model_switch_callback = get_model_switch_state();
|
||||
|
||||
observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name.to_string(),
|
||||
model: model_name.to_string(),
|
||||
@@ -3441,7 +3364,7 @@ pub async fn run(
|
||||
let native_tools = provider.supports_native_tools();
|
||||
let mut system_prompt = crate::channels::build_system_prompt_with_mode(
|
||||
&config.workspace_dir,
|
||||
&model_name,
|
||||
model_name,
|
||||
&tool_descs,
|
||||
&skills,
|
||||
Some(&config.identity),
|
||||
@@ -3524,72 +3447,27 @@ pub async fn run(
|
||||
let excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, &msg);
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
let mut response = String::new();
|
||||
loop {
|
||||
match run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
response = resp;
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some((new_provider, new_model)) = is_model_switch_requested(&e) {
|
||||
tracing::info!(
|
||||
"Model switch requested, switching from {} {} to {} {}",
|
||||
provider_name,
|
||||
model_name,
|
||||
new_provider,
|
||||
new_model
|
||||
);
|
||||
|
||||
provider = providers::create_routed_provider_with_options(
|
||||
&new_provider,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&new_model,
|
||||
&provider_runtime_options,
|
||||
)?;
|
||||
|
||||
provider_name = new_provider;
|
||||
model_name = new_model;
|
||||
|
||||
clear_model_switch_request();
|
||||
|
||||
observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name.to_string(),
|
||||
model: model_name.to_string(),
|
||||
});
|
||||
|
||||
continue;
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
println!("{response}");
|
||||
observer.record_event(&ObserverEvent::TurnComplete);
|
||||
@@ -3731,66 +3609,32 @@ pub async fn run(
|
||||
&user_input,
|
||||
);
|
||||
|
||||
let response = loop {
|
||||
match run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
Some(model_switch_callback.clone()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => break resp,
|
||||
Err(e) => {
|
||||
if let Some((new_provider, new_model)) = is_model_switch_requested(&e) {
|
||||
tracing::info!(
|
||||
"Model switch requested, switching from {} {} to {} {}",
|
||||
provider_name,
|
||||
model_name,
|
||||
new_provider,
|
||||
new_model
|
||||
);
|
||||
|
||||
provider = providers::create_routed_provider_with_options(
|
||||
&new_provider,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&new_model,
|
||||
&provider_runtime_options,
|
||||
)?;
|
||||
|
||||
provider_name = new_provider;
|
||||
model_name = new_model;
|
||||
|
||||
clear_model_switch_request();
|
||||
|
||||
observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name.to_string(),
|
||||
model: model_name.to_string(),
|
||||
});
|
||||
|
||||
continue;
|
||||
}
|
||||
eprintln!("\nError: {e}\n");
|
||||
break String::new();
|
||||
}
|
||||
let response = match run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
eprintln!("\nError: {e}\n");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
final_output = response.clone();
|
||||
@@ -3808,7 +3652,7 @@ pub async fn run(
|
||||
if let Ok(compacted) = auto_compact_history(
|
||||
&mut history,
|
||||
provider.as_ref(),
|
||||
&model_name,
|
||||
model_name,
|
||||
config.agent.max_history_messages,
|
||||
config.agent.max_context_tokens,
|
||||
)
|
||||
@@ -4102,7 +3946,6 @@ pub async fn process_message(
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle_pm.as_ref(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -4560,7 +4403,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
@@ -4609,7 +4451,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
@@ -4652,7 +4493,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
@@ -4781,7 +4621,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("parallel execution should complete");
|
||||
@@ -4853,7 +4692,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish after deduplicating repeated calls");
|
||||
@@ -4921,7 +4759,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("non-interactive shell should succeed for low-risk command");
|
||||
@@ -4980,7 +4817,6 @@ mod tests {
|
||||
&[],
|
||||
&exempt,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish with exempt tool executing twice");
|
||||
@@ -5059,7 +4895,6 @@ mod tests {
|
||||
&[],
|
||||
&exempt,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should complete");
|
||||
@@ -5115,7 +4950,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("native fallback id flow should complete");
|
||||
@@ -5184,7 +5018,6 @@ mod tests {
|
||||
&[],
|
||||
&[],
|
||||
Some(&activated),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("wrapper path should execute activated tools");
|
||||
@@ -7076,7 +6909,6 @@ Let me check the result."#;
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
|
||||
+17
-61
@@ -227,10 +227,6 @@ fn channel_message_timeout_budget_secs(
|
||||
struct ChannelRouteSelection {
|
||||
provider: String,
|
||||
model: String,
|
||||
/// Route-specific API key override. When set, this takes precedence over
|
||||
/// the global `api_key` in [`ChannelRuntimeContext`] when creating the
|
||||
/// provider for this route.
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -908,7 +904,6 @@ fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection
|
||||
ChannelRouteSelection {
|
||||
provider: defaults.default_provider,
|
||||
model: defaults.model,
|
||||
api_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1127,43 +1122,21 @@ fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec<S
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Build a cache key that includes the provider name and, when a
|
||||
/// route-specific API key is supplied, a hash of that key. This prevents
|
||||
/// cache poisoning when multiple routes target the same provider with
|
||||
/// different credentials.
|
||||
fn provider_cache_key(provider_name: &str, route_api_key: Option<&str>) -> String {
|
||||
match route_api_key {
|
||||
Some(key) => {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
format!("{provider_name}@{:x}", hasher.finish())
|
||||
}
|
||||
None => provider_name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_or_create_provider(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
provider_name: &str,
|
||||
route_api_key: Option<&str>,
|
||||
) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
let cache_key = provider_cache_key(provider_name, route_api_key);
|
||||
|
||||
if let Some(existing) = ctx
|
||||
.provider_cache
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.get(&cache_key)
|
||||
.get(provider_name)
|
||||
.cloned()
|
||||
{
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
// Only return the pre-built default provider when there is no
|
||||
// route-specific credential override — otherwise the default was
|
||||
// created with the global key and would be wrong.
|
||||
if route_api_key.is_none() && provider_name == ctx.default_provider.as_str() {
|
||||
if provider_name == ctx.default_provider.as_str() {
|
||||
return Ok(Arc::clone(&ctx.provider));
|
||||
}
|
||||
|
||||
@@ -1174,14 +1147,9 @@ async fn get_or_create_provider(
|
||||
None
|
||||
};
|
||||
|
||||
// Prefer route-specific credential; fall back to the global key.
|
||||
let effective_api_key = route_api_key
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| ctx.api_key.clone());
|
||||
|
||||
let provider = create_resilient_provider_nonblocking(
|
||||
provider_name,
|
||||
effective_api_key,
|
||||
ctx.api_key.clone(),
|
||||
api_url.map(ToString::to_string),
|
||||
ctx.reliability.as_ref().clone(),
|
||||
ctx.provider_runtime_options.clone(),
|
||||
@@ -1195,7 +1163,7 @@ async fn get_or_create_provider(
|
||||
|
||||
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let cached = cache
|
||||
.entry(cache_key)
|
||||
.entry(provider_name.to_string())
|
||||
.or_insert_with(|| Arc::clone(&provider));
|
||||
Ok(Arc::clone(cached))
|
||||
}
|
||||
@@ -1311,27 +1279,25 @@ async fn handle_runtime_command_if_needed(
|
||||
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(¤t),
|
||||
ChannelRuntimeCommand::SetProvider(raw_provider) => {
|
||||
match resolve_provider_alias(&raw_provider) {
|
||||
Some(provider_name) => {
|
||||
match get_or_create_provider(ctx, &provider_name, None).await {
|
||||
Ok(_) => {
|
||||
if provider_name != current.provider {
|
||||
current.provider = provider_name.clone();
|
||||
set_route_selection(ctx, &sender_key, current.clone());
|
||||
}
|
||||
Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await {
|
||||
Ok(_) => {
|
||||
if provider_name != current.provider {
|
||||
current.provider = provider_name.clone();
|
||||
set_route_selection(ctx, &sender_key, current.clone());
|
||||
}
|
||||
|
||||
format!(
|
||||
format!(
|
||||
"Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model <model-id>` to set a provider-compatible model.",
|
||||
current.model
|
||||
)
|
||||
}
|
||||
Err(err) => {
|
||||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||||
format!(
|
||||
}
|
||||
Err(err) => {
|
||||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||||
format!(
|
||||
"Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
None => format!(
|
||||
"Unknown provider `{raw_provider}`. Use `/models` to list valid providers."
|
||||
),
|
||||
@@ -1351,7 +1317,6 @@ async fn handle_runtime_command_if_needed(
|
||||
}) {
|
||||
current.provider = route.provider.clone();
|
||||
current.model = route.model.clone();
|
||||
current.api_key = route.api_key.clone();
|
||||
} else {
|
||||
current.model = model.clone();
|
||||
}
|
||||
@@ -1957,19 +1922,12 @@ async fn process_channel_message(
|
||||
route = ChannelRouteSelection {
|
||||
provider: matched_route.provider.clone(),
|
||||
model: matched_route.model.clone(),
|
||||
api_key: matched_route.api_key.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
let active_provider = match get_or_create_provider(
|
||||
ctx.as_ref(),
|
||||
&route.provider,
|
||||
route.api_key.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
Err(err) => {
|
||||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||||
@@ -2251,7 +2209,6 @@ async fn process_channel_message(
|
||||
},
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
None,
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@@ -5645,7 +5602,6 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
ChannelRouteSelection {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "route-model".to_string(),
|
||||
api_key: None,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
+1
-4
@@ -17,10 +17,7 @@ pub use store::{
|
||||
add_agent_job, due_jobs, get_job, list_jobs, list_runs, record_last_run, record_run,
|
||||
remove_job, reschedule_after_run, update_job,
|
||||
};
|
||||
pub use types::{
|
||||
deserialize_maybe_stringified, CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType,
|
||||
Schedule, SessionTarget,
|
||||
};
|
||||
pub use types::{CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType, Schedule, SessionTarget};
|
||||
|
||||
/// Validate a shell command against the full security policy (allowlist + risk gate).
|
||||
///
|
||||
|
||||
+1
-66
@@ -1,32 +1,6 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Try to deserialize a `serde_json::Value` as `T`. If the value is a JSON
|
||||
/// string that looks like an object (i.e. the LLM double-serialized it), parse
|
||||
/// the inner string first and then deserialize the resulting object. This
|
||||
/// provides backward-compatible handling for both `Value::Object` and
|
||||
/// `Value::String` representations.
|
||||
pub fn deserialize_maybe_stringified<T: serde::de::DeserializeOwned>(
|
||||
v: &serde_json::Value,
|
||||
) -> Result<T, serde_json::Error> {
|
||||
// Fast path: value is already the right shape (object, array, etc.)
|
||||
match serde_json::from_value::<T>(v.clone()) {
|
||||
Ok(parsed) => Ok(parsed),
|
||||
Err(first_err) => {
|
||||
// If it's a string, try parsing the string as JSON first.
|
||||
if let Some(s) = v.as_str() {
|
||||
let s = s.trim();
|
||||
if s.starts_with('{') || s.starts_with('[') {
|
||||
if let Ok(inner) = serde_json::from_str::<serde_json::Value>(s) {
|
||||
return serde_json::from_value::<T>(inner);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(first_err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum JobType {
|
||||
@@ -180,46 +154,7 @@ pub struct CronJobPatch {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialize_schedule_from_object() {
|
||||
let val = serde_json::json!({"kind": "cron", "expr": "*/5 * * * *"});
|
||||
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
|
||||
assert!(matches!(sched, Schedule::Cron { ref expr, .. } if expr == "*/5 * * * *"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_schedule_from_string() {
|
||||
let val = serde_json::Value::String(r#"{"kind":"cron","expr":"*/5 * * * *"}"#.to_string());
|
||||
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
|
||||
assert!(matches!(sched, Schedule::Cron { ref expr, .. } if expr == "*/5 * * * *"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_schedule_string_with_tz() {
|
||||
let val = serde_json::Value::String(
|
||||
r#"{"kind":"cron","expr":"*/30 9-15 * * 1-5","tz":"Asia/Shanghai"}"#.to_string(),
|
||||
);
|
||||
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
|
||||
match sched {
|
||||
Schedule::Cron { tz, .. } => assert_eq!(tz.as_deref(), Some("Asia/Shanghai")),
|
||||
_ => panic!("expected Cron variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_every_from_string() {
|
||||
let val = serde_json::Value::String(r#"{"kind":"every","every_ms":60000}"#.to_string());
|
||||
let sched = deserialize_maybe_stringified::<Schedule>(&val).unwrap();
|
||||
assert!(matches!(sched, Schedule::Every { every_ms: 60000 }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_invalid_string_returns_error() {
|
||||
let val = serde_json::Value::String("not json at all".to_string());
|
||||
assert!(deserialize_maybe_stringified::<Schedule>(&val).is_err());
|
||||
}
|
||||
use super::JobType;
|
||||
|
||||
#[test]
|
||||
fn job_type_try_from_accepts_known_values_case_insensitive() {
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
@@ -31,7 +34,7 @@
|
||||
//!
|
||||
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
|
||||
|
||||
use crate::providers::traits::{ChatMessage, ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -209,54 +212,6 @@ impl Provider for ClaudeCodeProvider {
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
// Separate system prompt from conversation messages.
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
|
||||
// Build conversation turns (skip system messages).
|
||||
let turns: Vec<&ChatMessage> = messages.iter().filter(|m| m.role != "system").collect();
|
||||
|
||||
// If there's only one user message, use the simple path.
|
||||
if turns.len() <= 1 {
|
||||
let last_user = turns.first().map(|m| m.content.as_str()).unwrap_or("");
|
||||
let full_message = match system {
|
||||
Some(s) if !s.is_empty() => format!("{s}\n\n{last_user}"),
|
||||
_ => last_user.to_string(),
|
||||
};
|
||||
return self.invoke_cli(&full_message, model).await;
|
||||
}
|
||||
|
||||
// Format multi-turn conversation into a single prompt.
|
||||
let mut parts = Vec::new();
|
||||
if let Some(s) = system {
|
||||
if !s.is_empty() {
|
||||
parts.push(format!("[system]\n{s}"));
|
||||
}
|
||||
}
|
||||
for msg in &turns {
|
||||
let label = match msg.role.as_str() {
|
||||
"user" => "[user]",
|
||||
"assistant" => "[assistant]",
|
||||
other => other,
|
||||
};
|
||||
parts.push(format!("{label}\n{}", msg.content));
|
||||
}
|
||||
parts.push("[assistant]".to_string());
|
||||
|
||||
let full_message = parts.join("\n\n");
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
@@ -372,105 +327,4 @@ mod tests {
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper: create a provider that uses a shell script echoing stdin back.
|
||||
/// The script ignores CLI flags (`--print`, `--model`, `-`) and just cats stdin.
|
||||
///
|
||||
/// Uses `OnceLock` to write the script file exactly once, avoiding
|
||||
/// "Text file busy" (ETXTBSY) races when parallel tests try to
|
||||
/// overwrite a script that another test is currently executing.
|
||||
fn echo_provider() -> ClaudeCodeProvider {
|
||||
use std::sync::OnceLock;
|
||||
|
||||
static SCRIPT_PATH: OnceLock<PathBuf> = OnceLock::new();
|
||||
let script = SCRIPT_PATH.get_or_init(|| {
|
||||
use std::io::Write;
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_claude_code");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let path = dir.join("fake_claude.sh");
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
writeln!(f, "#!/bin/sh\ncat /dev/stdin").unwrap();
|
||||
drop(f);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
|
||||
}
|
||||
path
|
||||
});
|
||||
ClaudeCodeProvider {
|
||||
binary_path: script.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_single_user_message() {
|
||||
let provider = echo_provider();
|
||||
let messages = vec![ChatMessage::user("hello")];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_single_user_with_system() {
|
||||
let provider = echo_provider();
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful."),
|
||||
ChatMessage::user("hello"),
|
||||
];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "You are helpful.\n\nhello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_multi_turn_includes_all_messages() {
|
||||
let provider = echo_provider();
|
||||
let messages = vec![
|
||||
ChatMessage::system("Be concise."),
|
||||
ChatMessage::user("What is 2+2?"),
|
||||
ChatMessage::assistant("4"),
|
||||
ChatMessage::user("And 3+3?"),
|
||||
];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.contains("[system]\nBe concise."));
|
||||
assert!(result.contains("[user]\nWhat is 2+2?"));
|
||||
assert!(result.contains("[assistant]\n4"));
|
||||
assert!(result.contains("[user]\nAnd 3+3?"));
|
||||
assert!(result.ends_with("[assistant]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_multi_turn_without_system() {
|
||||
let provider = echo_provider();
|
||||
let messages = vec![
|
||||
ChatMessage::user("hi"),
|
||||
ChatMessage::assistant("hello"),
|
||||
ChatMessage::user("bye"),
|
||||
];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.contains("[system]"));
|
||||
assert!(result.contains("[user]\nhi"));
|
||||
assert!(result.contains("[assistant]\nhello"));
|
||||
assert!(result.contains("[user]\nbye"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_rejects_bad_temperature() {
|
||||
let provider = echo_provider();
|
||||
let messages = vec![ChatMessage::user("test")];
|
||||
let result = provider.chat_with_history(&messages, "default", 0.5).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1320,12 +1320,11 @@ fn create_provider_with_url_and_options(
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("llama.cpp");
|
||||
Ok(compat(OpenAiCompatibleProvider::new_with_vision(
|
||||
Ok(compat(OpenAiCompatibleProvider::new(
|
||||
"llama.cpp",
|
||||
base_url,
|
||||
Some(llama_cpp_key),
|
||||
AuthStyle::Bearer,
|
||||
true,
|
||||
)))
|
||||
}
|
||||
"sglang" => {
|
||||
|
||||
+41
-250
@@ -16,10 +16,8 @@ use std::time::Duration;
|
||||
|
||||
/// Check if an error is non-retryable (client errors that won't resolve with retries).
|
||||
pub fn is_non_retryable(err: &anyhow::Error) -> bool {
|
||||
// Context window errors are NOT non-retryable — they can be recovered
|
||||
// by truncating conversation history, so let the retry loop handle them.
|
||||
if is_context_window_exceeded(err) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 4xx errors are generally non-retryable (bad request, auth failure, etc.),
|
||||
@@ -77,7 +75,6 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
|
||||
let lower = err.to_string().to_lowercase();
|
||||
let hints = [
|
||||
"exceeds the context window",
|
||||
"exceeds the available context size",
|
||||
"context window of this model",
|
||||
"maximum context length",
|
||||
"context length exceeded",
|
||||
@@ -200,35 +197,6 @@ fn compact_error_detail(err: &anyhow::Error) -> String {
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
/// Truncate conversation history by dropping the oldest non-system messages.
|
||||
/// Returns the number of messages dropped. Keeps at least the system message
|
||||
/// (if any) and the most recent user message.
|
||||
fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
|
||||
// Find all non-system message indices
|
||||
let non_system: Vec<usize> = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, m)| m.role != "system")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// Keep at least the last non-system message (most recent user turn)
|
||||
if non_system.len() <= 1 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Drop the oldest half of non-system messages
|
||||
let drop_count = non_system.len() / 2;
|
||||
let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
|
||||
|
||||
// Remove in reverse order to preserve indices
|
||||
for &idx in indices_to_remove.iter().rev() {
|
||||
messages.remove(idx);
|
||||
}
|
||||
|
||||
drop_count
|
||||
}
|
||||
|
||||
fn push_failure(
|
||||
failures: &mut Vec<String>,
|
||||
provider_name: &str,
|
||||
@@ -370,25 +338,6 @@ impl Provider for ReliableProvider {
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
// Context window exceeded: no history to truncate
|
||||
// in chat_with_system, bail immediately.
|
||||
if is_context_window_exceeded(&e) {
|
||||
let error_detail = compact_error_detail(&e);
|
||||
push_failure(
|
||||
&mut failures,
|
||||
provider_name,
|
||||
current_model,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
"non_retryable",
|
||||
&error_detail,
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Request exceeds model context window. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
);
|
||||
}
|
||||
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
@@ -427,6 +376,14 @@ impl Provider for ReliableProvider {
|
||||
error = %error_detail,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
|
||||
if is_context_window_exceeded(&e) {
|
||||
anyhow::bail!(
|
||||
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -478,8 +435,6 @@ impl Provider for ReliableProvider {
|
||||
) -> anyhow::Result<String> {
|
||||
let models = self.model_chain(model);
|
||||
let mut failures = Vec::new();
|
||||
let mut effective_messages = messages.to_vec();
|
||||
let mut context_truncated = false;
|
||||
|
||||
for current_model in &models {
|
||||
for (provider_name, provider) in &self.providers {
|
||||
@@ -487,39 +442,22 @@ impl Provider for ReliableProvider {
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
match provider
|
||||
.chat_with_history(&effective_messages, current_model, temperature)
|
||||
.chat_with_history(messages, current_model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
if attempt > 0 || *current_model != model || context_truncated {
|
||||
if attempt > 0 || *current_model != model {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt,
|
||||
original_model = model,
|
||||
context_truncated,
|
||||
"Provider recovered (failover/retry)"
|
||||
);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
// Context window exceeded: truncate history and retry
|
||||
if is_context_window_exceeded(&e) && !context_truncated {
|
||||
let dropped = truncate_for_context(&mut effective_messages);
|
||||
if dropped > 0 {
|
||||
context_truncated = true;
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
dropped,
|
||||
remaining = effective_messages.len(),
|
||||
"Context window exceeded; truncated history and retrying"
|
||||
);
|
||||
continue; // Retry with truncated messages (counts as an attempt)
|
||||
}
|
||||
}
|
||||
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
@@ -556,6 +494,14 @@ impl Provider for ReliableProvider {
|
||||
error = %error_detail,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
|
||||
if is_context_window_exceeded(&e) {
|
||||
anyhow::bail!(
|
||||
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -613,8 +559,6 @@ impl Provider for ReliableProvider {
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let models = self.model_chain(model);
|
||||
let mut failures = Vec::new();
|
||||
let mut effective_messages = messages.to_vec();
|
||||
let mut context_truncated = false;
|
||||
|
||||
for current_model in &models {
|
||||
for (provider_name, provider) in &self.providers {
|
||||
@@ -622,39 +566,22 @@ impl Provider for ReliableProvider {
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
match provider
|
||||
.chat_with_tools(&effective_messages, tools, current_model, temperature)
|
||||
.chat_with_tools(messages, tools, current_model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
if attempt > 0 || *current_model != model || context_truncated {
|
||||
if attempt > 0 || *current_model != model {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt,
|
||||
original_model = model,
|
||||
context_truncated,
|
||||
"Provider recovered (failover/retry)"
|
||||
);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
// Context window exceeded: truncate history and retry
|
||||
if is_context_window_exceeded(&e) && !context_truncated {
|
||||
let dropped = truncate_for_context(&mut effective_messages);
|
||||
if dropped > 0 {
|
||||
context_truncated = true;
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
dropped,
|
||||
remaining = effective_messages.len(),
|
||||
"Context window exceeded; truncated history and retrying"
|
||||
);
|
||||
continue; // Retry with truncated messages (counts as an attempt)
|
||||
}
|
||||
}
|
||||
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
@@ -691,6 +618,14 @@ impl Provider for ReliableProvider {
|
||||
error = %error_detail,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
|
||||
if is_context_window_exceeded(&e) {
|
||||
anyhow::bail!(
|
||||
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -734,8 +669,6 @@ impl Provider for ReliableProvider {
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let models = self.model_chain(model);
|
||||
let mut failures = Vec::new();
|
||||
let mut effective_messages = request.messages.to_vec();
|
||||
let mut context_truncated = false;
|
||||
|
||||
for current_model in &models {
|
||||
for (provider_name, provider) in &self.providers {
|
||||
@@ -743,40 +676,23 @@ impl Provider for ReliableProvider {
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
let req = ChatRequest {
|
||||
messages: &effective_messages,
|
||||
messages: request.messages,
|
||||
tools: request.tools,
|
||||
};
|
||||
match provider.chat(req, current_model, temperature).await {
|
||||
Ok(resp) => {
|
||||
if attempt > 0 || *current_model != model || context_truncated {
|
||||
if attempt > 0 || *current_model != model {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt,
|
||||
original_model = model,
|
||||
context_truncated,
|
||||
"Provider recovered (failover/retry)"
|
||||
);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
// Context window exceeded: truncate history and retry
|
||||
if is_context_window_exceeded(&e) && !context_truncated {
|
||||
let dropped = truncate_for_context(&mut effective_messages);
|
||||
if dropped > 0 {
|
||||
context_truncated = true;
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
dropped,
|
||||
remaining = effective_messages.len(),
|
||||
"Context window exceeded; truncated history and retrying"
|
||||
);
|
||||
continue; // Retry with truncated messages (counts as an attempt)
|
||||
}
|
||||
}
|
||||
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
@@ -813,6 +729,14 @@ impl Provider for ReliableProvider {
|
||||
error = %error_detail,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
|
||||
if is_context_window_exceeded(&e) {
|
||||
anyhow::bail!(
|
||||
"Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1147,8 +1071,7 @@ mod tests {
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||
"model overloaded, try again later"
|
||||
)));
|
||||
// Context window errors are now recoverable (not non-retryable)
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||
assert!(is_non_retryable(&anyhow::anyhow!(
|
||||
"OpenAI Codex stream error: Your input exceeds the context window of this model."
|
||||
)));
|
||||
}
|
||||
@@ -1184,7 +1107,7 @@ mod tests {
|
||||
let msg = err.to_string();
|
||||
|
||||
assert!(msg.contains("context window"));
|
||||
// chat_with_system has no history to truncate, so it bails immediately
|
||||
assert!(msg.contains("skipped"));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
@@ -2057,136 +1980,4 @@ mod tests {
|
||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
// ── Context window truncation tests ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn context_window_error_is_not_non_retryable() {
|
||||
// Context window errors should be recoverable via truncation
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||
"exceeds the context window"
|
||||
)));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||
"maximum context length exceeded"
|
||||
)));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||
"too many tokens in the request"
|
||||
)));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_context_window_exceeded_detects_llamacpp() {
|
||||
assert!(is_context_window_exceeded(&anyhow::anyhow!(
|
||||
"request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_for_context_drops_oldest_non_system() {
|
||||
let mut messages = vec![
|
||||
ChatMessage::system("sys"),
|
||||
ChatMessage::user("msg1"),
|
||||
ChatMessage::assistant("resp1"),
|
||||
ChatMessage::user("msg2"),
|
||||
ChatMessage::assistant("resp2"),
|
||||
ChatMessage::user("msg3"),
|
||||
];
|
||||
|
||||
let dropped = truncate_for_context(&mut messages);
|
||||
|
||||
// 5 non-system messages, drop oldest half = 2
|
||||
assert_eq!(dropped, 2);
|
||||
// System message preserved
|
||||
assert_eq!(messages[0].role, "system");
|
||||
// Remaining messages should be the newer ones
|
||||
assert_eq!(messages.len(), 4); // system + 3 remaining non-system
|
||||
// The last message should still be the most recent user message
|
||||
assert_eq!(messages.last().unwrap().content, "msg3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_for_context_preserves_system_and_last_message() {
|
||||
// Only one non-system message: nothing to drop
|
||||
let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
|
||||
let dropped = truncate_for_context(&mut messages);
|
||||
assert_eq!(dropped, 0);
|
||||
assert_eq!(messages.len(), 2);
|
||||
|
||||
// No system message, only one user message
|
||||
let mut messages = vec![ChatMessage::user("only")];
|
||||
let dropped = truncate_for_context(&mut messages);
|
||||
assert_eq!(dropped, 0);
|
||||
assert_eq!(messages.len(), 1);
|
||||
}
|
||||
|
||||
/// Mock that fails with context error on first N calls, then succeeds.
|
||||
/// Tracks the number of messages received on each call.
|
||||
struct ContextOverflowMock {
|
||||
calls: Arc<AtomicUsize>,
|
||||
fail_until_attempt: usize,
|
||||
message_counts: parking_lot::Mutex<Vec<usize>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ContextOverflowMock {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
self.message_counts.lock().push(messages.len());
|
||||
if attempt <= self.fail_until_attempt {
|
||||
anyhow::bail!(
|
||||
"request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
|
||||
);
|
||||
}
|
||||
Ok("recovered after truncation".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_truncates_on_context_overflow() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let mock = ContextOverflowMock {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: 1, // fail first call, succeed after truncation
|
||||
message_counts: parking_lot::Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
|
||||
3,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("system prompt"),
|
||||
ChatMessage::user("old message 1"),
|
||||
ChatMessage::assistant("old response 1"),
|
||||
ChatMessage::user("old message 2"),
|
||||
ChatMessage::assistant("old response 2"),
|
||||
ChatMessage::user("current question"),
|
||||
];
|
||||
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "local-model", 0.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "recovered after truncation");
|
||||
// Should have been called twice: once with full messages, once with truncated
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
}
|
||||
|
||||
+2
-61
@@ -1,8 +1,6 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::Config;
|
||||
use crate::cron::{
|
||||
self, deserialize_maybe_stringified, DeliveryConfig, JobType, Schedule, SessionTarget,
|
||||
};
|
||||
use crate::cron::{self, DeliveryConfig, JobType, Schedule, SessionTarget};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
@@ -178,7 +176,7 @@ impl Tool for CronAddTool {
|
||||
}
|
||||
|
||||
let schedule = match args.get("schedule") {
|
||||
Some(v) => match deserialize_maybe_stringified::<Schedule>(v) {
|
||||
Some(v) => match serde_json::from_value::<Schedule>(v.clone()) {
|
||||
Ok(schedule) => schedule,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
@@ -513,63 +511,6 @@ mod tests {
|
||||
assert!(approved.success, "{:?}", approved.error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn accepts_schedule_passed_as_json_string() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
// Simulate the LLM double-serializing the schedule: the value arrives
|
||||
// as a JSON string containing a JSON object, rather than an object.
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": r#"{"kind":"cron","expr":"*/5 * * * *"}"#,
|
||||
"job_type": "shell",
|
||||
"command": "echo string-schedule"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
assert!(result.output.contains("next_run"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn accepts_stringified_interval_schedule() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": r#"{"kind":"every","every_ms":60000}"#,
|
||||
"job_type": "shell",
|
||||
"command": "echo interval"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn accepts_stringified_schedule_with_timezone() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = test_config(&tmp).await;
|
||||
let tool = CronAddTool::new(cfg.clone(), test_security(&cfg));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"schedule": r#"{"kind":"cron","expr":"*/30 9-15 * * 1-5","tz":"Asia/Shanghai"}"#,
|
||||
"job_type": "shell",
|
||||
"command": "echo tz-test"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_invalid_schedule() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::Config;
|
||||
use crate::cron::{self, deserialize_maybe_stringified, CronJobPatch};
|
||||
use crate::cron::{self, CronJobPatch};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
@@ -202,7 +202,7 @@ impl Tool for CronUpdateTool {
|
||||
}
|
||||
};
|
||||
|
||||
let patch = match deserialize_maybe_stringified::<CronJobPatch>(&patch_val) {
|
||||
let patch = match serde_json::from_value::<CronJobPatch>(patch_val) {
|
||||
Ok(patch) => patch,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
|
||||
@@ -422,7 +422,6 @@ impl DelegateTool {
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -233,22 +233,12 @@ impl Default for ActivatedToolSet {
|
||||
|
||||
/// Build the `<available-deferred-tools>` section for the system prompt.
|
||||
/// Lists only tool names so the LLM knows what is available without
|
||||
/// consuming context window on full schemas. Includes an instruction
|
||||
/// block that tells the LLM to call `tool_search` to activate them.
|
||||
/// consuming context window on full schemas.
|
||||
pub fn build_deferred_tools_section(deferred: &DeferredMcpToolSet) -> String {
|
||||
if deferred.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
let mut out = String::new();
|
||||
out.push_str("## Deferred Tools\n\n");
|
||||
out.push_str(
|
||||
"The tools listed below are available but NOT yet loaded. \
|
||||
To use any of them you MUST first call the `tool_search` tool \
|
||||
to fetch their full schemas. Use `\"select:name1,name2\"` for \
|
||||
exact tools or keywords to search. Once activated, the tools \
|
||||
become callable for the rest of the conversation.\n\n",
|
||||
);
|
||||
out.push_str("<available-deferred-tools>\n");
|
||||
let mut out = String::from("<available-deferred-tools>\n");
|
||||
for stub in &deferred.stubs {
|
||||
out.push_str(&stub.prefixed_name);
|
||||
out.push('\n');
|
||||
@@ -426,55 +416,6 @@ mod tests {
|
||||
assert!(section.contains("</available-deferred-tools>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_deferred_section_includes_tool_search_instruction() {
|
||||
let stubs = vec![make_stub("fs__read_file", "Read a file")];
|
||||
let set = DeferredMcpToolSet {
|
||||
stubs,
|
||||
registry: std::sync::Arc::new(
|
||||
tokio::runtime::Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(McpRegistry::connect_all(&[]))
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
let section = build_deferred_tools_section(&set);
|
||||
assert!(
|
||||
section.contains("tool_search"),
|
||||
"deferred section must instruct the LLM to use tool_search"
|
||||
);
|
||||
assert!(
|
||||
section.contains("## Deferred Tools"),
|
||||
"deferred section must include a heading"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_deferred_section_multiple_servers() {
|
||||
let stubs = vec![
|
||||
make_stub("server_a__list", "List items"),
|
||||
make_stub("server_a__create", "Create item"),
|
||||
make_stub("server_b__query", "Query records"),
|
||||
];
|
||||
let set = DeferredMcpToolSet {
|
||||
stubs,
|
||||
registry: std::sync::Arc::new(
|
||||
tokio::runtime::Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(McpRegistry::connect_all(&[]))
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
let section = build_deferred_tools_section(&set);
|
||||
assert!(section.contains("server_a__list"));
|
||||
assert!(section.contains("server_a__create"));
|
||||
assert!(section.contains("server_b__query"));
|
||||
assert!(
|
||||
section.contains("tool_search"),
|
||||
"section must mention tool_search for multi-server setups"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyword_search_ranks_by_hits() {
|
||||
let stubs = vec![
|
||||
@@ -516,35 +457,4 @@ mod tests {
|
||||
assert!(set.get_by_name("a__one").is_some());
|
||||
assert!(set.get_by_name("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_across_multiple_servers() {
|
||||
let stubs = vec![
|
||||
make_stub("server_a__read_file", "Read a file from disk"),
|
||||
make_stub("server_b__read_config", "Read configuration from database"),
|
||||
];
|
||||
let set = DeferredMcpToolSet {
|
||||
stubs,
|
||||
registry: std::sync::Arc::new(
|
||||
tokio::runtime::Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(McpRegistry::connect_all(&[]))
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
|
||||
// "read" should match stubs from both servers
|
||||
let results = set.search("read", 10);
|
||||
assert_eq!(results.len(), 2);
|
||||
|
||||
// "file" should match only server_a
|
||||
let results = set.search("file", 10);
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].prefixed_name, "server_a__read_file");
|
||||
|
||||
// "config database" should rank server_b highest (2 hits)
|
||||
let results = set.search("config database", 10);
|
||||
assert!(!results.is_empty());
|
||||
assert_eq!(results[0].prefixed_name, "server_b__read_config");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,6 @@ pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod microsoft365;
|
||||
pub mod model_routing_config;
|
||||
pub mod model_switch;
|
||||
pub mod node_tool;
|
||||
pub mod notion_tool;
|
||||
pub mod pdf_read;
|
||||
@@ -120,7 +119,6 @@ pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use microsoft365::Microsoft365Tool;
|
||||
pub use model_routing_config::ModelRoutingConfigTool;
|
||||
pub use model_switch::ModelSwitchTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use node_tool::NodeTool;
|
||||
pub use notion_tool::NotionTool;
|
||||
@@ -304,7 +302,6 @@ pub fn all_tools_with_runtime(
|
||||
config.clone(),
|
||||
security.clone(),
|
||||
)),
|
||||
Arc::new(ModelSwitchTool::new(security.clone())),
|
||||
Arc::new(ProxyConfigTool::new(config.clone(), security.clone())),
|
||||
Arc::new(GitOperationsTool::new(
|
||||
security.clone(),
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::agent::loop_::get_model_switch_state;
|
||||
use crate::providers;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ModelSwitchTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl ModelSwitchTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ModelSwitchTool {
|
||||
fn name(&self) -> &str {
|
||||
"model_switch"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Switch the AI model at runtime. Use 'get' to see current model, 'list_providers' to see available providers, 'list_models' to see models for a provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["get", "set", "list_providers", "list_models"],
|
||||
"description": "Action to perform: get current model, set a new model, list available providers, or list models for a provider"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": "Provider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions."
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "model_switch")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match action {
|
||||
"get" => self.handle_get(),
|
||||
"set" => self.handle_set(&args),
|
||||
"list_providers" => self.handle_list_providers(),
|
||||
"list_models" => self.handle_list_models(&args),
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: {}. Valid actions: get, set, list_providers, list_models",
|
||||
action
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelSwitchTool {
|
||||
fn handle_get(&self) -> anyhow::Result<ToolResult> {
|
||||
let switch_state = get_model_switch_state();
|
||||
let pending = switch_state.lock().unwrap().clone();
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
"pending_switch": pending,
|
||||
"note": "To switch models, use action 'set' with provider and model parameters"
|
||||
}))?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let provider = args.get("provider").and_then(|v| v.as_str());
|
||||
|
||||
let provider = match provider {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing 'provider' parameter for 'set' action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let model = args.get("model").and_then(|v| v.as_str());
|
||||
|
||||
let model = match model {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing 'model' parameter for 'set' action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Validate the provider exists
|
||||
let known_providers = providers::list_providers();
|
||||
let provider_valid = known_providers.iter().any(|p| {
|
||||
p.name.eq_ignore_ascii_case(provider)
|
||||
|| p.aliases.iter().any(|a| a.eq_ignore_ascii_case(provider))
|
||||
});
|
||||
|
||||
if !provider_valid {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
"available_providers": known_providers.iter().map(|p| p.name).collect::<Vec<_>>()
|
||||
}))?,
|
||||
error: Some(format!(
|
||||
"Unknown provider: {}. Use 'list_providers' to see available options.",
|
||||
provider
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Set the global model switch request
|
||||
let switch_state = get_model_switch_state();
|
||||
*switch_state.lock().unwrap() = Some((provider.to_string(), model.to_string()));
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
"message": "Model switch requested",
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch."
|
||||
}))?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
|
||||
let providers_list = providers::list_providers();
|
||||
|
||||
let providers: Vec<serde_json::Value> = providers_list
|
||||
.iter()
|
||||
.map(|p| {
|
||||
json!({
|
||||
"name": p.name,
|
||||
"display_name": p.display_name,
|
||||
"aliases": p.aliases,
|
||||
"local": p.local
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
"providers": providers,
|
||||
"count": providers.len(),
|
||||
"example": "Use action 'set' with provider and model to switch"
|
||||
}))?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let provider = args.get("provider").and_then(|v| v.as_str());
|
||||
|
||||
let provider = match provider {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing 'provider' parameter for 'list_models' action".to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Return common models for known providers
|
||||
let models = match provider.to_lowercase().as_str() {
|
||||
"openai" => vec![
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
],
|
||||
"anthropic" => vec![
|
||||
"claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-opus",
|
||||
"claude-3-haiku",
|
||||
],
|
||||
"openrouter" => vec![
|
||||
"anthropic/claude-sonnet-4-6",
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-pro",
|
||||
"meta-llama/llama-3-70b-instruct",
|
||||
],
|
||||
"groq" => vec![
|
||||
"llama-3.3-70b-versatile",
|
||||
"mixtral-8x7b-32768",
|
||||
"llama-3.1-70b-speculative",
|
||||
],
|
||||
"ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
|
||||
"deepseek" => vec!["deepseek-chat", "deepseek-coder"],
|
||||
"mistral" => vec![
|
||||
"mistral-large-latest",
|
||||
"mistral-small-latest",
|
||||
"mistral-nemo",
|
||||
],
|
||||
"google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
|
||||
"xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"],
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
if models.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
"provider": provider,
|
||||
"models": [],
|
||||
"note": "No common models listed for this provider. Check provider documentation for available models."
|
||||
}))?,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&json!({
|
||||
"provider": provider,
|
||||
"models": models,
|
||||
"example": "Use action 'set' with this provider and a model ID to switch"
|
||||
}))?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -281,88 +281,4 @@ mod tests {
|
||||
// Tool should now be activated
|
||||
assert!(activated.lock().unwrap().is_activated("fs__read"));
|
||||
}
|
||||
|
||||
/// Verify tool_search works with stubs from multiple MCP servers,
|
||||
/// simulating a daemon-mode setup where several servers are deferred.
|
||||
#[tokio::test]
|
||||
async fn multiple_servers_stubs_all_searchable() {
|
||||
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
|
||||
let stubs = vec![
|
||||
make_stub("server_a__list_files", "List files on server A"),
|
||||
make_stub("server_a__read_file", "Read file on server A"),
|
||||
make_stub("server_b__query_db", "Query database on server B"),
|
||||
make_stub("server_b__insert_row", "Insert row on server B"),
|
||||
];
|
||||
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
|
||||
|
||||
// Search should find tools across both servers
|
||||
let result = tool
|
||||
.execute(serde_json::json!({"query": "file"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("server_a__list_files"));
|
||||
assert!(result.output.contains("server_a__read_file"));
|
||||
|
||||
// Server B tools should also be searchable
|
||||
let result = tool
|
||||
.execute(serde_json::json!({"query": "database query"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("server_b__query_db"));
|
||||
}
|
||||
|
||||
/// Verify select mode activates tools and they stay activated across calls,
|
||||
/// matching the daemon-mode pattern where a single ActivatedToolSet persists.
|
||||
#[tokio::test]
|
||||
async fn select_activates_and_persists_across_calls() {
|
||||
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
|
||||
let stubs = vec![
|
||||
make_stub("srv__tool_a", "Tool A"),
|
||||
make_stub("srv__tool_b", "Tool B"),
|
||||
];
|
||||
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
|
||||
|
||||
// Activate tool_a
|
||||
let result = tool
|
||||
.execute(serde_json::json!({"query": "select:srv__tool_a"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(activated.lock().unwrap().is_activated("srv__tool_a"));
|
||||
assert!(!activated.lock().unwrap().is_activated("srv__tool_b"));
|
||||
|
||||
// Activate tool_b in a separate call
|
||||
let result = tool
|
||||
.execute(serde_json::json!({"query": "select:srv__tool_b"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
// Both should remain activated
|
||||
let guard = activated.lock().unwrap();
|
||||
assert!(guard.is_activated("srv__tool_a"));
|
||||
assert!(guard.is_activated("srv__tool_b"));
|
||||
assert_eq!(guard.tool_specs().len(), 2);
|
||||
}
|
||||
|
||||
/// Verify re-activating an already-activated tool does not duplicate it.
|
||||
#[tokio::test]
|
||||
async fn reactivation_is_idempotent() {
|
||||
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
|
||||
let tool = ToolSearchTool::new(
|
||||
make_deferred_set(vec![make_stub("srv__tool", "A tool")]).await,
|
||||
Arc::clone(&activated),
|
||||
);
|
||||
|
||||
tool.execute(serde_json::json!({"query": "select:srv__tool"}))
|
||||
.await
|
||||
.unwrap();
|
||||
tool.execute(serde_json::json!({"query": "select:srv__tool"}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(activated.lock().unwrap().tool_specs().len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user