feat(cost): enforce preflight budget policy in agent loop

This commit is contained in:
argenis de la rosa 2026-02-28 23:16:27 -05:00 committed by Argenis
parent 9ef617289f
commit 4043056332
3 changed files with 501 additions and 64 deletions

View File

@ -1,5 +1,7 @@
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
use crate::config::schema::{CostEnforcementMode, ModelPricing};
use crate::config::{Config, ProgressMode};
use crate::cost::{BudgetCheck, CostTracker, UsagePeriod};
use crate::memory::{self, Memory, MemoryCategory};
use crate::multimodal;
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
@ -19,9 +21,11 @@ use rustyline::hint::Hinter;
use rustyline::validate::Validator;
use rustyline::{CompletionType, Config as RlConfig, Context, Editor, Helper};
use std::borrow::Cow;
use std::collections::{BTreeSet, HashSet};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Write;
use std::future::Future;
use std::io::Write as _;
use std::path::Path;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
@ -297,6 +301,7 @@ tokio::task_local! {
static LOOP_DETECTION_CONFIG: LoopDetectionConfig;
static SAFETY_HEARTBEAT_CONFIG: Option<SafetyHeartbeatConfig>;
static TOOL_LOOP_PROGRESS_MODE: ProgressMode;
static TOOL_LOOP_COST_ENFORCEMENT_CONTEXT: Option<CostEnforcementContext>;
}
/// Configuration for periodic safety-constraint re-injection (heartbeat).
@ -308,6 +313,56 @@ pub(crate) struct SafetyHeartbeatConfig {
pub interval: usize,
}
#[derive(Clone)]
pub(crate) struct CostEnforcementContext {
tracker: Arc<CostTracker>,
prices: HashMap<String, ModelPricing>,
mode: CostEnforcementMode,
route_down_model: Option<String>,
reserve_percent: u8,
}
pub(crate) fn create_cost_enforcement_context(
cost_config: &crate::config::CostConfig,
workspace_dir: &Path,
) -> Option<CostEnforcementContext> {
if !cost_config.enabled {
return None;
}
let tracker = match CostTracker::new(cost_config.clone(), workspace_dir) {
Ok(tracker) => Arc::new(tracker),
Err(error) => {
tracing::warn!("Cost budget preflight disabled: failed to initialize tracker: {error}");
return None;
}
};
let route_down_model = cost_config
.enforcement
.route_down_model
.clone()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
Some(CostEnforcementContext {
tracker,
prices: cost_config.prices.clone(),
mode: cost_config.enforcement.mode,
route_down_model,
reserve_percent: cost_config.enforcement.reserve_percent.min(100),
})
}
pub(crate) async fn scope_cost_enforcement_context<F>(
context: Option<CostEnforcementContext>,
future: F,
) -> F::Output
where
F: Future,
{
TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
.scope(context, future)
.await
}
fn should_inject_safety_heartbeat(counter: usize, interval: usize) -> bool {
interval > 0 && counter > 0 && counter % interval == 0
}
@ -320,6 +375,100 @@ fn should_emit_tool_progress(mode: ProgressMode) -> bool {
mode != ProgressMode::Off
}
fn estimate_prompt_tokens(
messages: &[ChatMessage],
tools: Option<&[crate::tools::ToolSpec]>,
) -> u64 {
let message_chars: usize = messages
.iter()
.map(|msg| {
msg.role
.len()
.saturating_add(msg.content.chars().count())
.saturating_add(16)
})
.sum();
let tool_chars: usize = tools
.map(|specs| {
specs
.iter()
.map(|spec| serde_json::to_string(spec).map_or(0, |value| value.chars().count()))
.sum()
})
.unwrap_or(0);
let total_chars = message_chars.saturating_add(tool_chars);
let char_estimate = (total_chars as f64 / 4.0).ceil() as u64;
let framing_overhead = (messages.len() as u64).saturating_mul(6).saturating_add(64);
char_estimate.saturating_add(framing_overhead)
}
fn lookup_model_pricing(
prices: &HashMap<String, ModelPricing>,
provider: &str,
model: &str,
) -> (f64, f64) {
let full_name = format!("{provider}/{model}");
if let Some(pricing) = prices.get(&full_name) {
return (pricing.input, pricing.output);
}
if let Some(pricing) = prices.get(model) {
return (pricing.input, pricing.output);
}
for (key, pricing) in prices {
let key_model = key.split('/').next_back().unwrap_or(key);
if model.starts_with(key_model) || key_model.starts_with(model) {
return (pricing.input, pricing.output);
}
let normalized_model = model.replace('-', ".");
let normalized_key = key_model.replace('-', ".");
if normalized_model.contains(&normalized_key) || normalized_key.contains(&normalized_model)
{
return (pricing.input, pricing.output);
}
}
(3.0, 15.0)
}
fn estimate_request_cost_usd(
context: &CostEnforcementContext,
provider: &str,
model: &str,
messages: &[ChatMessage],
tools: Option<&[crate::tools::ToolSpec]>,
) -> f64 {
let reserve_multiplier = 1.0 + (f64::from(context.reserve_percent) / 100.0);
let input_tokens = estimate_prompt_tokens(messages, tools);
let output_tokens = (input_tokens / 4).max(256);
let input_tokens = ((input_tokens as f64) * reserve_multiplier).ceil() as u64;
let output_tokens = ((output_tokens as f64) * reserve_multiplier).ceil() as u64;
let (input_price, output_price) = lookup_model_pricing(&context.prices, provider, model);
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price.max(0.0);
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price.max(0.0);
input_cost + output_cost
}
fn usage_period_label(period: UsagePeriod) -> &'static str {
match period {
UsagePeriod::Session => "session",
UsagePeriod::Day => "daily",
UsagePeriod::Month => "monthly",
}
}
fn budget_exceeded_message(
model: &str,
estimated_cost_usd: f64,
current_usd: f64,
limit_usd: f64,
period: UsagePeriod,
) -> String {
format!(
"Budget enforcement blocked request for model '{model}': projected cost (+${estimated_cost_usd:.4}) exceeds {period_label} limit (${limit_usd:.2}, current ${current_usd:.2}).",
period_label = usage_period_label(period)
)
}
#[derive(Debug, Clone)]
struct ProgressEntry {
name: String,
@ -894,7 +1043,12 @@ pub(crate) async fn run_tool_call_loop(
let progress_mode = TOOL_LOOP_PROGRESS_MODE
.try_with(|mode| *mode)
.unwrap_or(ProgressMode::Verbose);
let cost_enforcement_context = TOOL_LOOP_COST_ENFORCEMENT_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten();
let mut progress_tracker = ProgressTracker::default();
let mut active_model = model.to_string();
let bypass_non_cli_approval_for_turn =
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
if bypass_non_cli_approval_for_turn {
@ -902,7 +1056,7 @@ pub(crate) async fn run_tool_call_loop(
"approval_bypass_one_time_all_tools_consumed",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("consumed one-time non-cli allow-all approval token"),
@ -954,6 +1108,13 @@ pub(crate) async fn run_tool_call_loop(
request_messages.push(ChatMessage::user(reminder));
}
}
// Unified path via Provider::chat so provider-specific native tool logic
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
let request_tools = if use_native_tools {
Some(tool_specs.as_slice())
} else {
None
};
// ── Progress: LLM thinking ────────────────────────────
if should_emit_verbose_progress(progress_mode) {
@ -967,16 +1128,175 @@ pub(crate) async fn run_tool_call_loop(
}
}
if let Some(cost_ctx) = cost_enforcement_context.as_ref() {
let mut estimated_cost_usd = estimate_request_cost_usd(
cost_ctx,
provider_name,
active_model.as_str(),
&request_messages,
request_tools,
);
let mut budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
Ok(check) => Some(check),
Err(error) => {
tracing::warn!("Cost preflight check failed: {error}");
None
}
};
if matches!(cost_ctx.mode, CostEnforcementMode::RouteDown)
&& matches!(budget_check, Some(BudgetCheck::Exceeded { .. }))
{
if let Some(route_down_model) = cost_ctx
.route_down_model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
{
if route_down_model != active_model {
let previous_model = active_model.clone();
active_model = route_down_model.to_string();
estimated_cost_usd = estimate_request_cost_usd(
cost_ctx,
provider_name,
active_model.as_str(),
&request_messages,
request_tools,
);
budget_check = match cost_ctx.tracker.check_budget(estimated_cost_usd) {
Ok(check) => Some(check),
Err(error) => {
tracing::warn!(
"Cost preflight check failed after route-down: {error}"
);
None
}
};
runtime_trace::record_event(
"cost_budget_route_down",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("budget exceeded on primary model; route-down candidate applied"),
serde_json::json!({
"iteration": iteration + 1,
"from_model": previous_model,
"to_model": active_model,
"estimated_cost_usd": estimated_cost_usd,
}),
);
}
}
}
if let Some(check) = budget_check {
match check {
BudgetCheck::Allowed => {}
BudgetCheck::Warning {
current_usd,
limit_usd,
period,
} => {
tracing::warn!(
model = active_model.as_str(),
period = usage_period_label(period),
current_usd,
limit_usd,
estimated_cost_usd,
"Cost budget warning threshold reached"
);
runtime_trace::record_event(
"cost_budget_warning",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("budget warning threshold reached"),
serde_json::json!({
"iteration": iteration + 1,
"period": usage_period_label(period),
"current_usd": current_usd,
"limit_usd": limit_usd,
"estimated_cost_usd": estimated_cost_usd,
}),
);
}
BudgetCheck::Exceeded {
current_usd,
limit_usd,
period,
} => match cost_ctx.mode {
CostEnforcementMode::Warn => {
tracing::warn!(
model = active_model.as_str(),
period = usage_period_label(period),
current_usd,
limit_usd,
estimated_cost_usd,
"Cost budget exceeded (warn mode): continuing request"
);
runtime_trace::record_event(
"cost_budget_exceeded_warn_mode",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("budget exceeded but proceeding due to warn mode"),
serde_json::json!({
"iteration": iteration + 1,
"period": usage_period_label(period),
"current_usd": current_usd,
"limit_usd": limit_usd,
"estimated_cost_usd": estimated_cost_usd,
}),
);
}
CostEnforcementMode::RouteDown | CostEnforcementMode::Block => {
let message = budget_exceeded_message(
active_model.as_str(),
estimated_cost_usd,
current_usd,
limit_usd,
period,
);
runtime_trace::record_event(
"cost_budget_blocked",
Some(channel_name),
Some(provider_name),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&message),
serde_json::json!({
"iteration": iteration + 1,
"period": usage_period_label(period),
"current_usd": current_usd,
"limit_usd": limit_usd,
"estimated_cost_usd": estimated_cost_usd,
}),
);
return Err(anyhow::anyhow!(message));
}
},
}
}
}
observer.record_event(&ObserverEvent::LlmRequest {
provider: provider_name.to_string(),
model: model.to_string(),
model: active_model.clone(),
messages_count: history.len(),
});
runtime_trace::record_event(
"llm_request",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
None,
None,
@ -990,23 +1310,15 @@ pub(crate) async fn run_tool_call_loop(
// Fire void hook before LLM call
if let Some(hooks) = hooks {
hooks.fire_llm_input(history, model).await;
hooks.fire_llm_input(history, active_model.as_str()).await;
}
// Unified path via Provider::chat so provider-specific native tool logic
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
let request_tools = if use_native_tools {
Some(tool_specs.as_slice())
} else {
None
};
let chat_future = provider.chat(
ChatRequest {
messages: &request_messages,
tools: request_tools,
},
model,
active_model.as_str(),
temperature,
);
@ -1036,7 +1348,7 @@ pub(crate) async fn run_tool_call_loop(
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
model: active_model.clone(),
duration: llm_started_at.elapsed(),
success: true,
error_message: None,
@ -1066,7 +1378,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_parse_issue",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(parse_issue),
@ -1084,7 +1396,7 @@ pub(crate) async fn run_tool_call_loop(
"llm_response",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
None,
@ -1135,7 +1447,7 @@ pub(crate) async fn run_tool_call_loop(
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
model: active_model.clone(),
duration: llm_started_at.elapsed(),
success: false,
error_message: Some(safe_error.clone()),
@ -1146,7 +1458,7 @@ pub(crate) async fn run_tool_call_loop(
"llm_response",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&safe_error),
@ -1199,7 +1511,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_followthrough_retry",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("llm response implied follow-up action but emitted no tool call"),
@ -1227,7 +1539,7 @@ pub(crate) async fn run_tool_call_loop(
"turn_final_response",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
None,
@ -1303,7 +1615,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&cancelled),
@ -1345,7 +1657,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&blocked),
@ -1385,7 +1697,7 @@ pub(crate) async fn run_tool_call_loop(
"approval_bypass_non_cli_session_grant",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(true),
Some("using runtime non-cli session approval grant"),
@ -1442,7 +1754,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&denied),
@ -1476,7 +1788,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some(&duplicate),
@ -1504,7 +1816,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_start",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
None,
None,
@ -1564,7 +1876,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_call_result",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(outcome.success),
outcome.error_reason.as_deref(),
@ -1676,7 +1988,7 @@ pub(crate) async fn run_tool_call_loop(
"loop_detected_warning",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("loop pattern detected, injecting self-correction prompt"),
@ -1698,7 +2010,7 @@ pub(crate) async fn run_tool_call_loop(
"loop_detected_hard_stop",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("loop persisted after warning, stopping early"),
@ -1718,7 +2030,7 @@ pub(crate) async fn run_tool_call_loop(
"tool_loop_exhausted",
Some(channel_name),
Some(provider_name),
Some(model),
Some(active_model.as_str()),
Some(&turn_id),
Some(false),
Some("agent exceeded maximum tool iterations"),
@ -2151,6 +2463,8 @@ pub async fn run(
// ── Execute ──────────────────────────────────────────────────
let start = Instant::now();
let cost_enforcement_context =
create_cost_enforcement_context(&config.cost, &config.workspace_dir);
let mut final_output = String::new();
@ -2197,8 +2511,9 @@ pub async fn run(
} else {
None
};
let response = SAFETY_HEARTBEAT_CONFIG
.scope(
let response = scope_cost_enforcement_context(
cost_enforcement_context.clone(),
SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
@ -2221,8 +2536,9 @@ pub async fn run(
&[],
),
),
)
.await?;
),
)
.await?;
final_output = response.clone();
if config.memory.auto_save && response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
let assistant_key = autosave_memory_key("assistant_resp");
@ -2374,8 +2690,9 @@ pub async fn run(
} else {
None
};
let response = match SAFETY_HEARTBEAT_CONFIG
.scope(
let response = match scope_cost_enforcement_context(
cost_enforcement_context.clone(),
SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
LOOP_DETECTION_CONFIG.scope(
ld_cfg,
@ -2398,8 +2715,9 @@ pub async fn run(
&[],
),
),
)
.await
),
)
.await
{
Ok(resp) => resp,
Err(e) => {
@ -2682,6 +3000,8 @@ pub async fn process_message_with_session(
ChatMessage::user(&enriched),
];
let cost_enforcement_context =
create_cost_enforcement_context(&config.cost, &config.workspace_dir);
let hb_cfg = if config.agent.safety_heartbeat_interval > 0 {
Some(SafetyHeartbeatConfig {
body: security.summary_for_heartbeat(),
@ -2690,8 +3010,9 @@ pub async fn process_message_with_session(
} else {
None
};
SAFETY_HEARTBEAT_CONFIG
.scope(
scope_cost_enforcement_context(
cost_enforcement_context,
SAFETY_HEARTBEAT_CONFIG.scope(
hb_cfg,
agent_turn(
provider.as_ref(),
@ -2705,8 +3026,9 @@ pub async fn process_message_with_session(
&config.multimodal,
config.agent.max_tool_iterations,
),
)
.await
),
)
.await
}
#[cfg(test)]

View File

@ -250,6 +250,7 @@ struct ChannelRuntimeDefaults {
api_key: Option<String>,
api_url: Option<String>,
reliability: crate::config::ReliabilityConfig,
cost: crate::config::CostConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -1054,6 +1055,7 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
api_key: config.api_key.clone(),
api_url: config.api_url.clone(),
reliability: config.reliability.clone(),
cost: config.cost.clone(),
}
}
@ -1099,6 +1101,7 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau
api_key: ctx.api_key.clone(),
api_url: ctx.api_url.clone(),
reliability: (*ctx.reliability).clone(),
cost: crate::config::CostConfig::default(),
}
}
@ -3665,6 +3668,10 @@ or tune thresholds in config.",
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context(
&runtime_defaults.cost,
ctx.workspace_dir.as_path(),
);
let (approval_prompt_tx, mut approval_prompt_rx) =
tokio::sync::mpsc::unbounded_channel::<NonCliApprovalPrompt>();
@ -3706,31 +3713,33 @@ or tune thresholds in config.",
} else {
None
};
let llm_result = tokio::select! {
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
result = tokio::time::timeout(
Duration::from_secs(timeout_budget_secs),
run_tool_call_loop_with_non_cli_approval_context(
active_provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
ctx.observer.as_ref(),
route.provider.as_str(),
route.model.as_str(),
runtime_defaults.temperature,
true,
Some(ctx.approval_manager.as_ref()),
msg.channel.as_str(),
non_cli_approval_context,
&ctx.multimodal,
ctx.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx,
ctx.hooks.as_deref(),
&excluded_tools_snapshot,
progress_mode,
ctx.safety_heartbeat.clone(),
crate::agent::loop_::scope_cost_enforcement_context(
cost_enforcement_context,
run_tool_call_loop_with_non_cli_approval_context(
active_provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
ctx.observer.as_ref(),
route.provider.as_str(),
route.model.as_str(),
runtime_defaults.temperature,
true,
Some(ctx.approval_manager.as_ref()),
msg.channel.as_str(),
non_cli_approval_context,
&ctx.multimodal,
ctx.max_tool_iterations,
Some(cancellation_token.clone()),
delta_tx,
ctx.hooks.as_deref(),
&excluded_tools_snapshot,
progress_mode,
ctx.safety_heartbeat.clone(),
),
),
) => LlmExecutionResult::Completed(result),
};
@ -9401,6 +9410,7 @@ BTC is currently around $65,000 based on latest tool output."#
api_key: None,
api_url: None,
reliability: crate::config::ReliabilityConfig::default(),
cost: crate::config::CostConfig::default(),
},
perplexity_filter: crate::config::PerplexityFilterConfig::default(),
outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(),

View File

@ -1200,6 +1200,58 @@ pub struct CostConfig {
/// Per-model pricing (USD per 1M tokens)
#[serde(default)]
pub prices: std::collections::HashMap<String, ModelPricing>,
/// Runtime budget enforcement policy (`[cost.enforcement]`).
#[serde(default)]
pub enforcement: CostEnforcementConfig,
}
/// Budget enforcement behavior when projected spend approaches/exceeds limits.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CostEnforcementMode {
/// Log warnings only; never block the request.
Warn,
/// Attempt one downgrade to a cheaper route/model, then block if still over budget.
RouteDown,
/// Block immediately when projected spend exceeds configured limits.
Block,
}
fn default_cost_enforcement_mode() -> CostEnforcementMode {
CostEnforcementMode::Warn
}
/// Runtime budget enforcement controls (`[cost.enforcement]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct CostEnforcementConfig {
/// Enforcement behavior. Default: `warn`.
#[serde(default = "default_cost_enforcement_mode")]
pub mode: CostEnforcementMode,
/// Optional fallback model (or `hint:*`) when `mode = "route_down"`.
#[serde(default = "default_route_down_model")]
pub route_down_model: Option<String>,
/// Extra reserve added to token/cost estimates (percentage, 0-100). Default: `10`.
#[serde(default = "default_cost_reserve_percent")]
pub reserve_percent: u8,
}
fn default_route_down_model() -> Option<String> {
Some("hint:fast".to_string())
}
fn default_cost_reserve_percent() -> u8 {
10
}
impl Default for CostEnforcementConfig {
fn default() -> Self {
Self {
mode: default_cost_enforcement_mode(),
route_down_model: default_route_down_model(),
reserve_percent: default_cost_reserve_percent(),
}
}
}
/// Per-model pricing entry (USD per 1M tokens).
@ -1235,6 +1287,7 @@ impl Default for CostConfig {
warn_at_percent: default_warn_percent(),
allow_override: false,
prices: get_default_pricing(),
enforcement: CostEnforcementConfig::default(),
}
}
}
@ -7769,6 +7822,14 @@ impl Config {
anyhow::bail!("web_search.timeout_secs must be greater than 0");
}
// Cost
if self.cost.warn_at_percent > 100 {
anyhow::bail!("cost.warn_at_percent must be between 0 and 100");
}
if self.cost.enforcement.reserve_percent > 100 {
anyhow::bail!("cost.enforcement.reserve_percent must be between 0 and 100");
}
// Scheduler
if self.scheduler.max_concurrent == 0 {
anyhow::bail!("scheduler.max_concurrent must be greater than 0");
@ -13743,4 +13804,48 @@ sensitivity = 0.9
.validate()
.expect("disabled coordination should allow empty lead agent");
}
#[test]
async fn cost_enforcement_defaults_are_stable() {
let cost = CostConfig::default();
assert_eq!(cost.enforcement.mode, CostEnforcementMode::Warn);
assert_eq!(
cost.enforcement.route_down_model.as_deref(),
Some("hint:fast")
);
assert_eq!(cost.enforcement.reserve_percent, 10);
}
#[test]
async fn cost_enforcement_config_parses_route_down_mode() {
let parsed: CostConfig = toml::from_str(
r#"
enabled = true
[enforcement]
mode = "route_down"
route_down_model = "hint:fast"
reserve_percent = 15
"#,
)
.expect("cost enforcement should parse");
assert!(parsed.enabled);
assert_eq!(parsed.enforcement.mode, CostEnforcementMode::RouteDown);
assert_eq!(
parsed.enforcement.route_down_model.as_deref(),
Some("hint:fast")
);
assert_eq!(parsed.enforcement.reserve_percent, 15);
}
#[test]
async fn validation_rejects_cost_enforcement_reserve_over_100() {
let mut config = Config::default();
config.cost.enforcement.reserve_percent = 150;
let err = config
.validate()
.expect_err("expected cost.enforcement.reserve_percent validation failure");
assert!(err.to_string().contains("cost.enforcement.reserve_percent"));
}
}