Merge remote-tracking branch 'origin/main' into pr2093-mainmerge
This commit is contained in:
+17
-10
@@ -218,9 +218,7 @@ impl AgentBuilder {
|
||||
.memory_loader
|
||||
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
|
||||
config: self.config.unwrap_or_default(),
|
||||
model_name: self
|
||||
.model_name
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
|
||||
model_name: crate::config::resolve_default_model_id(self.model_name.as_deref(), None),
|
||||
temperature: self.temperature.unwrap_or(0.7),
|
||||
workspace_dir: self
|
||||
.workspace_dir
|
||||
@@ -298,11 +296,10 @@ impl Agent {
|
||||
|
||||
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||
|
||||
let model_name = config
|
||||
.default_model
|
||||
.as_deref()
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||
.to_string();
|
||||
let model_name = crate::config::resolve_default_model_id(
|
||||
config.default_model.as_deref(),
|
||||
Some(provider_name),
|
||||
);
|
||||
|
||||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
@@ -714,8 +711,12 @@ pub async fn run(
|
||||
let model_name = effective_config
|
||||
.default_model
|
||||
.as_deref()
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||
.to_string();
|
||||
.map(str::trim)
|
||||
.filter(|m| !m.is_empty())
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| {
|
||||
crate::config::default_model_fallback_for_provider(Some(&provider_name)).to_string()
|
||||
});
|
||||
|
||||
agent.observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name.clone(),
|
||||
@@ -776,6 +777,7 @@ mod tests {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
});
|
||||
}
|
||||
Ok(guard.remove(0))
|
||||
@@ -813,6 +815,7 @@ mod tests {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
});
|
||||
}
|
||||
Ok(guard.remove(0))
|
||||
@@ -852,6 +855,7 @@ mod tests {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}]),
|
||||
});
|
||||
|
||||
@@ -892,12 +896,14 @@ mod tests {
|
||||
}],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
},
|
||||
crate::providers::ChatResponse {
|
||||
text: Some("done".into()),
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
},
|
||||
]),
|
||||
});
|
||||
@@ -939,6 +945,7 @@ mod tests {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}]),
|
||||
seen_models: seen_models.clone(),
|
||||
});
|
||||
|
||||
@@ -263,6 +263,7 @@ mod tests {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
};
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
let (_, calls) = dispatcher.parse_response(&response);
|
||||
@@ -281,6 +282,7 @@ mod tests {
|
||||
}],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
};
|
||||
let dispatcher = NativeToolDispatcher;
|
||||
let (_, calls) = dispatcher.parse_response(&response);
|
||||
|
||||
+301
-124
@@ -290,6 +290,20 @@ pub(crate) struct NonCliApprovalContext {
|
||||
tokio::task_local! {
|
||||
static TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT: Option<NonCliApprovalContext>;
|
||||
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
|
||||
static SAFETY_HEARTBEAT_CONFIG: Option<SafetyHeartbeatConfig>;
|
||||
}
|
||||
|
||||
/// Configuration for periodic safety-constraint re-injection (heartbeat).
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SafetyHeartbeatConfig {
|
||||
/// Pre-rendered security policy summary text.
|
||||
pub body: String,
|
||||
/// Inject a heartbeat every `interval` tool iterations (0 = disabled).
|
||||
pub interval: usize,
|
||||
}
|
||||
|
||||
fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool {
|
||||
interval > 0 && counter > 0 && counter % interval == 0
|
||||
}
|
||||
|
||||
/// Extract a short hint from tool call arguments for progress display.
|
||||
@@ -687,33 +701,37 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
|
||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
excluded_tools: &[String],
|
||||
safety_heartbeat: Option<SafetyHeartbeatConfig>,
|
||||
) -> Result<String> {
|
||||
let reply_target = non_cli_approval_context
|
||||
.as_ref()
|
||||
.map(|ctx| ctx.reply_target.clone());
|
||||
|
||||
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT
|
||||
SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
non_cli_approval_context,
|
||||
TOOL_LOOP_REPLY_TARGET.scope(
|
||||
reply_target,
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
history,
|
||||
tools_registry,
|
||||
observer,
|
||||
provider_name,
|
||||
model,
|
||||
temperature,
|
||||
silent,
|
||||
approval,
|
||||
channel_name,
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
cancellation_token,
|
||||
on_delta,
|
||||
hooks,
|
||||
excluded_tools,
|
||||
safety_heartbeat,
|
||||
TOOL_LOOP_NON_CLI_APPROVAL_CONTEXT.scope(
|
||||
non_cli_approval_context,
|
||||
TOOL_LOOP_REPLY_TARGET.scope(
|
||||
reply_target,
|
||||
run_tool_call_loop(
|
||||
provider,
|
||||
history,
|
||||
tools_registry,
|
||||
observer,
|
||||
provider_name,
|
||||
model,
|
||||
temperature,
|
||||
silent,
|
||||
approval,
|
||||
channel_name,
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
cancellation_token,
|
||||
on_delta,
|
||||
hooks,
|
||||
excluded_tools,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -788,6 +806,10 @@ pub(crate) async fn run_tool_call_loop(
|
||||
.unwrap_or_default();
|
||||
let mut loop_detector = LoopDetector::new(ld_config);
|
||||
let mut loop_detection_prompt: Option<String> = None;
|
||||
let heartbeat_config = SAFETY_HEARTBEAT_CONFIG
|
||||
.try_with(Clone::clone)
|
||||
.ok()
|
||||
.flatten();
|
||||
let bypass_non_cli_approval_for_turn =
|
||||
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
|
||||
if bypass_non_cli_approval_for_turn {
|
||||
@@ -835,6 +857,19 @@ pub(crate) async fn run_tool_call_loop(
|
||||
request_messages.push(ChatMessage::user(prompt));
|
||||
}
|
||||
|
||||
// ── Safety heartbeat: periodic security-constraint re-injection ──
|
||||
if let Some(ref hb) = heartbeat_config {
|
||||
if should_inject_safety_heartbeat(iteration, hb.interval) {
|
||||
let reminder = format!(
|
||||
"[Safety Heartbeat — round {}/{}]\n{}",
|
||||
iteration + 1,
|
||||
max_iterations,
|
||||
hb.body
|
||||
);
|
||||
request_messages.push(ChatMessage::user(reminder));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Progress: LLM thinking ────────────────────────────
|
||||
if let Some(ref tx) = on_delta {
|
||||
let phase = if iteration == 0 {
|
||||
@@ -1244,13 +1279,30 @@ pub(crate) async fn run_tool_call_loop(
|
||||
|
||||
// ── Approval hook ────────────────────────────────
|
||||
if let Some(mgr) = approval {
|
||||
if bypass_non_cli_approval_for_turn {
|
||||
let non_cli_session_granted =
|
||||
channel_name != "cli" && mgr.is_non_cli_session_granted(&tool_name);
|
||||
if bypass_non_cli_approval_for_turn || non_cli_session_granted {
|
||||
mgr.record_decision(
|
||||
&tool_name,
|
||||
&tool_args,
|
||||
ApprovalResponse::Yes,
|
||||
channel_name,
|
||||
);
|
||||
if non_cli_session_granted {
|
||||
runtime_trace::record_event(
|
||||
"approval_bypass_non_cli_session_grant",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("using runtime non-cli session approval grant"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"tool": tool_name.clone(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else if mgr.needs_approval(&tool_name) {
|
||||
let request = ApprovalRequest {
|
||||
tool_name: tool_name.clone(),
|
||||
@@ -1765,10 +1817,12 @@ pub async fn run(
|
||||
.or(config.default_provider.as_deref())
|
||||
.unwrap_or("openrouter");
|
||||
|
||||
let model_name = model_override
|
||||
.as_deref()
|
||||
.or(config.default_model.as_deref())
|
||||
.unwrap_or("anthropic/claude-sonnet-4");
|
||||
let model_name = crate::config::resolve_default_model_id(
|
||||
model_override
|
||||
.as_deref()
|
||||
.or(config.default_model.as_deref()),
|
||||
Some(provider_name),
|
||||
);
|
||||
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
@@ -1789,7 +1843,7 @@ pub async fn run(
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
model_name,
|
||||
&model_name,
|
||||
&provider_runtime_options,
|
||||
)?;
|
||||
|
||||
@@ -1837,6 +1891,10 @@ pub async fn run(
|
||||
"memory_store",
|
||||
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
|
||||
),
|
||||
(
|
||||
"memory_observe",
|
||||
"Store observation memory. Use when: capturing patterns/signals that should remain searchable over long horizons.",
|
||||
),
|
||||
(
|
||||
"memory_recall",
|
||||
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
|
||||
@@ -1948,7 +2006,7 @@ pub async fn run(
|
||||
let native_tools = provider.supports_native_tools();
|
||||
let mut system_prompt = crate::channels::build_system_prompt_with_mode(
|
||||
&config.workspace_dir,
|
||||
model_name,
|
||||
&model_name,
|
||||
&tool_descs,
|
||||
&skills,
|
||||
Some(&config.identity),
|
||||
@@ -1987,7 +2045,7 @@ pub async fn run(
|
||||
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context =
|
||||
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await;
|
||||
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score, None).await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -2011,26 +2069,37 @@ pub async fn run(
|
||||
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
|
||||
failure_streak_threshold: config.agent.loop_detection_failure_streak,
|
||||
};
|
||||
let response = LOOP_DETECTION_CONFIG
|
||||
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
|
||||
Some(SafetyHeartbeatConfig {
|
||||
body: security.summary_for_heartbeat(),
|
||||
interval: config.agent.safety_heartbeat_interval,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let response = SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
ld_cfg,
|
||||
run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
hb_cfg,
|
||||
LOOP_DETECTION_CONFIG.scope(
|
||||
ld_cfg,
|
||||
run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
@@ -2044,6 +2113,7 @@ pub async fn run(
|
||||
|
||||
// Persistent conversation history across turns
|
||||
let mut history = vec![ChatMessage::system(&system_prompt)];
|
||||
let mut interactive_turn: usize = 0;
|
||||
// Reusable readline editor for UTF-8 input support
|
||||
let mut rl = Editor::with_config(
|
||||
RlConfig::builder()
|
||||
@@ -2094,6 +2164,7 @@ pub async fn run(
|
||||
rl.clear_history()?;
|
||||
history.clear();
|
||||
history.push(ChatMessage::system(&system_prompt));
|
||||
interactive_turn = 0;
|
||||
// Clear conversation and daily memory
|
||||
let mut cleared = 0;
|
||||
for category in [MemoryCategory::Conversation, MemoryCategory::Daily] {
|
||||
@@ -2123,8 +2194,13 @@ pub async fn run(
|
||||
}
|
||||
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context =
|
||||
build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&user_input,
|
||||
config.memory.min_relevance_score,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -2139,32 +2215,57 @@ pub async fn run(
|
||||
};
|
||||
|
||||
history.push(ChatMessage::user(&enriched));
|
||||
interactive_turn += 1;
|
||||
|
||||
// Inject interactive safety heartbeat at configured turn intervals
|
||||
if should_inject_safety_heartbeat(
|
||||
interactive_turn,
|
||||
config.agent.safety_heartbeat_turn_interval,
|
||||
) {
|
||||
let reminder = format!(
|
||||
"[Safety Heartbeat — turn {}]\n{}",
|
||||
interactive_turn,
|
||||
security.summary_for_heartbeat()
|
||||
);
|
||||
history.push(ChatMessage::user(reminder));
|
||||
}
|
||||
|
||||
let ld_cfg = LoopDetectionConfig {
|
||||
no_progress_threshold: config.agent.loop_detection_no_progress_threshold,
|
||||
ping_pong_cycles: config.agent.loop_detection_ping_pong_cycles,
|
||||
failure_streak_threshold: config.agent.loop_detection_failure_streak,
|
||||
};
|
||||
let response = match LOOP_DETECTION_CONFIG
|
||||
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
|
||||
Some(SafetyHeartbeatConfig {
|
||||
body: security.summary_for_heartbeat(),
|
||||
interval: config.agent.safety_heartbeat_interval,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let response = match SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
ld_cfg,
|
||||
run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
hb_cfg,
|
||||
LOOP_DETECTION_CONFIG.scope(
|
||||
ld_cfg,
|
||||
run_tool_call_loop(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
),
|
||||
),
|
||||
)
|
||||
.await
|
||||
@@ -2209,7 +2310,7 @@ pub async fn run(
|
||||
if let Ok(compacted) = auto_compact_history(
|
||||
&mut history,
|
||||
provider.as_ref(),
|
||||
model_name,
|
||||
&model_name,
|
||||
config.agent.max_history_messages,
|
||||
)
|
||||
.await
|
||||
@@ -2238,13 +2339,15 @@ pub async fn run(
|
||||
|
||||
/// Process a single message through the full agent (with tools, peripherals, memory).
|
||||
/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use.
|
||||
pub async fn process_message(
|
||||
pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
process_message_with_session(config, message, None).await
|
||||
}
|
||||
|
||||
pub async fn process_message_with_session(
|
||||
config: Config,
|
||||
message: &str,
|
||||
sender_id: &str,
|
||||
channel_name: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
tracing::debug!(sender_id, channel_name, "process_message called");
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
@@ -2288,10 +2391,10 @@ pub async fn process_message(
|
||||
tools_registry.extend(peripheral_tools);
|
||||
|
||||
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||
let model_name = config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||
let model_name = crate::config::resolve_default_model_id(
|
||||
config.default_model.as_deref(),
|
||||
Some(provider_name),
|
||||
);
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
provider_api_url: config.api_url.clone(),
|
||||
@@ -2335,6 +2438,7 @@ pub async fn process_message(
|
||||
("file_read", "Read file contents."),
|
||||
("file_write", "Write file contents."),
|
||||
("memory_store", "Save to memory."),
|
||||
("memory_observe", "Store observation memory."),
|
||||
("memory_recall", "Search memory."),
|
||||
("memory_forget", "Delete a memory entry."),
|
||||
(
|
||||
@@ -2407,7 +2511,13 @@ pub async fn process_message(
|
||||
}
|
||||
system_prompt.push_str(&build_shell_policy_instructions(&config.autonomy));
|
||||
|
||||
let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await;
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
message,
|
||||
config.memory.min_relevance_score,
|
||||
session_id,
|
||||
)
|
||||
.await;
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
@@ -2433,53 +2543,31 @@ pub async fn process_message(
|
||||
.filter(|m| crate::providers::is_user_or_assistant_role(m.role.as_str()))
|
||||
.collect();
|
||||
|
||||
let mut history = Vec::new();
|
||||
history.push(ChatMessage::system(&system_prompt));
|
||||
history.extend(filtered_history);
|
||||
history.push(ChatMessage::user(&enriched));
|
||||
let reply = agent_turn(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
config.default_temperature,
|
||||
true,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
)
|
||||
.await?;
|
||||
let persisted: Vec<ChatMessage> = history
|
||||
.into_iter()
|
||||
.filter(|m| crate::providers::is_user_or_assistant_role(m.role.as_str()))
|
||||
.collect();
|
||||
let saved_len = persisted.len();
|
||||
session
|
||||
.update_history(persisted)
|
||||
.await
|
||||
.context("Failed to update session history")?;
|
||||
tracing::debug!(saved_len, "session history saved");
|
||||
Ok(reply)
|
||||
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
|
||||
Some(SafetyHeartbeatConfig {
|
||||
body: security.summary_for_heartbeat(),
|
||||
interval: config.agent.safety_heartbeat_interval,
|
||||
})
|
||||
} else {
|
||||
let mut history = vec![
|
||||
ChatMessage::system(&system_prompt),
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
agent_turn(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
config.default_temperature,
|
||||
true,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None
|
||||
};
|
||||
SAFETY_HEARTBEAT_CONFIG
|
||||
.scope(
|
||||
hb_cfg,
|
||||
agent_turn(
|
||||
provider.as_ref(),
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
config.default_temperature,
|
||||
true,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -2574,6 +2662,36 @@ mod tests {
|
||||
assert_eq!(feishu_args["delivery"]["to"], "oc_yyy");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safety_heartbeat_interval_zero_disables_injection() {
|
||||
for counter in [0, 1, 2, 10, 100] {
|
||||
assert!(
|
||||
!should_inject_safety_heartbeat(counter, 0),
|
||||
"counter={counter} should not inject when interval=0"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safety_heartbeat_interval_one_injects_every_non_initial_step() {
|
||||
assert!(!should_inject_safety_heartbeat(0, 1));
|
||||
for counter in 1..=6 {
|
||||
assert!(
|
||||
should_inject_safety_heartbeat(counter, 1),
|
||||
"counter={counter} should inject when interval=1"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safety_heartbeat_injects_only_on_exact_multiples() {
|
||||
let interval = 3;
|
||||
let injected: Vec<usize> = (0..=10)
|
||||
.filter(|counter| should_inject_safety_heartbeat(*counter, interval))
|
||||
.collect();
|
||||
assert_eq!(injected, vec![3, 6, 9]);
|
||||
}
|
||||
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::traits::ProviderCapabilities;
|
||||
@@ -2643,6 +2761,7 @@ mod tests {
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2661,6 +2780,7 @@ mod tests {
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
@@ -3183,6 +3303,62 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_uses_non_cli_session_grant_without_waiting_for_prompt() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"shell","arguments":{"command":"echo hi"}}
|
||||
</tool_call>"#,
|
||||
"done",
|
||||
]);
|
||||
|
||||
let active = Arc::new(AtomicUsize::new(0));
|
||||
let max_active = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(DelayTool::new(
|
||||
"shell",
|
||||
50,
|
||||
Arc::clone(&active),
|
||||
Arc::clone(&max_active),
|
||||
))];
|
||||
|
||||
let approval_mgr = ApprovalManager::from_config(&crate::config::AutonomyConfig::default());
|
||||
approval_mgr.grant_non_cli_session("shell");
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("run shell"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
Some(&approval_mgr),
|
||||
"telegram",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should consume non-cli session grants");
|
||||
|
||||
assert_eq!(result, "done");
|
||||
assert_eq!(
|
||||
max_active.load(Ordering::SeqCst),
|
||||
1,
|
||||
"shell tool should execute when runtime non-cli session grant exists"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_waits_for_non_cli_approval_resolution() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
@@ -3252,6 +3428,7 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should continue after non-cli approval");
|
||||
@@ -4397,7 +4574,7 @@ Tail"#;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = build_context(&mem, "status updates", 0.0).await;
|
||||
let context = build_context(&mem, "status updates", 0.0, None).await;
|
||||
assert!(context.contains("user_msg_real"));
|
||||
assert!(!context.contains("assistant_resp_poisoned"));
|
||||
assert!(!context.contains("fabricated event"));
|
||||
|
||||
@@ -8,11 +8,12 @@ pub(super) async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
min_relevance_score: f64,
|
||||
session_id: Option<&str>,
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
|
||||
@@ -220,7 +220,9 @@ impl LoopDetector {
|
||||
|
||||
fn hash_output(output: &str) -> u64 {
|
||||
let prefix = if output.len() > OUTPUT_HASH_PREFIX_BYTES {
|
||||
&output[..OUTPUT_HASH_PREFIX_BYTES]
|
||||
// Use floor_utf8_char_boundary to avoid panic on multi-byte UTF-8 characters
|
||||
let boundary = crate::util::floor_utf8_char_boundary(output, OUTPUT_HASH_PREFIX_BYTES);
|
||||
&output[..boundary]
|
||||
} else {
|
||||
output
|
||||
};
|
||||
@@ -386,4 +388,26 @@ mod tests {
|
||||
det.record_call("shell", r#"{"cmd":"cargo test"}"#, "ok", true);
|
||||
assert_eq!(det.check(), DetectionVerdict::Continue);
|
||||
}
|
||||
|
||||
// 11. UTF-8 boundary safety: hash_output must not panic on CJK text
|
||||
#[test]
|
||||
fn hash_output_utf8_boundary_safe() {
|
||||
// Create a string where byte 4096 lands inside a multi-byte char
|
||||
// Chinese chars are 3 bytes each, so 1366 chars = 4098 bytes
|
||||
let cjk_text: String = "文".repeat(1366); // 4098 bytes
|
||||
assert!(cjk_text.len() > super::OUTPUT_HASH_PREFIX_BYTES);
|
||||
|
||||
// This should NOT panic
|
||||
let hash1 = super::hash_output(&cjk_text);
|
||||
|
||||
// Different content should produce different hash
|
||||
let cjk_text2: String = "字".repeat(1366);
|
||||
let hash2 = super::hash_output(&cjk_text2);
|
||||
assert_ne!(hash1, hash2);
|
||||
|
||||
// Mixed ASCII + CJK at boundary
|
||||
let mixed = "a".repeat(4094) + "文文"; // 4094 + 6 = 4100 bytes, boundary at 4096
|
||||
let hash3 = super::hash_output(&mixed);
|
||||
assert!(hash3 != 0); // Just verify it runs
|
||||
}
|
||||
}
|
||||
|
||||
+106
-3
@@ -28,8 +28,13 @@ pub(super) fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
|
||||
}
|
||||
|
||||
let start = if has_system { 1 } else { 0 };
|
||||
let to_remove = non_system_count - max_history;
|
||||
history.drain(start..start + to_remove);
|
||||
let mut trim_end = start + (non_system_count - max_history);
|
||||
// Never keep a leading `role=tool` at the trim boundary. Tool-message runs
|
||||
// must remain attached to their preceding assistant(tool_calls) message.
|
||||
while trim_end < history.len() && history[trim_end].role == "tool" {
|
||||
trim_end += 1;
|
||||
}
|
||||
history.drain(start..trim_end);
|
||||
}
|
||||
|
||||
pub(super) fn build_compaction_transcript(messages: &[ChatMessage]) -> String {
|
||||
@@ -80,7 +85,11 @@ pub(super) async fn auto_compact_history(
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let compact_end = start + compact_count;
|
||||
let mut compact_end = start + compact_count;
|
||||
// Do not split assistant(tool_calls) -> tool runs across compaction boundary.
|
||||
while compact_end < history.len() && history[compact_end].role == "tool" {
|
||||
compact_end += 1;
|
||||
}
|
||||
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
|
||||
let transcript = build_compaction_transcript(&to_compact);
|
||||
|
||||
@@ -104,3 +113,97 @@ pub(super) async fn auto_compact_history(
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::providers::{ChatRequest, ChatResponse, Provider};
|
||||
use async_trait::async_trait;
|
||||
|
||||
struct StaticSummaryProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for StaticSummaryProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("- summarized context".to_string())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
Ok(ChatResponse {
|
||||
text: Some("- summarized context".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_with_tool_call(id: &str) -> ChatMessage {
|
||||
ChatMessage::assistant(format!(
|
||||
"{{\"content\":\"\",\"tool_calls\":[{{\"id\":\"{id}\",\"name\":\"shell\",\"arguments\":\"{{}}\"}}]}}"
|
||||
))
|
||||
}
|
||||
|
||||
fn tool_result(id: &str) -> ChatMessage {
|
||||
ChatMessage::tool(format!("{{\"tool_call_id\":\"{id}\",\"content\":\"ok\"}}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trim_history_avoids_orphan_tool_at_boundary() {
|
||||
let mut history = vec![
|
||||
ChatMessage::user("old"),
|
||||
assistant_with_tool_call("call_1"),
|
||||
tool_result("call_1"),
|
||||
ChatMessage::user("recent"),
|
||||
];
|
||||
|
||||
trim_history(&mut history, 2);
|
||||
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "user");
|
||||
assert_eq!(history[0].content, "recent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auto_compact_history_does_not_split_tool_run_boundary() {
|
||||
let mut history = vec![
|
||||
ChatMessage::user("oldest"),
|
||||
assistant_with_tool_call("call_2"),
|
||||
tool_result("call_2"),
|
||||
];
|
||||
for idx in 0..19 {
|
||||
history.push(ChatMessage::user(format!("recent-{idx}")));
|
||||
}
|
||||
// 22 non-system messages => compaction with max_history=21 would
|
||||
// previously cut right before the tool result (index 2).
|
||||
assert_eq!(history.len(), 22);
|
||||
|
||||
let compacted =
|
||||
auto_compact_history(&mut history, &StaticSummaryProvider, "test-model", 21)
|
||||
.await
|
||||
.expect("compaction should succeed");
|
||||
|
||||
assert!(compacted);
|
||||
assert_eq!(history[0].role, "assistant");
|
||||
assert!(
|
||||
history[0].content.contains("[Compaction summary]"),
|
||||
"summary message should replace compacted range"
|
||||
);
|
||||
assert_ne!(
|
||||
history[1].role, "tool",
|
||||
"first retained message must not be an orphan tool result"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -886,6 +886,7 @@ pub(super) fn map_tool_name_alias(tool_name: &str) -> &str {
|
||||
// Memory variations
|
||||
"memoryrecall" | "memory_recall" | "recall" | "memrecall" => "memory_recall",
|
||||
"memorystore" | "memory_store" | "store" | "memstore" => "memory_store",
|
||||
"memoryobserve" | "memory_observe" | "observe" | "memobserve" => "memory_observe",
|
||||
"memoryforget" | "memory_forget" | "forget" | "memforget" => "memory_forget",
|
||||
// HTTP variations
|
||||
"http_request" | "http" | "fetch" | "curl" | "wget" => "http_request",
|
||||
@@ -1026,6 +1027,7 @@ pub(super) fn default_param_for_tool(tool: &str) -> &'static str {
|
||||
"memory_recall" | "memoryrecall" | "recall" | "memrecall" | "memory_forget"
|
||||
| "memoryforget" | "forget" | "memforget" => "query",
|
||||
"memory_store" | "memorystore" | "store" | "memstore" => "content",
|
||||
"memory_observe" | "memoryobserve" | "observe" | "memobserve" => "observation",
|
||||
// HTTP and browser tools default to "url"
|
||||
"http_request" | "http" | "fetch" | "curl" | "wget" | "browser_open" | "browser"
|
||||
| "web_search" => "url",
|
||||
|
||||
+2
-1
@@ -5,6 +5,7 @@ pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
pub mod quota_aware;
|
||||
pub mod research;
|
||||
pub mod session;
|
||||
|
||||
@@ -14,4 +15,4 @@ mod tests;
|
||||
#[allow(unused_imports)]
|
||||
pub use agent::{Agent, AgentBuilder};
|
||||
#[allow(unused_imports)]
|
||||
pub use loop_::{process_message, run};
|
||||
pub use loop_::{process_message, process_message_with_session, run};
|
||||
|
||||
@@ -496,6 +496,7 @@ mod tests {
|
||||
}],
|
||||
prompts: vec!["Run smoke tests before deploy.".into()],
|
||||
location: None,
|
||||
always: false,
|
||||
}];
|
||||
|
||||
let ctx = PromptContext {
|
||||
@@ -534,6 +535,7 @@ mod tests {
|
||||
}],
|
||||
prompts: vec!["Run smoke tests before deploy.".into()],
|
||||
location: Some(Path::new("/tmp/workspace/skills/deploy/SKILL.md").to_path_buf()),
|
||||
always: false,
|
||||
}];
|
||||
|
||||
let ctx = PromptContext {
|
||||
@@ -594,6 +596,7 @@ mod tests {
|
||||
}],
|
||||
prompts: vec!["Use <tool_call> and & keep output \"safe\"".into()],
|
||||
location: None,
|
||||
always: false,
|
||||
}];
|
||||
let ctx = PromptContext {
|
||||
workspace_dir: Path::new("/tmp/workspace"),
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
//! Quota-aware agent loop helpers.
|
||||
//!
|
||||
//! This module provides utilities for the agent loop to:
|
||||
//! - Check provider quota status before expensive operations
|
||||
//! - Warn users when quota is running low
|
||||
//! - Switch providers mid-conversation when requested via tools
|
||||
//! - Handle rate limit errors with automatic fallback
|
||||
|
||||
use crate::auth::profiles::AuthProfilesStore;
|
||||
use crate::config::Config;
|
||||
use crate::providers::health::ProviderHealthTracker;
|
||||
use crate::providers::quota_types::QuotaStatus;
|
||||
use anyhow::Result;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Check if we should warn about low quota before an operation.
|
||||
///
|
||||
/// Returns `Some(warning_message)` if quota is running low (< 10% remaining).
|
||||
pub async fn check_quota_warning(
|
||||
config: &Config,
|
||||
provider_name: &str,
|
||||
parallel_count: usize,
|
||||
) -> Result<Option<String>> {
|
||||
if parallel_count < 5 {
|
||||
// Only warn for operations with 5+ parallel calls
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let health_tracker = ProviderHealthTracker::new(
|
||||
3, // failure_threshold
|
||||
Duration::from_secs(60), // cooldown
|
||||
100, // max tracked providers
|
||||
);
|
||||
|
||||
let auth_store = AuthProfilesStore::new(&config.workspace_dir, config.secrets.encrypt);
|
||||
let profiles_data = auth_store.load().await?;
|
||||
|
||||
let summary = crate::providers::quota_cli::build_quota_summary(
|
||||
&health_tracker,
|
||||
&profiles_data,
|
||||
Some(provider_name),
|
||||
)?;
|
||||
|
||||
// Find the provider in summary
|
||||
if let Some(provider_info) = summary
|
||||
.providers
|
||||
.iter()
|
||||
.find(|p| p.provider == provider_name)
|
||||
{
|
||||
// Check circuit breaker status
|
||||
if provider_info.status == QuotaStatus::CircuitOpen {
|
||||
let reset_str = if let Some(resets_at) = provider_info.circuit_resets_at {
|
||||
format!(" (resets {})", format_relative_time(resets_at))
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
return Ok(Some(format!(
|
||||
"⚠️ **Provider Unavailable**: {} is circuit-open{}. \
|
||||
Consider switching to an alternative provider using the `check_provider_quota` tool.",
|
||||
provider_name, reset_str
|
||||
)));
|
||||
}
|
||||
|
||||
// Check rate limit status
|
||||
if provider_info.status == QuotaStatus::RateLimited
|
||||
|| provider_info.status == QuotaStatus::QuotaExhausted
|
||||
{
|
||||
return Ok(Some(format!(
|
||||
"⚠️ **Rate Limit Warning**: {} is rate-limited. \
|
||||
Your parallel operation ({} calls) may fail. \
|
||||
Consider switching to another provider using `check_provider_quota` and `switch_provider` tools.",
|
||||
provider_name, parallel_count
|
||||
)));
|
||||
}
|
||||
|
||||
// Check individual profile quotas
|
||||
for profile in &provider_info.profiles {
|
||||
if let (Some(remaining), Some(total)) =
|
||||
(profile.rate_limit_remaining, profile.rate_limit_total)
|
||||
{
|
||||
let quota_pct = (remaining as f64 / total as f64) * 100.0;
|
||||
if quota_pct < 10.0 && remaining < parallel_count as u64 {
|
||||
let reset_str = if let Some(reset_at) = profile.rate_limit_reset_at {
|
||||
format!(" (resets {})", format_relative_time(reset_at))
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
return Ok(Some(format!(
|
||||
"⚠️ **Low Quota Warning**: {} profile '{}' has only {}/{} requests remaining ({:.0}%){}. \
|
||||
Your operation requires {} calls. \
|
||||
Consider: (1) reducing parallel operations, (2) switching providers, or (3) waiting for quota reset.",
|
||||
provider_name,
|
||||
profile.profile_name,
|
||||
remaining,
|
||||
total,
|
||||
quota_pct,
|
||||
reset_str,
|
||||
parallel_count
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Parse switch_provider metadata from tool result output.
|
||||
///
|
||||
/// The `switch_provider` tool embeds JSON metadata in its output as:
|
||||
/// `<!-- metadata: {...} -->`
|
||||
///
|
||||
/// Returns `Some((provider, model))` if a provider switch was requested.
|
||||
pub fn parse_switch_provider_metadata(tool_output: &str) -> Option<(String, Option<String>)> {
|
||||
// Look for <!-- metadata: {...} --> pattern
|
||||
if let Some(start) = tool_output.find("<!-- metadata:") {
|
||||
if let Some(end) = tool_output[start..].find("-->") {
|
||||
let json_str = &tool_output[start + 14..start + end].trim();
|
||||
if let Ok(metadata) = serde_json::from_str::<serde_json::Value>(json_str) {
|
||||
if metadata.get("action").and_then(|v| v.as_str()) == Some("switch_provider") {
|
||||
let provider = metadata
|
||||
.get("provider")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
let model = metadata
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
if let Some(p) = provider {
|
||||
return Some((p, model));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Format relative time (e.g., "in 2h 30m" or "5 minutes ago").
|
||||
fn format_relative_time(dt: chrono::DateTime<chrono::Utc>) -> String {
|
||||
let now = chrono::Utc::now();
|
||||
let diff = dt.signed_duration_since(now);
|
||||
|
||||
if diff.num_seconds() < 0 {
|
||||
// In the past
|
||||
let abs_diff = -diff;
|
||||
if abs_diff.num_hours() > 0 {
|
||||
format!("{}h ago", abs_diff.num_hours())
|
||||
} else if abs_diff.num_minutes() > 0 {
|
||||
format!("{}m ago", abs_diff.num_minutes())
|
||||
} else {
|
||||
format!("{}s ago", abs_diff.num_seconds())
|
||||
}
|
||||
} else {
|
||||
// In the future
|
||||
if diff.num_hours() > 0 {
|
||||
format!("in {}h {}m", diff.num_hours(), diff.num_minutes() % 60)
|
||||
} else if diff.num_minutes() > 0 {
|
||||
format!("in {}m", diff.num_minutes())
|
||||
} else {
|
||||
format!("in {}s", diff.num_seconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find an available alternative provider when current provider is unavailable.
|
||||
///
|
||||
/// Returns the name of a healthy provider with available quota, or None if all are unavailable.
|
||||
pub async fn find_available_provider(
|
||||
config: &Config,
|
||||
current_provider: &str,
|
||||
) -> Result<Option<String>> {
|
||||
let health_tracker = ProviderHealthTracker::new(
|
||||
3, // failure_threshold
|
||||
Duration::from_secs(60), // cooldown
|
||||
100, // max tracked providers
|
||||
);
|
||||
|
||||
let auth_store = AuthProfilesStore::new(&config.workspace_dir, config.secrets.encrypt);
|
||||
let profiles_data = auth_store.load().await?;
|
||||
|
||||
let summary =
|
||||
crate::providers::quota_cli::build_quota_summary(&health_tracker, &profiles_data, None)?;
|
||||
|
||||
// Find providers with Ok status (not current provider)
|
||||
for provider_info in &summary.providers {
|
||||
if provider_info.provider != current_provider && provider_info.status == QuotaStatus::Ok {
|
||||
return Ok(Some(provider_info.provider.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_switch_provider_metadata() {
|
||||
let output = "Switching to gemini.\n\n<!-- metadata: {\"action\":\"switch_provider\",\"provider\":\"gemini\",\"model\":null,\"reason\":\"user request\"} -->";
|
||||
let result = parse_switch_provider_metadata(output);
|
||||
assert_eq!(result, Some(("gemini".to_string(), None)));
|
||||
|
||||
let output_with_model = "Switching to openai.\n\n<!-- metadata: {\"action\":\"switch_provider\",\"provider\":\"openai\",\"model\":\"gpt-4\",\"reason\":\"rate limit\"} -->";
|
||||
let result = parse_switch_provider_metadata(output_with_model);
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("openai".to_string(), Some("gpt-4".to_string())))
|
||||
);
|
||||
|
||||
let no_metadata = "Just some regular tool output";
|
||||
assert_eq!(parse_switch_provider_metadata(no_metadata), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_relative_time() {
|
||||
use chrono::{Duration, Utc};
|
||||
|
||||
let future = Utc::now() + Duration::seconds(3700);
|
||||
let formatted = format_relative_time(future);
|
||||
assert!(formatted.contains("in"));
|
||||
assert!(formatted.contains('h'));
|
||||
|
||||
let past = Utc::now() - Duration::seconds(300);
|
||||
let formatted = format_relative_time(past);
|
||||
assert!(formatted.contains("ago"));
|
||||
}
|
||||
}
|
||||
@@ -95,6 +95,7 @@ impl Provider for ScriptedProvider {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
});
|
||||
}
|
||||
Ok(guard.remove(0))
|
||||
@@ -332,6 +333,7 @@ fn tool_response(calls: Vec<ToolCall>) -> ChatResponse {
|
||||
tool_calls: calls,
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,6 +344,7 @@ fn text_response(text: &str) -> ChatResponse {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,6 +357,7 @@ fn xml_tool_response(name: &str, args: &str) -> ChatResponse {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -744,6 +748,7 @@ async fn turn_handles_empty_text_response() {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}]));
|
||||
|
||||
let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher));
|
||||
@@ -759,6 +764,7 @@ async fn turn_handles_none_text_response() {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
}]));
|
||||
|
||||
let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher));
|
||||
@@ -784,6 +790,7 @@ async fn turn_preserves_text_alongside_tool_calls() {
|
||||
}],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
},
|
||||
text_response("Here are the results"),
|
||||
]));
|
||||
@@ -1022,6 +1029,7 @@ async fn native_dispatcher_handles_stringified_arguments() {
|
||||
}],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
};
|
||||
|
||||
let (_, calls) = dispatcher.parse_response(&response);
|
||||
@@ -1049,6 +1057,7 @@ fn xml_dispatcher_handles_nested_json() {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
};
|
||||
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
@@ -1068,6 +1077,7 @@ fn xml_dispatcher_handles_empty_tool_call_tag() {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
};
|
||||
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
@@ -1083,6 +1093,7 @@ fn xml_dispatcher_handles_unclosed_tool_call() {
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
};
|
||||
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
|
||||
Reference in New Issue
Block a user