Merge branch 'main' into fix/release-v0.1.8-build-errors

This commit is contained in:
Chum Yin
2026-03-02 04:34:58 +08:00
committed by GitHub
32 changed files with 3321 additions and 1339 deletions
+4
View File
@@ -243,6 +243,10 @@ impl Agent {
AgentBuilder::new()
}
pub fn tool_specs(&self) -> &[ToolSpec] {
&self.tool_specs
}
pub fn history(&self) -> &[ConversationMessage] {
&self.history
}
+2 -1
View File
@@ -983,7 +983,7 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn run_tool_call_loop(
pub async fn run_tool_call_loop(
provider: &dyn Provider,
history: &mut Vec<ChatMessage>,
tools_registry: &[Box<dyn Tool>],
@@ -2773,6 +2773,7 @@ pub async fn run(
&model_name,
config.agent.max_history_messages,
effective_hooks,
Some(mem.as_ref()),
)
.await
{
+181 -17
View File
@@ -1,9 +1,29 @@
use crate::memory::{self, Memory};
use crate::memory::{self, decay, Memory, MemoryCategory};
use std::fmt::Write;
/// Default half-life (days) for time decay in context building.
const CONTEXT_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
/// Score boost applied to `Core` category memories so durable facts and
/// preferences surface even when keyword/semantic similarity is moderate.
const CORE_CATEGORY_SCORE_BOOST: f64 = 0.3;
/// Maximum number of memory entries included in the context preamble.
const CONTEXT_ENTRY_LIMIT: usize = 5;
/// Over-fetch factor: retrieve more candidates than the output limit so
/// that Core boost and re-ranking can select the best subset.
const RECALL_OVER_FETCH_FACTOR: usize = 2;
/// Build context preamble by searching memory for relevant entries.
/// Entries with a hybrid score below `min_relevance_score` are dropped to
/// prevent unrelated memories from bleeding into the conversation.
///
/// Core memories are exempt from time decay (evergreen).
///
/// `Core` category memories receive a score boost so that durable facts,
/// preferences, and project rules are more likely to appear in context
/// even when semantic similarity to the current message is moderate.
pub(super) async fn build_context(
mem: &dyn Memory,
user_msg: &str,
@@ -12,29 +32,41 @@ pub(super) async fn build_context(
) -> String {
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
let relevant: Vec<_> = entries
// Over-fetch so Core-boosted entries can compete fairly after re-ranking.
let fetch_limit = CONTEXT_ENTRY_LIMIT * RECALL_OVER_FETCH_FACTOR;
if let Ok(mut entries) = mem.recall(user_msg, fetch_limit, session_id).await {
// Apply time decay: older non-Core memories score lower.
decay::apply_time_decay(&mut entries, CONTEXT_DECAY_HALF_LIFE_DAYS);
// Apply Core category boost and filter by minimum relevance.
let mut scored: Vec<_> = entries
.iter()
.filter(|e| match e.score {
Some(score) => score >= min_relevance_score,
None => true,
.filter(|e| !memory::is_assistant_autosave_key(&e.key))
.filter_map(|e| {
let base = e.score.unwrap_or(min_relevance_score);
let boosted = if e.category == MemoryCategory::Core {
(base + CORE_CATEGORY_SCORE_BOOST).min(1.0)
} else {
base
};
if boosted >= min_relevance_score {
Some((e, boosted))
} else {
None
}
})
.collect();
if !relevant.is_empty() {
// Sort by boosted score descending, then truncate to output limit.
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(CONTEXT_ENTRY_LIMIT);
if !scored.is_empty() {
context.push_str("[Memory context]\n");
for entry in &relevant {
if memory::is_assistant_autosave_key(&entry.key) {
continue;
}
for (entry, _) in &scored {
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
}
if context == "[Memory context]\n" {
context.clear();
} else {
context.push('\n');
}
context.push('\n');
}
}
@@ -80,3 +112,135 @@ pub(super) fn build_hardware_context(
context.push('\n');
context
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
use async_trait::async_trait;
use std::sync::Arc;
struct MockMemory {
entries: Arc<Vec<MemoryEntry>>,
}
#[async_trait]
impl Memory for MockMemory {
async fn store(
&self,
_key: &str,
_content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(self.entries.as_ref().clone())
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![])
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(true)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(self.entries.len())
}
async fn health_check(&self) -> bool {
true
}
fn name(&self) -> &str {
"mock-memory"
}
}
#[tokio::test]
async fn build_context_promotes_core_entries_with_score_boost() {
let memory = MockMemory {
entries: Arc::new(vec![
MemoryEntry {
id: "1".into(),
key: "conv_note".into(),
content: "small talk".into(),
category: MemoryCategory::Conversation,
timestamp: "now".into(),
session_id: None,
score: Some(0.6),
},
MemoryEntry {
id: "2".into(),
key: "core_rule".into(),
content: "always provide tests".into(),
category: MemoryCategory::Core,
timestamp: "now".into(),
session_id: None,
score: Some(0.2),
},
MemoryEntry {
id: "3".into(),
key: "conv_low".into(),
content: "irrelevant".into(),
category: MemoryCategory::Conversation,
timestamp: "now".into(),
session_id: None,
score: Some(0.1),
},
]),
};
let context = build_context(&memory, "test query", 0.4, None).await;
assert!(
context.contains("core_rule"),
"expected core boost to include core_rule"
);
assert!(
!context.contains("conv_low"),
"low-score non-core should be filtered"
);
}
#[tokio::test]
async fn build_context_keeps_output_limit_at_five_entries() {
let entries = (0..8)
.map(|idx| MemoryEntry {
id: idx.to_string(),
key: format!("k{idx}"),
content: format!("v{idx}"),
category: MemoryCategory::Conversation,
timestamp: "now".into(),
session_id: None,
score: Some(0.9 - (idx as f64 * 0.01)),
})
.collect::<Vec<_>>();
let memory = MockMemory {
entries: Arc::new(entries),
};
let context = build_context(&memory, "limit", 0.0, None).await;
let listed = context
.lines()
.filter(|line| line.starts_with("- "))
.count();
assert_eq!(listed, 5, "context output limit should remain 5 entries");
}
}
+395 -4
View File
@@ -1,3 +1,4 @@
use crate::memory::{Memory, MemoryCategory};
use crate::providers::{ChatMessage, Provider};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
@@ -12,6 +13,9 @@ const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000;
/// Max characters retained in stored compaction summary.
const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000;
/// Safety cap for durable facts extracted during pre-compaction flush.
const COMPACTION_MAX_FLUSH_FACTS: usize = 8;
/// Trim conversation history to prevent unbounded growth.
/// Preserves the system prompt (first message if role=system) and the most recent messages.
pub(super) fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
@@ -67,6 +71,7 @@ pub(super) async fn auto_compact_history(
model: &str,
max_history: usize,
hooks: Option<&crate::hooks::HookRunner>,
memory: Option<&dyn Memory>,
) -> Result<bool> {
let has_system = history.first().map_or(false, |m| m.role == "system");
let non_system_count = if has_system {
@@ -105,6 +110,13 @@ pub(super) async fn auto_compact_history(
};
let transcript = build_compaction_transcript(&to_compact);
// ── Pre-compaction memory flush ──────────────────────────────────
// Before discarding old messages, ask the LLM to extract durable
// facts and store them as Core memories so they survive compaction.
if let Some(mem) = memory {
flush_durable_facts(provider, model, &transcript, mem).await;
}
let summarizer_system = "You are a conversation compaction engine. Summarize older chat history into concise context for future turns. Preserve: user preferences, commitments, decisions, unresolved tasks, key facts. Omit: filler, repeated chit-chat, verbose tool logs. Output plain text bullet points only.";
let summarizer_user = format!(
@@ -137,6 +149,86 @@ pub(super) async fn auto_compact_history(
Ok(true)
}
/// Extract durable facts from a conversation transcript and store them as
/// `Core` memories. Called before compaction discards old messages.
///
/// Best-effort: failures are logged but never block compaction.
async fn flush_durable_facts(
provider: &dyn Provider,
model: &str,
transcript: &str,
memory: &dyn Memory,
) {
const FLUSH_SYSTEM: &str = "\
You extract durable facts from a conversation that is about to be compacted. \
Output ONLY facts worth remembering long-term — user preferences, project decisions, \
technical constraints, commitments, or important discoveries. \
Output one fact per line, prefixed with a short key in brackets. \
Example:\n\
[preferred_language] User prefers Rust over Go\n\
[db_choice] Project uses PostgreSQL 16\n\
If there are no durable facts, output exactly: NONE";
let flush_user = format!(
"Extract durable facts from this conversation (max 8 facts):\n\n{}",
transcript
);
let response = match provider
.chat_with_system(Some(FLUSH_SYSTEM), &flush_user, model, 0.2)
.await
{
Ok(r) => r,
Err(e) => {
tracing::warn!("Pre-compaction memory flush failed: {e}");
return;
}
};
if response.trim().eq_ignore_ascii_case("NONE") || response.trim().is_empty() {
return;
}
let mut stored = 0usize;
for line in response.lines() {
if stored >= COMPACTION_MAX_FLUSH_FACTS {
break;
}
let line = line.trim();
if line.is_empty() {
continue;
}
// Parse "[key] content" format
if let Some((key, content)) = parse_fact_line(line) {
let prefixed_key = format!("compaction_fact_{key}");
if let Err(e) = memory
.store(&prefixed_key, content, MemoryCategory::Core, None)
.await
{
tracing::warn!("Failed to store compaction fact '{prefixed_key}': {e}");
} else {
stored += 1;
}
}
}
if stored > 0 {
tracing::info!("Pre-compaction flush: stored {stored} durable fact(s) to Core memory");
}
}
/// Parse a `[key] content` line from the fact extraction output.
fn parse_fact_line(line: &str) -> Option<(&str, &str)> {
let line = line.trim_start_matches(|c: char| c == '-' || c.is_whitespace());
let rest = line.strip_prefix('[')?;
let close = rest.find(']')?;
let key = rest[..close].trim();
let content = rest[close + 1..].trim();
if key.is_empty() || content.is_empty() {
return None;
}
Some((key, content))
}
#[cfg(test)]
mod tests {
use super::*;
@@ -213,10 +305,16 @@ mod tests {
// previously cut right before the tool result (index 2).
assert_eq!(history.len(), 22);
let compacted =
auto_compact_history(&mut history, &StaticSummaryProvider, "test-model", 21, None)
.await
.expect("compaction should succeed");
let compacted = auto_compact_history(
&mut history,
&StaticSummaryProvider,
"test-model",
21,
None,
None,
)
.await
.expect("compaction should succeed");
assert!(compacted);
assert_eq!(history[0].role, "assistant");
@@ -229,4 +327,297 @@ mod tests {
"first retained message must not be an orphan tool result"
);
}
#[test]
fn parse_fact_line_extracts_key_and_content() {
assert_eq!(
parse_fact_line("[preferred_language] User prefers Rust over Go"),
Some(("preferred_language", "User prefers Rust over Go"))
);
}
#[test]
fn parse_fact_line_handles_leading_dash() {
assert_eq!(
parse_fact_line("- [db_choice] Project uses PostgreSQL 16"),
Some(("db_choice", "Project uses PostgreSQL 16"))
);
}
#[test]
fn parse_fact_line_rejects_empty_key_or_content() {
assert_eq!(parse_fact_line("[] some content"), None);
assert_eq!(parse_fact_line("[key]"), None);
assert_eq!(parse_fact_line("[key] "), None);
}
#[test]
fn parse_fact_line_rejects_malformed_input() {
assert_eq!(parse_fact_line("no brackets here"), None);
assert_eq!(parse_fact_line(""), None);
assert_eq!(parse_fact_line("[unclosed bracket"), None);
}
#[tokio::test]
async fn auto_compact_with_memory_stores_durable_facts() {
use crate::memory::{MemoryCategory, MemoryEntry};
use std::sync::{Arc, Mutex};
struct FactCapture {
stored: Mutex<Vec<(String, String)>>,
}
#[async_trait]
impl Memory for FactCapture {
async fn store(
&self,
key: &str,
content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
self.stored
.lock()
.unwrap()
.push((key.to_string(), content.to_string()));
Ok(())
}
async fn recall(
&self,
_q: &str,
_l: usize,
_s: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![])
}
async fn get(&self, _k: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_c: Option<&MemoryCategory>,
_s: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![])
}
async fn forget(&self, _k: &str) -> anyhow::Result<bool> {
Ok(true)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(0)
}
async fn health_check(&self) -> bool {
true
}
fn name(&self) -> &str {
"fact-capture"
}
}
/// Provider that returns facts for the first call (flush) and summary for the second (compaction).
struct FlushThenSummaryProvider {
call_count: Mutex<usize>,
}
#[async_trait]
impl Provider for FlushThenSummaryProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let mut count = self.call_count.lock().unwrap();
*count += 1;
if *count == 1 {
// flush_durable_facts call
Ok("[lang] User prefers Rust\n[db] PostgreSQL 16".to_string())
} else {
// summarizer call
Ok("- summarized context".to_string())
}
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
text: Some("- summarized context".to_string()),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
quota_metadata: None,
})
}
}
let mem = Arc::new(FactCapture {
stored: Mutex::new(Vec::new()),
});
let provider = FlushThenSummaryProvider {
call_count: Mutex::new(0),
};
let mut history: Vec<ChatMessage> = Vec::new();
for i in 0..25 {
history.push(ChatMessage::user(format!("msg-{i}")));
}
let compacted = auto_compact_history(
&mut history,
&provider,
"test-model",
21,
None,
Some(mem.as_ref()),
)
.await
.expect("compaction should succeed");
assert!(compacted);
let stored = mem.stored.lock().unwrap();
assert_eq!(stored.len(), 2, "should store 2 durable facts");
assert_eq!(stored[0].0, "compaction_fact_lang");
assert_eq!(stored[0].1, "User prefers Rust");
assert_eq!(stored[1].0, "compaction_fact_db");
assert_eq!(stored[1].1, "PostgreSQL 16");
}
#[tokio::test]
async fn auto_compact_with_memory_caps_fact_flush_at_eight_entries() {
use crate::memory::{MemoryCategory, MemoryEntry};
use std::sync::{Arc, Mutex};
struct FactCapture {
stored: Mutex<Vec<(String, String)>>,
}
#[async_trait]
impl Memory for FactCapture {
async fn store(
&self,
key: &str,
content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
self.stored
.lock()
.expect("fact capture lock")
.push((key.to_string(), content.to_string()));
Ok(())
}
async fn recall(
&self,
_q: &str,
_l: usize,
_s: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![])
}
async fn get(&self, _k: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_c: Option<&MemoryCategory>,
_s: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![])
}
async fn forget(&self, _k: &str) -> anyhow::Result<bool> {
Ok(true)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(0)
}
async fn health_check(&self) -> bool {
true
}
fn name(&self) -> &str {
"fact-capture-cap"
}
}
struct FlushManyFactsProvider {
call_count: Mutex<usize>,
}
#[async_trait]
impl Provider for FlushManyFactsProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let mut count = self.call_count.lock().expect("provider lock");
*count += 1;
if *count == 1 {
let lines = (0..12)
.map(|idx| format!("[k{idx}] fact-{idx}"))
.collect::<Vec<_>>()
.join("\n");
Ok(lines)
} else {
Ok("- summarized context".to_string())
}
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
text: Some("- summarized context".to_string()),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
quota_metadata: None,
})
}
}
let mem = Arc::new(FactCapture {
stored: Mutex::new(Vec::new()),
});
let provider = FlushManyFactsProvider {
call_count: Mutex::new(0),
};
let mut history = (0..30)
.map(|idx| ChatMessage::user(format!("msg-{idx}")))
.collect::<Vec<_>>();
let compacted = auto_compact_history(
&mut history,
&provider,
"test-model",
21,
None,
Some(mem.as_ref()),
)
.await
.expect("compaction should succeed");
assert!(compacted);
let stored = mem.stored.lock().expect("fact capture lock");
assert_eq!(stored.len(), COMPACTION_MAX_FLUSH_FACTS);
assert_eq!(stored[0].0, "compaction_fact_k0");
assert_eq!(stored[7].0, "compaction_fact_k7");
}
}
+134 -16
View File
@@ -1,7 +1,18 @@
use crate::memory::{self, Memory};
use crate::memory::{self, decay, Memory, MemoryCategory};
use async_trait::async_trait;
use std::fmt::Write;
/// Default half-life (days) for time decay in memory loading.
const LOADER_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
/// Score boost applied to `Core` category memories so durable facts and
/// preferences surface even when keyword/semantic similarity is moderate.
const CORE_CATEGORY_SCORE_BOOST: f64 = 0.3;
/// Over-fetch factor: retrieve more candidates than the output limit so
/// that Core boost and re-ranking can select the best subset.
const RECALL_OVER_FETCH_FACTOR: usize = 2;
#[async_trait]
pub trait MemoryLoader: Send + Sync {
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
@@ -38,29 +49,47 @@ impl MemoryLoader for DefaultMemoryLoader {
memory: &dyn Memory,
user_message: &str,
) -> anyhow::Result<String> {
let entries = memory.recall(user_message, self.limit, None).await?;
// Over-fetch so Core-boosted entries can compete fairly after re-ranking.
let fetch_limit = self.limit * RECALL_OVER_FETCH_FACTOR;
let mut entries = memory.recall(user_message, fetch_limit, None).await?;
if entries.is_empty() {
return Ok(String::new());
}
let mut context = String::from("[Memory context]\n");
for entry in entries {
if memory::is_assistant_autosave_key(&entry.key) {
continue;
}
if let Some(score) = entry.score {
if score < self.min_relevance_score {
continue;
}
}
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
}
// Apply time decay: older non-Core memories score lower.
decay::apply_time_decay(&mut entries, LOADER_DECAY_HALF_LIFE_DAYS);
// If all entries were below threshold, return empty
if context == "[Memory context]\n" {
// Apply Core category boost and filter by minimum relevance.
let mut scored: Vec<_> = entries
.iter()
.filter(|e| !memory::is_assistant_autosave_key(&e.key))
.filter_map(|e| {
let base = e.score.unwrap_or(self.min_relevance_score);
let boosted = if e.category == MemoryCategory::Core {
(base + CORE_CATEGORY_SCORE_BOOST).min(1.0)
} else {
base
};
if boosted >= self.min_relevance_score {
Some((e, boosted))
} else {
None
}
})
.collect();
// Sort by boosted score descending, then truncate to output limit.
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(self.limit);
if scored.is_empty() {
return Ok(String::new());
}
let mut context = String::from("[Memory context]\n");
for (entry, _) in &scored {
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
}
context.push('\n');
Ok(context)
}
@@ -227,4 +256,93 @@ mod tests {
assert!(!context.contains("assistant_resp_legacy"));
assert!(!context.contains("fabricated detail"));
}
#[tokio::test]
async fn core_category_boost_promotes_low_score_core_entry() {
let loader = DefaultMemoryLoader::new(2, 0.4);
let memory = MockMemoryWithEntries {
entries: Arc::new(vec![
MemoryEntry {
id: "1".into(),
key: "chat_detail".into(),
content: "talked about weather".into(),
category: MemoryCategory::Conversation,
timestamp: "now".into(),
session_id: None,
score: Some(0.6),
},
MemoryEntry {
id: "2".into(),
key: "project_rule".into(),
content: "always use async/await".into(),
category: MemoryCategory::Core,
timestamp: "now".into(),
session_id: None,
// Below threshold without boost (0.25 < 0.4),
// but above with +0.3 boost (0.55 >= 0.4).
score: Some(0.25),
},
MemoryEntry {
id: "3".into(),
key: "low_conv".into(),
content: "irrelevant chatter".into(),
category: MemoryCategory::Conversation,
timestamp: "now".into(),
session_id: None,
score: Some(0.2),
},
]),
};
let context = loader.load_context(&memory, "code style").await.unwrap();
// Core entry should survive thanks to boost
assert!(
context.contains("project_rule"),
"Core entry should be promoted by boost: {context}"
);
// Low-score Conversation entry should be filtered out
assert!(
!context.contains("low_conv"),
"Low-score non-Core entry should be filtered: {context}"
);
}
#[tokio::test]
async fn core_boost_reranks_above_conversation() {
let loader = DefaultMemoryLoader::new(1, 0.0);
let memory = MockMemoryWithEntries {
entries: Arc::new(vec![
MemoryEntry {
id: "1".into(),
key: "conv_high".into(),
content: "recent conversation".into(),
category: MemoryCategory::Conversation,
timestamp: "now".into(),
session_id: None,
score: Some(0.6),
},
MemoryEntry {
id: "2".into(),
key: "core_pref".into(),
content: "user prefers Rust".into(),
category: MemoryCategory::Core,
timestamp: "now".into(),
session_id: None,
// 0.5 + 0.3 boost = 0.8 > 0.6
score: Some(0.5),
},
]),
};
let context = loader.load_context(&memory, "language").await.unwrap();
// With limit=1 and Core boost, Core entry (0.8) should win over Conversation (0.6)
assert!(
context.contains("core_pref"),
"Boosted Core should rank above Conversation: {context}"
);
assert!(
!context.contains("conv_high"),
"Conversation should be truncated when limit=1: {context}"
);
}
}
+1 -1
View File
@@ -16,4 +16,4 @@ mod tests;
#[allow(unused_imports)]
pub use agent::{Agent, AgentBuilder};
#[allow(unused_imports)]
pub use loop_::{process_message, process_message_with_session, run};
pub use loop_::{process_message, process_message_with_session, run, run_tool_call_loop};
+14
View File
@@ -736,6 +736,20 @@ async fn native_dispatcher_sends_tool_specs() {
assert!(dispatcher.should_send_tool_specs());
}
#[test]
fn agent_tool_specs_accessor_exposes_registered_tools() {
let provider = Box::new(ScriptedProvider::new(vec![text_response("ok")]));
let agent = build_agent_with(
provider,
vec![Box::new(EchoTool)],
Box::new(NativeToolDispatcher),
);
let specs = agent.tool_specs();
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].name, "echo");
}
#[tokio::test]
async fn xml_dispatcher_does_not_send_tool_specs() {
let dispatcher = XmlToolDispatcher;
+163 -19
View File
@@ -251,6 +251,14 @@ struct ChannelRuntimeDefaults {
api_url: Option<String>,
reliability: crate::config::ReliabilityConfig,
cost: crate::config::CostConfig,
auto_save_memory: bool,
max_tool_iterations: usize,
min_relevance_score: f64,
message_timeout_secs: u64,
interrupt_on_new_message: bool,
multimodal: crate::config::MultimodalConfig,
query_classification: crate::config::QueryClassificationConfig,
model_routes: Vec<crate::config::ModelRouteConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -1048,6 +1056,14 @@ fn resolved_default_model(config: &Config) -> String {
}
fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
let message_timeout_secs =
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
let interrupt_on_new_message = config
.channels_config
.telegram
.as_ref()
.is_some_and(|tg| tg.interrupt_on_new_message);
ChannelRuntimeDefaults {
default_provider: resolved_default_provider(config),
model: resolved_default_model(config),
@@ -1056,6 +1072,14 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
api_url: config.api_url.clone(),
reliability: config.reliability.clone(),
cost: config.cost.clone(),
auto_save_memory: config.memory.auto_save,
max_tool_iterations: config.agent.max_tool_iterations,
min_relevance_score: config.memory.min_relevance_score,
message_timeout_secs,
interrupt_on_new_message,
multimodal: config.multimodal.clone(),
query_classification: config.query_classification.clone(),
model_routes: config.model_routes.clone(),
}
}
@@ -1102,6 +1126,14 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau
api_url: ctx.api_url.clone(),
reliability: (*ctx.reliability).clone(),
cost: crate::config::CostConfig::default(),
auto_save_memory: ctx.auto_save_memory,
max_tool_iterations: ctx.max_tool_iterations,
min_relevance_score: ctx.min_relevance_score,
message_timeout_secs: ctx.message_timeout_secs,
interrupt_on_new_message: ctx.interrupt_on_new_message,
multimodal: ctx.multimodal.clone(),
query_classification: ctx.query_classification.clone(),
model_routes: ctx.model_routes.clone(),
}
}
@@ -1722,14 +1754,14 @@ fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> Channel
/// Classify a user message and return the appropriate route selection with logging.
/// Returns None if classification is disabled or no rules match.
fn classify_message_route(
ctx: &ChannelRuntimeContext,
query_classification: &crate::config::QueryClassificationConfig,
model_routes: &[crate::config::ModelRouteConfig],
message: &str,
) -> Option<ChannelRouteSelection> {
let decision =
crate::agent::classifier::classify_with_decision(&ctx.query_classification, message)?;
let decision = crate::agent::classifier::classify_with_decision(query_classification, message)?;
// Find the matching model route
let route = ctx.model_routes.iter().find(|r| r.hint == decision.hint)?;
let route = model_routes.iter().find(|r| r.hint == decision.hint)?;
tracing::info!(
target: "query_classification",
@@ -1956,9 +1988,9 @@ async fn get_or_create_provider(
let provider = create_resilient_provider_nonblocking(
provider_name,
ctx.api_key.clone(),
defaults.api_key.clone(),
api_url.map(ToString::to_string),
ctx.reliability.as_ref().clone(),
defaults.reliability.clone(),
ctx.provider_runtime_options.clone(),
)
.await?;
@@ -3446,10 +3478,14 @@ or tune thresholds in config.",
}
}
}
// Try classification first, fall back to sender/default route
let route = classify_message_route(ctx.as_ref(), &msg.content)
.unwrap_or_else(|| get_route_selection(ctx.as_ref(), &history_key));
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
// Try classification first, fall back to sender/default route.
let route = classify_message_route(
&runtime_defaults.query_classification,
&runtime_defaults.model_routes,
&msg.content,
)
.unwrap_or_else(|| get_route_selection(ctx.as_ref(), &history_key));
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
Ok(provider) => provider,
Err(err) => {
@@ -3469,7 +3505,9 @@ or tune thresholds in config.",
return;
}
};
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
if runtime_defaults.auto_save_memory
&& msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
{
let autosave_key = conversation_memory_key(&msg);
let _ = ctx
.memory
@@ -3532,7 +3570,7 @@ or tune thresholds in config.",
let memory_context = build_memory_context(
ctx.memory.as_ref(),
&msg.content,
ctx.min_relevance_score,
runtime_defaults.min_relevance_score,
Some(&history_key),
)
.await;
@@ -3686,8 +3724,10 @@ or tune thresholds in config.",
Cancelled,
}
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
let timeout_budget_secs = channel_message_timeout_budget_secs(
runtime_defaults.message_timeout_secs,
runtime_defaults.max_tool_iterations,
);
let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context(
&runtime_defaults.cost,
ctx.workspace_dir.as_path(),
@@ -3751,8 +3791,8 @@ or tune thresholds in config.",
Some(ctx.approval_manager.as_ref()),
msg.channel.as_str(),
non_cli_approval_context,
&ctx.multimodal,
ctx.max_tool_iterations,
&runtime_defaults.multimodal,
runtime_defaults.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx,
ctx.hooks.as_deref(),
@@ -3931,7 +3971,7 @@ or tune thresholds in config.",
&history_key,
ChatMessage::assistant(&history_response),
);
if ctx.auto_save_memory
if runtime_defaults.auto_save_memory
&& delivered_response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
{
let assistant_key = assistant_memory_key(&msg);
@@ -4044,7 +4084,7 @@ or tune thresholds in config.",
}
}
} else if is_tool_iteration_limit_error(&e) {
let limit = ctx.max_tool_iterations.max(1);
let limit = runtime_defaults.max_tool_iterations.max(1);
let pause_text = format!(
"⚠️ Reached tool-iteration limit ({limit}) for this turn. Context and progress were preserved. Reply \"continue\" to resume, or increase `agent.max_tool_iterations`."
);
@@ -4140,7 +4180,9 @@ or tune thresholds in config.",
LlmExecutionResult::Completed(Err(_)) => {
let timeout_msg = format!(
"LLM response timed out after {}s (base={}s, max_tool_iterations={})",
timeout_budget_secs, ctx.message_timeout_secs, ctx.max_tool_iterations
timeout_budget_secs,
runtime_defaults.message_timeout_secs,
runtime_defaults.max_tool_iterations
);
runtime_trace::record_event(
"channel_message_timeout",
@@ -4221,8 +4263,9 @@ async fn run_message_dispatch_loop(
let task_sequence = Arc::clone(&task_sequence);
workers.spawn(async move {
let _permit = permit;
let runtime_defaults = runtime_defaults_snapshot(worker_ctx.as_ref());
let interrupt_enabled =
worker_ctx.interrupt_on_new_message && msg.channel == "telegram";
runtime_defaults.interrupt_on_new_message && msg.channel == "telegram";
let sender_scope_key = interruption_scope_key(&msg);
let cancellation_token = CancellationToken::new();
let completion = Arc::new(InFlightTaskCompletion::new());
@@ -6747,6 +6790,36 @@ BTC is currently around $65,000 based on latest tool output."#
}
}
struct MockProcessTool;
#[async_trait::async_trait]
impl Tool for MockProcessTool {
fn name(&self) -> &str {
"process"
}
fn description(&self) -> &str {
"Mock process tool for runtime visibility tests"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"action": { "type": "string" }
}
})
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
Ok(ToolResult {
success: true,
output: String::new(),
error: None,
})
}
}
#[test]
fn build_runtime_tool_visibility_prompt_respects_excluded_snapshot() {
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockPriceTool), Box::new(MockEchoTool)];
@@ -6765,6 +6838,23 @@ BTC is currently around $65,000 based on latest tool output."#
assert!(!native.contains("## Tool Use Protocol"));
}
#[test]
fn build_runtime_tool_visibility_prompt_excludes_process_with_default_policy() {
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockProcessTool), Box::new(MockEchoTool)];
let excluded = crate::config::AutonomyConfig::default().non_cli_excluded_tools;
assert!(
excluded.contains(&"process".to_string()),
"default non-CLI exclusion list must include process"
);
let prompt = build_runtime_tool_visibility_prompt(&tools, &excluded, false);
assert!(prompt.contains("Excluded by runtime policy:"));
assert!(prompt.contains("process"));
assert!(!prompt.contains("**process**:"));
assert!(prompt.contains("`mock_echo`"));
}
#[tokio::test]
async fn process_channel_message_injects_runtime_tool_visibility_prompt() {
let channel_impl = Arc::new(RecordingChannel::default());
@@ -9456,6 +9546,14 @@ BTC is currently around $65,000 based on latest tool output."#
api_url: None,
reliability: crate::config::ReliabilityConfig::default(),
cost: crate::config::CostConfig::default(),
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: false,
multimodal: crate::config::MultimodalConfig::default(),
query_classification: crate::config::QueryClassificationConfig::default(),
model_routes: Vec::new(),
},
perplexity_filter: crate::config::PerplexityFilterConfig::default(),
outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(),
@@ -9638,6 +9736,13 @@ BTC is currently around $65,000 based on latest tool output."#
cfg.default_provider = Some("ollama".to_string());
cfg.default_model = Some("llama3.2".to_string());
cfg.api_key = Some("http://127.0.0.1:11434".to_string());
cfg.memory.auto_save = false;
cfg.memory.min_relevance_score = 0.15;
cfg.agent.max_tool_iterations = 5;
cfg.channels_config.message_timeout_secs = 45;
cfg.multimodal.allow_remote_fetch = false;
cfg.query_classification.enabled = false;
cfg.model_routes = vec![];
cfg.autonomy.non_cli_natural_language_approval_mode =
crate::config::NonCliNaturalLanguageApprovalMode::Direct;
cfg.autonomy.non_cli_excluded_tools = vec!["shell".to_string()];
@@ -9704,6 +9809,14 @@ BTC is currently around $65,000 based on latest tool output."#
runtime_outbound_leak_guard_snapshot(runtime_ctx.as_ref()).action,
crate::config::OutboundLeakGuardAction::Redact
);
let defaults = runtime_defaults_snapshot(runtime_ctx.as_ref());
assert!(!defaults.auto_save_memory);
assert_eq!(defaults.min_relevance_score, 0.15);
assert_eq!(defaults.max_tool_iterations, 5);
assert_eq!(defaults.message_timeout_secs, 45);
assert!(!defaults.multimodal.allow_remote_fetch);
assert!(!defaults.query_classification.enabled);
assert!(defaults.model_routes.is_empty());
cfg.autonomy.non_cli_natural_language_approval_mode =
crate::config::NonCliNaturalLanguageApprovalMode::Disabled;
@@ -9719,6 +9832,28 @@ BTC is currently around $65,000 based on latest tool output."#
cfg.security.perplexity_filter.perplexity_threshold = 12.5;
cfg.security.outbound_leak_guard.action = crate::config::OutboundLeakGuardAction::Block;
cfg.security.outbound_leak_guard.sensitivity = 0.92;
cfg.memory.auto_save = true;
cfg.memory.min_relevance_score = 0.65;
cfg.agent.max_tool_iterations = 11;
cfg.channels_config.message_timeout_secs = 120;
cfg.multimodal.allow_remote_fetch = true;
cfg.query_classification.enabled = true;
cfg.query_classification.rules = vec![crate::config::ClassificationRule {
hint: "reasoning".to_string(),
keywords: vec!["analyze".to_string()],
patterns: vec!["deep".to_string()],
min_length: None,
max_length: None,
priority: 10,
}];
cfg.model_routes = vec![crate::config::ModelRouteConfig {
hint: "reasoning".to_string(),
provider: "openrouter".to_string(),
model: "openai/gpt-5.2".to_string(),
max_tokens: Some(512),
api_key: None,
transport: None,
}];
cfg.save().await.expect("save updated config");
maybe_apply_runtime_config_update(runtime_ctx.as_ref())
@@ -9750,6 +9885,15 @@ BTC is currently around $65,000 based on latest tool output."#
crate::config::OutboundLeakGuardAction::Block
);
assert_eq!(leak_guard_cfg.sensitivity, 0.92);
let defaults = runtime_defaults_snapshot(runtime_ctx.as_ref());
assert!(defaults.auto_save_memory);
assert_eq!(defaults.min_relevance_score, 0.65);
assert_eq!(defaults.max_tool_iterations, 11);
assert_eq!(defaults.message_timeout_secs, 120);
assert!(defaults.multimodal.allow_remote_fetch);
assert!(defaults.query_classification.enabled);
assert_eq!(defaults.query_classification.rules.len(), 1);
assert_eq!(defaults.model_routes.len(), 1);
let mut store = runtime_config_store()
.lock()
+59 -22
View File
@@ -589,6 +589,40 @@ impl TelegramChannel {
body
}
fn build_approval_prompt_body(
chat_id: &str,
thread_id: Option<&str>,
request_id: &str,
tool_name: &str,
args_preview: &str,
) -> serde_json::Value {
let mut body = serde_json::json!({
"chat_id": chat_id,
"text": format!(
"Approval required for tool `{tool_name}`.\nRequest ID: `{request_id}`\nArgs: `{args_preview}`",
),
"parse_mode": "Markdown",
"reply_markup": {
"inline_keyboard": [[
{
"text": "Approve",
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_APPROVE_PREFIX}{request_id}")
},
{
"text": "Deny",
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_DENY_PREFIX}{request_id}")
}
]]
}
});
if let Some(thread_id) = thread_id {
body["message_thread_id"] = serde_json::Value::String(thread_id.to_string());
}
body
}
fn extract_update_message_ack_target(
update: &serde_json::Value,
) -> Option<(String, i64, AckReactionContextChatType, Option<String>)> {
@@ -3153,28 +3187,13 @@ impl Channel for TelegramChannel {
raw_args
};
let mut body = serde_json::json!({
"chat_id": chat_id,
"text": format!(
"Approval required for tool `{tool_name}`.\nRequest ID: `{request_id}`\nArgs: `{args_preview}`",
),
"reply_markup": {
"inline_keyboard": [[
{
"text": "Approve",
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_APPROVE_PREFIX}{request_id}")
},
{
"text": "Deny",
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_DENY_PREFIX}{request_id}")
}
]]
}
});
if let Some(thread_id) = thread_id {
body["message_thread_id"] = serde_json::Value::String(thread_id);
}
let body = Self::build_approval_prompt_body(
&chat_id,
thread_id.as_deref(),
request_id,
tool_name,
&args_preview,
);
let response = self
.http_client()
@@ -3654,6 +3673,24 @@ mod tests {
);
}
#[test]
fn approval_prompt_includes_markdown_parse_mode() {
let body = TelegramChannel::build_approval_prompt_body(
"12345",
Some("67890"),
"apr-1234",
"shell",
"{\"command\":\"echo hello\"}",
);
assert_eq!(body["parse_mode"], "Markdown");
assert_eq!(body["chat_id"], "12345");
assert_eq!(body["message_thread_id"], "67890");
assert!(body["text"]
.as_str()
.is_some_and(|text| text.contains("`shell`")));
}
#[test]
fn sanitize_telegram_error_redacts_bot_token_in_url() {
let input =
+7 -6
View File
@@ -10,12 +10,13 @@ pub use schema::{
AckReactionRuleConfig, AckReactionStrategy, AgentConfig, AgentSessionBackend,
AgentSessionConfig, AgentSessionStrategy, AgentsIpcConfig, AuditConfig, AutonomyConfig,
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
ClassificationRule, ComposioConfig, Config, CoordinationConfig, CostConfig, CronConfig,
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EconomicConfig, EconomicTokenPricing,
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GroupReplyConfig,
GroupReplyMode, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
HttpRequestConfig, HttpRequestCredentialProfile, IMessageConfig, IdentityConfig, LarkConfig,
MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
ClassificationRule, CommandContextRuleAction, CommandContextRuleConfig, ComposioConfig, Config,
CoordinationConfig, CostConfig, CronConfig, DelegateAgentConfig, DiscordConfig,
DockerRuntimeConfig, EconomicConfig, EconomicTokenPricing, EmbeddingRouteConfig, EstopConfig,
FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig,
HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
HttpRequestCredentialProfile, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig,
MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
NonCliNaturalLanguageApprovalMode, ObservabilityConfig, OtpChallengeDelivery, OtpConfig,
OtpMethod, OutboundLeakGuardAction, OutboundLeakGuardConfig, PeripheralBoardConfig,
PeripheralsConfig, PerplexityFilterConfig, PluginEntryConfig, PluginsConfig, ProgressMode,
+169
View File
@@ -3134,6 +3134,67 @@ pub enum NonCliNaturalLanguageApprovalMode {
Direct,
}
/// Action to apply when a command-context rule matches.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum CommandContextRuleAction {
/// Matching context is explicitly allowed.
#[default]
Allow,
/// Matching context is explicitly denied.
Deny,
}
/// Context-aware allow/deny rule for shell commands.
///
/// Rules are evaluated per command segment. Command matching accepts command
/// names (`curl`), explicit paths (`/usr/bin/curl`), and wildcard (`*`).
///
/// Matching semantics:
/// - `action = "deny"`: if all constraints match, the segment is rejected.
/// - `action = "allow"`: if at least one allow rule exists for a command,
/// segments must match at least one of those allow rules.
///
/// Constraints are optional:
/// - `allowed_domains`: require URL arguments to match these hosts/patterns.
/// - `allowed_path_prefixes`: require path-like arguments to stay under these prefixes.
/// - `denied_path_prefixes`: for deny rules, match when any path-like argument
/// is under these prefixes; for allow rules, require path arguments not to hit
/// these prefixes.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
pub struct CommandContextRuleConfig {
/// Command name/path pattern (`git`, `/usr/bin/curl`, or `*`).
pub command: String,
/// Rule action (`allow` | `deny`). Defaults to `allow`.
#[serde(default)]
pub action: CommandContextRuleAction,
/// Allowed host patterns for URL arguments.
///
/// Supports exact hosts (`api.example.com`) and wildcard suffixes (`*.example.com`).
#[serde(default)]
pub allowed_domains: Vec<String>,
/// Allowed path prefixes for path-like arguments.
///
/// Prefixes may be absolute, `~/...`, or workspace-relative.
#[serde(default)]
pub allowed_path_prefixes: Vec<String>,
/// Denied path prefixes for path-like arguments.
///
/// Prefixes may be absolute, `~/...`, or workspace-relative.
#[serde(default)]
pub denied_path_prefixes: Vec<String>,
/// Permit high-risk commands when this allow rule matches.
///
/// The command still requires explicit `approved=true` in supervised mode.
#[serde(default)]
pub allow_high_risk: bool,
}
/// Autonomy and security policy configuration (`[autonomy]` section).
///
/// Controls what the agent is allowed to do: shell commands, filesystem access,
@@ -3147,6 +3208,13 @@ pub struct AutonomyConfig {
pub workspace_only: bool,
/// Allowlist of executable names permitted for shell execution.
pub allowed_commands: Vec<String>,
/// Context-aware shell command allow/deny rules.
///
/// These rules are evaluated per command segment and can narrow or override
/// global `allowed_commands` behavior for matching commands.
#[serde(default)]
pub command_context_rules: Vec<CommandContextRuleConfig>,
/// Explicit path denylist. Default includes system-critical paths and sensitive dotdirs.
pub forbidden_paths: Vec<String>,
/// Maximum actions allowed per hour per policy. Default: `100`.
@@ -3251,6 +3319,7 @@ fn default_always_ask() -> Vec<String> {
fn default_non_cli_excluded_tools() -> Vec<String> {
[
"shell",
"process",
"file_write",
"file_edit",
"git_operations",
@@ -3309,6 +3378,7 @@ impl Default for AutonomyConfig {
"tail".into(),
"date".into(),
],
command_context_rules: Vec::new(),
forbidden_paths: vec![
"/etc".into(),
"/root".into(),
@@ -7514,6 +7584,61 @@ impl Config {
);
}
}
for (i, rule) in self.autonomy.command_context_rules.iter().enumerate() {
let command = rule.command.trim();
if command.is_empty() {
anyhow::bail!("autonomy.command_context_rules[{i}].command must not be empty");
}
if !command
.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '/' | '.' | '*'))
{
anyhow::bail!(
"autonomy.command_context_rules[{i}].command contains invalid characters: {command}"
);
}
for (j, domain) in rule.allowed_domains.iter().enumerate() {
let normalized = domain.trim();
if normalized.is_empty() {
anyhow::bail!(
"autonomy.command_context_rules[{i}].allowed_domains[{j}] must not be empty"
);
}
if normalized.chars().any(char::is_whitespace) {
anyhow::bail!(
"autonomy.command_context_rules[{i}].allowed_domains[{j}] must not contain whitespace"
);
}
}
for (j, prefix) in rule.allowed_path_prefixes.iter().enumerate() {
let normalized = prefix.trim();
if normalized.is_empty() {
anyhow::bail!(
"autonomy.command_context_rules[{i}].allowed_path_prefixes[{j}] must not be empty"
);
}
if normalized.contains('\0') {
anyhow::bail!(
"autonomy.command_context_rules[{i}].allowed_path_prefixes[{j}] must not contain null bytes"
);
}
}
for (j, prefix) in rule.denied_path_prefixes.iter().enumerate() {
let normalized = prefix.trim();
if normalized.is_empty() {
anyhow::bail!(
"autonomy.command_context_rules[{i}].denied_path_prefixes[{j}] must not be empty"
);
}
if normalized.contains('\0') {
anyhow::bail!(
"autonomy.command_context_rules[{i}].denied_path_prefixes[{j}] must not contain null bytes"
);
}
}
}
let mut seen_non_cli_excluded = std::collections::HashSet::new();
for (i, tool_name) in self.autonomy.non_cli_excluded_tools.iter().enumerate() {
let normalized = tool_name.trim();
@@ -9298,9 +9423,11 @@ mod tests {
assert!(a.require_approval_for_medium_risk);
assert!(a.block_high_risk_commands);
assert!(a.shell_env_passthrough.is_empty());
assert!(a.command_context_rules.is_empty());
assert!(!a.allow_sensitive_file_reads);
assert!(!a.allow_sensitive_file_writes);
assert!(a.non_cli_excluded_tools.contains(&"shell".to_string()));
assert!(a.non_cli_excluded_tools.contains(&"process".to_string()));
assert!(a.non_cli_excluded_tools.contains(&"delegate".to_string()));
}
@@ -9329,12 +9456,53 @@ allowed_roots = []
!parsed.allow_sensitive_file_writes,
"Missing allow_sensitive_file_writes must default to false"
);
assert!(
parsed.command_context_rules.is_empty(),
"Missing command_context_rules must default to empty"
);
assert!(parsed.non_cli_excluded_tools.contains(&"shell".to_string()));
assert!(parsed
.non_cli_excluded_tools
.contains(&"process".to_string()));
assert!(parsed
.non_cli_excluded_tools
.contains(&"browser".to_string()));
}
#[test]
async fn config_validate_rejects_invalid_command_context_rule_command() {
let mut cfg = Config::default();
cfg.autonomy.command_context_rules = vec![CommandContextRuleConfig {
command: "curl;rm".into(),
action: CommandContextRuleAction::Allow,
allowed_domains: vec![],
allowed_path_prefixes: vec![],
denied_path_prefixes: vec![],
allow_high_risk: false,
}];
let err = cfg.validate().unwrap_err();
assert!(err
.to_string()
.contains("autonomy.command_context_rules[0].command"));
}
#[test]
async fn config_validate_rejects_empty_command_context_rule_domain() {
let mut cfg = Config::default();
cfg.autonomy.command_context_rules = vec![CommandContextRuleConfig {
command: "curl".into(),
action: CommandContextRuleAction::Allow,
allowed_domains: vec![" ".into()],
allowed_path_prefixes: vec![],
denied_path_prefixes: vec![],
allow_high_risk: true,
}];
let err = cfg.validate().unwrap_err();
assert!(err
.to_string()
.contains("autonomy.command_context_rules[0].allowed_domains[0]"));
}
#[test]
async fn config_validate_rejects_duplicate_non_cli_excluded_tools() {
let mut cfg = Config::default();
@@ -9530,6 +9698,7 @@ ws_url = "ws://127.0.0.1:3002"
level: AutonomyLevel::Full,
workspace_only: false,
allowed_commands: vec!["docker".into()],
command_context_rules: vec![],
forbidden_paths: vec!["/secret".into()],
max_actions_per_hour: 50,
max_cost_per_day_cents: 1000,
+152
View File
@@ -0,0 +1,152 @@
use super::traits::{MemoryCategory, MemoryEntry};
use chrono::{DateTime, Utc};
/// Default half-life in days for time-decay scoring.
/// After this many days, a non-Core memory's score drops to 50%.
const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
/// Apply exponential time decay to memory entry scores.
///
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
/// - Entries without a score (`None`) are left unchanged.
///
/// Decay formula: `score * 2^(-age_days / half_life_days)`
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
let half_life = if half_life_days <= 0.0 {
DEFAULT_HALF_LIFE_DAYS
} else {
half_life_days
};
let now = Utc::now();
for entry in entries.iter_mut() {
// Core memories are evergreen — never decay
if entry.category == MemoryCategory::Core {
continue;
}
let score = match entry.score {
Some(s) => s,
None => continue,
};
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
Ok(dt) => dt.with_timezone(&Utc),
Err(_) => continue,
};
let age_days = now
.signed_duration_since(ts)
.num_seconds()
.max(0) as f64
/ 86_400.0;
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
entry.score = Some(score * decay_factor);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
MemoryEntry {
id: "1".into(),
key: "test".into(),
content: "value".into(),
category,
timestamp: timestamp.into(),
session_id: None,
score,
}
}
fn recent_rfc3339() -> String {
Utc::now().to_rfc3339()
}
fn days_ago_rfc3339(days: i64) -> String {
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
}
#[test]
fn core_memories_are_never_decayed() {
let mut entries = vec![make_entry(
MemoryCategory::Core,
Some(0.9),
&days_ago_rfc3339(30),
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, Some(0.9));
}
#[test]
fn recent_entry_score_barely_changes() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(0.8),
&recent_rfc3339(),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.8).abs() < 0.01,
"recent entry should barely decay, got {decayed}"
);
}
#[test]
fn one_half_life_halves_score() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(1.0),
&days_ago_rfc3339(7),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.5).abs() < 0.05,
"score after one half-life should be ~0.5, got {decayed}"
);
}
#[test]
fn two_half_lives_quarters_score() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(1.0),
&days_ago_rfc3339(14),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.25).abs() < 0.05,
"score after two half-lives should be ~0.25, got {decayed}"
);
}
#[test]
fn no_score_entry_is_unchanged() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
None,
&days_ago_rfc3339(30),
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, None);
}
#[test]
fn unparseable_timestamp_is_unchanged() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(0.9),
"not-a-date",
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, Some(0.9));
}
}
+1
View File
@@ -2,6 +2,7 @@ pub mod backend;
pub mod chunker;
pub mod cli;
pub mod cortex;
pub mod decay;
pub mod embeddings;
pub mod hybrid;
pub mod hygiene;
+8 -2
View File
@@ -8615,8 +8615,14 @@ mod tests {
&["ANTHROPIC_OAUTH_TOKEN"]
);
assert_eq!(provider_env_var_fallbacks("gemini"), &["GOOGLE_API_KEY"]);
assert_eq!(provider_env_var_fallbacks("minimax"), &["MINIMAX_OAUTH_TOKEN"]);
assert_eq!(provider_env_var_fallbacks("volcengine"), &["DOUBAO_API_KEY"]);
assert_eq!(
provider_env_var_fallbacks("minimax"),
&["MINIMAX_OAUTH_TOKEN"]
);
assert_eq!(
provider_env_var_fallbacks("volcengine"),
&["DOUBAO_API_KEY"]
);
}
#[tokio::test]
+33 -2
View File
@@ -408,8 +408,9 @@ impl AnthropicProvider {
response
.content
.into_iter()
.find(|c| c.kind == "text")
.and_then(|c| c.text)
.filter(|c| c.kind == "text")
.filter_map(|c| c.text.map(|text| text.trim().to_string()))
.find(|text| !text.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
}
@@ -1413,6 +1414,36 @@ mod tests {
assert!(result.usage.is_none());
}
#[test]
fn parse_text_response_ignores_empty_and_whitespace_text_blocks() {
let json = r#"{
"content": [
{"type": "text", "text": ""},
{"type": "text", "text": " \n "},
{"type": "text", "text": " final answer "}
]
}"#;
let response: ChatResponse = serde_json::from_str(json).unwrap();
let parsed = AnthropicProvider::parse_text_response(response).unwrap();
assert_eq!(parsed, "final answer");
}
#[test]
fn parse_text_response_rejects_empty_or_whitespace_only_text_blocks() {
let json = r#"{
"content": [
{"type": "text", "text": ""},
{"type": "text", "text": " \n "},
{"type": "tool_use", "id": "tool_1", "name": "shell"}
]
}"#;
let response: ChatResponse = serde_json::from_str(json).unwrap();
let err = AnthropicProvider::parse_text_response(response).unwrap_err();
assert!(err.to_string().contains("No response from Anthropic"));
}
#[test]
fn capabilities_reports_vision_and_native_tool_calling() {
let provider = AnthropicProvider::new(Some("test-key"));
+231 -88
View File
@@ -16,6 +16,7 @@ use reqwest::{
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use tokio_tungstenite::{
connect_async,
tungstenite::{
@@ -1618,90 +1619,173 @@ impl OpenAiCompatibleProvider {
messages: &[ChatMessage],
allow_user_image_parts: bool,
) -> Vec<NativeMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
{
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(
tool_calls_value.clone(),
)
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| ToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: Some(Function {
name: Some(tc.name),
arguments: Some(tc.arguments),
}),
name: None,
arguments: None,
parameters: None,
})
.collect::<Vec<_>>();
let mut native_messages = Vec::with_capacity(messages.len());
let mut assistant_tool_call_ids = HashSet::new();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(|value| MessageContent::Text(value.to_string()));
let reasoning_content = value
.get("reasoning_content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
reasoning_content,
};
for message in messages {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls) = Self::parse_history_tool_calls(&value) {
for call in &tool_calls {
if let Some(id) = call.id.as_ref() {
assistant_tool_call_ids.insert(id.clone());
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
// Some OpenAI-compatible providers (including NVIDIA NIM models)
// reject assistant tool-call messages if `content` is omitted.
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(|value| MessageContent::Text(value.to_string()))
.or_else(|| Some(MessageContent::Text(message.content.clone())));
.map(ToString::to_string)
.unwrap_or_default();
return NativeMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
reasoning_content: None,
};
let reasoning_content = value
.get("reasoning_content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
native_messages.push(NativeMessage {
role: "assistant".to_string(),
content: Some(MessageContent::Text(content)),
tool_call_id: None,
tool_calls: Some(tool_calls),
reasoning_content,
});
continue;
}
}
}
NativeMessage {
role: message.role.clone(),
content: Some(Self::to_message_content(
&message.role,
&message.content,
allow_user_image_parts,
)),
tool_call_id: None,
tool_calls: None,
reasoning_content: None,
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.or_else(|| value.get("tool_use_id"))
.or_else(|| value.get("toolUseId"))
.or_else(|| value.get("id"))
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content_text = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string)
.unwrap_or_else(|| message.content.clone());
if let Some(id) = tool_call_id {
if assistant_tool_call_ids.contains(&id) {
native_messages.push(NativeMessage {
role: "tool".to_string(),
content: Some(MessageContent::Text(content_text)),
tool_call_id: Some(id),
tool_calls: None,
reasoning_content: None,
});
continue;
}
tracing::warn!(
tool_call_id = %id,
"Dropping orphan tool-role message; no matching assistant tool_call in history"
);
} else {
tracing::warn!(
"Dropping tool-role message missing tool_call_id; preserving as user text fallback"
);
}
native_messages.push(NativeMessage {
role: "user".to_string(),
content: Some(MessageContent::Text(format!(
"[Tool result]\n{}",
content_text
))),
tool_call_id: None,
tool_calls: None,
reasoning_content: None,
});
continue;
}
})
.collect()
}
native_messages.push(NativeMessage {
role: message.role.clone(),
content: Some(Self::to_message_content(
&message.role,
&message.content,
allow_user_image_parts,
)),
tool_call_id: None,
tool_calls: None,
reasoning_content: None,
});
}
native_messages
}
fn parse_history_tool_calls(value: &serde_json::Value) -> Option<Vec<ToolCall>> {
let tool_calls_value = value.get("tool_calls")?;
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| ToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: Some(Function {
name: Some(tc.name),
arguments: Some(Self::normalize_tool_arguments(tc.arguments)),
}),
name: None,
arguments: None,
parameters: None,
})
.collect::<Vec<_>>();
if !tool_calls.is_empty() {
return Some(tool_calls);
}
}
if let Ok(parsed_calls) = serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
{
let mut normalized_calls = Vec::with_capacity(parsed_calls.len());
for call in parsed_calls {
let Some(name) = call.function_name() else {
continue;
};
let arguments = call
.function_arguments()
.unwrap_or_else(|| "{}".to_string());
normalized_calls.push(ToolCall {
id: Some(call.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())),
kind: Some("function".to_string()),
function: Some(Function {
name: Some(name),
arguments: Some(Self::normalize_tool_arguments(arguments)),
}),
name: None,
arguments: None,
parameters: None,
});
}
if !normalized_calls.is_empty() {
return Some(normalized_calls);
}
}
None
}
fn normalize_tool_arguments(arguments: String) -> String {
if serde_json::from_str::<serde_json::Value>(&arguments).is_ok() {
arguments
} else {
"{}".to_string()
}
}
fn with_prompt_guided_tool_instructions(
@@ -1741,17 +1825,14 @@ impl OpenAiCompatibleProvider {
.filter_map(|tc| {
let name = tc.function_name()?;
let arguments = tc.function_arguments().unwrap_or_else(|| "{}".to_string());
let normalized_arguments =
if serde_json::from_str::<serde_json::Value>(&arguments).is_ok() {
arguments
} else {
tracing::warn!(
function = %name,
arguments = %arguments,
"Invalid JSON in native tool-call arguments, using empty object"
);
"{}".to_string()
};
let normalized_arguments = Self::normalize_tool_arguments(arguments.clone());
if normalized_arguments == "{}" && arguments != "{}" {
tracing::warn!(
function = %name,
arguments = %arguments,
"Invalid JSON in native tool-call arguments, using empty object"
);
}
Some(ProviderToolCall {
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name,
@@ -3381,18 +3462,80 @@ mod tests {
#[test]
fn convert_messages_for_native_maps_tool_result_payload() {
let input = vec![ChatMessage::tool(
r#"{"tool_call_id":"call_abc","content":"done"}"#,
let input = vec![
ChatMessage::assistant(
r#"{"content":"","tool_calls":[{"id":"call_abc","name":"shell","arguments":"{}"}]}"#,
),
ChatMessage::tool(r#"{"tool_call_id":"call_abc","content":"done"}"#),
];
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input, true);
assert_eq!(converted.len(), 2);
assert_eq!(converted[1].role, "tool");
assert_eq!(converted[1].tool_call_id.as_deref(), Some("call_abc"));
assert!(matches!(
converted[1].content.as_ref(),
Some(MessageContent::Text(value)) if value == "done"
));
}
#[test]
fn convert_messages_for_native_parses_openai_style_assistant_tool_calls() {
let input = vec![ChatMessage::assistant(
r#"{
"content": null,
"tool_calls": [{
"id": "call_openai_1",
"type": "function",
"function": {
"name": "shell",
"arguments": "{\"command\":\"pwd\"}"
}
}]
}"#,
)];
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input, true);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
assert_eq!(converted[0].role, "assistant");
assert!(matches!(
converted[0].content.as_ref(),
Some(MessageContent::Text(value)) if value == "done"
Some(MessageContent::Text(value)) if value.is_empty()
));
let calls = converted[0]
.tool_calls
.as_ref()
.expect("assistant message should include tool_calls");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id.as_deref(), Some("call_openai_1"));
assert!(matches!(
calls[0].function.as_ref().and_then(|f| f.name.as_deref()),
Some("shell")
));
assert!(matches!(
calls[0]
.function
.as_ref()
.and_then(|f| f.arguments.as_deref()),
Some("{\"command\":\"pwd\"}")
));
}
#[test]
fn convert_messages_for_native_rewrites_orphan_tool_message_as_user() {
let input = vec![ChatMessage::tool(
r#"{"tool_call_id":"call_missing","content":"done"}"#,
)];
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input, true);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "user");
assert!(matches!(
converted[0].content.as_ref(),
Some(MessageContent::Text(value)) if value.contains("[Tool result]") && value.contains("done")
));
assert!(converted[0].tool_call_id.is_none());
}
#[test]
+152 -24
View File
@@ -5,6 +5,7 @@
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
use crate::auth::AuthService;
use crate::multimodal;
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use base64::Engine;
@@ -135,8 +136,22 @@ struct Content {
}
#[derive(Debug, Serialize, Clone)]
struct Part {
text: String,
#[serde(untagged)]
enum Part {
Text {
text: String,
},
InlineData {
#[serde(rename = "inlineData")]
inline_data: InlineDataPart,
},
}
#[derive(Debug, Serialize, Clone)]
struct InlineDataPart {
#[serde(rename = "mimeType")]
mime_type: String,
data: String,
}
#[derive(Debug, Serialize, Clone)]
@@ -930,6 +945,57 @@ impl GeminiProvider {
|| status.is_server_error()
|| error_text.contains("RESOURCE_EXHAUSTED")
}
fn parse_inline_image_marker(image_ref: &str) -> Option<InlineDataPart> {
let rest = image_ref.strip_prefix("data:")?;
let semi_index = rest.find(';')?;
let mime_type = rest[..semi_index].trim();
if mime_type.is_empty() {
return None;
}
let payload = rest[semi_index + 1..].strip_prefix("base64,")?.trim();
if payload.is_empty() {
return None;
}
Some(InlineDataPart {
mime_type: mime_type.to_string(),
data: payload.to_string(),
})
}
fn build_user_parts(content: &str) -> Vec<Part> {
let (cleaned_text, image_refs) = multimodal::parse_image_markers(content);
if image_refs.is_empty() {
return vec![Part::Text {
text: content.to_string(),
}];
}
let mut parts: Vec<Part> = Vec::with_capacity(image_refs.len() + 1);
if !cleaned_text.is_empty() {
parts.push(Part::Text { text: cleaned_text });
}
for image_ref in image_refs {
if let Some(inline_data) = Self::parse_inline_image_marker(&image_ref) {
parts.push(Part::InlineData { inline_data });
} else {
parts.push(Part::Text {
text: format!("[IMAGE:{image_ref}]"),
});
}
}
if parts.is_empty() {
vec![Part::Text {
text: String::new(),
}]
} else {
parts
}
}
}
impl GeminiProvider {
@@ -1154,16 +1220,14 @@ impl Provider for GeminiProvider {
) -> anyhow::Result<String> {
let system_instruction = system_prompt.map(|sys| Content {
role: None,
parts: vec![Part {
parts: vec![Part::Text {
text: sys.to_string(),
}],
});
let contents = vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: message.to_string(),
}],
parts: Self::build_user_parts(message),
}];
let (text, _usage) = self
@@ -1189,16 +1253,14 @@ impl Provider for GeminiProvider {
"user" => {
contents.push(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: msg.content.clone(),
}],
parts: Self::build_user_parts(&msg.content),
});
}
"assistant" => {
// Gemini API uses "model" role instead of "assistant"
contents.push(Content {
role: Some("model".to_string()),
parts: vec![Part {
parts: vec![Part::Text {
text: msg.content.clone(),
}],
});
@@ -1212,7 +1274,7 @@ impl Provider for GeminiProvider {
} else {
Some(Content {
role: None,
parts: vec![Part {
parts: vec![Part::Text {
text: system_parts.join("\n\n"),
}],
})
@@ -1238,13 +1300,11 @@ impl Provider for GeminiProvider {
"system" => system_parts.push(&msg.content),
"user" => contents.push(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: msg.content.clone(),
}],
parts: Self::build_user_parts(&msg.content),
}),
"assistant" => contents.push(Content {
role: Some("model".to_string()),
parts: vec![Part {
parts: vec![Part::Text {
text: msg.content.clone(),
}],
}),
@@ -1257,7 +1317,7 @@ impl Provider for GeminiProvider {
} else {
Some(Content {
role: None,
parts: vec![Part {
parts: vec![Part::Text {
text: system_parts.join("\n\n"),
}],
})
@@ -1545,7 +1605,7 @@ mod tests {
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
parts: vec![Part::Text {
text: "hello".into(),
}],
}],
@@ -1586,7 +1646,7 @@ mod tests {
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
parts: vec![Part::Text {
text: "hello".into(),
}],
}],
@@ -1630,7 +1690,7 @@ mod tests {
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
parts: vec![Part::Text {
text: "hello".into(),
}],
}],
@@ -1662,13 +1722,13 @@ mod tests {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
parts: vec![Part::Text {
text: "Hello".to_string(),
}],
}],
system_instruction: Some(Content {
role: None,
parts: vec![Part {
parts: vec![Part::Text {
text: "You are helpful".to_string(),
}],
}),
@@ -1687,6 +1747,74 @@ mod tests {
assert!(json.contains("\"maxOutputTokens\":8192"));
}
#[test]
fn build_user_parts_text_only_is_backward_compatible() {
let content = "Plain text message without image markers.";
let parts = GeminiProvider::build_user_parts(content);
assert_eq!(parts.len(), 1);
match &parts[0] {
Part::Text { text } => assert_eq!(text, content),
Part::InlineData { .. } => panic!("text-only message must stay text-only"),
}
}
#[test]
fn build_user_parts_single_image() {
let parts = GeminiProvider::build_user_parts(
"Describe this image [IMAGE:data:image/png;base64,aGVsbG8=]",
);
assert_eq!(parts.len(), 2);
match &parts[0] {
Part::Text { text } => assert_eq!(text, "Describe this image"),
Part::InlineData { .. } => panic!("first part should be text"),
}
match &parts[1] {
Part::InlineData { inline_data } => {
assert_eq!(inline_data.mime_type, "image/png");
assert_eq!(inline_data.data, "aGVsbG8=");
}
Part::Text { .. } => panic!("second part should be inline image data"),
}
}
#[test]
fn build_user_parts_multiple_images() {
let parts = GeminiProvider::build_user_parts(
"Compare [IMAGE:data:image/png;base64,aQ==] and [IMAGE:data:image/jpeg;base64,ag==]",
);
assert_eq!(parts.len(), 3);
assert!(matches!(parts[0], Part::Text { .. }));
assert!(matches!(parts[1], Part::InlineData { .. }));
assert!(matches!(parts[2], Part::InlineData { .. }));
}
#[test]
fn build_user_parts_image_only() {
let parts = GeminiProvider::build_user_parts("[IMAGE:data:image/webp;base64,YWJjZA==]");
assert_eq!(parts.len(), 1);
match &parts[0] {
Part::InlineData { inline_data } => {
assert_eq!(inline_data.mime_type, "image/webp");
assert_eq!(inline_data.data, "YWJjZA==");
}
Part::Text { .. } => panic!("image-only message should create inline image part"),
}
}
#[test]
fn build_user_parts_fallback_for_non_data_uri_markers() {
let parts = GeminiProvider::build_user_parts("Inspect [IMAGE:https://example.com/img.png]");
assert_eq!(parts.len(), 2);
match &parts[0] {
Part::Text { text } => assert_eq!(text, "Inspect"),
Part::InlineData { .. } => panic!("first part should be text"),
}
match &parts[1] {
Part::Text { text } => assert_eq!(text, "[IMAGE:https://example.com/img.png]"),
Part::InlineData { .. } => panic!("invalid markers should fall back to text"),
}
}
#[test]
fn internal_request_includes_model() {
let request = InternalGenerateContentEnvelope {
@@ -1696,7 +1824,7 @@ mod tests {
request: InternalGenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
parts: vec![Part::Text {
text: "Hello".to_string(),
}],
}],
@@ -1728,7 +1856,7 @@ mod tests {
request: InternalGenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
parts: vec![Part::Text {
text: "Hello".to_string(),
}],
}],
@@ -1751,7 +1879,7 @@ mod tests {
request: InternalGenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
parts: vec![Part::Text {
text: "Hello".to_string(),
}],
}],
+507 -110
View File
@@ -1,4 +1,5 @@
use parking_lot::Mutex;
use reqwest::Url;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
@@ -47,6 +48,24 @@ pub enum ToolOperation {
Act,
}
/// Action applied when a command context rule matches.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommandContextRuleAction {
Allow,
Deny,
}
/// Context-aware allow/deny rule for shell commands.
#[derive(Debug, Clone)]
pub struct CommandContextRule {
pub command: String,
pub action: CommandContextRuleAction,
pub allowed_domains: Vec<String>,
pub allowed_path_prefixes: Vec<String>,
pub denied_path_prefixes: Vec<String>,
pub allow_high_risk: bool,
}
/// Sliding-window action tracker for rate limiting.
#[derive(Debug)]
pub struct ActionTracker {
@@ -99,6 +118,7 @@ pub struct SecurityPolicy {
pub workspace_dir: PathBuf,
pub workspace_only: bool,
pub allowed_commands: Vec<String>,
pub command_context_rules: Vec<CommandContextRule>,
pub forbidden_paths: Vec<String>,
pub allowed_roots: Vec<PathBuf>,
pub max_actions_per_hour: u32,
@@ -132,6 +152,7 @@ impl Default for SecurityPolicy {
"tail".into(),
"date".into(),
],
command_context_rules: Vec::new(),
forbidden_paths: vec![
// System directories (blocked even when workspace_only=false)
"/etc".into(),
@@ -565,7 +586,366 @@ fn is_allowlist_entry_match(allowed: &str, executable: &str, executable_base: &s
allowed == executable_base
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SegmentRuleDecision {
NoMatch,
Allow,
Deny,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct SegmentRuleOutcome {
decision: SegmentRuleDecision,
allow_high_risk: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
struct CommandAllowlistEvaluation {
high_risk_overridden: bool,
}
fn is_high_risk_base_command(base: &str) -> bool {
matches!(
base,
"rm" | "mkfs"
| "dd"
| "shutdown"
| "reboot"
| "halt"
| "poweroff"
| "sudo"
| "su"
| "chown"
| "chmod"
| "useradd"
| "userdel"
| "usermod"
| "passwd"
| "mount"
| "umount"
| "iptables"
| "ufw"
| "firewall-cmd"
| "curl"
| "wget"
| "nc"
| "ncat"
| "netcat"
| "scp"
| "ssh"
| "ftp"
| "telnet"
)
}
impl SecurityPolicy {
fn path_matches_rule_prefix(&self, candidate: &str, prefix: &str) -> bool {
let candidate_path = expand_user_path(candidate);
let prefix_path = expand_user_path(prefix);
let normalized_candidate = if candidate_path.is_absolute() {
candidate_path
} else {
self.workspace_dir.join(candidate_path)
};
let normalized_prefix = if prefix_path.is_absolute() {
prefix_path
} else {
self.workspace_dir.join(prefix_path)
};
normalized_candidate.starts_with(&normalized_prefix)
}
fn host_matches_pattern(host: &str, pattern: &str) -> bool {
let host = host.trim().to_ascii_lowercase();
let pattern = pattern.trim().to_ascii_lowercase();
if host.is_empty() || pattern.is_empty() {
return false;
}
if let Some(suffix) = pattern.strip_prefix("*.") {
host == suffix || host.ends_with(&format!(".{suffix}"))
} else {
host == pattern
}
}
fn extract_segment_url_hosts(args: &[&str]) -> Vec<String> {
args.iter()
.filter_map(|token| {
let candidate = strip_wrapping_quotes(token)
.trim()
.trim_matches(|c: char| matches!(c, ',' | ';'));
if candidate.is_empty() {
return None;
}
Url::parse(candidate)
.ok()
.and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
})
.collect()
}
fn extract_segment_path_args(args: &[&str]) -> Vec<String> {
let mut paths = Vec::new();
for token in args {
let candidate = strip_wrapping_quotes(token).trim();
if candidate.is_empty() || candidate.contains("://") {
continue;
}
if let Some(target) = redirection_target(candidate) {
let normalized = strip_wrapping_quotes(target).trim();
if !normalized.is_empty() && looks_like_path(normalized) {
paths.push(normalized.to_string());
}
}
if candidate.starts_with('-') {
if let Some((_, value)) = candidate.split_once('=') {
let normalized = strip_wrapping_quotes(value).trim();
if !normalized.is_empty()
&& !normalized.contains("://")
&& looks_like_path(normalized)
{
paths.push(normalized.to_string());
}
}
if let Some(value) = attached_short_option_value(candidate) {
let normalized = strip_wrapping_quotes(value).trim();
if !normalized.is_empty()
&& !normalized.contains("://")
&& looks_like_path(normalized)
{
paths.push(normalized.to_string());
}
}
continue;
}
if looks_like_path(candidate) {
paths.push(candidate.to_string());
}
}
paths
}
fn rule_conditions_match(&self, rule: &CommandContextRule, args: &[&str]) -> bool {
if !rule.allowed_domains.is_empty() {
let hosts = Self::extract_segment_url_hosts(args);
if hosts.is_empty() {
return false;
}
if !hosts.iter().all(|host| {
rule.allowed_domains
.iter()
.any(|pattern| Self::host_matches_pattern(host, pattern))
}) {
return false;
}
}
let path_args =
if rule.allowed_path_prefixes.is_empty() && rule.denied_path_prefixes.is_empty() {
Vec::new()
} else {
Self::extract_segment_path_args(args)
};
if !rule.allowed_path_prefixes.is_empty() {
if path_args.is_empty() {
return false;
}
if !path_args.iter().all(|path| {
rule.allowed_path_prefixes
.iter()
.any(|prefix| self.path_matches_rule_prefix(path, prefix))
}) {
return false;
}
}
if !rule.denied_path_prefixes.is_empty() {
if path_args.is_empty() {
return false;
}
let has_denied_path = path_args.iter().any(|path| {
rule.denied_path_prefixes
.iter()
.any(|prefix| self.path_matches_rule_prefix(path, prefix))
});
match rule.action {
CommandContextRuleAction::Allow => {
if has_denied_path {
return false;
}
}
CommandContextRuleAction::Deny => {
if !has_denied_path {
return false;
}
}
}
}
true
}
fn evaluate_segment_context_rules(
&self,
executable: &str,
base_cmd: &str,
args: &[&str],
) -> SegmentRuleOutcome {
let mut has_allow_rules = false;
let mut allow_match = false;
let mut allow_high_risk = false;
for rule in &self.command_context_rules {
if !is_allowlist_entry_match(&rule.command, executable, base_cmd) {
continue;
}
if matches!(rule.action, CommandContextRuleAction::Allow) {
has_allow_rules = true;
}
if !self.rule_conditions_match(rule, args) {
continue;
}
match rule.action {
CommandContextRuleAction::Deny => {
return SegmentRuleOutcome {
decision: SegmentRuleDecision::Deny,
allow_high_risk: false,
};
}
CommandContextRuleAction::Allow => {
allow_match = true;
allow_high_risk |= rule.allow_high_risk;
}
}
}
if has_allow_rules {
if allow_match {
SegmentRuleOutcome {
decision: SegmentRuleDecision::Allow,
allow_high_risk,
}
} else {
SegmentRuleOutcome {
decision: SegmentRuleDecision::Deny,
allow_high_risk: false,
}
}
} else {
SegmentRuleOutcome {
decision: SegmentRuleDecision::NoMatch,
allow_high_risk: false,
}
}
}
fn evaluate_command_allowlist(
&self,
command: &str,
) -> Result<CommandAllowlistEvaluation, String> {
if self.autonomy == AutonomyLevel::ReadOnly {
return Err("readonly autonomy level blocks shell command execution".into());
}
if command.contains('`')
|| contains_unquoted_shell_variable_expansion(command)
|| command.contains("<(")
|| command.contains(">(")
{
return Err("command contains disallowed shell expansion syntax".into());
}
if contains_unquoted_char(command, '>') || contains_unquoted_char(command, '<') {
return Err("command contains disallowed redirection syntax".into());
}
if command
.split_whitespace()
.any(|w| w == "tee" || w.ends_with("/tee"))
{
return Err("command contains disallowed tee usage".into());
}
if contains_unquoted_single_ampersand(command) {
return Err("command contains disallowed background chaining operator '&'".into());
}
let segments = split_unquoted_segments(command);
let mut has_cmd = false;
let mut saw_high_risk_segment = false;
let mut all_high_risk_segments_overridden = true;
for segment in &segments {
let cmd_part = skip_env_assignments(segment);
let mut words = cmd_part.split_whitespace();
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
let base_cmd = executable.rsplit('/').next().unwrap_or("").trim();
if base_cmd.is_empty() {
continue;
}
has_cmd = true;
let args_raw: Vec<&str> = words.collect();
let args_lower: Vec<String> = args_raw.iter().map(|w| w.to_ascii_lowercase()).collect();
let context_outcome =
self.evaluate_segment_context_rules(executable, base_cmd, &args_raw);
if context_outcome.decision == SegmentRuleDecision::Deny {
return Err(format!("context rule denied command segment `{base_cmd}`"));
}
if context_outcome.decision != SegmentRuleDecision::Allow
&& !self
.allowed_commands
.iter()
.any(|allowed| is_allowlist_entry_match(allowed, executable, base_cmd))
{
return Err(format!(
"command segment `{base_cmd}` is not present in allowed_commands"
));
}
if !self.is_args_safe(base_cmd, &args_lower) {
return Err(format!(
"command segment `{base_cmd}` contains unsafe arguments"
));
}
let base_lower = base_cmd.to_ascii_lowercase();
if is_high_risk_base_command(&base_lower) {
saw_high_risk_segment = true;
if !(context_outcome.decision == SegmentRuleDecision::Allow
&& context_outcome.allow_high_risk)
{
all_high_risk_segments_overridden = false;
}
}
}
if !has_cmd {
return Err("command is empty after parsing".into());
}
Ok(CommandAllowlistEvaluation {
high_risk_overridden: saw_high_risk_segment && all_high_risk_segments_overridden,
})
}
// ── Risk Classification ──────────────────────────────────────────────
// Risk is assessed per-segment (split on shell operators), and the
// highest risk across all segments wins. This prevents bypasses like
@@ -592,37 +972,7 @@ impl SecurityPolicy {
let joined_segment = cmd_part.to_ascii_lowercase();
// High-risk commands
if matches!(
base.as_str(),
"rm" | "mkfs"
| "dd"
| "shutdown"
| "reboot"
| "halt"
| "poweroff"
| "sudo"
| "su"
| "chown"
| "chmod"
| "useradd"
| "userdel"
| "usermod"
| "passwd"
| "mount"
| "umount"
| "iptables"
| "ufw"
| "firewall-cmd"
| "curl"
| "wget"
| "nc"
| "ncat"
| "netcat"
| "scp"
| "ssh"
| "ftp"
| "telnet"
) {
if is_high_risk_base_command(base.as_str()) {
return CommandRiskLevel::High;
}
@@ -693,9 +1043,9 @@ impl SecurityPolicy {
command: &str,
approved: bool,
) -> Result<CommandRiskLevel, String> {
if !self.is_command_allowed(command) {
return Err(format!("Command not allowed by security policy: {command}"));
}
let allowlist_eval = self
.evaluate_command_allowlist(command)
.map_err(|reason| format!("Command not allowed by security policy: {reason}"))?;
if let Some(path) = self.forbidden_path_argument(command) {
return Err(format!("Path blocked by security policy: {path}"));
@@ -704,7 +1054,7 @@ impl SecurityPolicy {
let risk = self.command_risk_level(command);
if risk == CommandRiskLevel::High {
if self.block_high_risk_commands {
if self.block_high_risk_commands && !allowlist_eval.high_risk_overridden {
let lower = command.to_ascii_lowercase();
if lower.contains("curl") || lower.contains("wget") {
return Err(
@@ -750,81 +1100,7 @@ impl SecurityPolicy {
/// - Blocks shell redirections (`<`, `>`, `>>`) that can bypass path policy
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
pub fn is_command_allowed(&self, command: &str) -> bool {
if self.autonomy == AutonomyLevel::ReadOnly {
return false;
}
// Block subshell/expansion operators — these allow hiding arbitrary
// commands inside an allowed command (e.g. `echo $(rm -rf /)`) and
// bypassing path checks through variable indirection. The helper below
// ignores escapes and literals inside single quotes, so `$(` or `${`
// literals are permitted there.
if command.contains('`')
|| contains_unquoted_shell_variable_expansion(command)
|| command.contains("<(")
|| command.contains(">(")
{
return false;
}
// Block shell redirections (`<`, `>`, `>>`) — they can read/write
// arbitrary paths and bypass path checks.
// Ignore quoted literals, e.g. `echo "a>b"` and `echo "a<b"`.
if contains_unquoted_char(command, '>') || contains_unquoted_char(command, '<') {
return false;
}
// Block `tee` — it can write to arbitrary files, bypassing the
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
if command
.split_whitespace()
.any(|w| w == "tee" || w.ends_with("/tee"))
{
return false;
}
// Block background command chaining (`&`), which can hide extra
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
if contains_unquoted_single_ampersand(command) {
return false;
}
// Split on unquoted command separators and validate each sub-command.
let segments = split_unquoted_segments(command);
for segment in &segments {
// Strip leading env var assignments (e.g. FOO=bar cmd)
let cmd_part = skip_env_assignments(segment);
let mut words = cmd_part.split_whitespace();
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
let base_cmd = executable.rsplit('/').next().unwrap_or("");
if base_cmd.is_empty() {
continue;
}
if !self
.allowed_commands
.iter()
.any(|allowed| is_allowlist_entry_match(allowed, executable, base_cmd))
{
return false;
}
// Validate arguments for the command
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
if !self.is_args_safe(base_cmd, &args) {
return false;
}
}
// At least one command must be present
let has_cmd = segments.iter().any(|s| {
let s = skip_env_assignments(s.trim());
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
});
has_cmd
self.evaluate_command_allowlist(command).is_ok()
}
/// Check for dangerous arguments that allow sub-command execution.
@@ -1214,6 +1490,11 @@ impl SecurityPolicy {
format!("{} (others rejected)", shown.join(", "))
}
};
let context_rules = if self.command_context_rules.is_empty() {
"none".to_string()
} else {
format!("{} configured", self.command_context_rules.len())
};
let high_risk = if self.block_high_risk_commands {
"blocked"
@@ -1226,6 +1507,7 @@ impl SecurityPolicy {
- Workspace: {workspace} (workspace_only: {ws_only})\n\
- Forbidden paths: {forbidden_preview}\n\
- Allowed commands: {commands_preview}\n\
- Command context rules: {context_rules}\n\
- High-risk commands: {high_risk}\n\
- Do not exfiltrate data, bypass approval, or run destructive commands without asking."
)
@@ -1240,6 +1522,25 @@ impl SecurityPolicy {
workspace_dir: workspace_dir.to_path_buf(),
workspace_only: autonomy_config.workspace_only,
allowed_commands: autonomy_config.allowed_commands.clone(),
command_context_rules: autonomy_config
.command_context_rules
.iter()
.map(|rule| CommandContextRule {
command: rule.command.clone(),
action: match rule.action {
crate::config::CommandContextRuleAction::Allow => {
CommandContextRuleAction::Allow
}
crate::config::CommandContextRuleAction::Deny => {
CommandContextRuleAction::Deny
}
},
allowed_domains: rule.allowed_domains.clone(),
allowed_path_prefixes: rule.allowed_path_prefixes.clone(),
denied_path_prefixes: rule.denied_path_prefixes.clone(),
allow_high_risk: rule.allow_high_risk,
})
.collect(),
forbidden_paths: autonomy_config.forbidden_paths.clone(),
allowed_roots: autonomy_config
.allowed_roots
@@ -1461,6 +1762,102 @@ mod tests {
assert!(!p.is_command_allowed("echo hello"));
}
#[test]
fn context_allow_rule_overrides_global_allowlist_for_curl_domain() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec![],
command_context_rules: vec![CommandContextRule {
command: "curl".into(),
action: CommandContextRuleAction::Allow,
allowed_domains: vec!["api.example.com".into()],
allowed_path_prefixes: vec![],
denied_path_prefixes: vec![],
allow_high_risk: true,
}],
..SecurityPolicy::default()
};
assert!(p.is_command_allowed("curl https://api.example.com/v1/health"));
assert!(p
.validate_command_execution("curl https://api.example.com/v1/health", true)
.is_ok());
}
#[test]
fn context_allow_rule_restricts_curl_to_matching_domains() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
allowed_commands: vec!["curl".into()],
command_context_rules: vec![CommandContextRule {
command: "curl".into(),
action: CommandContextRuleAction::Allow,
allowed_domains: vec!["api.example.com".into()],
allowed_path_prefixes: vec![],
denied_path_prefixes: vec![],
allow_high_risk: true,
}],
..SecurityPolicy::default()
};
assert!(!p.is_command_allowed("curl https://evil.example.com/steal"));
let err = p
.validate_command_execution("curl https://evil.example.com/steal", true)
.expect_err("non-matching domains should be denied by context rules");
assert!(err.contains("context rule denied"));
}
#[test]
fn context_allow_rule_restricts_rm_to_allowed_path_prefix() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
workspace_only: false,
allowed_commands: vec!["rm".into()],
forbidden_paths: vec![],
command_context_rules: vec![CommandContextRule {
command: "rm".into(),
action: CommandContextRuleAction::Allow,
allowed_domains: vec![],
allowed_path_prefixes: vec!["/tmp".into()],
denied_path_prefixes: vec![],
allow_high_risk: true,
}],
..SecurityPolicy::default()
};
assert!(p.is_command_allowed("rm -rf /tmp/cleanup"));
assert!(p
.validate_command_execution("rm -rf /tmp/cleanup", true)
.is_ok());
assert!(!p.is_command_allowed("rm -rf /var/log"));
let err = p
.validate_command_execution("rm -rf /var/log", true)
.expect_err("paths outside /tmp should be denied");
assert!(err.contains("context rule denied"));
}
#[test]
fn context_deny_rule_can_block_specific_domain_even_when_allowlisted() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
block_high_risk_commands: false,
allowed_commands: vec!["curl".into()],
command_context_rules: vec![CommandContextRule {
command: "curl".into(),
action: CommandContextRuleAction::Deny,
allowed_domains: vec!["evil.example.com".into()],
allowed_path_prefixes: vec![],
denied_path_prefixes: vec![],
allow_high_risk: false,
}],
..SecurityPolicy::default()
};
assert!(p.is_command_allowed("curl https://api.example.com/v1/health"));
assert!(!p.is_command_allowed("curl https://evil.example.com/steal"));
}
#[test]
fn command_risk_low_for_read_commands() {
let p = default_policy();
+114 -13
View File
@@ -18,6 +18,12 @@ const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
/// Timeout for init/list operations.
const RECV_TIMEOUT_SECS: u64 = 30;
/// Streamable HTTP Accept header required by MCP HTTP transport.
const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream";
/// Default media type for MCP JSON-RPC request bodies.
const MCP_JSON_CONTENT_TYPE: &str = "application/json";
// ── Transport Trait ──────────────────────────────────────────────────────
/// Abstract transport for MCP communication.
@@ -171,10 +177,25 @@ impl McpTransportConn for HttpTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let body = serde_json::to_string(request)?;
let has_accept = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"));
let has_content_type = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
let mut req = self.client.post(&self.url).body(body);
if !has_content_type {
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
}
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !has_accept {
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
}
let resp = req
.send()
@@ -194,11 +215,24 @@ impl McpTransportConn for HttpTransport {
});
}
let resp_text = resp.text().await.context("failed to read HTTP response")?;
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
let is_sse = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
if is_sse {
let maybe_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
read_first_jsonrpc_from_sse_response(resp),
)
.await
.context("timeout waiting for MCP response from streamable HTTP SSE stream")??;
return maybe_resp
.ok_or_else(|| anyhow!("MCP server returned no response in SSE stream"));
}
Ok(mcp_resp)
let resp_text = resp.text().await.context("failed to read HTTP response")?;
parse_jsonrpc_response_text(&resp_text)
}
async fn close(&mut self) -> Result<()> {
@@ -264,14 +298,21 @@ impl SseTransport {
}
}
let has_accept = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"));
let mut req = self
.client
.get(&self.sse_url)
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache");
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !has_accept {
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
}
let resp = req.send().await.context("SSE GET to MCP server failed")?;
if resp.status() == reqwest::StatusCode::NOT_FOUND
@@ -556,6 +597,30 @@ fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
Cow::Owned(joined.trim().to_string())
}
fn parse_jsonrpc_response_text(resp_text: &str) -> Result<JsonRpcResponse> {
let trimmed = resp_text.trim();
if trimmed.is_empty() {
bail!("MCP server returned no response");
}
let json_text = if looks_like_sse_text(trimmed) {
extract_json_from_sse_text(trimmed)
} else {
Cow::Borrowed(trimmed)
};
let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref())
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
Ok(mcp_resp)
}
fn looks_like_sse_text(text: &str) -> bool {
text.starts_with("data:")
|| text.starts_with("event:")
|| text.contains("\ndata:")
|| text.contains("\nevent:")
}
async fn read_first_jsonrpc_from_sse_response(
resp: reqwest::Response,
) -> Result<Option<JsonRpcResponse>> {
@@ -673,21 +738,27 @@ impl McpTransportConn for SseTransport {
.chain(secondary_url.into_iter())
.enumerate()
{
let has_accept = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"));
let has_content_type = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
let mut req = self
.client
.post(&url)
.timeout(Duration::from_secs(120))
.body(body.clone())
.header("Content-Type", "application/json");
.body(body.clone());
if !has_content_type {
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
}
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"))
{
req = req.header("Accept", "application/json, text/event-stream");
if !has_accept {
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
}
let resp = req.send().await.context("SSE POST to MCP server failed")?;
@@ -887,4 +958,34 @@ mod tests {
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_parse_jsonrpc_response_text_handles_plain_json() {
let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}")
.expect("plain JSON response should parse");
assert_eq!(parsed.id, Some(serde_json::json!(1)));
assert!(parsed.error.is_none());
}
#[test]
fn test_parse_jsonrpc_response_text_handles_sse_framed_json() {
let sse =
"event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n";
let parsed =
parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse");
assert_eq!(parsed.id, Some(serde_json::json!(2)));
assert_eq!(
parsed
.result
.as_ref()
.and_then(|v| v.get("ok"))
.and_then(|v| v.as_bool()),
Some(true)
);
}
#[test]
fn test_parse_jsonrpc_response_text_rejects_empty_payload() {
assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
}
}