Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5dd724edd4 | |||
| a695ca4b9c | |||
| 811fab3b87 | |||
| 1a5d91fe69 | |||
| 6eec1c81b9 | |||
| 602db8bca1 | |||
| 314e1d3ae8 | |||
| 82be05b1e9 | |||
| 1373659058 | |||
| c7f064e866 | |||
| 9c1d63e109 | |||
| 966edf1553 | |||
| a1af84d992 | |||
| 70e8e7ebcd | |||
| 2bcb82c5b3 | |||
| e211b5c3e3 | |||
| 8691476577 | |||
| 996dbe95cf |
@@ -155,11 +155,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
|
||||
# ensuring compatibility with Ubuntu 22.04+ (#3573).
|
||||
- os: ubuntu-22.04
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-22.04
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
|
||||
@@ -156,11 +156,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
|
||||
# ensuring compatibility with Ubuntu 22.04+ (#3573).
|
||||
- os: ubuntu-22.04
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-22.04
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
|
||||
Generated
+1
-1
@@ -7945,7 +7945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-imap",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
+6
-4
@@ -2696,11 +2696,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
arguments: tool_args.clone(),
|
||||
};
|
||||
|
||||
// Only prompt interactively on CLI; auto-approve on other channels.
|
||||
let decision = if channel_name == "cli" {
|
||||
mgr.prompt_cli(&request)
|
||||
// Interactive CLI: prompt the operator.
|
||||
// Non-interactive (channels): auto-deny since no operator
|
||||
// is present to approve.
|
||||
let decision = if mgr.is_non_interactive() {
|
||||
ApprovalResponse::No
|
||||
} else {
|
||||
ApprovalResponse::Yes
|
||||
mgr.prompt_cli(&request)
|
||||
};
|
||||
|
||||
mgr.record_decision(&tool_name, &tool_args, decision, channel_name);
|
||||
|
||||
+128
-4
@@ -44,11 +44,18 @@ pub struct ApprovalLogEntry {
|
||||
|
||||
// ── ApprovalManager ──────────────────────────────────────────────
|
||||
|
||||
/// Manages the interactive approval workflow.
|
||||
/// Manages the approval workflow for tool calls.
|
||||
///
|
||||
/// - Checks config-level `auto_approve` / `always_ask` lists
|
||||
/// - Maintains a session-scoped "always" allowlist
|
||||
/// - Records an audit trail of all decisions
|
||||
///
|
||||
/// Two modes:
|
||||
/// - **Interactive** (CLI): tools needing approval trigger a stdin prompt.
|
||||
/// - **Non-interactive** (channels): tools needing approval are auto-denied
|
||||
/// because there is no interactive operator to approve them. `auto_approve`
|
||||
/// policy is still enforced, and `always_ask` / supervised-default tools are
|
||||
/// denied rather than silently allowed.
|
||||
pub struct ApprovalManager {
|
||||
/// Tools that never need approval (from config).
|
||||
auto_approve: HashSet<String>,
|
||||
@@ -56,6 +63,9 @@ pub struct ApprovalManager {
|
||||
always_ask: HashSet<String>,
|
||||
/// Autonomy level from config.
|
||||
autonomy_level: AutonomyLevel,
|
||||
/// When `true`, tools that would require interactive approval are
|
||||
/// auto-denied instead. Used for channel-driven (non-CLI) runs.
|
||||
non_interactive: bool,
|
||||
/// Session-scoped allowlist built from "Always" responses.
|
||||
session_allowlist: Mutex<HashSet<String>>,
|
||||
/// Audit trail of approval decisions.
|
||||
@@ -63,17 +73,40 @@ pub struct ApprovalManager {
|
||||
}
|
||||
|
||||
impl ApprovalManager {
|
||||
/// Create from autonomy config.
|
||||
/// Create an interactive (CLI) approval manager from autonomy config.
|
||||
pub fn from_config(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: false,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-interactive approval manager for channel-driven runs.
|
||||
///
|
||||
/// Enforces the same `auto_approve` / `always_ask` / supervised policies
|
||||
/// as the CLI manager, but tools that would require interactive approval
|
||||
/// are auto-denied instead of prompting (since there is no operator).
|
||||
pub fn for_non_interactive(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: true,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` when this manager operates in non-interactive mode
|
||||
/// (i.e. for channel-driven runs where no operator can approve).
|
||||
pub fn is_non_interactive(&self) -> bool {
|
||||
self.non_interactive
|
||||
}
|
||||
|
||||
/// Check whether a tool call requires interactive approval.
|
||||
///
|
||||
/// Returns `true` if the call needs a prompt, `false` if it can proceed.
|
||||
@@ -147,8 +180,8 @@ impl ApprovalManager {
|
||||
|
||||
/// Prompt the user on the CLI and return their decision.
|
||||
///
|
||||
/// For non-CLI channels, returns `Yes` automatically (interactive
|
||||
/// approval is only supported on CLI for now).
|
||||
/// Only called for interactive (CLI) managers. Non-interactive managers
|
||||
/// auto-deny in the tool-call loop before reaching this point.
|
||||
pub fn prompt_cli(&self, request: &ApprovalRequest) -> ApprovalResponse {
|
||||
prompt_cli_interactive(request)
|
||||
}
|
||||
@@ -401,6 +434,97 @@ mod tests {
|
||||
assert!(summary.contains("just a string"));
|
||||
}
|
||||
|
||||
// ── non-interactive (channel) mode ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn non_interactive_manager_reports_non_interactive() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_manager_reports_interactive() {
|
||||
let mgr = ApprovalManager::from_config(&supervised_config());
|
||||
assert!(!mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_auto_approve_tools_skip_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// auto_approve tools (file_read, memory_recall) should not need approval.
|
||||
assert!(!mgr.needs_approval("file_read"));
|
||||
assert!(!mgr.needs_approval("memory_recall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_tools_need_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// always_ask tools (shell) still report as needing approval,
|
||||
// so the tool-call loop will auto-deny them in non-interactive mode.
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_unknown_tools_need_approval_in_supervised() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// Unknown tools in supervised mode need approval (will be auto-denied
|
||||
// by the tool-call loop for non-interactive managers).
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
assert!(mgr.needs_approval("http_request"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_full_autonomy_never_needs_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&full_config());
|
||||
// Full autonomy means no approval needed, even in non-interactive mode.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
assert!(!mgr.needs_approval("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_readonly_never_needs_approval() {
|
||||
let config = AutonomyConfig {
|
||||
level: AutonomyLevel::ReadOnly,
|
||||
..AutonomyConfig::default()
|
||||
};
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
// ReadOnly blocks execution elsewhere; approval manager does not prompt.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_session_allowlist_still_works() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
|
||||
// Simulate an "Always" decision (would come from a prior channel run
|
||||
// if the tool was auto-approved somehow, e.g. via config change).
|
||||
mgr.record_decision(
|
||||
"file_write",
|
||||
&serde_json::json!({"path": "test.txt"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_overrides_session_allowlist() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
|
||||
mgr.record_decision(
|
||||
"shell",
|
||||
&serde_json::json!({"command": "ls"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
// shell is in always_ask, so it still needs approval even after "Always".
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
// ── ApprovalResponse serde ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -888,7 +888,7 @@ impl Channel for MatrixChannel {
|
||||
sender: sender.clone(),
|
||||
reply_target: format!("{}||{}", sender, room.room_id()),
|
||||
content: body,
|
||||
channel: format!("matrix:{}", room.room_id()),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
|
||||
+535
-2
@@ -76,6 +76,7 @@ pub use whatsapp::WhatsAppChannel;
|
||||
pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
|
||||
use crate::approval::ApprovalManager;
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
use crate::memory::{self, Memory};
|
||||
@@ -311,9 +312,15 @@ struct ChannelRuntimeContext {
|
||||
non_cli_excluded_tools: Arc<Vec<String>>,
|
||||
tool_call_dedup_exempt: Arc<Vec<String>>,
|
||||
model_routes: Arc<Vec<crate::config::ModelRouteConfig>>,
|
||||
query_classification: crate::config::QueryClassificationConfig,
|
||||
ack_reactions: bool,
|
||||
show_tool_calls: bool,
|
||||
session_store: Option<Arc<session_store::SessionStore>>,
|
||||
/// Non-interactive approval manager for channel-driven runs.
|
||||
/// Enforces `auto_approve` / `always_ask` / supervised policy from
|
||||
/// `[autonomy]` config; auto-denies tools that would need interactive
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1786,7 +1793,31 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let history_key = conversation_history_key(&msg);
|
||||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
let mut route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
|
||||
// ── Query classification: override route when a rule matches ──
|
||||
if let Some(hint) = crate::agent::classifier::classify(&ctx.query_classification, &msg.content)
|
||||
{
|
||||
if let Some(matched_route) = ctx
|
||||
.model_routes
|
||||
.iter()
|
||||
.find(|r| r.hint.eq_ignore_ascii_case(&hint))
|
||||
{
|
||||
tracing::info!(
|
||||
target: "query_classification",
|
||||
hint = hint.as_str(),
|
||||
provider = matched_route.provider.as_str(),
|
||||
model = matched_route.model.as_str(),
|
||||
channel = %msg.channel,
|
||||
"Channel message classified — overriding route"
|
||||
);
|
||||
route = ChannelRouteSelection {
|
||||
provider: matched_route.provider.clone(),
|
||||
model: matched_route.model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
@@ -2025,7 +2056,7 @@ async fn process_channel_message(
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
None,
|
||||
Some(&*ctx.approval_manager),
|
||||
msg.channel.as_str(),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
@@ -3235,6 +3266,8 @@ fn collect_configured_channels(
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||||
eprintln!(" ⚠ WhatsApp Web is configured but the 'whatsapp-web' feature is not compiled in.");
|
||||
eprintln!(" Rebuild with: cargo build --features whatsapp-web");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -3835,6 +3868,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
non_cli_excluded_tools: Arc::new(config.autonomy.non_cli_excluded_tools.clone()),
|
||||
tool_call_dedup_exempt: Arc::new(config.agent.tool_call_dedup_exempt.clone()),
|
||||
model_routes: Arc::new(config.model_routes.clone()),
|
||||
query_classification: config.query_classification.clone(),
|
||||
ack_reactions: config.channels_config.ack_reactions,
|
||||
show_tool_calls: config.channels_config.show_tool_calls,
|
||||
session_store: if config.channels_config.session_persistence {
|
||||
@@ -3851,6 +3885,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
} else {
|
||||
None
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
@@ -4136,9 +4171,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -4240,9 +4279,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -4300,9 +4343,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -4818,9 +4865,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4886,9 +4937,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4968,9 +5023,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5035,9 +5094,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5112,9 +5175,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5209,9 +5276,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5288,9 +5359,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5382,9 +5457,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5461,9 +5540,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5530,9 +5613,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5710,9 +5797,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -5798,9 +5889,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5904,6 +5999,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6001,9 +6100,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6083,9 +6186,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6150,9 +6257,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6775,9 +6886,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6868,9 +6983,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6961,9 +7080,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7518,9 +7641,13 @@ This is an example JSON object for profile settings."#;
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -7592,9 +7719,13 @@ This is an example JSON object for profile settings."#;
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7674,6 +7805,408 @@ This is an example JSON object for profile settings."#;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Query classification in channel message processing ─────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_applies_query_classification_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "please analyze-image from the dataset".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Vision provider should have been called instead of the default.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
vision_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["gpt-4-vision".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_disabled_uses_default_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
// Classification is disabled — matching keyword should NOT trigger reroute.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: false,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-disabled".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "please analyze-image from the dataset".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Default provider should be used since classification is disabled.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_no_match_uses_default_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
// Classification enabled with a rule that won't match the message.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-nomatch".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "just a regular text message".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Default provider should be used since no classification rule matched.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_priority_selects_highest() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let fast_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let fast_provider: Arc<dyn Provider> = fast_provider_impl.clone();
|
||||
let code_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let code_provider: Arc<dyn Provider> = code_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("fast-provider".to_string(), fast_provider);
|
||||
provider_cache_seed.insert("code-provider".to_string(), code_provider);
|
||||
|
||||
// Both rules match "code" keyword, but "code" rule has higher priority.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![
|
||||
crate::config::schema::ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 1,
|
||||
..Default::default()
|
||||
},
|
||||
crate::config::schema::ClassificationRule {
|
||||
hint: "code".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 10,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let model_routes = vec![
|
||||
crate::config::ModelRouteConfig {
|
||||
hint: "fast".into(),
|
||||
provider: "fast-provider".into(),
|
||||
model: "fast-model".into(),
|
||||
api_key: None,
|
||||
},
|
||||
crate::config::ModelRouteConfig {
|
||||
hint: "code".into(),
|
||||
provider: "code-provider".into(),
|
||||
model: "code-model".into(),
|
||||
api_key: None,
|
||||
},
|
||||
];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-prio".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "write some code for me".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Higher-priority "code" rule (priority=10) should win over "fast" (priority=1).
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(fast_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(code_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
code_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["code-model".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_channel_by_id_unconfigured_telegram_returns_error() {
|
||||
let config = Config::default();
|
||||
|
||||
+8
-7
@@ -12,13 +12,14 @@ pub use schema::{
|
||||
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
|
||||
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig,
|
||||
WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+164
-2
@@ -232,6 +232,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub agents: HashMap<String, DelegateAgentConfig>,
|
||||
|
||||
/// Swarm configurations for multi-agent orchestration.
|
||||
#[serde(default)]
|
||||
pub swarms: HashMap<String, SwarmConfig>,
|
||||
|
||||
/// Hooks configuration (lifecycle hooks and built-in hook toggles).
|
||||
#[serde(default)]
|
||||
pub hooks: HooksConfig,
|
||||
@@ -319,6 +323,44 @@ pub struct DelegateAgentConfig {
|
||||
pub max_iterations: usize,
|
||||
}
|
||||
|
||||
// ── Swarms ──────────────────────────────────────────────────────
|
||||
|
||||
/// Orchestration strategy for a swarm of agents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SwarmStrategy {
|
||||
/// Run agents sequentially; each agent's output feeds into the next.
|
||||
Sequential,
|
||||
/// Run agents in parallel; collect all outputs.
|
||||
Parallel,
|
||||
/// Use the LLM to pick the best agent for the task.
|
||||
Router,
|
||||
}
|
||||
|
||||
/// Configuration for a swarm of coordinated agents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SwarmConfig {
|
||||
/// Ordered list of agent names (must reference keys in `agents`).
|
||||
pub agents: Vec<String>,
|
||||
/// Orchestration strategy.
|
||||
pub strategy: SwarmStrategy,
|
||||
/// System prompt for router strategy (used to pick the best agent).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub router_prompt: Option<String>,
|
||||
/// Optional description shown to the LLM when choosing swarms.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Maximum total timeout for the swarm execution in seconds.
|
||||
#[serde(default = "default_swarm_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
const DEFAULT_SWARM_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn default_swarm_timeout_secs() -> u64 {
|
||||
DEFAULT_SWARM_TIMEOUT_SECS
|
||||
}
|
||||
|
||||
/// Valid temperature range for all paths (config, CLI, env override).
|
||||
pub const TEMPERATURE_RANGE: std::ops::RangeInclusive<f64> = 0.0..=2.0;
|
||||
|
||||
@@ -2975,10 +3017,10 @@ impl Default for CronConfig {
|
||||
|
||||
/// Tunnel configuration for exposing the gateway publicly (`[tunnel]` section).
|
||||
///
|
||||
/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"custom"`.
|
||||
/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, `"custom"`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TunnelConfig {
|
||||
/// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, or `"custom"`. Default: `"none"`.
|
||||
/// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, or `"custom"`. Default: `"none"`.
|
||||
pub provider: String,
|
||||
|
||||
/// Cloudflare Tunnel configuration (used when `provider = "cloudflare"`).
|
||||
@@ -2993,6 +3035,10 @@ pub struct TunnelConfig {
|
||||
#[serde(default)]
|
||||
pub ngrok: Option<NgrokTunnelConfig>,
|
||||
|
||||
/// OpenVPN tunnel configuration (used when `provider = "openvpn"`).
|
||||
#[serde(default)]
|
||||
pub openvpn: Option<OpenVpnTunnelConfig>,
|
||||
|
||||
/// Custom tunnel command configuration (used when `provider = "custom"`).
|
||||
#[serde(default)]
|
||||
pub custom: Option<CustomTunnelConfig>,
|
||||
@@ -3005,6 +3051,7 @@ impl Default for TunnelConfig {
|
||||
cloudflare: None,
|
||||
tailscale: None,
|
||||
ngrok: None,
|
||||
openvpn: None,
|
||||
custom: None,
|
||||
}
|
||||
}
|
||||
@@ -3033,6 +3080,36 @@ pub struct NgrokTunnelConfig {
|
||||
pub domain: Option<String>,
|
||||
}
|
||||
|
||||
/// OpenVPN tunnel configuration (`[tunnel.openvpn]`).
|
||||
///
|
||||
/// Required when `tunnel.provider = "openvpn"`. Omitting this section entirely
|
||||
/// preserves previous behavior. Setting `tunnel.provider = "none"` (or removing
|
||||
/// the `[tunnel.openvpn]` block) cleanly reverts to no-tunnel mode.
|
||||
///
|
||||
/// Defaults: `connect_timeout_secs = 30`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OpenVpnTunnelConfig {
|
||||
/// Path to `.ovpn` configuration file (must not be empty).
|
||||
pub config_file: String,
|
||||
/// Optional path to auth credentials file (`--auth-user-pass`).
|
||||
#[serde(default)]
|
||||
pub auth_file: Option<String>,
|
||||
/// Advertised address once VPN is connected (e.g., `"10.8.0.2:42617"`).
|
||||
/// When omitted the tunnel falls back to `http://{local_host}:{local_port}`.
|
||||
#[serde(default)]
|
||||
pub advertise_address: Option<String>,
|
||||
/// Connection timeout in seconds (default: 30, must be > 0).
|
||||
#[serde(default = "default_openvpn_timeout")]
|
||||
pub connect_timeout_secs: u64,
|
||||
/// Extra openvpn CLI arguments forwarded verbatim.
|
||||
#[serde(default)]
|
||||
pub extra_args: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_openvpn_timeout() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CustomTunnelConfig {
|
||||
/// Command template to start the tunnel. Use {port} and {host} placeholders.
|
||||
@@ -4202,6 +4279,7 @@ impl Default for Config {
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
@@ -5076,6 +5154,20 @@ impl Config {
|
||||
/// Called after TOML deserialization and env-override application to catch
|
||||
/// obviously invalid values early instead of failing at arbitrary runtime points.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
// Tunnel â OpenVPN
|
||||
if self.tunnel.provider.trim() == "openvpn" {
|
||||
let openvpn = self.tunnel.openvpn.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("tunnel.provider='openvpn' requires [tunnel.openvpn]")
|
||||
})?;
|
||||
|
||||
if openvpn.config_file.trim().is_empty() {
|
||||
anyhow::bail!("tunnel.openvpn.config_file must not be empty");
|
||||
}
|
||||
if openvpn.connect_timeout_secs == 0 {
|
||||
anyhow::bail!("tunnel.openvpn.connect_timeout_secs must be greater than 0");
|
||||
}
|
||||
}
|
||||
|
||||
// Gateway
|
||||
if self.gateway.host.trim().is_empty() {
|
||||
anyhow::bail!("gateway.host must not be empty");
|
||||
@@ -6309,6 +6401,7 @@ default_temperature = 0.7
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
@@ -6600,6 +6693,7 @@ tool_dispatcher = "xml"
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
@@ -9352,4 +9446,72 @@ require_otp_to_resume = true
|
||||
assert_eq!(&deserialized, variant);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_strategy_roundtrip() {
|
||||
let cases = vec![
|
||||
(SwarmStrategy::Sequential, "\"sequential\""),
|
||||
(SwarmStrategy::Parallel, "\"parallel\""),
|
||||
(SwarmStrategy::Router, "\"router\""),
|
||||
];
|
||||
for (variant, expected_json) in &cases {
|
||||
let serialized = serde_json::to_string(variant).expect("serialize");
|
||||
assert_eq!(&serialized, expected_json, "variant: {variant:?}");
|
||||
let deserialized: SwarmStrategy =
|
||||
serde_json::from_str(expected_json).expect("deserialize");
|
||||
assert_eq!(&deserialized, variant);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_config_deserializes_with_defaults() {
|
||||
let toml_str = r#"
|
||||
agents = ["researcher", "writer"]
|
||||
strategy = "sequential"
|
||||
"#;
|
||||
let config: SwarmConfig = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents, vec!["researcher", "writer"]);
|
||||
assert_eq!(config.strategy, SwarmStrategy::Sequential);
|
||||
assert!(config.router_prompt.is_none());
|
||||
assert!(config.description.is_none());
|
||||
assert_eq!(config.timeout_secs, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_config_deserializes_full() {
|
||||
let toml_str = r#"
|
||||
agents = ["a", "b", "c"]
|
||||
strategy = "router"
|
||||
router_prompt = "Pick the best."
|
||||
description = "Multi-agent router"
|
||||
timeout_secs = 120
|
||||
"#;
|
||||
let config: SwarmConfig = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents.len(), 3);
|
||||
assert_eq!(config.strategy, SwarmStrategy::Router);
|
||||
assert_eq!(config.router_prompt.as_deref(), Some("Pick the best."));
|
||||
assert_eq!(config.description.as_deref(), Some("Multi-agent router"));
|
||||
assert_eq!(config.timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn config_with_swarms_section_deserializes() {
|
||||
let toml_str = r#"
|
||||
[agents.researcher]
|
||||
provider = "ollama"
|
||||
model = "llama3"
|
||||
|
||||
[agents.writer]
|
||||
provider = "openrouter"
|
||||
model = "claude-sonnet"
|
||||
|
||||
[swarms.pipeline]
|
||||
agents = ["researcher", "writer"]
|
||||
strategy = "sequential"
|
||||
"#;
|
||||
let config: Config = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents.len(), 2);
|
||||
assert_eq!(config.swarms.len(), 1);
|
||||
assert!(config.swarms.contains_key("pipeline"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{Hand, HandContext, HandRun, HandRunStatus};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Load all hand definitions from TOML files in the given directory.
|
||||
///
|
||||
/// Each `.toml` file in `hands_dir` is expected to deserialize into a [`Hand`].
|
||||
/// Files that fail to parse are logged and skipped.
|
||||
pub fn load_hands(hands_dir: &Path) -> Result<Vec<Hand>> {
|
||||
if !hands_dir.is_dir() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut hands = Vec::new();
|
||||
let entries = std::fs::read_dir(hands_dir)
|
||||
.with_context(|| format!("failed to read hands directory: {}", hands_dir.display()))?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("toml") {
|
||||
continue;
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand file: {}", path.display()))?;
|
||||
match toml::from_str::<Hand>(&content) {
|
||||
Ok(hand) => hands.push(hand),
|
||||
Err(e) => {
|
||||
tracing::warn!(path = %path.display(), error = %e, "skipping malformed hand file");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hands)
|
||||
}
|
||||
|
||||
/// Load the rolling context for a hand.
|
||||
///
|
||||
/// Reads from `{hands_dir}/{name}/context.json`. Returns a fresh
|
||||
/// [`HandContext`] if the file does not exist yet.
|
||||
pub fn load_hand_context(hands_dir: &Path, name: &str) -> Result<HandContext> {
|
||||
let path = hands_dir.join(name).join("context.json");
|
||||
if !path.exists() {
|
||||
return Ok(HandContext::new(name));
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand context: {}", path.display()))?;
|
||||
let ctx: HandContext = serde_json::from_str(&content)
|
||||
.with_context(|| format!("failed to parse hand context: {}", path.display()))?;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
/// Persist the rolling context for a hand.
|
||||
///
|
||||
/// Writes to `{hands_dir}/{name}/context.json`, creating the
|
||||
/// directory if it does not exist.
|
||||
pub fn save_hand_context(hands_dir: &Path, context: &HandContext) -> Result<()> {
|
||||
let dir = hands_dir.join(&context.hand_name);
|
||||
std::fs::create_dir_all(&dir)
|
||||
.with_context(|| format!("failed to create hand context dir: {}", dir.display()))?;
|
||||
let path = dir.join("context.json");
|
||||
let json = serde_json::to_string_pretty(context)?;
|
||||
std::fs::write(&path, json)
|
||||
.with_context(|| format!("failed to write hand context: {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn write_hand_toml(dir: &Path, filename: &str, content: &str) {
|
||||
std::fs::write(dir.join(filename), content).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_empty_dir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_nonexistent_dir() {
|
||||
let hands = load_hands(Path::new("/nonexistent/path/hands")).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_parses_valid_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"scanner.toml",
|
||||
r#"
|
||||
name = "scanner"
|
||||
description = "Market scanner"
|
||||
prompt = "Scan markets."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * *"
|
||||
"#,
|
||||
);
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"digest.toml",
|
||||
r#"
|
||||
name = "digest"
|
||||
description = "News digest"
|
||||
prompt = "Digest news."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_skips_malformed_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(tmp.path(), "bad.toml", "this is not valid toml struct");
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"good.toml",
|
||||
r#"
|
||||
name = "good"
|
||||
description = "A good hand"
|
||||
prompt = "Do good things."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 60000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 1);
|
||||
assert_eq!(hands[0].name, "good");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_ignores_non_toml_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
std::fs::write(tmp.path().join("readme.md"), "# Hands").unwrap();
|
||||
std::fs::write(tmp.path().join("notes.txt"), "some notes").unwrap();
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_roundtrip_through_filesystem() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("test-hand");
|
||||
let run = HandRun {
|
||||
hand_name: "test-hand".into(),
|
||||
run_id: "run-001".into(),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec!["found something".into()],
|
||||
knowledge_added: vec!["learned something".into()],
|
||||
duration_ms: Some(500),
|
||||
};
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "test-hand").unwrap();
|
||||
|
||||
assert_eq!(loaded.hand_name, "test-hand");
|
||||
assert_eq!(loaded.total_runs, 1);
|
||||
assert_eq!(loaded.history.len(), 1);
|
||||
assert_eq!(loaded.learned_facts, vec!["learned something"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_context_returns_fresh_when_missing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = load_hand_context(tmp.path(), "nonexistent").unwrap();
|
||||
assert_eq!(ctx.hand_name, "nonexistent");
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_context_creates_directory() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = HandContext::new("new-hand");
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
|
||||
assert!(tmp.path().join("new-hand").join("context.json").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_then_load_preserves_multiple_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("multi");
|
||||
|
||||
for i in 0..5 {
|
||||
let run = HandRun {
|
||||
hand_name: "multi".into(),
|
||||
run_id: format!("run-{i:03}"),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec![format!("finding-{i}")],
|
||||
knowledge_added: vec![format!("fact-{i}")],
|
||||
duration_ms: Some(100),
|
||||
};
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "multi").unwrap();
|
||||
|
||||
assert_eq!(loaded.total_runs, 5);
|
||||
assert_eq!(loaded.history.len(), 3, "history capped at max_history=3");
|
||||
assert_eq!(loaded.learned_facts.len(), 5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::cron::Schedule;
|
||||
|
||||
// ── Hand ───────────────────────────────────────────────────────
|
||||
|
||||
/// A Hand is an autonomous agent package that runs on a schedule,
|
||||
/// accumulates knowledge over time, and reports results.
|
||||
///
|
||||
/// Hands are defined as TOML files in `~/.zeroclaw/hands/` and each
|
||||
/// maintains a rolling context of findings across runs so the agent
|
||||
/// grows smarter with every execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Hand {
|
||||
/// Unique name (also used as directory/file stem)
|
||||
pub name: String,
|
||||
/// Human-readable description of what this hand does
|
||||
pub description: String,
|
||||
/// The schedule this hand runs on (reuses cron schedule types)
|
||||
pub schedule: Schedule,
|
||||
/// System prompt / execution plan for this hand
|
||||
pub prompt: String,
|
||||
/// Domain knowledge lines to inject into context
|
||||
#[serde(default)]
|
||||
pub knowledge: Vec<String>,
|
||||
/// Tools this hand is allowed to use (None = all available)
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// Model override for this hand (None = default provider)
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
/// Whether this hand is currently active
|
||||
#[serde(default = "default_true")]
|
||||
pub active: bool,
|
||||
/// Maximum runs to keep in history
|
||||
#[serde(default = "default_max_runs")]
|
||||
pub max_history: usize,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_max_runs() -> usize {
|
||||
100
|
||||
}
|
||||
|
||||
// ── Hand Run ───────────────────────────────────────────────────
|
||||
|
||||
/// The status of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "status")]
|
||||
pub enum HandRunStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
/// Record of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandRun {
|
||||
/// Name of the hand that produced this run
|
||||
pub hand_name: String,
|
||||
/// Unique identifier for this run
|
||||
pub run_id: String,
|
||||
/// When the run started
|
||||
pub started_at: DateTime<Utc>,
|
||||
/// When the run finished (None if still running)
|
||||
pub finished_at: Option<DateTime<Utc>>,
|
||||
/// Outcome of the run
|
||||
pub status: HandRunStatus,
|
||||
/// Key findings/outputs extracted from this run
|
||||
#[serde(default)]
|
||||
pub findings: Vec<String>,
|
||||
/// New knowledge accumulated and stored to memory
|
||||
#[serde(default)]
|
||||
pub knowledge_added: Vec<String>,
|
||||
/// Wall-clock duration in milliseconds
|
||||
pub duration_ms: Option<u64>,
|
||||
}
|
||||
|
||||
// ── Hand Context ───────────────────────────────────────────────
|
||||
|
||||
/// Rolling context that accumulates across hand runs.
|
||||
///
|
||||
/// Persisted as `~/.zeroclaw/hands/{name}/context.json`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandContext {
|
||||
/// Name of the hand this context belongs to
|
||||
pub hand_name: String,
|
||||
/// Past runs, most-recent first, capped at `Hand::max_history`
|
||||
#[serde(default)]
|
||||
pub history: Vec<HandRun>,
|
||||
/// Persistent facts learned across runs
|
||||
#[serde(default)]
|
||||
pub learned_facts: Vec<String>,
|
||||
/// Timestamp of the last completed run
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
/// Total number of successful runs
|
||||
#[serde(default)]
|
||||
pub total_runs: u64,
|
||||
}
|
||||
|
||||
impl HandContext {
|
||||
/// Create a fresh, empty context for a hand.
|
||||
pub fn new(hand_name: &str) -> Self {
|
||||
Self {
|
||||
hand_name: hand_name.to_string(),
|
||||
history: Vec::new(),
|
||||
learned_facts: Vec::new(),
|
||||
last_run: None,
|
||||
total_runs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed run, updating counters and trimming history.
|
||||
pub fn record_run(&mut self, run: HandRun, max_history: usize) {
|
||||
if run.status == (HandRunStatus::Completed) {
|
||||
self.total_runs += 1;
|
||||
self.last_run = run.finished_at;
|
||||
}
|
||||
|
||||
// Merge new knowledge
|
||||
for fact in &run.knowledge_added {
|
||||
if !self.learned_facts.contains(fact) {
|
||||
self.learned_facts.push(fact.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Insert at the front (most-recent first)
|
||||
self.history.insert(0, run);
|
||||
|
||||
// Cap history length
|
||||
self.history.truncate(max_history);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cron::Schedule;
|
||||
|
||||
fn sample_hand() -> Hand {
|
||||
Hand {
|
||||
name: "market-scanner".into(),
|
||||
description: "Scans market trends and reports findings".into(),
|
||||
schedule: Schedule::Cron {
|
||||
expr: "0 9 * * 1-5".into(),
|
||||
tz: Some("America/New_York".into()),
|
||||
},
|
||||
prompt: "Scan market trends and report key findings.".into(),
|
||||
knowledge: vec!["Focus on tech sector.".into()],
|
||||
allowed_tools: Some(vec!["web_search".into(), "memory".into()]),
|
||||
model: Some("claude-opus-4-6".into()),
|
||||
active: true,
|
||||
max_history: 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_run(name: &str, status: HandRunStatus) -> HandRun {
|
||||
let now = Utc::now();
|
||||
HandRun {
|
||||
hand_name: name.into(),
|
||||
run_id: uuid::Uuid::new_v4().to_string(),
|
||||
started_at: now,
|
||||
finished_at: Some(now),
|
||||
status,
|
||||
findings: vec!["finding-1".into()],
|
||||
knowledge_added: vec!["learned-fact-A".into()],
|
||||
duration_ms: Some(1234),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Deserialization ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_from_toml() {
|
||||
let toml_str = r#"
|
||||
name = "market-scanner"
|
||||
description = "Scans market trends"
|
||||
prompt = "Scan trends."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * 1-5"
|
||||
tz = "America/New_York"
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "market-scanner");
|
||||
assert!(hand.active, "active should default to true");
|
||||
assert_eq!(hand.max_history, 100, "max_history should default to 100");
|
||||
assert!(hand.knowledge.is_empty());
|
||||
assert!(hand.allowed_tools.is_none());
|
||||
assert!(hand.model.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_full_toml() {
|
||||
let toml_str = r#"
|
||||
name = "news-digest"
|
||||
description = "Daily news digest"
|
||||
prompt = "Summarize the day's news."
|
||||
knowledge = ["focus on AI", "include funding rounds"]
|
||||
allowed_tools = ["web_search"]
|
||||
model = "claude-opus-4-6"
|
||||
active = false
|
||||
max_history = 25
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "news-digest");
|
||||
assert!(!hand.active);
|
||||
assert_eq!(hand.max_history, 25);
|
||||
assert_eq!(hand.knowledge.len(), 2);
|
||||
assert_eq!(hand.allowed_tools.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(hand.model.as_deref(), Some("claude-opus-4-6"));
|
||||
assert!(matches!(
|
||||
hand.schedule,
|
||||
Schedule::Every {
|
||||
every_ms: 3_600_000
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_roundtrip_json() {
|
||||
let hand = sample_hand();
|
||||
let json = serde_json::to_string(&hand).unwrap();
|
||||
let parsed: Hand = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, hand.name);
|
||||
assert_eq!(parsed.max_history, hand.max_history);
|
||||
}
|
||||
|
||||
// ── HandRunStatus ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_run_status_serde_roundtrip() {
|
||||
let statuses = vec![
|
||||
HandRunStatus::Running,
|
||||
HandRunStatus::Completed,
|
||||
HandRunStatus::Failed {
|
||||
error: "timeout".into(),
|
||||
},
|
||||
];
|
||||
for status in statuses {
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
let parsed: HandRunStatus = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, status);
|
||||
}
|
||||
}
|
||||
|
||||
// ── HandContext ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn context_new_is_empty() {
|
||||
let ctx = HandContext::new("test-hand");
|
||||
assert_eq!(ctx.hand_name, "test-hand");
|
||||
assert!(ctx.history.is_empty());
|
||||
assert!(ctx.learned_facts.is_empty());
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_run_increments_counters() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 1);
|
||||
assert!(ctx.last_run.is_some());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
assert_eq!(ctx.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_failed_run_does_not_increment_total() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run(
|
||||
"scanner",
|
||||
HandRunStatus::Failed {
|
||||
error: "boom".into(),
|
||||
},
|
||||
);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_caps_history_at_max() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for _ in 0..10 {
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
assert_eq!(ctx.history.len(), 3);
|
||||
assert_eq!(ctx.total_runs, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_deduplicates_learned_facts() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run1 = sample_run("scanner", HandRunStatus::Completed);
|
||||
let run2 = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run1, 100);
|
||||
ctx.record_run(run2, 100);
|
||||
|
||||
// Both runs add "learned-fact-A" but it should appear only once
|
||||
assert_eq!(ctx.learned_facts.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_json_roundtrip() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
let json = serde_json::to_string_pretty(&ctx).unwrap();
|
||||
let parsed: HandContext = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.hand_name, "scanner");
|
||||
assert_eq!(parsed.total_runs, 1);
|
||||
assert_eq!(parsed.history.len(), 1);
|
||||
assert_eq!(parsed.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn most_recent_run_is_first_in_history() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for i in 0..3 {
|
||||
let mut run = sample_run("scanner", HandRunStatus::Completed);
|
||||
run.findings = vec![format!("finding-{i}")];
|
||||
ctx.record_run(run, 100);
|
||||
}
|
||||
assert_eq!(ctx.history[0].findings[0], "finding-2");
|
||||
assert_eq!(ctx.history[2].findings[0], "finding-0");
|
||||
}
|
||||
}
|
||||
@@ -48,6 +48,7 @@ pub(crate) mod cron;
|
||||
pub(crate) mod daemon;
|
||||
pub(crate) mod doctor;
|
||||
pub mod gateway;
|
||||
pub mod hands;
|
||||
pub(crate) mod hardware;
|
||||
pub(crate) mod health;
|
||||
pub(crate) mod heartbeat;
|
||||
|
||||
+30
-5
@@ -37,7 +37,7 @@ use anyhow::{bail, Context, Result};
|
||||
use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use dialoguer::{Input, Password};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Write;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use tracing::{info, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
@@ -719,10 +719,11 @@ async fn main() -> Result<()> {
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
||||
// Onboard runs quick setup by default, or the interactive wizard with --interactive.
|
||||
// The onboard wizard uses reqwest::blocking internally, which creates its own
|
||||
// Tokio runtime. To avoid "Cannot drop a runtime in a context where blocking is
|
||||
// not allowed", we run the wizard on a blocking thread via spawn_blocking.
|
||||
// Onboard auto-detects the environment: if stdin/stdout are a TTY and no
|
||||
// provider flags were given, it runs the full interactive wizard; otherwise
|
||||
// it runs the quick (scriptable) setup. This means `curl … | bash` and
|
||||
// `zeroclaw onboard --api-key …` both take the fast path, while a bare
|
||||
// `zeroclaw onboard` in a terminal launches the wizard.
|
||||
if let Commands::Onboard {
|
||||
force,
|
||||
reinit,
|
||||
@@ -794,8 +795,16 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect: run the interactive wizard when in a TTY with no
|
||||
// provider flags, quick setup otherwise (scriptable path).
|
||||
let has_provider_flags =
|
||||
api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some();
|
||||
let is_tty = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
|
||||
|
||||
let config = if channels_only {
|
||||
Box::pin(onboard::run_channels_repair_wizard()).await
|
||||
} else if is_tty && !has_provider_flags {
|
||||
Box::pin(onboard::run_wizard(force)).await
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
@@ -2207,6 +2216,22 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_rejects_removed_interactive_flag() {
|
||||
// --interactive was removed; onboard auto-detects TTY instead.
|
||||
assert!(Cli::try_parse_from(["zeroclaw", "onboard", "--interactive"]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_bare_parses() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "onboard"]).expect("bare onboard should parse");
|
||||
|
||||
match cli.command {
|
||||
Commands::Onboard { .. } => {}
|
||||
other => panic!("expected onboard command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_parses_estop_default_engage() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "estop"]).expect("estop command should parse");
|
||||
|
||||
+2
-1
@@ -4,7 +4,7 @@ pub mod wizard;
|
||||
#[allow(unused_imports)]
|
||||
pub use wizard::{
|
||||
run_channels_repair_wizard, run_models_list, run_models_refresh, run_models_refresh_all,
|
||||
run_models_set, run_models_status, run_quick_setup,
|
||||
run_models_set, run_models_status, run_quick_setup, run_wizard,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -17,6 +17,7 @@ mod tests {
|
||||
fn wizard_functions_are_reexported() {
|
||||
assert_reexport_exists(run_channels_repair_wizard);
|
||||
assert_reexport_exists(run_quick_setup);
|
||||
assert_reexport_exists(run_wizard);
|
||||
assert_reexport_exists(run_models_refresh);
|
||||
assert_reexport_exists(run_models_list);
|
||||
assert_reexport_exists(run_models_set);
|
||||
|
||||
@@ -170,6 +170,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: hardware_config,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -527,6 +528,7 @@ async fn run_quick_setup_with_home(
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: crate::config::HardwareConfig::default(),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -4147,6 +4149,23 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
.interact()?;
|
||||
|
||||
if mode_idx == 0 {
|
||||
// Compile-time check: warn early if the feature is not enabled.
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
println!();
|
||||
println!(
|
||||
" {} {}",
|
||||
style("⚠").yellow().bold(),
|
||||
style("The 'whatsapp-web' feature is not compiled in. WhatsApp Web will not work at runtime.").yellow()
|
||||
);
|
||||
println!(
|
||||
" {} Rebuild with: {}",
|
||||
style("→").dim(),
|
||||
style("cargo build --features whatsapp-web").white().bold()
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
println!(" {}", style("Mode: WhatsApp Web").dim());
|
||||
print_bullet("1. Build with --features whatsapp-web");
|
||||
print_bullet(
|
||||
|
||||
@@ -500,19 +500,23 @@ struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
function: Option<Function>,
|
||||
|
||||
// Compatibility: Some providers (e.g., older GLM) may use 'name' directly
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
arguments: Option<String>,
|
||||
|
||||
// Compatibility: DeepSeek sometimes wraps arguments differently
|
||||
#[serde(rename = "parameters", default)]
|
||||
#[serde(
|
||||
rename = "parameters",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
@@ -3094,4 +3098,50 @@ mod tests {
|
||||
// Should not panic
|
||||
let _client = p.http_client();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_none_fields_omitted_from_json() {
|
||||
// Ensures providers like Mistral that reject extra fields (e.g. "name": null)
|
||||
// don't receive them when the ToolCall compat fields are None.
|
||||
let tc = ToolCall {
|
||||
id: Some("call_1".to_string()),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert!(!json.as_object().unwrap().contains_key("name"));
|
||||
assert!(!json.as_object().unwrap().contains_key("arguments"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
// Standard fields must be present
|
||||
assert!(json.as_object().unwrap().contains_key("id"));
|
||||
assert!(json.as_object().unwrap().contains_key("type"));
|
||||
assert!(json.as_object().unwrap().contains_key("function"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_with_compat_fields_serializes_them() {
|
||||
// When compat fields are Some, they should appear in the output.
|
||||
let tc = ToolCall {
|
||||
id: None,
|
||||
kind: None,
|
||||
function: None,
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert_eq!(json["name"], "shell");
|
||||
assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
|
||||
// None fields should be omitted
|
||||
assert!(!json.as_object().unwrap().contains_key("id"));
|
||||
assert!(!json.as_object().unwrap().contains_key("type"));
|
||||
assert!(!json.as_object().unwrap().contains_key("function"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
}
|
||||
}
|
||||
|
||||
+91
-6
@@ -442,8 +442,24 @@ fn install_linux_systemd(config: &Config) -> Result<()> {
|
||||
|
||||
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
|
||||
let unit = format!(
|
||||
"[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n",
|
||||
exe.display()
|
||||
"[Unit]\n\
|
||||
Description=ZeroClaw daemon\n\
|
||||
After=network.target\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
ExecStart={exe} daemon\n\
|
||||
Restart=always\n\
|
||||
RestartSec=3\n\
|
||||
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
|
||||
Environment=HOME=%h\n\
|
||||
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
|
||||
# so graphical/headless browsers can function correctly.\n\
|
||||
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n",
|
||||
exe = exe.display()
|
||||
);
|
||||
|
||||
fs::write(&file, unit)?;
|
||||
@@ -826,8 +842,8 @@ fn generate_openrc_script(exe_path: &Path, config_dir: &Path) -> String {
|
||||
name="zeroclaw"
|
||||
description="ZeroClaw daemon"
|
||||
|
||||
command="{}"
|
||||
command_args="--config-dir {} daemon"
|
||||
command="{exe}"
|
||||
command_args="--config-dir {config_dir} daemon"
|
||||
command_background="yes"
|
||||
command_user="zeroclaw:zeroclaw"
|
||||
pidfile="/run/${{RC_SVCNAME}}.pid"
|
||||
@@ -835,13 +851,21 @@ umask 027
|
||||
output_log="/var/log/zeroclaw/access.log"
|
||||
error_log="/var/log/zeroclaw/error.log"
|
||||
|
||||
# Provide HOME so headless browsers can create profile/cache directories.
|
||||
# Without this, Chromium/Firefox fail with sandbox or profile errors.
|
||||
export HOME="/var/lib/zeroclaw"
|
||||
|
||||
depend() {{
|
||||
need net
|
||||
after firewall
|
||||
}}
|
||||
|
||||
start_pre() {{
|
||||
checkpath --directory --owner zeroclaw:zeroclaw --mode 0750 /var/lib/zeroclaw
|
||||
}}
|
||||
"#,
|
||||
exe_path.display(),
|
||||
config_dir.display()
|
||||
exe = exe_path.display(),
|
||||
config_dir = config_dir.display(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1196,6 +1220,67 @@ mod tests {
|
||||
assert!(script.contains("after firewall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_openrc_script_sets_home_for_browser() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
|
||||
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
|
||||
|
||||
assert!(
|
||||
script.contains("export HOME=\"/var/lib/zeroclaw\""),
|
||||
"OpenRC script must set HOME for headless browser support"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_openrc_script_creates_home_directory() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
|
||||
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
|
||||
|
||||
assert!(
|
||||
script.contains("start_pre()"),
|
||||
"OpenRC script must have start_pre to create HOME dir"
|
||||
);
|
||||
assert!(
|
||||
script.contains("checkpath --directory --owner zeroclaw:zeroclaw"),
|
||||
"start_pre must ensure /var/lib/zeroclaw exists with correct ownership"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn systemd_unit_contains_home_and_pass_environment() {
|
||||
let unit = "[Unit]\n\
|
||||
Description=ZeroClaw daemon\n\
|
||||
After=network.target\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
ExecStart=/usr/local/bin/zeroclaw daemon\n\
|
||||
Restart=always\n\
|
||||
RestartSec=3\n\
|
||||
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
|
||||
Environment=HOME=%h\n\
|
||||
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
|
||||
# so graphical/headless browsers can function correctly.\n\
|
||||
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n"
|
||||
.to_string();
|
||||
|
||||
assert!(
|
||||
unit.contains("Environment=HOME=%h"),
|
||||
"systemd unit must set HOME for headless browser support"
|
||||
);
|
||||
assert!(
|
||||
unit.contains("PassEnvironment=DISPLAY XDG_RUNTIME_DIR"),
|
||||
"systemd unit must pass through display/runtime env vars"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warn_if_binary_in_home_detects_home_path() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -440,6 +440,12 @@ impl BrowserTool {
|
||||
async fn run_command(&self, args: &[&str]) -> anyhow::Result<AgentBrowserResponse> {
|
||||
let mut cmd = Command::new("agent-browser");
|
||||
|
||||
// When running as a service (systemd/OpenRC), the process may lack
|
||||
// HOME which browsers need for profile directories.
|
||||
if is_service_environment() {
|
||||
ensure_browser_env(&mut cmd);
|
||||
}
|
||||
|
||||
// Add session if configured
|
||||
if let Some(ref session) = self.session_name {
|
||||
cmd.arg("--session").arg(session);
|
||||
@@ -1461,6 +1467,14 @@ mod native_backend {
|
||||
args.push(Value::String("--disable-gpu".to_string()));
|
||||
}
|
||||
|
||||
// When running as a service (systemd/OpenRC), the browser sandbox
|
||||
// fails because the process lacks a user namespace / session.
|
||||
// --no-sandbox and --disable-dev-shm-usage are required in this context.
|
||||
if is_service_environment() {
|
||||
args.push(Value::String("--no-sandbox".to_string()));
|
||||
args.push(Value::String("--disable-dev-shm-usage".to_string()));
|
||||
}
|
||||
|
||||
if !args.is_empty() {
|
||||
chrome_options.insert("args".to_string(), Value::Array(args));
|
||||
}
|
||||
@@ -2111,6 +2125,44 @@ fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
|
||||
|| v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
|
||||
}
|
||||
|
||||
/// Detect whether the current process is running inside a service environment
|
||||
/// (e.g. systemd, OpenRC, or launchd) where the browser sandbox and
|
||||
/// environment setup may be restricted.
|
||||
fn is_service_environment() -> bool {
|
||||
if std::env::var_os("INVOCATION_ID").is_some() {
|
||||
return true;
|
||||
}
|
||||
if std::env::var_os("JOURNAL_STREAM").is_some() {
|
||||
return true;
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if std::path::Path::new("/run/openrc").exists() && std::env::var_os("HOME").is_none() {
|
||||
return true;
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if std::env::var_os("HOME").is_none() {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Ensure environment variables required by headless browsers are present
|
||||
/// when running inside a service context.
|
||||
fn ensure_browser_env(cmd: &mut Command) {
|
||||
if std::env::var_os("HOME").is_none() {
|
||||
cmd.env("HOME", "/tmp");
|
||||
}
|
||||
let existing = std::env::var("CHROMIUM_FLAGS").unwrap_or_default();
|
||||
if !existing.contains("--no-sandbox") {
|
||||
let new_flags = if existing.is_empty() {
|
||||
"--no-sandbox --disable-dev-shm-usage".to_string()
|
||||
} else {
|
||||
format!("{existing} --no-sandbox --disable-dev-shm-usage")
|
||||
};
|
||||
cmd.env("CHROMIUM_FLAGS", new_flags);
|
||||
}
|
||||
}
|
||||
|
||||
fn host_matches_allowlist(host: &str, allowed: &[String]) -> bool {
|
||||
allowed.iter().any(|pattern| {
|
||||
if pattern == "*" {
|
||||
@@ -2492,4 +2544,78 @@ mod tests {
|
||||
state.reset_session().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_browser_env_sets_home_when_missing() {
|
||||
let original_home = std::env::var_os("HOME");
|
||||
unsafe { std::env::remove_var("HOME") };
|
||||
|
||||
let mut cmd = Command::new("true");
|
||||
ensure_browser_env(&mut cmd);
|
||||
// Function completes without panic — HOME and CHROMIUM_FLAGS set on cmd.
|
||||
|
||||
if let Some(home) = original_home {
|
||||
unsafe { std::env::set_var("HOME", home) };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_browser_env_sets_chromium_flags() {
|
||||
let original = std::env::var_os("CHROMIUM_FLAGS");
|
||||
unsafe { std::env::remove_var("CHROMIUM_FLAGS") };
|
||||
|
||||
let mut cmd = Command::new("true");
|
||||
ensure_browser_env(&mut cmd);
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("CHROMIUM_FLAGS", val) };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_detects_invocation_id() {
|
||||
let original = std::env::var_os("INVOCATION_ID");
|
||||
unsafe { std::env::set_var("INVOCATION_ID", "test-unit-id") };
|
||||
|
||||
assert!(is_service_environment());
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("INVOCATION_ID", val) };
|
||||
} else {
|
||||
unsafe { std::env::remove_var("INVOCATION_ID") };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_detects_journal_stream() {
|
||||
let original = std::env::var_os("JOURNAL_STREAM");
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", "8:12345") };
|
||||
|
||||
assert!(is_service_environment());
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
|
||||
} else {
|
||||
unsafe { std::env::remove_var("JOURNAL_STREAM") };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_false_in_normal_context() {
|
||||
let inv = std::env::var_os("INVOCATION_ID");
|
||||
let journal = std::env::var_os("JOURNAL_STREAM");
|
||||
unsafe { std::env::remove_var("INVOCATION_ID") };
|
||||
unsafe { std::env::remove_var("JOURNAL_STREAM") };
|
||||
|
||||
if std::env::var_os("HOME").is_some() {
|
||||
assert!(!is_service_environment());
|
||||
}
|
||||
|
||||
if let Some(val) = inv {
|
||||
unsafe { std::env::set_var("INVOCATION_ID", val) };
|
||||
}
|
||||
if let Some(val) = journal {
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+37
-18
@@ -57,6 +57,7 @@ pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod tool_search;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
@@ -103,6 +104,7 @@ pub use schedule::ScheduleTool;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
pub use traits::Tool;
|
||||
#[allow(unused_imports)]
|
||||
@@ -358,6 +360,24 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
|
||||
// Add delegation tool when agents are configured
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
};
|
||||
|
||||
let delegate_handle: Option<DelegateParentToolsHandle> = if agents.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -365,28 +385,12 @@ pub fn all_tools_with_runtime(
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let parent_tools = Arc::new(RwLock::new(tool_arcs.clone()));
|
||||
let delegate_tool = DelegateTool::new_with_options(
|
||||
delegate_agents,
|
||||
delegate_fallback_credential,
|
||||
delegate_fallback_credential.clone(),
|
||||
security.clone(),
|
||||
crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
},
|
||||
provider_runtime_options.clone(),
|
||||
)
|
||||
.with_parent_tools(Arc::clone(&parent_tools))
|
||||
.with_multimodal_config(root_config.multimodal.clone());
|
||||
@@ -394,6 +398,21 @@ pub fn all_tools_with_runtime(
|
||||
Some(parent_tools)
|
||||
};
|
||||
|
||||
// Add swarm tool when swarms are configured
|
||||
if !root_config.swarms.is_empty() {
|
||||
let swarm_agents: HashMap<String, DelegateAgentConfig> = agents
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
tool_arcs.push(Arc::new(SwarmTool::new(
|
||||
root_config.swarms.clone(),
|
||||
swarm_agents,
|
||||
delegate_fallback_credential,
|
||||
security.clone(),
|
||||
provider_runtime_options,
|
||||
)));
|
||||
}
|
||||
|
||||
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,953 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::{DelegateAgentConfig, SwarmConfig, SwarmStrategy};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default timeout for individual agent calls within a swarm.
|
||||
const SWARM_AGENT_TIMEOUT_SECS: u64 = 120;
|
||||
|
||||
/// Tool that orchestrates multiple agents as a swarm. Supports sequential
|
||||
/// (pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies.
|
||||
pub struct SwarmTool {
|
||||
swarms: Arc<HashMap<String, SwarmConfig>>,
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
fallback_credential: Option<String>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
}
|
||||
|
||||
impl SwarmTool {
|
||||
pub fn new(
|
||||
swarms: HashMap<String, SwarmConfig>,
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_credential: Option<String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
) -> Self {
|
||||
Self {
|
||||
swarms: Arc::new(swarms),
|
||||
agents: Arc::new(agents),
|
||||
security,
|
||||
fallback_credential,
|
||||
provider_runtime_options,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider_for_agent(
|
||||
&self,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
agent_name: &str,
|
||||
) -> Result<Box<dyn Provider>, ToolResult> {
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
)
|
||||
.map_err(|e| ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider '{}' for agent '{agent_name}': {e}",
|
||||
agent_config.provider
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn call_agent(
|
||||
&self,
|
||||
agent_name: &str,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
prompt: &str,
|
||||
timeout_secs: u64,
|
||||
) -> Result<String, String> {
|
||||
let provider = self
|
||||
.create_provider_for_agent(agent_config, agent_name)
|
||||
.map_err(|r| r.error.unwrap_or_default())?;
|
||||
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
provider.chat_with_system(
|
||||
agent_config.system_prompt.as_deref(),
|
||||
prompt,
|
||||
&agent_config.model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(response)) => {
|
||||
if response.trim().is_empty() {
|
||||
Ok("[Empty response]".to_string())
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => Err(format!("Agent '{agent_name}' failed: {e}")),
|
||||
Err(_) => Err(format!(
|
||||
"Agent '{agent_name}' timed out after {timeout_secs}s"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_sequential(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let mut current_input = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let per_agent_timeout = swarm_config.timeout_secs / swarm_config.agents.len().max(1) as u64;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (i, agent_name) in swarm_config.agents.iter().enumerate() {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let agent_prompt = if i == 0 {
|
||||
current_input.clone()
|
||||
} else {
|
||||
format!("[Previous agent output]\n{current_input}\n\n[Original task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(agent_name, agent_config, &agent_prompt, per_agent_timeout)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
results.push(format!(
|
||||
"[{agent_name} ({}/{})] {output}",
|
||||
agent_config.provider, agent_config.model
|
||||
));
|
||||
current_input = output;
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: results.join("\n\n"),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm sequential — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_parallel(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for agent_name in &swarm_config.agents {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg.clone(),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
let provider = match providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider for agent '{agent_name}': {e}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let name = agent_name.clone();
|
||||
let prompt_clone = full_prompt.clone();
|
||||
let timeout = swarm_config.timeout_secs;
|
||||
let model = agent_config.model.clone();
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
let system_prompt = agent_config.system_prompt.clone();
|
||||
let provider_name = agent_config.provider.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout),
|
||||
provider.chat_with_system(
|
||||
system_prompt.as_deref(),
|
||||
&prompt_clone,
|
||||
&model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let output = match result {
|
||||
Ok(Ok(text)) => {
|
||||
if text.trim().is_empty() {
|
||||
"[Empty response]".to_string()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => format!("[Error] {e}"),
|
||||
Err(_) => format!("[Timed out after {timeout}s]"),
|
||||
};
|
||||
|
||||
(name, provider_name, model, output)
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(join_result) = join_set.join_next().await {
|
||||
match join_result {
|
||||
Ok((name, provider_name, model, output)) => {
|
||||
results.push(format!("[{name} ({provider_name}/{model})]\n{output}"));
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(format!("[join error] {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm parallel — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n---\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_router(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router swarm has no agents to choose from".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build agent descriptions for the router prompt
|
||||
let agent_descriptions: Vec<String> = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.filter_map(|name| {
|
||||
self.agents.get(name).map(|cfg| {
|
||||
let desc = cfg
|
||||
.system_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("General purpose agent");
|
||||
format!(
|
||||
"- {name}: {desc} (provider: {}, model: {})",
|
||||
cfg.provider, cfg.model
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Use the first agent's provider for routing
|
||||
let first_agent_name = &swarm_config.agents[0];
|
||||
let first_agent_config = match self.agents.get(first_agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Swarm references unknown agent '{first_agent_name}'"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let router_provider = self
|
||||
.create_provider_for_agent(first_agent_config, first_agent_name)
|
||||
.map_err(|r| anyhow::anyhow!(r.error.unwrap_or_default()))?;
|
||||
|
||||
let base_router_prompt = swarm_config
|
||||
.router_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("Pick the single best agent for this task.");
|
||||
|
||||
let routing_prompt = format!(
|
||||
"{base_router_prompt}\n\nAvailable agents:\n{}\n\nUser task: {prompt}\n\n\
|
||||
Respond with ONLY the agent name, nothing else.",
|
||||
agent_descriptions.join("\n")
|
||||
);
|
||||
|
||||
let chosen = tokio::time::timeout(
|
||||
Duration::from_secs(SWARM_AGENT_TIMEOUT_SECS),
|
||||
router_provider.chat_with_system(
|
||||
Some("You are a routing assistant. Respond with only the agent name."),
|
||||
&routing_prompt,
|
||||
&first_agent_config.model,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let chosen_name = match chosen {
|
||||
Ok(Ok(name)) => name.trim().to_string(),
|
||||
Ok(Err(e)) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router LLM call failed: {e}")),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router LLM call timed out".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Case-insensitive matching with fallback to first agent
|
||||
let matched_name = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.find(|name| name.eq_ignore_ascii_case(&chosen_name))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| swarm_config.agents[0].clone());
|
||||
|
||||
let agent_config = match self.agents.get(&matched_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router selected unknown agent '{matched_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(
|
||||
&matched_name,
|
||||
agent_config,
|
||||
&full_prompt,
|
||||
swarm_config.timeout_secs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm router — selected '{matched_name}' ({}/{})]\n{output}",
|
||||
agent_config.provider, agent_config.model
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SwarmTool {
|
||||
fn name(&self) -> &str {
|
||||
"swarm"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Orchestrate a swarm of agents to collaboratively handle a task. Supports sequential \
|
||||
(pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
let swarm_names: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
json!({
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"swarm": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": format!(
|
||||
"Name of the swarm to invoke. Available: {}",
|
||||
if swarm_names.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
swarm_names.join(", ")
|
||||
}
|
||||
)
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "The task/prompt to send to the swarm"
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Optional context to include (e.g. relevant code, prior findings)"
|
||||
}
|
||||
},
|
||||
"required": ["swarm", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let swarm_name = args
|
||||
.get("swarm")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'swarm' parameter"))?;
|
||||
|
||||
if swarm_name.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'swarm' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
if prompt.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'prompt' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let context = args
|
||||
.get("context")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.unwrap_or("");
|
||||
|
||||
let swarm_config = match self.swarms.get(swarm_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
let available: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown swarm '{swarm_name}'. Available swarms: {}",
|
||||
if available.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
available.join(", ")
|
||||
}
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm '{swarm_name}' has no agents configured")),
|
||||
});
|
||||
}
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "swarm")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match swarm_config.strategy {
|
||||
SwarmStrategy::Sequential => {
|
||||
self.execute_sequential(swarm_config, prompt, context).await
|
||||
}
|
||||
SwarmStrategy::Parallel => self.execute_parallel(swarm_config, prompt, context).await,
|
||||
SwarmStrategy::Router => self.execute_router(swarm_config, prompt, context).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn sample_agents() -> HashMap<String, DelegateAgentConfig> {
|
||||
let mut agents = HashMap::new();
|
||||
agents.insert(
|
||||
"researcher".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "llama3".to_string(),
|
||||
system_prompt: Some("You are a research assistant.".to_string()),
|
||||
api_key: None,
|
||||
temperature: Some(0.3),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
"writer".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
system_prompt: Some("You are a technical writer.".to_string()),
|
||||
api_key: Some("test-key".to_string()),
|
||||
temperature: Some(0.5),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents
|
||||
}
|
||||
|
||||
fn sample_swarms() -> HashMap<String, SwarmConfig> {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"pipeline".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: Some("Research then write".to_string()),
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"fanout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"router".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick the best agent.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert_eq!(tool.name(), "swarm");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["swarm"].is_object());
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["properties"]["context"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&json!("swarm")));
|
||||
assert!(required.contains(&json!("prompt")));
|
||||
assert_eq!(schema["additionalProperties"], json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn description_not_empty() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_lists_swarm_names() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("pipeline") || desc.contains("fanout") || desc.contains("router"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_swarms_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
HashMap::new(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("none configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_swarm_returns_error() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "nonexistent", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown swarm"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_swarm_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"swarm": "pipeline"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_swarm_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": " ", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_prompt_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": " "}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_missing_agent_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"broken".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["nonexistent_agent".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "broken", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("unknown agent"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_empty_agents_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"empty".to_string(),
|
||||
SwarmConfig {
|
||||
agents: Vec::new(),
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "empty", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("no agents configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_in_readonly_mode() {
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
readonly,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_when_rate_limited() {
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
limited,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"seq".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
// researcher uses "ollama" which won't be running in CI
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "seq", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Should fail at provider creation or call level
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"par".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "par", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Parallel strategy returns success with error annotations in output
|
||||
assert!(result.success || result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn router_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"rout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "rout", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
}
|
||||
+58
-2
@@ -2,6 +2,7 @@ mod cloudflare;
|
||||
mod custom;
|
||||
mod ngrok;
|
||||
mod none;
|
||||
mod openvpn;
|
||||
mod tailscale;
|
||||
|
||||
pub use cloudflare::CloudflareTunnel;
|
||||
@@ -9,6 +10,7 @@ pub use custom::CustomTunnel;
|
||||
pub use ngrok::NgrokTunnel;
|
||||
#[allow(unused_imports)]
|
||||
pub use none::NoneTunnel;
|
||||
pub use openvpn::OpenVpnTunnel;
|
||||
pub use tailscale::TailscaleTunnel;
|
||||
|
||||
use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig};
|
||||
@@ -104,6 +106,19 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
"openvpn" => {
|
||||
let ov = config
|
||||
.openvpn
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?;
|
||||
Ok(Some(Box::new(OpenVpnTunnel::new(
|
||||
ov.config_file.clone(),
|
||||
ov.auth_file.clone(),
|
||||
ov.advertise_address.clone(),
|
||||
ov.connect_timeout_secs,
|
||||
ov.extra_args.clone(),
|
||||
))))
|
||||
}
|
||||
"custom" => {
|
||||
let cu = config
|
||||
.custom
|
||||
@@ -116,7 +131,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"),
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +141,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::{
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig,
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig,
|
||||
TunnelConfig,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
@@ -315,6 +331,46 @@ mod tests {
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_missing_config_errors() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
assert_tunnel_err(&cfg, "[tunnel.openvpn]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_with_config_ok() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
openvpn: Some(OpenVpnTunnelConfig {
|
||||
config_file: "client.ovpn".into(),
|
||||
auth_file: None,
|
||||
advertise_address: None,
|
||||
connect_timeout_secs: 30,
|
||||
extra_args: vec![],
|
||||
}),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
let t = create_tunnel(&cfg).unwrap();
|
||||
assert!(t.is_some());
|
||||
assert_eq!(t.unwrap().name(), "openvpn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openvpn_tunnel_name() {
|
||||
let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert_eq!(t.name(), "openvpn");
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openvpn_health_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_shared_no_process_is_ok() {
|
||||
let proc = new_shared_process();
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess};
|
||||
use anyhow::{bail, Result};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection.
|
||||
///
|
||||
/// Requires the `openvpn` binary installed and accessible. On most systems,
|
||||
/// OpenVPN requires root/administrator privileges to create tun/tap devices.
|
||||
///
|
||||
/// The tunnel exposes the gateway via the VPN network using a configured
|
||||
/// `advertise_address` (e.g., `"10.8.0.2:42617"`).
|
||||
pub struct OpenVpnTunnel {
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
proc: SharedProcess,
|
||||
}
|
||||
|
||||
impl OpenVpnTunnel {
|
||||
/// Create a new OpenVPN tunnel instance.
|
||||
///
|
||||
/// * `config_file` — path to the `.ovpn` configuration file.
|
||||
/// * `auth_file` — optional path to a credentials file for `--auth-user-pass`.
|
||||
/// * `advertise_address` — optional public address to advertise once connected.
|
||||
/// * `connect_timeout_secs` — seconds to wait for the initialization sequence.
|
||||
/// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary.
|
||||
pub fn new(
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config_file,
|
||||
auth_file,
|
||||
advertise_address,
|
||||
connect_timeout_secs,
|
||||
extra_args,
|
||||
proc: new_shared_process(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the openvpn command arguments.
|
||||
fn build_args(&self) -> Vec<String> {
|
||||
let mut args = vec!["--config".to_string(), self.config_file.clone()];
|
||||
|
||||
if let Some(ref auth) = self.auth_file {
|
||||
args.push("--auth-user-pass".to_string());
|
||||
args.push(auth.clone());
|
||||
}
|
||||
|
||||
args.extend(self.extra_args.iter().cloned());
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tunnel for OpenVpnTunnel {
|
||||
fn name(&self) -> &str {
|
||||
"openvpn"
|
||||
}
|
||||
|
||||
/// Spawn the `openvpn` process and wait for the "Initialization Sequence
|
||||
/// Completed" marker on stderr. Returns the public URL on success.
|
||||
async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
|
||||
// Validate config file exists before spawning
|
||||
if !std::path::Path::new(&self.config_file).exists() {
|
||||
bail!("OpenVPN config file not found: {}", self.config_file);
|
||||
}
|
||||
|
||||
let args = self.build_args();
|
||||
|
||||
let mut child = Command::new("openvpn")
|
||||
.args(&args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
// Wait for "Initialization Sequence Completed" in stderr
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?;
|
||||
|
||||
let mut reader = tokio::io::BufReader::new(stderr).lines();
|
||||
let deadline = tokio::time::Instant::now()
|
||||
+ tokio::time::Duration::from_secs(self.connect_timeout_secs);
|
||||
|
||||
let mut connected = false;
|
||||
while tokio::time::Instant::now() < deadline {
|
||||
let line =
|
||||
tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await;
|
||||
|
||||
match line {
|
||||
Ok(Ok(Some(l))) => {
|
||||
tracing::debug!("openvpn: {l}");
|
||||
if l.contains("Initialization Sequence Completed") {
|
||||
connected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Ok(None)) => {
|
||||
bail!("OpenVPN process exited before connection was established");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
bail!("Error reading openvpn output: {e}");
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout on individual line read, continue waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !connected {
|
||||
child.kill().await.ok();
|
||||
bail!(
|
||||
"OpenVPN connection timed out after {}s waiting for initialization",
|
||||
self.connect_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
let public_url = self
|
||||
.advertise_address
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("http://{local_host}:{local_port}"));
|
||||
|
||||
// Drain stderr in background to prevent OS pipe buffer from filling and
|
||||
// blocking the openvpn process.
|
||||
tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = reader.next_line().await {
|
||||
tracing::trace!("openvpn: {line}");
|
||||
}
|
||||
});
|
||||
|
||||
let mut guard = self.proc.lock().await;
|
||||
*guard = Some(TunnelProcess {
|
||||
child,
|
||||
public_url: public_url.clone(),
|
||||
});
|
||||
|
||||
Ok(public_url)
|
||||
}
|
||||
|
||||
/// Kill the openvpn child process and release its resources.
|
||||
async fn stop(&self) -> Result<()> {
|
||||
kill_shared(&self.proc).await
|
||||
}
|
||||
|
||||
/// Return `true` if the openvpn child process is still running.
|
||||
async fn health_check(&self) -> bool {
|
||||
let guard = self.proc.lock().await;
|
||||
guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
|
||||
}
|
||||
|
||||
/// Return the public URL if the tunnel has been started.
|
||||
fn public_url(&self) -> Option<String> {
|
||||
self.proc
|
||||
.try_lock()
|
||||
.ok()
|
||||
.and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn constructor_stores_fields() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/etc/openvpn/client.ovpn".into(),
|
||||
Some("/etc/openvpn/auth.txt".into()),
|
||||
Some("10.8.0.2:42617".into()),
|
||||
45,
|
||||
vec!["--verb".into(), "3".into()],
|
||||
);
|
||||
assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn");
|
||||
assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt"));
|
||||
assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617"));
|
||||
assert_eq!(tunnel.connect_timeout_secs, 45);
|
||||
assert_eq!(tunnel.extra_args, vec!["--verb", "3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_basic() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(args, vec!["--config", "client.ovpn"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_with_auth_and_extras() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"client.ovpn".into(),
|
||||
Some("auth.txt".into()),
|
||||
None,
|
||||
30,
|
||||
vec!["--verb".into(), "5".into()],
|
||||
);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(
|
||||
args,
|
||||
vec![
|
||||
"--config",
|
||||
"client.ovpn",
|
||||
"--auth-user-pass",
|
||||
"auth.txt",
|
||||
"--verb",
|
||||
"5"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_url_is_none_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(tunnel.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_check_is_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stop_without_started_process_is_ok() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let result = tunnel.stop().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_with_missing_config_file_errors() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/nonexistent/path/to/client.ovpn".into(),
|
||||
None,
|
||||
None,
|
||||
30,
|
||||
vec![],
|
||||
);
|
||||
let result = tunnel.start("127.0.0.1", 8080).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("config file not found"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user