Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5dd724edd4 | |||
| a695ca4b9c | |||
| 811fab3b87 | |||
| 1a5d91fe69 | |||
| 6eec1c81b9 | |||
| 602db8bca1 | |||
| 314e1d3ae8 | |||
| 82be05b1e9 | |||
| 1373659058 | |||
| c7f064e866 | |||
| 9c1d63e109 | |||
| 966edf1553 | |||
| a1af84d992 | |||
| 0ad1965081 | |||
| 70e8e7ebcd | |||
| 2bcb82c5b3 | |||
| e211b5c3e3 | |||
| 8691476577 | |||
| e34a804255 | |||
| 6120b3f705 | |||
| f175261e32 | |||
| fd9f66cad7 | |||
| d928ebc92e | |||
| 9fca9f478a | |||
| 7106632b51 | |||
| b834278754 | |||
| 186f6d9797 | |||
| 6cdc92a256 | |||
| 02599dcd3c | |||
| fe64d7ef7e | |||
| 996dbe95cf | |||
| 45f953be6d | |||
| 82f29bbcb1 | |||
| 93b5a0b824 |
@@ -155,11 +155,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
|
||||
# ensuring compatibility with Ubuntu 22.04+ (#3573).
|
||||
- os: ubuntu-22.04
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-22.04
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
|
||||
@@ -156,11 +156,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
# Use ubuntu-22.04 for Linux builds to link against glibc 2.35,
|
||||
# ensuring compatibility with Ubuntu 22.04+ (#3573).
|
||||
- os: ubuntu-22.04
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
- os: ubuntu-latest
|
||||
- os: ubuntu-22.04
|
||||
target: aarch64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
ext: tar.gz
|
||||
|
||||
Generated
+1
-1
@@ -7945,7 +7945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.2"
|
||||
version = "0.3.4"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-imap",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.2"
|
||||
version = "0.3.4"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
+202
-10
@@ -195,6 +195,18 @@ const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000;
|
||||
/// Max characters retained in stored compaction summary.
|
||||
const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000;
|
||||
|
||||
/// Estimate token count for a message history using ~4 chars/token heuristic.
|
||||
/// Includes a small overhead per message for role/framing tokens.
|
||||
fn estimate_history_tokens(history: &[ChatMessage]) -> usize {
|
||||
history
|
||||
.iter()
|
||||
.map(|m| {
|
||||
// ~4 chars per token + ~4 framing tokens per message (role, delimiters)
|
||||
m.content.len().div_ceil(4) + 4
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Minimum interval between progress sends to avoid flooding the draft channel.
|
||||
pub(crate) const PROGRESS_MIN_INTERVAL_MS: u64 = 500;
|
||||
|
||||
@@ -288,6 +300,7 @@ async fn auto_compact_history(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
max_history: usize,
|
||||
max_context_tokens: usize,
|
||||
) -> Result<bool> {
|
||||
let has_system = history.first().map_or(false, |m| m.role == "system");
|
||||
let non_system_count = if has_system {
|
||||
@@ -296,7 +309,10 @@ async fn auto_compact_history(
|
||||
history.len()
|
||||
};
|
||||
|
||||
if non_system_count <= max_history {
|
||||
let estimated_tokens = estimate_history_tokens(history);
|
||||
|
||||
// Trigger compaction when either token budget OR message count is exceeded.
|
||||
if estimated_tokens <= max_context_tokens && non_system_count <= max_history {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
@@ -307,7 +323,16 @@ async fn auto_compact_history(
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let compact_end = start + compact_count;
|
||||
let mut compact_end = start + compact_count;
|
||||
|
||||
// Snap compact_end to a user-turn boundary so we don't split mid-conversation.
|
||||
while compact_end > start && history.get(compact_end).map_or(false, |m| m.role != "user") {
|
||||
compact_end -= 1;
|
||||
}
|
||||
if compact_end <= start {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
|
||||
let transcript = build_compaction_transcript(&to_compact);
|
||||
|
||||
@@ -2635,6 +2660,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"arguments": scrub_credentials(&tool_args.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!(
|
||||
"\u{274c} {}: {}\n",
|
||||
call.name,
|
||||
truncate_with_ellipsis(&scrub_credentials(&cancelled), 200)
|
||||
))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
call.name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2662,11 +2696,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
arguments: tool_args.clone(),
|
||||
};
|
||||
|
||||
// Only prompt interactively on CLI; auto-approve on other channels.
|
||||
let decision = if channel_name == "cli" {
|
||||
mgr.prompt_cli(&request)
|
||||
// Interactive CLI: prompt the operator.
|
||||
// Non-interactive (channels): auto-deny since no operator
|
||||
// is present to approve.
|
||||
let decision = if mgr.is_non_interactive() {
|
||||
ApprovalResponse::No
|
||||
} else {
|
||||
ApprovalResponse::Yes
|
||||
mgr.prompt_cli(&request)
|
||||
};
|
||||
|
||||
mgr.record_decision(&tool_name, &tool_args, decision, channel_name);
|
||||
@@ -2687,6 +2723,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"arguments": scrub_credentials(&tool_args.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!("\u{274c} {}: {}\n", tool_name, denied))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
tool_name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2723,6 +2764,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"deduplicated": true,
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!("\u{274c} {}: {}\n", tool_name, duplicate))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
tool_name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2825,13 +2871,19 @@ pub(crate) async fn run_tool_call_loop(
|
||||
// ── Progress: tool completion ───────────────────────
|
||||
if let Some(ref tx) = on_delta {
|
||||
let secs = outcome.duration.as_secs();
|
||||
let icon = if outcome.success {
|
||||
"\u{2705}"
|
||||
let progress_msg = if outcome.success {
|
||||
format!("\u{2705} {} ({secs}s)\n", call.name)
|
||||
} else if let Some(ref reason) = outcome.error_reason {
|
||||
format!(
|
||||
"\u{274c} {} ({secs}s): {}\n",
|
||||
call.name,
|
||||
truncate_with_ellipsis(reason, 200)
|
||||
)
|
||||
} else {
|
||||
"\u{274c}"
|
||||
format!("\u{274c} {} ({secs}s)\n", call.name)
|
||||
};
|
||||
tracing::debug!(tool = %call.name, secs, "Sending progress complete to draft");
|
||||
let _ = tx.send(format!("{icon} {} ({secs}s)\n", call.name)).await;
|
||||
let _ = tx.send(progress_msg).await;
|
||||
}
|
||||
|
||||
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
|
||||
@@ -3508,6 +3560,7 @@ pub async fn run(
|
||||
provider.as_ref(),
|
||||
model_name,
|
||||
config.agent.max_history_messages,
|
||||
config.agent.max_context_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -4109,6 +4162,52 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that always returns a failure with a given error reason.
|
||||
struct FailingTool {
|
||||
tool_name: String,
|
||||
error_reason: String,
|
||||
}
|
||||
|
||||
impl FailingTool {
|
||||
fn new(name: &str, error_reason: &str) -> Self {
|
||||
Self {
|
||||
tool_name: name.to_string(),
|
||||
error_reason: error_reason.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FailingTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"A tool that always fails for testing failure surfacing"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
) -> anyhow::Result<crate::tools::ToolResult> {
|
||||
Ok(crate::tools::ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(self.error_reason.clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
@@ -6449,4 +6548,97 @@ Let me check the result."#;
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "BROWSE the site");
|
||||
assert_eq!(result.len(), 1);
|
||||
}
|
||||
|
||||
// ── Token-based compaction tests ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_empty() {
|
||||
assert_eq!(super::estimate_history_tokens(&[]), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_single_message() {
|
||||
let history = vec![ChatMessage::user("hello world")]; // 11 chars
|
||||
let tokens = super::estimate_history_tokens(&history);
|
||||
// 11.div_ceil(4) + 4 = 3 + 4 = 7
|
||||
assert_eq!(tokens, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_multiple_messages() {
|
||||
let history = vec![
|
||||
ChatMessage::system("You are helpful."), // 16 chars → 4 + 4 = 8
|
||||
ChatMessage::user("What is Rust?"), // 13 chars → 4 + 4 = 8
|
||||
ChatMessage::assistant("A language."), // 11 chars → 3 + 4 = 7
|
||||
];
|
||||
let tokens = super::estimate_history_tokens(&history);
|
||||
assert_eq!(tokens, 23);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_surfaces_tool_failure_reason_in_on_delta() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"failing_shell","arguments":{"command":"rm -rf /"}}
|
||||
</tool_call>"#,
|
||||
"I could not execute that command.",
|
||||
]);
|
||||
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(FailingTool::new(
|
||||
"failing_shell",
|
||||
"Command not allowed by security policy: rm -rf /",
|
||||
))];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("delete everything"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"telegram",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
Some(tx),
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
|
||||
// Collect all messages sent to the on_delta channel.
|
||||
let mut deltas = Vec::new();
|
||||
while let Ok(msg) = rx.try_recv() {
|
||||
deltas.push(msg);
|
||||
}
|
||||
|
||||
let all_deltas = deltas.join("");
|
||||
|
||||
// The failure reason should appear in the progress messages.
|
||||
assert!(
|
||||
all_deltas.contains("Command not allowed by security policy"),
|
||||
"on_delta messages should include the tool failure reason, got: {all_deltas}"
|
||||
);
|
||||
|
||||
// Should also contain the cross mark (❌) icon to indicate failure.
|
||||
assert!(
|
||||
all_deltas.contains('\u{274c}'),
|
||||
"on_delta messages should include ❌ for failed tool calls, got: {all_deltas}"
|
||||
);
|
||||
|
||||
assert_eq!(result, "I could not execute that command.");
|
||||
}
|
||||
}
|
||||
|
||||
+128
-4
@@ -44,11 +44,18 @@ pub struct ApprovalLogEntry {
|
||||
|
||||
// ── ApprovalManager ──────────────────────────────────────────────
|
||||
|
||||
/// Manages the interactive approval workflow.
|
||||
/// Manages the approval workflow for tool calls.
|
||||
///
|
||||
/// - Checks config-level `auto_approve` / `always_ask` lists
|
||||
/// - Maintains a session-scoped "always" allowlist
|
||||
/// - Records an audit trail of all decisions
|
||||
///
|
||||
/// Two modes:
|
||||
/// - **Interactive** (CLI): tools needing approval trigger a stdin prompt.
|
||||
/// - **Non-interactive** (channels): tools needing approval are auto-denied
|
||||
/// because there is no interactive operator to approve them. `auto_approve`
|
||||
/// policy is still enforced, and `always_ask` / supervised-default tools are
|
||||
/// denied rather than silently allowed.
|
||||
pub struct ApprovalManager {
|
||||
/// Tools that never need approval (from config).
|
||||
auto_approve: HashSet<String>,
|
||||
@@ -56,6 +63,9 @@ pub struct ApprovalManager {
|
||||
always_ask: HashSet<String>,
|
||||
/// Autonomy level from config.
|
||||
autonomy_level: AutonomyLevel,
|
||||
/// When `true`, tools that would require interactive approval are
|
||||
/// auto-denied instead. Used for channel-driven (non-CLI) runs.
|
||||
non_interactive: bool,
|
||||
/// Session-scoped allowlist built from "Always" responses.
|
||||
session_allowlist: Mutex<HashSet<String>>,
|
||||
/// Audit trail of approval decisions.
|
||||
@@ -63,17 +73,40 @@ pub struct ApprovalManager {
|
||||
}
|
||||
|
||||
impl ApprovalManager {
|
||||
/// Create from autonomy config.
|
||||
/// Create an interactive (CLI) approval manager from autonomy config.
|
||||
pub fn from_config(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: false,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-interactive approval manager for channel-driven runs.
|
||||
///
|
||||
/// Enforces the same `auto_approve` / `always_ask` / supervised policies
|
||||
/// as the CLI manager, but tools that would require interactive approval
|
||||
/// are auto-denied instead of prompting (since there is no operator).
|
||||
pub fn for_non_interactive(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: true,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` when this manager operates in non-interactive mode
|
||||
/// (i.e. for channel-driven runs where no operator can approve).
|
||||
pub fn is_non_interactive(&self) -> bool {
|
||||
self.non_interactive
|
||||
}
|
||||
|
||||
/// Check whether a tool call requires interactive approval.
|
||||
///
|
||||
/// Returns `true` if the call needs a prompt, `false` if it can proceed.
|
||||
@@ -147,8 +180,8 @@ impl ApprovalManager {
|
||||
|
||||
/// Prompt the user on the CLI and return their decision.
|
||||
///
|
||||
/// For non-CLI channels, returns `Yes` automatically (interactive
|
||||
/// approval is only supported on CLI for now).
|
||||
/// Only called for interactive (CLI) managers. Non-interactive managers
|
||||
/// auto-deny in the tool-call loop before reaching this point.
|
||||
pub fn prompt_cli(&self, request: &ApprovalRequest) -> ApprovalResponse {
|
||||
prompt_cli_interactive(request)
|
||||
}
|
||||
@@ -401,6 +434,97 @@ mod tests {
|
||||
assert!(summary.contains("just a string"));
|
||||
}
|
||||
|
||||
// ── non-interactive (channel) mode ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn non_interactive_manager_reports_non_interactive() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_manager_reports_interactive() {
|
||||
let mgr = ApprovalManager::from_config(&supervised_config());
|
||||
assert!(!mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_auto_approve_tools_skip_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// auto_approve tools (file_read, memory_recall) should not need approval.
|
||||
assert!(!mgr.needs_approval("file_read"));
|
||||
assert!(!mgr.needs_approval("memory_recall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_tools_need_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// always_ask tools (shell) still report as needing approval,
|
||||
// so the tool-call loop will auto-deny them in non-interactive mode.
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_unknown_tools_need_approval_in_supervised() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// Unknown tools in supervised mode need approval (will be auto-denied
|
||||
// by the tool-call loop for non-interactive managers).
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
assert!(mgr.needs_approval("http_request"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_full_autonomy_never_needs_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&full_config());
|
||||
// Full autonomy means no approval needed, even in non-interactive mode.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
assert!(!mgr.needs_approval("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_readonly_never_needs_approval() {
|
||||
let config = AutonomyConfig {
|
||||
level: AutonomyLevel::ReadOnly,
|
||||
..AutonomyConfig::default()
|
||||
};
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
// ReadOnly blocks execution elsewhere; approval manager does not prompt.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_session_allowlist_still_works() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
|
||||
// Simulate an "Always" decision (would come from a prior channel run
|
||||
// if the tool was auto-approved somehow, e.g. via config change).
|
||||
mgr.record_decision(
|
||||
"file_write",
|
||||
&serde_json::json!({"path": "test.txt"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_overrides_session_allowlist() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
|
||||
mgr.record_decision(
|
||||
"shell",
|
||||
&serde_json::json!({"command": "ls"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
// shell is in always_ask, so it still needs approval even after "Always".
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
// ── ApprovalResponse serde ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -746,7 +746,7 @@ impl Channel for MatrixChannel {
|
||||
MessageType::Notice(content) => (content.body.clone(), None),
|
||||
MessageType::Image(content) => {
|
||||
let dl = media_info(&content.source, &content.body);
|
||||
(format!("[image: {}]", content.body), dl)
|
||||
(format!("[IMAGE:{}]", content.body), dl)
|
||||
}
|
||||
MessageType::File(content) => {
|
||||
let dl = media_info(&content.source, &content.body);
|
||||
@@ -888,7 +888,7 @@ impl Channel for MatrixChannel {
|
||||
sender: sender.clone(),
|
||||
reply_target: format!("{}||{}", sender, room.room_id()),
|
||||
content: body,
|
||||
channel: format!("matrix:{}", room.room_id()),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
|
||||
+625
-2
@@ -31,6 +31,7 @@ pub mod nextcloud_talk;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub mod nostr;
|
||||
pub mod qq;
|
||||
pub mod session_store;
|
||||
pub mod signal;
|
||||
pub mod slack;
|
||||
pub mod telegram;
|
||||
@@ -75,6 +76,7 @@ pub use whatsapp::WhatsAppChannel;
|
||||
pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
|
||||
use crate::approval::ApprovalManager;
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
use crate::memory::{self, Memory};
|
||||
@@ -310,8 +312,15 @@ struct ChannelRuntimeContext {
|
||||
non_cli_excluded_tools: Arc<Vec<String>>,
|
||||
tool_call_dedup_exempt: Arc<Vec<String>>,
|
||||
model_routes: Arc<Vec<crate::config::ModelRouteConfig>>,
|
||||
query_classification: crate::config::QueryClassificationConfig,
|
||||
ack_reactions: bool,
|
||||
show_tool_calls: bool,
|
||||
session_store: Option<Arc<session_store::SessionStore>>,
|
||||
/// Non-interactive approval manager for channel-driven runs.
|
||||
/// Enforces `auto_approve` / `always_ask` / supervised policy from
|
||||
/// `[autonomy]` config; auto-denies tools that would need interactive
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -965,6 +974,13 @@ fn proactive_trim_turns(turns: &mut Vec<ChatMessage>, budget: usize) -> usize {
|
||||
}
|
||||
|
||||
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||||
// Persist to JSONL before adding to in-memory history.
|
||||
if let Some(ref store) = ctx.session_store {
|
||||
if let Err(e) = store.append(sender_key, &turn) {
|
||||
tracing::warn!("Failed to persist session turn: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let mut histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
@@ -1777,7 +1793,31 @@ async fn process_channel_message(
|
||||
}
|
||||
|
||||
let history_key = conversation_history_key(&msg);
|
||||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
let mut route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
|
||||
// ── Query classification: override route when a rule matches ──
|
||||
if let Some(hint) = crate::agent::classifier::classify(&ctx.query_classification, &msg.content)
|
||||
{
|
||||
if let Some(matched_route) = ctx
|
||||
.model_routes
|
||||
.iter()
|
||||
.find(|r| r.hint.eq_ignore_ascii_case(&hint))
|
||||
{
|
||||
tracing::info!(
|
||||
target: "query_classification",
|
||||
hint = hint.as_str(),
|
||||
provider = matched_route.provider.as_str(),
|
||||
model = matched_route.model.as_str(),
|
||||
channel = %msg.channel,
|
||||
"Channel message classified — overriding route"
|
||||
);
|
||||
route = ChannelRouteSelection {
|
||||
provider: matched_route.provider.clone(),
|
||||
model: matched_route.model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
@@ -2016,7 +2056,7 @@ async fn process_channel_message(
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
None,
|
||||
Some(&*ctx.approval_manager),
|
||||
msg.channel.as_str(),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
@@ -2186,6 +2226,29 @@ async fn process_channel_message(
|
||||
&history_key,
|
||||
ChatMessage::assistant(&history_response),
|
||||
);
|
||||
|
||||
// Fire-and-forget LLM-driven memory consolidation.
|
||||
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let provider = Arc::clone(&ctx.provider);
|
||||
let model = ctx.model.to_string();
|
||||
let memory = Arc::clone(&ctx.memory);
|
||||
let user_msg = msg.content.clone();
|
||||
let assistant_resp = delivered_response.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::memory::consolidation::consolidate_turn(
|
||||
provider.as_ref(),
|
||||
&model,
|
||||
memory.as_ref(),
|
||||
&user_msg,
|
||||
&assistant_resp,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::debug!("Memory consolidation skipped: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
@@ -3203,6 +3266,8 @@ fn collect_configured_channels(
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||||
eprintln!(" ⚠ WhatsApp Web is configured but the 'whatsapp-web' feature is not compiled in.");
|
||||
eprintln!(" Rebuild with: cargo build --features whatsapp-web");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -3803,10 +3868,46 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
non_cli_excluded_tools: Arc::new(config.autonomy.non_cli_excluded_tools.clone()),
|
||||
tool_call_dedup_exempt: Arc::new(config.agent.tool_call_dedup_exempt.clone()),
|
||||
model_routes: Arc::new(config.model_routes.clone()),
|
||||
query_classification: config.query_classification.clone(),
|
||||
ack_reactions: config.channels_config.ack_reactions,
|
||||
show_tool_calls: config.channels_config.show_tool_calls,
|
||||
session_store: if config.channels_config.session_persistence {
|
||||
match session_store::SessionStore::new(&config.workspace_dir) {
|
||||
Ok(store) => {
|
||||
tracing::info!("📂 Session persistence enabled");
|
||||
Some(Arc::new(store))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Session persistence disabled: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
if let Some(ref store) = runtime_ctx.session_store {
|
||||
let mut hydrated = 0usize;
|
||||
let mut histories = runtime_ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
for key in store.list_sessions() {
|
||||
let msgs = store.load(&key);
|
||||
if !msgs.is_empty() {
|
||||
hydrated += 1;
|
||||
histories.insert(key, msgs);
|
||||
}
|
||||
}
|
||||
drop(histories);
|
||||
if hydrated > 0 {
|
||||
tracing::info!("📂 Restored {hydrated} session(s) from disk");
|
||||
}
|
||||
}
|
||||
|
||||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||||
|
||||
// Wait for all channel tasks
|
||||
@@ -4070,8 +4171,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -4173,8 +4279,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -4232,8 +4343,13 @@ mod tests {
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -4749,8 +4865,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4816,8 +4937,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4897,8 +5023,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4963,8 +5094,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5039,8 +5175,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5135,8 +5276,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5213,8 +5359,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5306,8 +5457,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5384,8 +5540,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5452,8 +5613,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5631,8 +5797,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -5718,8 +5889,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5817,11 +5993,16 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
},
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5919,8 +6100,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6000,8 +6186,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6066,8 +6257,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6690,8 +6886,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6782,8 +6983,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6874,8 +7080,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7430,8 +7641,13 @@ This is an example JSON object for profile settings."#;
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -7503,8 +7719,13 @@ This is an example JSON object for profile settings."#;
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7584,6 +7805,408 @@ This is an example JSON object for profile settings."#;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Query classification in channel message processing ─────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_applies_query_classification_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "please analyze-image from the dataset".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Vision provider should have been called instead of the default.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
vision_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["gpt-4-vision".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_disabled_uses_default_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
// Classification is disabled — matching keyword should NOT trigger reroute.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: false,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-disabled".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "please analyze-image from the dataset".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Default provider should be used since classification is disabled.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_no_match_uses_default_route() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let vision_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let vision_provider: Arc<dyn Provider> = vision_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("vision-provider".to_string(), vision_provider);
|
||||
|
||||
// Classification enabled with a rule that won't match the message.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![crate::config::schema::ClassificationRule {
|
||||
hint: "vision".into(),
|
||||
keywords: vec!["analyze-image".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
};
|
||||
|
||||
let model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "vision".into(),
|
||||
provider: "vision-provider".into(),
|
||||
model: "gpt-4-vision".into(),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-nomatch".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "just a regular text message".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Default provider should be used since no classification rule matched.
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(vision_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_classification_priority_selects_highest() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let fast_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let fast_provider: Arc<dyn Provider> = fast_provider_impl.clone();
|
||||
let code_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let code_provider: Arc<dyn Provider> = code_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("fast-provider".to_string(), fast_provider);
|
||||
provider_cache_seed.insert("code-provider".to_string(), code_provider);
|
||||
|
||||
// Both rules match "code" keyword, but "code" rule has higher priority.
|
||||
let classification_config = crate::config::QueryClassificationConfig {
|
||||
enabled: true,
|
||||
rules: vec![
|
||||
crate::config::schema::ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 1,
|
||||
..Default::default()
|
||||
},
|
||||
crate::config::schema::ClassificationRule {
|
||||
hint: "code".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 10,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let model_routes = vec![
|
||||
crate::config::ModelRouteConfig {
|
||||
hint: "fast".into(),
|
||||
provider: "fast-provider".into(),
|
||||
model: "fast-model".into(),
|
||||
api_key: None,
|
||||
},
|
||||
crate::config::ModelRouteConfig {
|
||||
hint: "code".into(),
|
||||
provider: "code-provider".into(),
|
||||
model: "code-model".into(),
|
||||
api_key: None,
|
||||
},
|
||||
];
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: InterruptOnNewMessageConfig {
|
||||
telegram: false,
|
||||
slack: false,
|
||||
},
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(model_routes),
|
||||
query_classification: classification_config,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-qc-prio".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "write some code for me".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Higher-priority "code" rule (priority=10) should win over "fast" (priority=1).
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(fast_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(code_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
code_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["code-model".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_channel_by_id_unconfigured_telegram_returns_error() {
|
||||
let config = Config::default();
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
//! JSONL-based session persistence for channel conversations.
|
||||
//!
|
||||
//! Each session (keyed by `channel_sender` or `channel_thread_sender`) is stored
|
||||
//! as an append-only JSONL file in `{workspace}/sessions/`. Messages are appended
|
||||
//! one-per-line as JSON, never modifying old lines. On daemon restart, sessions
|
||||
//! are loaded from disk to restore conversation context.
|
||||
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Append-only JSONL session store for channel conversations.
|
||||
pub struct SessionStore {
|
||||
sessions_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Create a new session store, ensuring the sessions directory exists.
|
||||
pub fn new(workspace_dir: &Path) -> std::io::Result<Self> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir)?;
|
||||
Ok(Self { sessions_dir })
|
||||
}
|
||||
|
||||
/// Compute the file path for a session key, sanitizing for filesystem safety.
|
||||
fn session_path(&self, session_key: &str) -> PathBuf {
|
||||
let safe_key: String = session_key
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_alphanumeric() || c == '_' || c == '-' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
self.sessions_dir.join(format!("{safe_key}.jsonl"))
|
||||
}
|
||||
|
||||
/// Load all messages for a session from its JSONL file.
|
||||
/// Returns an empty vec if the file does not exist or is unreadable.
|
||||
pub fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
let path = self.session_path(session_key);
|
||||
let file = match std::fs::File::open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<ChatMessage>(trimmed) {
|
||||
messages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Append a single message to the session JSONL file.
|
||||
pub fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
let path = self.session_path(session_key);
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)?;
|
||||
|
||||
let json = serde_json::to_string(message)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
writeln!(file, "{json}")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all session keys that have files on disk.
|
||||
pub fn list_sessions(&self) -> Vec<String> {
|
||||
let entries = match std::fs::read_dir(&self.sessions_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
entries
|
||||
.filter_map(|entry| {
|
||||
let entry = entry.ok()?;
|
||||
let name = entry.file_name().into_string().ok()?;
|
||||
name.strip_suffix(".jsonl").map(String::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn round_trip_append_and_load() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("telegram_user123", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("telegram_user123", &ChatMessage::assistant("hi there"))
|
||||
.unwrap();
|
||||
|
||||
let messages = store.load("telegram_user123");
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "hello");
|
||||
assert_eq!(messages[1].role, "assistant");
|
||||
assert_eq!(messages[1].content, "hi there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_nonexistent_session_returns_empty() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
let messages = store.load("nonexistent");
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_sanitization() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
// Keys with special chars should be sanitized
|
||||
store
|
||||
.append("slack/thread:123/user", &ChatMessage::user("test"))
|
||||
.unwrap();
|
||||
|
||||
let messages = store.load("slack/thread:123/user");
|
||||
assert_eq!(messages.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sessions_returns_keys() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("telegram_alice", &ChatMessage::user("hi"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("discord_bob", &ChatMessage::user("hey"))
|
||||
.unwrap();
|
||||
|
||||
let mut sessions = store.list_sessions();
|
||||
sessions.sort();
|
||||
assert_eq!(sessions.len(), 2);
|
||||
assert!(sessions.contains(&"discord_bob".to_string()));
|
||||
assert!(sessions.contains(&"telegram_alice".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_is_truly_append_only() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "test_session";
|
||||
|
||||
store.append(key, &ChatMessage::user("msg1")).unwrap();
|
||||
store.append(key, &ChatMessage::user("msg2")).unwrap();
|
||||
|
||||
// Read raw file to verify append-only format
|
||||
let path = store.session_path(key);
|
||||
let content = std::fs::read_to_string(&path).unwrap();
|
||||
let lines: Vec<&str> = content.trim().lines().collect();
|
||||
assert_eq!(lines.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_corrupt_lines_gracefully() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "corrupt_test";
|
||||
|
||||
// Write valid message + corrupt line + valid message
|
||||
let path = store.session_path(key);
|
||||
std::fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
let mut file = std::fs::File::create(&path).unwrap();
|
||||
writeln!(file, r#"{{"role":"user","content":"hello"}}"#).unwrap();
|
||||
writeln!(file, "this is not valid json").unwrap();
|
||||
writeln!(file, r#"{{"role":"assistant","content":"world"}}"#).unwrap();
|
||||
|
||||
let messages = store.load(key);
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].content, "hello");
|
||||
assert_eq!(messages[1].content, "world");
|
||||
}
|
||||
}
|
||||
+8
-7
@@ -12,13 +12,14 @@ pub use schema::{
|
||||
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
|
||||
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
NextcloudTalkConfig, NodesConfig, ObservabilityConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig,
|
||||
WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+188
-2
@@ -232,6 +232,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub agents: HashMap<String, DelegateAgentConfig>,
|
||||
|
||||
/// Swarm configurations for multi-agent orchestration.
|
||||
#[serde(default)]
|
||||
pub swarms: HashMap<String, SwarmConfig>,
|
||||
|
||||
/// Hooks configuration (lifecycle hooks and built-in hook toggles).
|
||||
#[serde(default)]
|
||||
pub hooks: HooksConfig,
|
||||
@@ -319,6 +323,44 @@ pub struct DelegateAgentConfig {
|
||||
pub max_iterations: usize,
|
||||
}
|
||||
|
||||
// ── Swarms ──────────────────────────────────────────────────────
|
||||
|
||||
/// Orchestration strategy for a swarm of agents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SwarmStrategy {
|
||||
/// Run agents sequentially; each agent's output feeds into the next.
|
||||
Sequential,
|
||||
/// Run agents in parallel; collect all outputs.
|
||||
Parallel,
|
||||
/// Use the LLM to pick the best agent for the task.
|
||||
Router,
|
||||
}
|
||||
|
||||
/// Configuration for a swarm of coordinated agents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SwarmConfig {
|
||||
/// Ordered list of agent names (must reference keys in `agents`).
|
||||
pub agents: Vec<String>,
|
||||
/// Orchestration strategy.
|
||||
pub strategy: SwarmStrategy,
|
||||
/// System prompt for router strategy (used to pick the best agent).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub router_prompt: Option<String>,
|
||||
/// Optional description shown to the LLM when choosing swarms.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Maximum total timeout for the swarm execution in seconds.
|
||||
#[serde(default = "default_swarm_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
const DEFAULT_SWARM_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn default_swarm_timeout_secs() -> u64 {
|
||||
DEFAULT_SWARM_TIMEOUT_SECS
|
||||
}
|
||||
|
||||
/// Valid temperature range for all paths (config, CLI, env override).
|
||||
pub const TEMPERATURE_RANGE: std::ops::RangeInclusive<f64> = 0.0..=2.0;
|
||||
|
||||
@@ -791,6 +833,11 @@ pub struct AgentConfig {
|
||||
/// Maximum conversation history messages retained per session. Default: `50`.
|
||||
#[serde(default = "default_agent_max_history_messages")]
|
||||
pub max_history_messages: usize,
|
||||
/// Maximum estimated tokens for conversation history before compaction triggers.
|
||||
/// Uses ~4 chars/token heuristic. When this threshold is exceeded, older messages
|
||||
/// are summarized to preserve context while staying within budget. Default: `32000`.
|
||||
#[serde(default = "default_agent_max_context_tokens")]
|
||||
pub max_context_tokens: usize,
|
||||
/// Enable parallel tool execution within a single iteration. Default: `false`.
|
||||
#[serde(default)]
|
||||
pub parallel_tools: bool,
|
||||
@@ -817,6 +864,10 @@ fn default_agent_max_history_messages() -> usize {
|
||||
50
|
||||
}
|
||||
|
||||
fn default_agent_max_context_tokens() -> usize {
|
||||
32_000
|
||||
}
|
||||
|
||||
fn default_agent_tool_dispatcher() -> String {
|
||||
"auto".into()
|
||||
}
|
||||
@@ -827,6 +878,7 @@ impl Default for AgentConfig {
|
||||
compact_context: false,
|
||||
max_tool_iterations: default_agent_max_tool_iterations(),
|
||||
max_history_messages: default_agent_max_history_messages(),
|
||||
max_context_tokens: default_agent_max_context_tokens(),
|
||||
parallel_tools: false,
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
@@ -1413,6 +1465,10 @@ pub struct HttpRequestConfig {
|
||||
/// Request timeout in seconds (default: 30)
|
||||
#[serde(default = "default_http_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
/// Allow requests to private/LAN hosts (RFC 1918, loopback, link-local, .local).
|
||||
/// Default: false (deny private hosts for SSRF protection).
|
||||
#[serde(default)]
|
||||
pub allow_private_hosts: bool,
|
||||
}
|
||||
|
||||
impl Default for HttpRequestConfig {
|
||||
@@ -1422,6 +1478,7 @@ impl Default for HttpRequestConfig {
|
||||
allowed_domains: vec![],
|
||||
max_response_size: default_http_max_response_size(),
|
||||
timeout_secs: default_http_timeout_secs(),
|
||||
allow_private_hosts: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2960,10 +3017,10 @@ impl Default for CronConfig {
|
||||
|
||||
/// Tunnel configuration for exposing the gateway publicly (`[tunnel]` section).
|
||||
///
|
||||
/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"custom"`.
|
||||
/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, `"custom"`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TunnelConfig {
|
||||
/// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, or `"custom"`. Default: `"none"`.
|
||||
/// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"openvpn"`, or `"custom"`. Default: `"none"`.
|
||||
pub provider: String,
|
||||
|
||||
/// Cloudflare Tunnel configuration (used when `provider = "cloudflare"`).
|
||||
@@ -2978,6 +3035,10 @@ pub struct TunnelConfig {
|
||||
#[serde(default)]
|
||||
pub ngrok: Option<NgrokTunnelConfig>,
|
||||
|
||||
/// OpenVPN tunnel configuration (used when `provider = "openvpn"`).
|
||||
#[serde(default)]
|
||||
pub openvpn: Option<OpenVpnTunnelConfig>,
|
||||
|
||||
/// Custom tunnel command configuration (used when `provider = "custom"`).
|
||||
#[serde(default)]
|
||||
pub custom: Option<CustomTunnelConfig>,
|
||||
@@ -2990,6 +3051,7 @@ impl Default for TunnelConfig {
|
||||
cloudflare: None,
|
||||
tailscale: None,
|
||||
ngrok: None,
|
||||
openvpn: None,
|
||||
custom: None,
|
||||
}
|
||||
}
|
||||
@@ -3018,6 +3080,36 @@ pub struct NgrokTunnelConfig {
|
||||
pub domain: Option<String>,
|
||||
}
|
||||
|
||||
/// OpenVPN tunnel configuration (`[tunnel.openvpn]`).
|
||||
///
|
||||
/// Required when `tunnel.provider = "openvpn"`. Omitting this section entirely
|
||||
/// preserves previous behavior. Setting `tunnel.provider = "none"` (or removing
|
||||
/// the `[tunnel.openvpn]` block) cleanly reverts to no-tunnel mode.
|
||||
///
|
||||
/// Defaults: `connect_timeout_secs = 30`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OpenVpnTunnelConfig {
|
||||
/// Path to `.ovpn` configuration file (must not be empty).
|
||||
pub config_file: String,
|
||||
/// Optional path to auth credentials file (`--auth-user-pass`).
|
||||
#[serde(default)]
|
||||
pub auth_file: Option<String>,
|
||||
/// Advertised address once VPN is connected (e.g., `"10.8.0.2:42617"`).
|
||||
/// When omitted the tunnel falls back to `http://{local_host}:{local_port}`.
|
||||
#[serde(default)]
|
||||
pub advertise_address: Option<String>,
|
||||
/// Connection timeout in seconds (default: 30, must be > 0).
|
||||
#[serde(default = "default_openvpn_timeout")]
|
||||
pub connect_timeout_secs: u64,
|
||||
/// Extra openvpn CLI arguments forwarded verbatim.
|
||||
#[serde(default)]
|
||||
pub extra_args: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_openvpn_timeout() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CustomTunnelConfig {
|
||||
/// Command template to start the tunnel. Use {port} and {host} placeholders.
|
||||
@@ -3052,6 +3144,7 @@ impl<T: ChannelConfig> crate::config::traits::ConfigHandle for ConfigWrapper<T>
|
||||
///
|
||||
/// Each channel sub-section (e.g. `telegram`, `discord`) is optional;
|
||||
/// setting it to `Some(...)` enables that channel.
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ChannelsConfig {
|
||||
/// Enable the CLI interactive channel. Default: `true`.
|
||||
@@ -3114,6 +3207,10 @@ pub struct ChannelsConfig {
|
||||
/// not forwarded as individual channel messages. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub show_tool_calls: bool,
|
||||
/// Persist channel conversation history to JSONL files so sessions survive
|
||||
/// daemon restarts. Files are stored in `{workspace}/sessions/`. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub session_persistence: bool,
|
||||
}
|
||||
|
||||
impl ChannelsConfig {
|
||||
@@ -3248,6 +3345,7 @@ impl Default for ChannelsConfig {
|
||||
message_timeout_secs: default_channel_message_timeout_secs(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4181,6 +4279,7 @@ impl Default for Config {
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
@@ -5055,6 +5154,20 @@ impl Config {
|
||||
/// Called after TOML deserialization and env-override application to catch
|
||||
/// obviously invalid values early instead of failing at arbitrary runtime points.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
// Tunnel â OpenVPN
|
||||
if self.tunnel.provider.trim() == "openvpn" {
|
||||
let openvpn = self.tunnel.openvpn.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("tunnel.provider='openvpn' requires [tunnel.openvpn]")
|
||||
})?;
|
||||
|
||||
if openvpn.config_file.trim().is_empty() {
|
||||
anyhow::bail!("tunnel.openvpn.config_file must not be empty");
|
||||
}
|
||||
if openvpn.connect_timeout_secs == 0 {
|
||||
anyhow::bail!("tunnel.openvpn.connect_timeout_secs must be greater than 0");
|
||||
}
|
||||
}
|
||||
|
||||
// Gateway
|
||||
if self.gateway.host.trim().is_empty() {
|
||||
anyhow::bail!("gateway.host must not be empty");
|
||||
@@ -6269,6 +6382,7 @@ default_temperature = 0.7
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
},
|
||||
memory: MemoryConfig::default(),
|
||||
storage: StorageConfig::default(),
|
||||
@@ -6287,6 +6401,7 @@ default_temperature = 0.7
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
@@ -6578,6 +6693,7 @@ tool_dispatcher = "xml"
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
@@ -6983,6 +7099,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -7210,6 +7327,7 @@ channel_id = "C123"
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -9328,4 +9446,72 @@ require_otp_to_resume = true
|
||||
assert_eq!(&deserialized, variant);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_strategy_roundtrip() {
|
||||
let cases = vec![
|
||||
(SwarmStrategy::Sequential, "\"sequential\""),
|
||||
(SwarmStrategy::Parallel, "\"parallel\""),
|
||||
(SwarmStrategy::Router, "\"router\""),
|
||||
];
|
||||
for (variant, expected_json) in &cases {
|
||||
let serialized = serde_json::to_string(variant).expect("serialize");
|
||||
assert_eq!(&serialized, expected_json, "variant: {variant:?}");
|
||||
let deserialized: SwarmStrategy =
|
||||
serde_json::from_str(expected_json).expect("deserialize");
|
||||
assert_eq!(&deserialized, variant);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_config_deserializes_with_defaults() {
|
||||
let toml_str = r#"
|
||||
agents = ["researcher", "writer"]
|
||||
strategy = "sequential"
|
||||
"#;
|
||||
let config: SwarmConfig = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents, vec!["researcher", "writer"]);
|
||||
assert_eq!(config.strategy, SwarmStrategy::Sequential);
|
||||
assert!(config.router_prompt.is_none());
|
||||
assert!(config.description.is_none());
|
||||
assert_eq!(config.timeout_secs, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_config_deserializes_full() {
|
||||
let toml_str = r#"
|
||||
agents = ["a", "b", "c"]
|
||||
strategy = "router"
|
||||
router_prompt = "Pick the best."
|
||||
description = "Multi-agent router"
|
||||
timeout_secs = 120
|
||||
"#;
|
||||
let config: SwarmConfig = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents.len(), 3);
|
||||
assert_eq!(config.strategy, SwarmStrategy::Router);
|
||||
assert_eq!(config.router_prompt.as_deref(), Some("Pick the best."));
|
||||
assert_eq!(config.description.as_deref(), Some("Multi-agent router"));
|
||||
assert_eq!(config.timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn config_with_swarms_section_deserializes() {
|
||||
let toml_str = r#"
|
||||
[agents.researcher]
|
||||
provider = "ollama"
|
||||
model = "llama3"
|
||||
|
||||
[agents.writer]
|
||||
provider = "openrouter"
|
||||
model = "claude-sonnet"
|
||||
|
||||
[swarms.pipeline]
|
||||
agents = ["researcher", "writer"]
|
||||
strategy = "sequential"
|
||||
"#;
|
||||
let config: Config = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents.len(), 2);
|
||||
assert_eq!(config.swarms.len(), 1);
|
||||
assert!(config.swarms.contains_key("pipeline"));
|
||||
}
|
||||
}
|
||||
|
||||
+172
-21
@@ -152,44 +152,122 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<(
|
||||
crate::CronCommands::Add {
|
||||
expression,
|
||||
tz,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
let schedule = Schedule::Cron {
|
||||
expr: expression,
|
||||
tz,
|
||||
};
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added cron job {}", job.id);
|
||||
println!(" Expr: {}", job.expression);
|
||||
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
println!("✅ Added agent cron job {}", job.id);
|
||||
println!(" Expr : {}", job.expression);
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added cron job {}", job.id);
|
||||
println!(" Expr: {}", job.expression);
|
||||
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::AddAt { at, command } => {
|
||||
crate::CronCommands::AddAt { at, agent, command } => {
|
||||
let at = chrono::DateTime::parse_from_rfc3339(&at)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid RFC3339 timestamp for --at: {e}"))?
|
||||
.with_timezone(&chrono::Utc);
|
||||
let schedule = Schedule::At { at };
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
println!("✅ Added one-shot agent cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::AddEvery { every_ms, command } => {
|
||||
crate::CronCommands::AddEvery {
|
||||
every_ms,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
let schedule = Schedule::Every { every_ms };
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added interval cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
println!("✅ Added interval agent cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt : {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added interval cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::Once { delay, command } => {
|
||||
let job = add_once(config, &delay, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
crate::CronCommands::Once {
|
||||
delay,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
if agent {
|
||||
let duration = parse_delay(&delay)?;
|
||||
let at = chrono::Utc::now() + duration;
|
||||
let schedule = Schedule::At { at };
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
println!("✅ Added one-shot agent cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_once(config, &delay, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::Update {
|
||||
@@ -686,4 +764,77 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("blocked by security policy"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_agent_flag_creates_agent_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/15 * * * *".into(),
|
||||
tz: None,
|
||||
agent: true,
|
||||
command: "Check server health: disk space, memory, CPU load".into(),
|
||||
},
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Agent);
|
||||
assert_eq!(
|
||||
jobs[0].prompt.as_deref(),
|
||||
Some("Check server health: disk space, memory, CPU load")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_agent_flag_bypasses_shell_security_validation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut config = test_config(&tmp);
|
||||
config.autonomy.allowed_commands = vec!["echo".into()];
|
||||
config.autonomy.level = crate::security::AutonomyLevel::Supervised;
|
||||
|
||||
// Without --agent, a natural language string would be blocked by shell
|
||||
// security policy. With --agent, it routes to agent job and skips
|
||||
// shell validation entirely.
|
||||
let result = handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/15 * * * *".into(),
|
||||
tz: None,
|
||||
agent: true,
|
||||
command: "Check server health: disk space, memory, CPU load".into(),
|
||||
},
|
||||
&config,
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Agent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_without_agent_flag_defaults_to_shell_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/5 * * * *".into(),
|
||||
tz: None,
|
||||
agent: false,
|
||||
command: "echo ok".into(),
|
||||
},
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Shell);
|
||||
assert_eq!(jobs[0].command, "echo ok");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{Hand, HandContext, HandRun, HandRunStatus};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Load all hand definitions from TOML files in the given directory.
|
||||
///
|
||||
/// Each `.toml` file in `hands_dir` is expected to deserialize into a [`Hand`].
|
||||
/// Files that fail to parse are logged and skipped.
|
||||
pub fn load_hands(hands_dir: &Path) -> Result<Vec<Hand>> {
|
||||
if !hands_dir.is_dir() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut hands = Vec::new();
|
||||
let entries = std::fs::read_dir(hands_dir)
|
||||
.with_context(|| format!("failed to read hands directory: {}", hands_dir.display()))?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("toml") {
|
||||
continue;
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand file: {}", path.display()))?;
|
||||
match toml::from_str::<Hand>(&content) {
|
||||
Ok(hand) => hands.push(hand),
|
||||
Err(e) => {
|
||||
tracing::warn!(path = %path.display(), error = %e, "skipping malformed hand file");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hands)
|
||||
}
|
||||
|
||||
/// Load the rolling context for a hand.
|
||||
///
|
||||
/// Reads from `{hands_dir}/{name}/context.json`. Returns a fresh
|
||||
/// [`HandContext`] if the file does not exist yet.
|
||||
pub fn load_hand_context(hands_dir: &Path, name: &str) -> Result<HandContext> {
|
||||
let path = hands_dir.join(name).join("context.json");
|
||||
if !path.exists() {
|
||||
return Ok(HandContext::new(name));
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand context: {}", path.display()))?;
|
||||
let ctx: HandContext = serde_json::from_str(&content)
|
||||
.with_context(|| format!("failed to parse hand context: {}", path.display()))?;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
/// Persist the rolling context for a hand.
|
||||
///
|
||||
/// Writes to `{hands_dir}/{name}/context.json`, creating the
|
||||
/// directory if it does not exist.
|
||||
pub fn save_hand_context(hands_dir: &Path, context: &HandContext) -> Result<()> {
|
||||
let dir = hands_dir.join(&context.hand_name);
|
||||
std::fs::create_dir_all(&dir)
|
||||
.with_context(|| format!("failed to create hand context dir: {}", dir.display()))?;
|
||||
let path = dir.join("context.json");
|
||||
let json = serde_json::to_string_pretty(context)?;
|
||||
std::fs::write(&path, json)
|
||||
.with_context(|| format!("failed to write hand context: {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn write_hand_toml(dir: &Path, filename: &str, content: &str) {
|
||||
std::fs::write(dir.join(filename), content).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_empty_dir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_nonexistent_dir() {
|
||||
let hands = load_hands(Path::new("/nonexistent/path/hands")).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_parses_valid_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"scanner.toml",
|
||||
r#"
|
||||
name = "scanner"
|
||||
description = "Market scanner"
|
||||
prompt = "Scan markets."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * *"
|
||||
"#,
|
||||
);
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"digest.toml",
|
||||
r#"
|
||||
name = "digest"
|
||||
description = "News digest"
|
||||
prompt = "Digest news."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_skips_malformed_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(tmp.path(), "bad.toml", "this is not valid toml struct");
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"good.toml",
|
||||
r#"
|
||||
name = "good"
|
||||
description = "A good hand"
|
||||
prompt = "Do good things."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 60000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 1);
|
||||
assert_eq!(hands[0].name, "good");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_ignores_non_toml_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
std::fs::write(tmp.path().join("readme.md"), "# Hands").unwrap();
|
||||
std::fs::write(tmp.path().join("notes.txt"), "some notes").unwrap();
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_roundtrip_through_filesystem() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("test-hand");
|
||||
let run = HandRun {
|
||||
hand_name: "test-hand".into(),
|
||||
run_id: "run-001".into(),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec!["found something".into()],
|
||||
knowledge_added: vec!["learned something".into()],
|
||||
duration_ms: Some(500),
|
||||
};
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "test-hand").unwrap();
|
||||
|
||||
assert_eq!(loaded.hand_name, "test-hand");
|
||||
assert_eq!(loaded.total_runs, 1);
|
||||
assert_eq!(loaded.history.len(), 1);
|
||||
assert_eq!(loaded.learned_facts, vec!["learned something"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_context_returns_fresh_when_missing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = load_hand_context(tmp.path(), "nonexistent").unwrap();
|
||||
assert_eq!(ctx.hand_name, "nonexistent");
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_context_creates_directory() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = HandContext::new("new-hand");
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
|
||||
assert!(tmp.path().join("new-hand").join("context.json").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_then_load_preserves_multiple_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("multi");
|
||||
|
||||
for i in 0..5 {
|
||||
let run = HandRun {
|
||||
hand_name: "multi".into(),
|
||||
run_id: format!("run-{i:03}"),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec![format!("finding-{i}")],
|
||||
knowledge_added: vec![format!("fact-{i}")],
|
||||
duration_ms: Some(100),
|
||||
};
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "multi").unwrap();
|
||||
|
||||
assert_eq!(loaded.total_runs, 5);
|
||||
assert_eq!(loaded.history.len(), 3, "history capped at max_history=3");
|
||||
assert_eq!(loaded.learned_facts.len(), 5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::cron::Schedule;
|
||||
|
||||
// ── Hand ───────────────────────────────────────────────────────
|
||||
|
||||
/// A Hand is an autonomous agent package that runs on a schedule,
|
||||
/// accumulates knowledge over time, and reports results.
|
||||
///
|
||||
/// Hands are defined as TOML files in `~/.zeroclaw/hands/` and each
|
||||
/// maintains a rolling context of findings across runs so the agent
|
||||
/// grows smarter with every execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Hand {
|
||||
/// Unique name (also used as directory/file stem)
|
||||
pub name: String,
|
||||
/// Human-readable description of what this hand does
|
||||
pub description: String,
|
||||
/// The schedule this hand runs on (reuses cron schedule types)
|
||||
pub schedule: Schedule,
|
||||
/// System prompt / execution plan for this hand
|
||||
pub prompt: String,
|
||||
/// Domain knowledge lines to inject into context
|
||||
#[serde(default)]
|
||||
pub knowledge: Vec<String>,
|
||||
/// Tools this hand is allowed to use (None = all available)
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// Model override for this hand (None = default provider)
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
/// Whether this hand is currently active
|
||||
#[serde(default = "default_true")]
|
||||
pub active: bool,
|
||||
/// Maximum runs to keep in history
|
||||
#[serde(default = "default_max_runs")]
|
||||
pub max_history: usize,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_max_runs() -> usize {
|
||||
100
|
||||
}
|
||||
|
||||
// ── Hand Run ───────────────────────────────────────────────────
|
||||
|
||||
/// The status of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "status")]
|
||||
pub enum HandRunStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
/// Record of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandRun {
|
||||
/// Name of the hand that produced this run
|
||||
pub hand_name: String,
|
||||
/// Unique identifier for this run
|
||||
pub run_id: String,
|
||||
/// When the run started
|
||||
pub started_at: DateTime<Utc>,
|
||||
/// When the run finished (None if still running)
|
||||
pub finished_at: Option<DateTime<Utc>>,
|
||||
/// Outcome of the run
|
||||
pub status: HandRunStatus,
|
||||
/// Key findings/outputs extracted from this run
|
||||
#[serde(default)]
|
||||
pub findings: Vec<String>,
|
||||
/// New knowledge accumulated and stored to memory
|
||||
#[serde(default)]
|
||||
pub knowledge_added: Vec<String>,
|
||||
/// Wall-clock duration in milliseconds
|
||||
pub duration_ms: Option<u64>,
|
||||
}
|
||||
|
||||
// ── Hand Context ───────────────────────────────────────────────
|
||||
|
||||
/// Rolling context that accumulates across hand runs.
|
||||
///
|
||||
/// Persisted as `~/.zeroclaw/hands/{name}/context.json`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandContext {
|
||||
/// Name of the hand this context belongs to
|
||||
pub hand_name: String,
|
||||
/// Past runs, most-recent first, capped at `Hand::max_history`
|
||||
#[serde(default)]
|
||||
pub history: Vec<HandRun>,
|
||||
/// Persistent facts learned across runs
|
||||
#[serde(default)]
|
||||
pub learned_facts: Vec<String>,
|
||||
/// Timestamp of the last completed run
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
/// Total number of successful runs
|
||||
#[serde(default)]
|
||||
pub total_runs: u64,
|
||||
}
|
||||
|
||||
impl HandContext {
|
||||
/// Create a fresh, empty context for a hand.
|
||||
pub fn new(hand_name: &str) -> Self {
|
||||
Self {
|
||||
hand_name: hand_name.to_string(),
|
||||
history: Vec::new(),
|
||||
learned_facts: Vec::new(),
|
||||
last_run: None,
|
||||
total_runs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed run, updating counters and trimming history.
|
||||
pub fn record_run(&mut self, run: HandRun, max_history: usize) {
|
||||
if run.status == (HandRunStatus::Completed) {
|
||||
self.total_runs += 1;
|
||||
self.last_run = run.finished_at;
|
||||
}
|
||||
|
||||
// Merge new knowledge
|
||||
for fact in &run.knowledge_added {
|
||||
if !self.learned_facts.contains(fact) {
|
||||
self.learned_facts.push(fact.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Insert at the front (most-recent first)
|
||||
self.history.insert(0, run);
|
||||
|
||||
// Cap history length
|
||||
self.history.truncate(max_history);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cron::Schedule;
|
||||
|
||||
fn sample_hand() -> Hand {
|
||||
Hand {
|
||||
name: "market-scanner".into(),
|
||||
description: "Scans market trends and reports findings".into(),
|
||||
schedule: Schedule::Cron {
|
||||
expr: "0 9 * * 1-5".into(),
|
||||
tz: Some("America/New_York".into()),
|
||||
},
|
||||
prompt: "Scan market trends and report key findings.".into(),
|
||||
knowledge: vec!["Focus on tech sector.".into()],
|
||||
allowed_tools: Some(vec!["web_search".into(), "memory".into()]),
|
||||
model: Some("claude-opus-4-6".into()),
|
||||
active: true,
|
||||
max_history: 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_run(name: &str, status: HandRunStatus) -> HandRun {
|
||||
let now = Utc::now();
|
||||
HandRun {
|
||||
hand_name: name.into(),
|
||||
run_id: uuid::Uuid::new_v4().to_string(),
|
||||
started_at: now,
|
||||
finished_at: Some(now),
|
||||
status,
|
||||
findings: vec!["finding-1".into()],
|
||||
knowledge_added: vec!["learned-fact-A".into()],
|
||||
duration_ms: Some(1234),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Deserialization ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_from_toml() {
|
||||
let toml_str = r#"
|
||||
name = "market-scanner"
|
||||
description = "Scans market trends"
|
||||
prompt = "Scan trends."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * 1-5"
|
||||
tz = "America/New_York"
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "market-scanner");
|
||||
assert!(hand.active, "active should default to true");
|
||||
assert_eq!(hand.max_history, 100, "max_history should default to 100");
|
||||
assert!(hand.knowledge.is_empty());
|
||||
assert!(hand.allowed_tools.is_none());
|
||||
assert!(hand.model.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_full_toml() {
|
||||
let toml_str = r#"
|
||||
name = "news-digest"
|
||||
description = "Daily news digest"
|
||||
prompt = "Summarize the day's news."
|
||||
knowledge = ["focus on AI", "include funding rounds"]
|
||||
allowed_tools = ["web_search"]
|
||||
model = "claude-opus-4-6"
|
||||
active = false
|
||||
max_history = 25
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "news-digest");
|
||||
assert!(!hand.active);
|
||||
assert_eq!(hand.max_history, 25);
|
||||
assert_eq!(hand.knowledge.len(), 2);
|
||||
assert_eq!(hand.allowed_tools.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(hand.model.as_deref(), Some("claude-opus-4-6"));
|
||||
assert!(matches!(
|
||||
hand.schedule,
|
||||
Schedule::Every {
|
||||
every_ms: 3_600_000
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_roundtrip_json() {
|
||||
let hand = sample_hand();
|
||||
let json = serde_json::to_string(&hand).unwrap();
|
||||
let parsed: Hand = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, hand.name);
|
||||
assert_eq!(parsed.max_history, hand.max_history);
|
||||
}
|
||||
|
||||
// ── HandRunStatus ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_run_status_serde_roundtrip() {
|
||||
let statuses = vec![
|
||||
HandRunStatus::Running,
|
||||
HandRunStatus::Completed,
|
||||
HandRunStatus::Failed {
|
||||
error: "timeout".into(),
|
||||
},
|
||||
];
|
||||
for status in statuses {
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
let parsed: HandRunStatus = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, status);
|
||||
}
|
||||
}
|
||||
|
||||
// ── HandContext ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn context_new_is_empty() {
|
||||
let ctx = HandContext::new("test-hand");
|
||||
assert_eq!(ctx.hand_name, "test-hand");
|
||||
assert!(ctx.history.is_empty());
|
||||
assert!(ctx.learned_facts.is_empty());
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_run_increments_counters() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 1);
|
||||
assert!(ctx.last_run.is_some());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
assert_eq!(ctx.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_failed_run_does_not_increment_total() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run(
|
||||
"scanner",
|
||||
HandRunStatus::Failed {
|
||||
error: "boom".into(),
|
||||
},
|
||||
);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_caps_history_at_max() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for _ in 0..10 {
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
assert_eq!(ctx.history.len(), 3);
|
||||
assert_eq!(ctx.total_runs, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_deduplicates_learned_facts() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run1 = sample_run("scanner", HandRunStatus::Completed);
|
||||
let run2 = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run1, 100);
|
||||
ctx.record_run(run2, 100);
|
||||
|
||||
// Both runs add "learned-fact-A" but it should appear only once
|
||||
assert_eq!(ctx.learned_facts.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_json_roundtrip() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
let json = serde_json::to_string_pretty(&ctx).unwrap();
|
||||
let parsed: HandContext = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.hand_name, "scanner");
|
||||
assert_eq!(parsed.total_runs, 1);
|
||||
assert_eq!(parsed.history.len(), 1);
|
||||
assert_eq!(parsed.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn most_recent_run_is_first_in_history() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for i in 0..3 {
|
||||
let mut run = sample_run("scanner", HandRunStatus::Completed);
|
||||
run.findings = vec![format!("finding-{i}")];
|
||||
ctx.record_run(run, 100);
|
||||
}
|
||||
assert_eq!(ctx.history[0].findings[0], "finding-2");
|
||||
assert_eq!(ctx.history[2].findings[0], "finding-0");
|
||||
}
|
||||
}
|
||||
+20
-6
@@ -48,6 +48,7 @@ pub(crate) mod cron;
|
||||
pub(crate) mod daemon;
|
||||
pub(crate) mod doctor;
|
||||
pub mod gateway;
|
||||
pub mod hands;
|
||||
pub(crate) mod hardware;
|
||||
pub(crate) mod health;
|
||||
pub(crate) mod heartbeat;
|
||||
@@ -280,15 +281,19 @@ Times are evaluated in UTC by default; use --tz with an IANA \
|
||||
timezone name to override.
|
||||
|
||||
Examples:
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health'")]
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
|
||||
zeroclaw cron add '*/5 * * * *' 'echo ok'")]
|
||||
Add {
|
||||
/// Cron expression
|
||||
expression: String,
|
||||
/// Optional IANA timezone (e.g. America/Los_Angeles)
|
||||
#[arg(long)]
|
||||
tz: Option<String>,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a one-shot scheduled task at an RFC3339 timestamp
|
||||
@@ -303,7 +308,10 @@ Examples:
|
||||
AddAt {
|
||||
/// One-shot timestamp in RFC3339 format
|
||||
at: String,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a fixed-interval scheduled task
|
||||
@@ -318,7 +326,10 @@ Examples:
|
||||
AddEvery {
|
||||
/// Interval in milliseconds
|
||||
every_ms: u64,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a one-shot delayed task (e.g. "30m", "2h", "1d")
|
||||
@@ -335,7 +346,10 @@ Examples:
|
||||
Once {
|
||||
/// Delay duration
|
||||
delay: String,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Remove a scheduled task
|
||||
|
||||
+35
-9
@@ -37,7 +37,7 @@ use anyhow::{bail, Context, Result};
|
||||
use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use dialoguer::{Input, Password};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Write;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use tracing::{info, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
@@ -325,11 +325,12 @@ override with --tz and an IANA timezone name.
|
||||
|
||||
Examples:
|
||||
zeroclaw cron list
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health'
|
||||
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder'
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
|
||||
zeroclaw cron add '*/5 * * * *' 'echo ok'
|
||||
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder' --agent
|
||||
zeroclaw cron add-every 60000 'Ping heartbeat'
|
||||
zeroclaw cron once 30m 'Run backup in 30 minutes'
|
||||
zeroclaw cron once 30m 'Run backup in 30 minutes' --agent
|
||||
zeroclaw cron pause <task-id>
|
||||
zeroclaw cron update <task-id> --expression '0 8 * * *' --tz Europe/London")]
|
||||
Cron {
|
||||
@@ -718,10 +719,11 @@ async fn main() -> Result<()> {
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
||||
// Onboard runs quick setup by default, or the interactive wizard with --interactive.
|
||||
// The onboard wizard uses reqwest::blocking internally, which creates its own
|
||||
// Tokio runtime. To avoid "Cannot drop a runtime in a context where blocking is
|
||||
// not allowed", we run the wizard on a blocking thread via spawn_blocking.
|
||||
// Onboard auto-detects the environment: if stdin/stdout are a TTY and no
|
||||
// provider flags were given, it runs the full interactive wizard; otherwise
|
||||
// it runs the quick (scriptable) setup. This means `curl … | bash` and
|
||||
// `zeroclaw onboard --api-key …` both take the fast path, while a bare
|
||||
// `zeroclaw onboard` in a terminal launches the wizard.
|
||||
if let Commands::Onboard {
|
||||
force,
|
||||
reinit,
|
||||
@@ -793,8 +795,16 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect: run the interactive wizard when in a TTY with no
|
||||
// provider flags, quick setup otherwise (scriptable path).
|
||||
let has_provider_flags =
|
||||
api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some();
|
||||
let is_tty = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
|
||||
|
||||
let config = if channels_only {
|
||||
Box::pin(onboard::run_channels_repair_wizard()).await
|
||||
} else if is_tty && !has_provider_flags {
|
||||
Box::pin(onboard::run_wizard(force)).await
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
@@ -2206,6 +2216,22 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_rejects_removed_interactive_flag() {
|
||||
// --interactive was removed; onboard auto-detects TTY instead.
|
||||
assert!(Cli::try_parse_from(["zeroclaw", "onboard", "--interactive"]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_bare_parses() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "onboard"]).expect("bare onboard should parse");
|
||||
|
||||
match cli.command {
|
||||
Commands::Onboard { .. } => {}
|
||||
other => panic!("expected onboard command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_parses_estop_default_engage() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "estop"]).expect("estop command should parse");
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
//! LLM-driven memory consolidation.
|
||||
//!
|
||||
//! After each conversation turn, extracts structured information:
|
||||
//! - `history_entry`: A timestamped summary for the daily conversation log.
|
||||
//! - `memory_update`: New facts, preferences, or decisions worth remembering
|
||||
//! long-term (or `null` if nothing new was learned).
|
||||
//!
|
||||
//! This two-phase approach replaces the naive raw-message auto-save with
|
||||
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
|
||||
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
/// Output of consolidation extraction.
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct ConsolidationResult {
|
||||
/// Brief timestamped summary for the conversation history log.
|
||||
pub history_entry: String,
|
||||
/// New facts/preferences/decisions to store long-term, or None.
|
||||
pub memory_update: Option<String>,
|
||||
}
|
||||
|
||||
const CONSOLIDATION_SYSTEM_PROMPT: &str = r#"You are a memory consolidation engine. Given a conversation turn, extract:
|
||||
1. "history_entry": A brief summary of what happened in this turn (1-2 sentences). Include the key topic or action.
|
||||
2. "memory_update": Any NEW facts, preferences, decisions, or commitments worth remembering long-term. Return null if nothing new was learned.
|
||||
|
||||
Respond ONLY with valid JSON: {"history_entry": "...", "memory_update": "..." or null}
|
||||
Do not include any text outside the JSON object."#;
|
||||
|
||||
/// Run two-phase LLM-driven consolidation on a conversation turn.
|
||||
///
|
||||
/// Phase 1: Write a history entry to the Daily memory category.
|
||||
/// Phase 2: Write a memory update to the Core category (if the LLM identified new facts).
|
||||
///
|
||||
/// This function is designed to be called fire-and-forget via `tokio::spawn`.
|
||||
pub async fn consolidate_turn(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
assistant_response: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let turn_text = format!("User: {user_message}\nAssistant: {assistant_response}");
|
||||
|
||||
// Truncate very long turns to avoid wasting tokens on consolidation.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8 (e.g. CJK text).
|
||||
let truncated = if turn_text.len() > 4000 {
|
||||
let mut end = 4000;
|
||||
while end > 0 && !turn_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}…", &turn_text[..end])
|
||||
} else {
|
||||
turn_text.clone()
|
||||
};
|
||||
|
||||
let raw = provider
|
||||
.chat_with_system(Some(CONSOLIDATION_SYSTEM_PROMPT), &truncated, model, 0.1)
|
||||
.await?;
|
||||
|
||||
let result: ConsolidationResult = parse_consolidation_response(&raw, &turn_text);
|
||||
|
||||
// Phase 1: Write history entry to Daily category.
|
||||
let date = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||||
let history_key = format!("daily_{date}_{}", uuid::Uuid::new_v4());
|
||||
memory
|
||||
.store(
|
||||
&history_key,
|
||||
&result.history_entry,
|
||||
MemoryCategory::Daily,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Phase 2: Write memory update to Core category (if present).
|
||||
if let Some(ref update) = result.memory_update {
|
||||
if !update.trim().is_empty() {
|
||||
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
|
||||
memory
|
||||
.store(&mem_key, update, MemoryCategory::Core, None)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse the LLM's consolidation response, with fallback for malformed JSON.
|
||||
fn parse_consolidation_response(raw: &str, fallback_text: &str) -> ConsolidationResult {
|
||||
// Try to extract JSON from the response (LLM may wrap in markdown code blocks).
|
||||
let cleaned = raw
|
||||
.trim()
|
||||
.trim_start_matches("```json")
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim();
|
||||
|
||||
serde_json::from_str(cleaned).unwrap_or_else(|_| {
|
||||
// Fallback: use truncated turn text as history entry.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let summary = if fallback_text.len() > 200 {
|
||||
let mut end = 200;
|
||||
while end > 0 && !fallback_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}…", &fallback_text[..end])
|
||||
} else {
|
||||
fallback_text.to_string()
|
||||
};
|
||||
ConsolidationResult {
|
||||
history_entry: summary,
|
||||
memory_update: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_valid_json_response() {
|
||||
let raw = r#"{"history_entry": "User asked about Rust.", "memory_update": "User prefers Rust over Go."}"#;
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "User asked about Rust.");
|
||||
assert_eq!(
|
||||
result.memory_update.as_deref(),
|
||||
Some("User prefers Rust over Go.")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_with_null_memory() {
|
||||
let raw = r#"{"history_entry": "Routine greeting.", "memory_update": null}"#;
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "Routine greeting.");
|
||||
assert!(result.memory_update.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_wrapped_in_code_block() {
|
||||
let raw =
|
||||
"```json\n{\"history_entry\": \"Discussed deployment.\", \"memory_update\": null}\n```";
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "Discussed deployment.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_on_malformed_response() {
|
||||
let raw = "I'm sorry, I can't do that.";
|
||||
let result = parse_consolidation_response(raw, "User: hello\nAssistant: hi");
|
||||
assert_eq!(result.history_entry, "User: hello\nAssistant: hi");
|
||||
assert!(result.memory_update.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_truncates_long_text() {
|
||||
let long_text = "x".repeat(500);
|
||||
let result = parse_consolidation_response("invalid", &long_text);
|
||||
// 200 bytes + "…" (3 bytes in UTF-8) = 203
|
||||
assert!(result.history_entry.len() <= 203);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_truncates_cjk_text_without_panic() {
|
||||
// Each CJK character is 3 bytes in UTF-8; byte index 200 may land
|
||||
// inside a character. This must not panic.
|
||||
let cjk_text = "二手书项目".repeat(50); // 250 chars = 750 bytes
|
||||
let result = parse_consolidation_response("invalid", &cjk_text);
|
||||
assert!(result
|
||||
.history_entry
|
||||
.is_char_boundary(result.history_entry.len()));
|
||||
assert!(result.history_entry.ends_with('…'));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod consolidation;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod lucid;
|
||||
|
||||
+2
-1
@@ -4,7 +4,7 @@ pub mod wizard;
|
||||
#[allow(unused_imports)]
|
||||
pub use wizard::{
|
||||
run_channels_repair_wizard, run_models_list, run_models_refresh, run_models_refresh_all,
|
||||
run_models_set, run_models_status, run_quick_setup,
|
||||
run_models_set, run_models_status, run_quick_setup, run_wizard,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -17,6 +17,7 @@ mod tests {
|
||||
fn wizard_functions_are_reexported() {
|
||||
assert_reexport_exists(run_channels_repair_wizard);
|
||||
assert_reexport_exists(run_quick_setup);
|
||||
assert_reexport_exists(run_wizard);
|
||||
assert_reexport_exists(run_models_refresh);
|
||||
assert_reexport_exists(run_models_list);
|
||||
assert_reexport_exists(run_models_set);
|
||||
|
||||
@@ -170,6 +170,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: hardware_config,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -527,6 +528,7 @@ async fn run_quick_setup_with_home(
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: crate::config::HardwareConfig::default(),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -4147,6 +4149,23 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
.interact()?;
|
||||
|
||||
if mode_idx == 0 {
|
||||
// Compile-time check: warn early if the feature is not enabled.
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
println!();
|
||||
println!(
|
||||
" {} {}",
|
||||
style("⚠").yellow().bold(),
|
||||
style("The 'whatsapp-web' feature is not compiled in. WhatsApp Web will not work at runtime.").yellow()
|
||||
);
|
||||
println!(
|
||||
" {} Rebuild with: {}",
|
||||
style("→").dim(),
|
||||
style("cargo build --features whatsapp-web").white().bold()
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
println!(" {}", style("Mode: WhatsApp Web").dim());
|
||||
print_bullet("1. Build with --features whatsapp-web");
|
||||
print_bullet(
|
||||
|
||||
@@ -500,19 +500,23 @@ struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
function: Option<Function>,
|
||||
|
||||
// Compatibility: Some providers (e.g., older GLM) may use 'name' directly
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
arguments: Option<String>,
|
||||
|
||||
// Compatibility: DeepSeek sometimes wraps arguments differently
|
||||
#[serde(rename = "parameters", default)]
|
||||
#[serde(
|
||||
rename = "parameters",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
@@ -3094,4 +3098,50 @@ mod tests {
|
||||
// Should not panic
|
||||
let _client = p.http_client();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_none_fields_omitted_from_json() {
|
||||
// Ensures providers like Mistral that reject extra fields (e.g. "name": null)
|
||||
// don't receive them when the ToolCall compat fields are None.
|
||||
let tc = ToolCall {
|
||||
id: Some("call_1".to_string()),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert!(!json.as_object().unwrap().contains_key("name"));
|
||||
assert!(!json.as_object().unwrap().contains_key("arguments"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
// Standard fields must be present
|
||||
assert!(json.as_object().unwrap().contains_key("id"));
|
||||
assert!(json.as_object().unwrap().contains_key("type"));
|
||||
assert!(json.as_object().unwrap().contains_key("function"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_with_compat_fields_serializes_them() {
|
||||
// When compat fields are Some, they should appear in the output.
|
||||
let tc = ToolCall {
|
||||
id: None,
|
||||
kind: None,
|
||||
function: None,
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert_eq!(json["name"], "shell");
|
||||
assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
|
||||
// None fields should be omitted
|
||||
assert!(!json.as_object().unwrap().contains_key("id"));
|
||||
assert!(!json.as_object().unwrap().contains_key("type"));
|
||||
assert!(!json.as_object().unwrap().contains_key("function"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::multimodal;
|
||||
use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities};
|
||||
use crate::providers::ProviderRuntimeOptions;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -472,8 +473,24 @@ fn extract_stream_error_message(event: &Value) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Read the response body incrementally via `bytes_stream()` to avoid
|
||||
/// buffering the entire SSE payload in memory. The previous implementation
|
||||
/// used `response.text().await?` which holds the HTTP connection open until
|
||||
/// every byte has arrived — on high-latency links the long-lived connection
|
||||
/// often drops mid-read, producing the "error decoding response body" failure
|
||||
/// reported in #3544.
|
||||
async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<String> {
|
||||
let body = response.text().await?;
|
||||
let mut body = String::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let bytes = chunk
|
||||
.map_err(|err| anyhow::anyhow!("error reading OpenAI Codex response stream: {err}"))?;
|
||||
let text = std::str::from_utf8(&bytes).map_err(|err| {
|
||||
anyhow::anyhow!("OpenAI Codex response contained invalid UTF-8: {err}")
|
||||
})?;
|
||||
body.push_str(text);
|
||||
}
|
||||
|
||||
if let Some(text) = parse_sse_text(&body)? {
|
||||
return Ok(text);
|
||||
|
||||
+16
-1
@@ -67,7 +67,13 @@ pub fn redact(value: &str) -> String {
|
||||
if value.len() <= 4 {
|
||||
"***".to_string()
|
||||
} else {
|
||||
format!("{}***", &value[..4])
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let prefix = value
|
||||
.char_indices()
|
||||
.nth(4)
|
||||
.map(|(byte_idx, _)| &value[..byte_idx])
|
||||
.unwrap_or(value);
|
||||
format!("{}***", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,4 +108,13 @@ mod tests {
|
||||
assert_eq!(redact(""), "***");
|
||||
assert_eq!(redact("12345"), "1234***");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redact_handles_multibyte_utf8_without_panic() {
|
||||
// CJK characters are 3 bytes each; slicing at byte 4 would panic
|
||||
// without char-boundary-safe handling.
|
||||
let result = redact("密码是很长的秘密");
|
||||
assert!(result.ends_with("***"));
|
||||
assert!(result.is_char_boundary(result.len()));
|
||||
}
|
||||
}
|
||||
|
||||
+144
-3
@@ -793,6 +793,8 @@ impl SecurityPolicy {
|
||||
// 1. Allowlist check (is the base command permitted at all?)
|
||||
// 2. Risk classification (high / medium / low)
|
||||
// 3. Policy flags (block_high_risk_commands, require_approval_for_medium_risk)
|
||||
// — explicit allowlist entries exempt a command from the high-risk block,
|
||||
// but the wildcard "*" does NOT grant an exemption.
|
||||
// 4. Autonomy level × approval status (supervised requires explicit approval)
|
||||
// This ordering ensures deny-by-default: unknown commands are rejected
|
||||
// before any risk or autonomy logic runs.
|
||||
@@ -810,7 +812,7 @@ impl SecurityPolicy {
|
||||
let risk = self.command_risk_level(command);
|
||||
|
||||
if risk == CommandRiskLevel::High {
|
||||
if self.block_high_risk_commands {
|
||||
if self.block_high_risk_commands && !self.is_command_explicitly_allowed(command) {
|
||||
return Err("Command blocked: high-risk command is disallowed by policy".into());
|
||||
}
|
||||
if self.autonomy == AutonomyLevel::Supervised && !approved {
|
||||
@@ -834,6 +836,48 @@ impl SecurityPolicy {
|
||||
Ok(risk)
|
||||
}
|
||||
|
||||
/// Check whether **every** segment of a command is explicitly listed in
|
||||
/// `allowed_commands` — i.e., matched by a concrete entry rather than by
|
||||
/// the wildcard `"*"`.
|
||||
///
|
||||
/// This is used to exempt explicitly-allowlisted high-risk commands from
|
||||
/// the `block_high_risk_commands` gate. The wildcard entry intentionally
|
||||
/// does **not** qualify as an explicit allowlist match, so that operators
|
||||
/// who set `allowed_commands = ["*"]` still get the high-risk safety net.
|
||||
fn is_command_explicitly_allowed(&self, command: &str) -> bool {
|
||||
let segments = split_unquoted_segments(command);
|
||||
for segment in &segments {
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
|
||||
let base_cmd_owned = command_basename(executable).to_ascii_lowercase();
|
||||
let base_cmd = strip_windows_exe_suffix(&base_cmd_owned);
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let explicitly_listed = self.allowed_commands.iter().any(|allowed| {
|
||||
let allowed = strip_wrapping_quotes(allowed).trim();
|
||||
// Skip wildcard — it does not count as an explicit entry.
|
||||
if allowed.is_empty() || allowed == "*" {
|
||||
return false;
|
||||
}
|
||||
is_allowlist_entry_match(allowed, executable, base_cmd)
|
||||
});
|
||||
|
||||
if !explicitly_listed {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At least one real command must be present.
|
||||
segments.iter().any(|s| {
|
||||
let s = skip_env_assignments(s.trim());
|
||||
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
|
||||
})
|
||||
}
|
||||
|
||||
// ── Layered Command Allowlist ──────────────────────────────────────────
|
||||
// Defence-in-depth: five independent gates run in order before the
|
||||
// per-segment allowlist check. Each gate targets a specific bypass
|
||||
@@ -1503,10 +1547,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_blocks_high_risk_by_default() {
|
||||
fn validate_command_blocks_high_risk_via_wildcard() {
|
||||
// Wildcard allows the command through is_command_allowed, but
|
||||
// block_high_risk_commands still rejects it because "*" does not
|
||||
// count as an explicit allowlist entry.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
allowed_commands: vec!["*".into()],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
@@ -1515,6 +1562,100 @@ mod tests {
|
||||
assert!(result.unwrap_err().contains("high-risk"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_allows_explicitly_listed_high_risk() {
|
||||
// When a high-risk command is explicitly in allowed_commands, the
|
||||
// block_high_risk_commands gate is bypassed — the operator has made
|
||||
// a deliberate decision to permit it.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("curl https://api.example.com/data", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_allows_wget_when_explicitly_listed() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["wget".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result =
|
||||
p.validate_command_execution("wget https://releases.example.com/v1.tar.gz", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_blocks_non_listed_high_risk_when_another_is_allowed() {
|
||||
// Allowing curl explicitly should not exempt wget.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("wget https://evil.com", true);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_explicit_rm_bypasses_high_risk_block() {
|
||||
// Operator explicitly listed "rm" — they accept the risk.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("rm -rf /tmp/test", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_high_risk_still_needs_approval_in_supervised() {
|
||||
// Even when explicitly allowed, supervised mode still requires
|
||||
// approval for high-risk commands (the approval gate is separate
|
||||
// from the block gate).
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let denied = p.validate_command_execution("curl https://api.example.com", false);
|
||||
assert!(denied.is_err());
|
||||
assert!(denied.unwrap_err().contains("requires explicit approval"));
|
||||
|
||||
let allowed = p.validate_command_execution("curl https://api.example.com", true);
|
||||
assert_eq!(allowed.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_pipe_needs_all_segments_explicitly_allowed() {
|
||||
// When a pipeline contains a high-risk command, every segment
|
||||
// must be explicitly allowed for the exemption to apply.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into(), "grep".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("curl https://api.example.com | grep data", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_full_mode_skips_medium_risk_approval_gate() {
|
||||
let p = SecurityPolicy {
|
||||
|
||||
+91
-6
@@ -442,8 +442,24 @@ fn install_linux_systemd(config: &Config) -> Result<()> {
|
||||
|
||||
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
|
||||
let unit = format!(
|
||||
"[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n",
|
||||
exe.display()
|
||||
"[Unit]\n\
|
||||
Description=ZeroClaw daemon\n\
|
||||
After=network.target\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
ExecStart={exe} daemon\n\
|
||||
Restart=always\n\
|
||||
RestartSec=3\n\
|
||||
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
|
||||
Environment=HOME=%h\n\
|
||||
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
|
||||
# so graphical/headless browsers can function correctly.\n\
|
||||
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n",
|
||||
exe = exe.display()
|
||||
);
|
||||
|
||||
fs::write(&file, unit)?;
|
||||
@@ -826,8 +842,8 @@ fn generate_openrc_script(exe_path: &Path, config_dir: &Path) -> String {
|
||||
name="zeroclaw"
|
||||
description="ZeroClaw daemon"
|
||||
|
||||
command="{}"
|
||||
command_args="--config-dir {} daemon"
|
||||
command="{exe}"
|
||||
command_args="--config-dir {config_dir} daemon"
|
||||
command_background="yes"
|
||||
command_user="zeroclaw:zeroclaw"
|
||||
pidfile="/run/${{RC_SVCNAME}}.pid"
|
||||
@@ -835,13 +851,21 @@ umask 027
|
||||
output_log="/var/log/zeroclaw/access.log"
|
||||
error_log="/var/log/zeroclaw/error.log"
|
||||
|
||||
# Provide HOME so headless browsers can create profile/cache directories.
|
||||
# Without this, Chromium/Firefox fail with sandbox or profile errors.
|
||||
export HOME="/var/lib/zeroclaw"
|
||||
|
||||
depend() {{
|
||||
need net
|
||||
after firewall
|
||||
}}
|
||||
|
||||
start_pre() {{
|
||||
checkpath --directory --owner zeroclaw:zeroclaw --mode 0750 /var/lib/zeroclaw
|
||||
}}
|
||||
"#,
|
||||
exe_path.display(),
|
||||
config_dir.display()
|
||||
exe = exe_path.display(),
|
||||
config_dir = config_dir.display(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1196,6 +1220,67 @@ mod tests {
|
||||
assert!(script.contains("after firewall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_openrc_script_sets_home_for_browser() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
|
||||
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
|
||||
|
||||
assert!(
|
||||
script.contains("export HOME=\"/var/lib/zeroclaw\""),
|
||||
"OpenRC script must set HOME for headless browser support"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_openrc_script_creates_home_directory() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let exe_path = PathBuf::from("/usr/local/bin/zeroclaw");
|
||||
let script = generate_openrc_script(&exe_path, Path::new("/etc/zeroclaw"));
|
||||
|
||||
assert!(
|
||||
script.contains("start_pre()"),
|
||||
"OpenRC script must have start_pre to create HOME dir"
|
||||
);
|
||||
assert!(
|
||||
script.contains("checkpath --directory --owner zeroclaw:zeroclaw"),
|
||||
"start_pre must ensure /var/lib/zeroclaw exists with correct ownership"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn systemd_unit_contains_home_and_pass_environment() {
|
||||
let unit = "[Unit]\n\
|
||||
Description=ZeroClaw daemon\n\
|
||||
After=network.target\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
ExecStart=/usr/local/bin/zeroclaw daemon\n\
|
||||
Restart=always\n\
|
||||
RestartSec=3\n\
|
||||
# Ensure HOME is set so headless browsers can create profile/cache dirs.\n\
|
||||
Environment=HOME=%h\n\
|
||||
# Allow inheriting DISPLAY and XDG_RUNTIME_DIR from the user session\n\
|
||||
# so graphical/headless browsers can function correctly.\n\
|
||||
PassEnvironment=DISPLAY XDG_RUNTIME_DIR\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n"
|
||||
.to_string();
|
||||
|
||||
assert!(
|
||||
unit.contains("Environment=HOME=%h"),
|
||||
"systemd unit must set HOME for headless browser support"
|
||||
);
|
||||
assert!(
|
||||
unit.contains("PassEnvironment=DISPLAY XDG_RUNTIME_DIR"),
|
||||
"systemd unit must pass through display/runtime env vars"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warn_if_binary_in_home_detects_home_path() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -440,6 +440,12 @@ impl BrowserTool {
|
||||
async fn run_command(&self, args: &[&str]) -> anyhow::Result<AgentBrowserResponse> {
|
||||
let mut cmd = Command::new("agent-browser");
|
||||
|
||||
// When running as a service (systemd/OpenRC), the process may lack
|
||||
// HOME which browsers need for profile directories.
|
||||
if is_service_environment() {
|
||||
ensure_browser_env(&mut cmd);
|
||||
}
|
||||
|
||||
// Add session if configured
|
||||
if let Some(ref session) = self.session_name {
|
||||
cmd.arg("--session").arg(session);
|
||||
@@ -1461,6 +1467,14 @@ mod native_backend {
|
||||
args.push(Value::String("--disable-gpu".to_string()));
|
||||
}
|
||||
|
||||
// When running as a service (systemd/OpenRC), the browser sandbox
|
||||
// fails because the process lacks a user namespace / session.
|
||||
// --no-sandbox and --disable-dev-shm-usage are required in this context.
|
||||
if is_service_environment() {
|
||||
args.push(Value::String("--no-sandbox".to_string()));
|
||||
args.push(Value::String("--disable-dev-shm-usage".to_string()));
|
||||
}
|
||||
|
||||
if !args.is_empty() {
|
||||
chrome_options.insert("args".to_string(), Value::Array(args));
|
||||
}
|
||||
@@ -2111,6 +2125,44 @@ fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
|
||||
|| v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
|
||||
}
|
||||
|
||||
/// Detect whether the current process is running inside a service environment
|
||||
/// (e.g. systemd, OpenRC, or launchd) where the browser sandbox and
|
||||
/// environment setup may be restricted.
|
||||
fn is_service_environment() -> bool {
|
||||
if std::env::var_os("INVOCATION_ID").is_some() {
|
||||
return true;
|
||||
}
|
||||
if std::env::var_os("JOURNAL_STREAM").is_some() {
|
||||
return true;
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if std::path::Path::new("/run/openrc").exists() && std::env::var_os("HOME").is_none() {
|
||||
return true;
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if std::env::var_os("HOME").is_none() {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Ensure environment variables required by headless browsers are present
|
||||
/// when running inside a service context.
|
||||
fn ensure_browser_env(cmd: &mut Command) {
|
||||
if std::env::var_os("HOME").is_none() {
|
||||
cmd.env("HOME", "/tmp");
|
||||
}
|
||||
let existing = std::env::var("CHROMIUM_FLAGS").unwrap_or_default();
|
||||
if !existing.contains("--no-sandbox") {
|
||||
let new_flags = if existing.is_empty() {
|
||||
"--no-sandbox --disable-dev-shm-usage".to_string()
|
||||
} else {
|
||||
format!("{existing} --no-sandbox --disable-dev-shm-usage")
|
||||
};
|
||||
cmd.env("CHROMIUM_FLAGS", new_flags);
|
||||
}
|
||||
}
|
||||
|
||||
fn host_matches_allowlist(host: &str, allowed: &[String]) -> bool {
|
||||
allowed.iter().any(|pattern| {
|
||||
if pattern == "*" {
|
||||
@@ -2492,4 +2544,78 @@ mod tests {
|
||||
state.reset_session().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_browser_env_sets_home_when_missing() {
|
||||
let original_home = std::env::var_os("HOME");
|
||||
unsafe { std::env::remove_var("HOME") };
|
||||
|
||||
let mut cmd = Command::new("true");
|
||||
ensure_browser_env(&mut cmd);
|
||||
// Function completes without panic — HOME and CHROMIUM_FLAGS set on cmd.
|
||||
|
||||
if let Some(home) = original_home {
|
||||
unsafe { std::env::set_var("HOME", home) };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_browser_env_sets_chromium_flags() {
|
||||
let original = std::env::var_os("CHROMIUM_FLAGS");
|
||||
unsafe { std::env::remove_var("CHROMIUM_FLAGS") };
|
||||
|
||||
let mut cmd = Command::new("true");
|
||||
ensure_browser_env(&mut cmd);
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("CHROMIUM_FLAGS", val) };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_detects_invocation_id() {
|
||||
let original = std::env::var_os("INVOCATION_ID");
|
||||
unsafe { std::env::set_var("INVOCATION_ID", "test-unit-id") };
|
||||
|
||||
assert!(is_service_environment());
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("INVOCATION_ID", val) };
|
||||
} else {
|
||||
unsafe { std::env::remove_var("INVOCATION_ID") };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_detects_journal_stream() {
|
||||
let original = std::env::var_os("JOURNAL_STREAM");
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", "8:12345") };
|
||||
|
||||
assert!(is_service_environment());
|
||||
|
||||
if let Some(val) = original {
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
|
||||
} else {
|
||||
unsafe { std::env::remove_var("JOURNAL_STREAM") };
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_service_environment_false_in_normal_context() {
|
||||
let inv = std::env::var_os("INVOCATION_ID");
|
||||
let journal = std::env::var_os("JOURNAL_STREAM");
|
||||
unsafe { std::env::remove_var("INVOCATION_ID") };
|
||||
unsafe { std::env::remove_var("JOURNAL_STREAM") };
|
||||
|
||||
if std::env::var_os("HOME").is_some() {
|
||||
assert!(!is_service_environment());
|
||||
}
|
||||
|
||||
if let Some(val) = inv {
|
||||
unsafe { std::env::set_var("INVOCATION_ID", val) };
|
||||
}
|
||||
if let Some(val) = journal {
|
||||
unsafe { std::env::set_var("JOURNAL_STREAM", val) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ pub struct HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
}
|
||||
|
||||
impl HttpRequestTool {
|
||||
@@ -20,12 +21,14 @@ impl HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
allowed_domains: normalize_allowed_domains(allowed_domains),
|
||||
max_response_size,
|
||||
timeout_secs,
|
||||
allow_private_hosts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +55,7 @@ impl HttpRequestTool {
|
||||
|
||||
let host = extract_host(url)?;
|
||||
|
||||
if is_private_or_local_host(&host) {
|
||||
if !self.allow_private_hosts && is_private_or_local_host(&host) {
|
||||
anyhow::bail!("Blocked local/private host: {host}");
|
||||
}
|
||||
|
||||
@@ -454,6 +457,13 @@ mod tests {
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool {
|
||||
test_tool_with_private(allowed_domains, false)
|
||||
}
|
||||
|
||||
fn test_tool_with_private(
|
||||
allowed_domains: Vec<&str>,
|
||||
allow_private_hosts: bool,
|
||||
) -> HttpRequestTool {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
..SecurityPolicy::default()
|
||||
@@ -463,6 +473,7 @@ mod tests {
|
||||
allowed_domains.into_iter().map(String::from).collect(),
|
||||
1_000_000,
|
||||
30,
|
||||
allow_private_hosts,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -570,7 +581,7 @@ mod tests {
|
||||
#[test]
|
||||
fn validate_requires_allowlist() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30, false);
|
||||
let err = tool
|
||||
.validate_url("https://example.com")
|
||||
.unwrap_err()
|
||||
@@ -686,7 +697,7 @@ mod tests {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({"url": "https://example.com"}))
|
||||
.await
|
||||
@@ -701,7 +712,7 @@ mod tests {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({"url": "https://example.com"}))
|
||||
.await
|
||||
@@ -724,6 +735,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
10,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "hello world this is long";
|
||||
let truncated = tool.truncate_response(text);
|
||||
@@ -738,6 +750,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
0, // max_response_size = 0 means no limit
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "a".repeat(10_000_000);
|
||||
assert_eq!(tool.truncate_response(&text), text);
|
||||
@@ -750,6 +763,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
5,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "hello world";
|
||||
let truncated = tool.truncate_response(text);
|
||||
@@ -935,4 +949,70 @@ mod tests {
|
||||
.to_string();
|
||||
assert!(err.contains("IPv6"));
|
||||
}
|
||||
|
||||
// ── allow_private_hosts opt-in tests ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn default_blocks_private_hosts() {
|
||||
let tool = test_tool(vec!["localhost", "192.168.1.5", "*"]);
|
||||
assert!(tool
|
||||
.validate_url("https://localhost:8080")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
assert!(tool
|
||||
.validate_url("https://192.168.1.5")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
assert!(tool
|
||||
.validate_url("https://10.0.0.1")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_localhost() {
|
||||
let tool = test_tool_with_private(vec!["localhost"], true);
|
||||
assert!(tool.validate_url("https://localhost:8080").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_private_ipv4() {
|
||||
let tool = test_tool_with_private(vec!["192.168.1.5"], true);
|
||||
assert!(tool.validate_url("https://192.168.1.5").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_rfc1918_with_wildcard() {
|
||||
let tool = test_tool_with_private(vec!["*"], true);
|
||||
assert!(tool.validate_url("https://10.0.0.1").is_ok());
|
||||
assert!(tool.validate_url("https://172.16.0.1").is_ok());
|
||||
assert!(tool.validate_url("https://192.168.1.1").is_ok());
|
||||
assert!(tool.validate_url("http://localhost:8123").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_still_requires_allowlist() {
|
||||
let tool = test_tool_with_private(vec!["example.com"], true);
|
||||
let err = tool
|
||||
.validate_url("https://192.168.1.5")
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
assert!(
|
||||
err.contains("allowed_domains"),
|
||||
"Private host should still need allowlist match, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_false_still_blocks() {
|
||||
let tool = test_tool_with_private(vec!["*"], false);
|
||||
assert!(tool
|
||||
.validate_url("https://localhost:8080")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
}
|
||||
}
|
||||
|
||||
+38
-18
@@ -57,6 +57,7 @@ pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod tool_search;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
@@ -103,6 +104,7 @@ pub use schedule::ScheduleTool;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
pub use traits::Tool;
|
||||
#[allow(unused_imports)]
|
||||
@@ -314,6 +316,7 @@ pub fn all_tools_with_runtime(
|
||||
http_config.allowed_domains.clone(),
|
||||
http_config.max_response_size,
|
||||
http_config.timeout_secs,
|
||||
http_config.allow_private_hosts,
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -357,6 +360,24 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
|
||||
// Add delegation tool when agents are configured
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
};
|
||||
|
||||
let delegate_handle: Option<DelegateParentToolsHandle> = if agents.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -364,28 +385,12 @@ pub fn all_tools_with_runtime(
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let parent_tools = Arc::new(RwLock::new(tool_arcs.clone()));
|
||||
let delegate_tool = DelegateTool::new_with_options(
|
||||
delegate_agents,
|
||||
delegate_fallback_credential,
|
||||
delegate_fallback_credential.clone(),
|
||||
security.clone(),
|
||||
crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
},
|
||||
provider_runtime_options.clone(),
|
||||
)
|
||||
.with_parent_tools(Arc::clone(&parent_tools))
|
||||
.with_multimodal_config(root_config.multimodal.clone());
|
||||
@@ -393,6 +398,21 @@ pub fn all_tools_with_runtime(
|
||||
Some(parent_tools)
|
||||
};
|
||||
|
||||
// Add swarm tool when swarms are configured
|
||||
if !root_config.swarms.is_empty() {
|
||||
let swarm_agents: HashMap<String, DelegateAgentConfig> = agents
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
tool_arcs.push(Arc::new(SwarmTool::new(
|
||||
root_config.swarms.clone(),
|
||||
swarm_agents,
|
||||
delegate_fallback_credential,
|
||||
security.clone(),
|
||||
provider_runtime_options,
|
||||
)));
|
||||
}
|
||||
|
||||
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,953 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::{DelegateAgentConfig, SwarmConfig, SwarmStrategy};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default timeout for individual agent calls within a swarm.
|
||||
const SWARM_AGENT_TIMEOUT_SECS: u64 = 120;
|
||||
|
||||
/// Tool that orchestrates multiple agents as a swarm. Supports sequential
|
||||
/// (pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies.
|
||||
pub struct SwarmTool {
|
||||
swarms: Arc<HashMap<String, SwarmConfig>>,
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
fallback_credential: Option<String>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
}
|
||||
|
||||
impl SwarmTool {
|
||||
pub fn new(
|
||||
swarms: HashMap<String, SwarmConfig>,
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_credential: Option<String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
) -> Self {
|
||||
Self {
|
||||
swarms: Arc::new(swarms),
|
||||
agents: Arc::new(agents),
|
||||
security,
|
||||
fallback_credential,
|
||||
provider_runtime_options,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider_for_agent(
|
||||
&self,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
agent_name: &str,
|
||||
) -> Result<Box<dyn Provider>, ToolResult> {
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
)
|
||||
.map_err(|e| ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider '{}' for agent '{agent_name}': {e}",
|
||||
agent_config.provider
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn call_agent(
|
||||
&self,
|
||||
agent_name: &str,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
prompt: &str,
|
||||
timeout_secs: u64,
|
||||
) -> Result<String, String> {
|
||||
let provider = self
|
||||
.create_provider_for_agent(agent_config, agent_name)
|
||||
.map_err(|r| r.error.unwrap_or_default())?;
|
||||
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
provider.chat_with_system(
|
||||
agent_config.system_prompt.as_deref(),
|
||||
prompt,
|
||||
&agent_config.model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(response)) => {
|
||||
if response.trim().is_empty() {
|
||||
Ok("[Empty response]".to_string())
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => Err(format!("Agent '{agent_name}' failed: {e}")),
|
||||
Err(_) => Err(format!(
|
||||
"Agent '{agent_name}' timed out after {timeout_secs}s"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_sequential(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let mut current_input = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let per_agent_timeout = swarm_config.timeout_secs / swarm_config.agents.len().max(1) as u64;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (i, agent_name) in swarm_config.agents.iter().enumerate() {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let agent_prompt = if i == 0 {
|
||||
current_input.clone()
|
||||
} else {
|
||||
format!("[Previous agent output]\n{current_input}\n\n[Original task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(agent_name, agent_config, &agent_prompt, per_agent_timeout)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
results.push(format!(
|
||||
"[{agent_name} ({}/{})] {output}",
|
||||
agent_config.provider, agent_config.model
|
||||
));
|
||||
current_input = output;
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: results.join("\n\n"),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm sequential — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_parallel(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for agent_name in &swarm_config.agents {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg.clone(),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
let provider = match providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider for agent '{agent_name}': {e}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let name = agent_name.clone();
|
||||
let prompt_clone = full_prompt.clone();
|
||||
let timeout = swarm_config.timeout_secs;
|
||||
let model = agent_config.model.clone();
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
let system_prompt = agent_config.system_prompt.clone();
|
||||
let provider_name = agent_config.provider.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout),
|
||||
provider.chat_with_system(
|
||||
system_prompt.as_deref(),
|
||||
&prompt_clone,
|
||||
&model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let output = match result {
|
||||
Ok(Ok(text)) => {
|
||||
if text.trim().is_empty() {
|
||||
"[Empty response]".to_string()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => format!("[Error] {e}"),
|
||||
Err(_) => format!("[Timed out after {timeout}s]"),
|
||||
};
|
||||
|
||||
(name, provider_name, model, output)
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(join_result) = join_set.join_next().await {
|
||||
match join_result {
|
||||
Ok((name, provider_name, model, output)) => {
|
||||
results.push(format!("[{name} ({provider_name}/{model})]\n{output}"));
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(format!("[join error] {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm parallel — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n---\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_router(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router swarm has no agents to choose from".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build agent descriptions for the router prompt
|
||||
let agent_descriptions: Vec<String> = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.filter_map(|name| {
|
||||
self.agents.get(name).map(|cfg| {
|
||||
let desc = cfg
|
||||
.system_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("General purpose agent");
|
||||
format!(
|
||||
"- {name}: {desc} (provider: {}, model: {})",
|
||||
cfg.provider, cfg.model
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Use the first agent's provider for routing
|
||||
let first_agent_name = &swarm_config.agents[0];
|
||||
let first_agent_config = match self.agents.get(first_agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Swarm references unknown agent '{first_agent_name}'"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let router_provider = self
|
||||
.create_provider_for_agent(first_agent_config, first_agent_name)
|
||||
.map_err(|r| anyhow::anyhow!(r.error.unwrap_or_default()))?;
|
||||
|
||||
let base_router_prompt = swarm_config
|
||||
.router_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("Pick the single best agent for this task.");
|
||||
|
||||
let routing_prompt = format!(
|
||||
"{base_router_prompt}\n\nAvailable agents:\n{}\n\nUser task: {prompt}\n\n\
|
||||
Respond with ONLY the agent name, nothing else.",
|
||||
agent_descriptions.join("\n")
|
||||
);
|
||||
|
||||
let chosen = tokio::time::timeout(
|
||||
Duration::from_secs(SWARM_AGENT_TIMEOUT_SECS),
|
||||
router_provider.chat_with_system(
|
||||
Some("You are a routing assistant. Respond with only the agent name."),
|
||||
&routing_prompt,
|
||||
&first_agent_config.model,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let chosen_name = match chosen {
|
||||
Ok(Ok(name)) => name.trim().to_string(),
|
||||
Ok(Err(e)) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router LLM call failed: {e}")),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router LLM call timed out".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Case-insensitive matching with fallback to first agent
|
||||
let matched_name = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.find(|name| name.eq_ignore_ascii_case(&chosen_name))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| swarm_config.agents[0].clone());
|
||||
|
||||
let agent_config = match self.agents.get(&matched_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router selected unknown agent '{matched_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(
|
||||
&matched_name,
|
||||
agent_config,
|
||||
&full_prompt,
|
||||
swarm_config.timeout_secs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm router — selected '{matched_name}' ({}/{})]\n{output}",
|
||||
agent_config.provider, agent_config.model
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SwarmTool {
|
||||
fn name(&self) -> &str {
|
||||
"swarm"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Orchestrate a swarm of agents to collaboratively handle a task. Supports sequential \
|
||||
(pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
let swarm_names: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
json!({
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"swarm": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": format!(
|
||||
"Name of the swarm to invoke. Available: {}",
|
||||
if swarm_names.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
swarm_names.join(", ")
|
||||
}
|
||||
)
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "The task/prompt to send to the swarm"
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Optional context to include (e.g. relevant code, prior findings)"
|
||||
}
|
||||
},
|
||||
"required": ["swarm", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let swarm_name = args
|
||||
.get("swarm")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'swarm' parameter"))?;
|
||||
|
||||
if swarm_name.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'swarm' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
if prompt.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'prompt' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let context = args
|
||||
.get("context")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.unwrap_or("");
|
||||
|
||||
let swarm_config = match self.swarms.get(swarm_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
let available: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown swarm '{swarm_name}'. Available swarms: {}",
|
||||
if available.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
available.join(", ")
|
||||
}
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm '{swarm_name}' has no agents configured")),
|
||||
});
|
||||
}
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "swarm")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match swarm_config.strategy {
|
||||
SwarmStrategy::Sequential => {
|
||||
self.execute_sequential(swarm_config, prompt, context).await
|
||||
}
|
||||
SwarmStrategy::Parallel => self.execute_parallel(swarm_config, prompt, context).await,
|
||||
SwarmStrategy::Router => self.execute_router(swarm_config, prompt, context).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn sample_agents() -> HashMap<String, DelegateAgentConfig> {
|
||||
let mut agents = HashMap::new();
|
||||
agents.insert(
|
||||
"researcher".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "llama3".to_string(),
|
||||
system_prompt: Some("You are a research assistant.".to_string()),
|
||||
api_key: None,
|
||||
temperature: Some(0.3),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
"writer".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
system_prompt: Some("You are a technical writer.".to_string()),
|
||||
api_key: Some("test-key".to_string()),
|
||||
temperature: Some(0.5),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents
|
||||
}
|
||||
|
||||
fn sample_swarms() -> HashMap<String, SwarmConfig> {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"pipeline".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: Some("Research then write".to_string()),
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"fanout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"router".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick the best agent.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert_eq!(tool.name(), "swarm");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["swarm"].is_object());
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["properties"]["context"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&json!("swarm")));
|
||||
assert!(required.contains(&json!("prompt")));
|
||||
assert_eq!(schema["additionalProperties"], json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn description_not_empty() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_lists_swarm_names() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("pipeline") || desc.contains("fanout") || desc.contains("router"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_swarms_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
HashMap::new(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("none configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_swarm_returns_error() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "nonexistent", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown swarm"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_swarm_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"swarm": "pipeline"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_swarm_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": " ", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_prompt_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": " "}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_missing_agent_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"broken".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["nonexistent_agent".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "broken", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("unknown agent"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_empty_agents_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"empty".to_string(),
|
||||
SwarmConfig {
|
||||
agents: Vec::new(),
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "empty", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("no agents configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_in_readonly_mode() {
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
readonly,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_when_rate_limited() {
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
limited,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"seq".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
// researcher uses "ollama" which won't be running in CI
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "seq", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Should fail at provider creation or call level
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"par".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "par", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Parallel strategy returns success with error annotations in output
|
||||
assert!(result.success || result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn router_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"rout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "rout", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
}
|
||||
+58
-2
@@ -2,6 +2,7 @@ mod cloudflare;
|
||||
mod custom;
|
||||
mod ngrok;
|
||||
mod none;
|
||||
mod openvpn;
|
||||
mod tailscale;
|
||||
|
||||
pub use cloudflare::CloudflareTunnel;
|
||||
@@ -9,6 +10,7 @@ pub use custom::CustomTunnel;
|
||||
pub use ngrok::NgrokTunnel;
|
||||
#[allow(unused_imports)]
|
||||
pub use none::NoneTunnel;
|
||||
pub use openvpn::OpenVpnTunnel;
|
||||
pub use tailscale::TailscaleTunnel;
|
||||
|
||||
use crate::config::schema::{TailscaleTunnelConfig, TunnelConfig};
|
||||
@@ -104,6 +106,19 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
"openvpn" => {
|
||||
let ov = config
|
||||
.openvpn
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("tunnel.provider = \"openvpn\" but [tunnel.openvpn] section is missing"))?;
|
||||
Ok(Some(Box::new(OpenVpnTunnel::new(
|
||||
ov.config_file.clone(),
|
||||
ov.auth_file.clone(),
|
||||
ov.advertise_address.clone(),
|
||||
ov.connect_timeout_secs,
|
||||
ov.extra_args.clone(),
|
||||
))))
|
||||
}
|
||||
"custom" => {
|
||||
let cu = config
|
||||
.custom
|
||||
@@ -116,7 +131,7 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
))))
|
||||
}
|
||||
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, custom"),
|
||||
other => bail!("Unknown tunnel provider: \"{other}\". Valid: none, cloudflare, tailscale, ngrok, openvpn, custom"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +141,8 @@ pub fn create_tunnel(config: &TunnelConfig) -> Result<Option<Box<dyn Tunnel>>> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::{
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig,
|
||||
CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, OpenVpnTunnelConfig,
|
||||
TunnelConfig,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
@@ -315,6 +331,46 @@ mod tests {
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_missing_config_errors() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
assert_tunnel_err(&cfg, "[tunnel.openvpn]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openvpn_with_config_ok() {
|
||||
let cfg = TunnelConfig {
|
||||
provider: "openvpn".into(),
|
||||
openvpn: Some(OpenVpnTunnelConfig {
|
||||
config_file: "client.ovpn".into(),
|
||||
auth_file: None,
|
||||
advertise_address: None,
|
||||
connect_timeout_secs: 30,
|
||||
extra_args: vec![],
|
||||
}),
|
||||
..TunnelConfig::default()
|
||||
};
|
||||
let t = create_tunnel(&cfg).unwrap();
|
||||
assert!(t.is_some());
|
||||
assert_eq!(t.unwrap().name(), "openvpn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openvpn_tunnel_name() {
|
||||
let t = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert_eq!(t.name(), "openvpn");
|
||||
assert!(t.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openvpn_health_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_shared_no_process_is_ok() {
|
||||
let proc = new_shared_process();
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
use super::{kill_shared, new_shared_process, SharedProcess, Tunnel, TunnelProcess};
|
||||
use anyhow::{bail, Result};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::process::Command;
|
||||
|
||||
/// OpenVPN Tunnel — uses the `openvpn` CLI to establish a VPN connection.
|
||||
///
|
||||
/// Requires the `openvpn` binary installed and accessible. On most systems,
|
||||
/// OpenVPN requires root/administrator privileges to create tun/tap devices.
|
||||
///
|
||||
/// The tunnel exposes the gateway via the VPN network using a configured
|
||||
/// `advertise_address` (e.g., `"10.8.0.2:42617"`).
|
||||
pub struct OpenVpnTunnel {
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
proc: SharedProcess,
|
||||
}
|
||||
|
||||
impl OpenVpnTunnel {
|
||||
/// Create a new OpenVPN tunnel instance.
|
||||
///
|
||||
/// * `config_file` — path to the `.ovpn` configuration file.
|
||||
/// * `auth_file` — optional path to a credentials file for `--auth-user-pass`.
|
||||
/// * `advertise_address` — optional public address to advertise once connected.
|
||||
/// * `connect_timeout_secs` — seconds to wait for the initialization sequence.
|
||||
/// * `extra_args` — additional CLI arguments forwarded to the `openvpn` binary.
|
||||
pub fn new(
|
||||
config_file: String,
|
||||
auth_file: Option<String>,
|
||||
advertise_address: Option<String>,
|
||||
connect_timeout_secs: u64,
|
||||
extra_args: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config_file,
|
||||
auth_file,
|
||||
advertise_address,
|
||||
connect_timeout_secs,
|
||||
extra_args,
|
||||
proc: new_shared_process(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the openvpn command arguments.
|
||||
fn build_args(&self) -> Vec<String> {
|
||||
let mut args = vec!["--config".to_string(), self.config_file.clone()];
|
||||
|
||||
if let Some(ref auth) = self.auth_file {
|
||||
args.push("--auth-user-pass".to_string());
|
||||
args.push(auth.clone());
|
||||
}
|
||||
|
||||
args.extend(self.extra_args.iter().cloned());
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tunnel for OpenVpnTunnel {
|
||||
fn name(&self) -> &str {
|
||||
"openvpn"
|
||||
}
|
||||
|
||||
/// Spawn the `openvpn` process and wait for the "Initialization Sequence
|
||||
/// Completed" marker on stderr. Returns the public URL on success.
|
||||
async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
|
||||
// Validate config file exists before spawning
|
||||
if !std::path::Path::new(&self.config_file).exists() {
|
||||
bail!("OpenVPN config file not found: {}", self.config_file);
|
||||
}
|
||||
|
||||
let args = self.build_args();
|
||||
|
||||
let mut child = Command::new("openvpn")
|
||||
.args(&args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
// Wait for "Initialization Sequence Completed" in stderr
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to capture openvpn stderr"))?;
|
||||
|
||||
let mut reader = tokio::io::BufReader::new(stderr).lines();
|
||||
let deadline = tokio::time::Instant::now()
|
||||
+ tokio::time::Duration::from_secs(self.connect_timeout_secs);
|
||||
|
||||
let mut connected = false;
|
||||
while tokio::time::Instant::now() < deadline {
|
||||
let line =
|
||||
tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line()).await;
|
||||
|
||||
match line {
|
||||
Ok(Ok(Some(l))) => {
|
||||
tracing::debug!("openvpn: {l}");
|
||||
if l.contains("Initialization Sequence Completed") {
|
||||
connected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Ok(None)) => {
|
||||
bail!("OpenVPN process exited before connection was established");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
bail!("Error reading openvpn output: {e}");
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout on individual line read, continue waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !connected {
|
||||
child.kill().await.ok();
|
||||
bail!(
|
||||
"OpenVPN connection timed out after {}s waiting for initialization",
|
||||
self.connect_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
let public_url = self
|
||||
.advertise_address
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("http://{local_host}:{local_port}"));
|
||||
|
||||
// Drain stderr in background to prevent OS pipe buffer from filling and
|
||||
// blocking the openvpn process.
|
||||
tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = reader.next_line().await {
|
||||
tracing::trace!("openvpn: {line}");
|
||||
}
|
||||
});
|
||||
|
||||
let mut guard = self.proc.lock().await;
|
||||
*guard = Some(TunnelProcess {
|
||||
child,
|
||||
public_url: public_url.clone(),
|
||||
});
|
||||
|
||||
Ok(public_url)
|
||||
}
|
||||
|
||||
/// Kill the openvpn child process and release its resources.
|
||||
async fn stop(&self) -> Result<()> {
|
||||
kill_shared(&self.proc).await
|
||||
}
|
||||
|
||||
/// Return `true` if the openvpn child process is still running.
|
||||
async fn health_check(&self) -> bool {
|
||||
let guard = self.proc.lock().await;
|
||||
guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
|
||||
}
|
||||
|
||||
/// Return the public URL if the tunnel has been started.
|
||||
fn public_url(&self) -> Option<String> {
|
||||
self.proc
|
||||
.try_lock()
|
||||
.ok()
|
||||
.and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn constructor_stores_fields() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/etc/openvpn/client.ovpn".into(),
|
||||
Some("/etc/openvpn/auth.txt".into()),
|
||||
Some("10.8.0.2:42617".into()),
|
||||
45,
|
||||
vec!["--verb".into(), "3".into()],
|
||||
);
|
||||
assert_eq!(tunnel.config_file, "/etc/openvpn/client.ovpn");
|
||||
assert_eq!(tunnel.auth_file.as_deref(), Some("/etc/openvpn/auth.txt"));
|
||||
assert_eq!(tunnel.advertise_address.as_deref(), Some("10.8.0.2:42617"));
|
||||
assert_eq!(tunnel.connect_timeout_secs, 45);
|
||||
assert_eq!(tunnel.extra_args, vec!["--verb", "3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_basic() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(args, vec!["--config", "client.ovpn"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_args_with_auth_and_extras() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"client.ovpn".into(),
|
||||
Some("auth.txt".into()),
|
||||
None,
|
||||
30,
|
||||
vec!["--verb".into(), "5".into()],
|
||||
);
|
||||
let args = tunnel.build_args();
|
||||
assert_eq!(
|
||||
args,
|
||||
vec![
|
||||
"--config",
|
||||
"client.ovpn",
|
||||
"--auth-user-pass",
|
||||
"auth.txt",
|
||||
"--verb",
|
||||
"5"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_url_is_none_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(tunnel.public_url().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_check_is_false_before_start() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
assert!(!tunnel.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stop_without_started_process_is_ok() {
|
||||
let tunnel = OpenVpnTunnel::new("client.ovpn".into(), None, None, 30, vec![]);
|
||||
let result = tunnel.stop().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_with_missing_config_file_errors() {
|
||||
let tunnel = OpenVpnTunnel::new(
|
||||
"/nonexistent/path/to/client.ovpn".into(),
|
||||
None,
|
||||
None,
|
||||
30,
|
||||
vec![],
|
||||
);
|
||||
let result = tunnel.start("127.0.0.1", 8080).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("config file not found"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user