Merge branch 'main' into feat/feishu-doc-tool

This commit is contained in:
Chum Yin 2026-03-02 04:14:16 +08:00 committed by GitHub
commit 87ae1c8ca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 304 additions and 19 deletions

View File

@ -243,6 +243,10 @@ impl Agent {
AgentBuilder::new()
}
pub fn tool_specs(&self) -> &[ToolSpec] {
&self.tool_specs
}
pub fn history(&self) -> &[ConversationMessage] {
&self.history
}

View File

@ -983,7 +983,7 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn run_tool_call_loop(
pub async fn run_tool_call_loop(
provider: &dyn Provider,
history: &mut Vec<ChatMessage>,
tools_registry: &[Box<dyn Tool>],

View File

@ -1,9 +1,13 @@
use crate::memory::{self, Memory};
use crate::memory::{self, decay, Memory};
use std::fmt::Write;
/// Default half-life (days) for time decay in context building.
const CONTEXT_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
/// Build context preamble by searching memory for relevant entries.
/// Entries with a hybrid score below `min_relevance_score` are dropped to
/// prevent unrelated memories from bleeding into the conversation.
/// Core memories are exempt from time decay (evergreen).
pub(super) async fn build_context(
mem: &dyn Memory,
user_msg: &str,
@ -13,7 +17,10 @@ pub(super) async fn build_context(
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
if let Ok(mut entries) = mem.recall(user_msg, 5, session_id).await {
// Apply time decay: older non-Core memories score lower
decay::apply_time_decay(&mut entries, CONTEXT_DECAY_HALF_LIFE_DAYS);
let relevant: Vec<_> = entries
.iter()
.filter(|e| match e.score {

View File

@ -1,7 +1,10 @@
use crate::memory::{self, Memory};
use crate::memory::{self, decay, Memory};
use async_trait::async_trait;
use std::fmt::Write;
/// Default half-life (days) for time decay in memory loading.
const LOADER_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
#[async_trait]
pub trait MemoryLoader: Send + Sync {
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
@ -38,11 +41,14 @@ impl MemoryLoader for DefaultMemoryLoader {
memory: &dyn Memory,
user_message: &str,
) -> anyhow::Result<String> {
let entries = memory.recall(user_message, self.limit, None).await?;
let mut entries = memory.recall(user_message, self.limit, None).await?;
if entries.is_empty() {
return Ok(String::new());
}
// Apply time decay: older non-Core memories score lower
decay::apply_time_decay(&mut entries, LOADER_DECAY_HALF_LIFE_DAYS);
let mut context = String::from("[Memory context]\n");
for entry in entries {
if memory::is_assistant_autosave_key(&entry.key) {

View File

@ -16,4 +16,4 @@ mod tests;
#[allow(unused_imports)]
pub use agent::{Agent, AgentBuilder};
#[allow(unused_imports)]
pub use loop_::{process_message, process_message_with_session, run};
pub use loop_::{process_message, process_message_with_session, run, run_tool_call_loop};

View File

@ -736,6 +736,20 @@ async fn native_dispatcher_sends_tool_specs() {
assert!(dispatcher.should_send_tool_specs());
}
#[test]
fn agent_tool_specs_accessor_exposes_registered_tools() {
let provider = Box::new(ScriptedProvider::new(vec![text_response("ok")]));
let agent = build_agent_with(
provider,
vec![Box::new(EchoTool)],
Box::new(NativeToolDispatcher),
);
let specs = agent.tool_specs();
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].name, "echo");
}
#[tokio::test]
async fn xml_dispatcher_does_not_send_tool_specs() {
let dispatcher = XmlToolDispatcher;

152
src/memory/decay.rs Normal file
View File

@ -0,0 +1,152 @@
use super::traits::{MemoryCategory, MemoryEntry};
use chrono::{DateTime, Utc};
/// Default half-life in days for time-decay scoring.
/// After this many days, a non-Core memory's score drops to 50%.
const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
/// Apply exponential time decay to memory entry scores.
///
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
/// - Entries without a score (`None`) are left unchanged.
///
/// Decay formula: `score * 2^(-age_days / half_life_days)`
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
let half_life = if half_life_days <= 0.0 {
DEFAULT_HALF_LIFE_DAYS
} else {
half_life_days
};
let now = Utc::now();
for entry in entries.iter_mut() {
// Core memories are evergreen — never decay
if entry.category == MemoryCategory::Core {
continue;
}
let score = match entry.score {
Some(s) => s,
None => continue,
};
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
Ok(dt) => dt.with_timezone(&Utc),
Err(_) => continue,
};
let age_days = now
.signed_duration_since(ts)
.num_seconds()
.max(0) as f64
/ 86_400.0;
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
entry.score = Some(score * decay_factor);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
MemoryEntry {
id: "1".into(),
key: "test".into(),
content: "value".into(),
category,
timestamp: timestamp.into(),
session_id: None,
score,
}
}
fn recent_rfc3339() -> String {
Utc::now().to_rfc3339()
}
fn days_ago_rfc3339(days: i64) -> String {
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
}
#[test]
fn core_memories_are_never_decayed() {
let mut entries = vec![make_entry(
MemoryCategory::Core,
Some(0.9),
&days_ago_rfc3339(30),
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, Some(0.9));
}
#[test]
fn recent_entry_score_barely_changes() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(0.8),
&recent_rfc3339(),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.8).abs() < 0.01,
"recent entry should barely decay, got {decayed}"
);
}
#[test]
fn one_half_life_halves_score() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(1.0),
&days_ago_rfc3339(7),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.5).abs() < 0.05,
"score after one half-life should be ~0.5, got {decayed}"
);
}
#[test]
fn two_half_lives_quarters_score() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(1.0),
&days_ago_rfc3339(14),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.25).abs() < 0.05,
"score after two half-lives should be ~0.25, got {decayed}"
);
}
#[test]
fn no_score_entry_is_unchanged() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
None,
&days_ago_rfc3339(30),
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, None);
}
#[test]
fn unparseable_timestamp_is_unchanged() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(0.9),
"not-a-date",
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, Some(0.9));
}
}

