Merge remote-tracking branch 'origin/main' into pr2093-mainmerge

This commit is contained in:
argenis de la rosa
2026-02-28 17:33:17 -05:00
145 changed files with 15627 additions and 1355 deletions
+17 -10
View File
@@ -218,9 +218,7 @@ impl AgentBuilder {
.memory_loader
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
config: self.config.unwrap_or_default(),
model_name: self
.model_name
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
model_name: crate::config::resolve_default_model_id(self.model_name.as_deref(), None),
temperature: self.temperature.unwrap_or(0.7),
workspace_dir: self
.workspace_dir
@@ -298,11 +296,10 @@ impl Agent {
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
let model_name = config
.default_model
.as_deref()
.unwrap_or("anthropic/claude-sonnet-4-20250514")
.to_string();
let model_name = crate::config::resolve_default_model_id(
config.default_model.as_deref(),
Some(provider_name),
);
let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name,
@@ -714,8 +711,12 @@ pub async fn run(
let model_name = effective_config
.default_model
.as_deref()
.unwrap_or("anthropic/claude-sonnet-4-20250514")
.to_string();
.map(str::trim)
.filter(|m| !m.is_empty())
.map(str::to_string)
.unwrap_or_else(|| {
crate::config::default_model_fallback_for_provider(Some(&provider_name)).to_string()
});
agent.observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.clone(),
@@ -776,6 +777,7 @@ mod tests {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
});
}
Ok(guard.remove(0))
@@ -813,6 +815,7 @@ mod tests {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
});
}
Ok(guard.remove(0))
@@ -852,6 +855,7 @@ mod tests {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
}]),
});
@@ -892,12 +896,14 @@ mod tests {
}],
usage: None,
reasoning_content: None,
quota_metadata: None,
},
crate::providers::ChatResponse {
text: Some("done".into()),
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
},
]),
});
@@ -939,6 +945,7 @@ mod tests {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
}]),
seen_models: seen_models.clone(),
});
+2
View File
@@ -263,6 +263,7 @@ mod tests {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
};
let dispatcher = XmlToolDispatcher;
let (_, calls) = dispatcher.parse_response(&response);
@@ -281,6 +282,7 @@ mod tests {
}],
usage: None,
reasoning_content: None,
quota_metadata: None,
};
let dispatcher = NativeToolDispatcher;
let (_, calls) = dispatcher.parse_response(&response);
+301 -124
View File
@@ -290,6 +290,20 @@ pub(crate) struct NonCliApprovalContext {
tokio::task_local! {
static TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT: Option<NonCliApprovalContext>;
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
static SAFETY_HEARTBEAT_CONFIG: Option<SafetyHeartbeatConfig>;
}
/// Configuration for periodic safety-constraint re-injection (heartbeat).
#[derive(Clone)]
pub(crate) struct SafetyHeartbeatConfig {
/// Pre-rendered security policy summary text.
pub body: String,
/// Inject a heartbeat every `interval` tool iterations (0 = disabled).
pub interval: usize,
}
fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool {
interval > 0 && counter > 0 && counter % interval == 0
}
/// Extract a short hint from tool call arguments for progress display.
@@ -687,33 +701,37 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
hooks: Option<&crate::hooks::HookRunner>,
excluded_tools: &[String],
safety_heartbeat: Option<SafetyHeartbeatConfig>,
) -> Result<String> {
let reply_target = non_cli_approval_context
.as_ref()
.map(|ctx| ctx.reply_target.clone());
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT
SAFETY_HEARTBEAT_CONFIG
.scope(
non_cli_approval_context,
TOOL_LOOP_REPLY_TARGET.scope(
reply_target,
run_tool_call_loop(
provider,
history,
tools_registry,
observer,
provider_name,
model,
temperature,
silent,
approval,
channel_name,
multimodal_config,
max_tool_iterations,
cancellation_token,
on_delta,
hooks,
excluded_tools,
safety_heartbeat,
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
non_cli_approval_context,
TOOL_LOOP_REPLY_TARGET.scope(
reply_target,
run_tool_call_loop(
provider,
history,
tools_registry,
observer,
provider_name,
model,
temperature,
silent,
approval,
channel_name,
multimodal_config,
max_tool_iterations,
cancellation_token,
on_delta,
hooks,
excluded_tools,
),
),
),
)
@@ -788,6 +806,10 @@ pub(crate) async fn run_tool_call_loop(
.unwrap_or_default();
let mut loop_detector = LoopDetector::new(ld_config);
let mut loop_detection_prompt: Option<String> = None;
let heartbeat_config = SAFETY_HEARTBEAT_CONFIG
.try_with(Clone::clone)
.ok()
.flatten();
let bypass_non_cli_approval_for_turn =
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
if bypass_non_cli_approval_for_turn {
@@ -835,6 +857,19 @@ pub(crate) async fn run_tool_call_loop(
request_messages.push(ChatMessage::user(prompt));
}
// ── Safety heartbeat: periodic security-constraint re-injection ──
if let Some(ref hb) = heartbeat_config {
if should_inject_safety_heartbeat(iteration, hb.interval) {
let reminder = format!(
"[Safety Heartbeat — round {}/{}]\n{}",
iteration + 1,
max_iterations,
hb.body
);
request_messages.push(ChatMessage::user(reminder));
}
}
// ── Progress: LLM thinking ────────────────────────────
if let Some(ref tx) = on_delta {
let phase = if iteration == 0 {
@@ -1244,13 +1279,30 @@ pub(crate) async fn run_tool_call_loop(
// ── Approval hook ────────────────────────────────
if let Some(mgr) = approval {
if bypass_non_cli_approval_for_turn {
let non_cli_session_granted =
channel_name != "cli" && mgr.is_non_cli_session_granted(&tool_name);
if bypass_non_cli_approval_for_turn || non_cli_session_granted {
mgr.record_decision(
&tool_name,
&tool_args,
ApprovalResponse::Yes,
channel_name,
);
if non_cli_session_granted {
runtime_trace::record_event(
"approval_bypass_non_cli_session_grant",
Some(channel_name),
Some(provider_name),
Some(model),
Some(&turn_id),
Some(true),
Some("using runtime non-cli session approval grant"),
serde_json::json!({
"iteration": iteration + 1,
"tool": tool_name.clone(),
}),
);
}
} else if mgr.needs_approval(&tool_name) {
let request = ApprovalRequest {
tool_name: tool_name.clone(),
@@ -1765,10 +1817,12 @@ pub async fn run(
.or(config.default_provider.as_deref())
.unwrap_or("openrouter");
let model_name = model_override
.as_deref()
.or(config.default_model.as_deref())
.unwrap_or("anthropic/claude-sonnet-4");
let model_name = crate::config::resolve_default_model_id(
model_override
.as_deref()
.or(config.default_model.as_deref()),
Some(provider_name),
);
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
@@ -1789,7 +1843,7 @@ pub async fn run(
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
model_name,
&model_name,
&provider_runtime_options,
)?;
@@ -1837,6 +1891,10 @@ pub async fn run(
"memory_store",
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
),
(
"memory_observe",
"Store observation memory. Use when: capturing patterns/signals that should remain searchable over long horizons.",
),
(
"memory_recall",
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
@@ -1948,7 +2006,7 @@ pub async fn run(
let native_tools = provider.supports_native_tools();
let mut system_prompt = crate::channels::build_system_prompt_with_mode(
&config.workspace_dir,
model_name,
&model_name,
&tool_descs,
&skills,
Some(&config.identity),
@@ -1987,7 +2045,7 @@ pub async fn run(
// Inject memory + hardware RAG context into user message
let mem_context =
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await;
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score, None).await;
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
let hw_context = hardware_rag
.as_ref()
@@ -2011,26 +2069,37 @@ pub async fn run(
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
failure_streak_threshold: config.agent.loop_detection_failure_streak,
};
let response = LOOP_DETECTION_CONFIG
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
Some(SafetyHeartbeatConfig {
body: security.summary_for_heartbeat(),
interval: config.agent.safety_heartbeat_interval,
})
} else {
None
};
let response = SAFETY_HEARTBEAT_CONFIG
.scope(
ld_cfg,
run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
&model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
),
),
)
.await?;
@@ -2044,6 +2113,7 @@ pub async fn run(
// Persistent conversation history across turns
let mut history = vec![ChatMessage::system(&system_prompt)];
let mut interactive_turn: usize = 0;
// Reusable readline editor for UTF-8 input support
let mut rl = Editor::with_config(
RlConfig::builder()
@@ -2094,6 +2164,7 @@ pub async fn run(
rl.clear_history()?;
history.clear();
history.push(ChatMessage::system(&system_prompt));
interactive_turn = 0;
// Clear conversation and daily memory
let mut cleared = 0;
for category in [MemoryCategory::Conversation, MemoryCategory::Daily] {
@@ -2123,8 +2194,13 @@ pub async fn run(
}
// Inject memory + hardware RAG context into user message
let mem_context =
build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await;
let mem_context = build_context(
mem.as_ref(),
&user_input,
config.memory.min_relevance_score,
None,
)
.await;
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
let hw_context = hardware_rag
.as_ref()
@@ -2139,32 +2215,57 @@ pub async fn run(
};
history.push(ChatMessage::user(&enriched));
interactive_turn += 1;
// Inject interactive safety heartbeat at configured turn intervals
if should_inject_safety_heartbeat(
interactive_turn,
config.agent.safety_heartbeat_turn_interval,
) {
let reminder = format!(
"[Safety Heartbeat — turn {}]\n{}",
interactive_turn,
security.summary_for_heartbeat()
);
history.push(ChatMessage::user(reminder));
}
let ld_cfg = LoopDetectionConfig {
no_progress_threshold: config.agent.loop_detection_no_progress_threshold,
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
failure_streak_threshold: config.agent.loop_detection_failure_streak,
};
let response = match LOOP_DETECTION_CONFIG
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
Some(SafetyHeartbeatConfig {
body: security.summary_for_heartbeat(),
interval: config.agent.safety_heartbeat_interval,
})
} else {
None
};
let response = match SAFETY_HEARTBEAT_CONFIG
.scope(
ld_cfg,
run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
run_tool_call_loop(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
&model_name,
temperature,
false,
approval_manager.as_ref(),
channel_name,
&config.multimodal,
config.agent.max_tool_iterations,
None,
None,
None,
&[],
),
),
)
.await
@@ -2209,7 +2310,7 @@ pub async fn run(
if let Ok(compacted) = auto_compact_history(
&mut history,
provider.as_ref(),
model_name,
&model_name,
config.agent.max_history_messages,
)
.await
@@ -2238,13 +2339,15 @@ pub async fn run(
/// Process a single message through the full agent (with tools, peripherals, memory).
/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use.
pub async fn process_message(
pub async fn process_message(config: Config, message: &str) -> Result<String> {
process_message_with_session(config, message, None).await
}
pub async fn process_message_with_session(
config: Config,
message: &str,
sender_id: &str,
channel_name: &str,
session_id: Option<&str>,
) -> Result<String> {
tracing::debug!(sender_id, channel_name, "process_message called");
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let runtime: Arc<dyn runtime::RuntimeAdapter> =
@@ -2288,10 +2391,10 @@ pub async fn process_message(
tools_registry.extend(peripheral_tools);
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
let model_name = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let model_name = crate::config::resolve_default_model_id(
config.default_model.as_deref(),
Some(provider_name),
);
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
provider_api_url: config.api_url.clone(),
@@ -2335,6 +2438,7 @@ pub async fn process_message(
("file_read", "Read file contents."),
("file_write", "Write file contents."),
("memory_store", "Save to memory."),
("memory_observe", "Store observation memory."),
("memory_recall", "Search memory."),
("memory_forget", "Delete a memory entry."),
(
@@ -2407,7 +2511,13 @@ pub async fn process_message(
}
system_prompt.push_str(&build_shell_policy_instructions(&config.autonomy));
let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await;
let mem_context = build_context(
mem.as_ref(),
message,
config.memory.min_relevance_score,
session_id,
)
.await;
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
let hw_context = hardware_rag
.as_ref()
@@ -2433,53 +2543,31 @@ pub async fn process_message(
.filter(|m| crate::providers::is_user_or_assistant_role(m.role.as_str()))
.collect();
let mut history = Vec::new();
history.push(ChatMessage::system(&system_prompt));
history.extend(filtered_history);
history.push(ChatMessage::user(&enriched));
let reply = agent_turn(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
&model_name,
config.default_temperature,
true,
&config.multimodal,
config.agent.max_tool_iterations,
)
.await?;
let persisted: Vec<ChatMessage> = history
.into_iter()
.filter(|m| crate::providers::is_user_or_assistant_role(m.role.as_str()))
.collect();
let saved_len = persisted.len();
session
.update_history(persisted)
.await
.context("Failed to update session history")?;
tracing::debug!(saved_len, "session history saved");
Ok(reply)
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
Some(SafetyHeartbeatConfig {
body: security.summary_for_heartbeat(),
interval: config.agent.safety_heartbeat_interval,
})
} else {
let mut history = vec![
ChatMessage::system(&system_prompt),
ChatMessage::user(&enriched),
];
agent_turn(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
&model_name,
config.default_temperature,
true,
&config.multimodal,
config.agent.max_tool_iterations,
None
};
SAFETY_HEARTBEAT_CONFIG
.scope(
hb_cfg,
agent_turn(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
provider_name,
&model_name,
config.default_temperature,
true,
&config.multimodal,
config.agent.max_tool_iterations,
),
)
.await
}
}
#[cfg(test)]
@@ -2574,6 +2662,36 @@ mod tests {
assert_eq!(feishu_args["delivery"]["to"], "oc_yyy");
}
#[test]
fn safety_heartbeat_interval_zero_disables_injection() {
for counter in [0, 1, 2, 10, 100] {
assert!(
!should_inject_safety_heartbeat(counter, 0),
"counter={counter} should not inject when interval=0"
);
}
}
#[test]
fn safety_heartbeat_interval_one_injects_every_non_initial_step() {
assert!(!should_inject_safety_heartbeat(0, 1));
for counter in 1..=6 {
assert!(
should_inject_safety_heartbeat(counter, 1),
"counter={counter} should inject when interval=1"
);
}
}
#[test]
fn safety_heartbeat_injects_only_on_exact_multiples() {
let interval = 3;
let injected: Vec<usize> = (0..=10)
.filter(|counter| should_inject_safety_heartbeat(*counter, interval))
.collect();
assert_eq!(injected, vec![3, 6, 9]);
}
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::observability::NoopObserver;
use crate::providers::traits::ProviderCapabilities;
@@ -2643,6 +2761,7 @@ mod tests {
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
quota_metadata: None,
})
}
}
@@ -2661,6 +2780,7 @@ mod tests {
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
quota_metadata: None,
})
.collect();
Self {
@@ -3183,6 +3303,62 @@ mod tests {
);
}
#[tokio::test]
async fn run_tool_call_loop_uses_non_cli_session_grant_without_waiting_for_prompt() {
let provider = ScriptedProvider::from_text_responses(vec![
r#"<tool_call>
{"name":"shell","arguments":{"command":"echo hi"}}
</tool_call>"#,
"done",
]);
let active = Arc::new(AtomicUsize::new(0));
let max_active = Arc::new(AtomicUsize::new(0));
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(DelayTool::new(
"shell",
50,
Arc::clone(&active),
Arc::clone(&max_active),
))];
let approval_mgr = ApprovalManager::from_config(&crate::config::AutonomyConfig::default());
approval_mgr.grant_non_cli_session("shell");
let mut history = vec![
ChatMessage::system("test-system"),
ChatMessage::user("run shell"),
];
let observer = NoopObserver;
let result = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
Some(&approval_mgr),
"telegram",
&crate::config::MultimodalConfig::default(),
4,
None,
None,
None,
&[],
)
.await
.expect("tool loop should consume non-cli session grants");
assert_eq!(result, "done");
assert_eq!(
max_active.load(Ordering::SeqCst),
1,
"shell tool should execute when runtime non-cli session grant exists"
);
}
#[tokio::test]
async fn run_tool_call_loop_waits_for_non_cli_approval_resolution() {
let provider = ScriptedProvider::from_text_responses(vec![
@@ -3252,6 +3428,7 @@ mod tests {
None,
None,
&[],
None,
)
.await
.expect("tool loop should continue after non-cli approval");
@@ -4397,7 +4574,7 @@ Tail"#;
.await
.unwrap();
let context = build_context(&mem, "status updates", 0.0).await;
let context = build_context(&mem, "status updates", 0.0, None).await;
assert!(context.contains("user_msg_real"));
assert!(!context.contains("assistant_resp_poisoned"));
assert!(!context.contains("fabricated event"));
+2 -1
View File
@@ -8,11 +8,12 @@ pub(super) async fn build_context(
mem: &dyn Memory,
user_msg: &str,
min_relevance_score: f64,
session_id: Option<&str>,
) -> String {
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
let relevant: Vec<_> = entries
.iter()
.filter(|e| match e.score {
+25 -1
View File
@@ -220,7 +220,9 @@ impl LoopDetector {
fn hash_output(output: &str) -> u64 {
let prefix = if output.len() > OUTPUT_HASH_PREFIX_BYTES {
&output[..OUTPUT_HASH_PREFIX_BYTES]
// Use floor_utf8_char_boundary to avoid panic on multi-byte UTF-8 characters
let boundary = crate::util::floor_utf8_char_boundary(output, OUTPUT_HASH_PREFIX_BYTES);
&output[..boundary]
} else {
output
};
@@ -386,4 +388,26 @@ mod tests {
det.record_call("shell", r#"{"cmd":"cargo test"}"#, "ok", true);
assert_eq!(det.check(), DetectionVerdict::Continue);
}
// 11. UTF-8 boundary safety: hash_output must not panic on CJK text
#[test]
fn hash_output_utf8_boundary_safe() {
// Create a string where byte 4096 lands inside a multi-byte char
// Chinese chars are 3 bytes each, so 1366 chars = 4098 bytes
let cjk_text: String = "".repeat(1366); // 4098 bytes
assert!(cjk_text.len() > super::OUTPUT_HASH_PREFIX_BYTES);
// This should NOT panic
let hash1 = super::hash_output(&cjk_text);
// Different content should produce different hash
let cjk_text2: String = "".repeat(1366);
let hash2 = super::hash_output(&cjk_text2);
assert_ne!(hash1, hash2);
// Mixed ASCII + CJK at boundary
let mixed = "a".repeat(4094) + "文文"; // 4094 + 6 = 4100 bytes, boundary at 4096
let hash3 = super::hash_output(&mixed);
assert!(hash3 != 0); // Just verify it runs
}
}
+106 -3
View File
@@ -28,8 +28,13 @@ pub(super) fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
}
let start = if has_system { 1 } else { 0 };
let to_remove = non_system_count - max_history;
history.drain(start..start + to_remove);
let mut trim_end = start + (non_system_count - max_history);
// Never keep a leading `role=tool` at the trim boundary. Tool-message runs
// must remain attached to their preceding assistant(tool_calls) message.
while trim_end < history.len() && history[trim_end].role == "tool" {
trim_end += 1;
}
history.drain(start..trim_end);
}
pub(super) fn build_compaction_transcript(messages: &[ChatMessage]) -> String {
@@ -80,7 +85,11 @@ pub(super) async fn auto_compact_history(
return Ok(false);
}
let compact_end = start + compact_count;
let mut compact_end = start + compact_count;
// Do not split assistant(tool_calls) -> tool runs across compaction boundary.
while compact_end < history.len() && history[compact_end].role == "tool" {
compact_end += 1;
}
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
let transcript = build_compaction_transcript(&to_compact);
@@ -104,3 +113,97 @@ pub(super) async fn auto_compact_history(
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{ChatRequest, ChatResponse, Provider};
use async_trait::async_trait;
struct StaticSummaryProvider;
#[async_trait]
impl Provider for StaticSummaryProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("- summarized context".to_string())
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
text: Some("- summarized context".to_string()),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
quota_metadata: None,
})
}
}
fn assistant_with_tool_call(id: &str) -> ChatMessage {
ChatMessage::assistant(format!(
"{{\"content\":\"\",\"tool_calls\":[{{\"id\":\"{id}\",\"name\":\"shell\",\"arguments\":\"{{}}\"}}]}}"
))
}
fn tool_result(id: &str) -> ChatMessage {
ChatMessage::tool(format!("{{\"tool_call_id\":\"{id}\",\"content\":\"ok\"}}"))
}
#[test]
fn trim_history_avoids_orphan_tool_at_boundary() {
let mut history = vec![
ChatMessage::user("old"),
assistant_with_tool_call("call_1"),
tool_result("call_1"),
ChatMessage::user("recent"),
];
trim_history(&mut history, 2);
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "user");
assert_eq!(history[0].content, "recent");
}
#[tokio::test]
async fn auto_compact_history_does_not_split_tool_run_boundary() {
let mut history = vec![
ChatMessage::user("oldest"),
assistant_with_tool_call("call_2"),
tool_result("call_2"),
];
for idx in 0..19 {
history.push(ChatMessage::user(format!("recent-{idx}")));
}
// 22 non-system messages => compaction with max_history=21 would
// previously cut right before the tool result (index 2).
assert_eq!(history.len(), 22);
let compacted =
auto_compact_history(&mut history, &StaticSummaryProvider, "test-model", 21)
.await
.expect("compaction should succeed");
assert!(compacted);
assert_eq!(history[0].role, "assistant");
assert!(
history[0].content.contains("[Compaction summary]"),
"summary message should replace compacted range"
);
assert_ne!(
history[1].role, "tool",
"first retained message must not be an orphan tool result"
);
}
}
+2
View File
@@ -886,6 +886,7 @@ pub(super) fn map_tool_name_alias(tool_name: &str) -> &str {
// Memory variations
"memoryrecall" | "memory_recall" | "recall" | "memrecall" => "memory_recall",
"memorystore" | "memory_store" | "store" | "memstore" => "memory_store",
"memoryobserve" | "memory_observe" | "observe" | "memobserve" => "memory_observe",
"memoryforget" | "memory_forget" | "forget" | "memforget" => "memory_forget",
// HTTP variations
"http_request" | "http" | "fetch" | "curl" | "wget" => "http_request",
@@ -1026,6 +1027,7 @@ pub(super) fn default_param_for_tool(tool: &str) -> &'static str {
"memory_recall" | "memoryrecall" | "recall" | "memrecall" | "memory_forget"
| "memoryforget" | "forget" | "memforget" => "query",
"memory_store" | "memorystore" | "store" | "memstore" => "content",
"memory_observe" | "memoryobserve" | "observe" | "memobserve" => "observation",
// HTTP and browser tools default to "url"
"http_request" | "http" | "fetch" | "curl" | "wget" | "browser_open" | "browser"
| "web_search" => "url",
+2 -1
View File
@@ -5,6 +5,7 @@ pub mod dispatcher;
pub mod loop_;
pub mod memory_loader;
pub mod prompt;
pub mod quota_aware;
pub mod research;
pub mod session;
@@ -14,4 +15,4 @@ mod tests;
#[allow(unused_imports)]
pub use agent::{Agent, AgentBuilder};
#[allow(unused_imports)]
pub use loop_::{process_message, run};
pub use loop_::{process_message, process_message_with_session, run};
+3
View File
@@ -496,6 +496,7 @@ mod tests {
}],
prompts: vec!["Run smoke tests before deploy.".into()],
location: None,
always: false,
}];
let ctx = PromptContext {
@@ -534,6 +535,7 @@ mod tests {
}],
prompts: vec!["Run smoke tests before deploy.".into()],
location: Some(Path::new("/tmp/workspace/skills/deploy/SKILL.md").to_path_buf()),
always: false,
}];
let ctx = PromptContext {
@@ -594,6 +596,7 @@ mod tests {
}],
prompts: vec!["Use <tool_call> and & keep output \"safe\"".into()],
location: None,
always: false,
}];
let ctx = PromptContext {
workspace_dir: Path::new("/tmp/workspace"),
+233
View File
@@ -0,0 +1,233 @@
//! Quota-aware agent loop helpers.
//!
//! This module provides utilities for the agent loop to:
//! - Check provider quota status before expensive operations
//! - Warn users when quota is running low
//! - Switch providers mid-conversation when requested via tools
//! - Handle rate limit errors with automatic fallback
use crate::auth::profiles::AuthProfilesStore;
use crate::config::Config;
use crate::providers::health::ProviderHealthTracker;
use crate::providers::quota_types::QuotaStatus;
use anyhow::Result;
use std::time::Duration;
/// Check if we should warn about low quota before an operation.
///
/// Returns `Some(warning_message)` if quota is running low (< 10% remaining).
pub async fn check_quota_warning(
config: &Config,
provider_name: &str,
parallel_count: usize,
) -> Result<Option<String>> {
if parallel_count < 5 {
// Only warn for operations with 5+ parallel calls
return Ok(None);
}
let health_tracker = ProviderHealthTracker::new(
3, // failure_threshold
Duration::from_secs(60), // cooldown
100, // max tracked providers
);
let auth_store = AuthProfilesStore::new(&config.workspace_dir, config.secrets.encrypt);
let profiles_data = auth_store.load().await?;
let summary = crate::providers::quota_cli::build_quota_summary(
&health_tracker,
&profiles_data,
Some(provider_name),
)?;
// Find the provider in summary
if let Some(provider_info) = summary
.providers
.iter()
.find(|p| p.provider == provider_name)
{
// Check circuit breaker status
if provider_info.status == QuotaStatus::CircuitOpen {
let reset_str = if let Some(resets_at) = provider_info.circuit_resets_at {
format!(" (resets {})", format_relative_time(resets_at))
} else {
String::new()
};
return Ok(Some(format!(
"⚠️ **Provider Unavailable**: {} is circuit-open{}. \
Consider switching to an alternative provider using the `check_provider_quota` tool.",
provider_name, reset_str
)));
}
// Check rate limit status
if provider_info.status == QuotaStatus::RateLimited
|| provider_info.status == QuotaStatus::QuotaExhausted
{
return Ok(Some(format!(
"⚠️ **Rate Limit Warning**: {} is rate-limited. \
Your parallel operation ({} calls) may fail. \
Consider switching to another provider using `check_provider_quota` and `switch_provider` tools.",
provider_name, parallel_count
)));
}
// Check individual profile quotas
for profile in &provider_info.profiles {
if let (Some(remaining), Some(total)) =
(profile.rate_limit_remaining, profile.rate_limit_total)
{
let quota_pct = (remaining as f64 / total as f64) * 100.0;
if quota_pct < 10.0 && remaining < parallel_count as u64 {
let reset_str = if let Some(reset_at) = profile.rate_limit_reset_at {
format!(" (resets {})", format_relative_time(reset_at))
} else {
String::new()
};
return Ok(Some(format!(
"⚠️ **Low Quota Warning**: {} profile '{}' has only {}/{} requests remaining ({:.0}%){}. \
Your operation requires {} calls. \
Consider: (1) reducing parallel operations, (2) switching providers, or (3) waiting for quota reset.",
provider_name,
profile.profile_name,
remaining,
total,
quota_pct,
reset_str,
parallel_count
)));
}
}
}
}
Ok(None)
}
/// Parse switch_provider metadata from tool result output.
///
/// The `switch_provider` tool embeds JSON metadata in its output as:
/// `<!-- metadata: {...} -->`
///
/// Returns `Some((provider, model))` if a provider switch was requested.
pub fn parse_switch_provider_metadata(tool_output: &str) -> Option<(String, Option<String>)> {
// Look for <!-- metadata: {...} --> pattern
if let Some(start) = tool_output.find("<!-- metadata:") {
if let Some(end) = tool_output[start..].find("-->") {
let json_str = &tool_output[start + 14..start + end].trim();
if let Ok(metadata) = serde_json::from_str::<serde_json::Value>(json_str) {
if metadata.get("action").and_then(|v| v.as_str()) == Some("switch_provider") {
let provider = metadata
.get("provider")
.and_then(|v| v.as_str())
.map(String::from);
let model = metadata
.get("model")
.and_then(|v| v.as_str())
.map(String::from);
if let Some(p) = provider {
return Some((p, model));
}
}
}
}
}
None
}
/// Format relative time (e.g., "in 2h 30m" or "5 minutes ago").
fn format_relative_time(dt: chrono::DateTime<chrono::Utc>) -> String {
let now = chrono::Utc::now();
let diff = dt.signed_duration_since(now);
if diff.num_seconds() < 0 {
// In the past
let abs_diff = -diff;
if abs_diff.num_hours() > 0 {
format!("{}h ago", abs_diff.num_hours())
} else if abs_diff.num_minutes() > 0 {
format!("{}m ago", abs_diff.num_minutes())
} else {
format!("{}s ago", abs_diff.num_seconds())
}
} else {
// In the future
if diff.num_hours() > 0 {
format!("in {}h {}m", diff.num_hours(), diff.num_minutes() % 60)
} else if diff.num_minutes() > 0 {
format!("in {}m", diff.num_minutes())
} else {
format!("in {}s", diff.num_seconds())
}
}
}
/// Find an available alternative provider when current provider is unavailable.
///
/// Returns the name of a healthy provider with available quota, or None if all are unavailable.
pub async fn find_available_provider(
config: &Config,
current_provider: &str,
) -> Result<Option<String>> {
let health_tracker = ProviderHealthTracker::new(
3, // failure_threshold
Duration::from_secs(60), // cooldown
100, // max tracked providers
);
let auth_store = AuthProfilesStore::new(&config.workspace_dir, config.secrets.encrypt);
let profiles_data = auth_store.load().await?;
let summary =
crate::providers::quota_cli::build_quota_summary(&health_tracker, &profiles_data, None)?;
// Find providers with Ok status (not current provider)
for provider_info in &summary.providers {
if provider_info.provider != current_provider && provider_info.status == QuotaStatus::Ok {
return Ok(Some(provider_info.provider.clone()));
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_switch_provider_metadata() {
let output = "Switching to gemini.\n\n<!-- metadata: {\"action\":\"switch_provider\",\"provider\":\"gemini\",\"model\":null,\"reason\":\"user request\"} -->";
let result = parse_switch_provider_metadata(output);
assert_eq!(result, Some(("gemini".to_string(), None)));
let output_with_model = "Switching to openai.\n\n<!-- metadata: {\"action\":\"switch_provider\",\"provider\":\"openai\",\"model\":\"gpt-4\",\"reason\":\"rate limit\"} -->";
let result = parse_switch_provider_metadata(output_with_model);
assert_eq!(
result,
Some(("openai".to_string(), Some("gpt-4".to_string())))
);
let no_metadata = "Just some regular tool output";
assert_eq!(parse_switch_provider_metadata(no_metadata), None);
}
#[test]
fn test_format_relative_time() {
use chrono::{Duration, Utc};
let future = Utc::now() + Duration::seconds(3700);
let formatted = format_relative_time(future);
assert!(formatted.contains("in"));
assert!(formatted.contains('h'));
let past = Utc::now() - Duration::seconds(300);
let formatted = format_relative_time(past);
assert!(formatted.contains("ago"));
}
}
+11
View File
@@ -95,6 +95,7 @@ impl Provider for ScriptedProvider {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
});
}
Ok(guard.remove(0))
@@ -332,6 +333,7 @@ fn tool_response(calls: Vec<ToolCall>) -> ChatResponse {
tool_calls: calls,
usage: None,
reasoning_content: None,
quota_metadata: None,
}
}
@@ -342,6 +344,7 @@ fn text_response(text: &str) -> ChatResponse {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
}
}
@@ -354,6 +357,7 @@ fn xml_tool_response(name: &str, args: &str) -> ChatResponse {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
}
}
@@ -744,6 +748,7 @@ async fn turn_handles_empty_text_response() {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
}]));
let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher));
@@ -759,6 +764,7 @@ async fn turn_handles_none_text_response() {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
}]));
let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher));
@@ -784,6 +790,7 @@ async fn turn_preserves_text_alongside_tool_calls() {
}],
usage: None,
reasoning_content: None,
quota_metadata: None,
},
text_response("Here are the results"),
]));
@@ -1022,6 +1029,7 @@ async fn native_dispatcher_handles_stringified_arguments() {
}],
usage: None,
reasoning_content: None,
quota_metadata: None,
};
let (_, calls) = dispatcher.parse_response(&response);
@@ -1049,6 +1057,7 @@ fn xml_dispatcher_handles_nested_json() {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
};
let dispatcher = XmlToolDispatcher;
@@ -1068,6 +1077,7 @@ fn xml_dispatcher_handles_empty_tool_call_tag() {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
};
let dispatcher = XmlToolDispatcher;
@@ -1083,6 +1093,7 @@ fn xml_dispatcher_handles_unclosed_tool_call() {
tool_calls: vec![],
usage: None,
reasoning_content: None,
quota_metadata: None,
};
let dispatcher = XmlToolDispatcher;