Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2539bcafe0 | |||
| 314e1d3ae8 | |||
| 82be05b1e9 | |||
| 1373659058 | |||
| c7f064e866 | |||
| 9c1d63e109 | |||
| 966edf1553 | |||
| a1af84d992 | |||
| 0ad1965081 | |||
| 70e8e7ebcd | |||
| 2bcb82c5b3 | |||
| e211b5c3e3 | |||
| 8691476577 | |||
| e34a804255 | |||
| 6120b3f705 | |||
| f175261e32 | |||
| fd9f66cad7 | |||
| d928ebc92e | |||
| 9fca9f478a | |||
| 7106632b51 | |||
| b834278754 | |||
| 186f6d9797 | |||
| 6cdc92a256 | |||
| 02599dcd3c | |||
| fe64d7ef7e | |||
| 996dbe95cf | |||
| 45f953be6d | |||
| 82f29bbcb1 | |||
| 93b5a0b824 | |||
| 08a67c4a2d | |||
| c86a0673ba |
Generated
+1
-1
@@ -7945,7 +7945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.1"
|
||||
version = "0.3.4"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-imap",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.3.1"
|
||||
version = "0.3.4"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
+202
-10
@@ -195,6 +195,18 @@ const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000;
|
||||
/// Max characters retained in stored compaction summary.
|
||||
const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000;
|
||||
|
||||
/// Estimate token count for a message history using ~4 chars/token heuristic.
|
||||
/// Includes a small overhead per message for role/framing tokens.
|
||||
fn estimate_history_tokens(history: &[ChatMessage]) -> usize {
|
||||
history
|
||||
.iter()
|
||||
.map(|m| {
|
||||
// ~4 chars per token + ~4 framing tokens per message (role, delimiters)
|
||||
m.content.len().div_ceil(4) + 4
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Minimum interval between progress sends to avoid flooding the draft channel.
|
||||
pub(crate) const PROGRESS_MIN_INTERVAL_MS: u64 = 500;
|
||||
|
||||
@@ -288,6 +300,7 @@ async fn auto_compact_history(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
max_history: usize,
|
||||
max_context_tokens: usize,
|
||||
) -> Result<bool> {
|
||||
let has_system = history.first().map_or(false, |m| m.role == "system");
|
||||
let non_system_count = if has_system {
|
||||
@@ -296,7 +309,10 @@ async fn auto_compact_history(
|
||||
history.len()
|
||||
};
|
||||
|
||||
if non_system_count <= max_history {
|
||||
let estimated_tokens = estimate_history_tokens(history);
|
||||
|
||||
// Trigger compaction when either token budget OR message count is exceeded.
|
||||
if estimated_tokens <= max_context_tokens && non_system_count <= max_history {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
@@ -307,7 +323,16 @@ async fn auto_compact_history(
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let compact_end = start + compact_count;
|
||||
let mut compact_end = start + compact_count;
|
||||
|
||||
// Snap compact_end to a user-turn boundary so we don't split mid-conversation.
|
||||
while compact_end > start && history.get(compact_end).map_or(false, |m| m.role != "user") {
|
||||
compact_end -= 1;
|
||||
}
|
||||
if compact_end <= start {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
|
||||
let transcript = build_compaction_transcript(&to_compact);
|
||||
|
||||
@@ -2635,6 +2660,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"arguments": scrub_credentials(&tool_args.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!(
|
||||
"\u{274c} {}: {}\n",
|
||||
call.name,
|
||||
truncate_with_ellipsis(&scrub_credentials(&cancelled), 200)
|
||||
))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
call.name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2662,11 +2696,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||
arguments: tool_args.clone(),
|
||||
};
|
||||
|
||||
// Only prompt interactively on CLI; auto-approve on other channels.
|
||||
let decision = if channel_name == "cli" {
|
||||
mgr.prompt_cli(&request)
|
||||
// Interactive CLI: prompt the operator.
|
||||
// Non-interactive (channels): auto-deny since no operator
|
||||
// is present to approve.
|
||||
let decision = if mgr.is_non_interactive() {
|
||||
ApprovalResponse::No
|
||||
} else {
|
||||
ApprovalResponse::Yes
|
||||
mgr.prompt_cli(&request)
|
||||
};
|
||||
|
||||
mgr.record_decision(&tool_name, &tool_args, decision, channel_name);
|
||||
@@ -2687,6 +2723,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"arguments": scrub_credentials(&tool_args.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!("\u{274c} {}: {}\n", tool_name, denied))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
tool_name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2723,6 +2764,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
"deduplicated": true,
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!("\u{274c} {}: {}\n", tool_name, duplicate))
|
||||
.await;
|
||||
}
|
||||
ordered_results[idx] = Some((
|
||||
tool_name.clone(),
|
||||
call.tool_call_id.clone(),
|
||||
@@ -2825,13 +2871,19 @@ pub(crate) async fn run_tool_call_loop(
|
||||
// ── Progress: tool completion ───────────────────────
|
||||
if let Some(ref tx) = on_delta {
|
||||
let secs = outcome.duration.as_secs();
|
||||
let icon = if outcome.success {
|
||||
"\u{2705}"
|
||||
let progress_msg = if outcome.success {
|
||||
format!("\u{2705} {} ({secs}s)\n", call.name)
|
||||
} else if let Some(ref reason) = outcome.error_reason {
|
||||
format!(
|
||||
"\u{274c} {} ({secs}s): {}\n",
|
||||
call.name,
|
||||
truncate_with_ellipsis(reason, 200)
|
||||
)
|
||||
} else {
|
||||
"\u{274c}"
|
||||
format!("\u{274c} {} ({secs}s)\n", call.name)
|
||||
};
|
||||
tracing::debug!(tool = %call.name, secs, "Sending progress complete to draft");
|
||||
let _ = tx.send(format!("{icon} {} ({secs}s)\n", call.name)).await;
|
||||
let _ = tx.send(progress_msg).await;
|
||||
}
|
||||
|
||||
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
|
||||
@@ -3508,6 +3560,7 @@ pub async fn run(
|
||||
provider.as_ref(),
|
||||
model_name,
|
||||
config.agent.max_history_messages,
|
||||
config.agent.max_context_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -4109,6 +4162,52 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that always returns a failure with a given error reason.
|
||||
struct FailingTool {
|
||||
tool_name: String,
|
||||
error_reason: String,
|
||||
}
|
||||
|
||||
impl FailingTool {
|
||||
fn new(name: &str, error_reason: &str) -> Self {
|
||||
Self {
|
||||
tool_name: name.to_string(),
|
||||
error_reason: error_reason.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FailingTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"A tool that always fails for testing failure surfacing"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
) -> anyhow::Result<crate::tools::ToolResult> {
|
||||
Ok(crate::tools::ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(self.error_reason.clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
@@ -6449,4 +6548,97 @@ Let me check the result."#;
|
||||
let result = filter_tool_specs_for_turn(specs, &groups, "BROWSE the site");
|
||||
assert_eq!(result.len(), 1);
|
||||
}
|
||||
|
||||
// ── Token-based compaction tests ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_empty() {
|
||||
assert_eq!(super::estimate_history_tokens(&[]), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_single_message() {
|
||||
let history = vec![ChatMessage::user("hello world")]; // 11 chars
|
||||
let tokens = super::estimate_history_tokens(&history);
|
||||
// 11.div_ceil(4) + 4 = 3 + 4 = 7
|
||||
assert_eq!(tokens, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_history_tokens_multiple_messages() {
|
||||
let history = vec![
|
||||
ChatMessage::system("You are helpful."), // 16 chars → 4 + 4 = 8
|
||||
ChatMessage::user("What is Rust?"), // 13 chars → 4 + 4 = 8
|
||||
ChatMessage::assistant("A language."), // 11 chars → 3 + 4 = 7
|
||||
];
|
||||
let tokens = super::estimate_history_tokens(&history);
|
||||
assert_eq!(tokens, 23);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_surfaces_tool_failure_reason_in_on_delta() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"failing_shell","arguments":{"command":"rm -rf /"}}
|
||||
</tool_call>"#,
|
||||
"I could not execute that command.",
|
||||
]);
|
||||
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(FailingTool::new(
|
||||
"failing_shell",
|
||||
"Command not allowed by security policy: rm -rf /",
|
||||
))];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("delete everything"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"telegram",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
Some(tx),
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
|
||||
// Collect all messages sent to the on_delta channel.
|
||||
let mut deltas = Vec::new();
|
||||
while let Ok(msg) = rx.try_recv() {
|
||||
deltas.push(msg);
|
||||
}
|
||||
|
||||
let all_deltas = deltas.join("");
|
||||
|
||||
// The failure reason should appear in the progress messages.
|
||||
assert!(
|
||||
all_deltas.contains("Command not allowed by security policy"),
|
||||
"on_delta messages should include the tool failure reason, got: {all_deltas}"
|
||||
);
|
||||
|
||||
// Should also contain the cross mark (❌) icon to indicate failure.
|
||||
assert!(
|
||||
all_deltas.contains('\u{274c}'),
|
||||
"on_delta messages should include ❌ for failed tool calls, got: {all_deltas}"
|
||||
);
|
||||
|
||||
assert_eq!(result, "I could not execute that command.");
|
||||
}
|
||||
}
|
||||
|
||||
+128
-4
@@ -44,11 +44,18 @@ pub struct ApprovalLogEntry {
|
||||
|
||||
// ── ApprovalManager ──────────────────────────────────────────────
|
||||
|
||||
/// Manages the interactive approval workflow.
|
||||
/// Manages the approval workflow for tool calls.
|
||||
///
|
||||
/// - Checks config-level `auto_approve` / `always_ask` lists
|
||||
/// - Maintains a session-scoped "always" allowlist
|
||||
/// - Records an audit trail of all decisions
|
||||
///
|
||||
/// Two modes:
|
||||
/// - **Interactive** (CLI): tools needing approval trigger a stdin prompt.
|
||||
/// - **Non-interactive** (channels): tools needing approval are auto-denied
|
||||
/// because there is no interactive operator to approve them. `auto_approve`
|
||||
/// policy is still enforced, and `always_ask` / supervised-default tools are
|
||||
/// denied rather than silently allowed.
|
||||
pub struct ApprovalManager {
|
||||
/// Tools that never need approval (from config).
|
||||
auto_approve: HashSet<String>,
|
||||
@@ -56,6 +63,9 @@ pub struct ApprovalManager {
|
||||
always_ask: HashSet<String>,
|
||||
/// Autonomy level from config.
|
||||
autonomy_level: AutonomyLevel,
|
||||
/// When `true`, tools that would require interactive approval are
|
||||
/// auto-denied instead. Used for channel-driven (non-CLI) runs.
|
||||
non_interactive: bool,
|
||||
/// Session-scoped allowlist built from "Always" responses.
|
||||
session_allowlist: Mutex<HashSet<String>>,
|
||||
/// Audit trail of approval decisions.
|
||||
@@ -63,17 +73,40 @@ pub struct ApprovalManager {
|
||||
}
|
||||
|
||||
impl ApprovalManager {
|
||||
/// Create from autonomy config.
|
||||
/// Create an interactive (CLI) approval manager from autonomy config.
|
||||
pub fn from_config(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: false,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-interactive approval manager for channel-driven runs.
|
||||
///
|
||||
/// Enforces the same `auto_approve` / `always_ask` / supervised policies
|
||||
/// as the CLI manager, but tools that would require interactive approval
|
||||
/// are auto-denied instead of prompting (since there is no operator).
|
||||
pub fn for_non_interactive(config: &AutonomyConfig) -> Self {
|
||||
Self {
|
||||
auto_approve: config.auto_approve.iter().cloned().collect(),
|
||||
always_ask: config.always_ask.iter().cloned().collect(),
|
||||
autonomy_level: config.level,
|
||||
non_interactive: true,
|
||||
session_allowlist: Mutex::new(HashSet::new()),
|
||||
audit_log: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` when this manager operates in non-interactive mode
|
||||
/// (i.e. for channel-driven runs where no operator can approve).
|
||||
pub fn is_non_interactive(&self) -> bool {
|
||||
self.non_interactive
|
||||
}
|
||||
|
||||
/// Check whether a tool call requires interactive approval.
|
||||
///
|
||||
/// Returns `true` if the call needs a prompt, `false` if it can proceed.
|
||||
@@ -147,8 +180,8 @@ impl ApprovalManager {
|
||||
|
||||
/// Prompt the user on the CLI and return their decision.
|
||||
///
|
||||
/// For non-CLI channels, returns `Yes` automatically (interactive
|
||||
/// approval is only supported on CLI for now).
|
||||
/// Only called for interactive (CLI) managers. Non-interactive managers
|
||||
/// auto-deny in the tool-call loop before reaching this point.
|
||||
pub fn prompt_cli(&self, request: &ApprovalRequest) -> ApprovalResponse {
|
||||
prompt_cli_interactive(request)
|
||||
}
|
||||
@@ -401,6 +434,97 @@ mod tests {
|
||||
assert!(summary.contains("just a string"));
|
||||
}
|
||||
|
||||
// ── non-interactive (channel) mode ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn non_interactive_manager_reports_non_interactive() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_manager_reports_interactive() {
|
||||
let mgr = ApprovalManager::from_config(&supervised_config());
|
||||
assert!(!mgr.is_non_interactive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_auto_approve_tools_skip_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// auto_approve tools (file_read, memory_recall) should not need approval.
|
||||
assert!(!mgr.needs_approval("file_read"));
|
||||
assert!(!mgr.needs_approval("memory_recall"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_tools_need_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// always_ask tools (shell) still report as needing approval,
|
||||
// so the tool-call loop will auto-deny them in non-interactive mode.
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_unknown_tools_need_approval_in_supervised() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
// Unknown tools in supervised mode need approval (will be auto-denied
|
||||
// by the tool-call loop for non-interactive managers).
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
assert!(mgr.needs_approval("http_request"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_full_autonomy_never_needs_approval() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&full_config());
|
||||
// Full autonomy means no approval needed, even in non-interactive mode.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
assert!(!mgr.needs_approval("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_readonly_never_needs_approval() {
|
||||
let config = AutonomyConfig {
|
||||
level: AutonomyLevel::ReadOnly,
|
||||
..AutonomyConfig::default()
|
||||
};
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
// ReadOnly blocks execution elsewhere; approval manager does not prompt.
|
||||
assert!(!mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_session_allowlist_still_works() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
assert!(mgr.needs_approval("file_write"));
|
||||
|
||||
// Simulate an "Always" decision (would come from a prior channel run
|
||||
// if the tool was auto-approved somehow, e.g. via config change).
|
||||
mgr.record_decision(
|
||||
"file_write",
|
||||
&serde_json::json!({"path": "test.txt"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
assert!(!mgr.needs_approval("file_write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_always_ask_overrides_session_allowlist() {
|
||||
let mgr = ApprovalManager::for_non_interactive(&supervised_config());
|
||||
|
||||
mgr.record_decision(
|
||||
"shell",
|
||||
&serde_json::json!({"command": "ls"}),
|
||||
ApprovalResponse::Always,
|
||||
"telegram",
|
||||
);
|
||||
|
||||
// shell is in always_ask, so it still needs approval even after "Always".
|
||||
assert!(mgr.needs_approval("shell"));
|
||||
}
|
||||
|
||||
// ── ApprovalResponse serde ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -746,7 +746,7 @@ impl Channel for MatrixChannel {
|
||||
MessageType::Notice(content) => (content.body.clone(), None),
|
||||
MessageType::Image(content) => {
|
||||
let dl = media_info(&content.source, &content.body);
|
||||
(format!("[image: {}]", content.body), dl)
|
||||
(format!("[IMAGE:{}]", content.body), dl)
|
||||
}
|
||||
MessageType::File(content) => {
|
||||
let dl = media_info(&content.source, &content.body);
|
||||
@@ -888,7 +888,7 @@ impl Channel for MatrixChannel {
|
||||
sender: sender.clone(),
|
||||
reply_target: format!("{}||{}", sender, room.room_id()),
|
||||
content: body,
|
||||
channel: format!("matrix:{}", room.room_id()),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
|
||||
+172
-1
@@ -31,6 +31,7 @@ pub mod nextcloud_talk;
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
pub mod nostr;
|
||||
pub mod qq;
|
||||
pub mod session_store;
|
||||
pub mod signal;
|
||||
pub mod slack;
|
||||
pub mod telegram;
|
||||
@@ -75,6 +76,7 @@ pub use whatsapp::WhatsAppChannel;
|
||||
pub use whatsapp_web::WhatsAppWebChannel;
|
||||
|
||||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, scrub_credentials};
|
||||
use crate::approval::ApprovalManager;
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
use crate::memory::{self, Memory};
|
||||
@@ -312,6 +314,12 @@ struct ChannelRuntimeContext {
|
||||
model_routes: Arc<Vec<crate::config::ModelRouteConfig>>,
|
||||
ack_reactions: bool,
|
||||
show_tool_calls: bool,
|
||||
session_store: Option<Arc<session_store::SessionStore>>,
|
||||
/// Non-interactive approval manager for channel-driven runs.
|
||||
/// Enforces `auto_approve` / `always_ask` / supervised policy from
|
||||
/// `[autonomy]` config; auto-denies tools that would need interactive
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -965,6 +973,13 @@ fn proactive_trim_turns(turns: &mut Vec<ChatMessage>, budget: usize) -> usize {
|
||||
}
|
||||
|
||||
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||||
// Persist to JSONL before adding to in-memory history.
|
||||
if let Some(ref store) = ctx.session_store {
|
||||
if let Err(e) = store.append(sender_key, &turn) {
|
||||
tracing::warn!("Failed to persist session turn: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let mut histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
@@ -2016,7 +2031,7 @@ async fn process_channel_message(
|
||||
route.model.as_str(),
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
None,
|
||||
Some(&*ctx.approval_manager),
|
||||
msg.channel.as_str(),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
@@ -2186,6 +2201,29 @@ async fn process_channel_message(
|
||||
&history_key,
|
||||
ChatMessage::assistant(&history_response),
|
||||
);
|
||||
|
||||
// Fire-and-forget LLM-driven memory consolidation.
|
||||
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
let provider = Arc::clone(&ctx.provider);
|
||||
let model = ctx.model.to_string();
|
||||
let memory = Arc::clone(&ctx.memory);
|
||||
let user_msg = msg.content.clone();
|
||||
let assistant_resp = delivered_response.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::memory::consolidation::consolidate_turn(
|
||||
provider.as_ref(),
|
||||
&model,
|
||||
memory.as_ref(),
|
||||
&user_msg,
|
||||
&assistant_resp,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::debug!("Memory consolidation skipped: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
@@ -3203,6 +3241,8 @@ fn collect_configured_channels(
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||||
eprintln!(" ⚠ WhatsApp Web is configured but the 'whatsapp-web' feature is not compiled in.");
|
||||
eprintln!(" Rebuild with: cargo build --features whatsapp-web");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -3805,8 +3845,43 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
model_routes: Arc::new(config.model_routes.clone()),
|
||||
ack_reactions: config.channels_config.ack_reactions,
|
||||
show_tool_calls: config.channels_config.show_tool_calls,
|
||||
session_store: if config.channels_config.session_persistence {
|
||||
match session_store::SessionStore::new(&config.workspace_dir) {
|
||||
Ok(store) => {
|
||||
tracing::info!("📂 Session persistence enabled");
|
||||
Some(Arc::new(store))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Session persistence disabled: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
if let Some(ref store) = runtime_ctx.session_store {
|
||||
let mut hydrated = 0usize;
|
||||
let mut histories = runtime_ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
for key in store.list_sessions() {
|
||||
let msgs = store.load(&key);
|
||||
if !msgs.is_empty() {
|
||||
hydrated += 1;
|
||||
histories.insert(key, msgs);
|
||||
}
|
||||
}
|
||||
drop(histories);
|
||||
if hydrated > 0 {
|
||||
tracing::info!("📂 Restored {hydrated} session(s) from disk");
|
||||
}
|
||||
}
|
||||
|
||||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||||
|
||||
// Wait for all channel tasks
|
||||
@@ -4072,6 +4147,10 @@ mod tests {
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -4175,6 +4254,10 @@ mod tests {
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -4234,6 +4317,10 @@ mod tests {
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -4751,6 +4838,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4818,6 +4909,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4899,6 +4994,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -4965,6 +5064,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5041,6 +5144,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5137,6 +5244,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5215,6 +5326,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5308,6 +5423,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5386,6 +5505,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5454,6 +5577,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5633,6 +5760,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -5720,6 +5851,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5817,11 +5952,15 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
},
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
hooks: None,
|
||||
non_cli_excluded_tools: Arc::new(Vec::new()),
|
||||
tool_call_dedup_exempt: Arc::new(Vec::new()),
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -5921,6 +6060,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6002,6 +6145,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6068,6 +6215,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6692,6 +6843,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6784,6 +6939,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6876,6 +7035,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7432,6 +7595,10 @@ This is an example JSON object for profile settings."#;
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -7505,6 +7672,10 @@ This is an example JSON object for profile settings."#;
|
||||
model_routes: Arc::new(Vec::new()),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_store: None,
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
//! JSONL-based session persistence for channel conversations.
|
||||
//!
|
||||
//! Each session (keyed by `channel_sender` or `channel_thread_sender`) is stored
|
||||
//! as an append-only JSONL file in `{workspace}/sessions/`. Messages are appended
|
||||
//! one-per-line as JSON, never modifying old lines. On daemon restart, sessions
|
||||
//! are loaded from disk to restore conversation context.
|
||||
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Append-only JSONL session store for channel conversations.
|
||||
pub struct SessionStore {
|
||||
sessions_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Create a new session store, ensuring the sessions directory exists.
|
||||
pub fn new(workspace_dir: &Path) -> std::io::Result<Self> {
|
||||
let sessions_dir = workspace_dir.join("sessions");
|
||||
std::fs::create_dir_all(&sessions_dir)?;
|
||||
Ok(Self { sessions_dir })
|
||||
}
|
||||
|
||||
/// Compute the file path for a session key, sanitizing for filesystem safety.
|
||||
fn session_path(&self, session_key: &str) -> PathBuf {
|
||||
let safe_key: String = session_key
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_alphanumeric() || c == '_' || c == '-' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
self.sessions_dir.join(format!("{safe_key}.jsonl"))
|
||||
}
|
||||
|
||||
/// Load all messages for a session from its JSONL file.
|
||||
/// Returns an empty vec if the file does not exist or is unreadable.
|
||||
pub fn load(&self, session_key: &str) -> Vec<ChatMessage> {
|
||||
let path = self.session_path(session_key);
|
||||
let file = match std::fs::File::open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(msg) = serde_json::from_str::<ChatMessage>(trimmed) {
|
||||
messages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Append a single message to the session JSONL file.
|
||||
pub fn append(&self, session_key: &str, message: &ChatMessage) -> std::io::Result<()> {
|
||||
let path = self.session_path(session_key);
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)?;
|
||||
|
||||
let json = serde_json::to_string(message)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
writeln!(file, "{json}")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all session keys that have files on disk.
|
||||
pub fn list_sessions(&self) -> Vec<String> {
|
||||
let entries = match std::fs::read_dir(&self.sessions_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
entries
|
||||
.filter_map(|entry| {
|
||||
let entry = entry.ok()?;
|
||||
let name = entry.file_name().into_string().ok()?;
|
||||
name.strip_suffix(".jsonl").map(String::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn round_trip_append_and_load() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("telegram_user123", &ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("telegram_user123", &ChatMessage::assistant("hi there"))
|
||||
.unwrap();
|
||||
|
||||
let messages = store.load("telegram_user123");
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "hello");
|
||||
assert_eq!(messages[1].role, "assistant");
|
||||
assert_eq!(messages[1].content, "hi there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_nonexistent_session_returns_empty() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
let messages = store.load("nonexistent");
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_sanitization() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
// Keys with special chars should be sanitized
|
||||
store
|
||||
.append("slack/thread:123/user", &ChatMessage::user("test"))
|
||||
.unwrap();
|
||||
|
||||
let messages = store.load("slack/thread:123/user");
|
||||
assert_eq!(messages.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sessions_returns_keys() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
|
||||
store
|
||||
.append("telegram_alice", &ChatMessage::user("hi"))
|
||||
.unwrap();
|
||||
store
|
||||
.append("discord_bob", &ChatMessage::user("hey"))
|
||||
.unwrap();
|
||||
|
||||
let mut sessions = store.list_sessions();
|
||||
sessions.sort();
|
||||
assert_eq!(sessions.len(), 2);
|
||||
assert!(sessions.contains(&"discord_bob".to_string()));
|
||||
assert!(sessions.contains(&"telegram_alice".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_is_truly_append_only() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "test_session";
|
||||
|
||||
store.append(key, &ChatMessage::user("msg1")).unwrap();
|
||||
store.append(key, &ChatMessage::user("msg2")).unwrap();
|
||||
|
||||
// Read raw file to verify append-only format
|
||||
let path = store.session_path(key);
|
||||
let content = std::fs::read_to_string(&path).unwrap();
|
||||
let lines: Vec<&str> = content.trim().lines().collect();
|
||||
assert_eq!(lines.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_corrupt_lines_gracefully() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
let key = "corrupt_test";
|
||||
|
||||
// Write valid message + corrupt line + valid message
|
||||
let path = store.session_path(key);
|
||||
std::fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
let mut file = std::fs::File::create(&path).unwrap();
|
||||
writeln!(file, r#"{{"role":"user","content":"hello"}}"#).unwrap();
|
||||
writeln!(file, "this is not valid json").unwrap();
|
||||
writeln!(file, r#"{{"role":"assistant","content":"world"}}"#).unwrap();
|
||||
|
||||
let messages = store.load(key);
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].content, "hello");
|
||||
assert_eq!(messages[1].content, "world");
|
||||
}
|
||||
}
|
||||
+3
-2
@@ -17,8 +17,9 @@ pub use schema::{
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, TunnelConfig,
|
||||
WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
+151
-1
@@ -232,6 +232,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub agents: HashMap<String, DelegateAgentConfig>,
|
||||
|
||||
/// Swarm configurations for multi-agent orchestration.
|
||||
#[serde(default)]
|
||||
pub swarms: HashMap<String, SwarmConfig>,
|
||||
|
||||
/// Hooks configuration (lifecycle hooks and built-in hook toggles).
|
||||
#[serde(default)]
|
||||
pub hooks: HooksConfig,
|
||||
@@ -319,6 +323,44 @@ pub struct DelegateAgentConfig {
|
||||
pub max_iterations: usize,
|
||||
}
|
||||
|
||||
// ── Swarms ──────────────────────────────────────────────────────
|
||||
|
||||
/// Orchestration strategy for a swarm of agents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SwarmStrategy {
|
||||
/// Run agents sequentially; each agent's output feeds into the next.
|
||||
Sequential,
|
||||
/// Run agents in parallel; collect all outputs.
|
||||
Parallel,
|
||||
/// Use the LLM to pick the best agent for the task.
|
||||
Router,
|
||||
}
|
||||
|
||||
/// Configuration for a swarm of coordinated agents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SwarmConfig {
|
||||
/// Ordered list of agent names (must reference keys in `agents`).
|
||||
pub agents: Vec<String>,
|
||||
/// Orchestration strategy.
|
||||
pub strategy: SwarmStrategy,
|
||||
/// System prompt for router strategy (used to pick the best agent).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub router_prompt: Option<String>,
|
||||
/// Optional description shown to the LLM when choosing swarms.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Maximum total timeout for the swarm execution in seconds.
|
||||
#[serde(default = "default_swarm_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
const DEFAULT_SWARM_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn default_swarm_timeout_secs() -> u64 {
|
||||
DEFAULT_SWARM_TIMEOUT_SECS
|
||||
}
|
||||
|
||||
/// Valid temperature range for all paths (config, CLI, env override).
|
||||
pub const TEMPERATURE_RANGE: std::ops::RangeInclusive<f64> = 0.0..=2.0;
|
||||
|
||||
@@ -791,6 +833,11 @@ pub struct AgentConfig {
|
||||
/// Maximum conversation history messages retained per session. Default: `50`.
|
||||
#[serde(default = "default_agent_max_history_messages")]
|
||||
pub max_history_messages: usize,
|
||||
/// Maximum estimated tokens for conversation history before compaction triggers.
|
||||
/// Uses ~4 chars/token heuristic. When this threshold is exceeded, older messages
|
||||
/// are summarized to preserve context while staying within budget. Default: `32000`.
|
||||
#[serde(default = "default_agent_max_context_tokens")]
|
||||
pub max_context_tokens: usize,
|
||||
/// Enable parallel tool execution within a single iteration. Default: `false`.
|
||||
#[serde(default)]
|
||||
pub parallel_tools: bool,
|
||||
@@ -817,6 +864,10 @@ fn default_agent_max_history_messages() -> usize {
|
||||
50
|
||||
}
|
||||
|
||||
fn default_agent_max_context_tokens() -> usize {
|
||||
32_000
|
||||
}
|
||||
|
||||
fn default_agent_tool_dispatcher() -> String {
|
||||
"auto".into()
|
||||
}
|
||||
@@ -827,6 +878,7 @@ impl Default for AgentConfig {
|
||||
compact_context: false,
|
||||
max_tool_iterations: default_agent_max_tool_iterations(),
|
||||
max_history_messages: default_agent_max_history_messages(),
|
||||
max_context_tokens: default_agent_max_context_tokens(),
|
||||
parallel_tools: false,
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
@@ -1413,6 +1465,10 @@ pub struct HttpRequestConfig {
|
||||
/// Request timeout in seconds (default: 30)
|
||||
#[serde(default = "default_http_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
/// Allow requests to private/LAN hosts (RFC 1918, loopback, link-local, .local).
|
||||
/// Default: false (deny private hosts for SSRF protection).
|
||||
#[serde(default)]
|
||||
pub allow_private_hosts: bool,
|
||||
}
|
||||
|
||||
impl Default for HttpRequestConfig {
|
||||
@@ -1422,6 +1478,7 @@ impl Default for HttpRequestConfig {
|
||||
allowed_domains: vec![],
|
||||
max_response_size: default_http_max_response_size(),
|
||||
timeout_secs: default_http_timeout_secs(),
|
||||
allow_private_hosts: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2895,22 +2952,34 @@ pub struct HeartbeatConfig {
|
||||
pub enabled: bool,
|
||||
/// Interval in minutes between heartbeat pings. Default: `30`.
|
||||
pub interval_minutes: u32,
|
||||
/// Enable two-phase heartbeat: Phase 1 asks LLM whether to run, Phase 2
|
||||
/// executes only when the LLM decides there is work to do. Saves API cost
|
||||
/// during quiet periods. Default: `true`.
|
||||
#[serde(default = "default_two_phase")]
|
||||
pub two_phase: bool,
|
||||
/// Optional fallback task text when `HEARTBEAT.md` has no task entries.
|
||||
#[serde(default)]
|
||||
pub message: Option<String>,
|
||||
/// Optional delivery channel for heartbeat output (for example: `telegram`).
|
||||
/// When omitted, auto-selects the first configured channel.
|
||||
#[serde(default, alias = "channel")]
|
||||
pub target: Option<String>,
|
||||
/// Optional delivery recipient/chat identifier (required when `target` is set).
|
||||
/// Optional delivery recipient/chat identifier (required when `target` is
|
||||
/// explicitly set).
|
||||
#[serde(default, alias = "recipient")]
|
||||
pub to: Option<String>,
|
||||
}
|
||||
|
||||
fn default_two_phase() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Default for HeartbeatConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
interval_minutes: 30,
|
||||
two_phase: true,
|
||||
message: None,
|
||||
target: None,
|
||||
to: None,
|
||||
@@ -3040,6 +3109,7 @@ impl<T: ChannelConfig> crate::config::traits::ConfigHandle for ConfigWrapper<T>
|
||||
///
|
||||
/// Each channel sub-section (e.g. `telegram`, `discord`) is optional;
|
||||
/// setting it to `Some(...)` enables that channel.
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ChannelsConfig {
|
||||
/// Enable the CLI interactive channel. Default: `true`.
|
||||
@@ -3102,6 +3172,10 @@ pub struct ChannelsConfig {
|
||||
/// not forwarded as individual channel messages. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub show_tool_calls: bool,
|
||||
/// Persist channel conversation history to JSONL files so sessions survive
|
||||
/// daemon restarts. Files are stored in `{workspace}/sessions/`. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub session_persistence: bool,
|
||||
}
|
||||
|
||||
impl ChannelsConfig {
|
||||
@@ -3236,6 +3310,7 @@ impl Default for ChannelsConfig {
|
||||
message_timeout_secs: default_channel_message_timeout_secs(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4169,6 +4244,7 @@ impl Default for Config {
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
@@ -6217,6 +6293,7 @@ default_temperature = 0.7
|
||||
heartbeat: HeartbeatConfig {
|
||||
enabled: true,
|
||||
interval_minutes: 15,
|
||||
two_phase: true,
|
||||
message: Some("Check London time".into()),
|
||||
target: Some("telegram".into()),
|
||||
to: Some("123456".into()),
|
||||
@@ -6256,6 +6333,7 @@ default_temperature = 0.7
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
},
|
||||
memory: MemoryConfig::default(),
|
||||
storage: StorageConfig::default(),
|
||||
@@ -6274,6 +6352,7 @@ default_temperature = 0.7
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
@@ -6565,6 +6644,7 @@ tool_dispatcher = "xml"
|
||||
cost: CostConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
swarms: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
@@ -6970,6 +7050,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -7197,6 +7278,7 @@ channel_id = "C123"
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
session_persistence: true,
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
@@ -9315,4 +9397,72 @@ require_otp_to_resume = true
|
||||
assert_eq!(&deserialized, variant);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_strategy_roundtrip() {
|
||||
let cases = vec![
|
||||
(SwarmStrategy::Sequential, "\"sequential\""),
|
||||
(SwarmStrategy::Parallel, "\"parallel\""),
|
||||
(SwarmStrategy::Router, "\"router\""),
|
||||
];
|
||||
for (variant, expected_json) in &cases {
|
||||
let serialized = serde_json::to_string(variant).expect("serialize");
|
||||
assert_eq!(&serialized, expected_json, "variant: {variant:?}");
|
||||
let deserialized: SwarmStrategy =
|
||||
serde_json::from_str(expected_json).expect("deserialize");
|
||||
assert_eq!(&deserialized, variant);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_config_deserializes_with_defaults() {
|
||||
let toml_str = r#"
|
||||
agents = ["researcher", "writer"]
|
||||
strategy = "sequential"
|
||||
"#;
|
||||
let config: SwarmConfig = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents, vec!["researcher", "writer"]);
|
||||
assert_eq!(config.strategy, SwarmStrategy::Sequential);
|
||||
assert!(config.router_prompt.is_none());
|
||||
assert!(config.description.is_none());
|
||||
assert_eq!(config.timeout_secs, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn swarm_config_deserializes_full() {
|
||||
let toml_str = r#"
|
||||
agents = ["a", "b", "c"]
|
||||
strategy = "router"
|
||||
router_prompt = "Pick the best."
|
||||
description = "Multi-agent router"
|
||||
timeout_secs = 120
|
||||
"#;
|
||||
let config: SwarmConfig = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents.len(), 3);
|
||||
assert_eq!(config.strategy, SwarmStrategy::Router);
|
||||
assert_eq!(config.router_prompt.as_deref(), Some("Pick the best."));
|
||||
assert_eq!(config.description.as_deref(), Some("Multi-agent router"));
|
||||
assert_eq!(config.timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn config_with_swarms_section_deserializes() {
|
||||
let toml_str = r#"
|
||||
[agents.researcher]
|
||||
provider = "ollama"
|
||||
model = "llama3"
|
||||
|
||||
[agents.writer]
|
||||
provider = "openrouter"
|
||||
model = "claude-sonnet"
|
||||
|
||||
[swarms.pipeline]
|
||||
agents = ["researcher", "writer"]
|
||||
strategy = "sequential"
|
||||
"#;
|
||||
let config: Config = toml::from_str(toml_str).expect("deserialize");
|
||||
assert_eq!(config.agents.len(), 2);
|
||||
assert_eq!(config.swarms.len(), 1);
|
||||
assert!(config.swarms.contains_key("pipeline"));
|
||||
}
|
||||
}
|
||||
|
||||
+172
-21
@@ -152,44 +152,122 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<(
|
||||
crate::CronCommands::Add {
|
||||
expression,
|
||||
tz,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
let schedule = Schedule::Cron {
|
||||
expr: expression,
|
||||
tz,
|
||||
};
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added cron job {}", job.id);
|
||||
println!(" Expr: {}", job.expression);
|
||||
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
println!("✅ Added agent cron job {}", job.id);
|
||||
println!(" Expr : {}", job.expression);
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added cron job {}", job.id);
|
||||
println!(" Expr: {}", job.expression);
|
||||
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::AddAt { at, command } => {
|
||||
crate::CronCommands::AddAt { at, agent, command } => {
|
||||
let at = chrono::DateTime::parse_from_rfc3339(&at)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid RFC3339 timestamp for --at: {e}"))?
|
||||
.with_timezone(&chrono::Utc);
|
||||
let schedule = Schedule::At { at };
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
println!("✅ Added one-shot agent cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::AddEvery { every_ms, command } => {
|
||||
crate::CronCommands::AddEvery {
|
||||
every_ms,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
let schedule = Schedule::Every { every_ms };
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added interval cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
if agent {
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
println!("✅ Added interval agent cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt : {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_shell_job(config, None, schedule, &command)?;
|
||||
println!("✅ Added interval cron job {}", job.id);
|
||||
println!(" Every(ms): {every_ms}");
|
||||
println!(" Next : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::Once { delay, command } => {
|
||||
let job = add_once(config, &delay, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
crate::CronCommands::Once {
|
||||
delay,
|
||||
agent,
|
||||
command,
|
||||
} => {
|
||||
if agent {
|
||||
let duration = parse_delay(&delay)?;
|
||||
let at = chrono::Utc::now() + duration;
|
||||
let schedule = Schedule::At { at };
|
||||
let job = add_agent_job(
|
||||
config,
|
||||
None,
|
||||
schedule,
|
||||
&command,
|
||||
SessionTarget::Isolated,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
println!("✅ Added one-shot agent cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Prompt: {}", job.prompt.as_deref().unwrap_or_default());
|
||||
} else {
|
||||
let job = add_once(config, &delay, &command)?;
|
||||
println!("✅ Added one-shot cron job {}", job.id);
|
||||
println!(" At : {}", job.next_run.to_rfc3339());
|
||||
println!(" Cmd : {}", job.command);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::CronCommands::Update {
|
||||
@@ -686,4 +764,77 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("blocked by security policy"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_agent_flag_creates_agent_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/15 * * * *".into(),
|
||||
tz: None,
|
||||
agent: true,
|
||||
command: "Check server health: disk space, memory, CPU load".into(),
|
||||
},
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Agent);
|
||||
assert_eq!(
|
||||
jobs[0].prompt.as_deref(),
|
||||
Some("Check server health: disk space, memory, CPU load")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_agent_flag_bypasses_shell_security_validation() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut config = test_config(&tmp);
|
||||
config.autonomy.allowed_commands = vec!["echo".into()];
|
||||
config.autonomy.level = crate::security::AutonomyLevel::Supervised;
|
||||
|
||||
// Without --agent, a natural language string would be blocked by shell
|
||||
// security policy. With --agent, it routes to agent job and skips
|
||||
// shell validation entirely.
|
||||
let result = handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/15 * * * *".into(),
|
||||
tz: None,
|
||||
agent: true,
|
||||
command: "Check server health: disk space, memory, CPU load".into(),
|
||||
},
|
||||
&config,
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Agent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_without_agent_flag_defaults_to_shell_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
handle_command(
|
||||
crate::CronCommands::Add {
|
||||
expression: "*/5 * * * *".into(),
|
||||
tz: None,
|
||||
agent: false,
|
||||
command: "echo ok".into(),
|
||||
},
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let jobs = list_jobs(&config).unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].job_type, JobType::Shell);
|
||||
assert_eq!(jobs[0].command, "echo ok");
|
||||
}
|
||||
}
|
||||
|
||||
+140
-58
@@ -203,14 +203,17 @@ where
|
||||
}
|
||||
|
||||
async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
use crate::heartbeat::engine::HeartbeatEngine;
|
||||
|
||||
let observer: std::sync::Arc<dyn crate::observability::Observer> =
|
||||
std::sync::Arc::from(crate::observability::create_observer(&config.observability));
|
||||
let engine = crate::heartbeat::engine::HeartbeatEngine::new(
|
||||
let engine = HeartbeatEngine::new(
|
||||
config.heartbeat.clone(),
|
||||
config.workspace_dir.clone(),
|
||||
observer,
|
||||
);
|
||||
let delivery = heartbeat_delivery_target(&config)?;
|
||||
let delivery = resolve_heartbeat_delivery(&config)?;
|
||||
let two_phase = config.heartbeat.two_phase;
|
||||
|
||||
let interval_mins = config.heartbeat.interval_minutes.max(5);
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
|
||||
@@ -218,14 +221,71 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let file_tasks = engine.collect_tasks().await?;
|
||||
let tasks = heartbeat_tasks_for_tick(file_tasks, config.heartbeat.message.as_deref());
|
||||
// Collect runnable tasks (active only, sorted by priority)
|
||||
let mut tasks = engine.collect_runnable_tasks().await?;
|
||||
if tasks.is_empty() {
|
||||
continue;
|
||||
// Try fallback message
|
||||
if let Some(fallback) = config
|
||||
.heartbeat
|
||||
.message
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|m| !m.is_empty())
|
||||
{
|
||||
tasks.push(crate::heartbeat::engine::HeartbeatTask {
|
||||
text: fallback.to_string(),
|
||||
priority: crate::heartbeat::engine::TaskPriority::Medium,
|
||||
status: crate::heartbeat::engine::TaskStatus::Active,
|
||||
});
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let prompt = format!("[Heartbeat Task] {task}");
|
||||
// ── Phase 1: LLM decision (two-phase mode) ──────────────
|
||||
let tasks_to_run = if two_phase {
|
||||
let decision_prompt = HeartbeatEngine::build_decision_prompt(&tasks);
|
||||
match crate::agent::run(
|
||||
config.clone(),
|
||||
Some(decision_prompt),
|
||||
None,
|
||||
None,
|
||||
0.0, // Low temperature for deterministic decision
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
let indices = HeartbeatEngine::parse_decision_response(&response, tasks.len());
|
||||
if indices.is_empty() {
|
||||
tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)");
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
continue;
|
||||
}
|
||||
tracing::info!(
|
||||
"💓 Heartbeat Phase 1: run {} of {} tasks",
|
||||
indices.len(),
|
||||
tasks.len()
|
||||
);
|
||||
indices
|
||||
.into_iter()
|
||||
.filter_map(|i| tasks.get(i).cloned())
|
||||
.collect()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("💓 Heartbeat Phase 1 failed, running all tasks: {e}");
|
||||
tasks
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tasks
|
||||
};
|
||||
|
||||
// ── Phase 2: Execute selected tasks ─────────────────────
|
||||
for task in &tasks_to_run {
|
||||
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
|
||||
let temp = config.default_temperature;
|
||||
match crate::agent::run(
|
||||
config.clone(),
|
||||
@@ -242,7 +302,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
Ok(output) => {
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
let announcement = if output.trim().is_empty() {
|
||||
"heartbeat task executed".to_string()
|
||||
format!("💓 heartbeat task completed: {}", task.text)
|
||||
} else {
|
||||
output
|
||||
};
|
||||
@@ -272,22 +332,8 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn heartbeat_tasks_for_tick(
|
||||
file_tasks: Vec<String>,
|
||||
fallback_message: Option<&str>,
|
||||
) -> Vec<String> {
|
||||
if !file_tasks.is_empty() {
|
||||
return file_tasks;
|
||||
}
|
||||
|
||||
fallback_message
|
||||
.map(str::trim)
|
||||
.filter(|message| !message.is_empty())
|
||||
.map(|message| vec![message.to_string()])
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn heartbeat_delivery_target(config: &Config) -> Result<Option<(String, String)>> {
|
||||
/// Resolve delivery target: explicit config > auto-detect first configured channel.
|
||||
fn resolve_heartbeat_delivery(config: &Config) -> Result<Option<(String, String)>> {
|
||||
let channel = config
|
||||
.heartbeat
|
||||
.target
|
||||
@@ -302,16 +348,45 @@ fn heartbeat_delivery_target(config: &Config) -> Result<Option<(String, String)>
|
||||
.filter(|value| !value.is_empty());
|
||||
|
||||
match (channel, target) {
|
||||
(None, None) => Ok(None),
|
||||
(Some(_), None) => anyhow::bail!("heartbeat.to is required when heartbeat.target is set"),
|
||||
(None, Some(_)) => anyhow::bail!("heartbeat.target is required when heartbeat.to is set"),
|
||||
// Both explicitly set — validate and use.
|
||||
(Some(channel), Some(target)) => {
|
||||
validate_heartbeat_channel_config(config, channel)?;
|
||||
Ok(Some((channel.to_string(), target.to_string())))
|
||||
}
|
||||
// Only one set — error.
|
||||
(Some(_), None) => anyhow::bail!("heartbeat.to is required when heartbeat.target is set"),
|
||||
(None, Some(_)) => anyhow::bail!("heartbeat.target is required when heartbeat.to is set"),
|
||||
// Neither set — try auto-detect the first configured channel.
|
||||
(None, None) => Ok(auto_detect_heartbeat_channel(config)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Auto-detect the best channel for heartbeat delivery by checking which
|
||||
/// channels are configured. Returns the first match in priority order.
|
||||
fn auto_detect_heartbeat_channel(config: &Config) -> Option<(String, String)> {
|
||||
// Priority order: telegram > discord > slack > mattermost
|
||||
if let Some(tg) = &config.channels_config.telegram {
|
||||
// Use the first allowed_user as target, or fall back to empty (broadcast)
|
||||
let target = tg.allowed_users.first().cloned().unwrap_or_default();
|
||||
if !target.is_empty() {
|
||||
return Some(("telegram".to_string(), target));
|
||||
}
|
||||
}
|
||||
if config.channels_config.discord.is_some() {
|
||||
// Discord requires explicit target — can't auto-detect
|
||||
return None;
|
||||
}
|
||||
if config.channels_config.slack.is_some() {
|
||||
// Slack requires explicit target
|
||||
return None;
|
||||
}
|
||||
if config.channels_config.mattermost.is_some() {
|
||||
// Mattermost requires explicit target
|
||||
return None;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn validate_heartbeat_channel_config(config: &Config, channel: &str) -> Result<()> {
|
||||
match channel.to_ascii_lowercase().as_str() {
|
||||
"telegram" => {
|
||||
@@ -487,75 +562,56 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_tasks_use_file_tasks_when_available() {
|
||||
let tasks =
|
||||
heartbeat_tasks_for_tick(vec!["From file".to_string()], Some("Fallback from config"));
|
||||
assert_eq!(tasks, vec!["From file".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_tasks_fall_back_to_config_message() {
|
||||
let tasks = heartbeat_tasks_for_tick(vec![], Some(" check london time "));
|
||||
assert_eq!(tasks, vec!["check london time".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_tasks_ignore_empty_fallback_message() {
|
||||
let tasks = heartbeat_tasks_for_tick(vec![], Some(" "));
|
||||
assert!(tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_none_when_unset() {
|
||||
fn resolve_delivery_none_when_unset() {
|
||||
let config = Config::default();
|
||||
let target = heartbeat_delivery_target(&config).unwrap();
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
assert!(target.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_requires_to_field() {
|
||||
fn resolve_delivery_requires_to_field() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("telegram".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("heartbeat.to is required when heartbeat.target is set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_requires_target_field() {
|
||||
fn resolve_delivery_requires_target_field() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.to = Some("123456".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("heartbeat.target is required when heartbeat.to is set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_rejects_unsupported_channel() {
|
||||
fn resolve_delivery_rejects_unsupported_channel() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("email".into());
|
||||
config.heartbeat.to = Some("ops@example.com".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("unsupported heartbeat.target channel"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_requires_channel_configuration() {
|
||||
fn resolve_delivery_requires_channel_configuration() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("telegram".into());
|
||||
config.heartbeat.to = Some("123456".into());
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
let err = resolve_heartbeat_delivery(&config).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("channels_config.telegram is not configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_accepts_telegram_configuration() {
|
||||
fn resolve_delivery_accepts_telegram_configuration() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("telegram".into());
|
||||
config.heartbeat.to = Some("123456".into());
|
||||
@@ -568,7 +624,33 @@ mod tests {
|
||||
mention_only: false,
|
||||
});
|
||||
|
||||
let target = heartbeat_delivery_target(&config).unwrap();
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
assert_eq!(target, Some(("telegram".to_string(), "123456".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_detect_telegram_when_configured() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.telegram = Some(crate::config::TelegramConfig {
|
||||
bot_token: "bot-token".into(),
|
||||
allowed_users: vec!["user123".into()],
|
||||
stream_mode: crate::config::StreamMode::default(),
|
||||
draft_update_interval_ms: 1000,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
});
|
||||
|
||||
let target = resolve_heartbeat_delivery(&config).unwrap();
|
||||
assert_eq!(
|
||||
target,
|
||||
Some(("telegram".to_string(), "user123".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_detect_none_when_no_channels() {
|
||||
let config = Config::default();
|
||||
let target = auto_detect_heartbeat_channel(&config);
|
||||
assert!(target.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
pub mod types;
|
||||
|
||||
pub use types::{Hand, HandContext, HandRun, HandRunStatus};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Load all hand definitions from TOML files in the given directory.
|
||||
///
|
||||
/// Each `.toml` file in `hands_dir` is expected to deserialize into a [`Hand`].
|
||||
/// Files that fail to parse are logged and skipped.
|
||||
pub fn load_hands(hands_dir: &Path) -> Result<Vec<Hand>> {
|
||||
if !hands_dir.is_dir() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut hands = Vec::new();
|
||||
let entries = std::fs::read_dir(hands_dir)
|
||||
.with_context(|| format!("failed to read hands directory: {}", hands_dir.display()))?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("toml") {
|
||||
continue;
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand file: {}", path.display()))?;
|
||||
match toml::from_str::<Hand>(&content) {
|
||||
Ok(hand) => hands.push(hand),
|
||||
Err(e) => {
|
||||
tracing::warn!(path = %path.display(), error = %e, "skipping malformed hand file");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hands)
|
||||
}
|
||||
|
||||
/// Load the rolling context for a hand.
|
||||
///
|
||||
/// Reads from `{hands_dir}/{name}/context.json`. Returns a fresh
|
||||
/// [`HandContext`] if the file does not exist yet.
|
||||
pub fn load_hand_context(hands_dir: &Path, name: &str) -> Result<HandContext> {
|
||||
let path = hands_dir.join(name).join("context.json");
|
||||
if !path.exists() {
|
||||
return Ok(HandContext::new(name));
|
||||
}
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("failed to read hand context: {}", path.display()))?;
|
||||
let ctx: HandContext = serde_json::from_str(&content)
|
||||
.with_context(|| format!("failed to parse hand context: {}", path.display()))?;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
/// Persist the rolling context for a hand.
|
||||
///
|
||||
/// Writes to `{hands_dir}/{name}/context.json`, creating the
|
||||
/// directory if it does not exist.
|
||||
pub fn save_hand_context(hands_dir: &Path, context: &HandContext) -> Result<()> {
|
||||
let dir = hands_dir.join(&context.hand_name);
|
||||
std::fs::create_dir_all(&dir)
|
||||
.with_context(|| format!("failed to create hand context dir: {}", dir.display()))?;
|
||||
let path = dir.join("context.json");
|
||||
let json = serde_json::to_string_pretty(context)?;
|
||||
std::fs::write(&path, json)
|
||||
.with_context(|| format!("failed to write hand context: {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn write_hand_toml(dir: &Path, filename: &str, content: &str) {
|
||||
std::fs::write(dir.join(filename), content).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_empty_dir() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_nonexistent_dir() {
|
||||
let hands = load_hands(Path::new("/nonexistent/path/hands")).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_parses_valid_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"scanner.toml",
|
||||
r#"
|
||||
name = "scanner"
|
||||
description = "Market scanner"
|
||||
prompt = "Scan markets."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * *"
|
||||
"#,
|
||||
);
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"digest.toml",
|
||||
r#"
|
||||
name = "digest"
|
||||
description = "News digest"
|
||||
prompt = "Digest news."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_skips_malformed_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
write_hand_toml(tmp.path(), "bad.toml", "this is not valid toml struct");
|
||||
write_hand_toml(
|
||||
tmp.path(),
|
||||
"good.toml",
|
||||
r#"
|
||||
name = "good"
|
||||
description = "A good hand"
|
||||
prompt = "Do good things."
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 60000
|
||||
"#,
|
||||
);
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert_eq!(hands.len(), 1);
|
||||
assert_eq!(hands[0].name, "good");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_hands_ignores_non_toml_files() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
std::fs::write(tmp.path().join("readme.md"), "# Hands").unwrap();
|
||||
std::fs::write(tmp.path().join("notes.txt"), "some notes").unwrap();
|
||||
|
||||
let hands = load_hands(tmp.path()).unwrap();
|
||||
assert!(hands.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_roundtrip_through_filesystem() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("test-hand");
|
||||
let run = HandRun {
|
||||
hand_name: "test-hand".into(),
|
||||
run_id: "run-001".into(),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec!["found something".into()],
|
||||
knowledge_added: vec!["learned something".into()],
|
||||
duration_ms: Some(500),
|
||||
};
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "test-hand").unwrap();
|
||||
|
||||
assert_eq!(loaded.hand_name, "test-hand");
|
||||
assert_eq!(loaded.total_runs, 1);
|
||||
assert_eq!(loaded.history.len(), 1);
|
||||
assert_eq!(loaded.learned_facts, vec!["learned something"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_context_returns_fresh_when_missing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = load_hand_context(tmp.path(), "nonexistent").unwrap();
|
||||
assert_eq!(ctx.hand_name, "nonexistent");
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_context_creates_directory() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let ctx = HandContext::new("new-hand");
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
|
||||
assert!(tmp.path().join("new-hand").join("context.json").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_then_load_preserves_multiple_runs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut ctx = HandContext::new("multi");
|
||||
|
||||
for i in 0..5 {
|
||||
let run = HandRun {
|
||||
hand_name: "multi".into(),
|
||||
run_id: format!("run-{i:03}"),
|
||||
started_at: chrono::Utc::now(),
|
||||
finished_at: Some(chrono::Utc::now()),
|
||||
status: HandRunStatus::Completed,
|
||||
findings: vec![format!("finding-{i}")],
|
||||
knowledge_added: vec![format!("fact-{i}")],
|
||||
duration_ms: Some(100),
|
||||
};
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
|
||||
save_hand_context(tmp.path(), &ctx).unwrap();
|
||||
let loaded = load_hand_context(tmp.path(), "multi").unwrap();
|
||||
|
||||
assert_eq!(loaded.total_runs, 5);
|
||||
assert_eq!(loaded.history.len(), 3, "history capped at max_history=3");
|
||||
assert_eq!(loaded.learned_facts.len(), 5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::cron::Schedule;
|
||||
|
||||
// ── Hand ───────────────────────────────────────────────────────
|
||||
|
||||
/// A Hand is an autonomous agent package that runs on a schedule,
|
||||
/// accumulates knowledge over time, and reports results.
|
||||
///
|
||||
/// Hands are defined as TOML files in `~/.zeroclaw/hands/` and each
|
||||
/// maintains a rolling context of findings across runs so the agent
|
||||
/// grows smarter with every execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Hand {
|
||||
/// Unique name (also used as directory/file stem)
|
||||
pub name: String,
|
||||
/// Human-readable description of what this hand does
|
||||
pub description: String,
|
||||
/// The schedule this hand runs on (reuses cron schedule types)
|
||||
pub schedule: Schedule,
|
||||
/// System prompt / execution plan for this hand
|
||||
pub prompt: String,
|
||||
/// Domain knowledge lines to inject into context
|
||||
#[serde(default)]
|
||||
pub knowledge: Vec<String>,
|
||||
/// Tools this hand is allowed to use (None = all available)
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// Model override for this hand (None = default provider)
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
/// Whether this hand is currently active
|
||||
#[serde(default = "default_true")]
|
||||
pub active: bool,
|
||||
/// Maximum runs to keep in history
|
||||
#[serde(default = "default_max_runs")]
|
||||
pub max_history: usize,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_max_runs() -> usize {
|
||||
100
|
||||
}
|
||||
|
||||
// ── Hand Run ───────────────────────────────────────────────────
|
||||
|
||||
/// The status of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "status")]
|
||||
pub enum HandRunStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
/// Record of a single hand execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandRun {
|
||||
/// Name of the hand that produced this run
|
||||
pub hand_name: String,
|
||||
/// Unique identifier for this run
|
||||
pub run_id: String,
|
||||
/// When the run started
|
||||
pub started_at: DateTime<Utc>,
|
||||
/// When the run finished (None if still running)
|
||||
pub finished_at: Option<DateTime<Utc>>,
|
||||
/// Outcome of the run
|
||||
pub status: HandRunStatus,
|
||||
/// Key findings/outputs extracted from this run
|
||||
#[serde(default)]
|
||||
pub findings: Vec<String>,
|
||||
/// New knowledge accumulated and stored to memory
|
||||
#[serde(default)]
|
||||
pub knowledge_added: Vec<String>,
|
||||
/// Wall-clock duration in milliseconds
|
||||
pub duration_ms: Option<u64>,
|
||||
}
|
||||
|
||||
// ── Hand Context ───────────────────────────────────────────────
|
||||
|
||||
/// Rolling context that accumulates across hand runs.
|
||||
///
|
||||
/// Persisted as `~/.zeroclaw/hands/{name}/context.json`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandContext {
|
||||
/// Name of the hand this context belongs to
|
||||
pub hand_name: String,
|
||||
/// Past runs, most-recent first, capped at `Hand::max_history`
|
||||
#[serde(default)]
|
||||
pub history: Vec<HandRun>,
|
||||
/// Persistent facts learned across runs
|
||||
#[serde(default)]
|
||||
pub learned_facts: Vec<String>,
|
||||
/// Timestamp of the last completed run
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
/// Total number of successful runs
|
||||
#[serde(default)]
|
||||
pub total_runs: u64,
|
||||
}
|
||||
|
||||
impl HandContext {
|
||||
/// Create a fresh, empty context for a hand.
|
||||
pub fn new(hand_name: &str) -> Self {
|
||||
Self {
|
||||
hand_name: hand_name.to_string(),
|
||||
history: Vec::new(),
|
||||
learned_facts: Vec::new(),
|
||||
last_run: None,
|
||||
total_runs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed run, updating counters and trimming history.
|
||||
pub fn record_run(&mut self, run: HandRun, max_history: usize) {
|
||||
if run.status == (HandRunStatus::Completed) {
|
||||
self.total_runs += 1;
|
||||
self.last_run = run.finished_at;
|
||||
}
|
||||
|
||||
// Merge new knowledge
|
||||
for fact in &run.knowledge_added {
|
||||
if !self.learned_facts.contains(fact) {
|
||||
self.learned_facts.push(fact.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Insert at the front (most-recent first)
|
||||
self.history.insert(0, run);
|
||||
|
||||
// Cap history length
|
||||
self.history.truncate(max_history);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cron::Schedule;
|
||||
|
||||
fn sample_hand() -> Hand {
|
||||
Hand {
|
||||
name: "market-scanner".into(),
|
||||
description: "Scans market trends and reports findings".into(),
|
||||
schedule: Schedule::Cron {
|
||||
expr: "0 9 * * 1-5".into(),
|
||||
tz: Some("America/New_York".into()),
|
||||
},
|
||||
prompt: "Scan market trends and report key findings.".into(),
|
||||
knowledge: vec!["Focus on tech sector.".into()],
|
||||
allowed_tools: Some(vec!["web_search".into(), "memory".into()]),
|
||||
model: Some("claude-opus-4-6".into()),
|
||||
active: true,
|
||||
max_history: 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_run(name: &str, status: HandRunStatus) -> HandRun {
|
||||
let now = Utc::now();
|
||||
HandRun {
|
||||
hand_name: name.into(),
|
||||
run_id: uuid::Uuid::new_v4().to_string(),
|
||||
started_at: now,
|
||||
finished_at: Some(now),
|
||||
status,
|
||||
findings: vec!["finding-1".into()],
|
||||
knowledge_added: vec!["learned-fact-A".into()],
|
||||
duration_ms: Some(1234),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Deserialization ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_from_toml() {
|
||||
let toml_str = r#"
|
||||
name = "market-scanner"
|
||||
description = "Scans market trends"
|
||||
prompt = "Scan trends."
|
||||
|
||||
[schedule]
|
||||
kind = "cron"
|
||||
expr = "0 9 * * 1-5"
|
||||
tz = "America/New_York"
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "market-scanner");
|
||||
assert!(hand.active, "active should default to true");
|
||||
assert_eq!(hand.max_history, 100, "max_history should default to 100");
|
||||
assert!(hand.knowledge.is_empty());
|
||||
assert!(hand.allowed_tools.is_none());
|
||||
assert!(hand.model.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_deserializes_full_toml() {
|
||||
let toml_str = r#"
|
||||
name = "news-digest"
|
||||
description = "Daily news digest"
|
||||
prompt = "Summarize the day's news."
|
||||
knowledge = ["focus on AI", "include funding rounds"]
|
||||
allowed_tools = ["web_search"]
|
||||
model = "claude-opus-4-6"
|
||||
active = false
|
||||
max_history = 25
|
||||
|
||||
[schedule]
|
||||
kind = "every"
|
||||
every_ms = 3600000
|
||||
"#;
|
||||
let hand: Hand = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(hand.name, "news-digest");
|
||||
assert!(!hand.active);
|
||||
assert_eq!(hand.max_history, 25);
|
||||
assert_eq!(hand.knowledge.len(), 2);
|
||||
assert_eq!(hand.allowed_tools.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(hand.model.as_deref(), Some("claude-opus-4-6"));
|
||||
assert!(matches!(
|
||||
hand.schedule,
|
||||
Schedule::Every {
|
||||
every_ms: 3_600_000
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hand_roundtrip_json() {
|
||||
let hand = sample_hand();
|
||||
let json = serde_json::to_string(&hand).unwrap();
|
||||
let parsed: Hand = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, hand.name);
|
||||
assert_eq!(parsed.max_history, hand.max_history);
|
||||
}
|
||||
|
||||
// ── HandRunStatus ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hand_run_status_serde_roundtrip() {
|
||||
let statuses = vec![
|
||||
HandRunStatus::Running,
|
||||
HandRunStatus::Completed,
|
||||
HandRunStatus::Failed {
|
||||
error: "timeout".into(),
|
||||
},
|
||||
];
|
||||
for status in statuses {
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
let parsed: HandRunStatus = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, status);
|
||||
}
|
||||
}
|
||||
|
||||
// ── HandContext ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn context_new_is_empty() {
|
||||
let ctx = HandContext::new("test-hand");
|
||||
assert_eq!(ctx.hand_name, "test-hand");
|
||||
assert!(ctx.history.is_empty());
|
||||
assert!(ctx.learned_facts.is_empty());
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_run_increments_counters() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 1);
|
||||
assert!(ctx.last_run.is_some());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
assert_eq!(ctx.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_record_failed_run_does_not_increment_total() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run(
|
||||
"scanner",
|
||||
HandRunStatus::Failed {
|
||||
error: "boom".into(),
|
||||
},
|
||||
);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
assert_eq!(ctx.total_runs, 0);
|
||||
assert!(ctx.last_run.is_none());
|
||||
assert_eq!(ctx.history.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_caps_history_at_max() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for _ in 0..10 {
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 3);
|
||||
}
|
||||
assert_eq!(ctx.history.len(), 3);
|
||||
assert_eq!(ctx.total_runs, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_deduplicates_learned_facts() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run1 = sample_run("scanner", HandRunStatus::Completed);
|
||||
let run2 = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run1, 100);
|
||||
ctx.record_run(run2, 100);
|
||||
|
||||
// Both runs add "learned-fact-A" but it should appear only once
|
||||
assert_eq!(ctx.learned_facts.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_json_roundtrip() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
let run = sample_run("scanner", HandRunStatus::Completed);
|
||||
ctx.record_run(run, 100);
|
||||
|
||||
let json = serde_json::to_string_pretty(&ctx).unwrap();
|
||||
let parsed: HandContext = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.hand_name, "scanner");
|
||||
assert_eq!(parsed.total_runs, 1);
|
||||
assert_eq!(parsed.history.len(), 1);
|
||||
assert_eq!(parsed.learned_facts, vec!["learned-fact-A"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn most_recent_run_is_first_in_history() {
|
||||
let mut ctx = HandContext::new("scanner");
|
||||
for i in 0..3 {
|
||||
let mut run = sample_run("scanner", HandRunStatus::Completed);
|
||||
run.findings = vec![format!("finding-{i}")];
|
||||
ctx.record_run(run, 100);
|
||||
}
|
||||
assert_eq!(ctx.history[0].findings[0], "finding-2");
|
||||
assert_eq!(ctx.history[2].findings[0], "finding-0");
|
||||
}
|
||||
}
|
||||
+399
-27
@@ -1,11 +1,75 @@
|
||||
use crate::config::HeartbeatConfig;
|
||||
use crate::observability::{Observer, ObserverEvent};
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{self, Duration};
|
||||
use tracing::{info, warn};
|
||||
|
||||
// ── Structured task types ────────────────────────────────────────
|
||||
|
||||
/// Priority level for a heartbeat task.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TaskPriority {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskPriority {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Low => write!(f, "low"),
|
||||
Self::Medium => write!(f, "medium"),
|
||||
Self::High => write!(f, "high"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Status of a heartbeat task.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TaskStatus {
|
||||
Active,
|
||||
Paused,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Active => write!(f, "active"),
|
||||
Self::Paused => write!(f, "paused"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A structured heartbeat task with priority and status metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HeartbeatTask {
|
||||
pub text: String,
|
||||
pub priority: TaskPriority,
|
||||
pub status: TaskStatus,
|
||||
}
|
||||
|
||||
impl HeartbeatTask {
|
||||
pub fn is_runnable(&self) -> bool {
|
||||
self.status == TaskStatus::Active
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for HeartbeatTask {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "[{}] {}", self.priority, self.text)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Engine ───────────────────────────────────────────────────────
|
||||
|
||||
/// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically
|
||||
pub struct HeartbeatEngine {
|
||||
config: HeartbeatConfig,
|
||||
@@ -64,8 +128,8 @@ impl HeartbeatEngine {
|
||||
Ok(self.collect_tasks().await?.len())
|
||||
}
|
||||
|
||||
/// Read HEARTBEAT.md and return all parsed tasks.
|
||||
pub async fn collect_tasks(&self) -> Result<Vec<String>> {
|
||||
/// Read HEARTBEAT.md and return all parsed structured tasks.
|
||||
pub async fn collect_tasks(&self) -> Result<Vec<HeartbeatTask>> {
|
||||
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
|
||||
if !heartbeat_path.exists() {
|
||||
return Ok(Vec::new());
|
||||
@@ -74,13 +138,145 @@ impl HeartbeatEngine {
|
||||
Ok(Self::parse_tasks(&content))
|
||||
}
|
||||
|
||||
/// Parse tasks from HEARTBEAT.md (lines starting with `- `)
|
||||
fn parse_tasks(content: &str) -> Vec<String> {
|
||||
/// Collect only runnable (active) tasks, sorted by priority (high first).
|
||||
pub async fn collect_runnable_tasks(&self) -> Result<Vec<HeartbeatTask>> {
|
||||
let mut tasks: Vec<HeartbeatTask> = self
|
||||
.collect_tasks()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(HeartbeatTask::is_runnable)
|
||||
.collect();
|
||||
// Sort by priority descending (High > Medium > Low)
|
||||
tasks.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
/// Parse tasks from HEARTBEAT.md with structured metadata support.
|
||||
///
|
||||
/// Supports both legacy flat format and new structured format:
|
||||
///
|
||||
/// Legacy:
|
||||
/// `- Check email` → medium priority, active status
|
||||
///
|
||||
/// Structured:
|
||||
/// `- [high] Check email` → high priority, active
|
||||
/// `- [low|paused] Review old PRs` → low priority, paused
|
||||
/// `- [completed] Old task` → medium priority, completed
|
||||
fn parse_tasks(content: &str) -> Vec<HeartbeatTask> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
trimmed.strip_prefix("- ").map(ToString::to_string)
|
||||
let text = trimmed.strip_prefix("- ")?;
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(Self::parse_task_line(text))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a single task line into a structured `HeartbeatTask`.
|
||||
///
|
||||
/// Format: `[priority|status] task text` or just `task text`.
|
||||
fn parse_task_line(text: &str) -> HeartbeatTask {
|
||||
if let Some(rest) = text.strip_prefix('[') {
|
||||
if let Some((meta, task_text)) = rest.split_once(']') {
|
||||
let task_text = task_text.trim();
|
||||
if !task_text.is_empty() {
|
||||
let (priority, status) = Self::parse_meta(meta);
|
||||
return HeartbeatTask {
|
||||
text: task_text.to_string(),
|
||||
priority,
|
||||
status,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
// No metadata — default to medium/active
|
||||
HeartbeatTask {
|
||||
text: text.to_string(),
|
||||
priority: TaskPriority::Medium,
|
||||
status: TaskStatus::Active,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse metadata tags like `high`, `low|paused`, `completed`.
|
||||
fn parse_meta(meta: &str) -> (TaskPriority, TaskStatus) {
|
||||
let mut priority = TaskPriority::Medium;
|
||||
let mut status = TaskStatus::Active;
|
||||
|
||||
for part in meta.split('|') {
|
||||
match part.trim().to_ascii_lowercase().as_str() {
|
||||
"high" => priority = TaskPriority::High,
|
||||
"medium" | "med" => priority = TaskPriority::Medium,
|
||||
"low" => priority = TaskPriority::Low,
|
||||
"active" => status = TaskStatus::Active,
|
||||
"paused" | "pause" => status = TaskStatus::Paused,
|
||||
"completed" | "complete" | "done" => status = TaskStatus::Completed,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
(priority, status)
|
||||
}
|
||||
|
||||
/// Build the Phase 1 LLM decision prompt for two-phase heartbeat.
|
||||
pub fn build_decision_prompt(tasks: &[HeartbeatTask]) -> String {
|
||||
let mut prompt = String::from(
|
||||
"You are a heartbeat scheduler. Review the following periodic tasks and decide \
|
||||
whether any should be executed right now.\n\n\
|
||||
Consider:\n\
|
||||
- Task priority (high tasks are more urgent)\n\
|
||||
- Whether the task is time-sensitive or can wait\n\
|
||||
- Whether running the task now would provide value\n\n\
|
||||
Tasks:\n",
|
||||
);
|
||||
|
||||
for (i, task) in tasks.iter().enumerate() {
|
||||
use std::fmt::Write;
|
||||
let _ = writeln!(prompt, "{}. [{}] {}", i + 1, task.priority, task.text);
|
||||
}
|
||||
|
||||
prompt.push_str(
|
||||
"\nRespond with ONLY one of:\n\
|
||||
- `run: 1,2,3` (comma-separated task numbers to execute)\n\
|
||||
- `skip` (nothing needs to run right now)\n\n\
|
||||
Be conservative — skip if tasks are routine and not time-sensitive.",
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Parse the Phase 1 LLM decision response.
|
||||
///
|
||||
/// Returns indices of tasks to run, or empty vec if skipped.
|
||||
pub fn parse_decision_response(response: &str, task_count: usize) -> Vec<usize> {
|
||||
let trimmed = response.trim().to_ascii_lowercase();
|
||||
|
||||
if trimmed == "skip" || trimmed.starts_with("skip") {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Look for "run: 1,2,3" pattern
|
||||
let numbers_part = if let Some(after_run) = trimmed.strip_prefix("run:") {
|
||||
after_run.trim()
|
||||
} else if let Some(after_run) = trimmed.strip_prefix("run ") {
|
||||
after_run.trim()
|
||||
} else {
|
||||
// Try to parse as bare numbers
|
||||
trimmed.as_str()
|
||||
};
|
||||
|
||||
numbers_part
|
||||
.split(',')
|
||||
.filter_map(|s| {
|
||||
let n: usize = s.trim().parse().ok()?;
|
||||
if n >= 1 && n <= task_count {
|
||||
Some(n - 1) // Convert to 0-indexed
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -93,10 +289,14 @@ impl HeartbeatEngine {
|
||||
# Add tasks below (one per line, starting with `- `)\n\
|
||||
# The agent will check this file on each heartbeat tick.\n\
|
||||
#\n\
|
||||
# Format: - [priority|status] Task description\n\
|
||||
# priority: high, medium (default), low\n\
|
||||
# status: active (default), paused, completed\n\
|
||||
#\n\
|
||||
# Examples:\n\
|
||||
# - Check my email for important messages\n\
|
||||
# - [high] Check my email for important messages\n\
|
||||
# - Review my calendar for upcoming events\n\
|
||||
# - Check the weather forecast\n";
|
||||
# - [low|paused] Check the weather forecast\n";
|
||||
tokio::fs::write(&path, default).await?;
|
||||
}
|
||||
Ok(())
|
||||
@@ -112,9 +312,9 @@ mod tests {
|
||||
let content = "# Tasks\n\n- Check email\n- Review calendar\nNot a task\n- Third task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert_eq!(tasks[0], "Check email");
|
||||
assert_eq!(tasks[1], "Review calendar");
|
||||
assert_eq!(tasks[2], "Third task");
|
||||
assert_eq!(tasks[0].text, "Check email");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -133,26 +333,21 @@ mod tests {
|
||||
let content = " - Indented task\n\t- Tab indented";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Indented task");
|
||||
assert_eq!(tasks[1], "Tab indented");
|
||||
assert_eq!(tasks[0].text, "Indented task");
|
||||
assert_eq!(tasks[1].text, "Tab indented");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_dash_without_space_ignored() {
|
||||
let content = "- Real task\n-\n- Another";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
// "-" trimmed = "-", does NOT start with "- " => skipped
|
||||
// "- Real task" => "Real task"
|
||||
// "- Another" => "Another"
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Real task");
|
||||
assert_eq!(tasks[1], "Another");
|
||||
assert_eq!(tasks[0].text, "Real task");
|
||||
assert_eq!(tasks[1].text, "Another");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_trailing_space_bullet_trimmed_to_dash() {
|
||||
// "- " trimmed becomes "-" (trim removes trailing space)
|
||||
// "-" does NOT start with "- " => skipped
|
||||
let content = "- ";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 0);
|
||||
@@ -160,11 +355,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_bullet_with_content_after_spaces() {
|
||||
// "- hello " trimmed becomes "- hello" => starts_with "- " => "hello"
|
||||
let content = "- hello ";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0], "hello");
|
||||
assert_eq!(tasks[0].text, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -172,8 +366,8 @@ mod tests {
|
||||
let content = "- Check email 📧\n- Review calendar 📅\n- 日本語タスク";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert!(tasks[0].contains("📧"));
|
||||
assert!(tasks[2].contains("日本語"));
|
||||
assert!(tasks[0].text.contains('📧'));
|
||||
assert!(tasks[2].text.contains("日本語"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -181,15 +375,15 @@ mod tests {
|
||||
let content = "# Periodic Tasks\n\n## Quick\n- Task A\n\n## Long\n- Task B\n\n* Not a dash bullet\n1. Not numbered";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Task A");
|
||||
assert_eq!(tasks[1], "Task B");
|
||||
assert_eq!(tasks[0].text, "Task A");
|
||||
assert_eq!(tasks[1].text, "Task B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_single_task() {
|
||||
let tasks = HeartbeatEngine::parse_tasks("- Only one");
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0], "Only one");
|
||||
assert_eq!(tasks[0].text, "Only one");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -201,9 +395,153 @@ mod tests {
|
||||
});
|
||||
let tasks = HeartbeatEngine::parse_tasks(&content);
|
||||
assert_eq!(tasks.len(), 100);
|
||||
assert_eq!(tasks[99], "Task 99");
|
||||
assert_eq!(tasks[99].text, "Task 99");
|
||||
}
|
||||
|
||||
// ── Structured task parsing tests ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_task_with_high_priority() {
|
||||
let content = "- [high] Urgent email check";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].text, "Urgent email check");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::High);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_task_with_low_paused() {
|
||||
let content = "- [low|paused] Review old PRs";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].text, "Review old PRs");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Low);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Paused);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_task_completed() {
|
||||
let content = "- [completed] Old task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Completed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_task_without_metadata_defaults() {
|
||||
let content = "- Plain task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0].text, "Plain task");
|
||||
assert_eq!(tasks[0].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[0].status, TaskStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mixed_structured_and_legacy() {
|
||||
let content = "- [high] Urgent\n- Normal task\n- [low|paused] Later";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert_eq!(tasks[0].priority, TaskPriority::High);
|
||||
assert_eq!(tasks[1].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[2].priority, TaskPriority::Low);
|
||||
assert_eq!(tasks[2].status, TaskStatus::Paused);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runnable_filters_paused_and_completed() {
|
||||
let content = "- [high] Active\n- [low|paused] Paused\n- [completed] Done";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
let runnable: Vec<_> = tasks
|
||||
.into_iter()
|
||||
.filter(HeartbeatTask::is_runnable)
|
||||
.collect();
|
||||
assert_eq!(runnable.len(), 1);
|
||||
assert_eq!(runnable[0].text, "Active");
|
||||
}
|
||||
|
||||
// ── Two-phase decision tests ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn decision_prompt_includes_all_tasks() {
|
||||
let tasks = vec![
|
||||
HeartbeatTask {
|
||||
text: "Check email".into(),
|
||||
priority: TaskPriority::High,
|
||||
status: TaskStatus::Active,
|
||||
},
|
||||
HeartbeatTask {
|
||||
text: "Review calendar".into(),
|
||||
priority: TaskPriority::Medium,
|
||||
status: TaskStatus::Active,
|
||||
},
|
||||
];
|
||||
let prompt = HeartbeatEngine::build_decision_prompt(&tasks);
|
||||
assert!(prompt.contains("1. [high] Check email"));
|
||||
assert!(prompt.contains("2. [medium] Review calendar"));
|
||||
assert!(prompt.contains("skip"));
|
||||
assert!(prompt.contains("run:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_skip() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("skip", 3);
|
||||
assert!(indices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_skip_with_reason() {
|
||||
let indices =
|
||||
HeartbeatEngine::parse_decision_response("skip — nothing urgent right now", 3);
|
||||
assert!(indices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_single() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 1", 3);
|
||||
assert_eq!(indices, vec![0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_multiple() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 1, 3", 3);
|
||||
assert_eq!(indices, vec![0, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_out_of_range_ignored() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 1, 5, 2", 3);
|
||||
assert_eq!(indices, vec![0, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_decision_run_zero_ignored() {
|
||||
let indices = HeartbeatEngine::parse_decision_response("run: 0, 1", 3);
|
||||
assert_eq!(indices, vec![0]);
|
||||
}
|
||||
|
||||
// ── Task display ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn task_display_format() {
|
||||
let task = HeartbeatTask {
|
||||
text: "Check email".into(),
|
||||
priority: TaskPriority::High,
|
||||
status: TaskStatus::Active,
|
||||
};
|
||||
assert_eq!(format!("{task}"), "[high] Check email");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn priority_ordering() {
|
||||
assert!(TaskPriority::High > TaskPriority::Medium);
|
||||
assert!(TaskPriority::Medium > TaskPriority::Low);
|
||||
}
|
||||
|
||||
// ── Async tests ─────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_heartbeat_file_creates_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_heartbeat");
|
||||
@@ -216,6 +554,7 @@ mod tests {
|
||||
assert!(path.exists());
|
||||
let content = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert!(content.contains("Periodic Tasks"));
|
||||
assert!(content.contains("[high]"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
@@ -301,4 +640,37 @@ mod tests {
|
||||
let result = engine.run().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn collect_runnable_tasks_sorts_by_priority() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_runnable_sort");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
tokio::fs::write(
|
||||
dir.join("HEARTBEAT.md"),
|
||||
"- [low] Low task\n- [high] High task\n- Medium task\n- [low|paused] Skip me",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine = HeartbeatEngine::new(
|
||||
HeartbeatConfig {
|
||||
enabled: true,
|
||||
interval_minutes: 30,
|
||||
..HeartbeatConfig::default()
|
||||
},
|
||||
dir.clone(),
|
||||
observer,
|
||||
);
|
||||
|
||||
let tasks = engine.collect_runnable_tasks().await.unwrap();
|
||||
assert_eq!(tasks.len(), 3); // paused one excluded
|
||||
assert_eq!(tasks[0].priority, TaskPriority::High);
|
||||
assert_eq!(tasks[1].priority, TaskPriority::Medium);
|
||||
assert_eq!(tasks[2].priority, TaskPriority::Low);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
}
|
||||
|
||||
+20
-6
@@ -48,6 +48,7 @@ pub(crate) mod cron;
|
||||
pub(crate) mod daemon;
|
||||
pub(crate) mod doctor;
|
||||
pub mod gateway;
|
||||
pub mod hands;
|
||||
pub(crate) mod hardware;
|
||||
pub(crate) mod health;
|
||||
pub(crate) mod heartbeat;
|
||||
@@ -280,15 +281,19 @@ Times are evaluated in UTC by default; use --tz with an IANA \
|
||||
timezone name to override.
|
||||
|
||||
Examples:
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health'")]
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
|
||||
zeroclaw cron add '*/5 * * * *' 'echo ok'")]
|
||||
Add {
|
||||
/// Cron expression
|
||||
expression: String,
|
||||
/// Optional IANA timezone (e.g. America/Los_Angeles)
|
||||
#[arg(long)]
|
||||
tz: Option<String>,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a one-shot scheduled task at an RFC3339 timestamp
|
||||
@@ -303,7 +308,10 @@ Examples:
|
||||
AddAt {
|
||||
/// One-shot timestamp in RFC3339 format
|
||||
at: String,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a fixed-interval scheduled task
|
||||
@@ -318,7 +326,10 @@ Examples:
|
||||
AddEvery {
|
||||
/// Interval in milliseconds
|
||||
every_ms: u64,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Add a one-shot delayed task (e.g. "30m", "2h", "1d")
|
||||
@@ -335,7 +346,10 @@ Examples:
|
||||
Once {
|
||||
/// Delay duration
|
||||
delay: String,
|
||||
/// Command to run
|
||||
/// Treat the argument as an agent prompt instead of a shell command
|
||||
#[arg(long)]
|
||||
agent: bool,
|
||||
/// Command (shell) or prompt (agent) to run
|
||||
command: String,
|
||||
},
|
||||
/// Remove a scheduled task
|
||||
|
||||
+32
-4
@@ -166,6 +166,10 @@ enum Commands {
|
||||
#[arg(long)]
|
||||
reinit: bool,
|
||||
|
||||
/// Run the full interactive setup wizard
|
||||
#[arg(long)]
|
||||
interactive: bool,
|
||||
|
||||
/// Reconfigure channels only (fast repair flow)
|
||||
#[arg(long)]
|
||||
channels_only: bool,
|
||||
@@ -325,11 +329,12 @@ override with --tz and an IANA timezone name.
|
||||
|
||||
Examples:
|
||||
zeroclaw cron list
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health'
|
||||
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder'
|
||||
zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York --agent
|
||||
zeroclaw cron add '*/30 * * * *' 'Check system health' --agent
|
||||
zeroclaw cron add '*/5 * * * *' 'echo ok'
|
||||
zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder' --agent
|
||||
zeroclaw cron add-every 60000 'Ping heartbeat'
|
||||
zeroclaw cron once 30m 'Run backup in 30 minutes'
|
||||
zeroclaw cron once 30m 'Run backup in 30 minutes' --agent
|
||||
zeroclaw cron pause <task-id>
|
||||
zeroclaw cron update <task-id> --expression '0 8 * * *' --tz Europe/London")]
|
||||
Cron {
|
||||
@@ -725,6 +730,7 @@ async fn main() -> Result<()> {
|
||||
if let Commands::Onboard {
|
||||
force,
|
||||
reinit,
|
||||
interactive,
|
||||
channels_only,
|
||||
api_key,
|
||||
provider,
|
||||
@@ -734,6 +740,7 @@ async fn main() -> Result<()> {
|
||||
{
|
||||
let force = *force;
|
||||
let reinit = *reinit;
|
||||
let interactive = *interactive;
|
||||
let channels_only = *channels_only;
|
||||
let api_key = api_key.clone();
|
||||
let provider = provider.clone();
|
||||
@@ -743,6 +750,14 @@ async fn main() -> Result<()> {
|
||||
if reinit && channels_only {
|
||||
bail!("--reinit and --channels-only cannot be used together");
|
||||
}
|
||||
if interactive && channels_only {
|
||||
bail!("--interactive and --channels-only cannot be used together");
|
||||
}
|
||||
if interactive
|
||||
&& (api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some())
|
||||
{
|
||||
bail!("--interactive does not accept --api-key, --provider, --model, or --memory");
|
||||
}
|
||||
if channels_only
|
||||
&& (api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some())
|
||||
{
|
||||
@@ -795,6 +810,8 @@ async fn main() -> Result<()> {
|
||||
|
||||
let config = if channels_only {
|
||||
Box::pin(onboard::run_channels_repair_wizard()).await
|
||||
} else if interactive {
|
||||
Box::pin(onboard::run_wizard(force)).await
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
@@ -2206,6 +2223,17 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onboard_cli_accepts_interactive_flag() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "onboard", "--interactive"])
|
||||
.expect("onboard --interactive should parse");
|
||||
|
||||
match cli.command {
|
||||
Commands::Onboard { interactive, .. } => assert!(interactive),
|
||||
other => panic!("expected onboard command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_parses_estop_default_engage() {
|
||||
let cli = Cli::try_parse_from(["zeroclaw", "estop"]).expect("estop command should parse");
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
//! LLM-driven memory consolidation.
|
||||
//!
|
||||
//! After each conversation turn, extracts structured information:
|
||||
//! - `history_entry`: A timestamped summary for the daily conversation log.
|
||||
//! - `memory_update`: New facts, preferences, or decisions worth remembering
|
||||
//! long-term (or `null` if nothing new was learned).
|
||||
//!
|
||||
//! This two-phase approach replaces the naive raw-message auto-save with
|
||||
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
|
||||
|
||||
use crate::memory::traits::{Memory, MemoryCategory};
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
/// Output of consolidation extraction.
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct ConsolidationResult {
|
||||
/// Brief timestamped summary for the conversation history log.
|
||||
pub history_entry: String,
|
||||
/// New facts/preferences/decisions to store long-term, or None.
|
||||
pub memory_update: Option<String>,
|
||||
}
|
||||
|
||||
const CONSOLIDATION_SYSTEM_PROMPT: &str = r#"You are a memory consolidation engine. Given a conversation turn, extract:
|
||||
1. "history_entry": A brief summary of what happened in this turn (1-2 sentences). Include the key topic or action.
|
||||
2. "memory_update": Any NEW facts, preferences, decisions, or commitments worth remembering long-term. Return null if nothing new was learned.
|
||||
|
||||
Respond ONLY with valid JSON: {"history_entry": "...", "memory_update": "..." or null}
|
||||
Do not include any text outside the JSON object."#;
|
||||
|
||||
/// Run two-phase LLM-driven consolidation on a conversation turn.
|
||||
///
|
||||
/// Phase 1: Write a history entry to the Daily memory category.
|
||||
/// Phase 2: Write a memory update to the Core category (if the LLM identified new facts).
|
||||
///
|
||||
/// This function is designed to be called fire-and-forget via `tokio::spawn`.
|
||||
pub async fn consolidate_turn(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
assistant_response: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let turn_text = format!("User: {user_message}\nAssistant: {assistant_response}");
|
||||
|
||||
// Truncate very long turns to avoid wasting tokens on consolidation.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8 (e.g. CJK text).
|
||||
let truncated = if turn_text.len() > 4000 {
|
||||
let mut end = 4000;
|
||||
while end > 0 && !turn_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}…", &turn_text[..end])
|
||||
} else {
|
||||
turn_text.clone()
|
||||
};
|
||||
|
||||
let raw = provider
|
||||
.chat_with_system(Some(CONSOLIDATION_SYSTEM_PROMPT), &truncated, model, 0.1)
|
||||
.await?;
|
||||
|
||||
let result: ConsolidationResult = parse_consolidation_response(&raw, &turn_text);
|
||||
|
||||
// Phase 1: Write history entry to Daily category.
|
||||
let date = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||||
let history_key = format!("daily_{date}_{}", uuid::Uuid::new_v4());
|
||||
memory
|
||||
.store(
|
||||
&history_key,
|
||||
&result.history_entry,
|
||||
MemoryCategory::Daily,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Phase 2: Write memory update to Core category (if present).
|
||||
if let Some(ref update) = result.memory_update {
|
||||
if !update.trim().is_empty() {
|
||||
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
|
||||
memory
|
||||
.store(&mem_key, update, MemoryCategory::Core, None)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse the LLM's consolidation response, with fallback for malformed JSON.
|
||||
fn parse_consolidation_response(raw: &str, fallback_text: &str) -> ConsolidationResult {
|
||||
// Try to extract JSON from the response (LLM may wrap in markdown code blocks).
|
||||
let cleaned = raw
|
||||
.trim()
|
||||
.trim_start_matches("```json")
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim();
|
||||
|
||||
serde_json::from_str(cleaned).unwrap_or_else(|_| {
|
||||
// Fallback: use truncated turn text as history entry.
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let summary = if fallback_text.len() > 200 {
|
||||
let mut end = 200;
|
||||
while end > 0 && !fallback_text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}…", &fallback_text[..end])
|
||||
} else {
|
||||
fallback_text.to_string()
|
||||
};
|
||||
ConsolidationResult {
|
||||
history_entry: summary,
|
||||
memory_update: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_valid_json_response() {
|
||||
let raw = r#"{"history_entry": "User asked about Rust.", "memory_update": "User prefers Rust over Go."}"#;
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "User asked about Rust.");
|
||||
assert_eq!(
|
||||
result.memory_update.as_deref(),
|
||||
Some("User prefers Rust over Go.")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_with_null_memory() {
|
||||
let raw = r#"{"history_entry": "Routine greeting.", "memory_update": null}"#;
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "Routine greeting.");
|
||||
assert!(result.memory_update.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_wrapped_in_code_block() {
|
||||
let raw =
|
||||
"```json\n{\"history_entry\": \"Discussed deployment.\", \"memory_update\": null}\n```";
|
||||
let result = parse_consolidation_response(raw, "fallback");
|
||||
assert_eq!(result.history_entry, "Discussed deployment.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_on_malformed_response() {
|
||||
let raw = "I'm sorry, I can't do that.";
|
||||
let result = parse_consolidation_response(raw, "User: hello\nAssistant: hi");
|
||||
assert_eq!(result.history_entry, "User: hello\nAssistant: hi");
|
||||
assert!(result.memory_update.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_truncates_long_text() {
|
||||
let long_text = "x".repeat(500);
|
||||
let result = parse_consolidation_response("invalid", &long_text);
|
||||
// 200 bytes + "…" (3 bytes in UTF-8) = 203
|
||||
assert!(result.history_entry.len() <= 203);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_truncates_cjk_text_without_panic() {
|
||||
// Each CJK character is 3 bytes in UTF-8; byte index 200 may land
|
||||
// inside a character. This must not panic.
|
||||
let cjk_text = "二手书项目".repeat(50); // 250 chars = 750 bytes
|
||||
let result = parse_consolidation_response("invalid", &cjk_text);
|
||||
assert!(result
|
||||
.history_entry
|
||||
.is_char_boundary(result.history_entry.len()));
|
||||
assert!(result.history_entry.ends_with('…'));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod consolidation;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod lucid;
|
||||
|
||||
+2
-1
@@ -4,7 +4,7 @@ pub mod wizard;
|
||||
#[allow(unused_imports)]
|
||||
pub use wizard::{
|
||||
run_channels_repair_wizard, run_models_list, run_models_refresh, run_models_refresh_all,
|
||||
run_models_set, run_models_status, run_quick_setup,
|
||||
run_models_set, run_models_status, run_quick_setup, run_wizard,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -17,6 +17,7 @@ mod tests {
|
||||
fn wizard_functions_are_reexported() {
|
||||
assert_reexport_exists(run_channels_repair_wizard);
|
||||
assert_reexport_exists(run_quick_setup);
|
||||
assert_reexport_exists(run_wizard);
|
||||
assert_reexport_exists(run_models_refresh);
|
||||
assert_reexport_exists(run_models_list);
|
||||
assert_reexport_exists(run_models_set);
|
||||
|
||||
@@ -170,6 +170,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: hardware_config,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -527,6 +528,7 @@ async fn run_quick_setup_with_home(
|
||||
cost: crate::config::CostConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
swarms: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
hardware: crate::config::HardwareConfig::default(),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
@@ -4147,6 +4149,23 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
.interact()?;
|
||||
|
||||
if mode_idx == 0 {
|
||||
// Compile-time check: warn early if the feature is not enabled.
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
println!();
|
||||
println!(
|
||||
" {} {}",
|
||||
style("⚠").yellow().bold(),
|
||||
style("The 'whatsapp-web' feature is not compiled in. WhatsApp Web will not work at runtime.").yellow()
|
||||
);
|
||||
println!(
|
||||
" {} Rebuild with: {}",
|
||||
style("→").dim(),
|
||||
style("cargo build --features whatsapp-web").white().bold()
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
println!(" {}", style("Mode: WhatsApp Web").dim());
|
||||
print_bullet("1. Build with --features whatsapp-web");
|
||||
print_bullet(
|
||||
|
||||
@@ -500,19 +500,23 @@ struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
function: Option<Function>,
|
||||
|
||||
// Compatibility: Some providers (e.g., older GLM) may use 'name' directly
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
arguments: Option<String>,
|
||||
|
||||
// Compatibility: DeepSeek sometimes wraps arguments differently
|
||||
#[serde(rename = "parameters", default)]
|
||||
#[serde(
|
||||
rename = "parameters",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
@@ -3094,4 +3098,50 @@ mod tests {
|
||||
// Should not panic
|
||||
let _client = p.http_client();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_none_fields_omitted_from_json() {
|
||||
// Ensures providers like Mistral that reject extra fields (e.g. "name": null)
|
||||
// don't receive them when the ToolCall compat fields are None.
|
||||
let tc = ToolCall {
|
||||
id: Some("call_1".to_string()),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert!(!json.as_object().unwrap().contains_key("name"));
|
||||
assert!(!json.as_object().unwrap().contains_key("arguments"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
// Standard fields must be present
|
||||
assert!(json.as_object().unwrap().contains_key("id"));
|
||||
assert!(json.as_object().unwrap().contains_key("type"));
|
||||
assert!(json.as_object().unwrap().contains_key("function"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_with_compat_fields_serializes_them() {
|
||||
// When compat fields are Some, they should appear in the output.
|
||||
let tc = ToolCall {
|
||||
id: None,
|
||||
kind: None,
|
||||
function: None,
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some("{\"command\":\"ls\"}".to_string()),
|
||||
parameters: None,
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert_eq!(json["name"], "shell");
|
||||
assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
|
||||
// None fields should be omitted
|
||||
assert!(!json.as_object().unwrap().contains_key("id"));
|
||||
assert!(!json.as_object().unwrap().contains_key("type"));
|
||||
assert!(!json.as_object().unwrap().contains_key("function"));
|
||||
assert!(!json.as_object().unwrap().contains_key("parameters"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::multimodal;
|
||||
use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities};
|
||||
use crate::providers::ProviderRuntimeOptions;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -472,8 +473,24 @@ fn extract_stream_error_message(event: &Value) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Read the response body incrementally via `bytes_stream()` to avoid
|
||||
/// buffering the entire SSE payload in memory. The previous implementation
|
||||
/// used `response.text().await?` which holds the HTTP connection open until
|
||||
/// every byte has arrived — on high-latency links the long-lived connection
|
||||
/// often drops mid-read, producing the "error decoding response body" failure
|
||||
/// reported in #3544.
|
||||
async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<String> {
|
||||
let body = response.text().await?;
|
||||
let mut body = String::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let bytes = chunk
|
||||
.map_err(|err| anyhow::anyhow!("error reading OpenAI Codex response stream: {err}"))?;
|
||||
let text = std::str::from_utf8(&bytes).map_err(|err| {
|
||||
anyhow::anyhow!("OpenAI Codex response contained invalid UTF-8: {err}")
|
||||
})?;
|
||||
body.push_str(text);
|
||||
}
|
||||
|
||||
if let Some(text) = parse_sse_text(&body)? {
|
||||
return Ok(text);
|
||||
|
||||
+16
-1
@@ -67,7 +67,13 @@ pub fn redact(value: &str) -> String {
|
||||
if value.len() <= 4 {
|
||||
"***".to_string()
|
||||
} else {
|
||||
format!("{}***", &value[..4])
|
||||
// Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
|
||||
let prefix = value
|
||||
.char_indices()
|
||||
.nth(4)
|
||||
.map(|(byte_idx, _)| &value[..byte_idx])
|
||||
.unwrap_or(value);
|
||||
format!("{}***", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,4 +108,13 @@ mod tests {
|
||||
assert_eq!(redact(""), "***");
|
||||
assert_eq!(redact("12345"), "1234***");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redact_handles_multibyte_utf8_without_panic() {
|
||||
// CJK characters are 3 bytes each; slicing at byte 4 would panic
|
||||
// without char-boundary-safe handling.
|
||||
let result = redact("密码是很长的秘密");
|
||||
assert!(result.ends_with("***"));
|
||||
assert!(result.is_char_boundary(result.len()));
|
||||
}
|
||||
}
|
||||
|
||||
+144
-3
@@ -793,6 +793,8 @@ impl SecurityPolicy {
|
||||
// 1. Allowlist check (is the base command permitted at all?)
|
||||
// 2. Risk classification (high / medium / low)
|
||||
// 3. Policy flags (block_high_risk_commands, require_approval_for_medium_risk)
|
||||
// — explicit allowlist entries exempt a command from the high-risk block,
|
||||
// but the wildcard "*" does NOT grant an exemption.
|
||||
// 4. Autonomy level × approval status (supervised requires explicit approval)
|
||||
// This ordering ensures deny-by-default: unknown commands are rejected
|
||||
// before any risk or autonomy logic runs.
|
||||
@@ -810,7 +812,7 @@ impl SecurityPolicy {
|
||||
let risk = self.command_risk_level(command);
|
||||
|
||||
if risk == CommandRiskLevel::High {
|
||||
if self.block_high_risk_commands {
|
||||
if self.block_high_risk_commands && !self.is_command_explicitly_allowed(command) {
|
||||
return Err("Command blocked: high-risk command is disallowed by policy".into());
|
||||
}
|
||||
if self.autonomy == AutonomyLevel::Supervised && !approved {
|
||||
@@ -834,6 +836,48 @@ impl SecurityPolicy {
|
||||
Ok(risk)
|
||||
}
|
||||
|
||||
/// Check whether **every** segment of a command is explicitly listed in
|
||||
/// `allowed_commands` — i.e., matched by a concrete entry rather than by
|
||||
/// the wildcard `"*"`.
|
||||
///
|
||||
/// This is used to exempt explicitly-allowlisted high-risk commands from
|
||||
/// the `block_high_risk_commands` gate. The wildcard entry intentionally
|
||||
/// does **not** qualify as an explicit allowlist match, so that operators
|
||||
/// who set `allowed_commands = ["*"]` still get the high-risk safety net.
|
||||
fn is_command_explicitly_allowed(&self, command: &str) -> bool {
|
||||
let segments = split_unquoted_segments(command);
|
||||
for segment in &segments {
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
|
||||
let base_cmd_owned = command_basename(executable).to_ascii_lowercase();
|
||||
let base_cmd = strip_windows_exe_suffix(&base_cmd_owned);
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let explicitly_listed = self.allowed_commands.iter().any(|allowed| {
|
||||
let allowed = strip_wrapping_quotes(allowed).trim();
|
||||
// Skip wildcard — it does not count as an explicit entry.
|
||||
if allowed.is_empty() || allowed == "*" {
|
||||
return false;
|
||||
}
|
||||
is_allowlist_entry_match(allowed, executable, base_cmd)
|
||||
});
|
||||
|
||||
if !explicitly_listed {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At least one real command must be present.
|
||||
segments.iter().any(|s| {
|
||||
let s = skip_env_assignments(s.trim());
|
||||
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
|
||||
})
|
||||
}
|
||||
|
||||
// ── Layered Command Allowlist ──────────────────────────────────────────
|
||||
// Defence-in-depth: five independent gates run in order before the
|
||||
// per-segment allowlist check. Each gate targets a specific bypass
|
||||
@@ -1503,10 +1547,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_blocks_high_risk_by_default() {
|
||||
fn validate_command_blocks_high_risk_via_wildcard() {
|
||||
// Wildcard allows the command through is_command_allowed, but
|
||||
// block_high_risk_commands still rejects it because "*" does not
|
||||
// count as an explicit allowlist entry.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
allowed_commands: vec!["*".into()],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
@@ -1515,6 +1562,100 @@ mod tests {
|
||||
assert!(result.unwrap_err().contains("high-risk"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_allows_explicitly_listed_high_risk() {
|
||||
// When a high-risk command is explicitly in allowed_commands, the
|
||||
// block_high_risk_commands gate is bypassed — the operator has made
|
||||
// a deliberate decision to permit it.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("curl https://api.example.com/data", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_allows_wget_when_explicitly_listed() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["wget".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result =
|
||||
p.validate_command_execution("wget https://releases.example.com/v1.tar.gz", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_blocks_non_listed_high_risk_when_another_is_allowed() {
|
||||
// Allowing curl explicitly should not exempt wget.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("wget https://evil.com", true);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_explicit_rm_bypasses_high_risk_block() {
|
||||
// Operator explicitly listed "rm" — they accept the risk.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("rm -rf /tmp/test", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_high_risk_still_needs_approval_in_supervised() {
|
||||
// Even when explicitly allowed, supervised mode still requires
|
||||
// approval for high-risk commands (the approval gate is separate
|
||||
// from the block gate).
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let denied = p.validate_command_execution("curl https://api.example.com", false);
|
||||
assert!(denied.is_err());
|
||||
assert!(denied.unwrap_err().contains("requires explicit approval"));
|
||||
|
||||
let allowed = p.validate_command_execution("curl https://api.example.com", true);
|
||||
assert_eq!(allowed.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_pipe_needs_all_segments_explicitly_allowed() {
|
||||
// When a pipeline contains a high-risk command, every segment
|
||||
// must be explicitly allowed for the exemption to apply.
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into(), "grep".into()],
|
||||
block_high_risk_commands: true,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
let result = p.validate_command_execution("curl https://api.example.com | grep data", true);
|
||||
assert_eq!(result.unwrap(), CommandRiskLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_command_full_mode_skips_medium_risk_approval_gate() {
|
||||
let p = SecurityPolicy {
|
||||
|
||||
@@ -12,6 +12,7 @@ pub struct HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
}
|
||||
|
||||
impl HttpRequestTool {
|
||||
@@ -20,12 +21,14 @@ impl HttpRequestTool {
|
||||
allowed_domains: Vec<String>,
|
||||
max_response_size: usize,
|
||||
timeout_secs: u64,
|
||||
allow_private_hosts: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
allowed_domains: normalize_allowed_domains(allowed_domains),
|
||||
max_response_size,
|
||||
timeout_secs,
|
||||
allow_private_hosts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +55,7 @@ impl HttpRequestTool {
|
||||
|
||||
let host = extract_host(url)?;
|
||||
|
||||
if is_private_or_local_host(&host) {
|
||||
if !self.allow_private_hosts && is_private_or_local_host(&host) {
|
||||
anyhow::bail!("Blocked local/private host: {host}");
|
||||
}
|
||||
|
||||
@@ -454,6 +457,13 @@ mod tests {
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool {
|
||||
test_tool_with_private(allowed_domains, false)
|
||||
}
|
||||
|
||||
fn test_tool_with_private(
|
||||
allowed_domains: Vec<&str>,
|
||||
allow_private_hosts: bool,
|
||||
) -> HttpRequestTool {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
..SecurityPolicy::default()
|
||||
@@ -463,6 +473,7 @@ mod tests {
|
||||
allowed_domains.into_iter().map(String::from).collect(),
|
||||
1_000_000,
|
||||
30,
|
||||
allow_private_hosts,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -570,7 +581,7 @@ mod tests {
|
||||
#[test]
|
||||
fn validate_requires_allowlist() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30, false);
|
||||
let err = tool
|
||||
.validate_url("https://example.com")
|
||||
.unwrap_err()
|
||||
@@ -686,7 +697,7 @@ mod tests {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({"url": "https://example.com"}))
|
||||
.await
|
||||
@@ -701,7 +712,7 @@ mod tests {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30);
|
||||
let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({"url": "https://example.com"}))
|
||||
.await
|
||||
@@ -724,6 +735,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
10,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "hello world this is long";
|
||||
let truncated = tool.truncate_response(text);
|
||||
@@ -738,6 +750,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
0, // max_response_size = 0 means no limit
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "a".repeat(10_000_000);
|
||||
assert_eq!(tool.truncate_response(&text), text);
|
||||
@@ -750,6 +763,7 @@ mod tests {
|
||||
vec!["example.com".into()],
|
||||
5,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let text = "hello world";
|
||||
let truncated = tool.truncate_response(text);
|
||||
@@ -935,4 +949,70 @@ mod tests {
|
||||
.to_string();
|
||||
assert!(err.contains("IPv6"));
|
||||
}
|
||||
|
||||
// ── allow_private_hosts opt-in tests ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn default_blocks_private_hosts() {
|
||||
let tool = test_tool(vec!["localhost", "192.168.1.5", "*"]);
|
||||
assert!(tool
|
||||
.validate_url("https://localhost:8080")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
assert!(tool
|
||||
.validate_url("https://192.168.1.5")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
assert!(tool
|
||||
.validate_url("https://10.0.0.1")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_localhost() {
|
||||
let tool = test_tool_with_private(vec!["localhost"], true);
|
||||
assert!(tool.validate_url("https://localhost:8080").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_private_ipv4() {
|
||||
let tool = test_tool_with_private(vec!["192.168.1.5"], true);
|
||||
assert!(tool.validate_url("https://192.168.1.5").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_permits_rfc1918_with_wildcard() {
|
||||
let tool = test_tool_with_private(vec!["*"], true);
|
||||
assert!(tool.validate_url("https://10.0.0.1").is_ok());
|
||||
assert!(tool.validate_url("https://172.16.0.1").is_ok());
|
||||
assert!(tool.validate_url("https://192.168.1.1").is_ok());
|
||||
assert!(tool.validate_url("http://localhost:8123").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_still_requires_allowlist() {
|
||||
let tool = test_tool_with_private(vec!["example.com"], true);
|
||||
let err = tool
|
||||
.validate_url("https://192.168.1.5")
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
assert!(
|
||||
err.contains("allowed_domains"),
|
||||
"Private host should still need allowlist match, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_private_hosts_false_still_blocks() {
|
||||
let tool = test_tool_with_private(vec!["*"], false);
|
||||
assert!(tool
|
||||
.validate_url("https://localhost:8080")
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("local/private"));
|
||||
}
|
||||
}
|
||||
|
||||
+38
-18
@@ -57,6 +57,7 @@ pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod tool_search;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
@@ -103,6 +104,7 @@ pub use schedule::ScheduleTool;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
pub use traits::Tool;
|
||||
#[allow(unused_imports)]
|
||||
@@ -314,6 +316,7 @@ pub fn all_tools_with_runtime(
|
||||
http_config.allowed_domains.clone(),
|
||||
http_config.max_response_size,
|
||||
http_config.timeout_secs,
|
||||
http_config.allow_private_hosts,
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -357,6 +360,24 @@ pub fn all_tools_with_runtime(
|
||||
}
|
||||
|
||||
// Add delegation tool when agents are configured
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
};
|
||||
|
||||
let delegate_handle: Option<DelegateParentToolsHandle> = if agents.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -364,28 +385,12 @@ pub fn all_tools_with_runtime(
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
let parent_tools = Arc::new(RwLock::new(tool_arcs.clone()));
|
||||
let delegate_tool = DelegateTool::new_with_options(
|
||||
delegate_agents,
|
||||
delegate_fallback_credential,
|
||||
delegate_fallback_credential.clone(),
|
||||
security.clone(),
|
||||
crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
.map(std::path::PathBuf::from),
|
||||
secrets_encrypt: root_config.secrets.encrypt,
|
||||
reasoning_enabled: root_config.runtime.reasoning_enabled,
|
||||
provider_timeout_secs: Some(root_config.provider_timeout_secs),
|
||||
extra_headers: root_config.extra_headers.clone(),
|
||||
api_path: root_config.api_path.clone(),
|
||||
},
|
||||
provider_runtime_options.clone(),
|
||||
)
|
||||
.with_parent_tools(Arc::clone(&parent_tools))
|
||||
.with_multimodal_config(root_config.multimodal.clone());
|
||||
@@ -393,6 +398,21 @@ pub fn all_tools_with_runtime(
|
||||
Some(parent_tools)
|
||||
};
|
||||
|
||||
// Add swarm tool when swarms are configured
|
||||
if !root_config.swarms.is_empty() {
|
||||
let swarm_agents: HashMap<String, DelegateAgentConfig> = agents
|
||||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
tool_arcs.push(Arc::new(SwarmTool::new(
|
||||
root_config.swarms.clone(),
|
||||
swarm_agents,
|
||||
delegate_fallback_credential,
|
||||
security.clone(),
|
||||
provider_runtime_options,
|
||||
)));
|
||||
}
|
||||
|
||||
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,953 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::{DelegateAgentConfig, SwarmConfig, SwarmStrategy};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default timeout for individual agent calls within a swarm.
|
||||
const SWARM_AGENT_TIMEOUT_SECS: u64 = 120;
|
||||
|
||||
/// Tool that orchestrates multiple agents as a swarm. Supports sequential
|
||||
/// (pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies.
|
||||
pub struct SwarmTool {
|
||||
swarms: Arc<HashMap<String, SwarmConfig>>,
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
fallback_credential: Option<String>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
}
|
||||
|
||||
impl SwarmTool {
|
||||
pub fn new(
|
||||
swarms: HashMap<String, SwarmConfig>,
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_credential: Option<String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
) -> Self {
|
||||
Self {
|
||||
swarms: Arc::new(swarms),
|
||||
agents: Arc::new(agents),
|
||||
security,
|
||||
fallback_credential,
|
||||
provider_runtime_options,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider_for_agent(
|
||||
&self,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
agent_name: &str,
|
||||
) -> Result<Box<dyn Provider>, ToolResult> {
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
)
|
||||
.map_err(|e| ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider '{}' for agent '{agent_name}': {e}",
|
||||
agent_config.provider
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn call_agent(
|
||||
&self,
|
||||
agent_name: &str,
|
||||
agent_config: &DelegateAgentConfig,
|
||||
prompt: &str,
|
||||
timeout_secs: u64,
|
||||
) -> Result<String, String> {
|
||||
let provider = self
|
||||
.create_provider_for_agent(agent_config, agent_name)
|
||||
.map_err(|r| r.error.unwrap_or_default())?;
|
||||
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
provider.chat_with_system(
|
||||
agent_config.system_prompt.as_deref(),
|
||||
prompt,
|
||||
&agent_config.model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(response)) => {
|
||||
if response.trim().is_empty() {
|
||||
Ok("[Empty response]".to_string())
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => Err(format!("Agent '{agent_name}' failed: {e}")),
|
||||
Err(_) => Err(format!(
|
||||
"Agent '{agent_name}' timed out after {timeout_secs}s"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_sequential(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let mut current_input = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let per_agent_timeout = swarm_config.timeout_secs / swarm_config.agents.len().max(1) as u64;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (i, agent_name) in swarm_config.agents.iter().enumerate() {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let agent_prompt = if i == 0 {
|
||||
current_input.clone()
|
||||
} else {
|
||||
format!("[Previous agent output]\n{current_input}\n\n[Original task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(agent_name, agent_config, &agent_prompt, per_agent_timeout)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
results.push(format!(
|
||||
"[{agent_name} ({}/{})] {output}",
|
||||
agent_config.provider, agent_config.model
|
||||
));
|
||||
current_input = output;
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: results.join("\n\n"),
|
||||
error: Some(e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm sequential — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_parallel(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for agent_name in &swarm_config.agents {
|
||||
let agent_config = match self.agents.get(agent_name) {
|
||||
Some(cfg) => cfg.clone(),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm references unknown agent '{agent_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let credential = agent_config
|
||||
.api_key
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
|
||||
let provider = match providers::create_provider_with_options(
|
||||
&agent_config.provider,
|
||||
credential.as_deref(),
|
||||
&self.provider_runtime_options,
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to create provider for agent '{agent_name}': {e}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let name = agent_name.clone();
|
||||
let prompt_clone = full_prompt.clone();
|
||||
let timeout = swarm_config.timeout_secs;
|
||||
let model = agent_config.model.clone();
|
||||
let temperature = agent_config.temperature.unwrap_or(0.7);
|
||||
let system_prompt = agent_config.system_prompt.clone();
|
||||
let provider_name = agent_config.provider.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(timeout),
|
||||
provider.chat_with_system(
|
||||
system_prompt.as_deref(),
|
||||
&prompt_clone,
|
||||
&model,
|
||||
temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let output = match result {
|
||||
Ok(Ok(text)) => {
|
||||
if text.trim().is_empty() {
|
||||
"[Empty response]".to_string()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => format!("[Error] {e}"),
|
||||
Err(_) => format!("[Timed out after {timeout}s]"),
|
||||
};
|
||||
|
||||
(name, provider_name, model, output)
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(join_result) = join_set.join_next().await {
|
||||
match join_result {
|
||||
Ok((name, provider_name, model, output)) => {
|
||||
results.push(format!("[{name} ({provider_name}/{model})]\n{output}"));
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(format!("[join error] {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm parallel — {} agents]\n\n{}",
|
||||
swarm_config.agents.len(),
|
||||
results.join("\n\n---\n\n")
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_router(
|
||||
&self,
|
||||
swarm_config: &SwarmConfig,
|
||||
prompt: &str,
|
||||
context: &str,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router swarm has no agents to choose from".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build agent descriptions for the router prompt
|
||||
let agent_descriptions: Vec<String> = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.filter_map(|name| {
|
||||
self.agents.get(name).map(|cfg| {
|
||||
let desc = cfg
|
||||
.system_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("General purpose agent");
|
||||
format!(
|
||||
"- {name}: {desc} (provider: {}, model: {})",
|
||||
cfg.provider, cfg.model
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Use the first agent's provider for routing
|
||||
let first_agent_name = &swarm_config.agents[0];
|
||||
let first_agent_config = match self.agents.get(first_agent_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Swarm references unknown agent '{first_agent_name}'"
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let router_provider = self
|
||||
.create_provider_for_agent(first_agent_config, first_agent_name)
|
||||
.map_err(|r| anyhow::anyhow!(r.error.unwrap_or_default()))?;
|
||||
|
||||
let base_router_prompt = swarm_config
|
||||
.router_prompt
|
||||
.as_deref()
|
||||
.unwrap_or("Pick the single best agent for this task.");
|
||||
|
||||
let routing_prompt = format!(
|
||||
"{base_router_prompt}\n\nAvailable agents:\n{}\n\nUser task: {prompt}\n\n\
|
||||
Respond with ONLY the agent name, nothing else.",
|
||||
agent_descriptions.join("\n")
|
||||
);
|
||||
|
||||
let chosen = tokio::time::timeout(
|
||||
Duration::from_secs(SWARM_AGENT_TIMEOUT_SECS),
|
||||
router_provider.chat_with_system(
|
||||
Some("You are a routing assistant. Respond with only the agent name."),
|
||||
&routing_prompt,
|
||||
&first_agent_config.model,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
let chosen_name = match chosen {
|
||||
Ok(Ok(name)) => name.trim().to_string(),
|
||||
Ok(Err(e)) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router LLM call failed: {e}")),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Router LLM call timed out".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Case-insensitive matching with fallback to first agent
|
||||
let matched_name = swarm_config
|
||||
.agents
|
||||
.iter()
|
||||
.find(|name| name.eq_ignore_ascii_case(&chosen_name))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| swarm_config.agents[0].clone());
|
||||
|
||||
let agent_config = match self.agents.get(&matched_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Router selected unknown agent '{matched_name}'")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let full_prompt = if context.is_empty() {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
format!("[Context]\n{context}\n\n[Task]\n{prompt}")
|
||||
};
|
||||
|
||||
match self
|
||||
.call_agent(
|
||||
&matched_name,
|
||||
agent_config,
|
||||
&full_prompt,
|
||||
swarm_config.timeout_secs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"[Swarm router — selected '{matched_name}' ({}/{})]\n{output}",
|
||||
agent_config.provider, agent_config.model
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SwarmTool {
|
||||
fn name(&self) -> &str {
|
||||
"swarm"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Orchestrate a swarm of agents to collaboratively handle a task. Supports sequential \
|
||||
(pipeline), parallel (fan-out/fan-in), and router (LLM-selected) strategies."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
let swarm_names: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
json!({
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"swarm": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": format!(
|
||||
"Name of the swarm to invoke. Available: {}",
|
||||
if swarm_names.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
swarm_names.join(", ")
|
||||
}
|
||||
)
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "The task/prompt to send to the swarm"
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Optional context to include (e.g. relevant code, prior findings)"
|
||||
}
|
||||
},
|
||||
"required": ["swarm", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let swarm_name = args
|
||||
.get("swarm")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'swarm' parameter"))?;
|
||||
|
||||
if swarm_name.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'swarm' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let prompt = args
|
||||
.get("prompt")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?;
|
||||
|
||||
if prompt.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'prompt' parameter must not be empty".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let context = args
|
||||
.get("context")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.unwrap_or("");
|
||||
|
||||
let swarm_config = match self.swarms.get(swarm_name) {
|
||||
Some(cfg) => cfg,
|
||||
None => {
|
||||
let available: Vec<&str> = self.swarms.keys().map(String::as_str).collect();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown swarm '{swarm_name}'. Available swarms: {}",
|
||||
if available.is_empty() {
|
||||
"(none configured)".to_string()
|
||||
} else {
|
||||
available.join(", ")
|
||||
}
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if swarm_config.agents.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Swarm '{swarm_name}' has no agents configured")),
|
||||
});
|
||||
}
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "swarm")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match swarm_config.strategy {
|
||||
SwarmStrategy::Sequential => {
|
||||
self.execute_sequential(swarm_config, prompt, context).await
|
||||
}
|
||||
SwarmStrategy::Parallel => self.execute_parallel(swarm_config, prompt, context).await,
|
||||
SwarmStrategy::Router => self.execute_router(swarm_config, prompt, context).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn sample_agents() -> HashMap<String, DelegateAgentConfig> {
|
||||
let mut agents = HashMap::new();
|
||||
agents.insert(
|
||||
"researcher".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "llama3".to_string(),
|
||||
system_prompt: Some("You are a research assistant.".to_string()),
|
||||
api_key: None,
|
||||
temperature: Some(0.3),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents.insert(
|
||||
"writer".to_string(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
system_prompt: Some("You are a technical writer.".to_string()),
|
||||
api_key: Some("test-key".to_string()),
|
||||
temperature: Some(0.5),
|
||||
max_depth: 3,
|
||||
agentic: false,
|
||||
allowed_tools: Vec::new(),
|
||||
max_iterations: 10,
|
||||
},
|
||||
);
|
||||
agents
|
||||
}
|
||||
|
||||
fn sample_swarms() -> HashMap<String, SwarmConfig> {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"pipeline".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: Some("Research then write".to_string()),
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"fanout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms.insert(
|
||||
"router".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string(), "writer".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick the best agent.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 300,
|
||||
},
|
||||
);
|
||||
swarms
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert_eq!(tool.name(), "swarm");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["swarm"].is_object());
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
assert!(schema["properties"]["context"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&json!("swarm")));
|
||||
assert!(required.contains(&json!("prompt")));
|
||||
assert_eq!(schema["additionalProperties"], json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn description_not_empty() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_lists_swarm_names() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("pipeline") || desc.contains("fanout") || desc.contains("router"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_swarms_schema() {
|
||||
let tool = SwarmTool::new(
|
||||
HashMap::new(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["swarm"]["description"]
|
||||
.as_str()
|
||||
.unwrap();
|
||||
assert!(desc.contains("none configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_swarm_returns_error() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "nonexistent", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown swarm"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_swarm_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_param() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool.execute(json!({"swarm": "pipeline"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_swarm_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": " ", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blank_prompt_rejected() {
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": " "}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("must not be empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_missing_agent_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"broken".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["nonexistent_agent".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "broken", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("unknown agent"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_with_empty_agents_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"empty".to_string(),
|
||||
SwarmConfig {
|
||||
agents: Vec::new(),
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "empty", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("no agents configured"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_in_readonly_mode() {
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
readonly,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn swarm_blocked_when_rate_limited() {
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = SwarmTool::new(
|
||||
sample_swarms(),
|
||||
sample_agents(),
|
||||
None,
|
||||
limited,
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "pipeline", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"seq".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Sequential,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
// researcher uses "ollama" which won't be running in CI
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "seq", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Should fail at provider creation or call level
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"par".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Parallel,
|
||||
router_prompt: None,
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "par", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
// Parallel strategy returns success with error annotations in output
|
||||
assert!(result.success || result.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn router_invalid_provider_returns_error() {
|
||||
let mut swarms = HashMap::new();
|
||||
swarms.insert(
|
||||
"rout".to_string(),
|
||||
SwarmConfig {
|
||||
agents: vec!["researcher".to_string()],
|
||||
strategy: SwarmStrategy::Router,
|
||||
router_prompt: Some("Pick.".to_string()),
|
||||
description: None,
|
||||
timeout_secs: 60,
|
||||
},
|
||||
);
|
||||
let tool = SwarmTool::new(
|
||||
swarms,
|
||||
sample_agents(),
|
||||
None,
|
||||
test_security(),
|
||||
providers::ProviderRuntimeOptions::default(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"swarm": "rout", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
}
|
||||
+3
-1
@@ -71,7 +71,9 @@ export class WebSocketClient {
|
||||
params.set('session_id', sessionId);
|
||||
const url = `${this.baseUrl}/ws/chat?${params.toString()}`;
|
||||
|
||||
this.ws = new WebSocket(url, ['zeroclaw.v1']);
|
||||
const protocols: string[] = ['zeroclaw.v1'];
|
||||
if (token) protocols.push(`bearer.${token}`);
|
||||
this.ws = new WebSocket(url, protocols);
|
||||
|
||||
this.ws.onopen = () => {
|
||||
this.currentDelay = this.reconnectDelay;
|
||||
|
||||
Reference in New Issue
Block a user