View File

@ -2,6 +2,7 @@ pub mod backend;
pub mod chunker;
pub mod cli;
pub mod cortex;
pub mod decay;
pub mod embeddings;
pub mod hybrid;
pub mod hygiene;

View File

@ -18,6 +18,12 @@ const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
/// Timeout for init/list operations.
const RECV_TIMEOUT_SECS: u64 = 30;
/// Streamable HTTP Accept header required by MCP HTTP transport.
const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream";
/// Default media type for MCP JSON-RPC request bodies.
const MCP_JSON_CONTENT_TYPE: &str = "application/json";
// ── Transport Trait ──────────────────────────────────────────────────────
/// Abstract transport for MCP communication.
@ -171,10 +177,25 @@ impl McpTransportConn for HttpTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let body = serde_json::to_string(request)?;
let has_accept = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"));
let has_content_type = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
let mut req = self.client.post(&self.url).body(body);
if !has_content_type {
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
}
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !has_accept {
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
}
let resp = req
.send()
@ -194,11 +215,24 @@ impl McpTransportConn for HttpTransport {
});
}
let resp_text = resp.text().await.context("failed to read HTTP response")?;
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
let is_sse = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
if is_sse {
let maybe_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
read_first_jsonrpc_from_sse_response(resp),
)
.await
.context("timeout waiting for MCP response from streamable HTTP SSE stream")??;
return maybe_resp
.ok_or_else(|| anyhow!("MCP server returned no response in SSE stream"));
}
Ok(mcp_resp)
let resp_text = resp.text().await.context("failed to read HTTP response")?;
parse_jsonrpc_response_text(&resp_text)
}
async fn close(&mut self) -> Result<()> {
@ -264,14 +298,21 @@ impl SseTransport {
}
}
let has_accept = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"));
let mut req = self
.client
.get(&self.sse_url)
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache");
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !has_accept {
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
}
let resp = req.send().await.context("SSE GET to MCP server failed")?;
if resp.status() == reqwest::StatusCode::NOT_FOUND
@ -556,6 +597,30 @@ fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
Cow::Owned(joined.trim().to_string())
}
fn parse_jsonrpc_response_text(resp_text: &str) -> Result<JsonRpcResponse> {
let trimmed = resp_text.trim();
if trimmed.is_empty() {
bail!("MCP server returned no response");
}
let json_text = if looks_like_sse_text(trimmed) {
extract_json_from_sse_text(trimmed)
} else {
Cow::Borrowed(trimmed)
};
let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref())
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
Ok(mcp_resp)
}
fn looks_like_sse_text(text: &str) -> bool {
text.starts_with("data:")
|| text.starts_with("event:")
|| text.contains("\ndata:")
|| text.contains("\nevent:")
}
async fn read_first_jsonrpc_from_sse_response(
resp: reqwest::Response,
) -> Result<Option<JsonRpcResponse>> {
@ -673,21 +738,27 @@ impl McpTransportConn for SseTransport {
.chain(secondary_url.into_iter())
.enumerate()
{
let has_accept = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"));
let has_content_type = self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
let mut req = self
.client
.post(&url)
.timeout(Duration::from_secs(120))
.body(body.clone())
.header("Content-Type", "application/json");
.body(body.clone());
if !has_content_type {
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
}
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"))
{
req = req.header("Accept", "application/json, text/event-stream");
if !has_accept {
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
}
let resp = req.send().await.context("SSE POST to MCP server failed")?;
@ -887,4 +958,34 @@ mod tests {
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_parse_jsonrpc_response_text_handles_plain_json() {
let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}")
.expect("plain JSON response should parse");
assert_eq!(parsed.id, Some(serde_json::json!(1)));
assert!(parsed.error.is_none());
}
#[test]
fn test_parse_jsonrpc_response_text_handles_sse_framed_json() {
let sse =
"event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n";
let parsed =
parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse");
assert_eq!(parsed.id, Some(serde_json::json!(2)));
assert_eq!(
parsed
.result
.as_ref()
.and_then(|v| v.get("ok"))
.and_then(|v| v.as_bool()),
Some(true)
);
}
#[test]
fn test_parse_jsonrpc_response_text_rejects_empty_payload() {
assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
}
}