Merge branch 'main' into feat/feishu-doc-tool
This commit is contained in:
commit
87ae1c8ca6
@ -243,6 +243,10 @@ impl Agent {
|
||||
AgentBuilder::new()
|
||||
}
|
||||
|
||||
pub fn tool_specs(&self) -> &[ToolSpec] {
|
||||
&self.tool_specs
|
||||
}
|
||||
|
||||
pub fn history(&self) -> &[ConversationMessage] {
|
||||
&self.history
|
||||
}
|
||||
|
||||
@ -983,7 +983,7 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn run_tool_call_loop(
|
||||
pub async fn run_tool_call_loop(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory};
|
||||
use std::fmt::Write;
|
||||
|
||||
/// Default half-life (days) for time decay in context building.
|
||||
const CONTEXT_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
/// Core memories are exempt from time decay (evergreen).
|
||||
pub(super) async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
@ -13,7 +17,10 @@ pub(super) async fn build_context(
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
if let Ok(mut entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, CONTEXT_DECAY_HALF_LIFE_DAYS);
|
||||
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory};
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
/// Default half-life (days) for time decay in memory loading.
|
||||
const LOADER_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
#[async_trait]
|
||||
pub trait MemoryLoader: Send + Sync {
|
||||
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
|
||||
@ -38,11 +41,14 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||
let mut entries = memory.recall(user_message, self.limit, None).await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, LOADER_DECAY_HALF_LIFE_DAYS);
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
|
||||
@ -16,4 +16,4 @@ mod tests;
|
||||
#[allow(unused_imports)]
|
||||
pub use agent::{Agent, AgentBuilder};
|
||||
#[allow(unused_imports)]
|
||||
pub use loop_::{process_message, process_message_with_session, run};
|
||||
pub use loop_::{process_message, process_message_with_session, run, run_tool_call_loop};
|
||||
|
||||
@ -736,6 +736,20 @@ async fn native_dispatcher_sends_tool_specs() {
|
||||
assert!(dispatcher.should_send_tool_specs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_tool_specs_accessor_exposes_registered_tools() {
|
||||
let provider = Box::new(ScriptedProvider::new(vec![text_response("ok")]));
|
||||
let agent = build_agent_with(
|
||||
provider,
|
||||
vec![Box::new(EchoTool)],
|
||||
Box::new(NativeToolDispatcher),
|
||||
);
|
||||
|
||||
let specs = agent.tool_specs();
|
||||
assert_eq!(specs.len(), 1);
|
||||
assert_eq!(specs[0].name, "echo");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn xml_dispatcher_does_not_send_tool_specs() {
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
|
||||
152
src/memory/decay.rs
Normal file
152
src/memory/decay.rs
Normal file
@ -0,0 +1,152 @@
|
||||
use super::traits::{MemoryCategory, MemoryEntry};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Default half-life in days for time-decay scoring.
|
||||
/// After this many days, a non-Core memory's score drops to 50%.
|
||||
const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Apply exponential time decay to memory entry scores.
|
||||
///
|
||||
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
|
||||
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
|
||||
/// - Entries without a score (`None`) are left unchanged.
|
||||
///
|
||||
/// Decay formula: `score * 2^(-age_days / half_life_days)`
|
||||
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
|
||||
let half_life = if half_life_days <= 0.0 {
|
||||
DEFAULT_HALF_LIFE_DAYS
|
||||
} else {
|
||||
half_life_days
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
// Core memories are evergreen — never decay
|
||||
if entry.category == MemoryCategory::Core {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score = match entry.score {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
Ok(dt) => dt.with_timezone(&Utc),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let age_days = now
|
||||
.signed_duration_since(ts)
|
||||
.num_seconds()
|
||||
.max(0) as f64
|
||||
/ 86_400.0;
|
||||
|
||||
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
|
||||
entry.score = Some(score * decay_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "test".into(),
|
||||
content: "value".into(),
|
||||
category,
|
||||
timestamp: timestamp.into(),
|
||||
session_id: None,
|
||||
score,
|
||||
}
|
||||
}
|
||||
|
||||
fn recent_rfc3339() -> String {
|
||||
Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
fn days_ago_rfc3339(days: i64) -> String {
|
||||
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn core_memories_are_never_decayed() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Core,
|
||||
Some(0.9),
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recent_entry_score_barely_changes() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.8),
|
||||
&recent_rfc3339(),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.8).abs() < 0.01,
|
||||
"recent entry should barely decay, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_half_life_halves_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(7),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.5).abs() < 0.05,
|
||||
"score after one half-life should be ~0.5, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_half_lives_quarters_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(14),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.25).abs() < 0.05,
|
||||
"score after two half-lives should be ~0.25, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_score_entry_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_timestamp_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.9),
|
||||
"not-a-date",
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
}
|
||||
@ -2,6 +2,7 @@ pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod cortex;
|
||||
pub mod decay;
|
||||
pub mod embeddings;
|
||||
pub mod hybrid;
|
||||
pub mod hygiene;
|
||||
|
||||
@ -18,6 +18,12 @@ const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
||||
/// Timeout for init/list operations.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Streamable HTTP Accept header required by MCP HTTP transport.
|
||||
const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream";
|
||||
|
||||
/// Default media type for MCP JSON-RPC request bodies.
|
||||
const MCP_JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
// ── Transport Trait ──────────────────────────────────────────────────────
|
||||
|
||||
/// Abstract transport for MCP communication.
|
||||
@ -171,10 +177,25 @@ impl McpTransportConn for HttpTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let has_accept = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"));
|
||||
let has_content_type = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
|
||||
|
||||
let mut req = self.client.post(&self.url).body(body);
|
||||
if !has_content_type {
|
||||
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
|
||||
}
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !has_accept {
|
||||
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
@ -194,11 +215,24 @@ impl McpTransportConn for HttpTransport {
|
||||
});
|
||||
}
|
||||
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
let is_sse = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
if is_sse {
|
||||
let maybe_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
read_first_jsonrpc_from_sse_response(resp),
|
||||
)
|
||||
.await
|
||||
.context("timeout waiting for MCP response from streamable HTTP SSE stream")??;
|
||||
return maybe_resp
|
||||
.ok_or_else(|| anyhow!("MCP server returned no response in SSE stream"));
|
||||
}
|
||||
|
||||
Ok(mcp_resp)
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
parse_jsonrpc_response_text(&resp_text)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
@ -264,14 +298,21 @@ impl SseTransport {
|
||||
}
|
||||
}
|
||||
|
||||
let has_accept = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"));
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&self.sse_url)
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !has_accept {
|
||||
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE GET to MCP server failed")?;
|
||||
if resp.status() == reqwest::StatusCode::NOT_FOUND
|
||||
@ -556,6 +597,30 @@ fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
|
||||
Cow::Owned(joined.trim().to_string())
|
||||
}
|
||||
|
||||
fn parse_jsonrpc_response_text(resp_text: &str) -> Result<JsonRpcResponse> {
|
||||
let trimmed = resp_text.trim();
|
||||
if trimmed.is_empty() {
|
||||
bail!("MCP server returned no response");
|
||||
}
|
||||
|
||||
let json_text = if looks_like_sse_text(trimmed) {
|
||||
extract_json_from_sse_text(trimmed)
|
||||
} else {
|
||||
Cow::Borrowed(trimmed)
|
||||
};
|
||||
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref())
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
Ok(mcp_resp)
|
||||
}
|
||||
|
||||
fn looks_like_sse_text(text: &str) -> bool {
|
||||
text.starts_with("data:")
|
||||
|| text.starts_with("event:")
|
||||
|| text.contains("\ndata:")
|
||||
|| text.contains("\nevent:")
|
||||
}
|
||||
|
||||
async fn read_first_jsonrpc_from_sse_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<Option<JsonRpcResponse>> {
|
||||
@ -673,21 +738,27 @@ impl McpTransportConn for SseTransport {
|
||||
.chain(secondary_url.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
let has_accept = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"));
|
||||
let has_content_type = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(120))
|
||||
.body(body.clone())
|
||||
.header("Content-Type", "application/json");
|
||||
.body(body.clone());
|
||||
if !has_content_type {
|
||||
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
|
||||
}
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"))
|
||||
{
|
||||
req = req.header("Accept", "application/json, text/event-stream");
|
||||
if !has_accept {
|
||||
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
@ -887,4 +958,34 @@ mod tests {
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_jsonrpc_response_text_handles_plain_json() {
|
||||
let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}")
|
||||
.expect("plain JSON response should parse");
|
||||
assert_eq!(parsed.id, Some(serde_json::json!(1)));
|
||||
assert!(parsed.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_jsonrpc_response_text_handles_sse_framed_json() {
|
||||
let sse =
|
||||
"event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n";
|
||||
let parsed =
|
||||
parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse");
|
||||
assert_eq!(parsed.id, Some(serde_json::json!(2)));
|
||||
assert_eq!(
|
||||
parsed
|
||||
.result
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("ok"))
|
||||
.and_then(|v| v.as_bool()),
|
||||
Some(true)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_jsonrpc_response_text_rejects_empty_payload() {
|
||||
assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user