Merge remote-tracking branch 'origin/main'

# Conflicts:
#	src/channels/mod.rs
#	src/config/mod.rs
#	src/config/schema.rs
This commit is contained in:
VirtualHotBar
2026-02-28 16:12:02 +08:00
91 changed files with 9374 additions and 895 deletions
+32 -1
View File
@@ -1,6 +1,7 @@
use crate::agent::dispatcher::{
NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher,
};
use crate::agent::loop_::detection::{DetectionVerdict, LoopDetectionConfig, LoopDetector};
use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader};
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
use crate::agent::research;
@@ -557,8 +558,13 @@ impl Agent {
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
let effective_model = self.classify_model(user_message);
let mut loop_detector = LoopDetector::new(LoopDetectionConfig {
no_progress_threshold: self.config.loop_detection_no_progress_threshold,
ping_pong_cycles: self.config.loop_detection_ping_pong_cycles,
failure_streak_threshold: self.config.loop_detection_failure_streak,
});
for _ in 0..self.config.max_tool_iterations {
for iteration in 0..self.config.max_tool_iterations {
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
let response = match self
.provider
@@ -613,9 +619,34 @@ impl Agent {
});
let results = self.execute_tools(&calls).await;
// ── Loop detection: record calls ─────────────────────
for (call, result) in calls.iter().zip(results.iter()) {
let args_sig =
serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".into());
loop_detector.record_call(&call.name, &args_sig, &result.output, result.success);
}
let formatted = self.tool_dispatcher.format_results(&results);
self.history.push(formatted);
self.trim_history();
// ── Loop detection: check verdict ────────────────────
match loop_detector.check() {
DetectionVerdict::Continue => {}
DetectionVerdict::InjectWarning(warning) => {
self.history
.push(ConversationMessage::Chat(ChatMessage::user(warning)));
}
DetectionVerdict::HardStop(reason) => {
anyhow::bail!(
"Agent stopped early due to detected loop pattern (iteration {}/{}): {}",
iteration + 1,
self.config.max_tool_iterations,
reason
);
}
}
}
anyhow::bail!(
+135 -38
View File
@@ -28,12 +28,14 @@ use tokio_util::sync::CancellationToken;
use uuid::Uuid;
mod context;
pub(crate) mod detection;
mod execution;
mod history;
mod parsing;
use crate::agent::session::{create_session_manager, resolve_session_id, SessionManager};
use context::{build_context, build_hardware_context};
use detection::{DetectionVerdict, LoopDetectionConfig, LoopDetector};
use execution::{
execute_tools_parallel, execute_tools_sequential, should_execute_tools_in_parallel,
ToolExecutionOutcome,
@@ -314,6 +316,7 @@ pub(crate) struct NonCliApprovalContext {
tokio::task_local! {
static TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT: Option<NonCliApprovalContext>;
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
}
/// Extract a short hint from tool call arguments for progress display.
@@ -599,6 +602,14 @@ pub(crate) fn is_tool_iteration_limit_error(err: &anyhow::Error) -> bool {
})
}
pub(crate) fn is_loop_detection_error(err: &anyhow::Error) -> bool {
err.chain().any(|source| {
source
.to_string()
.contains("Agent stopped early due to detected loop pattern")
})
}
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
/// When `silent` is true, suppresses stdout (for channel use).
@@ -799,6 +810,11 @@ pub(crate) async fn run_tool_call_loop(
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
let mut missing_tool_call_retry_used = false;
let mut missing_tool_call_retry_prompt: Option<String> = None;
let ld_config = LOOP_DETECTION_CONFIG
.try_with(Clone::clone)
.unwrap_or_default();
let mut loop_detector = LoopDetector::new(ld_config);
let mut loop_detection_prompt: Option<String> = None;
let bypass_non_cli_approval_for_turn =
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
if bypass_non_cli_approval_for_turn {
@@ -842,6 +858,9 @@ pub(crate) async fn run_tool_call_loop(
if let Some(prompt) = missing_tool_call_retry_prompt.take() {
request_messages.push(ChatMessage::user(prompt));
}
if let Some(prompt) = loop_detection_prompt.take() {
request_messages.push(ChatMessage::user(prompt));
}
// ── Progress: LLM thinking ────────────────────────────
if let Some(ref tx) = on_delta {
@@ -1469,6 +1488,12 @@ pub(crate) async fn run_tool_call_loop(
.await;
}
// ── Loop detection: record call ──────────────────────
{
let sig = tool_call_signature(&call.name, &call.arguments);
loop_detector.record_call(&sig.0, &sig.1, &outcome.output, outcome.success);
}
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
}
@@ -1514,6 +1539,49 @@ pub(crate) async fn run_tool_call_loop(
history.push(ChatMessage::tool(tool_msg.to_string()));
}
}
// ── Loop detection: check verdict ────────────────────────
match loop_detector.check() {
DetectionVerdict::Continue => {}
DetectionVerdict::InjectWarning(warning) => {
runtime_trace::record_event(
"loop_detected_warning",
Some(channel_name),
Some(provider_name),
Some(model),
Some(&turn_id),
Some(false),
Some("loop pattern detected, injecting self-correction prompt"),
serde_json::json!({ "iteration": iteration + 1, "warning": &warning }),
);
if let Some(ref tx) = on_delta {
let _ = tx
.send(format!(
"{DRAFT_PROGRESS_SENTINEL}\u{26a0}\u{fe0f} Loop detected, attempting self-correction\n"
))
.await;
}
loop_detection_prompt = Some(warning);
}
DetectionVerdict::HardStop(reason) => {
runtime_trace::record_event(
"loop_detected_hard_stop",
Some(channel_name),
Some(provider_name),
Some(model),
Some(&turn_id),
Some(false),
Some("loop persisted after warning, stopping early"),
serde_json::json!({ "iteration": iteration + 1, "reason": &reason }),
);
anyhow::bail!(
"Agent stopped early due to detected loop pattern (iteration {}/{}): {}",
iteration + 1,
max_iterations,
reason
);
}
}
}
runtime_trace::record_event(
@@ -1732,6 +1800,7 @@ pub async fn run(
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: config.api_url.clone(),
provider_transport: config.effective_provider_transport(),
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
reasoning_enabled: config.runtime.reasoning_enabled,
@@ -1956,25 +2025,34 @@ pub async fn run(
ChatMessage::user(&enriched),
];
let response = run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
)
.await?;
let ld_cfg = LoopDetectionConfig {
no_progress_threshold: config.agent.loop_detection_no_progress_threshold,
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
failure_streak_threshold: config.agent.loop_detection_failure_streak,
};
let response = LOOP_DETECTION_CONFIG
.scope(
ld_cfg,
run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
),
)
.await?;
final_output = response.clone();
println!("{response}");
observer.record_event(&ObserverEvent::TurnComplete);
@@ -2081,25 +2159,34 @@ pub async fn run(
history.push(ChatMessage::user(&enriched));
let response = match run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
)
.await
let ld_cfg = LoopDetectionConfig {
no_progress_threshold: config.agent.loop_detection_no_progress_threshold,
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
failure_streak_threshold: config.agent.loop_detection_failure_streak,
};
let response = match LOOP_DETECTION_CONFIG
.scope(
ld_cfg,
run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
),
)
.await
{
Ok(resp) => resp,
Err(e) => {
@@ -2113,6 +2200,15 @@ pub async fn run(
eprintln!("\n{pause_notice}\n");
continue;
}
if is_loop_detection_error(&e) {
let notice =
"\u{26a0}\u{fe0f} Loop pattern detected and agent stopped early. \
Context preserved. Reply \"continue\" to resume, or adjust \
loop_detection_* thresholds in config.";
history.push(ChatMessage::assistant(notice));
eprintln!("\n{notice}\n");
continue;
}
eprintln!("\nError: {e}\n");
continue;
}
@@ -2218,6 +2314,7 @@ pub async fn process_message(
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: config.api_url.clone(),
provider_transport: config.effective_provider_transport(),
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
reasoning_enabled: config.runtime.reasoning_enabled,
+389
View File
@@ -0,0 +1,389 @@
//! Loop detection for the agent tool-call loop.
//!
//! Detects three patterns of unproductive looping:
//! 1. **No-progress repeat** — same tool + same args + same output hash.
//! 2. **Ping-pong** — two calls alternating (A→B→A→B) with no progress.
//! 3. **Consecutive failure streak** — same tool failing repeatedly.
//!
//! On first detection an `InjectWarning` verdict gives the LLM a chance to
//! self-correct. If the pattern persists the next check returns `HardStop`.
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
/// Maximum bytes of tool output considered when hashing results.
/// Keeps hashing fast and bounded for large outputs.
const OUTPUT_HASH_PREFIX_BYTES: usize = 4096;
// ─── Configuration ───────────────────────────────────────────────────────────
/// Tuning knobs for each detection strategy.
#[derive(Debug, Clone)]
pub(crate) struct LoopDetectionConfig {
/// Identical (tool + args + output) repetitions before triggering.
/// `0` = disabled. Default: `3`.
pub no_progress_threshold: usize,
/// Full A-B cycles before triggering ping-pong detection.
/// `0` = disabled. Default: `2`.
pub ping_pong_cycles: usize,
/// Consecutive failures of the *same* tool before triggering.
/// `0` = disabled. Default: `3`.
pub failure_streak_threshold: usize,
}
impl Default for LoopDetectionConfig {
fn default() -> Self {
Self {
no_progress_threshold: 3,
ping_pong_cycles: 2,
failure_streak_threshold: 3,
}
}
}
// ─── Verdict ─────────────────────────────────────────────────────────────────
/// Action the caller should take after `LoopDetector::check()`.
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum DetectionVerdict {
/// No loop detected — proceed normally.
Continue,
/// First detection — inject this self-correction prompt, then continue.
InjectWarning(String),
/// Pattern persisted after warning — terminate the loop.
HardStop(String),
}
// ─── Internal record ─────────────────────────────────────────────────────────
struct CallRecord {
tool_name: String,
args_sig: String,
result_hash: u64,
success: bool,
}
// ─── Detector ────────────────────────────────────────────────────────────────
pub(crate) struct LoopDetector {
config: LoopDetectionConfig,
history: Vec<CallRecord>,
consecutive_failures: HashMap<String, usize>,
warning_injected: bool,
}
impl LoopDetector {
pub fn new(config: LoopDetectionConfig) -> Self {
Self {
config,
history: Vec::new(),
consecutive_failures: HashMap::new(),
warning_injected: false,
}
}
/// Record a completed tool invocation.
///
/// * `tool_name` — canonical tool name (lowercased by caller).
/// * `args_sig` — canonical JSON args string from `tool_call_signature()`.
/// * `output` — raw tool output text.
/// * `success` — whether the tool reported success.
pub fn record_call(&mut self, tool_name: &str, args_sig: &str, output: &str, success: bool) {
let result_hash = hash_output(output);
self.history.push(CallRecord {
tool_name: tool_name.to_owned(),
args_sig: args_sig.to_owned(),
result_hash,
success,
});
if success {
self.consecutive_failures.remove(tool_name);
} else {
*self
.consecutive_failures
.entry(tool_name.to_owned())
.or_insert(0) += 1;
}
}
/// Evaluate the current history and return a verdict.
pub fn check(&mut self) -> DetectionVerdict {
let reason = self
.check_no_progress_repeat()
.or_else(|| self.check_ping_pong())
.or_else(|| self.check_failure_streak());
match reason {
None => DetectionVerdict::Continue,
Some(msg) => {
if self.warning_injected {
DetectionVerdict::HardStop(msg)
} else {
self.warning_injected = true;
DetectionVerdict::InjectWarning(format_warning(&msg))
}
}
}
}
// ── Strategy 1: no-progress repeat ───────────────────────────────────
fn check_no_progress_repeat(&self) -> Option<String> {
let threshold = self.config.no_progress_threshold;
if threshold == 0 || self.history.is_empty() {
return None;
}
let last = self.history.last().unwrap();
let streak = self
.history
.iter()
.rev()
.take_while(|r| {
r.tool_name == last.tool_name
&& r.args_sig == last.args_sig
&& r.result_hash == last.result_hash
})
.count();
if streak >= threshold {
Some(format!(
"Tool '{}' called {} times with identical arguments and identical results \
— no progress detected",
last.tool_name, streak
))
} else {
None
}
}
// ── Strategy 2: ping-pong ────────────────────────────────────────────
fn check_ping_pong(&self) -> Option<String> {
let cycles = self.config.ping_pong_cycles;
if cycles == 0 || self.history.len() < 4 {
return None;
}
let len = self.history.len();
let a = &self.history[len - 2];
let b = &self.history[len - 1];
// The two sides of the ping-pong must differ.
if a.tool_name == b.tool_name && a.args_sig == b.args_sig {
return None;
}
let min_entries = cycles * 2;
if len < min_entries {
return None;
}
let tail = &self.history[len - min_entries..];
let is_ping_pong = tail.chunks(2).all(|pair| {
pair.len() == 2
&& pair[0].tool_name == a.tool_name
&& pair[0].args_sig == a.args_sig
&& pair[0].result_hash == a.result_hash
&& pair[1].tool_name == b.tool_name
&& pair[1].args_sig == b.args_sig
&& pair[1].result_hash == b.result_hash
});
if is_ping_pong {
Some(format!(
"Ping-pong loop detected: '{}' and '{}' alternating {} times with no progress",
a.tool_name, b.tool_name, cycles
))
} else {
None
}
}
// ── Strategy 3: consecutive failure streak ───────────────────────────
fn check_failure_streak(&self) -> Option<String> {
let threshold = self.config.failure_streak_threshold;
if threshold == 0 {
return None;
}
for (tool, count) in &self.consecutive_failures {
if *count >= threshold {
return Some(format!(
"Tool '{}' failed {} consecutive times",
tool, count
));
}
}
None
}
}
// ─── Helpers ─────────────────────────────────────────────────────────────────
fn hash_output(output: &str) -> u64 {
let prefix = if output.len() > OUTPUT_HASH_PREFIX_BYTES {
&output[..OUTPUT_HASH_PREFIX_BYTES]
} else {
output
};
let mut hasher = DefaultHasher::new();
prefix.hash(&mut hasher);
hasher.finish()
}
fn format_warning(reason: &str) -> String {
format!(
"IMPORTANT: A loop pattern has been detected in your tool usage. {reason}. \
You must change your approach: \
(1) Try a different tool or different arguments, \
(2) If polling a process, increase wait time or check if it's stuck, \
(3) If the task cannot be completed, explain why and stop. \
Do NOT repeat the same tool call with the same arguments."
)
}
// ─── Unit tests ──────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> LoopDetectionConfig {
LoopDetectionConfig::default()
}
fn disabled_config() -> LoopDetectionConfig {
LoopDetectionConfig {
no_progress_threshold: 0,
ping_pong_cycles: 0,
failure_streak_threshold: 0,
}
}
// 1. Below threshold → Continue
#[test]
fn below_threshold_does_not_trigger() {
let mut det = LoopDetector::new(default_config());
det.record_call("echo", r#"{"msg":"hi"}"#, "hello", true);
det.record_call("echo", r#"{"msg":"hi"}"#, "hello", true);
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 2. No-progress repeat triggers warning at threshold
#[test]
fn no_progress_repeat_triggers_warning() {
let mut det = LoopDetector::new(default_config());
for _ in 0..3 {
det.record_call("echo", r#"{"msg":"hi"}"#, "hello", true);
}
match det.check() {
DetectionVerdict::InjectWarning(msg) => {
assert!(msg.contains("no progress"), "msg: {msg}");
}
other => panic!("expected InjectWarning, got {other:?}"),
}
}
// 3. Same input but different output → no trigger (progress detected)
#[test]
fn same_input_different_output_does_not_trigger() {
let mut det = LoopDetector::new(default_config());
det.record_call("echo", r#"{"msg":"hi"}"#, "result_1", true);
det.record_call("echo", r#"{"msg":"hi"}"#, "result_2", true);
det.record_call("echo", r#"{"msg":"hi"}"#, "result_3", true);
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 4. Warning then continued loop → HardStop
#[test]
fn warning_then_continued_loop_triggers_hard_stop() {
let mut det = LoopDetector::new(default_config());
for _ in 0..3 {
det.record_call("echo", r#"{"msg":"hi"}"#, "same", true);
}
assert!(matches!(det.check(), DetectionVerdict::InjectWarning(_)));
// One more identical call
det.record_call("echo", r#"{"msg":"hi"}"#, "same", true);
match det.check() {
DetectionVerdict::HardStop(msg) => {
assert!(msg.contains("no progress"), "msg: {msg}");
}
other => panic!("expected HardStop, got {other:?}"),
}
}
// 5. Ping-pong detection
#[test]
fn ping_pong_triggers_warning() {
let mut det = LoopDetector::new(default_config());
// 2 cycles: A-B-A-B
det.record_call("tool_a", r#"{"x":1}"#, "out_a", true);
det.record_call("tool_b", r#"{"y":2}"#, "out_b", true);
det.record_call("tool_a", r#"{"x":1}"#, "out_a", true);
det.record_call("tool_b", r#"{"y":2}"#, "out_b", true);
match det.check() {
DetectionVerdict::InjectWarning(msg) => {
assert!(msg.contains("Ping-pong"), "msg: {msg}");
}
other => panic!("expected InjectWarning, got {other:?}"),
}
}
// 6. Ping-pong with progress does not trigger
#[test]
fn ping_pong_with_progress_does_not_trigger() {
let mut det = LoopDetector::new(default_config());
det.record_call("tool_a", r#"{"x":1}"#, "out_a_1", true);
det.record_call("tool_b", r#"{"y":2}"#, "out_b_1", true);
det.record_call("tool_a", r#"{"x":1}"#, "out_a_2", true); // different output
det.record_call("tool_b", r#"{"y":2}"#, "out_b_2", true); // different output
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 7. Consecutive failure streak (different args each time to avoid no-progress trigger)
#[test]
fn failure_streak_triggers_warning() {
let mut det = LoopDetector::new(default_config());
det.record_call("shell", r#"{"cmd":"bad1"}"#, "error: not found 1", false);
det.record_call("shell", r#"{"cmd":"bad2"}"#, "error: not found 2", false);
det.record_call("shell", r#"{"cmd":"bad3"}"#, "error: not found 3", false);
match det.check() {
DetectionVerdict::InjectWarning(msg) => {
assert!(msg.contains("failed 3 consecutive"), "msg: {msg}");
}
other => panic!("expected InjectWarning, got {other:?}"),
}
}
// 8. Failure streak resets on success
#[test]
fn failure_streak_resets_on_success() {
let mut det = LoopDetector::new(default_config());
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
det.record_call("shell", r#"{"cmd":"good"}"#, "ok", true); // resets
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 9. All thresholds zero → disabled
#[test]
fn all_disabled_never_triggers() {
let mut det = LoopDetector::new(disabled_config());
for _ in 0..20 {
det.record_call("echo", r#"{"msg":"hi"}"#, "same", true);
}
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 10. Mixed tools → no false positive
#[test]
fn mixed_tools_no_false_positive() {
let mut det = LoopDetector::new(default_config());
det.record_call("file_read", r#"{"path":"a.rs"}"#, "content_a", true);
det.record_call("shell", r#"{"cmd":"ls"}"#, "file_list", true);
det.record_call("memory_store", r#"{"key":"x"}"#, "stored", true);
det.record_call("file_read", r#"{"path":"b.rs"}"#, "content_b", true);
det.record_call("shell", r#"{"cmd":"cargo test"}"#, "ok", true);
assert_eq!(det.check(), DetectionVerdict::Continue);
}
}
+76 -116
View File
@@ -999,6 +999,7 @@ fn runtime_perplexity_filter_snapshot(
return state.perplexity_filter.clone();
}
}
crate::config::PerplexityFilterConfig::default()
}
@@ -2168,54 +2169,6 @@ async fn handle_runtime_command_if_needed(
)
}
}
ChannelRuntimeCommand::ApprovePendingRequest(raw_request_id) => {
let request_id = raw_request_id.trim().to_string();
if request_id.is_empty() {
"Usage: `/approve-allow <request-id>`".to_string()
} else {
match ctx.approval_manager.confirm_non_cli_pending_request(
&request_id,
sender,
source_channel,
reply_target,
) {
Ok(req) => {
ctx.approval_manager
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::Yes);
runtime_trace::record_event(
"approval_request_allowed",
Some(source_channel),
None,
None,
None,
Some(true),
Some("pending request allowed for current tool invocation"),
serde_json::json!({
"request_id": request_id,
"tool_name": req.tool_name,
"sender": sender,
"channel": source_channel,
}),
);
format!(
"Approved pending request `{}` for this invocation of `{}`.",
req.request_id, req.tool_name
)
}
Err(PendingApprovalError::NotFound) => {
format!("Pending approval request `{request_id}` was not found.")
}
Err(PendingApprovalError::Expired) => {
format!("Pending approval request `{request_id}` has expired.")
}
Err(PendingApprovalError::RequesterMismatch) => {
format!(
"Pending approval request `{request_id}` can only be approved by the same sender in the same chat/channel that created it."
)
}
}
}
}
ChannelRuntimeCommand::ConfirmToolApproval(raw_request_id) => {
let request_id = raw_request_id.trim().to_string();
if request_id.is_empty() {
@@ -2228,8 +2181,6 @@ async fn handle_runtime_command_if_needed(
reply_target,
) {
Ok(req) => {
ctx.approval_manager
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::Yes);
let tool_name = req.tool_name;
let mut approval_message = if tool_name == APPROVAL_ALL_TOOLS_ONCE_TOKEN {
let remaining = ctx.approval_manager.grant_non_cli_allow_all_once();
@@ -2336,10 +2287,54 @@ async fn handle_runtime_command_if_needed(
}
}
}
ChannelRuntimeCommand::ApprovePendingRequest(raw_request_id) => {
let request_id = raw_request_id.trim().to_string();
if request_id.is_empty() {
"Usage: `/approve-allow <request-id>`".to_string()
} else if !ctx
.approval_manager
.is_non_cli_approval_actor_allowed(source_channel, sender)
{
"You are not allowed to approve pending non-CLI tool requests.".to_string()
} else {
match ctx.approval_manager.confirm_non_cli_pending_request(
&request_id,
sender,
source_channel,
reply_target,
) {
Ok(req) => {
ctx.approval_manager.record_non_cli_pending_resolution(
&request_id,
ApprovalResponse::Yes,
);
format!(
"Approved pending request `{}` for `{}`.",
request_id,
approval_target_label(&req.tool_name)
)
}
Err(PendingApprovalError::NotFound) => {
format!("Pending approval request `{request_id}` was not found.")
}
Err(PendingApprovalError::Expired) => {
format!("Pending approval request `{request_id}` has expired.")
}
Err(PendingApprovalError::RequesterMismatch) => format!(
"Pending approval request `{request_id}` can only be approved by the same sender in the same chat/channel that created it."
),
}
}
}
ChannelRuntimeCommand::DenyToolApproval(raw_request_id) => {
let request_id = raw_request_id.trim().to_string();
if request_id.is_empty() {
"Usage: `/approve-deny <request-id>`".to_string()
} else if !ctx
.approval_manager
.is_non_cli_approval_actor_allowed(source_channel, sender)
{
"You are not allowed to deny pending non-CLI tool requests.".to_string()
} else {
match ctx.approval_manager.reject_non_cli_pending_request(
&request_id,
@@ -2348,81 +2343,25 @@ async fn handle_runtime_command_if_needed(
reply_target,
) {
Ok(req) => {
ctx.approval_manager
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::No);
runtime_trace::record_event(
"approval_request_denied",
Some(source_channel),
None,
None,
None,
Some(true),
Some("pending request denied"),
serde_json::json!({
"request_id": request_id,
"tool_name": req.tool_name,
"sender": sender,
"channel": source_channel,
}),
ctx.approval_manager.record_non_cli_pending_resolution(
&request_id,
ApprovalResponse::No,
);
format!(
"Denied pending approval request `{}` for tool `{}`.",
req.request_id, req.tool_name
"Denied pending request `{}` for `{}`.",
request_id,
approval_target_label(&req.tool_name)
)
}
Err(PendingApprovalError::NotFound) => {
runtime_trace::record_event(
"approval_request_denied",
Some(source_channel),
None,
None,
None,
Some(false),
Some("pending request not found"),
serde_json::json!({
"request_id": request_id,
"sender": sender,
"channel": source_channel,
}),
);
format!("Pending approval request `{request_id}` was not found.")
}
Err(PendingApprovalError::Expired) => {
runtime_trace::record_event(
"approval_request_denied",
Some(source_channel),
None,
None,
None,
Some(false),
Some("pending request expired"),
serde_json::json!({
"request_id": request_id,
"sender": sender,
"channel": source_channel,
}),
);
format!("Pending approval request `{request_id}` has expired.")
}
Err(PendingApprovalError::RequesterMismatch) => {
runtime_trace::record_event(
"approval_request_denied",
Some(source_channel),
None,
None,
None,
Some(false),
Some("pending request denier mismatch"),
serde_json::json!({
"request_id": request_id,
"sender": sender,
"channel": source_channel,
}),
);
format!(
"Pending approval request `{request_id}` can only be denied by the same sender in the same chat/channel that created it."
)
}
Err(PendingApprovalError::RequesterMismatch) => format!(
"Pending approval request `{request_id}` can only be denied by the same sender in the same chat/channel that created it."
),
}
}
}
@@ -3909,7 +3848,7 @@ async fn run_message_dispatch_loop(
}
}
process_channel_message(worker_ctx, msg, cancellation_token).await;
Box::pin(process_channel_message(worker_ctx, msg, cancellation_token)).await;
if interrupt_enabled {
let mut active = in_flight.lock().await;
@@ -4894,6 +4833,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: config.api_url.clone(),
provider_transport: config.effective_provider_transport(),
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
reasoning_enabled: config.runtime.reasoning_enabled,
@@ -5301,6 +5241,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
}
#[cfg(test)]
#[allow(clippy::large_futures)]
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
@@ -10411,6 +10352,25 @@ BTC is currently around $65,000 based on latest tool output."#;
.any(|entry| entry.channel.name() == "mattermost"));
}
#[test]
fn collect_configured_channels_includes_dingtalk_when_configured() {
let mut config = Config::default();
config.channels_config.dingtalk = Some(crate::config::schema::DingTalkConfig {
client_id: "ding-app-key".to_string(),
client_secret: "ding-app-secret".to_string(),
allowed_users: vec!["*".to_string()],
});
let channels = collect_configured_channels(&config, "test");
assert!(channels
.iter()
.any(|entry| entry.display_name == "DingTalk"));
assert!(channels
.iter()
.any(|entry| entry.channel.name() == "dingtalk"));
}
struct AlwaysFailChannel {
name: &'static str,
calls: Arc<AtomicUsize>,
+15 -14
View File
@@ -8,20 +8,21 @@ pub use schema::{
AgentConfig, AgentSessionBackend, AgentSessionConfig, AgentSessionStrategy, AgentsIpcConfig,
AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig,
ChannelsConfig, ClassificationRule, ComposioConfig, Config, CoordinationConfig, CostConfig,
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig,
EstopConfig, FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig,
HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig,
IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig,
NextcloudTalkConfig, NonCliNaturalLanguageApprovalMode, ObservabilityConfig,
OtpChallengeDelivery, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig,
PerplexityFilterConfig, PluginEntryConfig, PluginsConfig, ProviderConfig, ProxyConfig,
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResearchPhaseConfig,
ResearchTrigger, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityRoleConfig, SkillsConfig,
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig,
TunnelConfig, UrlAccessConfig, WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy,
WasmRuntimeConfig, WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EconomicConfig,
EconomicTokenPricing, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
GroupReplyConfig, GroupReplyMode, HardwareConfig, HardwareTransport, HeartbeatConfig,
HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig,
MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
NonCliNaturalLanguageApprovalMode, ObservabilityConfig, OtpChallengeDelivery, OtpConfig,
OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PerplexityFilterConfig, PluginEntryConfig,
PluginsConfig, ProviderConfig, ProxyConfig, ProxyScope, QdrantConfig,
QueryClassificationConfig, ReliabilityConfig, ResearchPhaseConfig, ResearchTrigger,
ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig,
SecretsConfig, SecurityConfig, SecurityRoleConfig, SkillsConfig, SkillsPromptInjectionMode,
SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode,
SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig, TunnelConfig, UrlAccessConfig,
WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy, WasmRuntimeConfig,
WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
};
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
+343
View File
@@ -237,6 +237,11 @@ pub struct Config {
#[serde(default)]
pub cost: CostConfig,
/// Economic agent survival tracking (`[economic]`).
/// Tracks balance, token costs, work income, and survival status.
#[serde(default)]
pub economic: EconomicConfig,
/// Peripheral board configuration for hardware integration (`[peripherals]`).
#[serde(default)]
pub peripherals: PeripheralsConfig,
@@ -309,6 +314,20 @@ pub struct ProviderConfig {
/// (e.g. OpenAI Codex `/responses` reasoning effort).
#[serde(default)]
pub reasoning_level: Option<String>,
/// Optional transport override for providers that support multiple transports.
/// Supported values: "auto", "websocket", "sse".
///
/// Resolution order:
/// 1) `model_routes[].transport` (route-specific)
/// 2) env overrides (`PROVIDER_TRANSPORT`, `ZEROCLAW_PROVIDER_TRANSPORT`, `ZEROCLAW_CODEX_TRANSPORT`)
/// 3) `provider.transport`
/// 4) runtime default (`auto`, WebSocket-first with SSE fallback for OpenAI Codex)
///
/// Note: env overrides replace configured `provider.transport` when set.
///
/// Existing configs that omit `provider.transport` remain valid and fall back to defaults.
#[serde(default)]
pub transport: Option<String>,
}
// ── Delegate Agents ──────────────────────────────────────────────
@@ -716,6 +735,21 @@ pub struct AgentConfig {
/// Tool dispatch strategy (e.g. `"auto"`). Default: `"auto"`.
#[serde(default = "default_agent_tool_dispatcher")]
pub tool_dispatcher: String,
/// Loop detection: no-progress repeat threshold.
/// Triggers when the same tool+args produces identical output this many times.
/// Set to `0` to disable. Default: `3`.
#[serde(default = "default_loop_detection_no_progress_threshold")]
pub loop_detection_no_progress_threshold: usize,
/// Loop detection: ping-pong cycle threshold.
/// Detects A→B→A→B alternating patterns with no progress.
/// Value is number of full cycles (A-B = 1 cycle). Set to `0` to disable. Default: `2`.
#[serde(default = "default_loop_detection_ping_pong_cycles")]
pub loop_detection_ping_pong_cycles: usize,
/// Loop detection: consecutive failure streak threshold.
/// Triggers when the same tool fails this many times in a row.
/// Set to `0` to disable. Default: `3`.
#[serde(default = "default_loop_detection_failure_streak")]
pub loop_detection_failure_streak: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
@@ -787,6 +821,18 @@ fn default_agent_session_max_messages() -> usize {
default_agent_max_history_messages()
}
fn default_loop_detection_no_progress_threshold() -> usize {
3
}
fn default_loop_detection_ping_pong_cycles() -> usize {
2
}
fn default_loop_detection_failure_streak() -> usize {
3
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
@@ -796,6 +842,9 @@ impl Default for AgentConfig {
max_history_messages: default_agent_max_history_messages(),
parallel_tools: false,
tool_dispatcher: default_agent_tool_dispatcher(),
loop_detection_no_progress_threshold: default_loop_detection_no_progress_threshold(),
loop_detection_ping_pong_cycles: default_loop_detection_ping_pong_cycles(),
loop_detection_failure_streak: default_loop_detection_failure_streak(),
}
}
}
@@ -1160,6 +1209,83 @@ pub struct PeripheralBoardConfig {
pub baud: u32,
}
// ── Economic Agent Config ─────────────────────────────────────────
/// Token pricing configuration for economic tracking.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct EconomicTokenPricing {
/// Price per million input tokens (USD)
#[serde(default = "default_input_price")]
pub input_price_per_million: f64,
/// Price per million output tokens (USD)
#[serde(default = "default_output_price")]
pub output_price_per_million: f64,
}
fn default_input_price() -> f64 {
3.0 // Claude Sonnet 4 input price
}
fn default_output_price() -> f64 {
15.0 // Claude Sonnet 4 output price
}
impl Default for EconomicTokenPricing {
fn default() -> Self {
Self {
input_price_per_million: default_input_price(),
output_price_per_million: default_output_price(),
}
}
}
/// Economic agent survival tracking configuration (`[economic]` section).
///
/// Implements the ClawWork economic model for AI agents, tracking
/// balance, costs, income, and survival status.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct EconomicConfig {
/// Enable economic tracking (default: false)
#[serde(default)]
pub enabled: bool,
/// Starting balance in USD (default: 1000.0)
#[serde(default = "default_initial_balance")]
pub initial_balance: f64,
/// Token pricing configuration
#[serde(default)]
pub token_pricing: EconomicTokenPricing,
/// Minimum evaluation score (0.0-1.0) to receive payment (default: 0.6)
#[serde(default = "default_min_evaluation_threshold")]
pub min_evaluation_threshold: f64,
/// Data directory for economic state persistence (relative to workspace)
#[serde(default)]
pub data_path: Option<String>,
}
fn default_initial_balance() -> f64 {
1000.0
}
fn default_min_evaluation_threshold() -> f64 {
0.6
}
impl Default for EconomicConfig {
fn default() -> Self {
Self {
enabled: false,
initial_balance: default_initial_balance(),
token_pricing: EconomicTokenPricing::default(),
min_evaluation_threshold: default_min_evaluation_threshold(),
data_path: None,
}
}
}
fn default_peripheral_transport() -> String {
"serial".into()
}
@@ -3266,6 +3392,14 @@ pub struct ModelRouteConfig {
/// Optional API key override for this route's provider
#[serde(default)]
pub api_key: Option<String>,
/// Optional route-specific transport override for this route.
/// Supported values: "auto", "websocket", "sse".
///
/// When `model_routes[].transport` is unset, the route inherits `provider.transport`.
/// If both are unset, runtime defaults are used (`auto` for OpenAI Codex).
/// Existing configs without this field remain valid.
#[serde(default)]
pub transport: Option<String>,
}
// ── Embedding routing ───────────────────────────────────────────
@@ -5135,6 +5269,7 @@ impl Default for Config {
proxy: ProxyConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
economic: EconomicConfig::default(),
peripherals: PeripheralsConfig::default(),
agents: HashMap::new(),
coordination: CoordinationConfig::default(),
@@ -6115,6 +6250,28 @@ impl Config {
}
}
fn normalize_provider_transport(raw: Option<&str>, source: &str) -> Option<String> {
let value = raw?.trim();
if value.is_empty() {
return None;
}
let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
match normalized.as_str() {
"auto" => Some("auto".to_string()),
"websocket" | "ws" => Some("websocket".to_string()),
"sse" | "http" => Some("sse".to_string()),
_ => {
tracing::warn!(
transport = %value,
source,
"Ignoring invalid provider transport override"
);
None
}
}
}
/// Resolve provider reasoning level with backward-compatible runtime alias.
///
/// Priority:
@@ -6158,6 +6315,16 @@ impl Config {
}
}
/// Resolve provider transport mode (`provider.transport`).
///
/// Supported values:
/// - `auto`
/// - `websocket`
/// - `sse`
pub fn effective_provider_transport(&self) -> Option<String> {
Self::normalize_provider_transport(self.provider.transport.as_deref(), "provider.transport")
}
fn lookup_model_provider_profile(
&self,
provider_name: &str,
@@ -6521,6 +6688,32 @@ impl Config {
if route.max_tokens == Some(0) {
anyhow::bail!("model_routes[{i}].max_tokens must be greater than 0");
}
if route
.transport
.as_deref()
.is_some_and(|value| !value.trim().is_empty())
&& Self::normalize_provider_transport(
route.transport.as_deref(),
"model_routes[].transport",
)
.is_none()
{
anyhow::bail!("model_routes[{i}].transport must be one of: auto, websocket, sse");
}
}
if self
.provider
.transport
.as_deref()
.is_some_and(|value| !value.trim().is_empty())
&& Self::normalize_provider_transport(
self.provider.transport.as_deref(),
"provider.transport",
)
.is_none()
{
anyhow::bail!("provider.transport must be one of: auto, websocket, sse");
}
if self.provider_api.is_some()
@@ -6852,6 +7045,17 @@ impl Config {
}
}
// Provider transport override: ZEROCLAW_PROVIDER_TRANSPORT or PROVIDER_TRANSPORT
if let Ok(transport) = std::env::var("ZEROCLAW_PROVIDER_TRANSPORT")
.or_else(|_| std::env::var("PROVIDER_TRANSPORT"))
{
if let Some(normalized) =
Self::normalize_provider_transport(Some(&transport), "env:provider_transport")
{
self.provider.transport = Some(normalized);
}
}
// Vision support override: ZEROCLAW_MODEL_SUPPORT_VISION or MODEL_SUPPORT_VISION
if let Ok(flag) = std::env::var("ZEROCLAW_MODEL_SUPPORT_VISION")
.or_else(|_| std::env::var("MODEL_SUPPORT_VISION"))
@@ -7700,6 +7904,7 @@ default_temperature = 0.7
agent: AgentConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
economic: EconomicConfig::default(),
peripherals: PeripheralsConfig::default(),
agents: HashMap::new(),
hooks: HooksConfig::default(),
@@ -8074,6 +8279,7 @@ tool_dispatcher = "xml"
agent: AgentConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
economic: EconomicConfig::default(),
peripherals: PeripheralsConfig::default(),
agents: HashMap::new(),
hooks: HooksConfig::default(),
@@ -9453,6 +9659,7 @@ provider_api = "not-a-real-mode"
model: "anthropic/claude-sonnet-4.6".to_string(),
max_tokens: Some(0),
api_key: None,
transport: None,
}];
let err = config
@@ -9463,6 +9670,48 @@ provider_api = "not-a-real-mode"
.contains("model_routes[0].max_tokens must be greater than 0"));
}
#[test]
async fn provider_transport_normalizes_aliases() {
let mut config = Config::default();
config.provider.transport = Some("WS".to_string());
assert_eq!(
config.effective_provider_transport().as_deref(),
Some("websocket")
);
}
#[test]
async fn provider_transport_invalid_is_rejected() {
let mut config = Config::default();
config.provider.transport = Some("udp".to_string());
let err = config
.validate()
.expect_err("provider.transport should reject invalid values");
assert!(err
.to_string()
.contains("provider.transport must be one of: auto, websocket, sse"));
}
#[test]
async fn model_route_transport_invalid_is_rejected() {
let mut config = Config::default();
config.model_routes = vec![ModelRouteConfig {
hint: "reasoning".to_string(),
provider: "openrouter".to_string(),
model: "anthropic/claude-sonnet-4.6".to_string(),
max_tokens: None,
api_key: None,
transport: Some("udp".to_string()),
}];
let err = config
.validate()
.expect_err("model_routes[].transport should reject invalid values");
assert!(err
.to_string()
.contains("model_routes[0].transport must be one of: auto, websocket, sse"));
}
#[test]
async fn env_override_glm_api_key_for_regional_aliases() {
let _env_guard = env_override_lock().await;
@@ -10103,6 +10352,60 @@ default_model = "legacy-model"
std::env::remove_var("ZEROCLAW_REASONING_LEVEL");
}
#[test]
async fn env_override_provider_transport_normalizes_zeroclaw_alias() {
let _env_guard = env_override_lock().await;
let mut config = Config::default();
std::env::remove_var("PROVIDER_TRANSPORT");
std::env::set_var("ZEROCLAW_PROVIDER_TRANSPORT", "WS");
config.apply_env_overrides();
assert_eq!(config.provider.transport.as_deref(), Some("websocket"));
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
}
#[test]
async fn env_override_provider_transport_normalizes_legacy_alias() {
let _env_guard = env_override_lock().await;
let mut config = Config::default();
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
std::env::set_var("PROVIDER_TRANSPORT", "HTTP");
config.apply_env_overrides();
assert_eq!(config.provider.transport.as_deref(), Some("sse"));
std::env::remove_var("PROVIDER_TRANSPORT");
}
#[test]
async fn env_override_provider_transport_invalid_zeroclaw_does_not_override_existing() {
let _env_guard = env_override_lock().await;
let mut config = Config::default();
config.provider.transport = Some("sse".to_string());
std::env::remove_var("PROVIDER_TRANSPORT");
std::env::set_var("ZEROCLAW_PROVIDER_TRANSPORT", "udp");
config.apply_env_overrides();
assert_eq!(config.provider.transport.as_deref(), Some("sse"));
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
}
#[test]
async fn env_override_provider_transport_invalid_legacy_does_not_override_existing() {
let _env_guard = env_override_lock().await;
let mut config = Config::default();
config.provider.transport = Some("auto".to_string());
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
std::env::set_var("PROVIDER_TRANSPORT", "udp");
config.apply_env_overrides();
assert_eq!(config.provider.transport.as_deref(), Some("auto"));
std::env::remove_var("PROVIDER_TRANSPORT");
}
#[test]
async fn env_override_model_support_vision() {
let _env_guard = env_override_lock().await;
@@ -10597,6 +10900,46 @@ default_model = "legacy-model"
assert_eq!(parsed.allowed_users, vec!["*"]);
}
#[test]
async fn dingtalk_config_defaults_allowed_users_to_empty() {
let json = r#"{"client_id":"ding-app-key","client_secret":"ding-app-secret"}"#;
let parsed: DingTalkConfig = serde_json::from_str(json).unwrap();
assert_eq!(parsed.client_id, "ding-app-key");
assert_eq!(parsed.client_secret, "ding-app-secret");
assert!(parsed.allowed_users.is_empty());
}
#[test]
async fn dingtalk_config_toml_roundtrip() {
let dc = DingTalkConfig {
client_id: "ding-app-key".into(),
client_secret: "ding-app-secret".into(),
allowed_users: vec!["*".into(), "staff123".into()],
};
let toml_str = toml::to_string(&dc).unwrap();
let parsed: DingTalkConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.client_id, "ding-app-key");
assert_eq!(parsed.client_secret, "ding-app-secret");
assert_eq!(parsed.allowed_users, vec!["*", "staff123"]);
}
#[test]
async fn channels_except_webhook_reports_dingtalk_as_enabled() {
let mut channels = ChannelsConfig::default();
channels.dingtalk = Some(DingTalkConfig {
client_id: "ding-app-key".into(),
client_secret: "ding-app-secret".into(),
allowed_users: vec!["*".into()],
});
let dingtalk_state = channels
.channels_except_webhook()
.iter()
.find_map(|(handle, enabled)| (handle.name() == "DingTalk").then_some(*enabled));
assert_eq!(dingtalk_state, Some(true));
}
#[test]
async fn nextcloud_talk_config_serde() {
let nc = NextcloudTalkConfig {
+36
View File
@@ -22,6 +22,10 @@ const MIN_POLL_SECONDS: u64 = 5;
const SHELL_JOB_TIMEOUT_SECS: u64 = 120;
const SCHEDULER_COMPONENT: &str = "scheduler";
pub(crate) fn is_no_reply_sentinel(output: &str) -> bool {
output.trim().eq_ignore_ascii_case("NO_REPLY")
}
pub async fn run(config: Config) -> Result<()> {
let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS);
let mut interval = time::interval(Duration::from_secs(poll_secs));
@@ -292,6 +296,13 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
if !delivery.mode.eq_ignore_ascii_case("announce") {
return Ok(());
}
if is_no_reply_sentinel(output) {
tracing::debug!(
"Cron job '{}' returned NO_REPLY sentinel; skipping announce delivery",
job.id
);
return Ok(());
}
let channel = delivery
.channel
@@ -1136,6 +1147,31 @@ mod tests {
assert!(err.to_string().contains("unsupported delivery channel"));
}
#[tokio::test]
async fn deliver_if_configured_skips_no_reply_sentinel() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp).await;
let mut job = test_job("echo ok");
job.delivery = DeliveryConfig {
mode: "announce".into(),
channel: Some("invalid".into()),
to: Some("target".into()),
best_effort: true,
};
assert!(deliver_if_configured(&config, &job, " no_reply ")
.await
.is_ok());
}
#[test]
fn no_reply_sentinel_matching_is_trimmed_and_case_insensitive() {
assert!(is_no_reply_sentinel("NO_REPLY"));
assert!(is_no_reply_sentinel(" no_reply "));
assert!(!is_no_reply_sentinel("NO_REPLY please"));
assert!(!is_no_reply_sentinel(""));
}
#[tokio::test]
async fn deliver_if_configured_whatsapp_web_requires_live_session_in_web_mode() {
let tmp = TempDir::new().unwrap();
+49 -19
View File
@@ -227,26 +227,25 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
{
Ok(output) => {
crate::health::mark_component_ok("heartbeat");
let announcement = if output.trim().is_empty() {
"heartbeat task executed".to_string()
} else {
output
};
if let Some((channel, target)) = &delivery {
if let Err(e) = crate::cron::scheduler::deliver_announcement(
&config,
channel,
target,
&announcement,
)
.await
{
crate::health::mark_component_error(
"heartbeat",
format!("delivery failed: {e}"),
);
tracing::warn!("Heartbeat delivery failed: {e}");
if let Some(announcement) = heartbeat_announcement_text(&output) {
if let Some((channel, target)) = &delivery {
if let Err(e) = crate::cron::scheduler::deliver_announcement(
&config,
channel,
target,
&announcement,
)
.await
{
crate::health::mark_component_error(
"heartbeat",
format!("delivery failed: {e}"),
);
tracing::warn!("Heartbeat delivery failed: {e}");
}
}
} else {
tracing::debug!("Heartbeat returned NO_REPLY sentinel; skipping delivery");
}
}
Err(e) => {
@@ -258,6 +257,16 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
}
}
fn heartbeat_announcement_text(output: &str) -> Option<String> {
if crate::cron::scheduler::is_no_reply_sentinel(output) {
return None;
}
if output.trim().is_empty() {
return Some("heartbeat task executed".to_string());
}
Some(output.to_string())
}
fn heartbeat_tasks_for_tick(
file_tasks: Vec<String>,
fallback_message: Option<&str>,
@@ -553,6 +562,27 @@ mod tests {
assert!(tasks.is_empty());
}
#[test]
fn heartbeat_announcement_text_skips_no_reply_sentinel() {
assert!(heartbeat_announcement_text(" NO_reply ").is_none());
}
#[test]
fn heartbeat_announcement_text_uses_default_for_empty_output() {
assert_eq!(
heartbeat_announcement_text(" \n\t "),
Some("heartbeat task executed".to_string())
);
}
#[test]
fn heartbeat_announcement_text_keeps_regular_output() {
assert_eq!(
heartbeat_announcement_text("system nominal"),
Some("system nominal".to_string())
);
}
#[test]
fn heartbeat_delivery_target_none_when_unset() {
let config = Config::default();
+1
View File
@@ -1167,6 +1167,7 @@ mod tests {
model: String::new(),
max_tokens: None,
api_key: None,
transport: None,
}];
let mut items = Vec::new();
check_config_semantics(&config, &mut items);
+874
View File
@@ -0,0 +1,874 @@
//! Task Classifier for ZeroClaw Economic Agents
//!
//! Classifies work instructions into 44 BLS occupations with wage data
//! to estimate task value for agent economics.
//!
//! ## Overview
//!
//! The classifier matches task instructions to standardized occupation
//! categories using keyword matching and heuristics, then calculates
//! expected payment based on BLS hourly wage data.
//!
//! ## Example
//!
//! ```rust,ignore
//! use zeroclaw::economic::classifier::{TaskClassifier, OccupationCategory};
//!
//! let classifier = TaskClassifier::new();
//! let result = classifier.classify("Write a REST API in Rust").await?;
//!
//! println!("Occupation: {}", result.occupation);
//! println!("Hourly wage: ${:.2}", result.hourly_wage);
//! println!("Estimated hours: {:.2}", result.estimated_hours);
//! println!("Max payment: ${:.2}", result.max_payment);
//! ```
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Occupation category groupings based on BLS major groups
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum OccupationCategory {
/// Software, IT, engineering roles
TechnologyEngineering,
/// Finance, accounting, management, sales
BusinessFinance,
/// Medical, nursing, social work
HealthcareSocialServices,
/// Legal, media, operations, other professional
LegalMediaOperations,
}
impl OccupationCategory {
/// Returns a human-readable name for the category
pub fn display_name(&self) -> &'static str {
match self {
Self::TechnologyEngineering => "Technology & Engineering",
Self::BusinessFinance => "Business & Finance",
Self::HealthcareSocialServices => "Healthcare & Social Services",
Self::LegalMediaOperations => "Legal, Media & Operations",
}
}
}
/// A single occupation with BLS wage data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Occupation {
/// Official BLS occupation name
pub name: String,
/// Hourly wage in USD (BLS median)
pub hourly_wage: f64,
/// Category grouping
pub category: OccupationCategory,
/// Keywords for matching
#[serde(skip)]
pub keywords: Vec<&'static str>,
}
/// Result of classifying a task instruction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationResult {
/// Matched occupation name
pub occupation: String,
/// BLS hourly wage for this occupation
pub hourly_wage: f64,
/// Estimated hours to complete task
pub estimated_hours: f64,
/// Maximum payment (hours × wage)
pub max_payment: f64,
/// Classification confidence (0.0 - 1.0)
pub confidence: f64,
/// Category of the matched occupation
pub category: OccupationCategory,
/// Brief reasoning for the classification
pub reasoning: String,
}
/// Task classifier that maps instructions to BLS occupations
#[derive(Debug)]
pub struct TaskClassifier {
occupations: Vec<Occupation>,
keyword_index: HashMap<&'static str, Vec<usize>>,
fallback_occupation: String,
fallback_wage: f64,
}
impl Default for TaskClassifier {
fn default() -> Self {
Self::new()
}
}
impl TaskClassifier {
/// Create a new TaskClassifier with embedded BLS occupation data
pub fn new() -> Self {
let occupations = Self::load_occupations();
let keyword_index = Self::build_keyword_index(&occupations);
Self {
occupations,
keyword_index,
fallback_occupation: "General and Operations Managers".to_string(),
fallback_wage: 64.0,
}
}
/// Load all 44 BLS occupations with wage data
fn load_occupations() -> Vec<Occupation> {
use OccupationCategory::{
BusinessFinance, HealthcareSocialServices, LegalMediaOperations, TechnologyEngineering,
};
vec![
// Technology & Engineering
Occupation {
name: "Software Developers".into(),
hourly_wage: 69.50,
category: TechnologyEngineering,
keywords: vec![
"software",
"code",
"programming",
"developer",
"rust",
"python",
"javascript",
"api",
"backend",
"frontend",
"fullstack",
"app",
"application",
"debug",
"refactor",
"implement",
"algorithm",
],
},
Occupation {
name: "Computer and Information Systems Managers".into(),
hourly_wage: 90.38,
category: TechnologyEngineering,
keywords: vec![
"it manager",
"cto",
"tech lead",
"infrastructure",
"systems",
"devops",
"cloud",
"architecture",
"platform",
"enterprise",
],
},
Occupation {
name: "Industrial Engineers".into(),
hourly_wage: 51.87,
category: TechnologyEngineering,
keywords: vec![
"industrial",
"process",
"optimization",
"efficiency",
"workflow",
"manufacturing",
"lean",
"six sigma",
"production",
],
},
Occupation {
name: "Mechanical Engineers".into(),
hourly_wage: 52.92,
category: TechnologyEngineering,
keywords: vec![
"mechanical",
"cad",
"solidworks",
"machinery",
"thermal",
"hvac",
"automotive",
"robotics",
],
},
// Business & Finance
Occupation {
name: "Accountants and Auditors".into(),
hourly_wage: 44.96,
category: BusinessFinance,
keywords: vec![
"accounting",
"audit",
"tax",
"bookkeeping",
"financial statements",
"gaap",
"ledger",
"reconciliation",
"cpa",
],
},
Occupation {
name: "Administrative Services Managers".into(),
hourly_wage: 60.59,
category: BusinessFinance,
keywords: vec![
"administrative",
"office manager",
"facilities",
"operations",
"scheduling",
"coordination",
],
},
Occupation {
name: "Buyers and Purchasing Agents".into(),
hourly_wage: 39.29,
category: BusinessFinance,
keywords: vec![
"procurement",
"purchasing",
"vendor",
"supplier",
"sourcing",
"negotiation",
"contracts",
],
},
Occupation {
name: "Compliance Officers".into(),
hourly_wage: 40.86,
category: BusinessFinance,
keywords: vec![
"compliance",
"regulatory",
"audit",
"policy",
"governance",
"risk",
"sox",
"gdpr",
],
},
Occupation {
name: "Financial Managers".into(),
hourly_wage: 86.76,
category: BusinessFinance,
keywords: vec![
"cfo",
"finance director",
"treasury",
"budget",
"financial planning",
"investment management",
],
},
Occupation {
name: "Financial and Investment Analysts".into(),
hourly_wage: 56.01,
category: BusinessFinance,
keywords: vec![
"financial analysis",
"investment",
"portfolio",
"stock",
"equity",
"valuation",
"modeling",
"dcf",
"market research",
],
},
Occupation {
name: "General and Operations Managers".into(),
hourly_wage: 64.00,
category: BusinessFinance,
keywords: vec![
"operations",
"general manager",
"director",
"oversee",
"manage",
"strategy",
"leadership",
"business",
],
},
Occupation {
name: "Market Research Analysts and Marketing Specialists".into(),
hourly_wage: 41.58,
category: BusinessFinance,
keywords: vec![
"market research",
"marketing",
"campaign",
"branding",
"seo",
"advertising",
"analytics",
"customer",
"segment",
],
},
Occupation {
name: "Personal Financial Advisors".into(),
hourly_wage: 77.02,
category: BusinessFinance,
keywords: vec![
"financial advisor",
"wealth",
"retirement",
"401k",
"ira",
"estate planning",
"insurance",
],
},
Occupation {
name: "Project Management Specialists".into(),
hourly_wage: 51.97,
category: BusinessFinance,
keywords: vec![
"project manager",
"pmp",
"agile",
"scrum",
"sprint",
"milestone",
"timeline",
"stakeholder",
"deliverable",
],
},
Occupation {
name: "Property, Real Estate, and Community Association Managers".into(),
hourly_wage: 39.77,
category: BusinessFinance,
keywords: vec![
"property",
"real estate",
"landlord",
"tenant",
"lease",
"hoa",
"community",
],
},
Occupation {
name: "Sales Managers".into(),
hourly_wage: 77.37,
category: BusinessFinance,
keywords: vec![
"sales manager",
"revenue",
"quota",
"pipeline",
"crm",
"account executive",
"territory",
],
},
Occupation {
name: "Marketing and Sales Managers".into(),
hourly_wage: 79.35,
category: BusinessFinance,
keywords: vec!["vp sales", "cmo", "growth", "go-to-market", "demand gen"],
},
Occupation {
name: "Financial Specialists".into(),
hourly_wage: 48.12,
category: BusinessFinance,
keywords: vec!["financial specialist", "credit", "loan", "underwriting"],
},
Occupation {
name: "Securities, Commodities, and Financial Services Sales Agents".into(),
hourly_wage: 48.12,
category: BusinessFinance,
keywords: vec!["broker", "securities", "commodities", "trading", "series 7"],
},
Occupation {
name: "Business Operations Specialists, All Other".into(),
hourly_wage: 44.41,
category: BusinessFinance,
keywords: vec![
"business analyst",
"operations specialist",
"process improvement",
],
},
Occupation {
name: "Claims Adjusters, Examiners, and Investigators".into(),
hourly_wage: 37.87,
category: BusinessFinance,
keywords: vec!["claims", "insurance", "adjuster", "investigator", "fraud"],
},
Occupation {
name: "Transportation, Storage, and Distribution Managers".into(),
hourly_wage: 55.77,
category: BusinessFinance,
keywords: vec![
"logistics",
"supply chain",
"warehouse",
"distribution",
"shipping",
"inventory",
"fulfillment",
],
},
Occupation {
name: "Industrial Production Managers".into(),
hourly_wage: 62.11,
category: BusinessFinance,
keywords: vec![
"production manager",
"plant manager",
"manufacturing operations",
],
},
Occupation {
name: "Lodging Managers".into(),
hourly_wage: 37.24,
category: BusinessFinance,
keywords: vec!["hotel", "hospitality", "lodging", "resort", "concierge"],
},
Occupation {
name: "Real Estate Brokers".into(),
hourly_wage: 39.77,
category: BusinessFinance,
keywords: vec!["real estate broker", "realtor", "mls", "listing"],
},
Occupation {
name: "Managers, All Other".into(),
hourly_wage: 72.06,
category: BusinessFinance,
keywords: vec!["manager", "supervisor", "team lead"],
},
// Healthcare & Social Services
Occupation {
name: "Medical and Health Services Managers".into(),
hourly_wage: 66.22,
category: HealthcareSocialServices,
keywords: vec![
"healthcare",
"hospital",
"clinic",
"medical",
"health services",
"patient",
"hipaa",
],
},
Occupation {
name: "Social and Community Service Managers".into(),
hourly_wage: 41.39,
category: HealthcareSocialServices,
keywords: vec![
"social services",
"community",
"nonprofit",
"outreach",
"case management",
"welfare",
],
},
Occupation {
name: "Child, Family, and School Social Workers".into(),
hourly_wage: 41.39,
category: HealthcareSocialServices,
keywords: vec![
"social worker",
"child welfare",
"family services",
"school counselor",
],
},
Occupation {
name: "Registered Nurses".into(),
hourly_wage: 66.22,
category: HealthcareSocialServices,
keywords: vec!["nurse", "rn", "nursing", "patient care", "clinical"],
},
Occupation {
name: "Nurse Practitioners".into(),
hourly_wage: 66.22,
category: HealthcareSocialServices,
keywords: vec!["np", "nurse practitioner", "aprn", "prescribe"],
},
Occupation {
name: "Pharmacists".into(),
hourly_wage: 66.22,
category: HealthcareSocialServices,
keywords: vec![
"pharmacy",
"pharmacist",
"medication",
"prescription",
"drug",
],
},
Occupation {
name: "Medical Secretaries and Administrative Assistants".into(),
hourly_wage: 66.22,
category: HealthcareSocialServices,
keywords: vec![
"medical secretary",
"medical records",
"ehr",
"scheduling appointments",
],
},
// Legal, Media & Operations
Occupation {
name: "Lawyers".into(),
hourly_wage: 44.41,
category: LegalMediaOperations,
keywords: vec![
"lawyer",
"attorney",
"legal",
"contract",
"litigation",
"counsel",
"law",
"paralegal",
],
},
Occupation {
name: "Editors".into(),
hourly_wage: 72.06,
category: LegalMediaOperations,
keywords: vec![
"editor",
"editing",
"proofread",
"copy edit",
"manuscript",
"publication",
],
},
Occupation {
name: "Film and Video Editors".into(),
hourly_wage: 68.15,
category: LegalMediaOperations,
keywords: vec![
"video editor",
"film",
"premiere",
"final cut",
"davinci",
"post-production",
],
},
Occupation {
name: "Audio and Video Technicians".into(),
hourly_wage: 41.86,
category: LegalMediaOperations,
keywords: vec![
"audio",
"video",
"av",
"broadcast",
"streaming",
"recording",
],
},
Occupation {
name: "Producers and Directors".into(),
hourly_wage: 41.86,
category: LegalMediaOperations,
keywords: vec![
"producer",
"director",
"production",
"creative director",
"content",
"show",
],
},
Occupation {
name: "News Analysts, Reporters, and Journalists".into(),
hourly_wage: 68.15,
category: LegalMediaOperations,
keywords: vec![
"journalist",
"reporter",
"news",
"article",
"press",
"interview",
"story",
],
},
Occupation {
name: "Entertainment and Recreation Managers, Except Gambling".into(),
hourly_wage: 41.86,
category: LegalMediaOperations,
keywords: vec!["entertainment", "recreation", "event", "venue", "concert"],
},
Occupation {
name: "Recreation Workers".into(),
hourly_wage: 41.86,
category: LegalMediaOperations,
keywords: vec!["recreation", "activity", "fitness", "sports"],
},
Occupation {
name: "Customer Service Representatives".into(),
hourly_wage: 44.41,
category: LegalMediaOperations,
keywords: vec!["customer service", "support", "helpdesk", "ticket", "chat"],
},
Occupation {
name: "Private Detectives and Investigators".into(),
hourly_wage: 37.87,
category: LegalMediaOperations,
keywords: vec![
"detective",
"investigator",
"background check",
"surveillance",
],
},
Occupation {
name: "First-Line Supervisors of Police and Detectives".into(),
hourly_wage: 72.06,
category: LegalMediaOperations,
keywords: vec!["police", "law enforcement", "security supervisor"],
},
]
}
/// Build keyword → occupation index for fast lookup
fn build_keyword_index(occupations: &[Occupation]) -> HashMap<&'static str, Vec<usize>> {
let mut index: HashMap<&'static str, Vec<usize>> = HashMap::new();
for (i, occ) in occupations.iter().enumerate() {
for &kw in &occ.keywords {
index.entry(kw).or_default().push(i);
}
}
index
}
/// Classify a task instruction into an occupation with estimated value
///
/// This is a synchronous keyword-based classifier. For LLM-based
/// classification, use `classify_with_llm` instead.
pub fn classify(&self, instruction: &str) -> ClassificationResult {
let lower = instruction.to_lowercase();
let mut scores: HashMap<usize, f64> = HashMap::new();
// Score each occupation by keyword matches
for (keyword, occ_indices) in &self.keyword_index {
if lower.contains(keyword) {
for &idx in occ_indices {
*scores.entry(idx).or_default() += 1.0;
}
}
}
// Find best match
let (best_idx, best_score) = scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(&idx, &score)| (idx, score))
.unwrap_or((usize::MAX, 0.0));
let (occupation, hourly_wage, category, confidence, reasoning) =
if best_idx < self.occupations.len() {
let occ = &self.occupations[best_idx];
let confidence = (best_score / 3.0).min(1.0); // Normalize confidence
(
occ.name.clone(),
occ.hourly_wage,
occ.category,
confidence,
format!("Matched {:.0} keywords", best_score),
)
} else {
// Fallback
(
self.fallback_occupation.clone(),
self.fallback_wage,
OccupationCategory::BusinessFinance,
0.3,
"Fallback classification - no strong keyword match".to_string(),
)
};
let estimated_hours = Self::estimate_hours(instruction);
let max_payment = (estimated_hours * hourly_wage * 100.0).round() / 100.0;
ClassificationResult {
occupation,
hourly_wage,
estimated_hours,
max_payment,
confidence,
category,
reasoning,
}
}
/// Estimate hours based on instruction complexity
fn estimate_hours(instruction: &str) -> f64 {
let word_count = instruction.split_whitespace().count();
let has_complex_markers = instruction.to_lowercase().contains("implement")
|| instruction.contains("build")
|| instruction.contains("create")
|| instruction.contains("design")
|| instruction.contains("develop");
let has_simple_markers = instruction.to_lowercase().contains("fix")
|| instruction.contains("update")
|| instruction.contains("change")
|| instruction.contains("review");
let base_hours = if has_complex_markers {
2.0
} else if has_simple_markers {
0.5
} else {
1.0
};
// Scale by instruction length
let length_factor = (word_count as f64 / 20.0).clamp(0.5, 2.0);
let hours = base_hours * length_factor;
// Clamp to valid range
hours.clamp(0.25, 40.0)
}
/// Get all occupations
pub fn occupations(&self) -> &[Occupation] {
&self.occupations
}
/// Get occupations by category
pub fn occupations_by_category(&self, category: OccupationCategory) -> Vec<&Occupation> {
self.occupations
.iter()
.filter(|o| o.category == category)
.collect()
}
/// Get the fallback occupation name
pub fn fallback_occupation(&self) -> &str {
&self.fallback_occupation
}
/// Get the fallback hourly wage
pub fn fallback_wage(&self) -> f64 {
self.fallback_wage
}
/// Look up an occupation by exact name
pub fn get_occupation(&self, name: &str) -> Option<&Occupation> {
self.occupations.iter().find(|o| o.name == name)
}
/// Fuzzy match an occupation name (case-insensitive, substring)
pub fn fuzzy_match(&self, name: &str) -> Option<&Occupation> {
let lower = name.to_lowercase();
// Exact match first
if let Some(occ) = self.occupations.iter().find(|o| o.name == name) {
return Some(occ);
}
// Case-insensitive match
if let Some(occ) = self
.occupations
.iter()
.find(|o| o.name.to_lowercase() == lower)
{
return Some(occ);
}
// Substring match
self.occupations.iter().find(|o| {
lower.contains(&o.name.to_lowercase()) || o.name.to_lowercase().contains(&lower)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classifier_new() {
let classifier = TaskClassifier::new();
assert_eq!(classifier.occupations.len(), 44);
}
#[test]
fn test_classify_software() {
let classifier = TaskClassifier::new();
let result = classifier.classify("Write a REST API in Rust with authentication");
assert_eq!(result.occupation, "Software Developers");
assert!((result.hourly_wage - 69.50).abs() < 0.01);
assert!(result.confidence > 0.0);
}
#[test]
fn test_classify_finance() {
let classifier = TaskClassifier::new();
let result = classifier.classify("Prepare quarterly financial statements and audit trail");
assert!(
result.occupation.contains("Account") || result.occupation.contains("Financial"),
"Expected finance occupation, got: {}",
result.occupation
);
}
#[test]
fn test_classify_fallback() {
let classifier = TaskClassifier::new();
let result = classifier.classify("xyzzy foobar baz");
assert_eq!(result.occupation, "General and Operations Managers");
assert_eq!(result.confidence, 0.3);
}
#[test]
fn test_estimate_hours_complex() {
let hours = TaskClassifier::estimate_hours(
"Implement a complete microservices architecture with event sourcing",
);
assert!(hours >= 1.0, "Complex task should estimate >= 1 hour");
}
#[test]
fn test_estimate_hours_simple() {
let hours = TaskClassifier::estimate_hours("Fix typo");
assert!(hours <= 1.0, "Simple task should estimate <= 1 hour");
}
#[test]
fn test_fuzzy_match() {
let classifier = TaskClassifier::new();
// Exact match
assert!(classifier.fuzzy_match("Software Developers").is_some());
// Case insensitive
assert!(classifier.fuzzy_match("software developers").is_some());
// Substring
assert!(classifier.fuzzy_match("Software").is_some());
}
#[test]
fn test_occupations_by_category() {
let classifier = TaskClassifier::new();
let tech = classifier.occupations_by_category(OccupationCategory::TechnologyEngineering);
assert!(!tech.is_empty());
assert!(tech.iter().any(|o| o.name == "Software Developers"));
}
}
+369
View File
@@ -0,0 +1,369 @@
//! Token cost tracking types for economic agents.
//!
//! Separates costs by channel (LLM, search API, OCR, etc.) following
//! the ClawWork economic model.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Channel-separated cost breakdown for a task or session.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CostBreakdown {
/// Cost from LLM token usage
pub llm_tokens: f64,
/// Cost from search API calls (Brave, JINA, Tavily, etc.)
pub search_api: f64,
/// Cost from OCR API calls
pub ocr_api: f64,
/// Cost from other API calls
pub other_api: f64,
}
impl CostBreakdown {
/// Create a new empty cost breakdown.
pub fn new() -> Self {
Self::default()
}
/// Get total cost across all channels.
pub fn total(&self) -> f64 {
self.llm_tokens + self.search_api + self.ocr_api + self.other_api
}
/// Add another breakdown to this one.
pub fn add(&mut self, other: &CostBreakdown) {
self.llm_tokens += other.llm_tokens;
self.search_api += other.search_api;
self.ocr_api += other.ocr_api;
self.other_api += other.other_api;
}
/// Reset all costs to zero.
pub fn reset(&mut self) {
self.llm_tokens = 0.0;
self.search_api = 0.0;
self.ocr_api = 0.0;
self.other_api = 0.0;
}
}
/// Token pricing configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenPricing {
/// Price per million input tokens (USD)
pub input_price_per_million: f64,
/// Price per million output tokens (USD)
pub output_price_per_million: f64,
}
impl Default for TokenPricing {
fn default() -> Self {
// Default to Claude Sonnet 4 pricing via OpenRouter
Self {
input_price_per_million: 3.0,
output_price_per_million: 15.0,
}
}
}
impl TokenPricing {
/// Calculate cost for given token counts.
pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_price_per_million;
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_price_per_million;
input_cost + output_cost
}
}
/// A single LLM call record with token details.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmCallRecord {
/// Timestamp of the call
pub timestamp: DateTime<Utc>,
/// API name/source (e.g., "agent", "wrapup", "research")
pub api_name: String,
/// Number of input tokens
pub input_tokens: u64,
/// Number of output tokens
pub output_tokens: u64,
/// Cost in USD
pub cost: f64,
}
/// A single API call record (non-LLM).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiCallRecord {
/// Timestamp of the call
pub timestamp: DateTime<Utc>,
/// API name (e.g., "tavily_search", "jina_reader")
pub api_name: String,
/// Pricing model used
pub pricing_model: PricingModel,
/// Number of tokens (if token-based pricing)
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<u64>,
/// Price per million tokens (if token-based)
#[serde(skip_serializing_if = "Option::is_none")]
pub price_per_million: Option<f64>,
/// Cost in USD
pub cost: f64,
}
/// Pricing model for API calls.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PricingModel {
/// Token-based pricing (cost = tokens / 1M * price_per_million)
PerToken,
/// Flat rate per call
FlatRate,
}
/// Comprehensive task cost record (one per task).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskCostRecord {
/// Task end timestamp
pub timestamp_end: DateTime<Utc>,
/// Task start timestamp
pub timestamp_start: DateTime<Utc>,
/// Date the task was assigned (YYYY-MM-DD)
pub date: String,
/// Unique task identifier
pub task_id: String,
/// LLM usage summary
pub llm_usage: LlmUsageSummary,
/// API usage summary
pub api_usage: ApiUsageSummary,
/// Cost summary by channel
pub cost_summary: CostBreakdown,
/// Balance after this task
pub balance_after: f64,
/// Session cost so far
pub session_cost: f64,
/// Daily cost so far
pub daily_cost: f64,
}
/// Aggregated LLM usage for a task.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmUsageSummary {
/// Number of LLM calls made
pub total_calls: usize,
/// Total input tokens
pub total_input_tokens: u64,
/// Total output tokens
pub total_output_tokens: u64,
/// Total tokens (input + output)
pub total_tokens: u64,
/// Total cost in USD
pub total_cost: f64,
/// Pricing used
pub input_price_per_million: f64,
pub output_price_per_million: f64,
/// Detailed call records
#[serde(default)]
pub calls_detail: Vec<LlmCallRecord>,
}
/// Aggregated API usage for a task.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ApiUsageSummary {
/// Number of API calls made
pub total_calls: usize,
/// Search API costs
pub search_api_cost: f64,
/// OCR API costs
pub ocr_api_cost: f64,
/// Other API costs
pub other_api_cost: f64,
/// Number of token-based calls
pub token_based_calls: usize,
/// Number of flat-rate calls
pub flat_rate_calls: usize,
/// Detailed call records
#[serde(default)]
pub calls_detail: Vec<ApiCallRecord>,
}
/// Work income record.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkIncomeRecord {
/// Timestamp
pub timestamp: DateTime<Utc>,
/// Date (YYYY-MM-DD)
pub date: String,
/// Task identifier
pub task_id: String,
/// Base payment amount offered
pub base_amount: f64,
/// Actual payment received (0 if below threshold)
pub actual_payment: f64,
/// Evaluation score (0.0-1.0)
pub evaluation_score: f64,
/// Minimum threshold required for payment
pub threshold: f64,
/// Whether payment was awarded
pub payment_awarded: bool,
/// Optional description
#[serde(default)]
pub description: String,
/// Balance after this income
pub balance_after: f64,
}
/// Daily balance record for persistence.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BalanceRecord {
/// Date (YYYY-MM-DD or "initialization")
pub date: String,
/// Current balance
pub balance: f64,
/// Token cost delta for this period
pub token_cost_delta: f64,
/// Work income delta for this period
pub work_income_delta: f64,
/// Trading profit delta for this period
pub trading_profit_delta: f64,
/// Cumulative total token cost
pub total_token_cost: f64,
/// Cumulative total work income
pub total_work_income: f64,
/// Cumulative total trading profit
pub total_trading_profit: f64,
/// Net worth (balance + portfolio value)
pub net_worth: f64,
/// Current survival status
pub survival_status: String,
/// Tasks completed in this period
#[serde(default)]
pub completed_tasks: Vec<String>,
/// Primary task ID for the day
#[serde(skip_serializing_if = "Option::is_none")]
pub task_id: Option<String>,
/// Time to complete tasks (seconds)
#[serde(skip_serializing_if = "Option::is_none")]
pub task_completion_time_seconds: Option<f64>,
/// Whether session was aborted by API error
#[serde(default)]
pub api_error: bool,
}
/// Task completion record for analytics.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskCompletionRecord {
/// Task identifier
pub task_id: String,
/// Date (YYYY-MM-DD)
pub date: String,
/// Attempt number (1-based)
pub attempt: u32,
/// Whether work was submitted
pub work_submitted: bool,
/// Evaluation score (0.0-1.0)
pub evaluation_score: f64,
/// Money earned from this task
pub money_earned: f64,
/// Wall-clock time in seconds
pub wall_clock_seconds: f64,
/// Timestamp of completion
pub timestamp: DateTime<Utc>,
}
/// Economic analytics summary.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EconomicAnalytics {
/// Total costs by channel
pub total_costs: CostBreakdown,
/// Costs broken down by date
pub by_date: HashMap<String, DateCostSummary>,
/// Costs broken down by task
pub by_task: HashMap<String, TaskCostSummary>,
/// Total number of tasks
pub total_tasks: usize,
/// Total income earned
pub total_income: f64,
/// Number of tasks that received payment
pub tasks_paid: usize,
/// Number of tasks rejected (below threshold)
pub tasks_rejected: usize,
}
/// Cost summary for a single date.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DateCostSummary {
/// Costs by channel
#[serde(flatten)]
pub costs: CostBreakdown,
/// Total cost
pub total: f64,
/// Income earned
pub income: f64,
}
/// Cost summary for a single task.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TaskCostSummary {
/// Costs by channel
#[serde(flatten)]
pub costs: CostBreakdown,
/// Total cost
pub total: f64,
/// Date of the task
pub date: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cost_breakdown_total() {
let breakdown = CostBreakdown {
llm_tokens: 1.0,
search_api: 0.5,
ocr_api: 0.25,
other_api: 0.1,
};
assert!((breakdown.total() - 1.85).abs() < f64::EPSILON);
}
#[test]
fn cost_breakdown_add() {
let mut a = CostBreakdown {
llm_tokens: 1.0,
search_api: 0.5,
ocr_api: 0.0,
other_api: 0.0,
};
let b = CostBreakdown {
llm_tokens: 0.5,
search_api: 0.25,
ocr_api: 0.1,
other_api: 0.05,
};
a.add(&b);
assert!((a.llm_tokens - 1.5).abs() < f64::EPSILON);
assert!((a.search_api - 0.75).abs() < f64::EPSILON);
assert!((a.total() - 2.4).abs() < f64::EPSILON);
}
#[test]
fn token_pricing_calculation() {
let pricing = TokenPricing {
input_price_per_million: 3.0,
output_price_per_million: 15.0,
};
// 1000 input, 500 output
// (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
let cost = pricing.calculate_cost(1000, 500);
assert!((cost - 0.0105).abs() < 0.0001);
}
#[test]
fn default_token_pricing() {
let pricing = TokenPricing::default();
assert!((pricing.input_price_per_million - 3.0).abs() < f64::EPSILON);
assert!((pricing.output_price_per_million - 15.0).abs() < f64::EPSILON);
}
}
+83
View File
@@ -0,0 +1,83 @@
//! Economic tracking module for agent survival economics.
//!
//! This module implements the ClawWork economic model for AI agents,
//! tracking balance, costs, income, and survival status. Agents start
//! with initial capital and must manage their resources while completing
//! tasks.
//!
//! ## Overview
//!
//! The economic system models agent viability:
//! - **Balance**: Starting capital minus costs plus earned income
//! - **Costs**: LLM tokens, search APIs, OCR, and other service usage
//! - **Income**: Payments for completed tasks (with quality threshold)
//! - **Status**: Health indicator based on remaining capital percentage
//!
//! ## Example
//!
//! ```rust,ignore
//! use zeroclaw::economic::{EconomicTracker, EconomicConfig, SurvivalStatus};
//!
//! let config = EconomicConfig {
//! enabled: true,
//! initial_balance: 1000.0,
//! ..Default::default()
//! };
//!
//! let tracker = EconomicTracker::new("my-agent", config, None);
//! tracker.initialize()?;
//!
//! // Start a task
//! tracker.start_task("task-001", None);
//!
//! // Track LLM usage
//! let cost = tracker.track_tokens(1000, 500, "agent", None);
//!
//! // Complete task and earn income
//! tracker.end_task()?;
//! let payment = tracker.add_work_income(10.0, "task-001", 0.85, "Completed task")?;
//!
//! // Check survival status
//! match tracker.get_survival_status() {
//! SurvivalStatus::Thriving => println!("Agent is healthy!"),
//! SurvivalStatus::Bankrupt => println!("Agent needs intervention!"),
//! _ => {}
//! }
//! ```
//!
//! ## Persistence
//!
//! Economic state is persisted to JSONL files:
//! - `balance.jsonl`: Daily balance snapshots and cumulative totals
//! - `token_costs.jsonl`: Detailed per-task cost records
//! - `task_completions.jsonl`: Task completion statistics
//!
//! ## Configuration
//!
//! Add to `config.toml`:
//!
//! ```toml
//! [economic]
//! enabled = true
//! initial_balance = 1000.0
//! min_evaluation_threshold = 0.6
//!
//! [economic.token_pricing]
//! input_price_per_million = 3.0
//! output_price_per_million = 15.0
//! ```
pub mod classifier;
pub mod costs;
pub mod status;
pub mod tracker;
// Re-exports for convenient access
pub use classifier::{ClassificationResult, Occupation, OccupationCategory, TaskClassifier};
pub use costs::{
ApiCallRecord, ApiUsageSummary, BalanceRecord, CostBreakdown, DateCostSummary,
EconomicAnalytics, LlmCallRecord, LlmUsageSummary, PricingModel, TaskCompletionRecord,
TaskCostRecord, TaskCostSummary, TokenPricing, WorkIncomeRecord,
};
pub use status::SurvivalStatus;
pub use tracker::{EconomicConfig, EconomicSummary, EconomicTracker};
+207
View File
@@ -0,0 +1,207 @@
//! Survival status tracking for economic agents.
//!
//! Defines the health states an agent can be in based on remaining balance
//! as a percentage of initial capital.
use serde::{Deserialize, Serialize};
use std::fmt;
/// Survival status based on balance percentage relative to initial capital.
///
/// Mirrors the ClawWork LiveBench agent survival states.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum SurvivalStatus {
/// Balance > 80% of initial - Agent is profitable and healthy
Thriving,
/// Balance 40-80% of initial - Agent is maintaining stability
#[default]
Stable,
/// Balance 10-40% of initial - Agent is losing money, needs attention
Struggling,
/// Balance 1-10% of initial - Agent is near death, urgent intervention needed
Critical,
/// Balance <= 0 - Agent has exhausted resources and cannot operate
Bankrupt,
}
impl SurvivalStatus {
/// Calculate survival status from current and initial balance.
///
/// # Arguments
/// * `current_balance` - Current remaining balance
/// * `initial_balance` - Starting balance
///
/// # Returns
/// The appropriate `SurvivalStatus` based on the percentage remaining.
pub fn from_balance(current_balance: f64, initial_balance: f64) -> Self {
if initial_balance <= 0.0 {
// Edge case: if initial was zero or negative, can't calculate percentage
return if current_balance <= 0.0 {
Self::Bankrupt
} else {
Self::Thriving
};
}
let percentage = (current_balance / initial_balance) * 100.0;
match percentage {
p if p <= 0.0 => Self::Bankrupt,
p if p < 10.0 => Self::Critical,
p if p < 40.0 => Self::Struggling,
p if p < 80.0 => Self::Stable,
_ => Self::Thriving,
}
}
/// Check if the agent can still operate (not bankrupt).
pub fn is_operational(&self) -> bool {
!matches!(self, Self::Bankrupt)
}
/// Check if the agent needs urgent attention.
pub fn needs_intervention(&self) -> bool {
matches!(self, Self::Critical | Self::Bankrupt)
}
/// Get a human-readable emoji indicator.
pub fn emoji(&self) -> &'static str {
match self {
Self::Thriving => "🌟",
Self::Stable => "",
Self::Struggling => "⚠️",
Self::Critical => "🚨",
Self::Bankrupt => "💀",
}
}
/// Get a color code for terminal output (ANSI).
pub fn ansi_color(&self) -> &'static str {
match self {
Self::Thriving => "\x1b[32m", // Green
Self::Stable => "\x1b[34m", // Blue
Self::Struggling => "\x1b[33m", // Yellow
Self::Critical => "\x1b[31m", // Red
Self::Bankrupt => "\x1b[35m", // Magenta
}
}
}
impl fmt::Display for SurvivalStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status = match self {
Self::Thriving => "Thriving",
Self::Stable => "Stable",
Self::Struggling => "Struggling",
Self::Critical => "Critical",
Self::Bankrupt => "Bankrupt",
};
write!(f, "{}", status)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn thriving_above_80_percent() {
assert_eq!(
SurvivalStatus::from_balance(900.0, 1000.0),
SurvivalStatus::Thriving
);
assert_eq!(
SurvivalStatus::from_balance(1500.0, 1000.0), // Profit!
SurvivalStatus::Thriving
);
assert_eq!(
SurvivalStatus::from_balance(800.01, 1000.0),
SurvivalStatus::Thriving
);
}
#[test]
fn stable_between_40_and_80_percent() {
assert_eq!(
SurvivalStatus::from_balance(799.99, 1000.0),
SurvivalStatus::Stable
);
assert_eq!(
SurvivalStatus::from_balance(500.0, 1000.0),
SurvivalStatus::Stable
);
assert_eq!(
SurvivalStatus::from_balance(400.01, 1000.0),
SurvivalStatus::Stable
);
}
#[test]
fn struggling_between_10_and_40_percent() {
assert_eq!(
SurvivalStatus::from_balance(399.99, 1000.0),
SurvivalStatus::Struggling
);
assert_eq!(
SurvivalStatus::from_balance(200.0, 1000.0),
SurvivalStatus::Struggling
);
assert_eq!(
SurvivalStatus::from_balance(100.01, 1000.0),
SurvivalStatus::Struggling
);
}
#[test]
fn critical_between_0_and_10_percent() {
assert_eq!(
SurvivalStatus::from_balance(99.99, 1000.0),
SurvivalStatus::Critical
);
assert_eq!(
SurvivalStatus::from_balance(50.0, 1000.0),
SurvivalStatus::Critical
);
assert_eq!(
SurvivalStatus::from_balance(0.01, 1000.0),
SurvivalStatus::Critical
);
}
#[test]
fn bankrupt_at_zero_or_negative() {
assert_eq!(
SurvivalStatus::from_balance(0.0, 1000.0),
SurvivalStatus::Bankrupt
);
assert_eq!(
SurvivalStatus::from_balance(-100.0, 1000.0),
SurvivalStatus::Bankrupt
);
}
#[test]
fn is_operational() {
assert!(SurvivalStatus::Thriving.is_operational());
assert!(SurvivalStatus::Stable.is_operational());
assert!(SurvivalStatus::Struggling.is_operational());
assert!(SurvivalStatus::Critical.is_operational());
assert!(!SurvivalStatus::Bankrupt.is_operational());
}
#[test]
fn needs_intervention() {
assert!(!SurvivalStatus::Thriving.needs_intervention());
assert!(!SurvivalStatus::Stable.needs_intervention());
assert!(!SurvivalStatus::Struggling.needs_intervention());
assert!(SurvivalStatus::Critical.needs_intervention());
assert!(SurvivalStatus::Bankrupt.needs_intervention());
}
#[test]
fn display_format() {
assert_eq!(format!("{}", SurvivalStatus::Thriving), "Thriving");
assert_eq!(format!("{}", SurvivalStatus::Bankrupt), "Bankrupt");
}
}
+992
View File
@@ -0,0 +1,992 @@
//! Economic tracker for agent survival economics.
//!
//! Tracks balance, token costs, work income, and survival status following
//! the ClawWork LiveBench economic model. Persists state to JSONL files.
use super::costs::{
ApiCallRecord, ApiUsageSummary, BalanceRecord, CostBreakdown, LlmCallRecord, LlmUsageSummary,
PricingModel, TaskCompletionRecord, TaskCostRecord, TokenPricing, WorkIncomeRecord,
};
use super::status::SurvivalStatus;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::fs::{self, File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use std::path::PathBuf;
use std::sync::Arc;
/// Economic configuration options.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EconomicConfig {
/// Enable economic tracking
#[serde(default)]
pub enabled: bool,
/// Starting balance in USD
#[serde(default = "default_initial_balance")]
pub initial_balance: f64,
/// Token pricing configuration
#[serde(default)]
pub token_pricing: TokenPricing,
/// Minimum evaluation score to receive payment (0.0-1.0)
#[serde(default = "default_min_threshold")]
pub min_evaluation_threshold: f64,
}
fn default_initial_balance() -> f64 {
1000.0
}
fn default_min_threshold() -> f64 {
0.6
}
impl Default for EconomicConfig {
fn default() -> Self {
Self {
enabled: false,
initial_balance: default_initial_balance(),
token_pricing: TokenPricing::default(),
min_evaluation_threshold: default_min_threshold(),
}
}
}
/// Task-level tracking state (in-memory during task execution).
#[derive(Debug, Clone, Default)]
struct TaskState {
/// Current task ID
task_id: Option<String>,
/// Date the task was assigned
task_date: Option<String>,
/// Task start timestamp
start_time: Option<DateTime<Utc>>,
/// Costs accumulated for this task
costs: CostBreakdown,
/// LLM call records
llm_calls: Vec<LlmCallRecord>,
/// API call records
api_calls: Vec<ApiCallRecord>,
}
impl TaskState {
fn reset(&mut self) {
self.task_id = None;
self.task_date = None;
self.start_time = None;
self.costs.reset();
self.llm_calls.clear();
self.api_calls.clear();
}
}
/// Daily tracking state (accumulated across tasks).
#[derive(Debug, Clone, Default)]
struct DailyState {
/// Task IDs completed today
task_ids: Vec<String>,
/// First task start time
first_task_start: Option<DateTime<Utc>>,
/// Last task end time
last_task_end: Option<DateTime<Utc>>,
/// Daily cost accumulator
cost: f64,
}
impl DailyState {
fn reset(&mut self) {
self.task_ids.clear();
self.first_task_start = None;
self.last_task_end = None;
self.cost = 0.0;
}
}
/// Session tracking state.
#[derive(Debug, Clone, Default)]
struct SessionState {
/// Input tokens this session
input_tokens: u64,
/// Output tokens this session
output_tokens: u64,
/// Cost this session
cost: f64,
}
impl SessionState {
fn reset(&mut self) {
self.input_tokens = 0;
self.output_tokens = 0;
self.cost = 0.0;
}
}
/// Economic tracker for managing agent survival economics.
///
/// Tracks:
/// - Balance (starting capital minus costs plus income)
/// - Token costs separated by channel (LLM, search, OCR, etc.)
/// - Work income with evaluation threshold
/// - Trading profits/losses
/// - Survival status
///
/// Persists records to JSONL files for durability and analysis.
pub struct EconomicTracker {
/// Configuration
config: EconomicConfig,
/// Agent signature/name
signature: String,
/// Data directory for persistence
data_path: PathBuf,
/// Current balance (protected by mutex for thread safety)
state: Arc<Mutex<TrackerState>>,
}
/// Internal mutable state.
struct TrackerState {
/// Current balance
balance: f64,
/// Initial balance (for status calculation)
initial_balance: f64,
/// Cumulative totals
total_token_cost: f64,
total_work_income: f64,
total_trading_profit: f64,
/// Task-level tracking
task: TaskState,
/// Daily tracking
daily: DailyState,
/// Session tracking
session: SessionState,
}
impl EconomicTracker {
/// Create a new economic tracker.
///
/// # Arguments
/// * `signature` - Agent signature/name for identification
/// * `config` - Economic configuration
/// * `data_path` - Optional custom data path (defaults to `./data/agent_data/{signature}/economic`)
pub fn new(
signature: impl Into<String>,
config: EconomicConfig,
data_path: Option<PathBuf>,
) -> Self {
let signature = signature.into();
let data_path = data_path
.unwrap_or_else(|| PathBuf::from(format!("./data/agent_data/{}/economic", signature)));
Self {
signature,
state: Arc::new(Mutex::new(TrackerState {
balance: config.initial_balance,
initial_balance: config.initial_balance,
total_token_cost: 0.0,
total_work_income: 0.0,
total_trading_profit: 0.0,
task: TaskState::default(),
daily: DailyState::default(),
session: SessionState::default(),
})),
config,
data_path,
}
}
/// Initialize the tracker, loading existing state or creating new.
pub fn initialize(&self) -> Result<()> {
fs::create_dir_all(&self.data_path).with_context(|| {
format!(
"Failed to create data directory: {}",
self.data_path.display()
)
})?;
let balance_file = self.balance_file_path();
if balance_file.exists() {
self.load_latest_state()?;
let state = self.state.lock();
tracing::info!(
"📊 Loaded economic state for {}: balance=${:.2}, status={}",
self.signature,
state.balance,
self.get_survival_status_inner(&state)
);
} else {
self.save_balance_record("initialization", 0.0, 0.0, 0.0, Vec::new(), false)?;
tracing::info!(
"✅ Initialized economic tracker for {}: starting balance=${:.2}",
self.signature,
self.config.initial_balance
);
}
Ok(())
}
/// Start tracking costs for a new task.
pub fn start_task(&self, task_id: impl Into<String>, date: Option<String>) {
let task_id = task_id.into();
let date = date.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string());
let now = Utc::now();
let mut state = self.state.lock();
state.task.task_id = Some(task_id.clone());
state.task.task_date = Some(date);
state.task.start_time = Some(now);
state.task.costs.reset();
state.task.llm_calls.clear();
state.task.api_calls.clear();
// Track daily window
if state.daily.first_task_start.is_none() {
state.daily.first_task_start = Some(now);
}
state.daily.task_ids.push(task_id);
}
/// End tracking for current task and save consolidated record.
pub fn end_task(&self) -> Result<()> {
let mut state = self.state.lock();
if state.task.task_id.is_some() {
self.save_task_record_inner(&state)?;
state.daily.last_task_end = Some(Utc::now());
state.task.reset();
}
Ok(())
}
/// Track LLM token usage.
///
/// # Arguments
/// * `input_tokens` - Number of input tokens
/// * `output_tokens` - Number of output tokens
/// * `api_name` - Origin of the call (e.g., "agent", "wrapup")
/// * `cost` - Pre-computed cost (if provided, skips local calculation)
///
/// # Returns
/// The cost in USD for this call.
pub fn track_tokens(
&self,
input_tokens: u64,
output_tokens: u64,
api_name: impl Into<String>,
cost: Option<f64>,
) -> f64 {
let api_name = api_name.into();
let cost = cost.unwrap_or_else(|| {
self.config
.token_pricing
.calculate_cost(input_tokens, output_tokens)
});
let mut state = self.state.lock();
// Update session tracking
state.session.input_tokens += input_tokens;
state.session.output_tokens += output_tokens;
state.session.cost += cost;
state.daily.cost += cost;
// Update task-level tracking
state.task.costs.llm_tokens += cost;
state.task.llm_calls.push(LlmCallRecord {
timestamp: Utc::now(),
api_name,
input_tokens,
output_tokens,
cost,
});
// Update totals
state.total_token_cost += cost;
state.balance -= cost;
cost
}
/// Track token-based API call cost.
///
/// # Arguments
/// * `tokens` - Number of tokens used
/// * `price_per_million` - Price per million tokens
/// * `api_name` - Name of the API
///
/// # Returns
/// The cost in USD for this call.
pub fn track_api_call(
&self,
tokens: u64,
price_per_million: f64,
api_name: impl Into<String>,
) -> f64 {
let api_name = api_name.into();
let cost = (tokens as f64 / 1_000_000.0) * price_per_million;
self.record_api_cost(
&api_name,
cost,
Some(tokens),
Some(price_per_million),
PricingModel::PerToken,
);
cost
}
/// Track flat-rate API call cost.
///
/// # Arguments
/// * `cost` - Flat cost in USD
/// * `api_name` - Name of the API
///
/// # Returns
/// The cost (same as input).
pub fn track_flat_api_call(&self, cost: f64, api_name: impl Into<String>) -> f64 {
let api_name = api_name.into();
self.record_api_cost(&api_name, cost, None, None, PricingModel::FlatRate);
cost
}
fn record_api_cost(
&self,
api_name: &str,
cost: f64,
tokens: Option<u64>,
price_per_million: Option<f64>,
pricing_model: PricingModel,
) {
let mut state = self.state.lock();
// Update session/daily
state.session.cost += cost;
state.daily.cost += cost;
// Categorize by API type
let api_lower = api_name.to_lowercase();
if api_lower.contains("search")
|| api_lower.contains("jina")
|| api_lower.contains("tavily")
{
state.task.costs.search_api += cost;
} else if api_lower.contains("ocr") {
state.task.costs.ocr_api += cost;
} else {
state.task.costs.other_api += cost;
}
// Record detailed call
state.task.api_calls.push(ApiCallRecord {
timestamp: Utc::now(),
api_name: api_name.to_string(),
pricing_model,
tokens,
price_per_million,
cost,
});
// Update totals
state.total_token_cost += cost;
state.balance -= cost;
}
/// Add income from completed work with evaluation threshold.
///
/// Payment is only awarded if `evaluation_score >= min_evaluation_threshold`.
///
/// # Arguments
/// * `amount` - Base payment amount in USD
/// * `task_id` - Task identifier
/// * `evaluation_score` - Score from 0.0 to 1.0
/// * `description` - Optional description
///
/// # Returns
/// Actual payment received (0.0 if below threshold).
pub fn add_work_income(
&self,
amount: f64,
task_id: impl Into<String>,
evaluation_score: f64,
description: impl Into<String>,
) -> Result<f64> {
let task_id = task_id.into();
let description = description.into();
let threshold = self.config.min_evaluation_threshold;
let actual_payment = if evaluation_score >= threshold {
amount
} else {
0.0
};
{
let mut state = self.state.lock();
if actual_payment > 0.0 {
state.balance += actual_payment;
state.total_work_income += actual_payment;
tracing::info!(
"💰 Work income: +${:.2} (Task: {}, Score: {:.2})",
actual_payment,
task_id,
evaluation_score
);
} else {
tracing::warn!(
"⚠️ Work below threshold (score: {:.2} < {:.2}), no payment for task: {}",
evaluation_score,
threshold,
task_id
);
}
}
self.log_work_income(
&task_id,
amount,
actual_payment,
evaluation_score,
&description,
)?;
Ok(actual_payment)
}
/// Add profit/loss from trading.
pub fn add_trading_profit(&self, profit: f64, _description: impl Into<String>) {
let mut state = self.state.lock();
state.balance += profit;
state.total_trading_profit += profit;
let sign = if profit >= 0.0 { "+" } else { "" };
tracing::info!(
"📈 Trading P&L: {}${:.2}, new balance: ${:.2}",
sign,
profit,
state.balance
);
}
/// Save end-of-day economic state.
pub fn save_daily_state(
&self,
date: &str,
work_income: f64,
trading_profit: f64,
completed_tasks: Vec<String>,
api_error: bool,
) -> Result<()> {
let daily_cost = {
let state = self.state.lock();
state.daily.cost
};
self.save_balance_record(
date,
daily_cost,
work_income,
trading_profit,
completed_tasks,
api_error,
)?;
// Reset daily tracking
{
let mut state = self.state.lock();
state.daily.reset();
state.session.reset();
}
tracing::info!("💾 Saved daily state for {}", date);
Ok(())
}
/// Get current balance.
pub fn get_balance(&self) -> f64 {
self.state.lock().balance
}
/// Get net worth (balance + portfolio value).
pub fn get_net_worth(&self) -> f64 {
// TODO: Add trading portfolio value
self.get_balance()
}
/// Get current survival status.
pub fn get_survival_status(&self) -> SurvivalStatus {
let state = self.state.lock();
self.get_survival_status_inner(&state)
}
fn get_survival_status_inner(&self, state: &TrackerState) -> SurvivalStatus {
SurvivalStatus::from_balance(state.balance, state.initial_balance)
}
/// Check if agent is bankrupt.
pub fn is_bankrupt(&self) -> bool {
self.get_survival_status() == SurvivalStatus::Bankrupt
}
/// Get session cost so far.
pub fn get_session_cost(&self) -> f64 {
self.state.lock().session.cost
}
/// Get daily cost so far.
pub fn get_daily_cost(&self) -> f64 {
self.state.lock().daily.cost
}
/// Get comprehensive economic summary.
pub fn get_summary(&self) -> EconomicSummary {
let state = self.state.lock();
EconomicSummary {
signature: self.signature.clone(),
balance: state.balance,
initial_balance: state.initial_balance,
net_worth: state.balance, // TODO: Add portfolio
total_token_cost: state.total_token_cost,
total_work_income: state.total_work_income,
total_trading_profit: state.total_trading_profit,
session_cost: state.session.cost,
daily_cost: state.daily.cost,
session_input_tokens: state.session.input_tokens,
session_output_tokens: state.session.output_tokens,
survival_status: self.get_survival_status_inner(&state),
is_bankrupt: self.get_survival_status_inner(&state) == SurvivalStatus::Bankrupt,
min_evaluation_threshold: self.config.min_evaluation_threshold,
}
}
/// Reset session tracking (for new decision/activity).
pub fn reset_session(&self) {
self.state.lock().session.reset();
}
/// Record task completion statistics.
pub fn record_task_completion(
&self,
task_id: impl Into<String>,
work_submitted: bool,
wall_clock_seconds: f64,
evaluation_score: f64,
money_earned: f64,
attempt: u32,
date: Option<String>,
) -> Result<()> {
let task_id = task_id.into();
let date = date
.or_else(|| self.state.lock().task.task_date.clone())
.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string());
let record = TaskCompletionRecord {
task_id: task_id.clone(),
date,
attempt,
work_submitted,
evaluation_score,
money_earned,
wall_clock_seconds,
timestamp: Utc::now(),
};
// Read existing records, filter out this task_id
let completions_file = self.task_completions_file_path();
let mut existing: Vec<String> = Vec::new();
if completions_file.exists() {
let file = File::open(&completions_file)?;
let reader = BufReader::new(file);
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
if let Ok(entry) = serde_json::from_str::<TaskCompletionRecord>(&line) {
if entry.task_id != task_id {
existing.push(line);
}
} else {
existing.push(line);
}
}
}
// Rewrite with updated record
let mut file = File::create(&completions_file)?;
for line in existing {
writeln!(file, "{}", line)?;
}
writeln!(file, "{}", serde_json::to_string(&record)?)?;
file.sync_all()?;
Ok(())
}
// ── Private helpers ──
fn balance_file_path(&self) -> PathBuf {
self.data_path.join("balance.jsonl")
}
fn token_costs_file_path(&self) -> PathBuf {
self.data_path.join("token_costs.jsonl")
}
fn task_completions_file_path(&self) -> PathBuf {
self.data_path.join("task_completions.jsonl")
}
fn load_latest_state(&self) -> Result<()> {
let balance_file = self.balance_file_path();
let file = File::open(&balance_file)?;
let reader = BufReader::new(file);
let mut last_record: Option<BalanceRecord> = None;
for line in reader.lines() {
let line = line?;
if let Ok(record) = serde_json::from_str::<BalanceRecord>(&line) {
last_record = Some(record);
}
}
if let Some(record) = last_record {
let mut state = self.state.lock();
state.balance = record.balance;
state.total_token_cost = record.total_token_cost;
state.total_work_income = record.total_work_income;
state.total_trading_profit = record.total_trading_profit;
}
Ok(())
}
fn save_task_record_inner(&self, state: &TrackerState) -> Result<()> {
let Some(ref task_id) = state.task.task_id else {
return Ok(());
};
let total_input = state.task.llm_calls.iter().map(|c| c.input_tokens).sum();
let total_output = state.task.llm_calls.iter().map(|c| c.output_tokens).sum();
let llm_call_count = state.task.llm_calls.len();
let token_based = state
.task
.api_calls
.iter()
.filter(|c| c.pricing_model == PricingModel::PerToken)
.count();
let flat_rate = state
.task
.api_calls
.iter()
.filter(|c| c.pricing_model == PricingModel::FlatRate)
.count();
let record = TaskCostRecord {
timestamp_end: Utc::now(),
timestamp_start: state.task.start_time.unwrap_or_else(Utc::now),
date: state
.task
.task_date
.clone()
.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string()),
task_id: task_id.clone(),
llm_usage: LlmUsageSummary {
total_calls: llm_call_count,
total_input_tokens: total_input,
total_output_tokens: total_output,
total_tokens: total_input + total_output,
total_cost: state.task.costs.llm_tokens,
input_price_per_million: self.config.token_pricing.input_price_per_million,
output_price_per_million: self.config.token_pricing.output_price_per_million,
calls_detail: state.task.llm_calls.clone(),
},
api_usage: ApiUsageSummary {
total_calls: state.task.api_calls.len(),
search_api_cost: state.task.costs.search_api,
ocr_api_cost: state.task.costs.ocr_api,
other_api_cost: state.task.costs.other_api,
token_based_calls: token_based,
flat_rate_calls: flat_rate,
calls_detail: state.task.api_calls.clone(),
},
cost_summary: state.task.costs.clone(),
balance_after: state.balance,
session_cost: state.session.cost,
daily_cost: state.daily.cost,
};
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(self.token_costs_file_path())?;
writeln!(file, "{}", serde_json::to_string(&record)?)?;
file.sync_all()?;
Ok(())
}
fn save_balance_record(
&self,
date: &str,
token_cost_delta: f64,
work_income_delta: f64,
trading_profit_delta: f64,
completed_tasks: Vec<String>,
api_error: bool,
) -> Result<()> {
let state = self.state.lock();
let task_completion_time = match (state.daily.first_task_start, state.daily.last_task_end) {
(Some(start), Some(end)) => Some((end - start).num_seconds() as f64),
_ => None,
};
let record = BalanceRecord {
date: date.to_string(),
balance: state.balance,
token_cost_delta,
work_income_delta,
trading_profit_delta,
total_token_cost: state.total_token_cost,
total_work_income: state.total_work_income,
total_trading_profit: state.total_trading_profit,
net_worth: state.balance,
survival_status: self.get_survival_status_inner(&state).to_string(),
completed_tasks,
task_id: state.daily.task_ids.first().cloned(),
task_completion_time_seconds: task_completion_time,
api_error,
};
drop(state); // Release lock before IO
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(self.balance_file_path())?;
writeln!(file, "{}", serde_json::to_string(&record)?)?;
file.sync_all()?;
Ok(())
}
fn log_work_income(
&self,
task_id: &str,
base_amount: f64,
actual_payment: f64,
evaluation_score: f64,
description: &str,
) -> Result<()> {
let state = self.state.lock();
let record = WorkIncomeRecord {
timestamp: Utc::now(),
date: state
.task
.task_date
.clone()
.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string()),
task_id: task_id.to_string(),
base_amount,
actual_payment,
evaluation_score,
threshold: self.config.min_evaluation_threshold,
payment_awarded: actual_payment > 0.0,
description: description.to_string(),
balance_after: state.balance,
};
drop(state);
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(self.token_costs_file_path())?;
writeln!(file, "{}", serde_json::to_string(&record)?)?;
file.sync_all()?;
Ok(())
}
}
impl std::fmt::Display for EconomicTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state = self.state.lock();
write!(
f,
"EconomicTracker(signature='{}', balance=${:.2}, status={})",
self.signature,
state.balance,
self.get_survival_status_inner(&state)
)
}
}
/// Comprehensive economic summary.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EconomicSummary {
pub signature: String,
pub balance: f64,
pub initial_balance: f64,
pub net_worth: f64,
pub total_token_cost: f64,
pub total_work_income: f64,
pub total_trading_profit: f64,
pub session_cost: f64,
pub daily_cost: f64,
pub session_input_tokens: u64,
pub session_output_tokens: u64,
pub survival_status: SurvivalStatus,
pub is_bankrupt: bool,
pub min_evaluation_threshold: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_config() -> EconomicConfig {
EconomicConfig {
enabled: true,
initial_balance: 1000.0,
token_pricing: TokenPricing {
input_price_per_million: 3.0,
output_price_per_million: 15.0,
},
min_evaluation_threshold: 0.6,
}
}
#[test]
fn tracker_initialization() {
let tmp = TempDir::new().unwrap();
let config = test_config();
let tracker = EconomicTracker::new("test-agent", config, Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
assert!((tracker.get_balance() - 1000.0).abs() < f64::EPSILON);
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Thriving);
}
#[test]
fn track_tokens_reduces_balance() {
let tmp = TempDir::new().unwrap();
let tracker =
EconomicTracker::new("test-agent", test_config(), Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
tracker.start_task("task-1", None);
let cost = tracker.track_tokens(1000, 500, "agent", None);
tracker.end_task().unwrap();
// (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
assert!((cost - 0.0105).abs() < 0.0001);
assert!((tracker.get_balance() - (1000.0 - 0.0105)).abs() < 0.0001);
}
#[test]
fn work_income_with_threshold() {
let tmp = TempDir::new().unwrap();
let tracker =
EconomicTracker::new("test-agent", test_config(), Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
// Below threshold - no payment
let payment = tracker.add_work_income(100.0, "task-1", 0.5, "").unwrap();
assert!((payment - 0.0).abs() < f64::EPSILON);
assert!((tracker.get_balance() - 1000.0).abs() < f64::EPSILON);
// At threshold - payment awarded
let payment = tracker.add_work_income(100.0, "task-2", 0.6, "").unwrap();
assert!((payment - 100.0).abs() < f64::EPSILON);
assert!((tracker.get_balance() - 1100.0).abs() < f64::EPSILON);
}
#[test]
fn survival_status_changes() {
let tmp = TempDir::new().unwrap();
let mut config = test_config();
config.initial_balance = 100.0;
let tracker = EconomicTracker::new("test-agent", config, Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Thriving);
// Spend 30% - should be stable
tracker.track_tokens(10_000_000, 0, "agent", Some(30.0));
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Stable);
// Spend more to reach struggling
tracker.track_tokens(10_000_000, 0, "agent", Some(35.0));
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Struggling);
// Spend more to reach critical
tracker.track_tokens(10_000_000, 0, "agent", Some(25.0));
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Critical);
// Bankrupt
tracker.track_tokens(10_000_000, 0, "agent", Some(20.0));
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Bankrupt);
assert!(tracker.is_bankrupt());
}
#[test]
fn state_persistence() {
let tmp = TempDir::new().unwrap();
let config = test_config();
// Create tracker, do some work, save state
{
let tracker =
EconomicTracker::new("test-agent", config.clone(), Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
tracker.track_tokens(1000, 500, "agent", Some(10.0));
tracker
.save_daily_state("2025-01-01", 0.0, 0.0, vec![], false)
.unwrap();
}
// Create new tracker, should load state
{
let tracker =
EconomicTracker::new("test-agent", config, Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
assert!((tracker.get_balance() - 990.0).abs() < 0.01);
}
}
#[test]
fn api_call_categorization() {
let tmp = TempDir::new().unwrap();
let tracker =
EconomicTracker::new("test-agent", test_config(), Some(tmp.path().to_path_buf()));
tracker.initialize().unwrap();
tracker.start_task("task-1", None);
// Search API
tracker.track_flat_api_call(0.001, "tavily_search");
// OCR API
tracker.track_api_call(1000, 1.0, "ocr_reader");
// Other API
tracker.track_flat_api_call(0.01, "some_api");
tracker.end_task().unwrap();
// Balance should reflect all costs
let expected_reduction = 0.001 + 0.001 + 0.01; // search + ocr + other
assert!((tracker.get_balance() - (1000.0 - expected_reduction)).abs() < 0.0001);
}
}
+8 -4
View File
@@ -362,6 +362,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
&providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: config.api_url.clone(),
provider_transport: config.effective_provider_transport(),
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
reasoning_enabled: config.runtime.reasoning_enabled,
@@ -662,11 +663,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
}
// Wrap observer with broadcast capability for SSE
// Use cost-tracking observer when cost tracking is enabled
let base_observer = crate::observability::create_observer_with_cost_tracking(
&config.observability,
cost_tracker.clone(),
&config.cost,
);
let broadcast_observer: Arc<dyn crate::observability::Observer> =
Arc::new(sse::BroadcastObserver::new(
crate::observability::create_observer(&config.observability),
event_tx.clone(),
));
Arc::new(sse::BroadcastObserver::new(base_observer, event_tx.clone()));
let state = AppState {
config: config_state,
+55 -2
View File
@@ -1,7 +1,7 @@
use super::{IntegrationCategory, IntegrationEntry, IntegrationStatus};
use crate::providers::{
is_glm_alias, is_minimax_alias, is_moonshot_alias, is_qianfan_alias, is_qwen_alias,
is_zai_alias,
is_doubao_alias, is_glm_alias, is_minimax_alias, is_moonshot_alias, is_qianfan_alias,
is_qwen_alias, is_siliconflow_alias, is_zai_alias,
};
/// Returns the full catalog of integrations
@@ -436,6 +436,33 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
}
},
},
IntegrationEntry {
name: "Volcengine ARK",
description: "Doubao and ARK model catalog",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref().is_some_and(is_doubao_alias) {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "SiliconFlow",
description: "OpenAI-compatible hosted models and reasoning",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider
.as_deref()
.is_some_and(is_siliconflow_alias)
{
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Groq",
description: "Llama 3.3 70B Versatile and low-latency models",
@@ -991,5 +1018,31 @@ mod tests {
(qianfan.status_fn)(&config),
IntegrationStatus::Active
));
config.default_provider = Some("ark".to_string());
let volcengine = entries.iter().find(|e| e.name == "Volcengine ARK").unwrap();
assert!(matches!(
(volcengine.status_fn)(&config),
IntegrationStatus::Active
));
config.default_provider = Some("volcengine".to_string());
assert!(matches!(
(volcengine.status_fn)(&config),
IntegrationStatus::Active
));
config.default_provider = Some("siliconflow".to_string());
let siliconflow = entries.iter().find(|e| e.name == "SiliconFlow").unwrap();
assert!(matches!(
(siliconflow.status_fn)(&config),
IntegrationStatus::Active
));
config.default_provider = Some("silicon-cloud".to_string());
assert!(matches!(
(siliconflow.status_fn)(&config),
IntegrationStatus::Active
));
}
}
+1
View File
@@ -49,6 +49,7 @@ pub(crate) mod cost;
pub(crate) mod cron;
pub(crate) mod daemon;
pub(crate) mod doctor;
pub mod economic;
pub mod gateway;
pub mod goals;
pub(crate) mod hardware;
+259
View File
@@ -0,0 +1,259 @@
//! Cost-tracking observer that wires provider token usage to the cost tracker.
//!
//! Intercepts `LlmResponse` events and records usage to the `CostTracker`,
//! calculating costs based on model pricing configuration.
use super::traits::{Observer, ObserverEvent, ObserverMetric};
use crate::config::schema::ModelPricing;
use crate::cost::{CostTracker, TokenUsage};
use std::collections::HashMap;
use std::sync::Arc;
/// Observer that records token usage to a CostTracker.
///
/// Listens for `LlmResponse` events and calculates costs using model pricing.
pub struct CostObserver {
tracker: Arc<CostTracker>,
prices: HashMap<String, ModelPricing>,
/// Default pricing for unknown models (USD per 1M tokens)
default_input_price: f64,
default_output_price: f64,
}
impl CostObserver {
/// Create a new cost observer with the given tracker and pricing config.
pub fn new(tracker: Arc<CostTracker>, prices: HashMap<String, ModelPricing>) -> Self {
Self {
tracker,
prices,
// Conservative defaults for unknown models
default_input_price: 3.0,
default_output_price: 15.0,
}
}
/// Look up pricing for a model, trying various name formats.
fn get_pricing(&self, provider: &str, model: &str) -> (f64, f64) {
// Try exact match first: "provider/model"
let full_name = format!("{provider}/{model}");
if let Some(pricing) = self.prices.get(&full_name) {
return (pricing.input, pricing.output);
}
// Try just the model name
if let Some(pricing) = self.prices.get(model) {
return (pricing.input, pricing.output);
}
// Try model family matching (e.g., "claude-sonnet-4" matches any claude-sonnet-4-*)
for (key, pricing) in &self.prices {
// Strip provider prefix if present
let key_model = key.split('/').next_back().unwrap_or(key);
// Check if model starts with the key (family match)
if model.starts_with(key_model) || key_model.starts_with(model) {
return (pricing.input, pricing.output);
}
// Check for common model name patterns
// e.g., "claude-3-5-sonnet-20241022" should match "claude-3.5-sonnet"
let normalized_model = model.replace('-', ".");
let normalized_key = key_model.replace('-', ".");
if normalized_model.contains(&normalized_key)
|| normalized_key.contains(&normalized_model)
{
return (pricing.input, pricing.output);
}
}
// Fall back to defaults
tracing::debug!(
"No pricing found for {}/{}, using defaults (${}/{} per 1M tokens)",
provider,
model,
self.default_input_price,
self.default_output_price
);
(self.default_input_price, self.default_output_price)
}
}
impl Observer for CostObserver {
fn record_event(&self, event: &ObserverEvent) {
if let ObserverEvent::LlmResponse {
provider,
model,
success: true,
input_tokens,
output_tokens,
..
} = event
{
// Only record if we have token counts
let input = input_tokens.unwrap_or(0);
let output = output_tokens.unwrap_or(0);
if input == 0 && output == 0 {
return;
}
let (input_price, output_price) = self.get_pricing(provider, model);
let full_model_name = format!("{provider}/{model}");
let usage = TokenUsage::new(full_model_name, input, output, input_price, output_price);
if let Err(e) = self.tracker.record_usage(usage) {
tracing::warn!("Failed to record cost usage: {e}");
}
}
}
fn record_metric(&self, _metric: &ObserverMetric) {
// Cost observer doesn't handle metrics
}
fn name(&self) -> &str {
"cost"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::schema::CostConfig;
use std::time::Duration;
use tempfile::TempDir;
fn create_test_tracker() -> (TempDir, Arc<CostTracker>) {
let tmp = TempDir::new().unwrap();
let config = CostConfig {
enabled: true,
..Default::default()
};
let tracker = Arc::new(CostTracker::new(config, tmp.path()).unwrap());
(tmp, tracker)
}
#[test]
fn cost_observer_records_llm_response_usage() {
let (_tmp, tracker) = create_test_tracker();
let mut prices = HashMap::new();
prices.insert(
"anthropic/claude-sonnet-4-20250514".into(),
ModelPricing {
input: 3.0,
output: 15.0,
},
);
let observer = CostObserver::new(tracker.clone(), prices);
observer.record_event(&ObserverEvent::LlmResponse {
provider: "anthropic".into(),
model: "claude-sonnet-4-20250514".into(),
duration: Duration::from_millis(100),
success: true,
error_message: None,
input_tokens: Some(1000),
output_tokens: Some(500),
});
let summary = tracker.get_summary().unwrap();
assert_eq!(summary.request_count, 1);
// Cost: (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
assert!((summary.session_cost_usd - 0.0105).abs() < 0.0001);
}
#[test]
fn cost_observer_ignores_failed_responses() {
let (_tmp, tracker) = create_test_tracker();
let observer = CostObserver::new(tracker.clone(), HashMap::new());
observer.record_event(&ObserverEvent::LlmResponse {
provider: "anthropic".into(),
model: "claude-sonnet-4".into(),
duration: Duration::from_millis(100),
success: false,
error_message: Some("API error".into()),
input_tokens: Some(1000),
output_tokens: Some(500),
});
let summary = tracker.get_summary().unwrap();
assert_eq!(summary.request_count, 0);
}
#[test]
fn cost_observer_ignores_zero_token_responses() {
let (_tmp, tracker) = create_test_tracker();
let observer = CostObserver::new(tracker.clone(), HashMap::new());
observer.record_event(&ObserverEvent::LlmResponse {
provider: "anthropic".into(),
model: "claude-sonnet-4".into(),
duration: Duration::from_millis(100),
success: true,
error_message: None,
input_tokens: None,
output_tokens: None,
});
let summary = tracker.get_summary().unwrap();
assert_eq!(summary.request_count, 0);
}
#[test]
fn cost_observer_uses_default_pricing_for_unknown_models() {
let (_tmp, tracker) = create_test_tracker();
let observer = CostObserver::new(tracker.clone(), HashMap::new());
observer.record_event(&ObserverEvent::LlmResponse {
provider: "unknown".into(),
model: "mystery-model".into(),
duration: Duration::from_millis(100),
success: true,
error_message: None,
input_tokens: Some(1_000_000), // 1M tokens
output_tokens: Some(1_000_000),
});
let summary = tracker.get_summary().unwrap();
assert_eq!(summary.request_count, 1);
// Default: $3 input + $15 output = $18 for 1M each
assert!((summary.session_cost_usd - 18.0).abs() < 0.01);
}
#[test]
fn cost_observer_matches_model_family() {
let (_tmp, tracker) = create_test_tracker();
let mut prices = HashMap::new();
prices.insert(
"openai/gpt-4o".into(),
ModelPricing {
input: 5.0,
output: 15.0,
},
);
let observer = CostObserver::new(tracker.clone(), prices);
// Model name with version suffix should still match
observer.record_event(&ObserverEvent::LlmResponse {
provider: "openai".into(),
model: "gpt-4o-2024-05-13".into(),
duration: Duration::from_millis(100),
success: true,
error_message: None,
input_tokens: Some(1_000_000),
output_tokens: Some(0),
});
let summary = tracker.get_summary().unwrap();
// Should use $5 input price, not default $3
assert!((summary.session_cost_usd - 5.0).abs() < 0.01);
}
}
+32
View File
@@ -1,3 +1,4 @@
pub mod cost;
pub mod log;
pub mod multi;
pub mod noop;
@@ -12,6 +13,7 @@ pub mod verbose;
pub use self::log::LogObserver;
#[allow(unused_imports)]
pub use self::multi::MultiObserver;
pub use cost::CostObserver;
pub use noop::NoopObserver;
#[cfg(feature = "observability-otel")]
pub use otel::OtelObserver;
@@ -20,10 +22,40 @@ pub use traits::{Observer, ObserverEvent};
#[allow(unused_imports)]
pub use verbose::VerboseObserver;
use crate::config::schema::CostConfig;
use crate::config::ObservabilityConfig;
use crate::cost::CostTracker;
use std::sync::Arc;
/// Factory: create the right observer from config
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
create_observer_internal(config)
}
/// Create an observer stack with optional cost tracking.
///
/// When cost tracking is enabled, wraps the base observer in a MultiObserver
/// that also includes a CostObserver for recording token usage.
pub fn create_observer_with_cost_tracking(
config: &ObservabilityConfig,
cost_tracker: Option<Arc<CostTracker>>,
cost_config: &CostConfig,
) -> Box<dyn Observer> {
let base_observer = create_observer_internal(config);
match cost_tracker {
Some(tracker) if cost_config.enabled => {
let cost_observer = CostObserver::new(tracker, cost_config.prices.clone());
Box::new(MultiObserver::new(vec![
base_observer,
Box::new(cost_observer),
]))
}
_ => base_observer,
}
}
fn create_observer_internal(config: &ObservabilityConfig) -> Box<dyn Observer> {
match config.backend.as_str() {
"log" => Box::new(LogObserver::new()),
"prometheus" => Box::new(PrometheusObserver::new()),
+139 -5
View File
@@ -18,9 +18,9 @@ use crate::memory::{
selectable_memory_backends, MemoryBackendKind,
};
use crate::providers::{
canonical_china_provider_name, is_glm_alias, is_glm_cn_alias, is_minimax_alias,
is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_qwen_oauth_alias, is_zai_alias,
is_zai_cn_alias,
canonical_china_provider_name, is_doubao_alias, is_glm_alias, is_glm_cn_alias,
is_minimax_alias, is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_qwen_oauth_alias,
is_siliconflow_alias, is_zai_alias, is_zai_cn_alias,
};
use anyhow::{bail, Context, Result};
use console::style;
@@ -186,6 +186,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
proxy: crate::config::ProxyConfig::default(),
identity: identity_config,
cost: crate::config::CostConfig::default(),
economic: crate::config::EconomicConfig::default(),
peripherals: crate::config::PeripheralsConfig::default(),
agents: std::collections::HashMap::new(),
hooks: crate::config::HooksConfig::default(),
@@ -550,6 +551,7 @@ async fn run_quick_setup_with_home(
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
economic: crate::config::EconomicConfig::default(),
peripherals: crate::config::PeripheralsConfig::default(),
agents: std::collections::HashMap::new(),
hooks: crate::config::HooksConfig::default(),
@@ -710,6 +712,9 @@ fn canonical_provider_name(provider_name: &str) -> &str {
}
if let Some(canonical) = canonical_china_provider_name(provider_name) {
if canonical == "doubao" {
return "volcengine";
}
return canonical;
}
@@ -775,6 +780,8 @@ fn default_model_for_provider(provider: &str) -> String {
"glm" | "zai" => "glm-5".into(),
"minimax" => "MiniMax-M2.5".into(),
"qwen" => "qwen-plus".into(),
"volcengine" => "doubao-1-5-pro-32k-250115".into(),
"siliconflow" => "Pro/zai-org/GLM-4.7".into(),
"qwen-code" => "qwen3-coder-plus".into(),
"ollama" => "llama3.2".into(),
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF".into(),
@@ -1088,6 +1095,31 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
"Qwen Turbo (fast and cost-efficient)".to_string(),
),
],
"volcengine" => vec![
(
"doubao-1-5-pro-32k-250115".to_string(),
"Doubao 1.5 Pro 32K (official sample model)".to_string(),
),
(
"doubao-seed-1-6-250615".to_string(),
"Doubao Seed 1.6 (reasoning flagship)".to_string(),
),
(
"deepseek-v3.2".to_string(),
"DeepSeek V3.2 (available in ARK catalog)".to_string(),
),
],
"siliconflow" => vec![
(
"Pro/zai-org/GLM-4.7".to_string(),
"GLM-4.7 Pro (official API example)".to_string(),
),
(
"Pro/deepseek-ai/DeepSeek-V3.2".to_string(),
"DeepSeek V3.2 Pro".to_string(),
),
("Qwen/Qwen3-32B".to_string(), "Qwen3 32B".to_string()),
],
"qwen-code" => vec![
(
"qwen3-coder-plus".to_string(),
@@ -1264,6 +1296,8 @@ fn supports_live_model_fetch(provider_name: &str) -> bool {
| "glm"
| "zai"
| "qwen"
| "volcengine"
| "siliconflow"
| "nvidia"
)
}
@@ -1276,6 +1310,9 @@ fn models_endpoint_for_provider(provider_name: &str) -> Option<&'static str> {
"moonshot-cn" | "kimi-cn" => Some("https://api.moonshot.cn/v1/models"),
"glm-cn" | "bigmodel" => Some("https://open.bigmodel.cn/api/paas/v4/models"),
"zai-cn" | "z.ai-cn" => Some("https://open.bigmodel.cn/api/coding/paas/v4/models"),
"volcengine" | "ark" | "doubao" | "doubao-cn" => {
Some("https://ark.cn-beijing.volces.com/api/v3/models")
}
_ => match canonical_provider_name(provider_name) {
"openai-codex" | "openai" => Some("https://api.openai.com/v1/models"),
"venice" => Some("https://api.venice.ai/api/v1/models"),
@@ -1291,6 +1328,7 @@ fn models_endpoint_for_provider(provider_name: &str) -> Option<&'static str> {
"glm" => Some("https://api.z.ai/api/paas/v4/models"),
"zai" => Some("https://api.z.ai/api/coding/paas/v4/models"),
"qwen" => Some("https://dashscope.aliyuncs.com/compatible-mode/v1/models"),
"siliconflow" => Some("https://api.siliconflow.cn/v1/models"),
"nvidia" => Some("https://integrate.api.nvidia.com/v1/models"),
"astrai" => Some("https://as-trai.com/v1/models"),
"llamacpp" => Some("http://localhost:8080/v1/models"),
@@ -2303,6 +2341,11 @@ async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String,
("qwen-us", "Qwen — DashScope US endpoint"),
("hunyuan", "Hunyuan — Tencent large models (T1, Turbo, Pro)"),
("qianfan", "Qianfan — Baidu AI models (China endpoint)"),
("volcengine", "Volcengine ARK — Doubao model family"),
(
"siliconflow",
"SiliconFlow — OpenAI-compatible hosted models",
),
("zai", "Z.AI — global coding endpoint"),
("zai-cn", "Z.AI — China coding endpoint (open.bigmodel.cn)"),
("synthetic", "Synthetic — Synthetic AI models"),
@@ -2697,6 +2740,10 @@ async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String,
"https://help.aliyun.com/zh/model-studio/developer-reference/get-api-key"
} else if is_qianfan_alias(provider_name) {
"https://cloud.baidu.com/doc/WENXINWORKSHOP/s/7lm0vxo78"
} else if is_doubao_alias(provider_name) {
"https://console.volcengine.com/ark/region:ark+cn-beijing/apiKey"
} else if is_siliconflow_alias(provider_name) {
"https://cloud.siliconflow.cn/account/ak"
} else {
match provider_name {
"openrouter" => "https://openrouter.ai/keys",
@@ -3005,6 +3052,8 @@ fn provider_env_var(name: &str) -> &'static str {
"glm" => "GLM_API_KEY",
"minimax" => "MINIMAX_API_KEY",
"qwen" => "DASHSCOPE_API_KEY",
"volcengine" => "ARK_API_KEY",
"siliconflow" => "SILICONFLOW_API_KEY",
"hunyuan" => "HUNYUAN_API_KEY",
"qianfan" => "QIANFAN_API_KEY",
"zai" => "ZAI_API_KEY",
@@ -7305,6 +7354,14 @@ mod tests {
assert_eq!(default_model_for_provider("moonshot"), "kimi-k2.5");
assert_eq!(default_model_for_provider("hunyuan"), "hunyuan-t1-latest");
assert_eq!(default_model_for_provider("tencent"), "hunyuan-t1-latest");
assert_eq!(
default_model_for_provider("siliconflow"),
"Pro/zai-org/GLM-4.7"
);
assert_eq!(
default_model_for_provider("volcengine"),
"doubao-1-5-pro-32k-250115"
);
assert_eq!(
default_model_for_provider("nvidia"),
"meta/llama-3.3-70b-instruct"
@@ -7343,6 +7400,10 @@ mod tests {
assert_eq!(canonical_provider_name("minimax-cn"), "minimax");
assert_eq!(canonical_provider_name("zai-cn"), "zai");
assert_eq!(canonical_provider_name("z.ai-global"), "zai");
assert_eq!(canonical_provider_name("doubao"), "volcengine");
assert_eq!(canonical_provider_name("ark"), "volcengine");
assert_eq!(canonical_provider_name("silicon-cloud"), "siliconflow");
assert_eq!(canonical_provider_name("siliconcloud"), "siliconflow");
assert_eq!(canonical_provider_name("nvidia-nim"), "nvidia");
assert_eq!(canonical_provider_name("aws-bedrock"), "bedrock");
assert_eq!(canonical_provider_name("build.nvidia.com"), "nvidia");
@@ -7485,6 +7546,23 @@ mod tests {
assert!(ids.contains(&"qwen3-max-2026-01-23".to_string()));
}
#[test]
fn curated_models_for_volcengine_and_siliconflow_include_expected_defaults() {
let volcengine_ids: Vec<String> = curated_models_for_provider("volcengine")
.into_iter()
.map(|(id, _)| id)
.collect();
assert!(volcengine_ids.contains(&"doubao-1-5-pro-32k-250115".to_string()));
assert!(volcengine_ids.contains(&"doubao-seed-1-6-250615".to_string()));
let siliconflow_ids: Vec<String> = curated_models_for_provider("siliconflow")
.into_iter()
.map(|(id, _)| id)
.collect();
assert!(siliconflow_ids.contains(&"Pro/zai-org/GLM-4.7".to_string()));
assert!(siliconflow_ids.contains(&"Pro/deepseek-ai/DeepSeek-V3.2".to_string()));
}
#[test]
fn supports_live_model_fetch_for_supported_and_unsupported_providers() {
assert!(supports_live_model_fetch("openai"));
@@ -7506,6 +7584,11 @@ mod tests {
assert!(supports_live_model_fetch("glm-cn"));
assert!(supports_live_model_fetch("qwen-intl"));
assert!(supports_live_model_fetch("qwen-coding-plan"));
assert!(supports_live_model_fetch("siliconflow"));
assert!(supports_live_model_fetch("silicon-cloud"));
assert!(supports_live_model_fetch("volcengine"));
assert!(supports_live_model_fetch("doubao"));
assert!(supports_live_model_fetch("ark"));
assert!(!supports_live_model_fetch("minimax-cn"));
assert!(!supports_live_model_fetch("unknown-provider"));
}
@@ -7564,6 +7647,22 @@ mod tests {
curated_models_for_provider("bedrock"),
curated_models_for_provider("aws-bedrock")
);
assert_eq!(
curated_models_for_provider("volcengine"),
curated_models_for_provider("doubao")
);
assert_eq!(
curated_models_for_provider("volcengine"),
curated_models_for_provider("ark")
);
assert_eq!(
curated_models_for_provider("siliconflow"),
curated_models_for_provider("silicon-cloud")
);
assert_eq!(
curated_models_for_provider("siliconflow"),
curated_models_for_provider("siliconcloud")
);
}
#[test]
@@ -7596,6 +7695,18 @@ mod tests {
models_endpoint_for_provider("qwen-coding-plan"),
Some("https://coding.dashscope.aliyuncs.com/v1/models")
);
assert_eq!(
models_endpoint_for_provider("volcengine"),
Some("https://ark.cn-beijing.volces.com/api/v3/models")
);
assert_eq!(
models_endpoint_for_provider("doubao"),
Some("https://ark.cn-beijing.volces.com/api/v3/models")
);
assert_eq!(
models_endpoint_for_provider("ark"),
Some("https://ark.cn-beijing.volces.com/api/v3/models")
);
}
#[test]
@@ -7616,6 +7727,14 @@ mod tests {
models_endpoint_for_provider("moonshot"),
Some("https://api.moonshot.ai/v1/models")
);
assert_eq!(
models_endpoint_for_provider("siliconflow"),
Some("https://api.siliconflow.cn/v1/models")
);
assert_eq!(
models_endpoint_for_provider("silicon-cloud"),
Some("https://api.siliconflow.cn/v1/models")
);
assert_eq!(
models_endpoint_for_provider("llamacpp"),
Some("http://localhost:8080/v1/models")
@@ -7914,6 +8033,12 @@ mod tests {
assert_eq!(provider_env_var("minimax-oauth-cn"), "MINIMAX_API_KEY");
assert_eq!(provider_env_var("moonshot-intl"), "MOONSHOT_API_KEY");
assert_eq!(provider_env_var("zai-cn"), "ZAI_API_KEY");
assert_eq!(provider_env_var("doubao"), "ARK_API_KEY");
assert_eq!(provider_env_var("volcengine"), "ARK_API_KEY");
assert_eq!(provider_env_var("ark"), "ARK_API_KEY");
assert_eq!(provider_env_var("siliconflow"), "SILICONFLOW_API_KEY");
assert_eq!(provider_env_var("silicon-cloud"), "SILICONFLOW_API_KEY");
assert_eq!(provider_env_var("siliconcloud"), "SILICONFLOW_API_KEY");
assert_eq!(provider_env_var("nvidia"), "NVIDIA_API_KEY");
assert_eq!(provider_env_var("nvidia-nim"), "NVIDIA_API_KEY"); // alias
assert_eq!(provider_env_var("build.nvidia.com"), "NVIDIA_API_KEY"); // alias
@@ -8006,13 +8131,14 @@ mod tests {
}
#[test]
fn channel_menu_choices_include_signal_and_nextcloud_talk() {
fn channel_menu_choices_include_signal_nextcloud_and_dingtalk() {
assert!(channel_menu_choices().contains(&ChannelMenuChoice::Signal));
assert!(channel_menu_choices().contains(&ChannelMenuChoice::NextcloudTalk));
assert!(channel_menu_choices().contains(&ChannelMenuChoice::DingTalk));
}
#[test]
fn launchable_channels_include_signal_mattermost_qq_and_nextcloud_talk() {
fn launchable_channels_include_signal_mattermost_qq_nextcloud_and_dingtalk() {
let mut channels = ChannelsConfig::default();
assert!(!has_launchable_channels(&channels));
@@ -8056,5 +8182,13 @@ mod tests {
allowed_users: vec!["*".into()],
});
assert!(has_launchable_channels(&channels));
channels.nextcloud_talk = None;
channels.dingtalk = Some(crate::config::schema::DingTalkConfig {
client_id: "client-id".into(),
client_secret: "client-secret".into(),
allowed_users: vec!["*".into()],
});
assert!(has_launchable_channels(&channels));
}
}
+56 -14
View File
@@ -74,6 +74,7 @@ const QWEN_OAUTH_DEFAULT_CLIENT_ID: &str = "f0304373b74a44d2b584a3fb70ca9e56";
const QWEN_OAUTH_CREDENTIAL_FILE: &str = ".qwen/oauth_creds.json";
const ZAI_GLOBAL_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4";
const ZAI_CN_BASE_URL: &str = "https://open.bigmodel.cn/api/coding/paas/v4";
const SILICONFLOW_BASE_URL: &str = "https://api.siliconflow.cn/v1";
const VERCEL_AI_GATEWAY_BASE_URL: &str = "https://ai-gateway.vercel.sh/v1";
pub(crate) fn is_minimax_intl_alias(name: &str) -> bool {
@@ -179,6 +180,10 @@ pub(crate) fn is_doubao_alias(name: &str) -> bool {
matches!(name, "doubao" | "volcengine" | "ark" | "doubao-cn")
}
pub(crate) fn is_siliconflow_alias(name: &str) -> bool {
matches!(name, "siliconflow" | "silicon-cloud" | "siliconcloud")
}
#[derive(Clone, Copy, Debug)]
enum MinimaxOauthRegion {
Global,
@@ -618,6 +623,8 @@ pub(crate) fn canonical_china_provider_name(name: &str) -> Option<&'static str>
Some("qianfan")
} else if is_doubao_alias(name) {
Some("doubao")
} else if is_siliconflow_alias(name) {
Some("siliconflow")
} else if matches!(name, "hunyuan" | "tencent") {
Some("hunyuan")
} else {
@@ -683,6 +690,7 @@ fn zai_base_url(name: &str) -> Option<&'static str> {
pub struct ProviderRuntimeOptions {
pub auth_profile_override: Option<String>,
pub provider_api_url: Option<String>,
pub provider_transport: Option<String>,
pub zeroclaw_dir: Option<PathBuf>,
pub secrets_encrypt: bool,
pub reasoning_enabled: Option<bool>,
@@ -697,6 +705,7 @@ impl Default for ProviderRuntimeOptions {
Self {
auth_profile_override: None,
provider_api_url: None,
provider_transport: None,
zeroclaw_dir: None,
secrets_encrypt: true,
reasoning_enabled: None,
@@ -872,6 +881,7 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) ->
"hunyuan" | "tencent" => vec!["HUNYUAN_API_KEY"],
name if is_qianfan_alias(name) => vec!["QIANFAN_API_KEY"],
name if is_doubao_alias(name) => vec!["ARK_API_KEY", "DOUBAO_API_KEY"],
name if is_siliconflow_alias(name) => vec!["SILICONFLOW_API_KEY"],
name if is_qwen_alias(name) => vec!["DASHSCOPE_API_KEY"],
name if is_zai_alias(name) => vec!["ZAI_API_KEY"],
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec!["NVIDIA_API_KEY"],
@@ -1181,6 +1191,13 @@ fn create_provider_with_url_and_options(
key,
AuthStyle::Bearer,
))),
name if is_siliconflow_alias(name) => Ok(Box::new(OpenAiCompatibleProvider::new_with_vision(
"SiliconFlow",
SILICONFLOW_BASE_URL,
key,
AuthStyle::Bearer,
true,
))),
name if qwen_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new_with_vision(
"Qwen",
qwen_base_url(name).expect("checked in guard"),
@@ -1512,7 +1529,15 @@ pub fn create_routed_provider_with_options(
.then_some(api_url)
.flatten();
let route_options = options.clone();
let mut route_options = options.clone();
if let Some(transport) = route
.transport
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
{
route_options.provider_transport = Some(transport.to_string());
}
match create_resilient_provider_with_options(
&route.provider,
@@ -1542,19 +1567,8 @@ pub fn create_routed_provider_with_options(
}
}
// Build route table
let routes: Vec<(String, router::Route)> = model_routes
.iter()
.map(|r| {
(
r.hint.clone(),
router::Route {
provider_name: r.provider.clone(),
model: r.model.clone(),
},
)
})
.collect();
// Keep only successfully initialized routed providers and preserve
// their provider-id bindings (e.g. "<provider>#<hint>").
Ok(Box::new(
router::RouterProvider::new(providers, routes, default_model.to_string())
@@ -1712,6 +1726,12 @@ pub fn list_providers() -> Vec<ProviderInfo> {
aliases: &["volcengine", "ark", "doubao-cn"],
local: false,
},
ProviderInfo {
name: "siliconflow",
display_name: "SiliconFlow",
aliases: &["silicon-cloud", "siliconcloud"],
local: false,
},
ProviderInfo {
name: "qwen",
display_name: "Qwen (DashScope / Qwen Code OAuth)",
@@ -2069,6 +2089,9 @@ mod tests {
assert!(is_doubao_alias("volcengine"));
assert!(is_doubao_alias("ark"));
assert!(is_doubao_alias("doubao-cn"));
assert!(is_siliconflow_alias("siliconflow"));
assert!(is_siliconflow_alias("silicon-cloud"));
assert!(is_siliconflow_alias("siliconcloud"));
assert!(!is_moonshot_alias("openrouter"));
assert!(!is_glm_alias("openai"));
@@ -2076,6 +2099,7 @@ mod tests {
assert!(!is_zai_alias("anthropic"));
assert!(!is_qianfan_alias("cohere"));
assert!(!is_doubao_alias("deepseek"));
assert!(!is_siliconflow_alias("volcengine"));
}
#[test]
@@ -2099,6 +2123,14 @@ mod tests {
assert_eq!(canonical_china_provider_name("baidu"), Some("qianfan"));
assert_eq!(canonical_china_provider_name("doubao"), Some("doubao"));
assert_eq!(canonical_china_provider_name("volcengine"), Some("doubao"));
assert_eq!(
canonical_china_provider_name("siliconflow"),
Some("siliconflow")
);
assert_eq!(
canonical_china_provider_name("silicon-cloud"),
Some("siliconflow")
);
assert_eq!(canonical_china_provider_name("hunyuan"), Some("hunyuan"));
assert_eq!(canonical_china_provider_name("tencent"), Some("hunyuan"));
assert_eq!(canonical_china_provider_name("openai"), None);
@@ -2316,6 +2348,13 @@ mod tests {
assert!(create_provider("doubao-cn", Some("key")).is_ok());
}
#[test]
fn factory_siliconflow() {
assert!(create_provider("siliconflow", Some("key")).is_ok());
assert!(create_provider("silicon-cloud", Some("key")).is_ok());
assert!(create_provider("siliconcloud", Some("key")).is_ok());
}
#[test]
fn factory_qwen() {
assert!(create_provider("qwen", Some("key")).is_ok());
@@ -2776,6 +2815,8 @@ mod tests {
"bedrock",
"qianfan",
"doubao",
"volcengine",
"siliconflow",
"qwen",
"qwen-intl",
"qwen-cn",
@@ -3049,6 +3090,7 @@ mod tests {
model: "anthropic/claude-sonnet-4.6".to_string(),
max_tokens: Some(4096),
api_key: None,
transport: None,
}];
let provider = create_routed_provider_with_options(
+534 -26
View File
@@ -4,21 +4,81 @@ use crate::multimodal;
use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities};
use crate::providers::ProviderRuntimeOptions;
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::path::PathBuf;
use std::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::{
connect_async,
tungstenite::{
client::IntoClientRequest,
http::{
header::{AUTHORIZATION, USER_AGENT},
HeaderValue as WsHeaderValue,
},
Message as WsMessage,
},
};
const DEFAULT_CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
const CODEX_RESPONSES_URL_ENV: &str = "ZEROCLAW_CODEX_RESPONSES_URL";
const CODEX_BASE_URL_ENV: &str = "ZEROCLAW_CODEX_BASE_URL";
const CODEX_TRANSPORT_ENV: &str = "ZEROCLAW_CODEX_TRANSPORT";
const CODEX_PROVIDER_TRANSPORT_ENV: &str = "ZEROCLAW_PROVIDER_TRANSPORT";
const CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY: &str = "ZEROCLAW_RESPONSES_WEBSOCKET";
const DEFAULT_CODEX_INSTRUCTIONS: &str =
"You are ZeroClaw, a concise and helpful coding assistant.";
const CODEX_WS_CONNECT_TIMEOUT: Duration = Duration::from_secs(20);
const CODEX_WS_SEND_TIMEOUT: Duration = Duration::from_secs(15);
const CODEX_WS_READ_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CodexTransport {
Auto,
WebSocket,
Sse,
}
#[derive(Debug)]
enum WebsocketRequestError {
TransportUnavailable(anyhow::Error),
Stream(anyhow::Error),
}
impl WebsocketRequestError {
fn transport_unavailable<E>(error: E) -> Self
where
E: Into<anyhow::Error>,
{
Self::TransportUnavailable(error.into())
}
fn stream<E>(error: E) -> Self
where
E: Into<anyhow::Error>,
{
Self::Stream(error.into())
}
}
impl std::fmt::Display for WebsocketRequestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TransportUnavailable(error) | Self::Stream(error) => error.fmt(f),
}
}
}
impl std::error::Error for WebsocketRequestError {}
pub struct OpenAiCodexProvider {
auth: AuthService,
auth_profile_override: Option<String>,
responses_url: String,
transport: CodexTransport,
custom_endpoint: bool,
gateway_api_key: Option<String>,
reasoning_level: Option<String>,
@@ -104,6 +164,7 @@ impl OpenAiCodexProvider {
auth_profile_override: options.auth_profile_override.clone(),
custom_endpoint: !is_default_responses_url(&responses_url),
responses_url,
transport: resolve_transport_mode(options)?,
gateway_api_key: gateway_api_key.map(ToString::to_string),
reasoning_level: normalize_reasoning_level(
options.reasoning_level.as_deref(),
@@ -204,6 +265,72 @@ fn first_nonempty(text: Option<&str>) -> Option<String> {
})
}
fn parse_transport_override(
raw: Option<&str>,
source: &str,
) -> anyhow::Result<Option<CodexTransport>> {
let Some(raw_value) = raw else {
return Ok(None);
};
let value = raw_value.trim();
if value.is_empty() {
return Ok(None);
}
let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
match normalized.as_str() {
"auto" => Ok(Some(CodexTransport::Auto)),
"websocket" | "ws" => Ok(Some(CodexTransport::WebSocket)),
"sse" | "http" => Ok(Some(CodexTransport::Sse)),
_ => anyhow::bail!(
"Invalid OpenAI Codex transport override '{value}' from {source}; expected one of: auto, websocket, sse"
),
}
}
fn parse_legacy_websocket_flag(raw: &str) -> Option<CodexTransport> {
let normalized = raw.trim().to_ascii_lowercase();
match normalized.as_str() {
"1" | "true" | "on" | "yes" => Some(CodexTransport::WebSocket),
"0" | "false" | "off" | "no" => Some(CodexTransport::Sse),
_ => None,
}
}
fn resolve_transport_mode(options: &ProviderRuntimeOptions) -> anyhow::Result<CodexTransport> {
if let Some(mode) = parse_transport_override(
options.provider_transport.as_deref(),
"provider.transport runtime override",
)? {
return Ok(mode);
}
if let Ok(value) = std::env::var(CODEX_TRANSPORT_ENV) {
if let Some(mode) = parse_transport_override(Some(&value), CODEX_TRANSPORT_ENV)? {
return Ok(mode);
}
}
if let Ok(value) = std::env::var(CODEX_PROVIDER_TRANSPORT_ENV) {
if let Some(mode) = parse_transport_override(Some(&value), CODEX_PROVIDER_TRANSPORT_ENV)? {
return Ok(mode);
}
}
if let Some(mode) = std::env::var(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY)
.ok()
.and_then(|value| parse_legacy_websocket_flag(&value))
{
tracing::warn!(
env = CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY,
"Using deprecated websocket toggle env for OpenAI Codex transport"
);
return Ok(mode);
}
Ok(CodexTransport::Auto)
}
fn resolve_instructions(system_prompt: Option<&str>) -> String {
first_nonempty(system_prompt).unwrap_or_else(|| DEFAULT_CODEX_INSTRUCTIONS.to_string())
}
@@ -526,6 +653,283 @@ async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<St
}
impl OpenAiCodexProvider {
fn responses_websocket_url(&self, model: &str) -> anyhow::Result<String> {
let mut url = reqwest::Url::parse(&self.responses_url)?;
let next_scheme: &'static str = match url.scheme() {
"https" | "wss" => "wss",
"http" | "ws" => "ws",
other => {
anyhow::bail!(
"OpenAI Codex websocket transport does not support URL scheme: {}",
other
);
}
};
url.set_scheme(next_scheme)
.map_err(|()| anyhow::anyhow!("failed to set websocket URL scheme"))?;
if !url.query_pairs().any(|(k, _)| k == "model") {
url.query_pairs_mut().append_pair("model", model);
}
Ok(url.into())
}
fn apply_auth_headers_ws(
&self,
request: &mut tokio_tungstenite::tungstenite::http::Request<()>,
bearer_token: &str,
account_id: Option<&str>,
access_token: Option<&str>,
use_gateway_api_key_auth: bool,
) -> anyhow::Result<()> {
let headers = request.headers_mut();
headers.insert(
AUTHORIZATION,
WsHeaderValue::from_str(&format!("Bearer {bearer_token}"))?,
);
headers.insert(
"OpenAI-Beta",
WsHeaderValue::from_static("responses=experimental"),
);
headers.insert("originator", WsHeaderValue::from_static("pi"));
headers.insert("accept", WsHeaderValue::from_static("text/event-stream"));
headers.insert(USER_AGENT, WsHeaderValue::from_static("zeroclaw"));
if let Some(account_id) = account_id {
headers.insert("chatgpt-account-id", WsHeaderValue::from_str(account_id)?);
}
if use_gateway_api_key_auth {
if let Some(access_token) = access_token {
headers.insert(
"x-openai-access-token",
WsHeaderValue::from_str(access_token)?,
);
}
if let Some(account_id) = account_id {
headers.insert("x-openai-account-id", WsHeaderValue::from_str(account_id)?);
}
}
Ok(())
}
async fn send_responses_websocket_request(
&self,
request: &ResponsesRequest,
model: &str,
bearer_token: &str,
account_id: Option<&str>,
access_token: Option<&str>,
use_gateway_api_key_auth: bool,
) -> Result<String, WebsocketRequestError> {
let ws_url = self
.responses_websocket_url(model)
.map_err(WebsocketRequestError::transport_unavailable)?;
let mut ws_request = ws_url.into_client_request().map_err(|error| {
WebsocketRequestError::transport_unavailable(anyhow::anyhow!(
"invalid websocket request URL: {error}"
))
})?;
self.apply_auth_headers_ws(
&mut ws_request,
bearer_token,
account_id,
access_token,
use_gateway_api_key_auth,
)
.map_err(WebsocketRequestError::transport_unavailable)?;
let payload = serde_json::json!({
"type": "response.create",
"model": &request.model,
"input": &request.input,
"instructions": &request.instructions,
"store": request.store,
"text": &request.text,
"reasoning": &request.reasoning,
"include": &request.include,
"tool_choice": &request.tool_choice,
"parallel_tool_calls": request.parallel_tool_calls,
});
let (mut ws_stream, _) = timeout(CODEX_WS_CONNECT_TIMEOUT, connect_async(ws_request))
.await
.map_err(|_| {
WebsocketRequestError::transport_unavailable(anyhow::anyhow!(
"OpenAI Codex websocket connect timed out after {}s",
CODEX_WS_CONNECT_TIMEOUT.as_secs()
))
})?
.map_err(WebsocketRequestError::transport_unavailable)?;
timeout(
CODEX_WS_SEND_TIMEOUT,
ws_stream.send(WsMessage::Text(
serde_json::to_string(&payload)
.map_err(WebsocketRequestError::transport_unavailable)?
.into(),
)),
)
.await
.map_err(|_| {
WebsocketRequestError::transport_unavailable(anyhow::anyhow!(
"OpenAI Codex websocket send timed out after {}s",
CODEX_WS_SEND_TIMEOUT.as_secs()
))
})?
.map_err(WebsocketRequestError::transport_unavailable)?;
let mut saw_delta = false;
let mut delta_accumulator = String::new();
let mut fallback_text: Option<String> = None;
let mut timed_out = false;
loop {
let frame = match timeout(CODEX_WS_READ_TIMEOUT, ws_stream.next()).await {
Ok(frame) => frame,
Err(_) => {
let _ = ws_stream.close(None).await;
if saw_delta || fallback_text.is_some() {
timed_out = true;
break;
}
return Err(WebsocketRequestError::stream(anyhow::anyhow!(
"OpenAI Codex websocket stream timed out after {}s waiting for events",
CODEX_WS_READ_TIMEOUT.as_secs()
)));
}
};
let Some(frame) = frame else {
break;
};
let frame = frame.map_err(WebsocketRequestError::stream)?;
let event: Value = match frame {
WsMessage::Text(text) => {
serde_json::from_str(text.as_ref()).map_err(WebsocketRequestError::stream)?
}
WsMessage::Binary(binary) => {
let text = String::from_utf8(binary.to_vec()).map_err(|error| {
WebsocketRequestError::stream(anyhow::anyhow!(
"invalid UTF-8 websocket frame from OpenAI Codex: {error}"
))
})?;
serde_json::from_str(&text).map_err(WebsocketRequestError::stream)?
}
WsMessage::Ping(payload) => {
ws_stream
.send(WsMessage::Pong(payload))
.await
.map_err(WebsocketRequestError::stream)?;
continue;
}
WsMessage::Close(_) => break,
_ => continue,
};
if let Some(message) = extract_stream_error_message(&event) {
return Err(WebsocketRequestError::stream(anyhow::anyhow!(
"OpenAI Codex websocket stream error: {message}"
)));
}
if let Some(text) = extract_stream_event_text(&event, saw_delta) {
let event_type = event.get("type").and_then(Value::as_str);
if event_type == Some("response.output_text.delta") {
saw_delta = true;
delta_accumulator.push_str(&text);
} else if fallback_text.is_none() {
fallback_text = Some(text);
}
}
let event_type = event.get("type").and_then(Value::as_str);
if event_type == Some("response.completed") || event_type == Some("response.done") {
if let Some(response_value) = event.get("response").cloned() {
if let Ok(parsed) = serde_json::from_value::<ResponsesResponse>(response_value)
{
if let Some(text) = extract_responses_text(&parsed) {
let _ = ws_stream.close(None).await;
return Ok(text);
}
}
}
if saw_delta {
let _ = ws_stream.close(None).await;
return nonempty_preserve(Some(&delta_accumulator)).ok_or_else(|| {
WebsocketRequestError::stream(anyhow::anyhow!(
"No response from OpenAI Codex"
))
});
}
if let Some(text) = fallback_text.clone() {
let _ = ws_stream.close(None).await;
return Ok(text);
}
}
}
if saw_delta {
return nonempty_preserve(Some(&delta_accumulator)).ok_or_else(|| {
WebsocketRequestError::stream(anyhow::anyhow!("No response from OpenAI Codex"))
});
}
if let Some(text) = fallback_text {
return Ok(text);
}
if timed_out {
return Err(WebsocketRequestError::stream(anyhow::anyhow!(
"No response from OpenAI Codex websocket stream before timeout"
)));
}
Err(WebsocketRequestError::stream(anyhow::anyhow!(
"No response from OpenAI Codex websocket stream"
)))
}
async fn send_responses_sse_request(
&self,
request: &ResponsesRequest,
bearer_token: &str,
account_id: Option<&str>,
access_token: Option<&str>,
use_gateway_api_key_auth: bool,
) -> anyhow::Result<String> {
let mut request_builder = self
.client
.post(&self.responses_url)
.header("Authorization", format!("Bearer {bearer_token}"))
.header("OpenAI-Beta", "responses=experimental")
.header("originator", "pi")
.header("accept", "text/event-stream")
.header("Content-Type", "application/json");
if let Some(account_id) = account_id {
request_builder = request_builder.header("chatgpt-account-id", account_id);
}
if use_gateway_api_key_auth {
if let Some(access_token) = access_token {
request_builder = request_builder.header("x-openai-access-token", access_token);
}
if let Some(account_id) = account_id {
request_builder = request_builder.header("x-openai-account-id", account_id);
}
}
let response = request_builder.json(request).send().await?;
if !response.status().is_success() {
return Err(super::api_error("OpenAI Codex", response).await);
}
decode_responses_body(response).await
}
async fn send_responses_request(
&self,
input: Vec<ResponsesInput>,
@@ -613,35 +1017,59 @@ impl OpenAiCodexProvider {
access_token.as_deref().unwrap_or_default()
};
let mut request_builder = self
.client
.post(&self.responses_url)
.header("Authorization", format!("Bearer {bearer_token}"))
.header("OpenAI-Beta", "responses=experimental")
.header("originator", "pi")
.header("accept", "text/event-stream")
.header("Content-Type", "application/json");
if let Some(account_id) = account_id.as_deref() {
request_builder = request_builder.header("chatgpt-account-id", account_id);
}
if use_gateway_api_key_auth {
if let Some(access_token) = access_token.as_deref() {
request_builder = request_builder.header("x-openai-access-token", access_token);
match self.transport {
CodexTransport::WebSocket => self
.send_responses_websocket_request(
&request,
normalized_model,
bearer_token,
account_id.as_deref(),
access_token.as_deref(),
use_gateway_api_key_auth,
)
.await
.map_err(Into::into),
CodexTransport::Sse => {
self.send_responses_sse_request(
&request,
bearer_token,
account_id.as_deref(),
access_token.as_deref(),
use_gateway_api_key_auth,
)
.await
}
if let Some(account_id) = account_id.as_deref() {
request_builder = request_builder.header("x-openai-account-id", account_id);
CodexTransport::Auto => {
match self
.send_responses_websocket_request(
&request,
normalized_model,
bearer_token,
account_id.as_deref(),
access_token.as_deref(),
use_gateway_api_key_auth,
)
.await
{
Ok(text) => Ok(text),
Err(WebsocketRequestError::TransportUnavailable(error)) => {
tracing::warn!(
error = %error,
"OpenAI Codex websocket request failed; falling back to SSE"
);
self.send_responses_sse_request(
&request,
bearer_token,
account_id.as_deref(),
access_token.as_deref(),
use_gateway_api_key_auth,
)
.await
}
Err(WebsocketRequestError::Stream(error)) => Err(error),
}
}
}
let response = request_builder.json(&request).send().await?;
if !response.status().is_success() {
return Err(super::api_error("OpenAI Codex", response).await);
}
decode_responses_body(response).await
}
}
@@ -809,6 +1237,85 @@ mod tests {
);
}
#[test]
fn resolve_transport_mode_defaults_to_auto() {
let _env_lock = env_lock();
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None);
let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, None);
let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None);
assert_eq!(
resolve_transport_mode(&ProviderRuntimeOptions::default()).unwrap(),
CodexTransport::Auto
);
}
#[test]
fn resolve_transport_mode_accepts_runtime_override() {
let _env_lock = env_lock();
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, Some("sse"));
let options = ProviderRuntimeOptions {
provider_transport: Some("websocket".to_string()),
..ProviderRuntimeOptions::default()
};
assert_eq!(
resolve_transport_mode(&options).unwrap(),
CodexTransport::WebSocket
);
}
#[test]
fn resolve_transport_mode_legacy_bool_env_is_supported() {
let _env_lock = env_lock();
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None);
let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None);
let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, Some("false"));
assert_eq!(
resolve_transport_mode(&ProviderRuntimeOptions::default()).unwrap(),
CodexTransport::Sse
);
}
#[test]
fn resolve_transport_mode_rejects_invalid_runtime_override() {
let _env_lock = env_lock();
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None);
let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None);
let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, None);
let options = ProviderRuntimeOptions {
provider_transport: Some("udp".to_string()),
..ProviderRuntimeOptions::default()
};
let err =
resolve_transport_mode(&options).expect_err("invalid runtime transport must fail");
assert!(err
.to_string()
.contains("Invalid OpenAI Codex transport override 'udp'"));
}
#[test]
fn websocket_url_uses_ws_scheme_and_model_query() {
let _env_lock = env_lock();
let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None);
let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None);
let options = ProviderRuntimeOptions::default();
let provider = OpenAiCodexProvider::new(&options, None).expect("provider should init");
let ws_url = provider
.responses_websocket_url("gpt-5.3-codex")
.expect("websocket URL should be derived");
assert_eq!(
ws_url,
"wss://chatgpt.com/backend-api/codex/responses?model=gpt-5.3-codex"
);
}
#[test]
fn default_responses_url_detector_handles_equivalent_urls() {
assert!(is_default_responses_url(DEFAULT_CODEX_RESPONSES_URL));
@@ -1077,6 +1584,7 @@ data: [DONE]
fn capabilities_includes_vision() {
let options = ProviderRuntimeOptions {
provider_api_url: None,
provider_transport: None,
zeroclaw_dir: None,
secrets_encrypt: false,
auth_profile_override: None,
File diff suppressed because it is too large Load Diff
+34 -26
View File
@@ -31,6 +31,8 @@ pub mod cron_update;
pub mod delegate;
pub mod delegate_coordination_status;
pub mod docx_read;
#[cfg(feature = "channel-lark")]
pub mod feishu_doc;
pub mod file_edit;
pub mod file_read;
pub mod file_write;
@@ -86,6 +88,8 @@ pub use cron_update::CronUpdateTool;
pub use delegate::DelegateTool;
pub use delegate_coordination_status::DelegateCoordinationStatusTool;
pub use docx_read::DocxReadTool;
#[cfg(feature = "channel-lark")]
pub use feishu_doc::FeishuDocTool;
pub use file_edit::FileEditTool;
pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
@@ -428,6 +432,7 @@ pub fn all_tools_with_runtime(
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: root_config.api_url.clone(),
provider_transport: root_config.effective_provider_transport(),
zeroclaw_dir: root_config
.config_path
.parent()
@@ -512,38 +517,41 @@ pub fn all_tools_with_runtime(
)));
}
// Inter-process agent communication (opt-in)
if root_config.agents_ipc.enabled {
match agents_ipc::IpcDb::open(workspace_dir, &root_config.agents_ipc) {
Ok(ipc_db) => {
let ipc_db = Arc::new(ipc_db);
tool_arcs.push(Arc::new(agents_ipc::AgentsListTool::new(ipc_db.clone())));
tool_arcs.push(Arc::new(agents_ipc::AgentsSendTool::new(
ipc_db.clone(),
// Feishu document tools (enabled when channel-lark feature is active)
#[cfg(feature = "channel-lark")]
{
let feishu_creds = root_config
.channels_config
.feishu
.as_ref()
.map(|fs| (fs.app_id.clone(), fs.app_secret.clone(), true))
.or_else(|| {
root_config
.channels_config
.lark
.as_ref()
.map(|lk| (lk.app_id.clone(), lk.app_secret.clone(), lk.use_feishu))
});
if let Some((app_id, app_secret, use_feishu)) = feishu_creds {
let app_id = app_id.trim().to_string();
let app_secret = app_secret.trim().to_string();
if app_id.is_empty() || app_secret.is_empty() {
tracing::warn!(
"feishu_doc: skipped registration because app credentials are empty"
);
} else {
tool_arcs.push(Arc::new(FeishuDocTool::new(
app_id,
app_secret,
use_feishu,
security.clone(),
)));
tool_arcs.push(Arc::new(agents_ipc::AgentsInboxTool::new(ipc_db.clone())));
tool_arcs.push(Arc::new(agents_ipc::StateGetTool::new(ipc_db.clone())));
tool_arcs.push(Arc::new(agents_ipc::StateSetTool::new(
ipc_db,
security.clone(),
)));
}
Err(e) => {
tracing::warn!("agents_ipc: failed to open IPC database: {e}");
}
}
}
// Load WASM plugin tools from the skills directory.
// Each installed skill package may ship one or more WASM tools under
// `<skill-dir>/tools/<tool-name>/{tool.wasm, manifest.json}`.
// Failures are logged and skipped — a broken plugin must not block startup.
let skills_dir = workspace_dir.join("skills");
let mut boxed = boxed_registry_from_arcs(tool_arcs);
let wasm_tools = wasm_tool::load_wasm_tools_from_skills(&skills_dir);
boxed.extend(wasm_tools);
boxed
boxed_registry_from_arcs(tool_arcs)
}
#[cfg(test)]
+100
View File
@@ -125,6 +125,42 @@ impl ModelRoutingConfigTool {
Ok(output)
}
fn normalize_transport_value(raw: &str, field: &str) -> anyhow::Result<String> {
let normalized = raw.trim().to_ascii_lowercase().replace(['-', '_'], "");
match normalized.as_str() {
"auto" => Ok("auto".to_string()),
"websocket" | "ws" => Ok("websocket".to_string()),
"sse" | "http" => Ok("sse".to_string()),
_ => anyhow::bail!("'{field}' must be one of: auto, websocket, sse"),
}
}
fn parse_optional_transport_update(
args: &Value,
field: &str,
) -> anyhow::Result<MaybeSet<String>> {
let Some(raw) = args.get(field) else {
return Ok(MaybeSet::Unset);
};
if raw.is_null() {
return Ok(MaybeSet::Null);
}
let value = raw
.as_str()
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
.trim();
if value.is_empty() {
return Ok(MaybeSet::Null);
}
Ok(MaybeSet::Set(Self::normalize_transport_value(
value, field,
)?))
}
fn parse_optional_f64_update(args: &Value, field: &str) -> anyhow::Result<MaybeSet<f64>> {
let Some(raw) = args.get(field) else {
return Ok(MaybeSet::Unset);
@@ -217,6 +253,7 @@ impl ModelRoutingConfigTool {
"hint": route.hint,
"provider": route.provider,
"model": route.model,
"transport": route.transport,
"api_key_configured": has_provider_credential(&route.provider, route.api_key.as_deref()),
"classification": classification,
})
@@ -429,6 +466,7 @@ impl ModelRoutingConfigTool {
let provider = Self::parse_non_empty_string(args, "provider")?;
let model = Self::parse_non_empty_string(args, "model")?;
let api_key_update = Self::parse_optional_string_update(args, "api_key")?;
let transport_update = Self::parse_optional_transport_update(args, "transport")?;
let keywords_update = if let Some(raw) = args.get("keywords") {
Some(Self::parse_string_list(raw, "keywords")?)
@@ -466,6 +504,7 @@ impl ModelRoutingConfigTool {
model: model.clone(),
max_tokens: None,
api_key: None,
transport: None,
});
next_route.hint = hint.clone();
@@ -478,6 +517,12 @@ impl ModelRoutingConfigTool {
MaybeSet::Unset => {}
}
match transport_update {
MaybeSet::Set(transport) => next_route.transport = Some(transport),
MaybeSet::Null => next_route.transport = None,
MaybeSet::Unset => {}
}
cfg.model_routes.retain(|route| route.hint != hint);
cfg.model_routes.push(next_route);
Self::normalize_and_sort_routes(&mut cfg.model_routes);
@@ -782,6 +827,11 @@ impl Tool for ModelRoutingConfigTool {
"type": ["string", "null"],
"description": "Optional API key override for scenario route or delegate agent"
},
"transport": {
"type": ["string", "null"],
"enum": ["auto", "websocket", "sse", "ws", "http", null],
"description": "Optional route transport override for upsert_scenario (auto, websocket, sse)"
},
"keywords": {
"description": "Classification keywords for upsert_scenario (string or string array)",
"oneOf": [
@@ -1003,6 +1053,7 @@ mod tests {
"hint": "coding",
"provider": "openai",
"model": "gpt-5.3-codex",
"transport": "websocket",
"classification_enabled": true,
"keywords": ["code", "bug", "refactor"],
"patterns": ["```"],
@@ -1024,9 +1075,58 @@ mod tests {
item["hint"] == json!("coding")
&& item["provider"] == json!("openai")
&& item["model"] == json!("gpt-5.3-codex")
&& item["transport"] == json!("websocket")
}));
}
#[tokio::test]
async fn upsert_scenario_transport_alias_is_canonicalized() {
let tmp = TempDir::new().unwrap();
let tool = ModelRoutingConfigTool::new(test_config(&tmp).await, test_security());
let result = tool
.execute(json!({
"action": "upsert_scenario",
"hint": "analysis",
"provider": "openai",
"model": "gpt-5.3-codex",
"transport": "WS"
}))
.await
.unwrap();
assert!(result.success, "{:?}", result.error);
let get_result = tool.execute(json!({"action": "get"})).await.unwrap();
assert!(get_result.success);
let output: Value = serde_json::from_str(&get_result.output).unwrap();
let scenarios = output["scenarios"].as_array().unwrap();
assert!(scenarios.iter().any(|item| {
item["hint"] == json!("analysis") && item["transport"] == json!("websocket")
}));
}
#[tokio::test]
async fn upsert_scenario_rejects_invalid_transport() {
let tmp = TempDir::new().unwrap();
let tool = ModelRoutingConfigTool::new(test_config(&tmp).await, test_security());
let result = tool
.execute(json!({
"action": "upsert_scenario",
"hint": "analysis",
"provider": "openai",
"model": "gpt-5.3-codex",
"transport": "udp"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.unwrap_or_default()
.contains("'transport' must be one of: auto, websocket, sse"));
}
#[tokio::test]
async fn remove_scenario_also_removes_rule() {
let tmp = TempDir::new().unwrap();