Merge remote-tracking branch 'origin/main'
# Conflicts: # src/channels/mod.rs # src/config/mod.rs # src/config/schema.rs
This commit is contained in:
+32
-1
@@ -1,6 +1,7 @@
|
||||
use crate::agent::dispatcher::{
|
||||
NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher,
|
||||
};
|
||||
use crate::agent::loop_::detection::{DetectionVerdict, LoopDetectionConfig, LoopDetector};
|
||||
use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader};
|
||||
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
|
||||
use crate::agent::research;
|
||||
@@ -557,8 +558,13 @@ impl Agent {
|
||||
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
|
||||
|
||||
let effective_model = self.classify_model(user_message);
|
||||
let mut loop_detector = LoopDetector::new(LoopDetectionConfig {
|
||||
no_progress_threshold: self.config.loop_detection_no_progress_threshold,
|
||||
ping_pong_cycles: self.config.loop_detection_ping_pong_cycles,
|
||||
failure_streak_threshold: self.config.loop_detection_failure_streak,
|
||||
});
|
||||
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
for iteration in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
let response = match self
|
||||
.provider
|
||||
@@ -613,9 +619,34 @@ impl Agent {
|
||||
});
|
||||
|
||||
let results = self.execute_tools(&calls).await;
|
||||
|
||||
// ── Loop detection: record calls ─────────────────────
|
||||
for (call, result) in calls.iter().zip(results.iter()) {
|
||||
let args_sig =
|
||||
serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".into());
|
||||
loop_detector.record_call(&call.name, &args_sig, &result.output, result.success);
|
||||
}
|
||||
|
||||
let formatted = self.tool_dispatcher.format_results(&results);
|
||||
self.history.push(formatted);
|
||||
self.trim_history();
|
||||
|
||||
// ── Loop detection: check verdict ────────────────────
|
||||
match loop_detector.check() {
|
||||
DetectionVerdict::Continue => {}
|
||||
DetectionVerdict::InjectWarning(warning) => {
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::user(warning)));
|
||||
}
|
||||
DetectionVerdict::HardStop(reason) => {
|
||||
anyhow::bail!(
|
||||
"Agent stopped early due to detected loop pattern (iteration {}/{}): {}",
|
||||
iteration + 1,
|
||||
self.config.max_tool_iterations,
|
||||
reason
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
|
||||
+135
-38
@@ -28,12 +28,14 @@ use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
mod context;
|
||||
pub(crate) mod detection;
|
||||
mod execution;
|
||||
mod history;
|
||||
mod parsing;
|
||||
|
||||
use crate::agent::session::{create_session_manager, resolve_session_id, SessionManager};
|
||||
use context::{build_context, build_hardware_context};
|
||||
use detection::{DetectionVerdict, LoopDetectionConfig, LoopDetector};
|
||||
use execution::{
|
||||
execute_tools_parallel, execute_tools_sequential, should_execute_tools_in_parallel,
|
||||
ToolExecutionOutcome,
|
||||
@@ -314,6 +316,7 @@ pub(crate) struct NonCliApprovalContext {
|
||||
|
||||
tokio::task_local! {
|
||||
static TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT: Option<NonCliApprovalContext>;
|
||||
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
|
||||
}
|
||||
|
||||
/// Extract a short hint from tool call arguments for progress display.
|
||||
@@ -599,6 +602,14 @@ pub(crate) fn is_tool_iteration_limit_error(err: &anyhow::Error) -> bool {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn is_loop_detection_error(err: &anyhow::Error) -> bool {
|
||||
err.chain().any(|source| {
|
||||
source
|
||||
.to_string()
|
||||
.contains("Agent stopped early due to detected loop pattern")
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
/// When `silent` is true, suppresses stdout (for channel use).
|
||||
@@ -799,6 +810,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
let mut missing_tool_call_retry_used = false;
|
||||
let mut missing_tool_call_retry_prompt: Option<String> = None;
|
||||
let ld_config = LOOP_DETECTION_CONFIG
|
||||
.try_with(Clone::clone)
|
||||
.unwrap_or_default();
|
||||
let mut loop_detector = LoopDetector::new(ld_config);
|
||||
let mut loop_detection_prompt: Option<String> = None;
|
||||
let bypass_non_cli_approval_for_turn =
|
||||
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
|
||||
if bypass_non_cli_approval_for_turn {
|
||||
@@ -842,6 +858,9 @@ pub(crate) async fn run_tool_call_loop(
|
||||
if let Some(prompt) = missing_tool_call_retry_prompt.take() {
|
||||
request_messages.push(ChatMessage::user(prompt));
|
||||
}
|
||||
if let Some(prompt) = loop_detection_prompt.take() {
|
||||
request_messages.push(ChatMessage::user(prompt));
|
||||
}
|
||||
|
||||
// ── Progress: LLM thinking ────────────────────────────
|
||||
if let Some(ref tx) = on_delta {
|
||||
@@ -1469,6 +1488,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
.await;
|
||||
}
|
||||
|
||||
// ── Loop detection: record call ──────────────────────
|
||||
{
|
||||
let sig = tool_call_signature(&call.name, &call.arguments);
|
||||
loop_detector.record_call(&sig.0, &sig.1, &outcome.output, outcome.success);
|
||||
}
|
||||
|
||||
ordered_results[*idx] = Some((call.name.clone(), call.tool_call_id.clone(), outcome));
|
||||
}
|
||||
|
||||
@@ -1514,6 +1539,49 @@ pub(crate) async fn run_tool_call_loop(
|
||||
history.push(ChatMessage::tool(tool_msg.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Loop detection: check verdict ────────────────────────
|
||||
match loop_detector.check() {
|
||||
DetectionVerdict::Continue => {}
|
||||
DetectionVerdict::InjectWarning(warning) => {
|
||||
runtime_trace::record_event(
|
||||
"loop_detected_warning",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("loop pattern detected, injecting self-correction prompt"),
|
||||
serde_json::json!({ "iteration": iteration + 1, "warning": &warning }),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!(
|
||||
"{DRAFT_PROGRESS_SENTINEL}\u{26a0}\u{fe0f} Loop detected, attempting self-correction\n"
|
||||
))
|
||||
.await;
|
||||
}
|
||||
loop_detection_prompt = Some(warning);
|
||||
}
|
||||
DetectionVerdict::HardStop(reason) => {
|
||||
runtime_trace::record_event(
|
||||
"loop_detected_hard_stop",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("loop persisted after warning, stopping early"),
|
||||
serde_json::json!({ "iteration": iteration + 1, "reason": &reason }),
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Agent stopped early due to detected loop pattern (iteration {}/{}): {}",
|
||||
iteration + 1,
|
||||
max_iterations,
|
||||
reason
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
runtime_trace::record_event(
|
||||
@@ -1732,6 +1800,7 @@ pub async fn run(
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: config.api_url.clone(),
|
||||
provider_transport: config.effective_provider_transport(),
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
@@ -1956,25 +2025,34 @@ pub async fn run(
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
|
||||
let response = run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
let ld_cfg = LoopDetectionConfig {
|
||||
no_progress_threshold: config.agent.loop_detection_no_progress_threshold,
|
||||
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
|
||||
failure_streak_threshold: config.agent.loop_detection_failure_streak,
|
||||
};
|
||||
let response = LOOP_DETECTION_CONFIG
|
||||
.scope(
|
||||
ld_cfg,
|
||||
run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
println!("{response}");
|
||||
observer.record_event(&ObserverEvent::TurnComplete);
|
||||
@@ -2081,25 +2159,34 @@ pub async fn run(
|
||||
|
||||
history.push(ChatMessage::user(&enriched));
|
||||
|
||||
let response = match run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
let ld_cfg = LoopDetectionConfig {
|
||||
no_progress_threshold: config.agent.loop_detection_no_progress_threshold,
|
||||
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
|
||||
failure_streak_threshold: config.agent.loop_detection_failure_streak,
|
||||
};
|
||||
let response = match LOOP_DETECTION_CONFIG
|
||||
.scope(
|
||||
ld_cfg,
|
||||
run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
@@ -2113,6 +2200,15 @@ pub async fn run(
|
||||
eprintln!("\n{pause_notice}\n");
|
||||
continue;
|
||||
}
|
||||
if is_loop_detection_error(&e) {
|
||||
let notice =
|
||||
"\u{26a0}\u{fe0f} Loop pattern detected and agent stopped early. \
|
||||
Context preserved. Reply \"continue\" to resume, or adjust \
|
||||
loop_detection_* thresholds in config.";
|
||||
history.push(ChatMessage::assistant(notice));
|
||||
eprintln!("\n{notice}\n");
|
||||
continue;
|
||||
}
|
||||
eprintln!("\nError: {e}\n");
|
||||
continue;
|
||||
}
|
||||
@@ -2218,6 +2314,7 @@ pub async fn process_message(
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: config.api_url.clone(),
|
||||
provider_transport: config.effective_provider_transport(),
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
//! Loop detection for the agent tool-call loop.
|
||||
//!
|
||||
//! Detects three patterns of unproductive looping:
|
||||
//! 1. **No-progress repeat** — same tool + same args + same output hash.
|
||||
//! 2. **Ping-pong** — two calls alternating (A→B→A→B) with no progress.
|
||||
//! 3. **Consecutive failure streak** — same tool failing repeatedly.
|
||||
//!
|
||||
//! On first detection an `InjectWarning` verdict gives the LLM a chance to
|
||||
//! self-correct. If the pattern persists the next check returns `HardStop`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||
|
||||
/// Maximum bytes of tool output considered when hashing results.
|
||||
/// Keeps hashing fast and bounded for large outputs.
|
||||
const OUTPUT_HASH_PREFIX_BYTES: usize = 4096;
|
||||
|
||||
// ─── Configuration ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Tuning knobs for each detection strategy.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct LoopDetectionConfig {
|
||||
/// Identical (tool + args + output) repetitions before triggering.
|
||||
/// `0` = disabled. Default: `3`.
|
||||
pub no_progress_threshold: usize,
|
||||
/// Full A-B cycles before triggering ping-pong detection.
|
||||
/// `0` = disabled. Default: `2`.
|
||||
pub ping_pong_cycles: usize,
|
||||
/// Consecutive failures of the *same* tool before triggering.
|
||||
/// `0` = disabled. Default: `3`.
|
||||
pub failure_streak_threshold: usize,
|
||||
}
|
||||
|
||||
impl Default for LoopDetectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
no_progress_threshold: 3,
|
||||
ping_pong_cycles: 2,
|
||||
failure_streak_threshold: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Verdict ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Action the caller should take after `LoopDetector::check()`.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub(crate) enum DetectionVerdict {
|
||||
/// No loop detected — proceed normally.
|
||||
Continue,
|
||||
/// First detection — inject this self-correction prompt, then continue.
|
||||
InjectWarning(String),
|
||||
/// Pattern persisted after warning — terminate the loop.
|
||||
HardStop(String),
|
||||
}
|
||||
|
||||
// ─── Internal record ─────────────────────────────────────────────────────────
|
||||
|
||||
struct CallRecord {
|
||||
tool_name: String,
|
||||
args_sig: String,
|
||||
result_hash: u64,
|
||||
success: bool,
|
||||
}
|
||||
|
||||
// ─── Detector ────────────────────────────────────────────────────────────────
|
||||
|
||||
pub(crate) struct LoopDetector {
|
||||
config: LoopDetectionConfig,
|
||||
history: Vec<CallRecord>,
|
||||
consecutive_failures: HashMap<String, usize>,
|
||||
warning_injected: bool,
|
||||
}
|
||||
|
||||
impl LoopDetector {
|
||||
pub fn new(config: LoopDetectionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
history: Vec::new(),
|
||||
consecutive_failures: HashMap::new(),
|
||||
warning_injected: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed tool invocation.
|
||||
///
|
||||
/// * `tool_name` — canonical tool name (lowercased by caller).
|
||||
/// * `args_sig` — canonical JSON args string from `tool_call_signature()`.
|
||||
/// * `output` — raw tool output text.
|
||||
/// * `success` — whether the tool reported success.
|
||||
pub fn record_call(&mut self, tool_name: &str, args_sig: &str, output: &str, success: bool) {
|
||||
let result_hash = hash_output(output);
|
||||
self.history.push(CallRecord {
|
||||
tool_name: tool_name.to_owned(),
|
||||
args_sig: args_sig.to_owned(),
|
||||
result_hash,
|
||||
success,
|
||||
});
|
||||
|
||||
if success {
|
||||
self.consecutive_failures.remove(tool_name);
|
||||
} else {
|
||||
*self
|
||||
.consecutive_failures
|
||||
.entry(tool_name.to_owned())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the current history and return a verdict.
|
||||
pub fn check(&mut self) -> DetectionVerdict {
|
||||
let reason = self
|
||||
.check_no_progress_repeat()
|
||||
.or_else(|| self.check_ping_pong())
|
||||
.or_else(|| self.check_failure_streak());
|
||||
|
||||
match reason {
|
||||
None => DetectionVerdict::Continue,
|
||||
Some(msg) => {
|
||||
if self.warning_injected {
|
||||
DetectionVerdict::HardStop(msg)
|
||||
} else {
|
||||
self.warning_injected = true;
|
||||
DetectionVerdict::InjectWarning(format_warning(&msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Strategy 1: no-progress repeat ───────────────────────────────────
|
||||
|
||||
fn check_no_progress_repeat(&self) -> Option<String> {
|
||||
let threshold = self.config.no_progress_threshold;
|
||||
if threshold == 0 || self.history.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let last = self.history.last().unwrap();
|
||||
let streak = self
|
||||
.history
|
||||
.iter()
|
||||
.rev()
|
||||
.take_while(|r| {
|
||||
r.tool_name == last.tool_name
|
||||
&& r.args_sig == last.args_sig
|
||||
&& r.result_hash == last.result_hash
|
||||
})
|
||||
.count();
|
||||
if streak >= threshold {
|
||||
Some(format!(
|
||||
"Tool '{}' called {} times with identical arguments and identical results \
|
||||
— no progress detected",
|
||||
last.tool_name, streak
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ── Strategy 2: ping-pong ────────────────────────────────────────────
|
||||
|
||||
fn check_ping_pong(&self) -> Option<String> {
|
||||
let cycles = self.config.ping_pong_cycles;
|
||||
if cycles == 0 || self.history.len() < 4 {
|
||||
return None;
|
||||
}
|
||||
let len = self.history.len();
|
||||
let a = &self.history[len - 2];
|
||||
let b = &self.history[len - 1];
|
||||
|
||||
// The two sides of the ping-pong must differ.
|
||||
if a.tool_name == b.tool_name && a.args_sig == b.args_sig {
|
||||
return None;
|
||||
}
|
||||
|
||||
let min_entries = cycles * 2;
|
||||
if len < min_entries {
|
||||
return None;
|
||||
}
|
||||
let tail = &self.history[len - min_entries..];
|
||||
let is_ping_pong = tail.chunks(2).all(|pair| {
|
||||
pair.len() == 2
|
||||
&& pair[0].tool_name == a.tool_name
|
||||
&& pair[0].args_sig == a.args_sig
|
||||
&& pair[0].result_hash == a.result_hash
|
||||
&& pair[1].tool_name == b.tool_name
|
||||
&& pair[1].args_sig == b.args_sig
|
||||
&& pair[1].result_hash == b.result_hash
|
||||
});
|
||||
|
||||
if is_ping_pong {
|
||||
Some(format!(
|
||||
"Ping-pong loop detected: '{}' and '{}' alternating {} times with no progress",
|
||||
a.tool_name, b.tool_name, cycles
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ── Strategy 3: consecutive failure streak ───────────────────────────
|
||||
|
||||
fn check_failure_streak(&self) -> Option<String> {
|
||||
let threshold = self.config.failure_streak_threshold;
|
||||
if threshold == 0 {
|
||||
return None;
|
||||
}
|
||||
for (tool, count) in &self.consecutive_failures {
|
||||
if *count >= threshold {
|
||||
return Some(format!(
|
||||
"Tool '{}' failed {} consecutive times",
|
||||
tool, count
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
fn hash_output(output: &str) -> u64 {
|
||||
let prefix = if output.len() > OUTPUT_HASH_PREFIX_BYTES {
|
||||
&output[..OUTPUT_HASH_PREFIX_BYTES]
|
||||
} else {
|
||||
output
|
||||
};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
prefix.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
fn format_warning(reason: &str) -> String {
|
||||
format!(
|
||||
"IMPORTANT: A loop pattern has been detected in your tool usage. {reason}. \
|
||||
You must change your approach: \
|
||||
(1) Try a different tool or different arguments, \
|
||||
(2) If polling a process, increase wait time or check if it's stuck, \
|
||||
(3) If the task cannot be completed, explain why and stop. \
|
||||
Do NOT repeat the same tool call with the same arguments."
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Unit tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_config() -> LoopDetectionConfig {
|
||||
LoopDetectionConfig::default()
|
||||
}
|
||||
|
||||
fn disabled_config() -> LoopDetectionConfig {
|
||||
LoopDetectionConfig {
|
||||
no_progress_threshold: 0,
|
||||
ping_pong_cycles: 0,
|
||||
failure_streak_threshold: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Below threshold → Continue
|
||||
#[test]
|
||||
fn below_threshold_does_not_trigger() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "hello", true);
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "hello", true);
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 2. No-progress repeat triggers warning at threshold
|
||||
#[test]
|
||||
fn no_progress_repeat_triggers_warning() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
for _ in 0..3 {
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "hello", true);
|
||||
}
|
||||
match det.check() {
|
||||
DetectionVerdict::InjectWarning(msg) => {
|
||||
assert!(msg.contains("no progress"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected InjectWarning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Same input but different output → no trigger (progress detected)
|
||||
#[test]
|
||||
fn same_input_different_output_does_not_trigger() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "result_1", true);
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "result_2", true);
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "result_3", true);
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 4. Warning then continued loop → HardStop
|
||||
#[test]
|
||||
fn warning_then_continued_loop_triggers_hard_stop() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
for _ in 0..3 {
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "same", true);
|
||||
}
|
||||
assert!(matches!(det.check(), DetectionVerdict::InjectWarning(_)));
|
||||
// One more identical call
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "same", true);
|
||||
match det.check() {
|
||||
DetectionVerdict::HardStop(msg) => {
|
||||
assert!(msg.contains("no progress"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected HardStop, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Ping-pong detection
|
||||
#[test]
|
||||
fn ping_pong_triggers_warning() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
// 2 cycles: A-B-A-B
|
||||
det.record_call("tool_a", r#"{"x":1}"#, "out_a", true);
|
||||
det.record_call("tool_b", r#"{"y":2}"#, "out_b", true);
|
||||
det.record_call("tool_a", r#"{"x":1}"#, "out_a", true);
|
||||
det.record_call("tool_b", r#"{"y":2}"#, "out_b", true);
|
||||
match det.check() {
|
||||
DetectionVerdict::InjectWarning(msg) => {
|
||||
assert!(msg.contains("Ping-pong"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected InjectWarning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Ping-pong with progress does not trigger
|
||||
#[test]
|
||||
fn ping_pong_with_progress_does_not_trigger() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
det.record_call("tool_a", r#"{"x":1}"#, "out_a_1", true);
|
||||
det.record_call("tool_b", r#"{"y":2}"#, "out_b_1", true);
|
||||
det.record_call("tool_a", r#"{"x":1}"#, "out_a_2", true); // different output
|
||||
det.record_call("tool_b", r#"{"y":2}"#, "out_b_2", true); // different output
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 7. Consecutive failure streak (different args each time to avoid no-progress trigger)
|
||||
#[test]
|
||||
fn failure_streak_triggers_warning() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
det.record_call("shell", r#"{"cmd":"bad1"}"#, "error: not found 1", false);
|
||||
det.record_call("shell", r#"{"cmd":"bad2"}"#, "error: not found 2", false);
|
||||
det.record_call("shell", r#"{"cmd":"bad3"}"#, "error: not found 3", false);
|
||||
match det.check() {
|
||||
DetectionVerdict::InjectWarning(msg) => {
|
||||
assert!(msg.contains("failed 3 consecutive"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected InjectWarning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Failure streak resets on success
|
||||
#[test]
|
||||
fn failure_streak_resets_on_success() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
|
||||
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
|
||||
det.record_call("shell", r#"{"cmd":"good"}"#, "ok", true); // resets
|
||||
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
|
||||
det.record_call("shell", r#"{"cmd":"bad"}"#, "err", false);
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 9. All thresholds zero → disabled
|
||||
#[test]
|
||||
fn all_disabled_never_triggers() {
|
||||
let mut det = LoopDetector::new(disabled_config());
|
||||
for _ in 0..20 {
|
||||
det.record_call("echo", r#"{"msg":"hi"}"#, "same", true);
|
||||
}
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 10. Mixed tools → no false positive
|
||||
#[test]
|
||||
fn mixed_tools_no_false_positive() {
|
||||
let mut det = LoopDetector::new(default_config());
|
||||
det.record_call("file_read", r#"{"path":"a.rs"}"#, "content_a", true);
|
||||
det.record_call("shell", r#"{"cmd":"ls"}"#, "file_list", true);
|
||||
det.record_call("memory_store", r#"{"key":"x"}"#, "stored", true);
|
||||
det.record_call("file_read", r#"{"path":"b.rs"}"#, "content_b", true);
|
||||
det.record_call("shell", r#"{"cmd":"cargo test"}"#, "ok", true);
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
}
|
||||
+76
-116
@@ -999,6 +999,7 @@ fn runtime_perplexity_filter_snapshot(
|
||||
return state.perplexity_filter.clone();
|
||||
}
|
||||
}
|
||||
|
||||
crate::config::PerplexityFilterConfig::default()
|
||||
}
|
||||
|
||||
@@ -2168,54 +2169,6 @@ async fn handle_runtime_command_if_needed(
|
||||
)
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::ApprovePendingRequest(raw_request_id) => {
|
||||
let request_id = raw_request_id.trim().to_string();
|
||||
if request_id.is_empty() {
|
||||
"Usage: `/approve-allow <request-id>`".to_string()
|
||||
} else {
|
||||
match ctx.approval_manager.confirm_non_cli_pending_request(
|
||||
&request_id,
|
||||
sender,
|
||||
source_channel,
|
||||
reply_target,
|
||||
) {
|
||||
Ok(req) => {
|
||||
ctx.approval_manager
|
||||
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::Yes);
|
||||
runtime_trace::record_event(
|
||||
"approval_request_allowed",
|
||||
Some(source_channel),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(true),
|
||||
Some("pending request allowed for current tool invocation"),
|
||||
serde_json::json!({
|
||||
"request_id": request_id,
|
||||
"tool_name": req.tool_name,
|
||||
"sender": sender,
|
||||
"channel": source_channel,
|
||||
}),
|
||||
);
|
||||
format!(
|
||||
"Approved pending request `{}` for this invocation of `{}`.",
|
||||
req.request_id, req.tool_name
|
||||
)
|
||||
}
|
||||
Err(PendingApprovalError::NotFound) => {
|
||||
format!("Pending approval request `{request_id}` was not found.")
|
||||
}
|
||||
Err(PendingApprovalError::Expired) => {
|
||||
format!("Pending approval request `{request_id}` has expired.")
|
||||
}
|
||||
Err(PendingApprovalError::RequesterMismatch) => {
|
||||
format!(
|
||||
"Pending approval request `{request_id}` can only be approved by the same sender in the same chat/channel that created it."
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::ConfirmToolApproval(raw_request_id) => {
|
||||
let request_id = raw_request_id.trim().to_string();
|
||||
if request_id.is_empty() {
|
||||
@@ -2228,8 +2181,6 @@ async fn handle_runtime_command_if_needed(
|
||||
reply_target,
|
||||
) {
|
||||
Ok(req) => {
|
||||
ctx.approval_manager
|
||||
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::Yes);
|
||||
let tool_name = req.tool_name;
|
||||
let mut approval_message = if tool_name == APPROVAL_ALL_TOOLS_ONCE_TOKEN {
|
||||
let remaining = ctx.approval_manager.grant_non_cli_allow_all_once();
|
||||
@@ -2336,10 +2287,54 @@ async fn handle_runtime_command_if_needed(
|
||||
}
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::ApprovePendingRequest(raw_request_id) => {
|
||||
let request_id = raw_request_id.trim().to_string();
|
||||
if request_id.is_empty() {
|
||||
"Usage: `/approve-allow <request-id>`".to_string()
|
||||
} else if !ctx
|
||||
.approval_manager
|
||||
.is_non_cli_approval_actor_allowed(source_channel, sender)
|
||||
{
|
||||
"You are not allowed to approve pending non-CLI tool requests.".to_string()
|
||||
} else {
|
||||
match ctx.approval_manager.confirm_non_cli_pending_request(
|
||||
&request_id,
|
||||
sender,
|
||||
source_channel,
|
||||
reply_target,
|
||||
) {
|
||||
Ok(req) => {
|
||||
ctx.approval_manager.record_non_cli_pending_resolution(
|
||||
&request_id,
|
||||
ApprovalResponse::Yes,
|
||||
);
|
||||
format!(
|
||||
"Approved pending request `{}` for `{}`.",
|
||||
request_id,
|
||||
approval_target_label(&req.tool_name)
|
||||
)
|
||||
}
|
||||
Err(PendingApprovalError::NotFound) => {
|
||||
format!("Pending approval request `{request_id}` was not found.")
|
||||
}
|
||||
Err(PendingApprovalError::Expired) => {
|
||||
format!("Pending approval request `{request_id}` has expired.")
|
||||
}
|
||||
Err(PendingApprovalError::RequesterMismatch) => format!(
|
||||
"Pending approval request `{request_id}` can only be approved by the same sender in the same chat/channel that created it."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::DenyToolApproval(raw_request_id) => {
|
||||
let request_id = raw_request_id.trim().to_string();
|
||||
if request_id.is_empty() {
|
||||
"Usage: `/approve-deny <request-id>`".to_string()
|
||||
} else if !ctx
|
||||
.approval_manager
|
||||
.is_non_cli_approval_actor_allowed(source_channel, sender)
|
||||
{
|
||||
"You are not allowed to deny pending non-CLI tool requests.".to_string()
|
||||
} else {
|
||||
match ctx.approval_manager.reject_non_cli_pending_request(
|
||||
&request_id,
|
||||
@@ -2348,81 +2343,25 @@ async fn handle_runtime_command_if_needed(
|
||||
reply_target,
|
||||
) {
|
||||
Ok(req) => {
|
||||
ctx.approval_manager
|
||||
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::No);
|
||||
runtime_trace::record_event(
|
||||
"approval_request_denied",
|
||||
Some(source_channel),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(true),
|
||||
Some("pending request denied"),
|
||||
serde_json::json!({
|
||||
"request_id": request_id,
|
||||
"tool_name": req.tool_name,
|
||||
"sender": sender,
|
||||
"channel": source_channel,
|
||||
}),
|
||||
ctx.approval_manager.record_non_cli_pending_resolution(
|
||||
&request_id,
|
||||
ApprovalResponse::No,
|
||||
);
|
||||
format!(
|
||||
"Denied pending approval request `{}` for tool `{}`.",
|
||||
req.request_id, req.tool_name
|
||||
"Denied pending request `{}` for `{}`.",
|
||||
request_id,
|
||||
approval_target_label(&req.tool_name)
|
||||
)
|
||||
}
|
||||
Err(PendingApprovalError::NotFound) => {
|
||||
runtime_trace::record_event(
|
||||
"approval_request_denied",
|
||||
Some(source_channel),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(false),
|
||||
Some("pending request not found"),
|
||||
serde_json::json!({
|
||||
"request_id": request_id,
|
||||
"sender": sender,
|
||||
"channel": source_channel,
|
||||
}),
|
||||
);
|
||||
format!("Pending approval request `{request_id}` was not found.")
|
||||
}
|
||||
Err(PendingApprovalError::Expired) => {
|
||||
runtime_trace::record_event(
|
||||
"approval_request_denied",
|
||||
Some(source_channel),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(false),
|
||||
Some("pending request expired"),
|
||||
serde_json::json!({
|
||||
"request_id": request_id,
|
||||
"sender": sender,
|
||||
"channel": source_channel,
|
||||
}),
|
||||
);
|
||||
format!("Pending approval request `{request_id}` has expired.")
|
||||
}
|
||||
Err(PendingApprovalError::RequesterMismatch) => {
|
||||
runtime_trace::record_event(
|
||||
"approval_request_denied",
|
||||
Some(source_channel),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(false),
|
||||
Some("pending request denier mismatch"),
|
||||
serde_json::json!({
|
||||
"request_id": request_id,
|
||||
"sender": sender,
|
||||
"channel": source_channel,
|
||||
}),
|
||||
);
|
||||
format!(
|
||||
"Pending approval request `{request_id}` can only be denied by the same sender in the same chat/channel that created it."
|
||||
)
|
||||
}
|
||||
Err(PendingApprovalError::RequesterMismatch) => format!(
|
||||
"Pending approval request `{request_id}` can only be denied by the same sender in the same chat/channel that created it."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3909,7 +3848,7 @@ async fn run_message_dispatch_loop(
|
||||
}
|
||||
}
|
||||
|
||||
process_channel_message(worker_ctx, msg, cancellation_token).await;
|
||||
Box::pin(process_channel_message(worker_ctx, msg, cancellation_token)).await;
|
||||
|
||||
if interrupt_enabled {
|
||||
let mut active = in_flight.lock().await;
|
||||
@@ -4894,6 +4833,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: config.api_url.clone(),
|
||||
provider_transport: config.effective_provider_transport(),
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
@@ -5301,6 +5241,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::large_futures)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
@@ -10411,6 +10352,25 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
.any(|entry| entry.channel.name() == "mattermost"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_configured_channels_includes_dingtalk_when_configured() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.dingtalk = Some(crate::config::schema::DingTalkConfig {
|
||||
client_id: "ding-app-key".to_string(),
|
||||
client_secret: "ding-app-secret".to_string(),
|
||||
allowed_users: vec!["*".to_string()],
|
||||
});
|
||||
|
||||
let channels = collect_configured_channels(&config, "test");
|
||||
|
||||
assert!(channels
|
||||
.iter()
|
||||
.any(|entry| entry.display_name == "DingTalk"));
|
||||
assert!(channels
|
||||
.iter()
|
||||
.any(|entry| entry.channel.name() == "dingtalk"));
|
||||
}
|
||||
|
||||
struct AlwaysFailChannel {
|
||||
name: &'static str,
|
||||
calls: Arc<AtomicUsize>,
|
||||
|
||||
+15
-14
@@ -8,20 +8,21 @@ pub use schema::{
|
||||
AgentConfig, AgentSessionBackend, AgentSessionConfig, AgentSessionStrategy, AgentsIpcConfig,
|
||||
AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig,
|
||||
ChannelsConfig, ClassificationRule, ComposioConfig, Config, CoordinationConfig, CostConfig,
|
||||
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig,
|
||||
EstopConfig, FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig,
|
||||
HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig,
|
||||
IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NonCliNaturalLanguageApprovalMode, ObservabilityConfig,
|
||||
OtpChallengeDelivery, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig,
|
||||
PerplexityFilterConfig, PluginEntryConfig, PluginsConfig, ProviderConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResearchPhaseConfig,
|
||||
ResearchTrigger, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityRoleConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig,
|
||||
TunnelConfig, UrlAccessConfig, WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy,
|
||||
WasmRuntimeConfig, WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EconomicConfig,
|
||||
EconomicTokenPricing, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GroupReplyConfig, GroupReplyMode, HardwareConfig, HardwareTransport, HeartbeatConfig,
|
||||
HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig,
|
||||
MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
|
||||
NonCliNaturalLanguageApprovalMode, ObservabilityConfig, OtpChallengeDelivery, OtpConfig,
|
||||
OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PerplexityFilterConfig, PluginEntryConfig,
|
||||
PluginsConfig, ProviderConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResearchPhaseConfig, ResearchTrigger,
|
||||
ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig,
|
||||
SecretsConfig, SecurityConfig, SecurityRoleConfig, SkillsConfig, SkillsPromptInjectionMode,
|
||||
SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode,
|
||||
SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig, TunnelConfig, UrlAccessConfig,
|
||||
WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy, WasmRuntimeConfig,
|
||||
WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
@@ -237,6 +237,11 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub cost: CostConfig,
|
||||
|
||||
/// Economic agent survival tracking (`[economic]`).
|
||||
/// Tracks balance, token costs, work income, and survival status.
|
||||
#[serde(default)]
|
||||
pub economic: EconomicConfig,
|
||||
|
||||
/// Peripheral board configuration for hardware integration (`[peripherals]`).
|
||||
#[serde(default)]
|
||||
pub peripherals: PeripheralsConfig,
|
||||
@@ -309,6 +314,20 @@ pub struct ProviderConfig {
|
||||
/// (e.g. OpenAI Codex `/responses` reasoning effort).
|
||||
#[serde(default)]
|
||||
pub reasoning_level: Option<String>,
|
||||
/// Optional transport override for providers that support multiple transports.
|
||||
/// Supported values: "auto", "websocket", "sse".
|
||||
///
|
||||
/// Resolution order:
|
||||
/// 1) `model_routes[].transport` (route-specific)
|
||||
/// 2) env overrides (`PROVIDER_TRANSPORT`, `ZEROCLAW_PROVIDER_TRANSPORT`, `ZEROCLAW_CODEX_TRANSPORT`)
|
||||
/// 3) `provider.transport`
|
||||
/// 4) runtime default (`auto`, WebSocket-first with SSE fallback for OpenAI Codex)
|
||||
///
|
||||
/// Note: env overrides replace configured `provider.transport` when set.
|
||||
///
|
||||
/// Existing configs that omit `provider.transport` remain valid and fall back to defaults.
|
||||
#[serde(default)]
|
||||
pub transport: Option<String>,
|
||||
}
|
||||
|
||||
// ── Delegate Agents ──────────────────────────────────────────────
|
||||
@@ -716,6 +735,21 @@ pub struct AgentConfig {
|
||||
/// Tool dispatch strategy (e.g. `"auto"`). Default: `"auto"`.
|
||||
#[serde(default = "default_agent_tool_dispatcher")]
|
||||
pub tool_dispatcher: String,
|
||||
/// Loop detection: no-progress repeat threshold.
|
||||
/// Triggers when the same tool+args produces identical output this many times.
|
||||
/// Set to `0` to disable. Default: `3`.
|
||||
#[serde(default = "default_loop_detection_no_progress_threshold")]
|
||||
pub loop_detection_no_progress_threshold: usize,
|
||||
/// Loop detection: ping-pong cycle threshold.
|
||||
/// Detects A→B→A→B alternating patterns with no progress.
|
||||
/// Value is number of full cycles (A-B = 1 cycle). Set to `0` to disable. Default: `2`.
|
||||
#[serde(default = "default_loop_detection_ping_pong_cycles")]
|
||||
pub loop_detection_ping_pong_cycles: usize,
|
||||
/// Loop detection: consecutive failure streak threshold.
|
||||
/// Triggers when the same tool fails this many times in a row.
|
||||
/// Set to `0` to disable. Default: `3`.
|
||||
#[serde(default = "default_loop_detection_failure_streak")]
|
||||
pub loop_detection_failure_streak: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -787,6 +821,18 @@ fn default_agent_session_max_messages() -> usize {
|
||||
default_agent_max_history_messages()
|
||||
}
|
||||
|
||||
fn default_loop_detection_no_progress_threshold() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_loop_detection_ping_pong_cycles() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn default_loop_detection_failure_streak() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -796,6 +842,9 @@ impl Default for AgentConfig {
|
||||
max_history_messages: default_agent_max_history_messages(),
|
||||
parallel_tools: false,
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
loop_detection_no_progress_threshold: default_loop_detection_no_progress_threshold(),
|
||||
loop_detection_ping_pong_cycles: default_loop_detection_ping_pong_cycles(),
|
||||
loop_detection_failure_streak: default_loop_detection_failure_streak(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1160,6 +1209,83 @@ pub struct PeripheralBoardConfig {
|
||||
pub baud: u32,
|
||||
}
|
||||
|
||||
// ── Economic Agent Config ─────────────────────────────────────────
|
||||
|
||||
/// Token pricing configuration for economic tracking.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct EconomicTokenPricing {
|
||||
/// Price per million input tokens (USD)
|
||||
#[serde(default = "default_input_price")]
|
||||
pub input_price_per_million: f64,
|
||||
/// Price per million output tokens (USD)
|
||||
#[serde(default = "default_output_price")]
|
||||
pub output_price_per_million: f64,
|
||||
}
|
||||
|
||||
fn default_input_price() -> f64 {
|
||||
3.0 // Claude Sonnet 4 input price
|
||||
}
|
||||
|
||||
fn default_output_price() -> f64 {
|
||||
15.0 // Claude Sonnet 4 output price
|
||||
}
|
||||
|
||||
impl Default for EconomicTokenPricing {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input_price_per_million: default_input_price(),
|
||||
output_price_per_million: default_output_price(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Economic agent survival tracking configuration (`[economic]` section).
|
||||
///
|
||||
/// Implements the ClawWork economic model for AI agents, tracking
|
||||
/// balance, costs, income, and survival status.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct EconomicConfig {
|
||||
/// Enable economic tracking (default: false)
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Starting balance in USD (default: 1000.0)
|
||||
#[serde(default = "default_initial_balance")]
|
||||
pub initial_balance: f64,
|
||||
|
||||
/// Token pricing configuration
|
||||
#[serde(default)]
|
||||
pub token_pricing: EconomicTokenPricing,
|
||||
|
||||
/// Minimum evaluation score (0.0-1.0) to receive payment (default: 0.6)
|
||||
#[serde(default = "default_min_evaluation_threshold")]
|
||||
pub min_evaluation_threshold: f64,
|
||||
|
||||
/// Data directory for economic state persistence (relative to workspace)
|
||||
#[serde(default)]
|
||||
pub data_path: Option<String>,
|
||||
}
|
||||
|
||||
fn default_initial_balance() -> f64 {
|
||||
1000.0
|
||||
}
|
||||
|
||||
fn default_min_evaluation_threshold() -> f64 {
|
||||
0.6
|
||||
}
|
||||
|
||||
impl Default for EconomicConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
initial_balance: default_initial_balance(),
|
||||
token_pricing: EconomicTokenPricing::default(),
|
||||
min_evaluation_threshold: default_min_evaluation_threshold(),
|
||||
data_path: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_peripheral_transport() -> String {
|
||||
"serial".into()
|
||||
}
|
||||
@@ -3266,6 +3392,14 @@ pub struct ModelRouteConfig {
|
||||
/// Optional API key override for this route's provider
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Optional route-specific transport override for this route.
|
||||
/// Supported values: "auto", "websocket", "sse".
|
||||
///
|
||||
/// When `model_routes[].transport` is unset, the route inherits `provider.transport`.
|
||||
/// If both are unset, runtime defaults are used (`auto` for OpenAI Codex).
|
||||
/// Existing configs without this field remain valid.
|
||||
#[serde(default)]
|
||||
pub transport: Option<String>,
|
||||
}
|
||||
|
||||
// ── Embedding routing ───────────────────────────────────────────
|
||||
@@ -5135,6 +5269,7 @@ impl Default for Config {
|
||||
proxy: ProxyConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
economic: EconomicConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
coordination: CoordinationConfig::default(),
|
||||
@@ -6115,6 +6250,28 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_provider_transport(raw: Option<&str>, source: &str) -> Option<String> {
|
||||
let value = raw?.trim();
|
||||
if value.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
|
||||
match normalized.as_str() {
|
||||
"auto" => Some("auto".to_string()),
|
||||
"websocket" | "ws" => Some("websocket".to_string()),
|
||||
"sse" | "http" => Some("sse".to_string()),
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
transport = %value,
|
||||
source,
|
||||
"Ignoring invalid provider transport override"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve provider reasoning level with backward-compatible runtime alias.
|
||||
///
|
||||
/// Priority:
|
||||
@@ -6158,6 +6315,16 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve provider transport mode (`provider.transport`).
|
||||
///
|
||||
/// Supported values:
|
||||
/// - `auto`
|
||||
/// - `websocket`
|
||||
/// - `sse`
|
||||
pub fn effective_provider_transport(&self) -> Option<String> {
|
||||
Self::normalize_provider_transport(self.provider.transport.as_deref(), "provider.transport")
|
||||
}
|
||||
|
||||
fn lookup_model_provider_profile(
|
||||
&self,
|
||||
provider_name: &str,
|
||||
@@ -6521,6 +6688,32 @@ impl Config {
|
||||
if route.max_tokens == Some(0) {
|
||||
anyhow::bail!("model_routes[{i}].max_tokens must be greater than 0");
|
||||
}
|
||||
if route
|
||||
.transport
|
||||
.as_deref()
|
||||
.is_some_and(|value| !value.trim().is_empty())
|
||||
&& Self::normalize_provider_transport(
|
||||
route.transport.as_deref(),
|
||||
"model_routes[].transport",
|
||||
)
|
||||
.is_none()
|
||||
{
|
||||
anyhow::bail!("model_routes[{i}].transport must be one of: auto, websocket, sse");
|
||||
}
|
||||
}
|
||||
|
||||
if self
|
||||
.provider
|
||||
.transport
|
||||
.as_deref()
|
||||
.is_some_and(|value| !value.trim().is_empty())
|
||||
&& Self::normalize_provider_transport(
|
||||
self.provider.transport.as_deref(),
|
||||
"provider.transport",
|
||||
)
|
||||
.is_none()
|
||||
{
|
||||
anyhow::bail!("provider.transport must be one of: auto, websocket, sse");
|
||||
}
|
||||
|
||||
if self.provider_api.is_some()
|
||||
@@ -6852,6 +7045,17 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// Provider transport override: ZEROCLAW_PROVIDER_TRANSPORT or PROVIDER_TRANSPORT
|
||||
if let Ok(transport) = std::env::var("ZEROCLAW_PROVIDER_TRANSPORT")
|
||||
.or_else(|_| std::env::var("PROVIDER_TRANSPORT"))
|
||||
{
|
||||
if let Some(normalized) =
|
||||
Self::normalize_provider_transport(Some(&transport), "env:provider_transport")
|
||||
{
|
||||
self.provider.transport = Some(normalized);
|
||||
}
|
||||
}
|
||||
|
||||
// Vision support override: ZEROCLAW_MODEL_SUPPORT_VISION or MODEL_SUPPORT_VISION
|
||||
if let Ok(flag) = std::env::var("ZEROCLAW_MODEL_SUPPORT_VISION")
|
||||
.or_else(|_| std::env::var("MODEL_SUPPORT_VISION"))
|
||||
@@ -7700,6 +7904,7 @@ default_temperature = 0.7
|
||||
agent: AgentConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
economic: EconomicConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
@@ -8074,6 +8279,7 @@ tool_dispatcher = "xml"
|
||||
agent: AgentConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
economic: EconomicConfig::default(),
|
||||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
hooks: HooksConfig::default(),
|
||||
@@ -9453,6 +9659,7 @@ provider_api = "not-a-real-mode"
|
||||
model: "anthropic/claude-sonnet-4.6".to_string(),
|
||||
max_tokens: Some(0),
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let err = config
|
||||
@@ -9463,6 +9670,48 @@ provider_api = "not-a-real-mode"
|
||||
.contains("model_routes[0].max_tokens must be greater than 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_transport_normalizes_aliases() {
|
||||
let mut config = Config::default();
|
||||
config.provider.transport = Some("WS".to_string());
|
||||
assert_eq!(
|
||||
config.effective_provider_transport().as_deref(),
|
||||
Some("websocket")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_transport_invalid_is_rejected() {
|
||||
let mut config = Config::default();
|
||||
config.provider.transport = Some("udp".to_string());
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("provider.transport should reject invalid values");
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("provider.transport must be one of: auto, websocket, sse"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn model_route_transport_invalid_is_rejected() {
|
||||
let mut config = Config::default();
|
||||
config.model_routes = vec![ModelRouteConfig {
|
||||
hint: "reasoning".to_string(),
|
||||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4.6".to_string(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: Some("udp".to_string()),
|
||||
}];
|
||||
|
||||
let err = config
|
||||
.validate()
|
||||
.expect_err("model_routes[].transport should reject invalid values");
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("model_routes[0].transport must be one of: auto, websocket, sse"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_glm_api_key_for_regional_aliases() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
@@ -10103,6 +10352,60 @@ default_model = "legacy-model"
|
||||
std::env::remove_var("ZEROCLAW_REASONING_LEVEL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_provider_transport_normalizes_zeroclaw_alias() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
let mut config = Config::default();
|
||||
|
||||
std::env::remove_var("PROVIDER_TRANSPORT");
|
||||
std::env::set_var("ZEROCLAW_PROVIDER_TRANSPORT", "WS");
|
||||
config.apply_env_overrides();
|
||||
assert_eq!(config.provider.transport.as_deref(), Some("websocket"));
|
||||
|
||||
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_provider_transport_normalizes_legacy_alias() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
let mut config = Config::default();
|
||||
|
||||
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
|
||||
std::env::set_var("PROVIDER_TRANSPORT", "HTTP");
|
||||
config.apply_env_overrides();
|
||||
assert_eq!(config.provider.transport.as_deref(), Some("sse"));
|
||||
|
||||
std::env::remove_var("PROVIDER_TRANSPORT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_provider_transport_invalid_zeroclaw_does_not_override_existing() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
let mut config = Config::default();
|
||||
config.provider.transport = Some("sse".to_string());
|
||||
|
||||
std::env::remove_var("PROVIDER_TRANSPORT");
|
||||
std::env::set_var("ZEROCLAW_PROVIDER_TRANSPORT", "udp");
|
||||
config.apply_env_overrides();
|
||||
assert_eq!(config.provider.transport.as_deref(), Some("sse"));
|
||||
|
||||
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_provider_transport_invalid_legacy_does_not_override_existing() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
let mut config = Config::default();
|
||||
config.provider.transport = Some("auto".to_string());
|
||||
|
||||
std::env::remove_var("ZEROCLAW_PROVIDER_TRANSPORT");
|
||||
std::env::set_var("PROVIDER_TRANSPORT", "udp");
|
||||
config.apply_env_overrides();
|
||||
assert_eq!(config.provider.transport.as_deref(), Some("auto"));
|
||||
|
||||
std::env::remove_var("PROVIDER_TRANSPORT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn env_override_model_support_vision() {
|
||||
let _env_guard = env_override_lock().await;
|
||||
@@ -10597,6 +10900,46 @@ default_model = "legacy-model"
|
||||
assert_eq!(parsed.allowed_users, vec!["*"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn dingtalk_config_defaults_allowed_users_to_empty() {
|
||||
let json = r#"{"client_id":"ding-app-key","client_secret":"ding-app-secret"}"#;
|
||||
let parsed: DingTalkConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(parsed.client_id, "ding-app-key");
|
||||
assert_eq!(parsed.client_secret, "ding-app-secret");
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn dingtalk_config_toml_roundtrip() {
|
||||
let dc = DingTalkConfig {
|
||||
client_id: "ding-app-key".into(),
|
||||
client_secret: "ding-app-secret".into(),
|
||||
allowed_users: vec!["*".into(), "staff123".into()],
|
||||
};
|
||||
let toml_str = toml::to_string(&dc).unwrap();
|
||||
let parsed: DingTalkConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(parsed.client_id, "ding-app-key");
|
||||
assert_eq!(parsed.client_secret, "ding-app-secret");
|
||||
assert_eq!(parsed.allowed_users, vec!["*", "staff123"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn channels_except_webhook_reports_dingtalk_as_enabled() {
|
||||
let mut channels = ChannelsConfig::default();
|
||||
channels.dingtalk = Some(DingTalkConfig {
|
||||
client_id: "ding-app-key".into(),
|
||||
client_secret: "ding-app-secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
});
|
||||
|
||||
let dingtalk_state = channels
|
||||
.channels_except_webhook()
|
||||
.iter()
|
||||
.find_map(|(handle, enabled)| (handle.name() == "DingTalk").then_some(*enabled));
|
||||
|
||||
assert_eq!(dingtalk_state, Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn nextcloud_talk_config_serde() {
|
||||
let nc = NextcloudTalkConfig {
|
||||
|
||||
@@ -22,6 +22,10 @@ const MIN_POLL_SECONDS: u64 = 5;
|
||||
const SHELL_JOB_TIMEOUT_SECS: u64 = 120;
|
||||
const SCHEDULER_COMPONENT: &str = "scheduler";
|
||||
|
||||
pub(crate) fn is_no_reply_sentinel(output: &str) -> bool {
|
||||
output.trim().eq_ignore_ascii_case("NO_REPLY")
|
||||
}
|
||||
|
||||
pub async fn run(config: Config) -> Result<()> {
|
||||
let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS);
|
||||
let mut interval = time::interval(Duration::from_secs(poll_secs));
|
||||
@@ -292,6 +296,13 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
|
||||
if !delivery.mode.eq_ignore_ascii_case("announce") {
|
||||
return Ok(());
|
||||
}
|
||||
if is_no_reply_sentinel(output) {
|
||||
tracing::debug!(
|
||||
"Cron job '{}' returned NO_REPLY sentinel; skipping announce delivery",
|
||||
job.id
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let channel = delivery
|
||||
.channel
|
||||
@@ -1136,6 +1147,31 @@ mod tests {
|
||||
assert!(err.to_string().contains("unsupported delivery channel"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deliver_if_configured_skips_no_reply_sentinel() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp).await;
|
||||
let mut job = test_job("echo ok");
|
||||
job.delivery = DeliveryConfig {
|
||||
mode: "announce".into(),
|
||||
channel: Some("invalid".into()),
|
||||
to: Some("target".into()),
|
||||
best_effort: true,
|
||||
};
|
||||
|
||||
assert!(deliver_if_configured(&config, &job, " no_reply ")
|
||||
.await
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_reply_sentinel_matching_is_trimmed_and_case_insensitive() {
|
||||
assert!(is_no_reply_sentinel("NO_REPLY"));
|
||||
assert!(is_no_reply_sentinel(" no_reply "));
|
||||
assert!(!is_no_reply_sentinel("NO_REPLY please"));
|
||||
assert!(!is_no_reply_sentinel(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deliver_if_configured_whatsapp_web_requires_live_session_in_web_mode() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
+49
-19
@@ -227,26 +227,25 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
{
|
||||
Ok(output) => {
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
let announcement = if output.trim().is_empty() {
|
||||
"heartbeat task executed".to_string()
|
||||
} else {
|
||||
output
|
||||
};
|
||||
if let Some((channel, target)) = &delivery {
|
||||
if let Err(e) = crate::cron::scheduler::deliver_announcement(
|
||||
&config,
|
||||
channel,
|
||||
target,
|
||||
&announcement,
|
||||
)
|
||||
.await
|
||||
{
|
||||
crate::health::mark_component_error(
|
||||
"heartbeat",
|
||||
format!("delivery failed: {e}"),
|
||||
);
|
||||
tracing::warn!("Heartbeat delivery failed: {e}");
|
||||
if let Some(announcement) = heartbeat_announcement_text(&output) {
|
||||
if let Some((channel, target)) = &delivery {
|
||||
if let Err(e) = crate::cron::scheduler::deliver_announcement(
|
||||
&config,
|
||||
channel,
|
||||
target,
|
||||
&announcement,
|
||||
)
|
||||
.await
|
||||
{
|
||||
crate::health::mark_component_error(
|
||||
"heartbeat",
|
||||
format!("delivery failed: {e}"),
|
||||
);
|
||||
tracing::warn!("Heartbeat delivery failed: {e}");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("Heartbeat returned NO_REPLY sentinel; skipping delivery");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -258,6 +257,16 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn heartbeat_announcement_text(output: &str) -> Option<String> {
|
||||
if crate::cron::scheduler::is_no_reply_sentinel(output) {
|
||||
return None;
|
||||
}
|
||||
if output.trim().is_empty() {
|
||||
return Some("heartbeat task executed".to_string());
|
||||
}
|
||||
Some(output.to_string())
|
||||
}
|
||||
|
||||
fn heartbeat_tasks_for_tick(
|
||||
file_tasks: Vec<String>,
|
||||
fallback_message: Option<&str>,
|
||||
@@ -553,6 +562,27 @@ mod tests {
|
||||
assert!(tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_announcement_text_skips_no_reply_sentinel() {
|
||||
assert!(heartbeat_announcement_text(" NO_reply ").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_announcement_text_uses_default_for_empty_output() {
|
||||
assert_eq!(
|
||||
heartbeat_announcement_text(" \n\t "),
|
||||
Some("heartbeat task executed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_announcement_text_keeps_regular_output() {
|
||||
assert_eq!(
|
||||
heartbeat_announcement_text("system nominal"),
|
||||
Some("system nominal".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_none_when_unset() {
|
||||
let config = Config::default();
|
||||
|
||||
@@ -1167,6 +1167,7 @@ mod tests {
|
||||
model: String::new(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
let mut items = Vec::new();
|
||||
check_config_semantics(&config, &mut items);
|
||||
|
||||
@@ -0,0 +1,874 @@
|
||||
//! Task Classifier for ZeroClaw Economic Agents
|
||||
//!
|
||||
//! Classifies work instructions into 44 BLS occupations with wage data
|
||||
//! to estimate task value for agent economics.
|
||||
//!
|
||||
//! ## Overview
|
||||
//!
|
||||
//! The classifier matches task instructions to standardized occupation
|
||||
//! categories using keyword matching and heuristics, then calculates
|
||||
//! expected payment based on BLS hourly wage data.
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use zeroclaw::economic::classifier::{TaskClassifier, OccupationCategory};
|
||||
//!
|
||||
//! let classifier = TaskClassifier::new();
|
||||
//! let result = classifier.classify("Write a REST API in Rust").await?;
|
||||
//!
|
||||
//! println!("Occupation: {}", result.occupation);
|
||||
//! println!("Hourly wage: ${:.2}", result.hourly_wage);
|
||||
//! println!("Estimated hours: {:.2}", result.estimated_hours);
|
||||
//! println!("Max payment: ${:.2}", result.max_payment);
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Occupation category groupings based on BLS major groups
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum OccupationCategory {
|
||||
/// Software, IT, engineering roles
|
||||
TechnologyEngineering,
|
||||
/// Finance, accounting, management, sales
|
||||
BusinessFinance,
|
||||
/// Medical, nursing, social work
|
||||
HealthcareSocialServices,
|
||||
/// Legal, media, operations, other professional
|
||||
LegalMediaOperations,
|
||||
}
|
||||
|
||||
impl OccupationCategory {
|
||||
/// Returns a human-readable name for the category
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::TechnologyEngineering => "Technology & Engineering",
|
||||
Self::BusinessFinance => "Business & Finance",
|
||||
Self::HealthcareSocialServices => "Healthcare & Social Services",
|
||||
Self::LegalMediaOperations => "Legal, Media & Operations",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single occupation with BLS wage data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Occupation {
|
||||
/// Official BLS occupation name
|
||||
pub name: String,
|
||||
/// Hourly wage in USD (BLS median)
|
||||
pub hourly_wage: f64,
|
||||
/// Category grouping
|
||||
pub category: OccupationCategory,
|
||||
/// Keywords for matching
|
||||
#[serde(skip)]
|
||||
pub keywords: Vec<&'static str>,
|
||||
}
|
||||
|
||||
/// Result of classifying a task instruction
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationResult {
|
||||
/// Matched occupation name
|
||||
pub occupation: String,
|
||||
/// BLS hourly wage for this occupation
|
||||
pub hourly_wage: f64,
|
||||
/// Estimated hours to complete task
|
||||
pub estimated_hours: f64,
|
||||
/// Maximum payment (hours × wage)
|
||||
pub max_payment: f64,
|
||||
/// Classification confidence (0.0 - 1.0)
|
||||
pub confidence: f64,
|
||||
/// Category of the matched occupation
|
||||
pub category: OccupationCategory,
|
||||
/// Brief reasoning for the classification
|
||||
pub reasoning: String,
|
||||
}
|
||||
|
||||
/// Task classifier that maps instructions to BLS occupations
|
||||
#[derive(Debug)]
|
||||
pub struct TaskClassifier {
|
||||
occupations: Vec<Occupation>,
|
||||
keyword_index: HashMap<&'static str, Vec<usize>>,
|
||||
fallback_occupation: String,
|
||||
fallback_wage: f64,
|
||||
}
|
||||
|
||||
impl Default for TaskClassifier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskClassifier {
|
||||
/// Create a new TaskClassifier with embedded BLS occupation data
|
||||
pub fn new() -> Self {
|
||||
let occupations = Self::load_occupations();
|
||||
let keyword_index = Self::build_keyword_index(&occupations);
|
||||
|
||||
Self {
|
||||
occupations,
|
||||
keyword_index,
|
||||
fallback_occupation: "General and Operations Managers".to_string(),
|
||||
fallback_wage: 64.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all 44 BLS occupations with wage data
|
||||
fn load_occupations() -> Vec<Occupation> {
|
||||
use OccupationCategory::{
|
||||
BusinessFinance, HealthcareSocialServices, LegalMediaOperations, TechnologyEngineering,
|
||||
};
|
||||
|
||||
vec![
|
||||
// Technology & Engineering
|
||||
Occupation {
|
||||
name: "Software Developers".into(),
|
||||
hourly_wage: 69.50,
|
||||
category: TechnologyEngineering,
|
||||
keywords: vec![
|
||||
"software",
|
||||
"code",
|
||||
"programming",
|
||||
"developer",
|
||||
"rust",
|
||||
"python",
|
||||
"javascript",
|
||||
"api",
|
||||
"backend",
|
||||
"frontend",
|
||||
"fullstack",
|
||||
"app",
|
||||
"application",
|
||||
"debug",
|
||||
"refactor",
|
||||
"implement",
|
||||
"algorithm",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Computer and Information Systems Managers".into(),
|
||||
hourly_wage: 90.38,
|
||||
category: TechnologyEngineering,
|
||||
keywords: vec![
|
||||
"it manager",
|
||||
"cto",
|
||||
"tech lead",
|
||||
"infrastructure",
|
||||
"systems",
|
||||
"devops",
|
||||
"cloud",
|
||||
"architecture",
|
||||
"platform",
|
||||
"enterprise",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Industrial Engineers".into(),
|
||||
hourly_wage: 51.87,
|
||||
category: TechnologyEngineering,
|
||||
keywords: vec![
|
||||
"industrial",
|
||||
"process",
|
||||
"optimization",
|
||||
"efficiency",
|
||||
"workflow",
|
||||
"manufacturing",
|
||||
"lean",
|
||||
"six sigma",
|
||||
"production",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Mechanical Engineers".into(),
|
||||
hourly_wage: 52.92,
|
||||
category: TechnologyEngineering,
|
||||
keywords: vec![
|
||||
"mechanical",
|
||||
"cad",
|
||||
"solidworks",
|
||||
"machinery",
|
||||
"thermal",
|
||||
"hvac",
|
||||
"automotive",
|
||||
"robotics",
|
||||
],
|
||||
},
|
||||
// Business & Finance
|
||||
Occupation {
|
||||
name: "Accountants and Auditors".into(),
|
||||
hourly_wage: 44.96,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"accounting",
|
||||
"audit",
|
||||
"tax",
|
||||
"bookkeeping",
|
||||
"financial statements",
|
||||
"gaap",
|
||||
"ledger",
|
||||
"reconciliation",
|
||||
"cpa",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Administrative Services Managers".into(),
|
||||
hourly_wage: 60.59,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"administrative",
|
||||
"office manager",
|
||||
"facilities",
|
||||
"operations",
|
||||
"scheduling",
|
||||
"coordination",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Buyers and Purchasing Agents".into(),
|
||||
hourly_wage: 39.29,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"procurement",
|
||||
"purchasing",
|
||||
"vendor",
|
||||
"supplier",
|
||||
"sourcing",
|
||||
"negotiation",
|
||||
"contracts",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Compliance Officers".into(),
|
||||
hourly_wage: 40.86,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"compliance",
|
||||
"regulatory",
|
||||
"audit",
|
||||
"policy",
|
||||
"governance",
|
||||
"risk",
|
||||
"sox",
|
||||
"gdpr",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Financial Managers".into(),
|
||||
hourly_wage: 86.76,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"cfo",
|
||||
"finance director",
|
||||
"treasury",
|
||||
"budget",
|
||||
"financial planning",
|
||||
"investment management",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Financial and Investment Analysts".into(),
|
||||
hourly_wage: 56.01,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"financial analysis",
|
||||
"investment",
|
||||
"portfolio",
|
||||
"stock",
|
||||
"equity",
|
||||
"valuation",
|
||||
"modeling",
|
||||
"dcf",
|
||||
"market research",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "General and Operations Managers".into(),
|
||||
hourly_wage: 64.00,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"operations",
|
||||
"general manager",
|
||||
"director",
|
||||
"oversee",
|
||||
"manage",
|
||||
"strategy",
|
||||
"leadership",
|
||||
"business",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Market Research Analysts and Marketing Specialists".into(),
|
||||
hourly_wage: 41.58,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"market research",
|
||||
"marketing",
|
||||
"campaign",
|
||||
"branding",
|
||||
"seo",
|
||||
"advertising",
|
||||
"analytics",
|
||||
"customer",
|
||||
"segment",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Personal Financial Advisors".into(),
|
||||
hourly_wage: 77.02,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"financial advisor",
|
||||
"wealth",
|
||||
"retirement",
|
||||
"401k",
|
||||
"ira",
|
||||
"estate planning",
|
||||
"insurance",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Project Management Specialists".into(),
|
||||
hourly_wage: 51.97,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"project manager",
|
||||
"pmp",
|
||||
"agile",
|
||||
"scrum",
|
||||
"sprint",
|
||||
"milestone",
|
||||
"timeline",
|
||||
"stakeholder",
|
||||
"deliverable",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Property, Real Estate, and Community Association Managers".into(),
|
||||
hourly_wage: 39.77,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"property",
|
||||
"real estate",
|
||||
"landlord",
|
||||
"tenant",
|
||||
"lease",
|
||||
"hoa",
|
||||
"community",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Sales Managers".into(),
|
||||
hourly_wage: 77.37,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"sales manager",
|
||||
"revenue",
|
||||
"quota",
|
||||
"pipeline",
|
||||
"crm",
|
||||
"account executive",
|
||||
"territory",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Marketing and Sales Managers".into(),
|
||||
hourly_wage: 79.35,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["vp sales", "cmo", "growth", "go-to-market", "demand gen"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Financial Specialists".into(),
|
||||
hourly_wage: 48.12,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["financial specialist", "credit", "loan", "underwriting"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Securities, Commodities, and Financial Services Sales Agents".into(),
|
||||
hourly_wage: 48.12,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["broker", "securities", "commodities", "trading", "series 7"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Business Operations Specialists, All Other".into(),
|
||||
hourly_wage: 44.41,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"business analyst",
|
||||
"operations specialist",
|
||||
"process improvement",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Claims Adjusters, Examiners, and Investigators".into(),
|
||||
hourly_wage: 37.87,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["claims", "insurance", "adjuster", "investigator", "fraud"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Transportation, Storage, and Distribution Managers".into(),
|
||||
hourly_wage: 55.77,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"logistics",
|
||||
"supply chain",
|
||||
"warehouse",
|
||||
"distribution",
|
||||
"shipping",
|
||||
"inventory",
|
||||
"fulfillment",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Industrial Production Managers".into(),
|
||||
hourly_wage: 62.11,
|
||||
category: BusinessFinance,
|
||||
keywords: vec![
|
||||
"production manager",
|
||||
"plant manager",
|
||||
"manufacturing operations",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Lodging Managers".into(),
|
||||
hourly_wage: 37.24,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["hotel", "hospitality", "lodging", "resort", "concierge"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Real Estate Brokers".into(),
|
||||
hourly_wage: 39.77,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["real estate broker", "realtor", "mls", "listing"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Managers, All Other".into(),
|
||||
hourly_wage: 72.06,
|
||||
category: BusinessFinance,
|
||||
keywords: vec!["manager", "supervisor", "team lead"],
|
||||
},
|
||||
// Healthcare & Social Services
|
||||
Occupation {
|
||||
name: "Medical and Health Services Managers".into(),
|
||||
hourly_wage: 66.22,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec![
|
||||
"healthcare",
|
||||
"hospital",
|
||||
"clinic",
|
||||
"medical",
|
||||
"health services",
|
||||
"patient",
|
||||
"hipaa",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Social and Community Service Managers".into(),
|
||||
hourly_wage: 41.39,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec![
|
||||
"social services",
|
||||
"community",
|
||||
"nonprofit",
|
||||
"outreach",
|
||||
"case management",
|
||||
"welfare",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Child, Family, and School Social Workers".into(),
|
||||
hourly_wage: 41.39,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec![
|
||||
"social worker",
|
||||
"child welfare",
|
||||
"family services",
|
||||
"school counselor",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Registered Nurses".into(),
|
||||
hourly_wage: 66.22,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec!["nurse", "rn", "nursing", "patient care", "clinical"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Nurse Practitioners".into(),
|
||||
hourly_wage: 66.22,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec!["np", "nurse practitioner", "aprn", "prescribe"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Pharmacists".into(),
|
||||
hourly_wage: 66.22,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec![
|
||||
"pharmacy",
|
||||
"pharmacist",
|
||||
"medication",
|
||||
"prescription",
|
||||
"drug",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Medical Secretaries and Administrative Assistants".into(),
|
||||
hourly_wage: 66.22,
|
||||
category: HealthcareSocialServices,
|
||||
keywords: vec![
|
||||
"medical secretary",
|
||||
"medical records",
|
||||
"ehr",
|
||||
"scheduling appointments",
|
||||
],
|
||||
},
|
||||
// Legal, Media & Operations
|
||||
Occupation {
|
||||
name: "Lawyers".into(),
|
||||
hourly_wage: 44.41,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"lawyer",
|
||||
"attorney",
|
||||
"legal",
|
||||
"contract",
|
||||
"litigation",
|
||||
"counsel",
|
||||
"law",
|
||||
"paralegal",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Editors".into(),
|
||||
hourly_wage: 72.06,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"editor",
|
||||
"editing",
|
||||
"proofread",
|
||||
"copy edit",
|
||||
"manuscript",
|
||||
"publication",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Film and Video Editors".into(),
|
||||
hourly_wage: 68.15,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"video editor",
|
||||
"film",
|
||||
"premiere",
|
||||
"final cut",
|
||||
"davinci",
|
||||
"post-production",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Audio and Video Technicians".into(),
|
||||
hourly_wage: 41.86,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"audio",
|
||||
"video",
|
||||
"av",
|
||||
"broadcast",
|
||||
"streaming",
|
||||
"recording",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Producers and Directors".into(),
|
||||
hourly_wage: 41.86,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"producer",
|
||||
"director",
|
||||
"production",
|
||||
"creative director",
|
||||
"content",
|
||||
"show",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "News Analysts, Reporters, and Journalists".into(),
|
||||
hourly_wage: 68.15,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"journalist",
|
||||
"reporter",
|
||||
"news",
|
||||
"article",
|
||||
"press",
|
||||
"interview",
|
||||
"story",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "Entertainment and Recreation Managers, Except Gambling".into(),
|
||||
hourly_wage: 41.86,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec!["entertainment", "recreation", "event", "venue", "concert"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Recreation Workers".into(),
|
||||
hourly_wage: 41.86,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec!["recreation", "activity", "fitness", "sports"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Customer Service Representatives".into(),
|
||||
hourly_wage: 44.41,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec!["customer service", "support", "helpdesk", "ticket", "chat"],
|
||||
},
|
||||
Occupation {
|
||||
name: "Private Detectives and Investigators".into(),
|
||||
hourly_wage: 37.87,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec![
|
||||
"detective",
|
||||
"investigator",
|
||||
"background check",
|
||||
"surveillance",
|
||||
],
|
||||
},
|
||||
Occupation {
|
||||
name: "First-Line Supervisors of Police and Detectives".into(),
|
||||
hourly_wage: 72.06,
|
||||
category: LegalMediaOperations,
|
||||
keywords: vec!["police", "law enforcement", "security supervisor"],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Build keyword → occupation index for fast lookup
|
||||
fn build_keyword_index(occupations: &[Occupation]) -> HashMap<&'static str, Vec<usize>> {
|
||||
let mut index: HashMap<&'static str, Vec<usize>> = HashMap::new();
|
||||
for (i, occ) in occupations.iter().enumerate() {
|
||||
for &kw in &occ.keywords {
|
||||
index.entry(kw).or_default().push(i);
|
||||
}
|
||||
}
|
||||
index
|
||||
}
|
||||
|
||||
/// Classify a task instruction into an occupation with estimated value
|
||||
///
|
||||
/// This is a synchronous keyword-based classifier. For LLM-based
|
||||
/// classification, use `classify_with_llm` instead.
|
||||
pub fn classify(&self, instruction: &str) -> ClassificationResult {
|
||||
let lower = instruction.to_lowercase();
|
||||
let mut scores: HashMap<usize, f64> = HashMap::new();
|
||||
|
||||
// Score each occupation by keyword matches
|
||||
for (keyword, occ_indices) in &self.keyword_index {
|
||||
if lower.contains(keyword) {
|
||||
for &idx in occ_indices {
|
||||
*scores.entry(idx).or_default() += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find best match
|
||||
let (best_idx, best_score) = scores
|
||||
.iter()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(&idx, &score)| (idx, score))
|
||||
.unwrap_or((usize::MAX, 0.0));
|
||||
|
||||
let (occupation, hourly_wage, category, confidence, reasoning) =
|
||||
if best_idx < self.occupations.len() {
|
||||
let occ = &self.occupations[best_idx];
|
||||
let confidence = (best_score / 3.0).min(1.0); // Normalize confidence
|
||||
(
|
||||
occ.name.clone(),
|
||||
occ.hourly_wage,
|
||||
occ.category,
|
||||
confidence,
|
||||
format!("Matched {:.0} keywords", best_score),
|
||||
)
|
||||
} else {
|
||||
// Fallback
|
||||
(
|
||||
self.fallback_occupation.clone(),
|
||||
self.fallback_wage,
|
||||
OccupationCategory::BusinessFinance,
|
||||
0.3,
|
||||
"Fallback classification - no strong keyword match".to_string(),
|
||||
)
|
||||
};
|
||||
|
||||
let estimated_hours = Self::estimate_hours(instruction);
|
||||
let max_payment = (estimated_hours * hourly_wage * 100.0).round() / 100.0;
|
||||
|
||||
ClassificationResult {
|
||||
occupation,
|
||||
hourly_wage,
|
||||
estimated_hours,
|
||||
max_payment,
|
||||
confidence,
|
||||
category,
|
||||
reasoning,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate hours based on instruction complexity
|
||||
fn estimate_hours(instruction: &str) -> f64 {
|
||||
let word_count = instruction.split_whitespace().count();
|
||||
let has_complex_markers = instruction.to_lowercase().contains("implement")
|
||||
|| instruction.contains("build")
|
||||
|| instruction.contains("create")
|
||||
|| instruction.contains("design")
|
||||
|| instruction.contains("develop");
|
||||
|
||||
let has_simple_markers = instruction.to_lowercase().contains("fix")
|
||||
|| instruction.contains("update")
|
||||
|| instruction.contains("change")
|
||||
|| instruction.contains("review");
|
||||
|
||||
let base_hours = if has_complex_markers {
|
||||
2.0
|
||||
} else if has_simple_markers {
|
||||
0.5
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
// Scale by instruction length
|
||||
let length_factor = (word_count as f64 / 20.0).clamp(0.5, 2.0);
|
||||
let hours = base_hours * length_factor;
|
||||
|
||||
// Clamp to valid range
|
||||
hours.clamp(0.25, 40.0)
|
||||
}
|
||||
|
||||
/// Get all occupations
|
||||
pub fn occupations(&self) -> &[Occupation] {
|
||||
&self.occupations
|
||||
}
|
||||
|
||||
/// Get occupations by category
|
||||
pub fn occupations_by_category(&self, category: OccupationCategory) -> Vec<&Occupation> {
|
||||
self.occupations
|
||||
.iter()
|
||||
.filter(|o| o.category == category)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the fallback occupation name
|
||||
pub fn fallback_occupation(&self) -> &str {
|
||||
&self.fallback_occupation
|
||||
}
|
||||
|
||||
/// Get the fallback hourly wage
|
||||
pub fn fallback_wage(&self) -> f64 {
|
||||
self.fallback_wage
|
||||
}
|
||||
|
||||
/// Look up an occupation by exact name
|
||||
pub fn get_occupation(&self, name: &str) -> Option<&Occupation> {
|
||||
self.occupations.iter().find(|o| o.name == name)
|
||||
}
|
||||
|
||||
/// Fuzzy match an occupation name (case-insensitive, substring)
|
||||
pub fn fuzzy_match(&self, name: &str) -> Option<&Occupation> {
|
||||
let lower = name.to_lowercase();
|
||||
|
||||
// Exact match first
|
||||
if let Some(occ) = self.occupations.iter().find(|o| o.name == name) {
|
||||
return Some(occ);
|
||||
}
|
||||
|
||||
// Case-insensitive match
|
||||
if let Some(occ) = self
|
||||
.occupations
|
||||
.iter()
|
||||
.find(|o| o.name.to_lowercase() == lower)
|
||||
{
|
||||
return Some(occ);
|
||||
}
|
||||
|
||||
// Substring match
|
||||
self.occupations.iter().find(|o| {
|
||||
lower.contains(&o.name.to_lowercase()) || o.name.to_lowercase().contains(&lower)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_classifier_new() {
|
||||
let classifier = TaskClassifier::new();
|
||||
assert_eq!(classifier.occupations.len(), 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_software() {
|
||||
let classifier = TaskClassifier::new();
|
||||
let result = classifier.classify("Write a REST API in Rust with authentication");
|
||||
|
||||
assert_eq!(result.occupation, "Software Developers");
|
||||
assert!((result.hourly_wage - 69.50).abs() < 0.01);
|
||||
assert!(result.confidence > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_finance() {
|
||||
let classifier = TaskClassifier::new();
|
||||
let result = classifier.classify("Prepare quarterly financial statements and audit trail");
|
||||
|
||||
assert!(
|
||||
result.occupation.contains("Account") || result.occupation.contains("Financial"),
|
||||
"Expected finance occupation, got: {}",
|
||||
result.occupation
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_fallback() {
|
||||
let classifier = TaskClassifier::new();
|
||||
let result = classifier.classify("xyzzy foobar baz");
|
||||
|
||||
assert_eq!(result.occupation, "General and Operations Managers");
|
||||
assert_eq!(result.confidence, 0.3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_hours_complex() {
|
||||
let hours = TaskClassifier::estimate_hours(
|
||||
"Implement a complete microservices architecture with event sourcing",
|
||||
);
|
||||
assert!(hours >= 1.0, "Complex task should estimate >= 1 hour");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_hours_simple() {
|
||||
let hours = TaskClassifier::estimate_hours("Fix typo");
|
||||
assert!(hours <= 1.0, "Simple task should estimate <= 1 hour");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_match() {
|
||||
let classifier = TaskClassifier::new();
|
||||
|
||||
// Exact match
|
||||
assert!(classifier.fuzzy_match("Software Developers").is_some());
|
||||
|
||||
// Case insensitive
|
||||
assert!(classifier.fuzzy_match("software developers").is_some());
|
||||
|
||||
// Substring
|
||||
assert!(classifier.fuzzy_match("Software").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_occupations_by_category() {
|
||||
let classifier = TaskClassifier::new();
|
||||
let tech = classifier.occupations_by_category(OccupationCategory::TechnologyEngineering);
|
||||
|
||||
assert!(!tech.is_empty());
|
||||
assert!(tech.iter().any(|o| o.name == "Software Developers"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
//! Token cost tracking types for economic agents.
|
||||
//!
|
||||
//! Separates costs by channel (LLM, search API, OCR, etc.) following
|
||||
//! the ClawWork economic model.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Channel-separated cost breakdown for a task or session.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CostBreakdown {
|
||||
/// Cost from LLM token usage
|
||||
pub llm_tokens: f64,
|
||||
/// Cost from search API calls (Brave, JINA, Tavily, etc.)
|
||||
pub search_api: f64,
|
||||
/// Cost from OCR API calls
|
||||
pub ocr_api: f64,
|
||||
/// Cost from other API calls
|
||||
pub other_api: f64,
|
||||
}
|
||||
|
||||
impl CostBreakdown {
|
||||
/// Create a new empty cost breakdown.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Get total cost across all channels.
|
||||
pub fn total(&self) -> f64 {
|
||||
self.llm_tokens + self.search_api + self.ocr_api + self.other_api
|
||||
}
|
||||
|
||||
/// Add another breakdown to this one.
|
||||
pub fn add(&mut self, other: &CostBreakdown) {
|
||||
self.llm_tokens += other.llm_tokens;
|
||||
self.search_api += other.search_api;
|
||||
self.ocr_api += other.ocr_api;
|
||||
self.other_api += other.other_api;
|
||||
}
|
||||
|
||||
/// Reset all costs to zero.
|
||||
pub fn reset(&mut self) {
|
||||
self.llm_tokens = 0.0;
|
||||
self.search_api = 0.0;
|
||||
self.ocr_api = 0.0;
|
||||
self.other_api = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Token pricing configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenPricing {
|
||||
/// Price per million input tokens (USD)
|
||||
pub input_price_per_million: f64,
|
||||
/// Price per million output tokens (USD)
|
||||
pub output_price_per_million: f64,
|
||||
}
|
||||
|
||||
impl Default for TokenPricing {
|
||||
fn default() -> Self {
|
||||
// Default to Claude Sonnet 4 pricing via OpenRouter
|
||||
Self {
|
||||
input_price_per_million: 3.0,
|
||||
output_price_per_million: 15.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenPricing {
|
||||
/// Calculate cost for given token counts.
|
||||
pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_price_per_million;
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_price_per_million;
|
||||
input_cost + output_cost
|
||||
}
|
||||
}
|
||||
|
||||
/// A single LLM call record with token details.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmCallRecord {
|
||||
/// Timestamp of the call
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// API name/source (e.g., "agent", "wrapup", "research")
|
||||
pub api_name: String,
|
||||
/// Number of input tokens
|
||||
pub input_tokens: u64,
|
||||
/// Number of output tokens
|
||||
pub output_tokens: u64,
|
||||
/// Cost in USD
|
||||
pub cost: f64,
|
||||
}
|
||||
|
||||
/// A single API call record (non-LLM).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiCallRecord {
|
||||
/// Timestamp of the call
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// API name (e.g., "tavily_search", "jina_reader")
|
||||
pub api_name: String,
|
||||
/// Pricing model used
|
||||
pub pricing_model: PricingModel,
|
||||
/// Number of tokens (if token-based pricing)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<u64>,
|
||||
/// Price per million tokens (if token-based)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub price_per_million: Option<f64>,
|
||||
/// Cost in USD
|
||||
pub cost: f64,
|
||||
}
|
||||
|
||||
/// Pricing model for API calls.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PricingModel {
|
||||
/// Token-based pricing (cost = tokens / 1M * price_per_million)
|
||||
PerToken,
|
||||
/// Flat rate per call
|
||||
FlatRate,
|
||||
}
|
||||
|
||||
/// Comprehensive task cost record (one per task).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskCostRecord {
|
||||
/// Task end timestamp
|
||||
pub timestamp_end: DateTime<Utc>,
|
||||
/// Task start timestamp
|
||||
pub timestamp_start: DateTime<Utc>,
|
||||
/// Date the task was assigned (YYYY-MM-DD)
|
||||
pub date: String,
|
||||
/// Unique task identifier
|
||||
pub task_id: String,
|
||||
/// LLM usage summary
|
||||
pub llm_usage: LlmUsageSummary,
|
||||
/// API usage summary
|
||||
pub api_usage: ApiUsageSummary,
|
||||
/// Cost summary by channel
|
||||
pub cost_summary: CostBreakdown,
|
||||
/// Balance after this task
|
||||
pub balance_after: f64,
|
||||
/// Session cost so far
|
||||
pub session_cost: f64,
|
||||
/// Daily cost so far
|
||||
pub daily_cost: f64,
|
||||
}
|
||||
|
||||
/// Aggregated LLM usage for a task.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct LlmUsageSummary {
|
||||
/// Number of LLM calls made
|
||||
pub total_calls: usize,
|
||||
/// Total input tokens
|
||||
pub total_input_tokens: u64,
|
||||
/// Total output tokens
|
||||
pub total_output_tokens: u64,
|
||||
/// Total tokens (input + output)
|
||||
pub total_tokens: u64,
|
||||
/// Total cost in USD
|
||||
pub total_cost: f64,
|
||||
/// Pricing used
|
||||
pub input_price_per_million: f64,
|
||||
pub output_price_per_million: f64,
|
||||
/// Detailed call records
|
||||
#[serde(default)]
|
||||
pub calls_detail: Vec<LlmCallRecord>,
|
||||
}
|
||||
|
||||
/// Aggregated API usage for a task.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ApiUsageSummary {
|
||||
/// Number of API calls made
|
||||
pub total_calls: usize,
|
||||
/// Search API costs
|
||||
pub search_api_cost: f64,
|
||||
/// OCR API costs
|
||||
pub ocr_api_cost: f64,
|
||||
/// Other API costs
|
||||
pub other_api_cost: f64,
|
||||
/// Number of token-based calls
|
||||
pub token_based_calls: usize,
|
||||
/// Number of flat-rate calls
|
||||
pub flat_rate_calls: usize,
|
||||
/// Detailed call records
|
||||
#[serde(default)]
|
||||
pub calls_detail: Vec<ApiCallRecord>,
|
||||
}
|
||||
|
||||
/// Work income record.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkIncomeRecord {
|
||||
/// Timestamp
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Date (YYYY-MM-DD)
|
||||
pub date: String,
|
||||
/// Task identifier
|
||||
pub task_id: String,
|
||||
/// Base payment amount offered
|
||||
pub base_amount: f64,
|
||||
/// Actual payment received (0 if below threshold)
|
||||
pub actual_payment: f64,
|
||||
/// Evaluation score (0.0-1.0)
|
||||
pub evaluation_score: f64,
|
||||
/// Minimum threshold required for payment
|
||||
pub threshold: f64,
|
||||
/// Whether payment was awarded
|
||||
pub payment_awarded: bool,
|
||||
/// Optional description
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
/// Balance after this income
|
||||
pub balance_after: f64,
|
||||
}
|
||||
|
||||
/// Daily balance record for persistence.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BalanceRecord {
|
||||
/// Date (YYYY-MM-DD or "initialization")
|
||||
pub date: String,
|
||||
/// Current balance
|
||||
pub balance: f64,
|
||||
/// Token cost delta for this period
|
||||
pub token_cost_delta: f64,
|
||||
/// Work income delta for this period
|
||||
pub work_income_delta: f64,
|
||||
/// Trading profit delta for this period
|
||||
pub trading_profit_delta: f64,
|
||||
/// Cumulative total token cost
|
||||
pub total_token_cost: f64,
|
||||
/// Cumulative total work income
|
||||
pub total_work_income: f64,
|
||||
/// Cumulative total trading profit
|
||||
pub total_trading_profit: f64,
|
||||
/// Net worth (balance + portfolio value)
|
||||
pub net_worth: f64,
|
||||
/// Current survival status
|
||||
pub survival_status: String,
|
||||
/// Tasks completed in this period
|
||||
#[serde(default)]
|
||||
pub completed_tasks: Vec<String>,
|
||||
/// Primary task ID for the day
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub task_id: Option<String>,
|
||||
/// Time to complete tasks (seconds)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub task_completion_time_seconds: Option<f64>,
|
||||
/// Whether session was aborted by API error
|
||||
#[serde(default)]
|
||||
pub api_error: bool,
|
||||
}
|
||||
|
||||
/// Task completion record for analytics.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskCompletionRecord {
|
||||
/// Task identifier
|
||||
pub task_id: String,
|
||||
/// Date (YYYY-MM-DD)
|
||||
pub date: String,
|
||||
/// Attempt number (1-based)
|
||||
pub attempt: u32,
|
||||
/// Whether work was submitted
|
||||
pub work_submitted: bool,
|
||||
/// Evaluation score (0.0-1.0)
|
||||
pub evaluation_score: f64,
|
||||
/// Money earned from this task
|
||||
pub money_earned: f64,
|
||||
/// Wall-clock time in seconds
|
||||
pub wall_clock_seconds: f64,
|
||||
/// Timestamp of completion
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Economic analytics summary.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct EconomicAnalytics {
|
||||
/// Total costs by channel
|
||||
pub total_costs: CostBreakdown,
|
||||
/// Costs broken down by date
|
||||
pub by_date: HashMap<String, DateCostSummary>,
|
||||
/// Costs broken down by task
|
||||
pub by_task: HashMap<String, TaskCostSummary>,
|
||||
/// Total number of tasks
|
||||
pub total_tasks: usize,
|
||||
/// Total income earned
|
||||
pub total_income: f64,
|
||||
/// Number of tasks that received payment
|
||||
pub tasks_paid: usize,
|
||||
/// Number of tasks rejected (below threshold)
|
||||
pub tasks_rejected: usize,
|
||||
}
|
||||
|
||||
/// Cost summary for a single date.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct DateCostSummary {
|
||||
/// Costs by channel
|
||||
#[serde(flatten)]
|
||||
pub costs: CostBreakdown,
|
||||
/// Total cost
|
||||
pub total: f64,
|
||||
/// Income earned
|
||||
pub income: f64,
|
||||
}
|
||||
|
||||
/// Cost summary for a single task.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct TaskCostSummary {
|
||||
/// Costs by channel
|
||||
#[serde(flatten)]
|
||||
pub costs: CostBreakdown,
|
||||
/// Total cost
|
||||
pub total: f64,
|
||||
/// Date of the task
|
||||
pub date: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cost_breakdown_total() {
|
||||
let breakdown = CostBreakdown {
|
||||
llm_tokens: 1.0,
|
||||
search_api: 0.5,
|
||||
ocr_api: 0.25,
|
||||
other_api: 0.1,
|
||||
};
|
||||
assert!((breakdown.total() - 1.85).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_breakdown_add() {
|
||||
let mut a = CostBreakdown {
|
||||
llm_tokens: 1.0,
|
||||
search_api: 0.5,
|
||||
ocr_api: 0.0,
|
||||
other_api: 0.0,
|
||||
};
|
||||
let b = CostBreakdown {
|
||||
llm_tokens: 0.5,
|
||||
search_api: 0.25,
|
||||
ocr_api: 0.1,
|
||||
other_api: 0.05,
|
||||
};
|
||||
a.add(&b);
|
||||
assert!((a.llm_tokens - 1.5).abs() < f64::EPSILON);
|
||||
assert!((a.search_api - 0.75).abs() < f64::EPSILON);
|
||||
assert!((a.total() - 2.4).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_pricing_calculation() {
|
||||
let pricing = TokenPricing {
|
||||
input_price_per_million: 3.0,
|
||||
output_price_per_million: 15.0,
|
||||
};
|
||||
// 1000 input, 500 output
|
||||
// (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
|
||||
let cost = pricing.calculate_cost(1000, 500);
|
||||
assert!((cost - 0.0105).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_token_pricing() {
|
||||
let pricing = TokenPricing::default();
|
||||
assert!((pricing.input_price_per_million - 3.0).abs() < f64::EPSILON);
|
||||
assert!((pricing.output_price_per_million - 15.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
//! Economic tracking module for agent survival economics.
|
||||
//!
|
||||
//! This module implements the ClawWork economic model for AI agents,
|
||||
//! tracking balance, costs, income, and survival status. Agents start
|
||||
//! with initial capital and must manage their resources while completing
|
||||
//! tasks.
|
||||
//!
|
||||
//! ## Overview
|
||||
//!
|
||||
//! The economic system models agent viability:
|
||||
//! - **Balance**: Starting capital minus costs plus earned income
|
||||
//! - **Costs**: LLM tokens, search APIs, OCR, and other service usage
|
||||
//! - **Income**: Payments for completed tasks (with quality threshold)
|
||||
//! - **Status**: Health indicator based on remaining capital percentage
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use zeroclaw::economic::{EconomicTracker, EconomicConfig, SurvivalStatus};
|
||||
//!
|
||||
//! let config = EconomicConfig {
|
||||
//! enabled: true,
|
||||
//! initial_balance: 1000.0,
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! let tracker = EconomicTracker::new("my-agent", config, None);
|
||||
//! tracker.initialize()?;
|
||||
//!
|
||||
//! // Start a task
|
||||
//! tracker.start_task("task-001", None);
|
||||
//!
|
||||
//! // Track LLM usage
|
||||
//! let cost = tracker.track_tokens(1000, 500, "agent", None);
|
||||
//!
|
||||
//! // Complete task and earn income
|
||||
//! tracker.end_task()?;
|
||||
//! let payment = tracker.add_work_income(10.0, "task-001", 0.85, "Completed task")?;
|
||||
//!
|
||||
//! // Check survival status
|
||||
//! match tracker.get_survival_status() {
|
||||
//! SurvivalStatus::Thriving => println!("Agent is healthy!"),
|
||||
//! SurvivalStatus::Bankrupt => println!("Agent needs intervention!"),
|
||||
//! _ => {}
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Persistence
|
||||
//!
|
||||
//! Economic state is persisted to JSONL files:
|
||||
//! - `balance.jsonl`: Daily balance snapshots and cumulative totals
|
||||
//! - `token_costs.jsonl`: Detailed per-task cost records
|
||||
//! - `task_completions.jsonl`: Task completion statistics
|
||||
//!
|
||||
//! ## Configuration
|
||||
//!
|
||||
//! Add to `config.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! [economic]
|
||||
//! enabled = true
|
||||
//! initial_balance = 1000.0
|
||||
//! min_evaluation_threshold = 0.6
|
||||
//!
|
||||
//! [economic.token_pricing]
|
||||
//! input_price_per_million = 3.0
|
||||
//! output_price_per_million = 15.0
|
||||
//! ```
|
||||
|
||||
pub mod classifier;
|
||||
pub mod costs;
|
||||
pub mod status;
|
||||
pub mod tracker;
|
||||
|
||||
// Re-exports for convenient access
|
||||
pub use classifier::{ClassificationResult, Occupation, OccupationCategory, TaskClassifier};
|
||||
pub use costs::{
|
||||
ApiCallRecord, ApiUsageSummary, BalanceRecord, CostBreakdown, DateCostSummary,
|
||||
EconomicAnalytics, LlmCallRecord, LlmUsageSummary, PricingModel, TaskCompletionRecord,
|
||||
TaskCostRecord, TaskCostSummary, TokenPricing, WorkIncomeRecord,
|
||||
};
|
||||
pub use status::SurvivalStatus;
|
||||
pub use tracker::{EconomicConfig, EconomicSummary, EconomicTracker};
|
||||
@@ -0,0 +1,207 @@
|
||||
//! Survival status tracking for economic agents.
|
||||
//!
|
||||
//! Defines the health states an agent can be in based on remaining balance
|
||||
//! as a percentage of initial capital.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Survival status based on balance percentage relative to initial capital.
|
||||
///
|
||||
/// Mirrors the ClawWork LiveBench agent survival states.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SurvivalStatus {
|
||||
/// Balance > 80% of initial - Agent is profitable and healthy
|
||||
Thriving,
|
||||
/// Balance 40-80% of initial - Agent is maintaining stability
|
||||
#[default]
|
||||
Stable,
|
||||
/// Balance 10-40% of initial - Agent is losing money, needs attention
|
||||
Struggling,
|
||||
/// Balance 1-10% of initial - Agent is near death, urgent intervention needed
|
||||
Critical,
|
||||
/// Balance <= 0 - Agent has exhausted resources and cannot operate
|
||||
Bankrupt,
|
||||
}
|
||||
|
||||
impl SurvivalStatus {
|
||||
/// Calculate survival status from current and initial balance.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `current_balance` - Current remaining balance
|
||||
/// * `initial_balance` - Starting balance
|
||||
///
|
||||
/// # Returns
|
||||
/// The appropriate `SurvivalStatus` based on the percentage remaining.
|
||||
pub fn from_balance(current_balance: f64, initial_balance: f64) -> Self {
|
||||
if initial_balance <= 0.0 {
|
||||
// Edge case: if initial was zero or negative, can't calculate percentage
|
||||
return if current_balance <= 0.0 {
|
||||
Self::Bankrupt
|
||||
} else {
|
||||
Self::Thriving
|
||||
};
|
||||
}
|
||||
|
||||
let percentage = (current_balance / initial_balance) * 100.0;
|
||||
|
||||
match percentage {
|
||||
p if p <= 0.0 => Self::Bankrupt,
|
||||
p if p < 10.0 => Self::Critical,
|
||||
p if p < 40.0 => Self::Struggling,
|
||||
p if p < 80.0 => Self::Stable,
|
||||
_ => Self::Thriving,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the agent can still operate (not bankrupt).
|
||||
pub fn is_operational(&self) -> bool {
|
||||
!matches!(self, Self::Bankrupt)
|
||||
}
|
||||
|
||||
/// Check if the agent needs urgent attention.
|
||||
pub fn needs_intervention(&self) -> bool {
|
||||
matches!(self, Self::Critical | Self::Bankrupt)
|
||||
}
|
||||
|
||||
/// Get a human-readable emoji indicator.
|
||||
pub fn emoji(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Thriving => "🌟",
|
||||
Self::Stable => "✅",
|
||||
Self::Struggling => "⚠️",
|
||||
Self::Critical => "🚨",
|
||||
Self::Bankrupt => "💀",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a color code for terminal output (ANSI).
|
||||
pub fn ansi_color(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Thriving => "\x1b[32m", // Green
|
||||
Self::Stable => "\x1b[34m", // Blue
|
||||
Self::Struggling => "\x1b[33m", // Yellow
|
||||
Self::Critical => "\x1b[31m", // Red
|
||||
Self::Bankrupt => "\x1b[35m", // Magenta
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SurvivalStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let status = match self {
|
||||
Self::Thriving => "Thriving",
|
||||
Self::Stable => "Stable",
|
||||
Self::Struggling => "Struggling",
|
||||
Self::Critical => "Critical",
|
||||
Self::Bankrupt => "Bankrupt",
|
||||
};
|
||||
write!(f, "{}", status)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn thriving_above_80_percent() {
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(900.0, 1000.0),
|
||||
SurvivalStatus::Thriving
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(1500.0, 1000.0), // Profit!
|
||||
SurvivalStatus::Thriving
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(800.01, 1000.0),
|
||||
SurvivalStatus::Thriving
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stable_between_40_and_80_percent() {
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(799.99, 1000.0),
|
||||
SurvivalStatus::Stable
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(500.0, 1000.0),
|
||||
SurvivalStatus::Stable
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(400.01, 1000.0),
|
||||
SurvivalStatus::Stable
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struggling_between_10_and_40_percent() {
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(399.99, 1000.0),
|
||||
SurvivalStatus::Struggling
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(200.0, 1000.0),
|
||||
SurvivalStatus::Struggling
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(100.01, 1000.0),
|
||||
SurvivalStatus::Struggling
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn critical_between_0_and_10_percent() {
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(99.99, 1000.0),
|
||||
SurvivalStatus::Critical
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(50.0, 1000.0),
|
||||
SurvivalStatus::Critical
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(0.01, 1000.0),
|
||||
SurvivalStatus::Critical
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bankrupt_at_zero_or_negative() {
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(0.0, 1000.0),
|
||||
SurvivalStatus::Bankrupt
|
||||
);
|
||||
assert_eq!(
|
||||
SurvivalStatus::from_balance(-100.0, 1000.0),
|
||||
SurvivalStatus::Bankrupt
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_operational() {
|
||||
assert!(SurvivalStatus::Thriving.is_operational());
|
||||
assert!(SurvivalStatus::Stable.is_operational());
|
||||
assert!(SurvivalStatus::Struggling.is_operational());
|
||||
assert!(SurvivalStatus::Critical.is_operational());
|
||||
assert!(!SurvivalStatus::Bankrupt.is_operational());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_intervention() {
|
||||
assert!(!SurvivalStatus::Thriving.needs_intervention());
|
||||
assert!(!SurvivalStatus::Stable.needs_intervention());
|
||||
assert!(!SurvivalStatus::Struggling.needs_intervention());
|
||||
assert!(SurvivalStatus::Critical.needs_intervention());
|
||||
assert!(SurvivalStatus::Bankrupt.needs_intervention());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format() {
|
||||
assert_eq!(format!("{}", SurvivalStatus::Thriving), "Thriving");
|
||||
assert_eq!(format!("{}", SurvivalStatus::Bankrupt), "Bankrupt");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,992 @@
|
||||
//! Economic tracker for agent survival economics.
|
||||
//!
|
||||
//! Tracks balance, token costs, work income, and survival status following
|
||||
//! the ClawWork LiveBench economic model. Persists state to JSONL files.
|
||||
|
||||
use super::costs::{
|
||||
ApiCallRecord, ApiUsageSummary, BalanceRecord, CostBreakdown, LlmCallRecord, LlmUsageSummary,
|
||||
PricingModel, TaskCompletionRecord, TaskCostRecord, TokenPricing, WorkIncomeRecord,
|
||||
};
|
||||
use super::status::SurvivalStatus;
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Economic configuration options.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EconomicConfig {
|
||||
/// Enable economic tracking
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Starting balance in USD
|
||||
#[serde(default = "default_initial_balance")]
|
||||
pub initial_balance: f64,
|
||||
/// Token pricing configuration
|
||||
#[serde(default)]
|
||||
pub token_pricing: TokenPricing,
|
||||
/// Minimum evaluation score to receive payment (0.0-1.0)
|
||||
#[serde(default = "default_min_threshold")]
|
||||
pub min_evaluation_threshold: f64,
|
||||
}
|
||||
|
||||
fn default_initial_balance() -> f64 {
|
||||
1000.0
|
||||
}
|
||||
|
||||
fn default_min_threshold() -> f64 {
|
||||
0.6
|
||||
}
|
||||
|
||||
impl Default for EconomicConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
initial_balance: default_initial_balance(),
|
||||
token_pricing: TokenPricing::default(),
|
||||
min_evaluation_threshold: default_min_threshold(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Task-level tracking state (in-memory during task execution).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct TaskState {
|
||||
/// Current task ID
|
||||
task_id: Option<String>,
|
||||
/// Date the task was assigned
|
||||
task_date: Option<String>,
|
||||
/// Task start timestamp
|
||||
start_time: Option<DateTime<Utc>>,
|
||||
/// Costs accumulated for this task
|
||||
costs: CostBreakdown,
|
||||
/// LLM call records
|
||||
llm_calls: Vec<LlmCallRecord>,
|
||||
/// API call records
|
||||
api_calls: Vec<ApiCallRecord>,
|
||||
}
|
||||
|
||||
impl TaskState {
|
||||
fn reset(&mut self) {
|
||||
self.task_id = None;
|
||||
self.task_date = None;
|
||||
self.start_time = None;
|
||||
self.costs.reset();
|
||||
self.llm_calls.clear();
|
||||
self.api_calls.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Daily tracking state (accumulated across tasks).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct DailyState {
|
||||
/// Task IDs completed today
|
||||
task_ids: Vec<String>,
|
||||
/// First task start time
|
||||
first_task_start: Option<DateTime<Utc>>,
|
||||
/// Last task end time
|
||||
last_task_end: Option<DateTime<Utc>>,
|
||||
/// Daily cost accumulator
|
||||
cost: f64,
|
||||
}
|
||||
|
||||
impl DailyState {
|
||||
fn reset(&mut self) {
|
||||
self.task_ids.clear();
|
||||
self.first_task_start = None;
|
||||
self.last_task_end = None;
|
||||
self.cost = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Session tracking state.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct SessionState {
|
||||
/// Input tokens this session
|
||||
input_tokens: u64,
|
||||
/// Output tokens this session
|
||||
output_tokens: u64,
|
||||
/// Cost this session
|
||||
cost: f64,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
fn reset(&mut self) {
|
||||
self.input_tokens = 0;
|
||||
self.output_tokens = 0;
|
||||
self.cost = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Economic tracker for managing agent survival economics.
|
||||
///
|
||||
/// Tracks:
|
||||
/// - Balance (starting capital minus costs plus income)
|
||||
/// - Token costs separated by channel (LLM, search, OCR, etc.)
|
||||
/// - Work income with evaluation threshold
|
||||
/// - Trading profits/losses
|
||||
/// - Survival status
|
||||
///
|
||||
/// Persists records to JSONL files for durability and analysis.
|
||||
pub struct EconomicTracker {
|
||||
/// Configuration
|
||||
config: EconomicConfig,
|
||||
/// Agent signature/name
|
||||
signature: String,
|
||||
/// Data directory for persistence
|
||||
data_path: PathBuf,
|
||||
/// Current balance (protected by mutex for thread safety)
|
||||
state: Arc<Mutex<TrackerState>>,
|
||||
}
|
||||
|
||||
/// Internal mutable state.
|
||||
struct TrackerState {
|
||||
/// Current balance
|
||||
balance: f64,
|
||||
/// Initial balance (for status calculation)
|
||||
initial_balance: f64,
|
||||
/// Cumulative totals
|
||||
total_token_cost: f64,
|
||||
total_work_income: f64,
|
||||
total_trading_profit: f64,
|
||||
/// Task-level tracking
|
||||
task: TaskState,
|
||||
/// Daily tracking
|
||||
daily: DailyState,
|
||||
/// Session tracking
|
||||
session: SessionState,
|
||||
}
|
||||
|
||||
impl EconomicTracker {
|
||||
/// Create a new economic tracker.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `signature` - Agent signature/name for identification
|
||||
/// * `config` - Economic configuration
|
||||
/// * `data_path` - Optional custom data path (defaults to `./data/agent_data/{signature}/economic`)
|
||||
pub fn new(
|
||||
signature: impl Into<String>,
|
||||
config: EconomicConfig,
|
||||
data_path: Option<PathBuf>,
|
||||
) -> Self {
|
||||
let signature = signature.into();
|
||||
let data_path = data_path
|
||||
.unwrap_or_else(|| PathBuf::from(format!("./data/agent_data/{}/economic", signature)));
|
||||
|
||||
Self {
|
||||
signature,
|
||||
state: Arc::new(Mutex::new(TrackerState {
|
||||
balance: config.initial_balance,
|
||||
initial_balance: config.initial_balance,
|
||||
total_token_cost: 0.0,
|
||||
total_work_income: 0.0,
|
||||
total_trading_profit: 0.0,
|
||||
task: TaskState::default(),
|
||||
daily: DailyState::default(),
|
||||
session: SessionState::default(),
|
||||
})),
|
||||
config,
|
||||
data_path,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tracker, loading existing state or creating new.
|
||||
pub fn initialize(&self) -> Result<()> {
|
||||
fs::create_dir_all(&self.data_path).with_context(|| {
|
||||
format!(
|
||||
"Failed to create data directory: {}",
|
||||
self.data_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let balance_file = self.balance_file_path();
|
||||
|
||||
if balance_file.exists() {
|
||||
self.load_latest_state()?;
|
||||
let state = self.state.lock();
|
||||
tracing::info!(
|
||||
"📊 Loaded economic state for {}: balance=${:.2}, status={}",
|
||||
self.signature,
|
||||
state.balance,
|
||||
self.get_survival_status_inner(&state)
|
||||
);
|
||||
} else {
|
||||
self.save_balance_record("initialization", 0.0, 0.0, 0.0, Vec::new(), false)?;
|
||||
tracing::info!(
|
||||
"✅ Initialized economic tracker for {}: starting balance=${:.2}",
|
||||
self.signature,
|
||||
self.config.initial_balance
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start tracking costs for a new task.
|
||||
pub fn start_task(&self, task_id: impl Into<String>, date: Option<String>) {
|
||||
let task_id = task_id.into();
|
||||
let date = date.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string());
|
||||
let now = Utc::now();
|
||||
|
||||
let mut state = self.state.lock();
|
||||
state.task.task_id = Some(task_id.clone());
|
||||
state.task.task_date = Some(date);
|
||||
state.task.start_time = Some(now);
|
||||
state.task.costs.reset();
|
||||
state.task.llm_calls.clear();
|
||||
state.task.api_calls.clear();
|
||||
|
||||
// Track daily window
|
||||
if state.daily.first_task_start.is_none() {
|
||||
state.daily.first_task_start = Some(now);
|
||||
}
|
||||
state.daily.task_ids.push(task_id);
|
||||
}
|
||||
|
||||
/// End tracking for current task and save consolidated record.
|
||||
pub fn end_task(&self) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
|
||||
if state.task.task_id.is_some() {
|
||||
self.save_task_record_inner(&state)?;
|
||||
state.daily.last_task_end = Some(Utc::now());
|
||||
state.task.reset();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track LLM token usage.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_tokens` - Number of input tokens
|
||||
/// * `output_tokens` - Number of output tokens
|
||||
/// * `api_name` - Origin of the call (e.g., "agent", "wrapup")
|
||||
/// * `cost` - Pre-computed cost (if provided, skips local calculation)
|
||||
///
|
||||
/// # Returns
|
||||
/// The cost in USD for this call.
|
||||
pub fn track_tokens(
|
||||
&self,
|
||||
input_tokens: u64,
|
||||
output_tokens: u64,
|
||||
api_name: impl Into<String>,
|
||||
cost: Option<f64>,
|
||||
) -> f64 {
|
||||
let api_name = api_name.into();
|
||||
let cost = cost.unwrap_or_else(|| {
|
||||
self.config
|
||||
.token_pricing
|
||||
.calculate_cost(input_tokens, output_tokens)
|
||||
});
|
||||
|
||||
let mut state = self.state.lock();
|
||||
|
||||
// Update session tracking
|
||||
state.session.input_tokens += input_tokens;
|
||||
state.session.output_tokens += output_tokens;
|
||||
state.session.cost += cost;
|
||||
state.daily.cost += cost;
|
||||
|
||||
// Update task-level tracking
|
||||
state.task.costs.llm_tokens += cost;
|
||||
state.task.llm_calls.push(LlmCallRecord {
|
||||
timestamp: Utc::now(),
|
||||
api_name,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cost,
|
||||
});
|
||||
|
||||
// Update totals
|
||||
state.total_token_cost += cost;
|
||||
state.balance -= cost;
|
||||
|
||||
cost
|
||||
}
|
||||
|
||||
/// Track token-based API call cost.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens used
|
||||
/// * `price_per_million` - Price per million tokens
|
||||
/// * `api_name` - Name of the API
|
||||
///
|
||||
/// # Returns
|
||||
/// The cost in USD for this call.
|
||||
pub fn track_api_call(
|
||||
&self,
|
||||
tokens: u64,
|
||||
price_per_million: f64,
|
||||
api_name: impl Into<String>,
|
||||
) -> f64 {
|
||||
let api_name = api_name.into();
|
||||
let cost = (tokens as f64 / 1_000_000.0) * price_per_million;
|
||||
|
||||
self.record_api_cost(
|
||||
&api_name,
|
||||
cost,
|
||||
Some(tokens),
|
||||
Some(price_per_million),
|
||||
PricingModel::PerToken,
|
||||
);
|
||||
|
||||
cost
|
||||
}
|
||||
|
||||
/// Track flat-rate API call cost.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cost` - Flat cost in USD
|
||||
/// * `api_name` - Name of the API
|
||||
///
|
||||
/// # Returns
|
||||
/// The cost (same as input).
|
||||
pub fn track_flat_api_call(&self, cost: f64, api_name: impl Into<String>) -> f64 {
|
||||
let api_name = api_name.into();
|
||||
self.record_api_cost(&api_name, cost, None, None, PricingModel::FlatRate);
|
||||
cost
|
||||
}
|
||||
|
||||
fn record_api_cost(
|
||||
&self,
|
||||
api_name: &str,
|
||||
cost: f64,
|
||||
tokens: Option<u64>,
|
||||
price_per_million: Option<f64>,
|
||||
pricing_model: PricingModel,
|
||||
) {
|
||||
let mut state = self.state.lock();
|
||||
|
||||
// Update session/daily
|
||||
state.session.cost += cost;
|
||||
state.daily.cost += cost;
|
||||
|
||||
// Categorize by API type
|
||||
let api_lower = api_name.to_lowercase();
|
||||
if api_lower.contains("search")
|
||||
|| api_lower.contains("jina")
|
||||
|| api_lower.contains("tavily")
|
||||
{
|
||||
state.task.costs.search_api += cost;
|
||||
} else if api_lower.contains("ocr") {
|
||||
state.task.costs.ocr_api += cost;
|
||||
} else {
|
||||
state.task.costs.other_api += cost;
|
||||
}
|
||||
|
||||
// Record detailed call
|
||||
state.task.api_calls.push(ApiCallRecord {
|
||||
timestamp: Utc::now(),
|
||||
api_name: api_name.to_string(),
|
||||
pricing_model,
|
||||
tokens,
|
||||
price_per_million,
|
||||
cost,
|
||||
});
|
||||
|
||||
// Update totals
|
||||
state.total_token_cost += cost;
|
||||
state.balance -= cost;
|
||||
}
|
||||
|
||||
/// Add income from completed work with evaluation threshold.
|
||||
///
|
||||
/// Payment is only awarded if `evaluation_score >= min_evaluation_threshold`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `amount` - Base payment amount in USD
|
||||
/// * `task_id` - Task identifier
|
||||
/// * `evaluation_score` - Score from 0.0 to 1.0
|
||||
/// * `description` - Optional description
|
||||
///
|
||||
/// # Returns
|
||||
/// Actual payment received (0.0 if below threshold).
|
||||
pub fn add_work_income(
|
||||
&self,
|
||||
amount: f64,
|
||||
task_id: impl Into<String>,
|
||||
evaluation_score: f64,
|
||||
description: impl Into<String>,
|
||||
) -> Result<f64> {
|
||||
let task_id = task_id.into();
|
||||
let description = description.into();
|
||||
let threshold = self.config.min_evaluation_threshold;
|
||||
|
||||
let actual_payment = if evaluation_score >= threshold {
|
||||
amount
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
{
|
||||
let mut state = self.state.lock();
|
||||
if actual_payment > 0.0 {
|
||||
state.balance += actual_payment;
|
||||
state.total_work_income += actual_payment;
|
||||
tracing::info!(
|
||||
"💰 Work income: +${:.2} (Task: {}, Score: {:.2})",
|
||||
actual_payment,
|
||||
task_id,
|
||||
evaluation_score
|
||||
);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"⚠️ Work below threshold (score: {:.2} < {:.2}), no payment for task: {}",
|
||||
evaluation_score,
|
||||
threshold,
|
||||
task_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.log_work_income(
|
||||
&task_id,
|
||||
amount,
|
||||
actual_payment,
|
||||
evaluation_score,
|
||||
&description,
|
||||
)?;
|
||||
|
||||
Ok(actual_payment)
|
||||
}
|
||||
|
||||
/// Add profit/loss from trading.
|
||||
pub fn add_trading_profit(&self, profit: f64, _description: impl Into<String>) {
|
||||
let mut state = self.state.lock();
|
||||
state.balance += profit;
|
||||
state.total_trading_profit += profit;
|
||||
|
||||
let sign = if profit >= 0.0 { "+" } else { "" };
|
||||
tracing::info!(
|
||||
"📈 Trading P&L: {}${:.2}, new balance: ${:.2}",
|
||||
sign,
|
||||
profit,
|
||||
state.balance
|
||||
);
|
||||
}
|
||||
|
||||
/// Save end-of-day economic state.
|
||||
pub fn save_daily_state(
|
||||
&self,
|
||||
date: &str,
|
||||
work_income: f64,
|
||||
trading_profit: f64,
|
||||
completed_tasks: Vec<String>,
|
||||
api_error: bool,
|
||||
) -> Result<()> {
|
||||
let daily_cost = {
|
||||
let state = self.state.lock();
|
||||
state.daily.cost
|
||||
};
|
||||
|
||||
self.save_balance_record(
|
||||
date,
|
||||
daily_cost,
|
||||
work_income,
|
||||
trading_profit,
|
||||
completed_tasks,
|
||||
api_error,
|
||||
)?;
|
||||
|
||||
// Reset daily tracking
|
||||
{
|
||||
let mut state = self.state.lock();
|
||||
state.daily.reset();
|
||||
state.session.reset();
|
||||
}
|
||||
|
||||
tracing::info!("💾 Saved daily state for {}", date);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current balance.
|
||||
pub fn get_balance(&self) -> f64 {
|
||||
self.state.lock().balance
|
||||
}
|
||||
|
||||
/// Get net worth (balance + portfolio value).
|
||||
pub fn get_net_worth(&self) -> f64 {
|
||||
// TODO: Add trading portfolio value
|
||||
self.get_balance()
|
||||
}
|
||||
|
||||
/// Get current survival status.
|
||||
pub fn get_survival_status(&self) -> SurvivalStatus {
|
||||
let state = self.state.lock();
|
||||
self.get_survival_status_inner(&state)
|
||||
}
|
||||
|
||||
fn get_survival_status_inner(&self, state: &TrackerState) -> SurvivalStatus {
|
||||
SurvivalStatus::from_balance(state.balance, state.initial_balance)
|
||||
}
|
||||
|
||||
/// Check if agent is bankrupt.
|
||||
pub fn is_bankrupt(&self) -> bool {
|
||||
self.get_survival_status() == SurvivalStatus::Bankrupt
|
||||
}
|
||||
|
||||
/// Get session cost so far.
|
||||
pub fn get_session_cost(&self) -> f64 {
|
||||
self.state.lock().session.cost
|
||||
}
|
||||
|
||||
/// Get daily cost so far.
|
||||
pub fn get_daily_cost(&self) -> f64 {
|
||||
self.state.lock().daily.cost
|
||||
}
|
||||
|
||||
/// Get comprehensive economic summary.
|
||||
pub fn get_summary(&self) -> EconomicSummary {
|
||||
let state = self.state.lock();
|
||||
EconomicSummary {
|
||||
signature: self.signature.clone(),
|
||||
balance: state.balance,
|
||||
initial_balance: state.initial_balance,
|
||||
net_worth: state.balance, // TODO: Add portfolio
|
||||
total_token_cost: state.total_token_cost,
|
||||
total_work_income: state.total_work_income,
|
||||
total_trading_profit: state.total_trading_profit,
|
||||
session_cost: state.session.cost,
|
||||
daily_cost: state.daily.cost,
|
||||
session_input_tokens: state.session.input_tokens,
|
||||
session_output_tokens: state.session.output_tokens,
|
||||
survival_status: self.get_survival_status_inner(&state),
|
||||
is_bankrupt: self.get_survival_status_inner(&state) == SurvivalStatus::Bankrupt,
|
||||
min_evaluation_threshold: self.config.min_evaluation_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset session tracking (for new decision/activity).
|
||||
pub fn reset_session(&self) {
|
||||
self.state.lock().session.reset();
|
||||
}
|
||||
|
||||
/// Record task completion statistics.
|
||||
pub fn record_task_completion(
|
||||
&self,
|
||||
task_id: impl Into<String>,
|
||||
work_submitted: bool,
|
||||
wall_clock_seconds: f64,
|
||||
evaluation_score: f64,
|
||||
money_earned: f64,
|
||||
attempt: u32,
|
||||
date: Option<String>,
|
||||
) -> Result<()> {
|
||||
let task_id = task_id.into();
|
||||
let date = date
|
||||
.or_else(|| self.state.lock().task.task_date.clone())
|
||||
.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string());
|
||||
|
||||
let record = TaskCompletionRecord {
|
||||
task_id: task_id.clone(),
|
||||
date,
|
||||
attempt,
|
||||
work_submitted,
|
||||
evaluation_score,
|
||||
money_earned,
|
||||
wall_clock_seconds,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
// Read existing records, filter out this task_id
|
||||
let completions_file = self.task_completions_file_path();
|
||||
let mut existing: Vec<String> = Vec::new();
|
||||
|
||||
if completions_file.exists() {
|
||||
let file = File::open(&completions_file)?;
|
||||
let reader = BufReader::new(file);
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(entry) = serde_json::from_str::<TaskCompletionRecord>(&line) {
|
||||
if entry.task_id != task_id {
|
||||
existing.push(line);
|
||||
}
|
||||
} else {
|
||||
existing.push(line);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrite with updated record
|
||||
let mut file = File::create(&completions_file)?;
|
||||
for line in existing {
|
||||
writeln!(file, "{}", line)?;
|
||||
}
|
||||
writeln!(file, "{}", serde_json::to_string(&record)?)?;
|
||||
file.sync_all()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Private helpers ──
|
||||
|
||||
fn balance_file_path(&self) -> PathBuf {
|
||||
self.data_path.join("balance.jsonl")
|
||||
}
|
||||
|
||||
fn token_costs_file_path(&self) -> PathBuf {
|
||||
self.data_path.join("token_costs.jsonl")
|
||||
}
|
||||
|
||||
fn task_completions_file_path(&self) -> PathBuf {
|
||||
self.data_path.join("task_completions.jsonl")
|
||||
}
|
||||
|
||||
fn load_latest_state(&self) -> Result<()> {
|
||||
let balance_file = self.balance_file_path();
|
||||
let file = File::open(&balance_file)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut last_record: Option<BalanceRecord> = None;
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if let Ok(record) = serde_json::from_str::<BalanceRecord>(&line) {
|
||||
last_record = Some(record);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(record) = last_record {
|
||||
let mut state = self.state.lock();
|
||||
state.balance = record.balance;
|
||||
state.total_token_cost = record.total_token_cost;
|
||||
state.total_work_income = record.total_work_income;
|
||||
state.total_trading_profit = record.total_trading_profit;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_task_record_inner(&self, state: &TrackerState) -> Result<()> {
|
||||
let Some(ref task_id) = state.task.task_id else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let total_input = state.task.llm_calls.iter().map(|c| c.input_tokens).sum();
|
||||
let total_output = state.task.llm_calls.iter().map(|c| c.output_tokens).sum();
|
||||
let llm_call_count = state.task.llm_calls.len();
|
||||
|
||||
let token_based = state
|
||||
.task
|
||||
.api_calls
|
||||
.iter()
|
||||
.filter(|c| c.pricing_model == PricingModel::PerToken)
|
||||
.count();
|
||||
let flat_rate = state
|
||||
.task
|
||||
.api_calls
|
||||
.iter()
|
||||
.filter(|c| c.pricing_model == PricingModel::FlatRate)
|
||||
.count();
|
||||
|
||||
let record = TaskCostRecord {
|
||||
timestamp_end: Utc::now(),
|
||||
timestamp_start: state.task.start_time.unwrap_or_else(Utc::now),
|
||||
date: state
|
||||
.task
|
||||
.task_date
|
||||
.clone()
|
||||
.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string()),
|
||||
task_id: task_id.clone(),
|
||||
llm_usage: LlmUsageSummary {
|
||||
total_calls: llm_call_count,
|
||||
total_input_tokens: total_input,
|
||||
total_output_tokens: total_output,
|
||||
total_tokens: total_input + total_output,
|
||||
total_cost: state.task.costs.llm_tokens,
|
||||
input_price_per_million: self.config.token_pricing.input_price_per_million,
|
||||
output_price_per_million: self.config.token_pricing.output_price_per_million,
|
||||
calls_detail: state.task.llm_calls.clone(),
|
||||
},
|
||||
api_usage: ApiUsageSummary {
|
||||
total_calls: state.task.api_calls.len(),
|
||||
search_api_cost: state.task.costs.search_api,
|
||||
ocr_api_cost: state.task.costs.ocr_api,
|
||||
other_api_cost: state.task.costs.other_api,
|
||||
token_based_calls: token_based,
|
||||
flat_rate_calls: flat_rate,
|
||||
calls_detail: state.task.api_calls.clone(),
|
||||
},
|
||||
cost_summary: state.task.costs.clone(),
|
||||
balance_after: state.balance,
|
||||
session_cost: state.session.cost,
|
||||
daily_cost: state.daily.cost,
|
||||
};
|
||||
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(self.token_costs_file_path())?;
|
||||
writeln!(file, "{}", serde_json::to_string(&record)?)?;
|
||||
file.sync_all()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_balance_record(
|
||||
&self,
|
||||
date: &str,
|
||||
token_cost_delta: f64,
|
||||
work_income_delta: f64,
|
||||
trading_profit_delta: f64,
|
||||
completed_tasks: Vec<String>,
|
||||
api_error: bool,
|
||||
) -> Result<()> {
|
||||
let state = self.state.lock();
|
||||
|
||||
let task_completion_time = match (state.daily.first_task_start, state.daily.last_task_end) {
|
||||
(Some(start), Some(end)) => Some((end - start).num_seconds() as f64),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let record = BalanceRecord {
|
||||
date: date.to_string(),
|
||||
balance: state.balance,
|
||||
token_cost_delta,
|
||||
work_income_delta,
|
||||
trading_profit_delta,
|
||||
total_token_cost: state.total_token_cost,
|
||||
total_work_income: state.total_work_income,
|
||||
total_trading_profit: state.total_trading_profit,
|
||||
net_worth: state.balance,
|
||||
survival_status: self.get_survival_status_inner(&state).to_string(),
|
||||
completed_tasks,
|
||||
task_id: state.daily.task_ids.first().cloned(),
|
||||
task_completion_time_seconds: task_completion_time,
|
||||
api_error,
|
||||
};
|
||||
|
||||
drop(state); // Release lock before IO
|
||||
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(self.balance_file_path())?;
|
||||
writeln!(file, "{}", serde_json::to_string(&record)?)?;
|
||||
file.sync_all()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn log_work_income(
|
||||
&self,
|
||||
task_id: &str,
|
||||
base_amount: f64,
|
||||
actual_payment: f64,
|
||||
evaluation_score: f64,
|
||||
description: &str,
|
||||
) -> Result<()> {
|
||||
let state = self.state.lock();
|
||||
|
||||
let record = WorkIncomeRecord {
|
||||
timestamp: Utc::now(),
|
||||
date: state
|
||||
.task
|
||||
.task_date
|
||||
.clone()
|
||||
.unwrap_or_else(|| Utc::now().format("%Y-%m-%d").to_string()),
|
||||
task_id: task_id.to_string(),
|
||||
base_amount,
|
||||
actual_payment,
|
||||
evaluation_score,
|
||||
threshold: self.config.min_evaluation_threshold,
|
||||
payment_awarded: actual_payment > 0.0,
|
||||
description: description.to_string(),
|
||||
balance_after: state.balance,
|
||||
};
|
||||
|
||||
drop(state);
|
||||
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(self.token_costs_file_path())?;
|
||||
writeln!(file, "{}", serde_json::to_string(&record)?)?;
|
||||
file.sync_all()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EconomicTracker {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let state = self.state.lock();
|
||||
write!(
|
||||
f,
|
||||
"EconomicTracker(signature='{}', balance=${:.2}, status={})",
|
||||
self.signature,
|
||||
state.balance,
|
||||
self.get_survival_status_inner(&state)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Comprehensive economic summary.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EconomicSummary {
|
||||
pub signature: String,
|
||||
pub balance: f64,
|
||||
pub initial_balance: f64,
|
||||
pub net_worth: f64,
|
||||
pub total_token_cost: f64,
|
||||
pub total_work_income: f64,
|
||||
pub total_trading_profit: f64,
|
||||
pub session_cost: f64,
|
||||
pub daily_cost: f64,
|
||||
pub session_input_tokens: u64,
|
||||
pub session_output_tokens: u64,
|
||||
pub survival_status: SurvivalStatus,
|
||||
pub is_bankrupt: bool,
|
||||
pub min_evaluation_threshold: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_config() -> EconomicConfig {
|
||||
EconomicConfig {
|
||||
enabled: true,
|
||||
initial_balance: 1000.0,
|
||||
token_pricing: TokenPricing {
|
||||
input_price_per_million: 3.0,
|
||||
output_price_per_million: 15.0,
|
||||
},
|
||||
min_evaluation_threshold: 0.6,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_initialization() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config();
|
||||
let tracker = EconomicTracker::new("test-agent", config, Some(tmp.path().to_path_buf()));
|
||||
|
||||
tracker.initialize().unwrap();
|
||||
|
||||
assert!((tracker.get_balance() - 1000.0).abs() < f64::EPSILON);
|
||||
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Thriving);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn track_tokens_reduces_balance() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tracker =
|
||||
EconomicTracker::new("test-agent", test_config(), Some(tmp.path().to_path_buf()));
|
||||
tracker.initialize().unwrap();
|
||||
|
||||
tracker.start_task("task-1", None);
|
||||
let cost = tracker.track_tokens(1000, 500, "agent", None);
|
||||
tracker.end_task().unwrap();
|
||||
|
||||
// (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
|
||||
assert!((cost - 0.0105).abs() < 0.0001);
|
||||
assert!((tracker.get_balance() - (1000.0 - 0.0105)).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn work_income_with_threshold() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tracker =
|
||||
EconomicTracker::new("test-agent", test_config(), Some(tmp.path().to_path_buf()));
|
||||
tracker.initialize().unwrap();
|
||||
|
||||
// Below threshold - no payment
|
||||
let payment = tracker.add_work_income(100.0, "task-1", 0.5, "").unwrap();
|
||||
assert!((payment - 0.0).abs() < f64::EPSILON);
|
||||
assert!((tracker.get_balance() - 1000.0).abs() < f64::EPSILON);
|
||||
|
||||
// At threshold - payment awarded
|
||||
let payment = tracker.add_work_income(100.0, "task-2", 0.6, "").unwrap();
|
||||
assert!((payment - 100.0).abs() < f64::EPSILON);
|
||||
assert!((tracker.get_balance() - 1100.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn survival_status_changes() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut config = test_config();
|
||||
config.initial_balance = 100.0;
|
||||
|
||||
let tracker = EconomicTracker::new("test-agent", config, Some(tmp.path().to_path_buf()));
|
||||
tracker.initialize().unwrap();
|
||||
|
||||
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Thriving);
|
||||
|
||||
// Spend 30% - should be stable
|
||||
tracker.track_tokens(10_000_000, 0, "agent", Some(30.0));
|
||||
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Stable);
|
||||
|
||||
// Spend more to reach struggling
|
||||
tracker.track_tokens(10_000_000, 0, "agent", Some(35.0));
|
||||
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Struggling);
|
||||
|
||||
// Spend more to reach critical
|
||||
tracker.track_tokens(10_000_000, 0, "agent", Some(25.0));
|
||||
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Critical);
|
||||
|
||||
// Bankrupt
|
||||
tracker.track_tokens(10_000_000, 0, "agent", Some(20.0));
|
||||
assert_eq!(tracker.get_survival_status(), SurvivalStatus::Bankrupt);
|
||||
assert!(tracker.is_bankrupt());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_persistence() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config();
|
||||
|
||||
// Create tracker, do some work, save state
|
||||
{
|
||||
let tracker =
|
||||
EconomicTracker::new("test-agent", config.clone(), Some(tmp.path().to_path_buf()));
|
||||
tracker.initialize().unwrap();
|
||||
tracker.track_tokens(1000, 500, "agent", Some(10.0));
|
||||
tracker
|
||||
.save_daily_state("2025-01-01", 0.0, 0.0, vec![], false)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Create new tracker, should load state
|
||||
{
|
||||
let tracker =
|
||||
EconomicTracker::new("test-agent", config, Some(tmp.path().to_path_buf()));
|
||||
tracker.initialize().unwrap();
|
||||
assert!((tracker.get_balance() - 990.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_call_categorization() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tracker =
|
||||
EconomicTracker::new("test-agent", test_config(), Some(tmp.path().to_path_buf()));
|
||||
tracker.initialize().unwrap();
|
||||
|
||||
tracker.start_task("task-1", None);
|
||||
|
||||
// Search API
|
||||
tracker.track_flat_api_call(0.001, "tavily_search");
|
||||
|
||||
// OCR API
|
||||
tracker.track_api_call(1000, 1.0, "ocr_reader");
|
||||
|
||||
// Other API
|
||||
tracker.track_flat_api_call(0.01, "some_api");
|
||||
|
||||
tracker.end_task().unwrap();
|
||||
|
||||
// Balance should reflect all costs
|
||||
let expected_reduction = 0.001 + 0.001 + 0.01; // search + ocr + other
|
||||
assert!((tracker.get_balance() - (1000.0 - expected_reduction)).abs() < 0.0001);
|
||||
}
|
||||
}
|
||||
+8
-4
@@ -362,6 +362,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
&providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: config.api_url.clone(),
|
||||
provider_transport: config.effective_provider_transport(),
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||||
@@ -662,11 +663,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
}
|
||||
|
||||
// Wrap observer with broadcast capability for SSE
|
||||
// Use cost-tracking observer when cost tracking is enabled
|
||||
let base_observer = crate::observability::create_observer_with_cost_tracking(
|
||||
&config.observability,
|
||||
cost_tracker.clone(),
|
||||
&config.cost,
|
||||
);
|
||||
let broadcast_observer: Arc<dyn crate::observability::Observer> =
|
||||
Arc::new(sse::BroadcastObserver::new(
|
||||
crate::observability::create_observer(&config.observability),
|
||||
event_tx.clone(),
|
||||
));
|
||||
Arc::new(sse::BroadcastObserver::new(base_observer, event_tx.clone()));
|
||||
|
||||
let state = AppState {
|
||||
config: config_state,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::{IntegrationCategory, IntegrationEntry, IntegrationStatus};
|
||||
use crate::providers::{
|
||||
is_glm_alias, is_minimax_alias, is_moonshot_alias, is_qianfan_alias, is_qwen_alias,
|
||||
is_zai_alias,
|
||||
is_doubao_alias, is_glm_alias, is_minimax_alias, is_moonshot_alias, is_qianfan_alias,
|
||||
is_qwen_alias, is_siliconflow_alias, is_zai_alias,
|
||||
};
|
||||
|
||||
/// Returns the full catalog of integrations
|
||||
@@ -436,6 +436,33 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Volcengine ARK",
|
||||
description: "Doubao and ARK model catalog",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref().is_some_and(is_doubao_alias) {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "SiliconFlow",
|
||||
description: "OpenAI-compatible hosted models and reasoning",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider
|
||||
.as_deref()
|
||||
.is_some_and(is_siliconflow_alias)
|
||||
{
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Groq",
|
||||
description: "Llama 3.3 70B Versatile and low-latency models",
|
||||
@@ -991,5 +1018,31 @@ mod tests {
|
||||
(qianfan.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
|
||||
config.default_provider = Some("ark".to_string());
|
||||
let volcengine = entries.iter().find(|e| e.name == "Volcengine ARK").unwrap();
|
||||
assert!(matches!(
|
||||
(volcengine.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
|
||||
config.default_provider = Some("volcengine".to_string());
|
||||
assert!(matches!(
|
||||
(volcengine.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
|
||||
config.default_provider = Some("siliconflow".to_string());
|
||||
let siliconflow = entries.iter().find(|e| e.name == "SiliconFlow").unwrap();
|
||||
assert!(matches!(
|
||||
(siliconflow.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
|
||||
config.default_provider = Some("silicon-cloud".to_string());
|
||||
assert!(matches!(
|
||||
(siliconflow.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ pub(crate) mod cost;
|
||||
pub(crate) mod cron;
|
||||
pub(crate) mod daemon;
|
||||
pub(crate) mod doctor;
|
||||
pub mod economic;
|
||||
pub mod gateway;
|
||||
pub mod goals;
|
||||
pub(crate) mod hardware;
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
//! Cost-tracking observer that wires provider token usage to the cost tracker.
|
||||
//!
|
||||
//! Intercepts `LlmResponse` events and records usage to the `CostTracker`,
|
||||
//! calculating costs based on model pricing configuration.
|
||||
|
||||
use super::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
use crate::config::schema::ModelPricing;
|
||||
use crate::cost::{CostTracker, TokenUsage};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Observer that records token usage to a CostTracker.
|
||||
///
|
||||
/// Listens for `LlmResponse` events and calculates costs using model pricing.
|
||||
pub struct CostObserver {
|
||||
tracker: Arc<CostTracker>,
|
||||
prices: HashMap<String, ModelPricing>,
|
||||
/// Default pricing for unknown models (USD per 1M tokens)
|
||||
default_input_price: f64,
|
||||
default_output_price: f64,
|
||||
}
|
||||
|
||||
impl CostObserver {
|
||||
/// Create a new cost observer with the given tracker and pricing config.
|
||||
pub fn new(tracker: Arc<CostTracker>, prices: HashMap<String, ModelPricing>) -> Self {
|
||||
Self {
|
||||
tracker,
|
||||
prices,
|
||||
// Conservative defaults for unknown models
|
||||
default_input_price: 3.0,
|
||||
default_output_price: 15.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up pricing for a model, trying various name formats.
|
||||
fn get_pricing(&self, provider: &str, model: &str) -> (f64, f64) {
|
||||
// Try exact match first: "provider/model"
|
||||
let full_name = format!("{provider}/{model}");
|
||||
if let Some(pricing) = self.prices.get(&full_name) {
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
|
||||
// Try just the model name
|
||||
if let Some(pricing) = self.prices.get(model) {
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
|
||||
// Try model family matching (e.g., "claude-sonnet-4" matches any claude-sonnet-4-*)
|
||||
for (key, pricing) in &self.prices {
|
||||
// Strip provider prefix if present
|
||||
let key_model = key.split('/').next_back().unwrap_or(key);
|
||||
|
||||
// Check if model starts with the key (family match)
|
||||
if model.starts_with(key_model) || key_model.starts_with(model) {
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
|
||||
// Check for common model name patterns
|
||||
// e.g., "claude-3-5-sonnet-20241022" should match "claude-3.5-sonnet"
|
||||
let normalized_model = model.replace('-', ".");
|
||||
let normalized_key = key_model.replace('-', ".");
|
||||
if normalized_model.contains(&normalized_key)
|
||||
|| normalized_key.contains(&normalized_model)
|
||||
{
|
||||
return (pricing.input, pricing.output);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to defaults
|
||||
tracing::debug!(
|
||||
"No pricing found for {}/{}, using defaults (${}/{} per 1M tokens)",
|
||||
provider,
|
||||
model,
|
||||
self.default_input_price,
|
||||
self.default_output_price
|
||||
);
|
||||
(self.default_input_price, self.default_output_price)
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer for CostObserver {
|
||||
fn record_event(&self, event: &ObserverEvent) {
|
||||
if let ObserverEvent::LlmResponse {
|
||||
provider,
|
||||
model,
|
||||
success: true,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
..
|
||||
} = event
|
||||
{
|
||||
// Only record if we have token counts
|
||||
let input = input_tokens.unwrap_or(0);
|
||||
let output = output_tokens.unwrap_or(0);
|
||||
|
||||
if input == 0 && output == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let (input_price, output_price) = self.get_pricing(provider, model);
|
||||
let full_model_name = format!("{provider}/{model}");
|
||||
|
||||
let usage = TokenUsage::new(full_model_name, input, output, input_price, output_price);
|
||||
|
||||
if let Err(e) = self.tracker.record_usage(usage) {
|
||||
tracing::warn!("Failed to record cost usage: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_metric(&self, _metric: &ObserverMetric) {
|
||||
// Cost observer doesn't handle metrics
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"cost"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::CostConfig;
|
||||
use std::time::Duration;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn create_test_tracker() -> (TempDir, Arc<CostTracker>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = CostConfig {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
};
|
||||
let tracker = Arc::new(CostTracker::new(config, tmp.path()).unwrap());
|
||||
(tmp, tracker)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_observer_records_llm_response_usage() {
|
||||
let (_tmp, tracker) = create_test_tracker();
|
||||
let mut prices = HashMap::new();
|
||||
prices.insert(
|
||||
"anthropic/claude-sonnet-4-20250514".into(),
|
||||
ModelPricing {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
},
|
||||
);
|
||||
|
||||
let observer = CostObserver::new(tracker.clone(), prices);
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: "anthropic".into(),
|
||||
model: "claude-sonnet-4-20250514".into(),
|
||||
duration: Duration::from_millis(100),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: Some(1000),
|
||||
output_tokens: Some(500),
|
||||
});
|
||||
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
assert_eq!(summary.request_count, 1);
|
||||
// Cost: (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105
|
||||
assert!((summary.session_cost_usd - 0.0105).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_observer_ignores_failed_responses() {
|
||||
let (_tmp, tracker) = create_test_tracker();
|
||||
let observer = CostObserver::new(tracker.clone(), HashMap::new());
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: "anthropic".into(),
|
||||
model: "claude-sonnet-4".into(),
|
||||
duration: Duration::from_millis(100),
|
||||
success: false,
|
||||
error_message: Some("API error".into()),
|
||||
input_tokens: Some(1000),
|
||||
output_tokens: Some(500),
|
||||
});
|
||||
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
assert_eq!(summary.request_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_observer_ignores_zero_token_responses() {
|
||||
let (_tmp, tracker) = create_test_tracker();
|
||||
let observer = CostObserver::new(tracker.clone(), HashMap::new());
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: "anthropic".into(),
|
||||
model: "claude-sonnet-4".into(),
|
||||
duration: Duration::from_millis(100),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
});
|
||||
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
assert_eq!(summary.request_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_observer_uses_default_pricing_for_unknown_models() {
|
||||
let (_tmp, tracker) = create_test_tracker();
|
||||
let observer = CostObserver::new(tracker.clone(), HashMap::new());
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: "unknown".into(),
|
||||
model: "mystery-model".into(),
|
||||
duration: Duration::from_millis(100),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: Some(1_000_000), // 1M tokens
|
||||
output_tokens: Some(1_000_000),
|
||||
});
|
||||
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
assert_eq!(summary.request_count, 1);
|
||||
// Default: $3 input + $15 output = $18 for 1M each
|
||||
assert!((summary.session_cost_usd - 18.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cost_observer_matches_model_family() {
|
||||
let (_tmp, tracker) = create_test_tracker();
|
||||
let mut prices = HashMap::new();
|
||||
prices.insert(
|
||||
"openai/gpt-4o".into(),
|
||||
ModelPricing {
|
||||
input: 5.0,
|
||||
output: 15.0,
|
||||
},
|
||||
);
|
||||
|
||||
let observer = CostObserver::new(tracker.clone(), prices);
|
||||
|
||||
// Model name with version suffix should still match
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: "openai".into(),
|
||||
model: "gpt-4o-2024-05-13".into(),
|
||||
duration: Duration::from_millis(100),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: Some(1_000_000),
|
||||
output_tokens: Some(0),
|
||||
});
|
||||
|
||||
let summary = tracker.get_summary().unwrap();
|
||||
// Should use $5 input price, not default $3
|
||||
assert!((summary.session_cost_usd - 5.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod cost;
|
||||
pub mod log;
|
||||
pub mod multi;
|
||||
pub mod noop;
|
||||
@@ -12,6 +13,7 @@ pub mod verbose;
|
||||
pub use self::log::LogObserver;
|
||||
#[allow(unused_imports)]
|
||||
pub use self::multi::MultiObserver;
|
||||
pub use cost::CostObserver;
|
||||
pub use noop::NoopObserver;
|
||||
#[cfg(feature = "observability-otel")]
|
||||
pub use otel::OtelObserver;
|
||||
@@ -20,10 +22,40 @@ pub use traits::{Observer, ObserverEvent};
|
||||
#[allow(unused_imports)]
|
||||
pub use verbose::VerboseObserver;
|
||||
|
||||
use crate::config::schema::CostConfig;
|
||||
use crate::config::ObservabilityConfig;
|
||||
use crate::cost::CostTracker;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory: create the right observer from config
|
||||
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
|
||||
create_observer_internal(config)
|
||||
}
|
||||
|
||||
/// Create an observer stack with optional cost tracking.
|
||||
///
|
||||
/// When cost tracking is enabled, wraps the base observer in a MultiObserver
|
||||
/// that also includes a CostObserver for recording token usage.
|
||||
pub fn create_observer_with_cost_tracking(
|
||||
config: &ObservabilityConfig,
|
||||
cost_tracker: Option<Arc<CostTracker>>,
|
||||
cost_config: &CostConfig,
|
||||
) -> Box<dyn Observer> {
|
||||
let base_observer = create_observer_internal(config);
|
||||
|
||||
match cost_tracker {
|
||||
Some(tracker) if cost_config.enabled => {
|
||||
let cost_observer = CostObserver::new(tracker, cost_config.prices.clone());
|
||||
Box::new(MultiObserver::new(vec![
|
||||
base_observer,
|
||||
Box::new(cost_observer),
|
||||
]))
|
||||
}
|
||||
_ => base_observer,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_observer_internal(config: &ObservabilityConfig) -> Box<dyn Observer> {
|
||||
match config.backend.as_str() {
|
||||
"log" => Box::new(LogObserver::new()),
|
||||
"prometheus" => Box::new(PrometheusObserver::new()),
|
||||
|
||||
+139
-5
@@ -18,9 +18,9 @@ use crate::memory::{
|
||||
selectable_memory_backends, MemoryBackendKind,
|
||||
};
|
||||
use crate::providers::{
|
||||
canonical_china_provider_name, is_glm_alias, is_glm_cn_alias, is_minimax_alias,
|
||||
is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_qwen_oauth_alias, is_zai_alias,
|
||||
is_zai_cn_alias,
|
||||
canonical_china_provider_name, is_doubao_alias, is_glm_alias, is_glm_cn_alias,
|
||||
is_minimax_alias, is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_qwen_oauth_alias,
|
||||
is_siliconflow_alias, is_zai_alias, is_zai_cn_alias,
|
||||
};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use console::style;
|
||||
@@ -186,6 +186,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: identity_config,
|
||||
cost: crate::config::CostConfig::default(),
|
||||
economic: crate::config::EconomicConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
@@ -550,6 +551,7 @@ async fn run_quick_setup_with_home(
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
economic: crate::config::EconomicConfig::default(),
|
||||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
hooks: crate::config::HooksConfig::default(),
|
||||
@@ -710,6 +712,9 @@ fn canonical_provider_name(provider_name: &str) -> &str {
|
||||
}
|
||||
|
||||
if let Some(canonical) = canonical_china_provider_name(provider_name) {
|
||||
if canonical == "doubao" {
|
||||
return "volcengine";
|
||||
}
|
||||
return canonical;
|
||||
}
|
||||
|
||||
@@ -775,6 +780,8 @@ fn default_model_for_provider(provider: &str) -> String {
|
||||
"glm" | "zai" => "glm-5".into(),
|
||||
"minimax" => "MiniMax-M2.5".into(),
|
||||
"qwen" => "qwen-plus".into(),
|
||||
"volcengine" => "doubao-1-5-pro-32k-250115".into(),
|
||||
"siliconflow" => "Pro/zai-org/GLM-4.7".into(),
|
||||
"qwen-code" => "qwen3-coder-plus".into(),
|
||||
"ollama" => "llama3.2".into(),
|
||||
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF".into(),
|
||||
@@ -1088,6 +1095,31 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
||||
"Qwen Turbo (fast and cost-efficient)".to_string(),
|
||||
),
|
||||
],
|
||||
"volcengine" => vec![
|
||||
(
|
||||
"doubao-1-5-pro-32k-250115".to_string(),
|
||||
"Doubao 1.5 Pro 32K (official sample model)".to_string(),
|
||||
),
|
||||
(
|
||||
"doubao-seed-1-6-250615".to_string(),
|
||||
"Doubao Seed 1.6 (reasoning flagship)".to_string(),
|
||||
),
|
||||
(
|
||||
"deepseek-v3.2".to_string(),
|
||||
"DeepSeek V3.2 (available in ARK catalog)".to_string(),
|
||||
),
|
||||
],
|
||||
"siliconflow" => vec![
|
||||
(
|
||||
"Pro/zai-org/GLM-4.7".to_string(),
|
||||
"GLM-4.7 Pro (official API example)".to_string(),
|
||||
),
|
||||
(
|
||||
"Pro/deepseek-ai/DeepSeek-V3.2".to_string(),
|
||||
"DeepSeek V3.2 Pro".to_string(),
|
||||
),
|
||||
("Qwen/Qwen3-32B".to_string(), "Qwen3 32B".to_string()),
|
||||
],
|
||||
"qwen-code" => vec![
|
||||
(
|
||||
"qwen3-coder-plus".to_string(),
|
||||
@@ -1264,6 +1296,8 @@ fn supports_live_model_fetch(provider_name: &str) -> bool {
|
||||
| "glm"
|
||||
| "zai"
|
||||
| "qwen"
|
||||
| "volcengine"
|
||||
| "siliconflow"
|
||||
| "nvidia"
|
||||
)
|
||||
}
|
||||
@@ -1276,6 +1310,9 @@ fn models_endpoint_for_provider(provider_name: &str) -> Option<&'static str> {
|
||||
"moonshot-cn" | "kimi-cn" => Some("https://api.moonshot.cn/v1/models"),
|
||||
"glm-cn" | "bigmodel" => Some("https://open.bigmodel.cn/api/paas/v4/models"),
|
||||
"zai-cn" | "z.ai-cn" => Some("https://open.bigmodel.cn/api/coding/paas/v4/models"),
|
||||
"volcengine" | "ark" | "doubao" | "doubao-cn" => {
|
||||
Some("https://ark.cn-beijing.volces.com/api/v3/models")
|
||||
}
|
||||
_ => match canonical_provider_name(provider_name) {
|
||||
"openai-codex" | "openai" => Some("https://api.openai.com/v1/models"),
|
||||
"venice" => Some("https://api.venice.ai/api/v1/models"),
|
||||
@@ -1291,6 +1328,7 @@ fn models_endpoint_for_provider(provider_name: &str) -> Option<&'static str> {
|
||||
"glm" => Some("https://api.z.ai/api/paas/v4/models"),
|
||||
"zai" => Some("https://api.z.ai/api/coding/paas/v4/models"),
|
||||
"qwen" => Some("https://dashscope.aliyuncs.com/compatible-mode/v1/models"),
|
||||
"siliconflow" => Some("https://api.siliconflow.cn/v1/models"),
|
||||
"nvidia" => Some("https://integrate.api.nvidia.com/v1/models"),
|
||||
"astrai" => Some("https://as-trai.com/v1/models"),
|
||||
"llamacpp" => Some("http://localhost:8080/v1/models"),
|
||||
@@ -2303,6 +2341,11 @@ async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String,
|
||||
("qwen-us", "Qwen — DashScope US endpoint"),
|
||||
("hunyuan", "Hunyuan — Tencent large models (T1, Turbo, Pro)"),
|
||||
("qianfan", "Qianfan — Baidu AI models (China endpoint)"),
|
||||
("volcengine", "Volcengine ARK — Doubao model family"),
|
||||
(
|
||||
"siliconflow",
|
||||
"SiliconFlow — OpenAI-compatible hosted models",
|
||||
),
|
||||
("zai", "Z.AI — global coding endpoint"),
|
||||
("zai-cn", "Z.AI — China coding endpoint (open.bigmodel.cn)"),
|
||||
("synthetic", "Synthetic — Synthetic AI models"),
|
||||
@@ -2697,6 +2740,10 @@ async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String,
|
||||
"https://help.aliyun.com/zh/model-studio/developer-reference/get-api-key"
|
||||
} else if is_qianfan_alias(provider_name) {
|
||||
"https://cloud.baidu.com/doc/WENXINWORKSHOP/s/7lm0vxo78"
|
||||
} else if is_doubao_alias(provider_name) {
|
||||
"https://console.volcengine.com/ark/region:ark+cn-beijing/apiKey"
|
||||
} else if is_siliconflow_alias(provider_name) {
|
||||
"https://cloud.siliconflow.cn/account/ak"
|
||||
} else {
|
||||
match provider_name {
|
||||
"openrouter" => "https://openrouter.ai/keys",
|
||||
@@ -3005,6 +3052,8 @@ fn provider_env_var(name: &str) -> &'static str {
|
||||
"glm" => "GLM_API_KEY",
|
||||
"minimax" => "MINIMAX_API_KEY",
|
||||
"qwen" => "DASHSCOPE_API_KEY",
|
||||
"volcengine" => "ARK_API_KEY",
|
||||
"siliconflow" => "SILICONFLOW_API_KEY",
|
||||
"hunyuan" => "HUNYUAN_API_KEY",
|
||||
"qianfan" => "QIANFAN_API_KEY",
|
||||
"zai" => "ZAI_API_KEY",
|
||||
@@ -7305,6 +7354,14 @@ mod tests {
|
||||
assert_eq!(default_model_for_provider("moonshot"), "kimi-k2.5");
|
||||
assert_eq!(default_model_for_provider("hunyuan"), "hunyuan-t1-latest");
|
||||
assert_eq!(default_model_for_provider("tencent"), "hunyuan-t1-latest");
|
||||
assert_eq!(
|
||||
default_model_for_provider("siliconflow"),
|
||||
"Pro/zai-org/GLM-4.7"
|
||||
);
|
||||
assert_eq!(
|
||||
default_model_for_provider("volcengine"),
|
||||
"doubao-1-5-pro-32k-250115"
|
||||
);
|
||||
assert_eq!(
|
||||
default_model_for_provider("nvidia"),
|
||||
"meta/llama-3.3-70b-instruct"
|
||||
@@ -7343,6 +7400,10 @@ mod tests {
|
||||
assert_eq!(canonical_provider_name("minimax-cn"), "minimax");
|
||||
assert_eq!(canonical_provider_name("zai-cn"), "zai");
|
||||
assert_eq!(canonical_provider_name("z.ai-global"), "zai");
|
||||
assert_eq!(canonical_provider_name("doubao"), "volcengine");
|
||||
assert_eq!(canonical_provider_name("ark"), "volcengine");
|
||||
assert_eq!(canonical_provider_name("silicon-cloud"), "siliconflow");
|
||||
assert_eq!(canonical_provider_name("siliconcloud"), "siliconflow");
|
||||
assert_eq!(canonical_provider_name("nvidia-nim"), "nvidia");
|
||||
assert_eq!(canonical_provider_name("aws-bedrock"), "bedrock");
|
||||
assert_eq!(canonical_provider_name("build.nvidia.com"), "nvidia");
|
||||
@@ -7485,6 +7546,23 @@ mod tests {
|
||||
assert!(ids.contains(&"qwen3-max-2026-01-23".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_volcengine_and_siliconflow_include_expected_defaults() {
|
||||
let volcengine_ids: Vec<String> = curated_models_for_provider("volcengine")
|
||||
.into_iter()
|
||||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
assert!(volcengine_ids.contains(&"doubao-1-5-pro-32k-250115".to_string()));
|
||||
assert!(volcengine_ids.contains(&"doubao-seed-1-6-250615".to_string()));
|
||||
|
||||
let siliconflow_ids: Vec<String> = curated_models_for_provider("siliconflow")
|
||||
.into_iter()
|
||||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
assert!(siliconflow_ids.contains(&"Pro/zai-org/GLM-4.7".to_string()));
|
||||
assert!(siliconflow_ids.contains(&"Pro/deepseek-ai/DeepSeek-V3.2".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_live_model_fetch_for_supported_and_unsupported_providers() {
|
||||
assert!(supports_live_model_fetch("openai"));
|
||||
@@ -7506,6 +7584,11 @@ mod tests {
|
||||
assert!(supports_live_model_fetch("glm-cn"));
|
||||
assert!(supports_live_model_fetch("qwen-intl"));
|
||||
assert!(supports_live_model_fetch("qwen-coding-plan"));
|
||||
assert!(supports_live_model_fetch("siliconflow"));
|
||||
assert!(supports_live_model_fetch("silicon-cloud"));
|
||||
assert!(supports_live_model_fetch("volcengine"));
|
||||
assert!(supports_live_model_fetch("doubao"));
|
||||
assert!(supports_live_model_fetch("ark"));
|
||||
assert!(!supports_live_model_fetch("minimax-cn"));
|
||||
assert!(!supports_live_model_fetch("unknown-provider"));
|
||||
}
|
||||
@@ -7564,6 +7647,22 @@ mod tests {
|
||||
curated_models_for_provider("bedrock"),
|
||||
curated_models_for_provider("aws-bedrock")
|
||||
);
|
||||
assert_eq!(
|
||||
curated_models_for_provider("volcengine"),
|
||||
curated_models_for_provider("doubao")
|
||||
);
|
||||
assert_eq!(
|
||||
curated_models_for_provider("volcengine"),
|
||||
curated_models_for_provider("ark")
|
||||
);
|
||||
assert_eq!(
|
||||
curated_models_for_provider("siliconflow"),
|
||||
curated_models_for_provider("silicon-cloud")
|
||||
);
|
||||
assert_eq!(
|
||||
curated_models_for_provider("siliconflow"),
|
||||
curated_models_for_provider("siliconcloud")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -7596,6 +7695,18 @@ mod tests {
|
||||
models_endpoint_for_provider("qwen-coding-plan"),
|
||||
Some("https://coding.dashscope.aliyuncs.com/v1/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("volcengine"),
|
||||
Some("https://ark.cn-beijing.volces.com/api/v3/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("doubao"),
|
||||
Some("https://ark.cn-beijing.volces.com/api/v3/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("ark"),
|
||||
Some("https://ark.cn-beijing.volces.com/api/v3/models")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -7616,6 +7727,14 @@ mod tests {
|
||||
models_endpoint_for_provider("moonshot"),
|
||||
Some("https://api.moonshot.ai/v1/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("siliconflow"),
|
||||
Some("https://api.siliconflow.cn/v1/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("silicon-cloud"),
|
||||
Some("https://api.siliconflow.cn/v1/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("llamacpp"),
|
||||
Some("http://localhost:8080/v1/models")
|
||||
@@ -7914,6 +8033,12 @@ mod tests {
|
||||
assert_eq!(provider_env_var("minimax-oauth-cn"), "MINIMAX_API_KEY");
|
||||
assert_eq!(provider_env_var("moonshot-intl"), "MOONSHOT_API_KEY");
|
||||
assert_eq!(provider_env_var("zai-cn"), "ZAI_API_KEY");
|
||||
assert_eq!(provider_env_var("doubao"), "ARK_API_KEY");
|
||||
assert_eq!(provider_env_var("volcengine"), "ARK_API_KEY");
|
||||
assert_eq!(provider_env_var("ark"), "ARK_API_KEY");
|
||||
assert_eq!(provider_env_var("siliconflow"), "SILICONFLOW_API_KEY");
|
||||
assert_eq!(provider_env_var("silicon-cloud"), "SILICONFLOW_API_KEY");
|
||||
assert_eq!(provider_env_var("siliconcloud"), "SILICONFLOW_API_KEY");
|
||||
assert_eq!(provider_env_var("nvidia"), "NVIDIA_API_KEY");
|
||||
assert_eq!(provider_env_var("nvidia-nim"), "NVIDIA_API_KEY"); // alias
|
||||
assert_eq!(provider_env_var("build.nvidia.com"), "NVIDIA_API_KEY"); // alias
|
||||
@@ -8006,13 +8131,14 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_menu_choices_include_signal_and_nextcloud_talk() {
|
||||
fn channel_menu_choices_include_signal_nextcloud_and_dingtalk() {
|
||||
assert!(channel_menu_choices().contains(&ChannelMenuChoice::Signal));
|
||||
assert!(channel_menu_choices().contains(&ChannelMenuChoice::NextcloudTalk));
|
||||
assert!(channel_menu_choices().contains(&ChannelMenuChoice::DingTalk));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn launchable_channels_include_signal_mattermost_qq_and_nextcloud_talk() {
|
||||
fn launchable_channels_include_signal_mattermost_qq_nextcloud_and_dingtalk() {
|
||||
let mut channels = ChannelsConfig::default();
|
||||
assert!(!has_launchable_channels(&channels));
|
||||
|
||||
@@ -8056,5 +8182,13 @@ mod tests {
|
||||
allowed_users: vec!["*".into()],
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
|
||||
channels.nextcloud_talk = None;
|
||||
channels.dingtalk = Some(crate::config::schema::DingTalkConfig {
|
||||
client_id: "client-id".into(),
|
||||
client_secret: "client-secret".into(),
|
||||
allowed_users: vec!["*".into()],
|
||||
});
|
||||
assert!(has_launchable_channels(&channels));
|
||||
}
|
||||
}
|
||||
|
||||
+56
-14
@@ -74,6 +74,7 @@ const QWEN_OAUTH_DEFAULT_CLIENT_ID: &str = "f0304373b74a44d2b584a3fb70ca9e56";
|
||||
const QWEN_OAUTH_CREDENTIAL_FILE: &str = ".qwen/oauth_creds.json";
|
||||
const ZAI_GLOBAL_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4";
|
||||
const ZAI_CN_BASE_URL: &str = "https://open.bigmodel.cn/api/coding/paas/v4";
|
||||
const SILICONFLOW_BASE_URL: &str = "https://api.siliconflow.cn/v1";
|
||||
const VERCEL_AI_GATEWAY_BASE_URL: &str = "https://ai-gateway.vercel.sh/v1";
|
||||
|
||||
pub(crate) fn is_minimax_intl_alias(name: &str) -> bool {
|
||||
@@ -179,6 +180,10 @@ pub(crate) fn is_doubao_alias(name: &str) -> bool {
|
||||
matches!(name, "doubao" | "volcengine" | "ark" | "doubao-cn")
|
||||
}
|
||||
|
||||
pub(crate) fn is_siliconflow_alias(name: &str) -> bool {
|
||||
matches!(name, "siliconflow" | "silicon-cloud" | "siliconcloud")
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum MinimaxOauthRegion {
|
||||
Global,
|
||||
@@ -618,6 +623,8 @@ pub(crate) fn canonical_china_provider_name(name: &str) -> Option<&'static str>
|
||||
Some("qianfan")
|
||||
} else if is_doubao_alias(name) {
|
||||
Some("doubao")
|
||||
} else if is_siliconflow_alias(name) {
|
||||
Some("siliconflow")
|
||||
} else if matches!(name, "hunyuan" | "tencent") {
|
||||
Some("hunyuan")
|
||||
} else {
|
||||
@@ -683,6 +690,7 @@ fn zai_base_url(name: &str) -> Option<&'static str> {
|
||||
pub struct ProviderRuntimeOptions {
|
||||
pub auth_profile_override: Option<String>,
|
||||
pub provider_api_url: Option<String>,
|
||||
pub provider_transport: Option<String>,
|
||||
pub zeroclaw_dir: Option<PathBuf>,
|
||||
pub secrets_encrypt: bool,
|
||||
pub reasoning_enabled: Option<bool>,
|
||||
@@ -697,6 +705,7 @@ impl Default for ProviderRuntimeOptions {
|
||||
Self {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: None,
|
||||
provider_transport: None,
|
||||
zeroclaw_dir: None,
|
||||
secrets_encrypt: true,
|
||||
reasoning_enabled: None,
|
||||
@@ -872,6 +881,7 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) ->
|
||||
"hunyuan" | "tencent" => vec!["HUNYUAN_API_KEY"],
|
||||
name if is_qianfan_alias(name) => vec!["QIANFAN_API_KEY"],
|
||||
name if is_doubao_alias(name) => vec!["ARK_API_KEY", "DOUBAO_API_KEY"],
|
||||
name if is_siliconflow_alias(name) => vec!["SILICONFLOW_API_KEY"],
|
||||
name if is_qwen_alias(name) => vec!["DASHSCOPE_API_KEY"],
|
||||
name if is_zai_alias(name) => vec!["ZAI_API_KEY"],
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec!["NVIDIA_API_KEY"],
|
||||
@@ -1181,6 +1191,13 @@ fn create_provider_with_url_and_options(
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
name if is_siliconflow_alias(name) => Ok(Box::new(OpenAiCompatibleProvider::new_with_vision(
|
||||
"SiliconFlow",
|
||||
SILICONFLOW_BASE_URL,
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
true,
|
||||
))),
|
||||
name if qwen_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new_with_vision(
|
||||
"Qwen",
|
||||
qwen_base_url(name).expect("checked in guard"),
|
||||
@@ -1512,7 +1529,15 @@ pub fn create_routed_provider_with_options(
|
||||
.then_some(api_url)
|
||||
.flatten();
|
||||
|
||||
let route_options = options.clone();
|
||||
let mut route_options = options.clone();
|
||||
if let Some(transport) = route
|
||||
.transport
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
route_options.provider_transport = Some(transport.to_string());
|
||||
}
|
||||
|
||||
match create_resilient_provider_with_options(
|
||||
&route.provider,
|
||||
@@ -1542,19 +1567,8 @@ pub fn create_routed_provider_with_options(
|
||||
}
|
||||
}
|
||||
|
||||
// Build route table
|
||||
let routes: Vec<(String, router::Route)> = model_routes
|
||||
.iter()
|
||||
.map(|r| {
|
||||
(
|
||||
r.hint.clone(),
|
||||
router::Route {
|
||||
provider_name: r.provider.clone(),
|
||||
model: r.model.clone(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
// Keep only successfully initialized routed providers and preserve
|
||||
// their provider-id bindings (e.g. "<provider>#<hint>").
|
||||
|
||||
Ok(Box::new(
|
||||
router::RouterProvider::new(providers, routes, default_model.to_string())
|
||||
@@ -1712,6 +1726,12 @@ pub fn list_providers() -> Vec<ProviderInfo> {
|
||||
aliases: &["volcengine", "ark", "doubao-cn"],
|
||||
local: false,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "siliconflow",
|
||||
display_name: "SiliconFlow",
|
||||
aliases: &["silicon-cloud", "siliconcloud"],
|
||||
local: false,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "qwen",
|
||||
display_name: "Qwen (DashScope / Qwen Code OAuth)",
|
||||
@@ -2069,6 +2089,9 @@ mod tests {
|
||||
assert!(is_doubao_alias("volcengine"));
|
||||
assert!(is_doubao_alias("ark"));
|
||||
assert!(is_doubao_alias("doubao-cn"));
|
||||
assert!(is_siliconflow_alias("siliconflow"));
|
||||
assert!(is_siliconflow_alias("silicon-cloud"));
|
||||
assert!(is_siliconflow_alias("siliconcloud"));
|
||||
|
||||
assert!(!is_moonshot_alias("openrouter"));
|
||||
assert!(!is_glm_alias("openai"));
|
||||
@@ -2076,6 +2099,7 @@ mod tests {
|
||||
assert!(!is_zai_alias("anthropic"));
|
||||
assert!(!is_qianfan_alias("cohere"));
|
||||
assert!(!is_doubao_alias("deepseek"));
|
||||
assert!(!is_siliconflow_alias("volcengine"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2099,6 +2123,14 @@ mod tests {
|
||||
assert_eq!(canonical_china_provider_name("baidu"), Some("qianfan"));
|
||||
assert_eq!(canonical_china_provider_name("doubao"), Some("doubao"));
|
||||
assert_eq!(canonical_china_provider_name("volcengine"), Some("doubao"));
|
||||
assert_eq!(
|
||||
canonical_china_provider_name("siliconflow"),
|
||||
Some("siliconflow")
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_china_provider_name("silicon-cloud"),
|
||||
Some("siliconflow")
|
||||
);
|
||||
assert_eq!(canonical_china_provider_name("hunyuan"), Some("hunyuan"));
|
||||
assert_eq!(canonical_china_provider_name("tencent"), Some("hunyuan"));
|
||||
assert_eq!(canonical_china_provider_name("openai"), None);
|
||||
@@ -2316,6 +2348,13 @@ mod tests {
|
||||
assert!(create_provider("doubao-cn", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_siliconflow() {
|
||||
assert!(create_provider("siliconflow", Some("key")).is_ok());
|
||||
assert!(create_provider("silicon-cloud", Some("key")).is_ok());
|
||||
assert!(create_provider("siliconcloud", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_qwen() {
|
||||
assert!(create_provider("qwen", Some("key")).is_ok());
|
||||
@@ -2776,6 +2815,8 @@ mod tests {
|
||||
"bedrock",
|
||||
"qianfan",
|
||||
"doubao",
|
||||
"volcengine",
|
||||
"siliconflow",
|
||||
"qwen",
|
||||
"qwen-intl",
|
||||
"qwen-cn",
|
||||
@@ -3049,6 +3090,7 @@ mod tests {
|
||||
model: "anthropic/claude-sonnet-4.6".to_string(),
|
||||
max_tokens: Some(4096),
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
|
||||
let provider = create_routed_provider_with_options(
|
||||
|
||||
+534
-26
@@ -4,21 +4,81 @@ use crate::multimodal;
|
||||
use crate::providers::traits::{ChatMessage, Provider, ProviderCapabilities};
|
||||
use crate::providers::ProviderRuntimeOptions;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
http::{
|
||||
header::{AUTHORIZATION, USER_AGENT},
|
||||
HeaderValue as WsHeaderValue,
|
||||
},
|
||||
Message as WsMessage,
|
||||
},
|
||||
};
|
||||
|
||||
const DEFAULT_CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
|
||||
const CODEX_RESPONSES_URL_ENV: &str = "ZEROCLAW_CODEX_RESPONSES_URL";
|
||||
const CODEX_BASE_URL_ENV: &str = "ZEROCLAW_CODEX_BASE_URL";
|
||||
const CODEX_TRANSPORT_ENV: &str = "ZEROCLAW_CODEX_TRANSPORT";
|
||||
const CODEX_PROVIDER_TRANSPORT_ENV: &str = "ZEROCLAW_PROVIDER_TRANSPORT";
|
||||
const CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY: &str = "ZEROCLAW_RESPONSES_WEBSOCKET";
|
||||
const DEFAULT_CODEX_INSTRUCTIONS: &str =
|
||||
"You are ZeroClaw, a concise and helpful coding assistant.";
|
||||
const CODEX_WS_CONNECT_TIMEOUT: Duration = Duration::from_secs(20);
|
||||
const CODEX_WS_SEND_TIMEOUT: Duration = Duration::from_secs(15);
|
||||
const CODEX_WS_READ_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum CodexTransport {
|
||||
Auto,
|
||||
WebSocket,
|
||||
Sse,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WebsocketRequestError {
|
||||
TransportUnavailable(anyhow::Error),
|
||||
Stream(anyhow::Error),
|
||||
}
|
||||
|
||||
impl WebsocketRequestError {
|
||||
fn transport_unavailable<E>(error: E) -> Self
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
Self::TransportUnavailable(error.into())
|
||||
}
|
||||
|
||||
fn stream<E>(error: E) -> Self
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
Self::Stream(error.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebsocketRequestError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::TransportUnavailable(error) | Self::Stream(error) => error.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WebsocketRequestError {}
|
||||
|
||||
pub struct OpenAiCodexProvider {
|
||||
auth: AuthService,
|
||||
auth_profile_override: Option<String>,
|
||||
responses_url: String,
|
||||
transport: CodexTransport,
|
||||
custom_endpoint: bool,
|
||||
gateway_api_key: Option<String>,
|
||||
reasoning_level: Option<String>,
|
||||
@@ -104,6 +164,7 @@ impl OpenAiCodexProvider {
|
||||
auth_profile_override: options.auth_profile_override.clone(),
|
||||
custom_endpoint: !is_default_responses_url(&responses_url),
|
||||
responses_url,
|
||||
transport: resolve_transport_mode(options)?,
|
||||
gateway_api_key: gateway_api_key.map(ToString::to_string),
|
||||
reasoning_level: normalize_reasoning_level(
|
||||
options.reasoning_level.as_deref(),
|
||||
@@ -204,6 +265,72 @@ fn first_nonempty(text: Option<&str>) -> Option<String> {
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_transport_override(
|
||||
raw: Option<&str>,
|
||||
source: &str,
|
||||
) -> anyhow::Result<Option<CodexTransport>> {
|
||||
let Some(raw_value) = raw else {
|
||||
return Ok(None);
|
||||
};
|
||||
let value = raw_value.trim();
|
||||
if value.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
|
||||
match normalized.as_str() {
|
||||
"auto" => Ok(Some(CodexTransport::Auto)),
|
||||
"websocket" | "ws" => Ok(Some(CodexTransport::WebSocket)),
|
||||
"sse" | "http" => Ok(Some(CodexTransport::Sse)),
|
||||
_ => anyhow::bail!(
|
||||
"Invalid OpenAI Codex transport override '{value}' from {source}; expected one of: auto, websocket, sse"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_legacy_websocket_flag(raw: &str) -> Option<CodexTransport> {
|
||||
let normalized = raw.trim().to_ascii_lowercase();
|
||||
match normalized.as_str() {
|
||||
"1" | "true" | "on" | "yes" => Some(CodexTransport::WebSocket),
|
||||
"0" | "false" | "off" | "no" => Some(CodexTransport::Sse),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_transport_mode(options: &ProviderRuntimeOptions) -> anyhow::Result<CodexTransport> {
|
||||
if let Some(mode) = parse_transport_override(
|
||||
options.provider_transport.as_deref(),
|
||||
"provider.transport runtime override",
|
||||
)? {
|
||||
return Ok(mode);
|
||||
}
|
||||
|
||||
if let Ok(value) = std::env::var(CODEX_TRANSPORT_ENV) {
|
||||
if let Some(mode) = parse_transport_override(Some(&value), CODEX_TRANSPORT_ENV)? {
|
||||
return Ok(mode);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(value) = std::env::var(CODEX_PROVIDER_TRANSPORT_ENV) {
|
||||
if let Some(mode) = parse_transport_override(Some(&value), CODEX_PROVIDER_TRANSPORT_ENV)? {
|
||||
return Ok(mode);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mode) = std::env::var(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY)
|
||||
.ok()
|
||||
.and_then(|value| parse_legacy_websocket_flag(&value))
|
||||
{
|
||||
tracing::warn!(
|
||||
env = CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY,
|
||||
"Using deprecated websocket toggle env for OpenAI Codex transport"
|
||||
);
|
||||
return Ok(mode);
|
||||
}
|
||||
|
||||
Ok(CodexTransport::Auto)
|
||||
}
|
||||
|
||||
fn resolve_instructions(system_prompt: Option<&str>) -> String {
|
||||
first_nonempty(system_prompt).unwrap_or_else(|| DEFAULT_CODEX_INSTRUCTIONS.to_string())
|
||||
}
|
||||
@@ -526,6 +653,283 @@ async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result<St
|
||||
}
|
||||
|
||||
impl OpenAiCodexProvider {
|
||||
fn responses_websocket_url(&self, model: &str) -> anyhow::Result<String> {
|
||||
let mut url = reqwest::Url::parse(&self.responses_url)?;
|
||||
let next_scheme: &'static str = match url.scheme() {
|
||||
"https" | "wss" => "wss",
|
||||
"http" | "ws" => "ws",
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"OpenAI Codex websocket transport does not support URL scheme: {}",
|
||||
other
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
url.set_scheme(next_scheme)
|
||||
.map_err(|()| anyhow::anyhow!("failed to set websocket URL scheme"))?;
|
||||
|
||||
if !url.query_pairs().any(|(k, _)| k == "model") {
|
||||
url.query_pairs_mut().append_pair("model", model);
|
||||
}
|
||||
|
||||
Ok(url.into())
|
||||
}
|
||||
|
||||
fn apply_auth_headers_ws(
|
||||
&self,
|
||||
request: &mut tokio_tungstenite::tungstenite::http::Request<()>,
|
||||
bearer_token: &str,
|
||||
account_id: Option<&str>,
|
||||
access_token: Option<&str>,
|
||||
use_gateway_api_key_auth: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let headers = request.headers_mut();
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
WsHeaderValue::from_str(&format!("Bearer {bearer_token}"))?,
|
||||
);
|
||||
headers.insert(
|
||||
"OpenAI-Beta",
|
||||
WsHeaderValue::from_static("responses=experimental"),
|
||||
);
|
||||
headers.insert("originator", WsHeaderValue::from_static("pi"));
|
||||
headers.insert("accept", WsHeaderValue::from_static("text/event-stream"));
|
||||
headers.insert(USER_AGENT, WsHeaderValue::from_static("zeroclaw"));
|
||||
|
||||
if let Some(account_id) = account_id {
|
||||
headers.insert("chatgpt-account-id", WsHeaderValue::from_str(account_id)?);
|
||||
}
|
||||
|
||||
if use_gateway_api_key_auth {
|
||||
if let Some(access_token) = access_token {
|
||||
headers.insert(
|
||||
"x-openai-access-token",
|
||||
WsHeaderValue::from_str(access_token)?,
|
||||
);
|
||||
}
|
||||
if let Some(account_id) = account_id {
|
||||
headers.insert("x-openai-account-id", WsHeaderValue::from_str(account_id)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_responses_websocket_request(
|
||||
&self,
|
||||
request: &ResponsesRequest,
|
||||
model: &str,
|
||||
bearer_token: &str,
|
||||
account_id: Option<&str>,
|
||||
access_token: Option<&str>,
|
||||
use_gateway_api_key_auth: bool,
|
||||
) -> Result<String, WebsocketRequestError> {
|
||||
let ws_url = self
|
||||
.responses_websocket_url(model)
|
||||
.map_err(WebsocketRequestError::transport_unavailable)?;
|
||||
let mut ws_request = ws_url.into_client_request().map_err(|error| {
|
||||
WebsocketRequestError::transport_unavailable(anyhow::anyhow!(
|
||||
"invalid websocket request URL: {error}"
|
||||
))
|
||||
})?;
|
||||
self.apply_auth_headers_ws(
|
||||
&mut ws_request,
|
||||
bearer_token,
|
||||
account_id,
|
||||
access_token,
|
||||
use_gateway_api_key_auth,
|
||||
)
|
||||
.map_err(WebsocketRequestError::transport_unavailable)?;
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"type": "response.create",
|
||||
"model": &request.model,
|
||||
"input": &request.input,
|
||||
"instructions": &request.instructions,
|
||||
"store": request.store,
|
||||
"text": &request.text,
|
||||
"reasoning": &request.reasoning,
|
||||
"include": &request.include,
|
||||
"tool_choice": &request.tool_choice,
|
||||
"parallel_tool_calls": request.parallel_tool_calls,
|
||||
});
|
||||
|
||||
let (mut ws_stream, _) = timeout(CODEX_WS_CONNECT_TIMEOUT, connect_async(ws_request))
|
||||
.await
|
||||
.map_err(|_| {
|
||||
WebsocketRequestError::transport_unavailable(anyhow::anyhow!(
|
||||
"OpenAI Codex websocket connect timed out after {}s",
|
||||
CODEX_WS_CONNECT_TIMEOUT.as_secs()
|
||||
))
|
||||
})?
|
||||
.map_err(WebsocketRequestError::transport_unavailable)?;
|
||||
timeout(
|
||||
CODEX_WS_SEND_TIMEOUT,
|
||||
ws_stream.send(WsMessage::Text(
|
||||
serde_json::to_string(&payload)
|
||||
.map_err(WebsocketRequestError::transport_unavailable)?
|
||||
.into(),
|
||||
)),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
WebsocketRequestError::transport_unavailable(anyhow::anyhow!(
|
||||
"OpenAI Codex websocket send timed out after {}s",
|
||||
CODEX_WS_SEND_TIMEOUT.as_secs()
|
||||
))
|
||||
})?
|
||||
.map_err(WebsocketRequestError::transport_unavailable)?;
|
||||
|
||||
let mut saw_delta = false;
|
||||
let mut delta_accumulator = String::new();
|
||||
let mut fallback_text: Option<String> = None;
|
||||
let mut timed_out = false;
|
||||
|
||||
loop {
|
||||
let frame = match timeout(CODEX_WS_READ_TIMEOUT, ws_stream.next()).await {
|
||||
Ok(frame) => frame,
|
||||
Err(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
if saw_delta || fallback_text.is_some() {
|
||||
timed_out = true;
|
||||
break;
|
||||
}
|
||||
return Err(WebsocketRequestError::stream(anyhow::anyhow!(
|
||||
"OpenAI Codex websocket stream timed out after {}s waiting for events",
|
||||
CODEX_WS_READ_TIMEOUT.as_secs()
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let Some(frame) = frame else {
|
||||
break;
|
||||
};
|
||||
let frame = frame.map_err(WebsocketRequestError::stream)?;
|
||||
let event: Value = match frame {
|
||||
WsMessage::Text(text) => {
|
||||
serde_json::from_str(text.as_ref()).map_err(WebsocketRequestError::stream)?
|
||||
}
|
||||
WsMessage::Binary(binary) => {
|
||||
let text = String::from_utf8(binary.to_vec()).map_err(|error| {
|
||||
WebsocketRequestError::stream(anyhow::anyhow!(
|
||||
"invalid UTF-8 websocket frame from OpenAI Codex: {error}"
|
||||
))
|
||||
})?;
|
||||
serde_json::from_str(&text).map_err(WebsocketRequestError::stream)?
|
||||
}
|
||||
WsMessage::Ping(payload) => {
|
||||
ws_stream
|
||||
.send(WsMessage::Pong(payload))
|
||||
.await
|
||||
.map_err(WebsocketRequestError::stream)?;
|
||||
continue;
|
||||
}
|
||||
WsMessage::Close(_) => break,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if let Some(message) = extract_stream_error_message(&event) {
|
||||
return Err(WebsocketRequestError::stream(anyhow::anyhow!(
|
||||
"OpenAI Codex websocket stream error: {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(text) = extract_stream_event_text(&event, saw_delta) {
|
||||
let event_type = event.get("type").and_then(Value::as_str);
|
||||
if event_type == Some("response.output_text.delta") {
|
||||
saw_delta = true;
|
||||
delta_accumulator.push_str(&text);
|
||||
} else if fallback_text.is_none() {
|
||||
fallback_text = Some(text);
|
||||
}
|
||||
}
|
||||
|
||||
let event_type = event.get("type").and_then(Value::as_str);
|
||||
if event_type == Some("response.completed") || event_type == Some("response.done") {
|
||||
if let Some(response_value) = event.get("response").cloned() {
|
||||
if let Ok(parsed) = serde_json::from_value::<ResponsesResponse>(response_value)
|
||||
{
|
||||
if let Some(text) = extract_responses_text(&parsed) {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Ok(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if saw_delta {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return nonempty_preserve(Some(&delta_accumulator)).ok_or_else(|| {
|
||||
WebsocketRequestError::stream(anyhow::anyhow!(
|
||||
"No response from OpenAI Codex"
|
||||
))
|
||||
});
|
||||
}
|
||||
if let Some(text) = fallback_text.clone() {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Ok(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if saw_delta {
|
||||
return nonempty_preserve(Some(&delta_accumulator)).ok_or_else(|| {
|
||||
WebsocketRequestError::stream(anyhow::anyhow!("No response from OpenAI Codex"))
|
||||
});
|
||||
}
|
||||
if let Some(text) = fallback_text {
|
||||
return Ok(text);
|
||||
}
|
||||
if timed_out {
|
||||
return Err(WebsocketRequestError::stream(anyhow::anyhow!(
|
||||
"No response from OpenAI Codex websocket stream before timeout"
|
||||
)));
|
||||
}
|
||||
|
||||
Err(WebsocketRequestError::stream(anyhow::anyhow!(
|
||||
"No response from OpenAI Codex websocket stream"
|
||||
)))
|
||||
}
|
||||
|
||||
async fn send_responses_sse_request(
|
||||
&self,
|
||||
request: &ResponsesRequest,
|
||||
bearer_token: &str,
|
||||
account_id: Option<&str>,
|
||||
access_token: Option<&str>,
|
||||
use_gateway_api_key_auth: bool,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut request_builder = self
|
||||
.client
|
||||
.post(&self.responses_url)
|
||||
.header("Authorization", format!("Bearer {bearer_token}"))
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("originator", "pi")
|
||||
.header("accept", "text/event-stream")
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(account_id) = account_id {
|
||||
request_builder = request_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
if use_gateway_api_key_auth {
|
||||
if let Some(access_token) = access_token {
|
||||
request_builder = request_builder.header("x-openai-access-token", access_token);
|
||||
}
|
||||
if let Some(account_id) = account_id {
|
||||
request_builder = request_builder.header("x-openai-account-id", account_id);
|
||||
}
|
||||
}
|
||||
|
||||
let response = request_builder.json(request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("OpenAI Codex", response).await);
|
||||
}
|
||||
|
||||
decode_responses_body(response).await
|
||||
}
|
||||
|
||||
async fn send_responses_request(
|
||||
&self,
|
||||
input: Vec<ResponsesInput>,
|
||||
@@ -613,35 +1017,59 @@ impl OpenAiCodexProvider {
|
||||
access_token.as_deref().unwrap_or_default()
|
||||
};
|
||||
|
||||
let mut request_builder = self
|
||||
.client
|
||||
.post(&self.responses_url)
|
||||
.header("Authorization", format!("Bearer {bearer_token}"))
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("originator", "pi")
|
||||
.header("accept", "text/event-stream")
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(account_id) = account_id.as_deref() {
|
||||
request_builder = request_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
if use_gateway_api_key_auth {
|
||||
if let Some(access_token) = access_token.as_deref() {
|
||||
request_builder = request_builder.header("x-openai-access-token", access_token);
|
||||
match self.transport {
|
||||
CodexTransport::WebSocket => self
|
||||
.send_responses_websocket_request(
|
||||
&request,
|
||||
normalized_model,
|
||||
bearer_token,
|
||||
account_id.as_deref(),
|
||||
access_token.as_deref(),
|
||||
use_gateway_api_key_auth,
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into),
|
||||
CodexTransport::Sse => {
|
||||
self.send_responses_sse_request(
|
||||
&request,
|
||||
bearer_token,
|
||||
account_id.as_deref(),
|
||||
access_token.as_deref(),
|
||||
use_gateway_api_key_auth,
|
||||
)
|
||||
.await
|
||||
}
|
||||
if let Some(account_id) = account_id.as_deref() {
|
||||
request_builder = request_builder.header("x-openai-account-id", account_id);
|
||||
CodexTransport::Auto => {
|
||||
match self
|
||||
.send_responses_websocket_request(
|
||||
&request,
|
||||
normalized_model,
|
||||
bearer_token,
|
||||
account_id.as_deref(),
|
||||
access_token.as_deref(),
|
||||
use_gateway_api_key_auth,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(text) => Ok(text),
|
||||
Err(WebsocketRequestError::TransportUnavailable(error)) => {
|
||||
tracing::warn!(
|
||||
error = %error,
|
||||
"OpenAI Codex websocket request failed; falling back to SSE"
|
||||
);
|
||||
self.send_responses_sse_request(
|
||||
&request,
|
||||
bearer_token,
|
||||
account_id.as_deref(),
|
||||
access_token.as_deref(),
|
||||
use_gateway_api_key_auth,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Err(WebsocketRequestError::Stream(error)) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = request_builder.json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("OpenAI Codex", response).await);
|
||||
}
|
||||
|
||||
decode_responses_body(response).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -809,6 +1237,85 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_transport_mode_defaults_to_auto() {
|
||||
let _env_lock = env_lock();
|
||||
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None);
|
||||
let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, None);
|
||||
let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None);
|
||||
|
||||
assert_eq!(
|
||||
resolve_transport_mode(&ProviderRuntimeOptions::default()).unwrap(),
|
||||
CodexTransport::Auto
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_transport_mode_accepts_runtime_override() {
|
||||
let _env_lock = env_lock();
|
||||
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, Some("sse"));
|
||||
|
||||
let options = ProviderRuntimeOptions {
|
||||
provider_transport: Some("websocket".to_string()),
|
||||
..ProviderRuntimeOptions::default()
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
resolve_transport_mode(&options).unwrap(),
|
||||
CodexTransport::WebSocket
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_transport_mode_legacy_bool_env_is_supported() {
|
||||
let _env_lock = env_lock();
|
||||
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None);
|
||||
let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None);
|
||||
let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, Some("false"));
|
||||
|
||||
assert_eq!(
|
||||
resolve_transport_mode(&ProviderRuntimeOptions::default()).unwrap(),
|
||||
CodexTransport::Sse
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_transport_mode_rejects_invalid_runtime_override() {
|
||||
let _env_lock = env_lock();
|
||||
let _transport_guard = EnvGuard::set(CODEX_TRANSPORT_ENV, None);
|
||||
let _provider_guard = EnvGuard::set("ZEROCLAW_PROVIDER_TRANSPORT", None);
|
||||
let _legacy_guard = EnvGuard::set(CODEX_RESPONSES_WEBSOCKET_ENV_LEGACY, None);
|
||||
|
||||
let options = ProviderRuntimeOptions {
|
||||
provider_transport: Some("udp".to_string()),
|
||||
..ProviderRuntimeOptions::default()
|
||||
};
|
||||
|
||||
let err =
|
||||
resolve_transport_mode(&options).expect_err("invalid runtime transport must fail");
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("Invalid OpenAI Codex transport override 'udp'"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_uses_ws_scheme_and_model_query() {
|
||||
let _env_lock = env_lock();
|
||||
let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None);
|
||||
let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None);
|
||||
|
||||
let options = ProviderRuntimeOptions::default();
|
||||
let provider = OpenAiCodexProvider::new(&options, None).expect("provider should init");
|
||||
let ws_url = provider
|
||||
.responses_websocket_url("gpt-5.3-codex")
|
||||
.expect("websocket URL should be derived");
|
||||
|
||||
assert_eq!(
|
||||
ws_url,
|
||||
"wss://chatgpt.com/backend-api/codex/responses?model=gpt-5.3-codex"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_responses_url_detector_handles_equivalent_urls() {
|
||||
assert!(is_default_responses_url(DEFAULT_CODEX_RESPONSES_URL));
|
||||
@@ -1077,6 +1584,7 @@ data: [DONE]
|
||||
fn capabilities_includes_vision() {
|
||||
let options = ProviderRuntimeOptions {
|
||||
provider_api_url: None,
|
||||
provider_transport: None,
|
||||
zeroclaw_dir: None,
|
||||
secrets_encrypt: false,
|
||||
auth_profile_override: None,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+34
-26
@@ -31,6 +31,8 @@ pub mod cron_update;
|
||||
pub mod delegate;
|
||||
pub mod delegate_coordination_status;
|
||||
pub mod docx_read;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
pub mod feishu_doc;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
@@ -86,6 +88,8 @@ pub use cron_update::CronUpdateTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use delegate_coordination_status::DelegateCoordinationStatusTool;
|
||||
pub use docx_read::DocxReadTool;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
pub use feishu_doc::FeishuDocTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
@@ -428,6 +432,7 @@ pub fn all_tools_with_runtime(
|
||||
let provider_runtime_options = crate::providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: root_config.api_url.clone(),
|
||||
provider_transport: root_config.effective_provider_transport(),
|
||||
zeroclaw_dir: root_config
|
||||
.config_path
|
||||
.parent()
|
||||
@@ -512,38 +517,41 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Inter-process agent communication (opt-in)
|
||||
if root_config.agents_ipc.enabled {
|
||||
match agents_ipc::IpcDb::open(workspace_dir, &root_config.agents_ipc) {
|
||||
Ok(ipc_db) => {
|
||||
let ipc_db = Arc::new(ipc_db);
|
||||
tool_arcs.push(Arc::new(agents_ipc::AgentsListTool::new(ipc_db.clone())));
|
||||
tool_arcs.push(Arc::new(agents_ipc::AgentsSendTool::new(
|
||||
ipc_db.clone(),
|
||||
// Feishu document tools (enabled when channel-lark feature is active)
|
||||
#[cfg(feature = "channel-lark")]
|
||||
{
|
||||
let feishu_creds = root_config
|
||||
.channels_config
|
||||
.feishu
|
||||
.as_ref()
|
||||
.map(|fs| (fs.app_id.clone(), fs.app_secret.clone(), true))
|
||||
.or_else(|| {
|
||||
root_config
|
||||
.channels_config
|
||||
.lark
|
||||
.as_ref()
|
||||
.map(|lk| (lk.app_id.clone(), lk.app_secret.clone(), lk.use_feishu))
|
||||
});
|
||||
|
||||
if let Some((app_id, app_secret, use_feishu)) = feishu_creds {
|
||||
let app_id = app_id.trim().to_string();
|
||||
let app_secret = app_secret.trim().to_string();
|
||||
if app_id.is_empty() || app_secret.is_empty() {
|
||||
tracing::warn!(
|
||||
"feishu_doc: skipped registration because app credentials are empty"
|
||||
);
|
||||
} else {
|
||||
tool_arcs.push(Arc::new(FeishuDocTool::new(
|
||||
app_id,
|
||||
app_secret,
|
||||
use_feishu,
|
||||
security.clone(),
|
||||
)));
|
||||
tool_arcs.push(Arc::new(agents_ipc::AgentsInboxTool::new(ipc_db.clone())));
|
||||
tool_arcs.push(Arc::new(agents_ipc::StateGetTool::new(ipc_db.clone())));
|
||||
tool_arcs.push(Arc::new(agents_ipc::StateSetTool::new(
|
||||
ipc_db,
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("agents_ipc: failed to open IPC database: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load WASM plugin tools from the skills directory.
|
||||
// Each installed skill package may ship one or more WASM tools under
|
||||
// `<skill-dir>/tools/<tool-name>/{tool.wasm, manifest.json}`.
|
||||
// Failures are logged and skipped — a broken plugin must not block startup.
|
||||
let skills_dir = workspace_dir.join("skills");
|
||||
let mut boxed = boxed_registry_from_arcs(tool_arcs);
|
||||
let wasm_tools = wasm_tool::load_wasm_tools_from_skills(&skills_dir);
|
||||
boxed.extend(wasm_tools);
|
||||
boxed
|
||||
boxed_registry_from_arcs(tool_arcs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -125,6 +125,42 @@ impl ModelRoutingConfigTool {
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn normalize_transport_value(raw: &str, field: &str) -> anyhow::Result<String> {
|
||||
let normalized = raw.trim().to_ascii_lowercase().replace(['-', '_'], "");
|
||||
match normalized.as_str() {
|
||||
"auto" => Ok("auto".to_string()),
|
||||
"websocket" | "ws" => Ok("websocket".to_string()),
|
||||
"sse" | "http" => Ok("sse".to_string()),
|
||||
_ => anyhow::bail!("'{field}' must be one of: auto, websocket, sse"),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_optional_transport_update(
|
||||
args: &Value,
|
||||
field: &str,
|
||||
) -> anyhow::Result<MaybeSet<String>> {
|
||||
let Some(raw) = args.get(field) else {
|
||||
return Ok(MaybeSet::Unset);
|
||||
};
|
||||
|
||||
if raw.is_null() {
|
||||
return Ok(MaybeSet::Null);
|
||||
}
|
||||
|
||||
let value = raw
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
|
||||
.trim();
|
||||
|
||||
if value.is_empty() {
|
||||
return Ok(MaybeSet::Null);
|
||||
}
|
||||
|
||||
Ok(MaybeSet::Set(Self::normalize_transport_value(
|
||||
value, field,
|
||||
)?))
|
||||
}
|
||||
|
||||
fn parse_optional_f64_update(args: &Value, field: &str) -> anyhow::Result<MaybeSet<f64>> {
|
||||
let Some(raw) = args.get(field) else {
|
||||
return Ok(MaybeSet::Unset);
|
||||
@@ -217,6 +253,7 @@ impl ModelRoutingConfigTool {
|
||||
"hint": route.hint,
|
||||
"provider": route.provider,
|
||||
"model": route.model,
|
||||
"transport": route.transport,
|
||||
"api_key_configured": has_provider_credential(&route.provider, route.api_key.as_deref()),
|
||||
"classification": classification,
|
||||
})
|
||||
@@ -429,6 +466,7 @@ impl ModelRoutingConfigTool {
|
||||
let provider = Self::parse_non_empty_string(args, "provider")?;
|
||||
let model = Self::parse_non_empty_string(args, "model")?;
|
||||
let api_key_update = Self::parse_optional_string_update(args, "api_key")?;
|
||||
let transport_update = Self::parse_optional_transport_update(args, "transport")?;
|
||||
|
||||
let keywords_update = if let Some(raw) = args.get("keywords") {
|
||||
Some(Self::parse_string_list(raw, "keywords")?)
|
||||
@@ -466,6 +504,7 @@ impl ModelRoutingConfigTool {
|
||||
model: model.clone(),
|
||||
max_tokens: None,
|
||||
api_key: None,
|
||||
transport: None,
|
||||
});
|
||||
|
||||
next_route.hint = hint.clone();
|
||||
@@ -478,6 +517,12 @@ impl ModelRoutingConfigTool {
|
||||
MaybeSet::Unset => {}
|
||||
}
|
||||
|
||||
match transport_update {
|
||||
MaybeSet::Set(transport) => next_route.transport = Some(transport),
|
||||
MaybeSet::Null => next_route.transport = None,
|
||||
MaybeSet::Unset => {}
|
||||
}
|
||||
|
||||
cfg.model_routes.retain(|route| route.hint != hint);
|
||||
cfg.model_routes.push(next_route);
|
||||
Self::normalize_and_sort_routes(&mut cfg.model_routes);
|
||||
@@ -782,6 +827,11 @@ impl Tool for ModelRoutingConfigTool {
|
||||
"type": ["string", "null"],
|
||||
"description": "Optional API key override for scenario route or delegate agent"
|
||||
},
|
||||
"transport": {
|
||||
"type": ["string", "null"],
|
||||
"enum": ["auto", "websocket", "sse", "ws", "http", null],
|
||||
"description": "Optional route transport override for upsert_scenario (auto, websocket, sse)"
|
||||
},
|
||||
"keywords": {
|
||||
"description": "Classification keywords for upsert_scenario (string or string array)",
|
||||
"oneOf": [
|
||||
@@ -1003,6 +1053,7 @@ mod tests {
|
||||
"hint": "coding",
|
||||
"provider": "openai",
|
||||
"model": "gpt-5.3-codex",
|
||||
"transport": "websocket",
|
||||
"classification_enabled": true,
|
||||
"keywords": ["code", "bug", "refactor"],
|
||||
"patterns": ["```"],
|
||||
@@ -1024,9 +1075,58 @@ mod tests {
|
||||
item["hint"] == json!("coding")
|
||||
&& item["provider"] == json!("openai")
|
||||
&& item["model"] == json!("gpt-5.3-codex")
|
||||
&& item["transport"] == json!("websocket")
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_scenario_transport_alias_is_canonicalized() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ModelRoutingConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "upsert_scenario",
|
||||
"hint": "analysis",
|
||||
"provider": "openai",
|
||||
"model": "gpt-5.3-codex",
|
||||
"transport": "WS"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success, "{:?}", result.error);
|
||||
|
||||
let get_result = tool.execute(json!({"action": "get"})).await.unwrap();
|
||||
assert!(get_result.success);
|
||||
let output: Value = serde_json::from_str(&get_result.output).unwrap();
|
||||
let scenarios = output["scenarios"].as_array().unwrap();
|
||||
assert!(scenarios.iter().any(|item| {
|
||||
item["hint"] == json!("analysis") && item["transport"] == json!("websocket")
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_scenario_rejects_invalid_transport() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ModelRoutingConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "upsert_scenario",
|
||||
"hint": "analysis",
|
||||
"provider": "openai",
|
||||
"model": "gpt-5.3-codex",
|
||||
"transport": "udp"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.unwrap_or_default()
|
||||
.contains("'transport' must be one of: auto, websocket, sse"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_scenario_also_removes_rule() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user