feat(cache): wire two-tier response cache, multi-provider token tracking, and cache analytics
- Two-tier response cache: in-memory LRU (hot) + SQLite (warm) with TTL-aware eviction - Wire response cache into agent turn loop (temp==0.0, text-only responses only) - Parse Anthropic cache_creation_input_tokens/cache_read_input_tokens - Parse OpenAI prompt_tokens_details.cached_tokens - Add cached_input_tokens to TokenUsage, prompt_caching to ProviderCapabilities - Add CacheHit/CacheMiss observer events with Prometheus counters - Add response_cache_hot_entries config field (default: 256)
This commit is contained in:
parent
f5bd557bda
commit
18cb38b09e
@ -38,6 +38,7 @@ pub struct Agent {
|
||||
available_hints: Vec<String>,
|
||||
route_model_by_hint: HashMap<String, String>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
response_cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder {
|
||||
@ -60,6 +61,7 @@ pub struct AgentBuilder {
|
||||
available_hints: Option<Vec<String>>,
|
||||
route_model_by_hint: Option<HashMap<String, String>>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
response_cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
@ -84,6 +86,7 @@ impl AgentBuilder {
|
||||
available_hints: None,
|
||||
route_model_by_hint: None,
|
||||
allowed_tools: None,
|
||||
response_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,6 +191,14 @@ impl AgentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_cache(
|
||||
mut self,
|
||||
cache: Option<Arc<crate::memory::response_cache::ResponseCache>>,
|
||||
) -> Self {
|
||||
self.response_cache = cache;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Agent> {
|
||||
let mut tools = self
|
||||
.tools
|
||||
@ -236,6 +247,7 @@ impl AgentBuilder {
|
||||
available_hints: self.available_hints.unwrap_or_default(),
|
||||
route_model_by_hint: self.route_model_by_hint.unwrap_or_default(),
|
||||
allowed_tools: allowed,
|
||||
response_cache: self.response_cache,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -330,11 +342,25 @@ impl Agent {
|
||||
.collect();
|
||||
let available_hints: Vec<String> = route_model_by_hint.keys().cloned().collect();
|
||||
|
||||
let response_cache = if config.memory.response_cache_enabled {
|
||||
crate::memory::response_cache::ResponseCache::with_hot_cache(
|
||||
&config.workspace_dir,
|
||||
config.memory.response_cache_ttl_minutes,
|
||||
config.memory.response_cache_max_entries,
|
||||
config.memory.response_cache_hot_entries,
|
||||
)
|
||||
.ok()
|
||||
.map(Arc::new)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(tools)
|
||||
.memory(memory)
|
||||
.observer(observer)
|
||||
.response_cache(response_cache)
|
||||
.tool_dispatcher(tool_dispatcher)
|
||||
.memory_loader(Box::new(DefaultMemoryLoader::new(
|
||||
5,
|
||||
@ -513,6 +539,47 @@ impl Agent {
|
||||
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
|
||||
// Response cache: check before LLM call (only for deterministic, text-only prompts)
|
||||
let cache_key = if self.temperature == 0.0 {
|
||||
self.response_cache.as_ref().map(|_| {
|
||||
let last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.as_str())
|
||||
.unwrap_or("");
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
crate::memory::response_cache::ResponseCache::cache_key(
|
||||
&effective_model,
|
||||
system,
|
||||
last_user,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
if let Ok(Some(cached)) = cache.get(key) {
|
||||
self.observer.record_event(&ObserverEvent::CacheHit {
|
||||
cache_type: "response".into(),
|
||||
tokens_saved: 0,
|
||||
});
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
cached.clone(),
|
||||
)));
|
||||
self.trim_history();
|
||||
return Ok(cached);
|
||||
}
|
||||
self.observer.record_event(&ObserverEvent::CacheMiss {
|
||||
cache_type: "response".into(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.provider
|
||||
.chat(
|
||||
@ -541,6 +608,17 @@ impl Agent {
|
||||
text
|
||||
};
|
||||
|
||||
// Store in response cache (text-only, no tool calls)
|
||||
if let (Some(ref cache), Some(ref key)) = (&self.response_cache, &cache_key) {
|
||||
let token_count = response
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|u| u.output_tokens)
|
||||
.unwrap_or(0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let _ = cache.put(key, &effective_model, &final_text, token_count as u32);
|
||||
}
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
final_text.clone(),
|
||||
|
||||
@ -3977,6 +3977,7 @@ mod tests {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2650,6 +2650,9 @@ pub struct MemoryConfig {
|
||||
/// Max number of cached responses before LRU eviction (default: 5000)
|
||||
#[serde(default = "default_response_cache_max")]
|
||||
pub response_cache_max_entries: usize,
|
||||
/// Max in-memory hot cache entries for the two-tier response cache (default: 256)
|
||||
#[serde(default = "default_response_cache_hot_entries")]
|
||||
pub response_cache_hot_entries: usize,
|
||||
|
||||
// ── Memory Snapshot (soul backup to Markdown) ─────────────
|
||||
/// Enable periodic export of core memories to MEMORY_SNAPSHOT.md
|
||||
@ -2718,6 +2721,10 @@ fn default_response_cache_max() -> usize {
|
||||
5_000
|
||||
}
|
||||
|
||||
fn default_response_cache_hot_entries() -> usize {
|
||||
256
|
||||
}
|
||||
|
||||
impl Default for MemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -2738,6 +2745,7 @@ impl Default for MemoryConfig {
|
||||
response_cache_enabled: false,
|
||||
response_cache_ttl_minutes: default_response_cache_ttl(),
|
||||
response_cache_max_entries: default_response_cache_max(),
|
||||
response_cache_hot_entries: default_response_cache_hot_entries(),
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
|
||||
@ -323,7 +323,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
tracing::info!("💓 Heartbeat Phase 1: skip (nothing to do)");
|
||||
crate::health::mark_component_ok("heartbeat");
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
let elapsed = tick_start.elapsed().as_millis() as f64;
|
||||
metrics.lock().record_success(elapsed);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -10,23 +10,45 @@ use chrono::{Duration, Local};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Response cache backed by a dedicated SQLite database.
|
||||
/// An in-memory hot cache entry for the two-tier response cache.
|
||||
struct InMemoryEntry {
|
||||
response: String,
|
||||
token_count: u32,
|
||||
created_at: std::time::Instant,
|
||||
accessed_at: std::time::Instant,
|
||||
}
|
||||
|
||||
/// Two-tier response cache: in-memory LRU (hot) + SQLite (warm).
|
||||
///
|
||||
/// Lives alongside `brain.db` as `response_cache.db` so it can be
|
||||
/// independently wiped without touching memories.
|
||||
/// The hot cache avoids SQLite round-trips for frequently repeated prompts.
|
||||
/// On miss from hot cache, falls through to SQLite. On hit from SQLite,
|
||||
/// the entry is promoted to the hot cache.
|
||||
pub struct ResponseCache {
|
||||
conn: Mutex<Connection>,
|
||||
#[allow(dead_code)]
|
||||
db_path: PathBuf,
|
||||
ttl_minutes: i64,
|
||||
max_entries: usize,
|
||||
hot_cache: Mutex<HashMap<String, InMemoryEntry>>,
|
||||
hot_max_entries: usize,
|
||||
}
|
||||
|
||||
impl ResponseCache {
|
||||
/// Open (or create) the response cache database.
|
||||
pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result<Self> {
|
||||
Self::with_hot_cache(workspace_dir, ttl_minutes, max_entries, 256)
|
||||
}
|
||||
|
||||
/// Open (or create) the response cache database with a custom hot cache size.
|
||||
pub fn with_hot_cache(
|
||||
workspace_dir: &Path,
|
||||
ttl_minutes: u32,
|
||||
max_entries: usize,
|
||||
hot_max_entries: usize,
|
||||
) -> Result<Self> {
|
||||
let db_dir = workspace_dir.join("memory");
|
||||
std::fs::create_dir_all(&db_dir)?;
|
||||
let db_path = db_dir.join("response_cache.db");
|
||||
@ -58,6 +80,8 @@ impl ResponseCache {
|
||||
db_path,
|
||||
ttl_minutes: i64::from(ttl_minutes),
|
||||
max_entries,
|
||||
hot_cache: Mutex::new(HashMap::new()),
|
||||
hot_max_entries,
|
||||
})
|
||||
}
|
||||
|
||||
@ -76,35 +100,77 @@ impl ResponseCache {
|
||||
}
|
||||
|
||||
/// Look up a cached response. Returns `None` on miss or expired entry.
|
||||
///
|
||||
/// Two-tier lookup: checks the in-memory hot cache first, then falls
|
||||
/// through to SQLite. On a SQLite hit the entry is promoted to hot cache.
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT response FROM response_cache
|
||||
WHERE prompt_hash = ?1 AND created_at > ?2",
|
||||
)?;
|
||||
|
||||
let result: Option<String> = stmt.query_row(params![key, cutoff], |row| row.get(0)).ok();
|
||||
|
||||
if result.is_some() {
|
||||
// Bump hit count and accessed_at
|
||||
let now_str = now.to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
// Tier 1: hot cache (with TTL check)
|
||||
{
|
||||
let mut hot = self.hot_cache.lock();
|
||||
if let Some(entry) = hot.get_mut(key) {
|
||||
let ttl = std::time::Duration::from_secs(self.ttl_minutes as u64 * 60);
|
||||
if entry.created_at.elapsed() > ttl {
|
||||
hot.remove(key);
|
||||
} else {
|
||||
entry.accessed_at = std::time::Instant::now();
|
||||
let response = entry.response.clone();
|
||||
drop(hot);
|
||||
// Still bump SQLite hit count for accurate stats
|
||||
let conn = self.conn.lock();
|
||||
let now_str = Local::now().to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
return Ok(Some(response));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
// Tier 2: SQLite (warm)
|
||||
let result: Option<(String, u32)> = {
|
||||
let conn = self.conn.lock();
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT response, token_count FROM response_cache
|
||||
WHERE prompt_hash = ?1 AND created_at > ?2",
|
||||
)?;
|
||||
|
||||
let result: Option<(String, u32)> = stmt
|
||||
.query_row(params![key, cutoff], |row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.ok();
|
||||
|
||||
if result.is_some() {
|
||||
let now_str = now.to_rfc3339();
|
||||
conn.execute(
|
||||
"UPDATE response_cache
|
||||
SET accessed_at = ?1, hit_count = hit_count + 1
|
||||
WHERE prompt_hash = ?2",
|
||||
params![now_str, key],
|
||||
)?;
|
||||
}
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
if let Some((ref response, token_count)) = result {
|
||||
self.promote_to_hot(key, response, token_count);
|
||||
}
|
||||
|
||||
Ok(result.map(|(r, _)| r))
|
||||
}
|
||||
|
||||
/// Store a response in the cache.
|
||||
/// Store a response in the cache (both hot and warm tiers).
|
||||
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
|
||||
// Write to hot cache
|
||||
self.promote_to_hot(key, response, token_count);
|
||||
|
||||
// Write to SQLite (warm)
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now().to_rfc3339();
|
||||
@ -138,6 +204,43 @@ impl ResponseCache {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Promote an entry to the in-memory hot cache, evicting the oldest if full.
|
||||
fn promote_to_hot(&self, key: &str, response: &str, token_count: u32) {
|
||||
let mut hot = self.hot_cache.lock();
|
||||
|
||||
// If already present, just update (keep original created_at for TTL)
|
||||
if let Some(entry) = hot.get_mut(key) {
|
||||
entry.response = response.to_string();
|
||||
entry.token_count = token_count;
|
||||
entry.accessed_at = std::time::Instant::now();
|
||||
return;
|
||||
}
|
||||
|
||||
// Evict oldest entry if at capacity
|
||||
if self.hot_max_entries > 0 && hot.len() >= self.hot_max_entries {
|
||||
if let Some(oldest_key) = hot
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.accessed_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
hot.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
|
||||
if self.hot_max_entries > 0 {
|
||||
let now = std::time::Instant::now();
|
||||
hot.insert(
|
||||
key.to_string(),
|
||||
InMemoryEntry {
|
||||
response: response.to_string(),
|
||||
token_count,
|
||||
created_at: now,
|
||||
accessed_at: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
|
||||
pub fn stats(&self) -> Result<(usize, u64, u64)> {
|
||||
let conn = self.conn.lock();
|
||||
@ -163,8 +266,8 @@ impl ResponseCache {
|
||||
|
||||
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
|
||||
pub fn clear(&self) -> Result<usize> {
|
||||
self.hot_cache.lock().clear();
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let affected = conn.execute("DELETE FROM response_cache", [])?;
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
@ -47,6 +47,15 @@ impl Observer for LogObserver {
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
info!("heartbeat.tick");
|
||||
}
|
||||
ObserverEvent::CacheHit {
|
||||
cache_type,
|
||||
tokens_saved,
|
||||
} => {
|
||||
info!(cache_type = %cache_type, tokens_saved = tokens_saved, "cache.hit");
|
||||
}
|
||||
ObserverEvent::CacheMiss { cache_type } => {
|
||||
info!(cache_type = %cache_type, "cache.miss");
|
||||
}
|
||||
ObserverEvent::Error { component, message } => {
|
||||
info!(component = %component, error = %message, "error");
|
||||
}
|
||||
|
||||
@ -16,6 +16,9 @@ pub struct PrometheusObserver {
|
||||
channel_messages: IntCounterVec,
|
||||
heartbeat_ticks: prometheus::IntCounter,
|
||||
errors: IntCounterVec,
|
||||
cache_hits: IntCounterVec,
|
||||
cache_misses: IntCounterVec,
|
||||
cache_tokens_saved: IntCounterVec,
|
||||
|
||||
// Histograms
|
||||
agent_duration: HistogramVec,
|
||||
@ -81,6 +84,27 @@ impl PrometheusObserver {
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_hits = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_cache_hits_total", "Total response cache hits"),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_misses = IntCounterVec::new(
|
||||
prometheus::Opts::new("zeroclaw_cache_misses_total", "Total response cache misses"),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let cache_tokens_saved = IntCounterVec::new(
|
||||
prometheus::Opts::new(
|
||||
"zeroclaw_cache_tokens_saved_total",
|
||||
"Total tokens saved by response cache",
|
||||
),
|
||||
&["cache_type"],
|
||||
)
|
||||
.expect("valid metric");
|
||||
|
||||
let agent_duration = HistogramVec::new(
|
||||
HistogramOpts::new(
|
||||
"zeroclaw_agent_duration_seconds",
|
||||
@ -139,6 +163,9 @@ impl PrometheusObserver {
|
||||
registry.register(Box::new(channel_messages.clone())).ok();
|
||||
registry.register(Box::new(heartbeat_ticks.clone())).ok();
|
||||
registry.register(Box::new(errors.clone())).ok();
|
||||
registry.register(Box::new(cache_hits.clone())).ok();
|
||||
registry.register(Box::new(cache_misses.clone())).ok();
|
||||
registry.register(Box::new(cache_tokens_saved.clone())).ok();
|
||||
registry.register(Box::new(agent_duration.clone())).ok();
|
||||
registry.register(Box::new(tool_duration.clone())).ok();
|
||||
registry.register(Box::new(request_latency.clone())).ok();
|
||||
@ -156,6 +183,9 @@ impl PrometheusObserver {
|
||||
channel_messages,
|
||||
heartbeat_ticks,
|
||||
errors,
|
||||
cache_hits,
|
||||
cache_misses,
|
||||
cache_tokens_saved,
|
||||
agent_duration,
|
||||
tool_duration,
|
||||
request_latency,
|
||||
@ -245,6 +275,18 @@ impl Observer for PrometheusObserver {
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
self.heartbeat_ticks.inc();
|
||||
}
|
||||
ObserverEvent::CacheHit {
|
||||
cache_type,
|
||||
tokens_saved,
|
||||
} => {
|
||||
self.cache_hits.with_label_values(&[cache_type]).inc();
|
||||
self.cache_tokens_saved
|
||||
.with_label_values(&[cache_type])
|
||||
.inc_by(*tokens_saved);
|
||||
}
|
||||
ObserverEvent::CacheMiss { cache_type } => {
|
||||
self.cache_misses.with_label_values(&[cache_type]).inc();
|
||||
}
|
||||
ObserverEvent::Error {
|
||||
component,
|
||||
message: _,
|
||||
|
||||
@ -61,6 +61,18 @@ pub enum ObserverEvent {
|
||||
},
|
||||
/// Periodic heartbeat tick from the runtime keep-alive loop.
|
||||
HeartbeatTick,
|
||||
/// Response cache hit — an LLM call was avoided.
|
||||
CacheHit {
|
||||
/// `"hot"` (in-memory) or `"warm"` (SQLite).
|
||||
cache_type: String,
|
||||
/// Estimated tokens saved by this cache hit.
|
||||
tokens_saved: u64,
|
||||
},
|
||||
/// Response cache miss — the prompt was not found in cache.
|
||||
CacheMiss {
|
||||
/// `"response"` cache layer that was checked.
|
||||
cache_type: String,
|
||||
},
|
||||
/// An error occurred in a named component.
|
||||
Error {
|
||||
/// Subsystem where the error originated (e.g., `"provider"`, `"gateway"`).
|
||||
|
||||
@ -402,6 +402,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||
response_cache_enabled: false,
|
||||
response_cache_ttl_minutes: 60,
|
||||
response_cache_max_entries: 5_000,
|
||||
response_cache_hot_entries: 256,
|
||||
snapshot_enabled: false,
|
||||
snapshot_on_hygiene: false,
|
||||
auto_hydrate: true,
|
||||
|
||||
@ -149,6 +149,10 @@ struct AnthropicUsage {
|
||||
input_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
output_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -475,6 +479,7 @@ impl AnthropicProvider {
|
||||
let usage = response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
cached_input_tokens: u.cache_read_input_tokens,
|
||||
});
|
||||
|
||||
for block in response.content {
|
||||
@ -614,6 +619,7 @@ impl Provider for AnthropicProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -312,6 +312,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -431,6 +432,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@ -491,6 +493,7 @@ impl Provider for AzureOpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -832,6 +832,7 @@ impl BedrockProvider {
|
||||
let usage = response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
|
||||
if let Some(output) = response.output {
|
||||
@ -967,6 +968,7 @@ impl Provider for BedrockProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1193,6 +1193,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
crate::providers::traits::ProviderCapabilities {
|
||||
native_tool_calling: self.native_tool_calling,
|
||||
vision: self.supports_vision,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1514,6 +1515,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
let usage = chat_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = chat_response
|
||||
.choices
|
||||
@ -1657,6 +1659,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -353,6 +353,7 @@ impl CopilotProvider {
|
||||
let usage = api_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = api_response
|
||||
.choices
|
||||
|
||||
@ -1128,6 +1128,7 @@ impl GeminiProvider {
|
||||
let usage = result.usage_metadata.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_token_count,
|
||||
output_tokens: u.candidates_token_count,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
|
||||
let text = result
|
||||
|
||||
@ -632,6 +632,7 @@ impl Provider for OllamaProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -764,6 +765,7 @@ impl Provider for OllamaProvider {
|
||||
Some(TokenUsage {
|
||||
input_tokens: response.prompt_eval_count,
|
||||
output_tokens: response.eval_count,
|
||||
cached_input_tokens: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
||||
@ -135,6 +135,14 @@ struct UsageInfo {
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
completion_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
prompt_tokens_details: Option<PromptTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PromptTokensDetails {
|
||||
#[serde(default)]
|
||||
cached_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -385,6 +393,7 @@ impl Provider for OpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@ -448,6 +457,7 @@ impl Provider for OpenAiProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -640,6 +640,7 @@ impl Provider for OpenAiCodexProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -306,6 +306,7 @@ impl Provider for OpenRouterProvider {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -463,6 +464,7 @@ impl Provider for OpenRouterProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
@ -554,6 +556,7 @@ impl Provider for OpenRouterProvider {
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
|
||||
@ -54,6 +54,9 @@ pub struct ToolCall {
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: Option<u64>,
|
||||
pub output_tokens: Option<u64>,
|
||||
/// Tokens served from the provider's prompt cache (Anthropic `cache_read_input_tokens`,
|
||||
/// OpenAI `prompt_tokens_details.cached_tokens`).
|
||||
pub cached_input_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
/// An LLM response that may contain text, tool calls, or both.
|
||||
@ -233,6 +236,9 @@ pub struct ProviderCapabilities {
|
||||
pub native_tool_calling: bool,
|
||||
/// Whether the provider supports vision / image inputs.
|
||||
pub vision: bool,
|
||||
/// Whether the provider supports prompt caching (Anthropic cache_control,
|
||||
/// OpenAI automatic prompt caching).
|
||||
pub prompt_caching: bool,
|
||||
}
|
||||
|
||||
/// Provider-specific tool payload formats.
|
||||
@ -498,6 +504,7 @@ mod tests {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -568,6 +575,7 @@ mod tests {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(100),
|
||||
output_tokens: Some(50),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
};
|
||||
@ -613,14 +621,17 @@ mod tests {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: false,
|
||||
prompt_caching: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
|
||||
@ -166,6 +166,7 @@ impl Provider for TraceLlmProvider {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(input_tokens),
|
||||
output_tokens: Some(output_tokens),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
}),
|
||||
@ -188,6 +189,7 @@ impl Provider for TraceLlmProvider {
|
||||
usage: Some(TokenUsage {
|
||||
input_tokens: Some(input_tokens),
|
||||
output_tokens: Some(output_tokens),
|
||||
cached_input_tokens: None,
|
||||
}),
|
||||
reasoning_content: None,
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user