feat(tools): wire auth_profile + quota tools into agent loop and persist switch_provider
- Register 4 new tools (ManageAuthProfileTool, CheckProviderQuotaTool, SwitchProviderTool, EstimateQuotaCostTool) in all_tools_with_runtime - SwitchProviderTool now loads config from disk and calls save() to persist default_provider/default_model to config.toml - Inject Provider & Budget Context section into system prompt when Config is available - Remove emoji from tool output for cleaner parsing - Replace format! push_str with std::fmt::Write for consistency Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8c0be20422
commit
d5fe47acff
310
src/tools/auth_profile.rs
Normal file
310
src/tools/auth_profile.rs
Normal file
@ -0,0 +1,310 @@
|
||||
//! Tool for managing auth profiles (list, switch, refresh).
|
||||
//!
|
||||
//! Allows the agent to:
|
||||
//! - List all configured auth profiles with expiry status
|
||||
//! - Switch active profile for a provider
|
||||
//! - Refresh OAuth tokens that are expired or expiring
|
||||
|
||||
use crate::auth::{normalize_provider, AuthService};
|
||||
use crate::config::Config;
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use std::fmt::Write as _;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ManageAuthProfileTool {
|
||||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
impl ManageAuthProfileTool {
|
||||
pub fn new(config: Arc<Config>) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
fn auth_service(&self) -> AuthService {
|
||||
AuthService::from_config(&self.config)
|
||||
}
|
||||
|
||||
async fn handle_list(&self, provider_filter: Option<&str>) -> Result<ToolResult> {
|
||||
let auth = self.auth_service();
|
||||
let data = auth.load_profiles().await?;
|
||||
|
||||
let mut output = String::new();
|
||||
let _ = writeln!(output, "## Auth Profiles\n");
|
||||
|
||||
let mut count = 0u32;
|
||||
for (id, profile) in &data.profiles {
|
||||
if let Some(filter) = provider_filter {
|
||||
let normalized =
|
||||
normalize_provider(filter).unwrap_or_else(|_| filter.to_string());
|
||||
if profile.provider != normalized {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
count += 1;
|
||||
let is_active = data
|
||||
.active_profiles
|
||||
.get(&profile.provider)
|
||||
.map_or(false, |active| active == id);
|
||||
|
||||
let active_marker = if is_active { " [ACTIVE]" } else { "" };
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"- **{}** ({}){active_marker}",
|
||||
profile.profile_name, profile.provider
|
||||
);
|
||||
|
||||
if let Some(ref acct) = profile.account_id {
|
||||
let _ = writeln!(output, " Account: {acct}");
|
||||
}
|
||||
|
||||
let _ = writeln!(output, " Type: {:?}", profile.kind);
|
||||
|
||||
if let Some(ref ts) = profile.token_set {
|
||||
if let Some(expires) = ts.expires_at {
|
||||
let now = chrono::Utc::now();
|
||||
if expires < now {
|
||||
let ago = now.signed_duration_since(expires);
|
||||
let _ = writeln!(output, " Token: EXPIRED ({}h ago)", ago.num_hours());
|
||||
} else {
|
||||
let left = expires.signed_duration_since(now);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Token: valid (expires in {}h {}m)",
|
||||
left.num_hours(),
|
||||
left.num_minutes() % 60
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let _ = writeln!(output, " Token: no expiry set");
|
||||
}
|
||||
let has_refresh = ts.refresh_token.is_some();
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Refresh token: {}",
|
||||
if has_refresh { "yes" } else { "no" }
|
||||
);
|
||||
} else if profile.token.is_some() {
|
||||
let _ = writeln!(output, " Token: API key (no expiry)");
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
if provider_filter.is_some() {
|
||||
let _ = writeln!(output, "No profiles found for the specified provider.");
|
||||
} else {
|
||||
let _ = writeln!(output, "No auth profiles configured.");
|
||||
}
|
||||
} else {
|
||||
let _ = writeln!(output, "\nTotal: {count} profile(s)");
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_switch(&self, provider: &str, profile_name: &str) -> Result<ToolResult> {
|
||||
let auth = self.auth_service();
|
||||
let profile_id = auth.set_active_profile(provider, profile_name).await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Switched active profile for {provider} to: {profile_id}"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_refresh(&self, provider: &str) -> Result<ToolResult> {
|
||||
let normalized = normalize_provider(provider)?;
|
||||
let auth = self.auth_service();
|
||||
|
||||
let result = match normalized.as_str() {
|
||||
"openai-codex" => match auth.get_valid_openai_access_token(None).await {
|
||||
Ok(Some(_)) => "OpenAI Codex token refreshed successfully.".to_string(),
|
||||
Ok(None) => "No OpenAI Codex profile found to refresh.".to_string(),
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("OpenAI token refresh failed: {e}")),
|
||||
})
|
||||
}
|
||||
},
|
||||
"gemini" => match auth.get_valid_gemini_access_token(None).await {
|
||||
Ok(Some(_)) => "Gemini token refreshed successfully.".to_string(),
|
||||
Ok(None) => "No Gemini profile found to refresh.".to_string(),
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Gemini token refresh failed: {e}")),
|
||||
})
|
||||
}
|
||||
},
|
||||
other => {
|
||||
// For non-OAuth providers, just verify the token exists
|
||||
match auth.get_provider_bearer_token(other, None).await {
|
||||
Ok(Some(_)) => format!("Provider '{other}' uses API key auth (no refresh needed). Token is present."),
|
||||
Ok(None) => format!("No profile found for provider '{other}'."),
|
||||
Err(e) => return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Token check failed for '{other}': {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: result,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ManageAuthProfileTool {
|
||||
fn name(&self) -> &str {
|
||||
"manage_auth_profile"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Manage auth profiles: list all profiles with token status, switch active profile \
|
||||
for a provider, or refresh expired OAuth tokens. Use when user asks about accounts, \
|
||||
tokens, or when you encounter expired/rate-limited credentials."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list", "switch", "refresh"],
|
||||
"description": "Action to perform: 'list' shows all profiles, 'switch' changes active profile, 'refresh' renews OAuth tokens"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": "Provider name (e.g., 'gemini', 'openai-codex', 'anthropic'). Required for switch and refresh."
|
||||
},
|
||||
"profile": {
|
||||
"type": "string",
|
||||
"description": "Profile name to switch to (for 'switch' action). E.g., 'default', 'work', 'personal'."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("list");
|
||||
|
||||
let provider = args.get("provider").and_then(|v| v.as_str());
|
||||
|
||||
let result = match action {
|
||||
"list" => self.handle_list(provider).await,
|
||||
"switch" => {
|
||||
let Some(provider) = provider else {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'provider' is required for switch action".into()),
|
||||
});
|
||||
};
|
||||
let profile = args
|
||||
.get("profile")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
self.handle_switch(provider, profile).await
|
||||
}
|
||||
"refresh" => {
|
||||
let Some(provider) = provider else {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'provider' is required for refresh action".into()),
|
||||
});
|
||||
};
|
||||
self.handle_refresh(provider).await
|
||||
}
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action '{other}'. Valid: list, switch, refresh"
|
||||
)),
|
||||
}),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(outcome) => Ok(outcome),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_manage_auth_profile_schema() {
|
||||
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["action"]["enum"].is_array());
|
||||
assert_eq!(tool.name(), "manage_auth_profile");
|
||||
assert!(tool.description().contains("auth profiles"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_empty_profiles() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config = Config {
|
||||
workspace_dir: tmp.path().to_path_buf(),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
};
|
||||
let tool = ManageAuthProfileTool::new(Arc::new(config));
|
||||
let result = tool.execute(json!({"action": "list"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Auth Profiles"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_switch_missing_provider() {
|
||||
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
|
||||
let result = tool.execute(json!({"action": "switch"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("provider"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_missing_provider() {
|
||||
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
|
||||
let result = tool.execute(json!({"action": "refresh"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("provider"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unknown_action() {
|
||||
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
|
||||
let result = tool.execute(json!({"action": "delete"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
}
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
pub mod agents_ipc;
|
||||
pub mod apply_patch;
|
||||
pub mod auth_profile;
|
||||
pub mod browser;
|
||||
pub mod browser_open;
|
||||
pub mod cli_discovery;
|
||||
@ -58,6 +59,7 @@ pub mod pdf_read;
|
||||
pub mod process;
|
||||
pub mod proxy_config;
|
||||
pub mod pushover;
|
||||
pub mod quota_tools;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
@ -134,6 +136,9 @@ pub use web_fetch::WebFetchTool;
|
||||
pub use web_search_config::WebSearchConfigTool;
|
||||
pub use web_search_tool::WebSearchTool;
|
||||
|
||||
pub use auth_profile::ManageAuthProfileTool;
|
||||
pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool};
|
||||
|
||||
use crate::config::{Config, DelegateAgentConfig};
|
||||
use crate::memory::Memory;
|
||||
use crate::runtime::{NativeRuntime, RuntimeAdapter};
|
||||
@ -290,6 +295,10 @@ pub fn all_tools_with_runtime(
|
||||
Arc::new(ProxyConfigTool::new(config.clone(), security.clone())),
|
||||
Arc::new(WebAccessConfigTool::new(config.clone(), security.clone())),
|
||||
Arc::new(WebSearchConfigTool::new(config.clone(), security.clone())),
|
||||
Arc::new(ManageAuthProfileTool::new(config.clone())),
|
||||
Arc::new(CheckProviderQuotaTool::new(config.clone())),
|
||||
Arc::new(SwitchProviderTool::new(config.clone())),
|
||||
Arc::new(EstimateQuotaCostTool),
|
||||
Arc::new(PushoverTool::new(
|
||||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
|
||||
562
src/tools/quota_tools.rs
Normal file
562
src/tools/quota_tools.rs
Normal file
@ -0,0 +1,562 @@
|
||||
//! Built-in tools for quota monitoring and provider management.
|
||||
//!
|
||||
//! These tools allow the agent to:
|
||||
//! - Check quota status conversationally
|
||||
//! - Switch providers when rate limited
|
||||
//! - Estimate quota costs before operations
|
||||
//! - Report usage metrics to the user
|
||||
|
||||
use crate::auth::profiles::AuthProfilesStore;
|
||||
use crate::config::Config;
|
||||
use crate::cost::tracker::CostTracker;
|
||||
use crate::providers::health::ProviderHealthTracker;
|
||||
use crate::providers::quota_types::{QuotaStatus, QuotaSummary};
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write as _;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Tool for checking provider quota status.
|
||||
///
|
||||
/// Allows agent to query: "какие модели доступны?" or "what providers have quota?"
|
||||
pub struct CheckProviderQuotaTool {
|
||||
config: Arc<Config>,
|
||||
cost_tracker: Option<Arc<CostTracker>>,
|
||||
}
|
||||
|
||||
impl CheckProviderQuotaTool {
|
||||
pub fn new(config: Arc<Config>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
cost_tracker: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_cost_tracker(mut self, tracker: Arc<CostTracker>) -> Self {
|
||||
self.cost_tracker = Some(tracker);
|
||||
self
|
||||
}
|
||||
|
||||
async fn build_quota_summary(&self, provider_filter: Option<&str>) -> Result<QuotaSummary> {
|
||||
// Initialize health tracker with same settings as reliable.rs
|
||||
let health_tracker = ProviderHealthTracker::new(
|
||||
3, // failure_threshold
|
||||
Duration::from_secs(60), // cooldown
|
||||
100, // max tracked providers
|
||||
);
|
||||
|
||||
// Load OAuth profiles (state_dir = config dir parent, where auth-profiles.json lives)
|
||||
let state_dir = crate::auth::state_dir_from_config(&self.config);
|
||||
let auth_store = AuthProfilesStore::new(&state_dir, self.config.secrets.encrypt);
|
||||
let profiles_data = auth_store.load().await?;
|
||||
|
||||
// Build quota summary using quota_cli logic
|
||||
crate::providers::quota_cli::build_quota_summary(
|
||||
&health_tracker,
|
||||
&profiles_data,
|
||||
provider_filter,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CheckProviderQuotaTool {
|
||||
fn name(&self) -> &str {
|
||||
"check_provider_quota"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Check current rate limit and quota status for AI providers. \
|
||||
Returns available providers, rate-limited providers, quota remaining, \
|
||||
and estimated reset time. Use this when user asks about model availability \
|
||||
or when you encounter rate limit errors."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": "Specific provider to check (optional). Examples: openai, gemini, anthropic. If omitted, checks all providers."
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
use std::fmt::Write;
|
||||
let provider_filter = args.get("provider").and_then(|v| v.as_str());
|
||||
|
||||
let summary = self.build_quota_summary(provider_filter).await?;
|
||||
|
||||
// Format result for agent
|
||||
let available = summary.available_providers();
|
||||
let rate_limited = summary.rate_limited_providers();
|
||||
let circuit_open = summary.circuit_open_providers();
|
||||
|
||||
let mut output = String::new();
|
||||
let _ = write!(
|
||||
output,
|
||||
"Quota Status ({})\n\n",
|
||||
summary.timestamp.format("%Y-%m-%d %H:%M UTC")
|
||||
);
|
||||
|
||||
if !available.is_empty() {
|
||||
let _ = writeln!(output, "Available providers: {}", available.join(", "));
|
||||
}
|
||||
if !rate_limited.is_empty() {
|
||||
let _ = writeln!(output, "Rate-limited providers: {}", rate_limited.join(", "));
|
||||
}
|
||||
if !circuit_open.is_empty() {
|
||||
let _ = writeln!(output, "Circuit-open providers: {}", circuit_open.join(", "));
|
||||
}
|
||||
|
||||
if available.is_empty() && rate_limited.is_empty() && circuit_open.is_empty() {
|
||||
output.push_str(
|
||||
"No quota information available. Quota is populated after API calls.\n",
|
||||
);
|
||||
}
|
||||
|
||||
// Always show per-provider and per-profile details
|
||||
for provider_info in &summary.providers {
|
||||
let status_label = match &provider_info.status {
|
||||
QuotaStatus::Ok => "ok",
|
||||
QuotaStatus::RateLimited => "rate-limited",
|
||||
QuotaStatus::CircuitOpen => "circuit-open",
|
||||
QuotaStatus::QuotaExhausted => "quota-exhausted",
|
||||
};
|
||||
let _ = write!(
|
||||
output,
|
||||
"\n{} (status: {})\n",
|
||||
provider_info.provider, status_label
|
||||
);
|
||||
|
||||
if provider_info.failure_count > 0 {
|
||||
let _ = writeln!(output, " Failures: {}", provider_info.failure_count);
|
||||
}
|
||||
if let Some(retry_after) = provider_info.retry_after_seconds {
|
||||
let _ = writeln!(output, " Retry after: {}s", retry_after);
|
||||
}
|
||||
if let Some(ref err) = provider_info.last_error {
|
||||
let truncated = if err.len() > 120 { &err[..120] } else { err };
|
||||
let _ = writeln!(output, " Last error: {}", truncated);
|
||||
}
|
||||
|
||||
for profile in &provider_info.profiles {
|
||||
let _ = write!(output, " - {}", profile.profile_name);
|
||||
if let Some(ref acct) = profile.account_id {
|
||||
let _ = write!(output, " ({})", acct);
|
||||
}
|
||||
output.push('\n');
|
||||
|
||||
if let Some(remaining) = profile.rate_limit_remaining {
|
||||
if let Some(total) = profile.rate_limit_total {
|
||||
let _ = writeln!(output, " Quota: {}/{} requests", remaining, total);
|
||||
} else {
|
||||
let _ = writeln!(output, " Quota: {} remaining", remaining);
|
||||
}
|
||||
}
|
||||
if let Some(reset_at) = profile.rate_limit_reset_at {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Resets at: {}",
|
||||
reset_at.format("%Y-%m-%d %H:%M UTC")
|
||||
);
|
||||
}
|
||||
if let Some(expires) = profile.token_expires_at {
|
||||
let now = chrono::Utc::now();
|
||||
if expires < now {
|
||||
let ago = now.signed_duration_since(expires);
|
||||
let _ = writeln!(output, " Token: EXPIRED ({}h ago)", ago.num_hours());
|
||||
} else {
|
||||
let left = expires.signed_duration_since(now);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Token: valid (expires in {}h {}m)",
|
||||
left.num_hours(),
|
||||
left.num_minutes() % 60
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(ref plan) = profile.plan_type {
|
||||
let _ = writeln!(output, " Plan: {}", plan);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add cost tracking information if available
|
||||
if let Some(tracker) = &self.cost_tracker {
|
||||
if let Ok(cost_summary) = tracker.get_summary() {
|
||||
let _ = writeln!(output, "\nCost & Usage Summary:");
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" Session: ${:.4} ({} tokens, {} requests)",
|
||||
cost_summary.session_cost_usd,
|
||||
cost_summary.total_tokens,
|
||||
cost_summary.request_count
|
||||
);
|
||||
let _ = writeln!(output, " Today: ${:.4}", cost_summary.daily_cost_usd);
|
||||
let _ = writeln!(output, " Month: ${:.4}", cost_summary.monthly_cost_usd);
|
||||
|
||||
if !cost_summary.by_model.is_empty() {
|
||||
let _ = writeln!(output, "\n Per-model breakdown:");
|
||||
for (model, stats) in &cost_summary.by_model {
|
||||
let _ = writeln!(
|
||||
output,
|
||||
" {}: ${:.4} ({} tokens)",
|
||||
model, stats.cost_usd, stats.total_tokens
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add metadata as JSON at the end of output for programmatic parsing
|
||||
let _ = write!(
|
||||
output,
|
||||
"\n\n<!-- metadata: {} -->",
|
||||
json!({
|
||||
"available_providers": available,
|
||||
"rate_limited_providers": rate_limited,
|
||||
"circuit_open_providers": circuit_open,
|
||||
})
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for switching the default provider/model in config.toml.
|
||||
///
|
||||
/// Writes `default_provider` and `default_model` to config.toml so the
|
||||
/// change persists across requests. Uses the same Config::save() pattern
|
||||
/// as ModelRoutingConfigTool.
|
||||
pub struct SwitchProviderTool {
|
||||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
impl SwitchProviderTool {
|
||||
pub fn new(config: Arc<Config>) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
fn load_config_without_env(&self) -> Result<Config> {
|
||||
let contents = std::fs::read_to_string(&self.config.config_path).map_err(|error| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read config file {}: {error}",
|
||||
self.config.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut parsed: Config = toml::from_str(&contents).map_err(|error| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to parse config file {}: {error}",
|
||||
self.config.config_path.display()
|
||||
)
|
||||
})?;
|
||||
parsed.config_path.clone_from(&self.config.config_path);
|
||||
parsed.workspace_dir.clone_from(&self.config.workspace_dir);
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SwitchProviderTool {
|
||||
fn name(&self) -> &str {
|
||||
"switch_provider"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Switch to a different AI provider/model by updating config.toml. \
|
||||
Use when current provider is rate-limited or when user explicitly \
|
||||
requests a specific provider for a task. The change persists across requests."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": "Provider name (e.g., 'gemini', 'openai', 'anthropic')",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Specific model (optional, e.g., 'gemini-2.5-flash', 'claude-opus-4')"
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Reason for switching (for logging and user notification)"
|
||||
}
|
||||
},
|
||||
"required": ["provider"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
let provider = args["provider"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing provider"))?;
|
||||
let model = args.get("model").and_then(|v| v.as_str());
|
||||
let reason = args
|
||||
.get("reason")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("user request");
|
||||
|
||||
// Load config from disk (without env overrides), update, and save
|
||||
let save_result = async {
|
||||
let mut cfg = self.load_config_without_env()?;
|
||||
let previous_provider = cfg.default_provider.clone();
|
||||
let previous_model = cfg.default_model.clone();
|
||||
|
||||
cfg.default_provider = Some(provider.to_string());
|
||||
if let Some(m) = model {
|
||||
cfg.default_model = Some(m.to_string());
|
||||
}
|
||||
|
||||
cfg.save().await?;
|
||||
Ok::<_, anyhow::Error>((previous_provider, previous_model))
|
||||
}
|
||||
.await;
|
||||
|
||||
match save_result {
|
||||
Ok((prev_provider, prev_model)) => {
|
||||
let mut output = format!(
|
||||
"Switched provider to '{provider}'{}. Reason: {reason}",
|
||||
model.map(|m| format!(" (model: {m})")).unwrap_or_default(),
|
||||
);
|
||||
|
||||
if let Some(pp) = &prev_provider {
|
||||
let _ = write!(output, "\nPrevious: {pp}");
|
||||
if let Some(pm) = &prev_model {
|
||||
let _ = write!(output, " ({pm})");
|
||||
}
|
||||
}
|
||||
|
||||
let _ = write!(
|
||||
output,
|
||||
"\n\n<!-- metadata: {} -->",
|
||||
json!({
|
||||
"action": "switch_provider",
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"reason": reason,
|
||||
"previous_provider": prev_provider,
|
||||
"previous_model": prev_model,
|
||||
"persisted": true,
|
||||
})
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to update config: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for estimating quota cost before expensive operations.
|
||||
///
|
||||
/// Allows agent to predict: "это займет ~100 токенов"
|
||||
pub struct EstimateQuotaCostTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for EstimateQuotaCostTool {
|
||||
fn name(&self) -> &str {
|
||||
"estimate_quota_cost"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Estimate quota cost (tokens, requests) for an operation before executing it. \
|
||||
Useful for warning user if operation may exhaust quota or when planning \
|
||||
parallel tool calls."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"description": "Operation type",
|
||||
"enum": ["tool_call", "chat_response", "parallel_tools", "file_analysis"]
|
||||
},
|
||||
"estimated_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Estimated input+output tokens (optional, default: 1000)"
|
||||
},
|
||||
"parallel_count": {
|
||||
"type": "integer",
|
||||
"description": "Number of parallel operations (if applicable, default: 1)"
|
||||
}
|
||||
},
|
||||
"required": ["operation"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
let operation = args["operation"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing operation"))?;
|
||||
let estimated_tokens = args
|
||||
.get("estimated_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(1000);
|
||||
let parallel_count = args
|
||||
.get("parallel_count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(1);
|
||||
|
||||
// Simple cost estimation (can be improved with provider-specific pricing)
|
||||
let total_tokens = estimated_tokens * parallel_count;
|
||||
let total_requests = parallel_count;
|
||||
|
||||
// Rough cost estimate (based on average pricing)
|
||||
let cost_per_1k_tokens = 0.015; // Average across providers
|
||||
let estimated_cost_usd = (total_tokens as f64 / 1000.0) * cost_per_1k_tokens;
|
||||
|
||||
let output = format!(
|
||||
"Estimated cost for {operation}:\n\
|
||||
- Requests: {total_requests}\n\
|
||||
- Tokens: {total_tokens}\n\
|
||||
- Cost: ${estimated_cost_usd:.4} USD (estimate)\n\
|
||||
\n\
|
||||
Note: Actual cost may vary by provider and model."
|
||||
);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_check_provider_quota_schema() {
|
||||
let tool = CheckProviderQuotaTool::new(Arc::new(Config::default()));
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["provider"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_switch_provider_schema() {
|
||||
let tool = SwitchProviderTool::new(Arc::new(Config::default()));
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("provider")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_quota_schema() {
|
||||
let tool = EstimateQuotaCostTool;
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["operation"]["enum"].is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_provider_quota_name_and_description() {
|
||||
let tool = CheckProviderQuotaTool::new(Arc::new(Config::default()));
|
||||
assert_eq!(tool.name(), "check_provider_quota");
|
||||
assert!(tool.description().contains("quota"));
|
||||
assert!(tool.description().contains("rate limit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_switch_provider_name_and_description() {
|
||||
let tool = SwitchProviderTool::new(Arc::new(Config::default()));
|
||||
assert_eq!(tool.name(), "switch_provider");
|
||||
assert!(tool.description().contains("Switch"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_quota_cost_name_and_description() {
|
||||
let tool = EstimateQuotaCostTool;
|
||||
assert_eq!(tool.name(), "estimate_quota_cost");
|
||||
assert!(tool.description().contains("cost"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_switch_provider_execute() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config = Config {
|
||||
workspace_dir: tmp.path().to_path_buf(),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
};
|
||||
config.save().await.unwrap();
|
||||
let tool = SwitchProviderTool::new(Arc::new(config));
|
||||
let result = tool
|
||||
.execute(json!({"provider": "gemini", "model": "gemini-2.5-flash", "reason": "rate limited"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("gemini"));
|
||||
assert!(result.output.contains("rate limited"));
|
||||
// Verify config was actually updated
|
||||
let saved = std::fs::read_to_string(tmp.path().join("config.toml")).unwrap();
|
||||
assert!(saved.contains("gemini"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_estimate_quota_cost_execute() {
|
||||
let tool = EstimateQuotaCostTool;
|
||||
let result = tool
|
||||
.execute(json!({"operation": "chat_response", "estimated_tokens": 5000}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("5000"));
|
||||
assert!(result.output.contains('$'));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_provider_quota_execute_no_profiles() {
|
||||
// Test with default config (no real auth profiles)
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config = Config {
|
||||
workspace_dir: tmp.path().to_path_buf(),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
};
|
||||
let tool = CheckProviderQuotaTool::new(Arc::new(config));
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
// Should contain quota status header
|
||||
assert!(result.output.contains("Quota Status"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_provider_quota_with_filter() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config = Config {
|
||||
workspace_dir: tmp.path().to_path_buf(),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
};
|
||||
let tool = CheckProviderQuotaTool::new(Arc::new(config));
|
||||
let result = tool.execute(json!({"provider": "gemini"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user