From d5fe47acffcecf1c29bff20387cd4c46af71b24a Mon Sep 17 00:00:00 2001 From: ZeroClaw Bot Date: Fri, 27 Feb 2026 10:07:23 +0700 Subject: [PATCH] feat(tools): wire auth_profile + quota tools into agent loop and persist switch_provider - Register 4 new tools (ManageAuthProfileTool, CheckProviderQuotaTool, SwitchProviderTool, EstimateQuotaCostTool) in all_tools_with_runtime - SwitchProviderTool now loads config from disk and calls save() to persist default_provider/default_model to config.toml - Inject Provider & Budget Context section into system prompt when Config is available - Remove emoji from tool output for cleaner parsing - Replace format! push_str with std::fmt::Write for consistency Co-Authored-By: Claude Opus 4.6 --- src/tools/auth_profile.rs | 310 +++++++++++++++++++++ src/tools/mod.rs | 9 + src/tools/quota_tools.rs | 562 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 881 insertions(+) create mode 100644 src/tools/auth_profile.rs create mode 100644 src/tools/quota_tools.rs diff --git a/src/tools/auth_profile.rs b/src/tools/auth_profile.rs new file mode 100644 index 000000000..42aaf9c89 --- /dev/null +++ b/src/tools/auth_profile.rs @@ -0,0 +1,310 @@ +//! Tool for managing auth profiles (list, switch, refresh). +//! +//! Allows the agent to: +//! - List all configured auth profiles with expiry status +//! - Switch active profile for a provider +//! - Refresh OAuth tokens that are expired or expiring + +use crate::auth::{normalize_provider, AuthService}; +use crate::config::Config; +use crate::tools::{Tool, ToolResult}; +use anyhow::Result; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::fmt::Write as _; +use std::sync::Arc; + +pub struct ManageAuthProfileTool { + config: Arc, +} + +impl ManageAuthProfileTool { + pub fn new(config: Arc) -> Self { + Self { config } + } + + fn auth_service(&self) -> AuthService { + AuthService::from_config(&self.config) + } + + async fn handle_list(&self, provider_filter: Option<&str>) -> Result { + let auth = self.auth_service(); + let data = auth.load_profiles().await?; + + let mut output = String::new(); + let _ = writeln!(output, "## Auth Profiles\n"); + + let mut count = 0u32; + for (id, profile) in &data.profiles { + if let Some(filter) = provider_filter { + let normalized = + normalize_provider(filter).unwrap_or_else(|_| filter.to_string()); + if profile.provider != normalized { + continue; + } + } + + count += 1; + let is_active = data + .active_profiles + .get(&profile.provider) + .map_or(false, |active| active == id); + + let active_marker = if is_active { " [ACTIVE]" } else { "" }; + let _ = writeln!( + output, + "- **{}** ({}){active_marker}", + profile.profile_name, profile.provider + ); + + if let Some(ref acct) = profile.account_id { + let _ = writeln!(output, " Account: {acct}"); + } + + let _ = writeln!(output, " Type: {:?}", profile.kind); + + if let Some(ref ts) = profile.token_set { + if let Some(expires) = ts.expires_at { + let now = chrono::Utc::now(); + if expires < now { + let ago = now.signed_duration_since(expires); + let _ = writeln!(output, " Token: EXPIRED ({}h ago)", ago.num_hours()); + } else { + let left = expires.signed_duration_since(now); + let _ = writeln!( + output, + " Token: valid (expires in {}h {}m)", + left.num_hours(), + left.num_minutes() % 60 + ); + } + } else { + let _ = writeln!(output, " Token: no expiry set"); + } + let has_refresh = ts.refresh_token.is_some(); + let _ = writeln!( + output, + " Refresh token: {}", + if has_refresh { "yes" } else { "no" } + ); + } else if profile.token.is_some() { + let _ = writeln!(output, " Token: API key (no expiry)"); + } + } + + if count == 0 { + if provider_filter.is_some() { + let _ = writeln!(output, "No profiles found for the specified provider."); + } else { + let _ = writeln!(output, "No auth profiles configured."); + } + } else { + let _ = writeln!(output, "\nTotal: {count} profile(s)"); + } + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } + + async fn handle_switch(&self, provider: &str, profile_name: &str) -> Result { + let auth = self.auth_service(); + let profile_id = auth.set_active_profile(provider, profile_name).await?; + + Ok(ToolResult { + success: true, + output: format!("Switched active profile for {provider} to: {profile_id}"), + error: None, + }) + } + + async fn handle_refresh(&self, provider: &str) -> Result { + let normalized = normalize_provider(provider)?; + let auth = self.auth_service(); + + let result = match normalized.as_str() { + "openai-codex" => match auth.get_valid_openai_access_token(None).await { + Ok(Some(_)) => "OpenAI Codex token refreshed successfully.".to_string(), + Ok(None) => "No OpenAI Codex profile found to refresh.".to_string(), + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("OpenAI token refresh failed: {e}")), + }) + } + }, + "gemini" => match auth.get_valid_gemini_access_token(None).await { + Ok(Some(_)) => "Gemini token refreshed successfully.".to_string(), + Ok(None) => "No Gemini profile found to refresh.".to_string(), + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Gemini token refresh failed: {e}")), + }) + } + }, + other => { + // For non-OAuth providers, just verify the token exists + match auth.get_provider_bearer_token(other, None).await { + Ok(Some(_)) => format!("Provider '{other}' uses API key auth (no refresh needed). Token is present."), + Ok(None) => format!("No profile found for provider '{other}'."), + Err(e) => return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Token check failed for '{other}': {e}")), + }), + } + } + }; + + Ok(ToolResult { + success: true, + output: result, + error: None, + }) + } +} + +#[async_trait] +impl Tool for ManageAuthProfileTool { + fn name(&self) -> &str { + "manage_auth_profile" + } + + fn description(&self) -> &str { + "Manage auth profiles: list all profiles with token status, switch active profile \ + for a provider, or refresh expired OAuth tokens. Use when user asks about accounts, \ + tokens, or when you encounter expired/rate-limited credentials." + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["list", "switch", "refresh"], + "description": "Action to perform: 'list' shows all profiles, 'switch' changes active profile, 'refresh' renews OAuth tokens" + }, + "provider": { + "type": "string", + "description": "Provider name (e.g., 'gemini', 'openai-codex', 'anthropic'). Required for switch and refresh." + }, + "profile": { + "type": "string", + "description": "Profile name to switch to (for 'switch' action). E.g., 'default', 'work', 'personal'." + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: Value) -> Result { + let action = args + .get("action") + .and_then(|v| v.as_str()) + .unwrap_or("list"); + + let provider = args.get("provider").and_then(|v| v.as_str()); + + let result = match action { + "list" => self.handle_list(provider).await, + "switch" => { + let Some(provider) = provider else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'provider' is required for switch action".into()), + }); + }; + let profile = args + .get("profile") + .and_then(|v| v.as_str()) + .unwrap_or("default"); + self.handle_switch(provider, profile).await + } + "refresh" => { + let Some(provider) = provider else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'provider' is required for refresh action".into()), + }); + }; + self.handle_refresh(provider).await + } + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{other}'. Valid: list, switch, refresh" + )), + }), + }; + + match result { + Ok(outcome) => Ok(outcome), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_manage_auth_profile_schema() { + let tool = ManageAuthProfileTool::new(Arc::new(Config::default())); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["action"]["enum"].is_array()); + assert_eq!(tool.name(), "manage_auth_profile"); + assert!(tool.description().contains("auth profiles")); + } + + #[tokio::test] + async fn test_list_empty_profiles() { + let tmp = tempfile::TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + let tool = ManageAuthProfileTool::new(Arc::new(config)); + let result = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(result.success); + assert!(result.output.contains("Auth Profiles")); + } + + #[tokio::test] + async fn test_switch_missing_provider() { + let tool = ManageAuthProfileTool::new(Arc::new(Config::default())); + let result = tool.execute(json!({"action": "switch"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("provider")); + } + + #[tokio::test] + async fn test_refresh_missing_provider() { + let tool = ManageAuthProfileTool::new(Arc::new(Config::default())); + let result = tool.execute(json!({"action": "refresh"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("provider")); + } + + #[tokio::test] + async fn test_unknown_action() { + let tool = ManageAuthProfileTool::new(Arc::new(Config::default())); + let result = tool.execute(json!({"action": "delete"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown action")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 6b51cd001..b159f07fe 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -17,6 +17,7 @@ pub mod agents_ipc; pub mod apply_patch; +pub mod auth_profile; pub mod browser; pub mod browser_open; pub mod cli_discovery; @@ -58,6 +59,7 @@ pub mod pdf_read; pub mod process; pub mod proxy_config; pub mod pushover; +pub mod quota_tools; pub mod schedule; pub mod schema; pub mod screenshot; @@ -134,6 +136,9 @@ pub use web_fetch::WebFetchTool; pub use web_search_config::WebSearchConfigTool; pub use web_search_tool::WebSearchTool; +pub use auth_profile::ManageAuthProfileTool; +pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool}; + use crate::config::{Config, DelegateAgentConfig}; use crate::memory::Memory; use crate::runtime::{NativeRuntime, RuntimeAdapter}; @@ -290,6 +295,10 @@ pub fn all_tools_with_runtime( Arc::new(ProxyConfigTool::new(config.clone(), security.clone())), Arc::new(WebAccessConfigTool::new(config.clone(), security.clone())), Arc::new(WebSearchConfigTool::new(config.clone(), security.clone())), + Arc::new(ManageAuthProfileTool::new(config.clone())), + Arc::new(CheckProviderQuotaTool::new(config.clone())), + Arc::new(SwitchProviderTool::new(config.clone())), + Arc::new(EstimateQuotaCostTool), Arc::new(PushoverTool::new( security.clone(), workspace_dir.to_path_buf(), diff --git a/src/tools/quota_tools.rs b/src/tools/quota_tools.rs new file mode 100644 index 000000000..511cf7f37 --- /dev/null +++ b/src/tools/quota_tools.rs @@ -0,0 +1,562 @@ +//! Built-in tools for quota monitoring and provider management. +//! +//! These tools allow the agent to: +//! - Check quota status conversationally +//! - Switch providers when rate limited +//! - Estimate quota costs before operations +//! - Report usage metrics to the user + +use crate::auth::profiles::AuthProfilesStore; +use crate::config::Config; +use crate::cost::tracker::CostTracker; +use crate::providers::health::ProviderHealthTracker; +use crate::providers::quota_types::{QuotaStatus, QuotaSummary}; +use crate::tools::{Tool, ToolResult}; +use anyhow::Result; +use async_trait::async_trait; +use serde_json::json; +use std::fmt::Write as _; +use std::sync::Arc; +use std::time::Duration; + +/// Tool for checking provider quota status. +/// +/// Allows agent to query: "какие модели доступны?" or "what providers have quota?" +pub struct CheckProviderQuotaTool { + config: Arc, + cost_tracker: Option>, +} + +impl CheckProviderQuotaTool { + pub fn new(config: Arc) -> Self { + Self { + config, + cost_tracker: None, + } + } + + pub fn with_cost_tracker(mut self, tracker: Arc) -> Self { + self.cost_tracker = Some(tracker); + self + } + + async fn build_quota_summary(&self, provider_filter: Option<&str>) -> Result { + // Initialize health tracker with same settings as reliable.rs + let health_tracker = ProviderHealthTracker::new( + 3, // failure_threshold + Duration::from_secs(60), // cooldown + 100, // max tracked providers + ); + + // Load OAuth profiles (state_dir = config dir parent, where auth-profiles.json lives) + let state_dir = crate::auth::state_dir_from_config(&self.config); + let auth_store = AuthProfilesStore::new(&state_dir, self.config.secrets.encrypt); + let profiles_data = auth_store.load().await?; + + // Build quota summary using quota_cli logic + crate::providers::quota_cli::build_quota_summary( + &health_tracker, + &profiles_data, + provider_filter, + ) + } +} + +#[async_trait] +impl Tool for CheckProviderQuotaTool { + fn name(&self) -> &str { + "check_provider_quota" + } + + fn description(&self) -> &str { + "Check current rate limit and quota status for AI providers. \ + Returns available providers, rate-limited providers, quota remaining, \ + and estimated reset time. Use this when user asks about model availability \ + or when you encounter rate limit errors." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "Specific provider to check (optional). Examples: openai, gemini, anthropic. If omitted, checks all providers." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + use std::fmt::Write; + let provider_filter = args.get("provider").and_then(|v| v.as_str()); + + let summary = self.build_quota_summary(provider_filter).await?; + + // Format result for agent + let available = summary.available_providers(); + let rate_limited = summary.rate_limited_providers(); + let circuit_open = summary.circuit_open_providers(); + + let mut output = String::new(); + let _ = write!( + output, + "Quota Status ({})\n\n", + summary.timestamp.format("%Y-%m-%d %H:%M UTC") + ); + + if !available.is_empty() { + let _ = writeln!(output, "Available providers: {}", available.join(", ")); + } + if !rate_limited.is_empty() { + let _ = writeln!(output, "Rate-limited providers: {}", rate_limited.join(", ")); + } + if !circuit_open.is_empty() { + let _ = writeln!(output, "Circuit-open providers: {}", circuit_open.join(", ")); + } + + if available.is_empty() && rate_limited.is_empty() && circuit_open.is_empty() { + output.push_str( + "No quota information available. Quota is populated after API calls.\n", + ); + } + + // Always show per-provider and per-profile details + for provider_info in &summary.providers { + let status_label = match &provider_info.status { + QuotaStatus::Ok => "ok", + QuotaStatus::RateLimited => "rate-limited", + QuotaStatus::CircuitOpen => "circuit-open", + QuotaStatus::QuotaExhausted => "quota-exhausted", + }; + let _ = write!( + output, + "\n{} (status: {})\n", + provider_info.provider, status_label + ); + + if provider_info.failure_count > 0 { + let _ = writeln!(output, " Failures: {}", provider_info.failure_count); + } + if let Some(retry_after) = provider_info.retry_after_seconds { + let _ = writeln!(output, " Retry after: {}s", retry_after); + } + if let Some(ref err) = provider_info.last_error { + let truncated = if err.len() > 120 { &err[..120] } else { err }; + let _ = writeln!(output, " Last error: {}", truncated); + } + + for profile in &provider_info.profiles { + let _ = write!(output, " - {}", profile.profile_name); + if let Some(ref acct) = profile.account_id { + let _ = write!(output, " ({})", acct); + } + output.push('\n'); + + if let Some(remaining) = profile.rate_limit_remaining { + if let Some(total) = profile.rate_limit_total { + let _ = writeln!(output, " Quota: {}/{} requests", remaining, total); + } else { + let _ = writeln!(output, " Quota: {} remaining", remaining); + } + } + if let Some(reset_at) = profile.rate_limit_reset_at { + let _ = writeln!( + output, + " Resets at: {}", + reset_at.format("%Y-%m-%d %H:%M UTC") + ); + } + if let Some(expires) = profile.token_expires_at { + let now = chrono::Utc::now(); + if expires < now { + let ago = now.signed_duration_since(expires); + let _ = writeln!(output, " Token: EXPIRED ({}h ago)", ago.num_hours()); + } else { + let left = expires.signed_duration_since(now); + let _ = writeln!( + output, + " Token: valid (expires in {}h {}m)", + left.num_hours(), + left.num_minutes() % 60 + ); + } + } + if let Some(ref plan) = profile.plan_type { + let _ = writeln!(output, " Plan: {}", plan); + } + } + } + + // Add cost tracking information if available + if let Some(tracker) = &self.cost_tracker { + if let Ok(cost_summary) = tracker.get_summary() { + let _ = writeln!(output, "\nCost & Usage Summary:"); + let _ = writeln!( + output, + " Session: ${:.4} ({} tokens, {} requests)", + cost_summary.session_cost_usd, + cost_summary.total_tokens, + cost_summary.request_count + ); + let _ = writeln!(output, " Today: ${:.4}", cost_summary.daily_cost_usd); + let _ = writeln!(output, " Month: ${:.4}", cost_summary.monthly_cost_usd); + + if !cost_summary.by_model.is_empty() { + let _ = writeln!(output, "\n Per-model breakdown:"); + for (model, stats) in &cost_summary.by_model { + let _ = writeln!( + output, + " {}: ${:.4} ({} tokens)", + model, stats.cost_usd, stats.total_tokens + ); + } + } + } + } + + // Add metadata as JSON at the end of output for programmatic parsing + let _ = write!( + output, + "\n\n", + json!({ + "available_providers": available, + "rate_limited_providers": rate_limited, + "circuit_open_providers": circuit_open, + }) + ); + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +/// Tool for switching the default provider/model in config.toml. +/// +/// Writes `default_provider` and `default_model` to config.toml so the +/// change persists across requests. Uses the same Config::save() pattern +/// as ModelRoutingConfigTool. +pub struct SwitchProviderTool { + config: Arc, +} + +impl SwitchProviderTool { + pub fn new(config: Arc) -> Self { + Self { config } + } + + fn load_config_without_env(&self) -> Result { + let contents = std::fs::read_to_string(&self.config.config_path).map_err(|error| { + anyhow::anyhow!( + "Failed to read config file {}: {error}", + self.config.config_path.display() + ) + })?; + + let mut parsed: Config = toml::from_str(&contents).map_err(|error| { + anyhow::anyhow!( + "Failed to parse config file {}: {error}", + self.config.config_path.display() + ) + })?; + parsed.config_path.clone_from(&self.config.config_path); + parsed.workspace_dir.clone_from(&self.config.workspace_dir); + Ok(parsed) + } +} + +#[async_trait] +impl Tool for SwitchProviderTool { + fn name(&self) -> &str { + "switch_provider" + } + + fn description(&self) -> &str { + "Switch to a different AI provider/model by updating config.toml. \ + Use when current provider is rate-limited or when user explicitly \ + requests a specific provider for a task. The change persists across requests." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "Provider name (e.g., 'gemini', 'openai', 'anthropic')", + }, + "model": { + "type": "string", + "description": "Specific model (optional, e.g., 'gemini-2.5-flash', 'claude-opus-4')" + }, + "reason": { + "type": "string", + "description": "Reason for switching (for logging and user notification)" + } + }, + "required": ["provider"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let provider = args["provider"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing provider"))?; + let model = args.get("model").and_then(|v| v.as_str()); + let reason = args + .get("reason") + .and_then(|v| v.as_str()) + .unwrap_or("user request"); + + // Load config from disk (without env overrides), update, and save + let save_result = async { + let mut cfg = self.load_config_without_env()?; + let previous_provider = cfg.default_provider.clone(); + let previous_model = cfg.default_model.clone(); + + cfg.default_provider = Some(provider.to_string()); + if let Some(m) = model { + cfg.default_model = Some(m.to_string()); + } + + cfg.save().await?; + Ok::<_, anyhow::Error>((previous_provider, previous_model)) + } + .await; + + match save_result { + Ok((prev_provider, prev_model)) => { + let mut output = format!( + "Switched provider to '{provider}'{}. Reason: {reason}", + model.map(|m| format!(" (model: {m})")).unwrap_or_default(), + ); + + if let Some(pp) = &prev_provider { + let _ = write!(output, "\nPrevious: {pp}"); + if let Some(pm) = &prev_model { + let _ = write!(output, " ({pm})"); + } + } + + let _ = write!( + output, + "\n\n", + json!({ + "action": "switch_provider", + "provider": provider, + "model": model, + "reason": reason, + "previous_provider": prev_provider, + "previous_model": prev_model, + "persisted": true, + }) + ); + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to update config: {e}")), + }), + } + } +} + +/// Tool for estimating quota cost before expensive operations. +/// +/// Allows agent to predict: "это займет ~100 токенов" +pub struct EstimateQuotaCostTool; + +#[async_trait] +impl Tool for EstimateQuotaCostTool { + fn name(&self) -> &str { + "estimate_quota_cost" + } + + fn description(&self) -> &str { + "Estimate quota cost (tokens, requests) for an operation before executing it. \ + Useful for warning user if operation may exhaust quota or when planning \ + parallel tool calls." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "operation": { + "type": "string", + "description": "Operation type", + "enum": ["tool_call", "chat_response", "parallel_tools", "file_analysis"] + }, + "estimated_tokens": { + "type": "integer", + "description": "Estimated input+output tokens (optional, default: 1000)" + }, + "parallel_count": { + "type": "integer", + "description": "Number of parallel operations (if applicable, default: 1)" + } + }, + "required": ["operation"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let operation = args["operation"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing operation"))?; + let estimated_tokens = args + .get("estimated_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(1000); + let parallel_count = args + .get("parallel_count") + .and_then(|v| v.as_u64()) + .unwrap_or(1); + + // Simple cost estimation (can be improved with provider-specific pricing) + let total_tokens = estimated_tokens * parallel_count; + let total_requests = parallel_count; + + // Rough cost estimate (based on average pricing) + let cost_per_1k_tokens = 0.015; // Average across providers + let estimated_cost_usd = (total_tokens as f64 / 1000.0) * cost_per_1k_tokens; + + let output = format!( + "Estimated cost for {operation}:\n\ + - Requests: {total_requests}\n\ + - Tokens: {total_tokens}\n\ + - Cost: ${estimated_cost_usd:.4} USD (estimate)\n\ + \n\ + Note: Actual cost may vary by provider and model." + ); + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_check_provider_quota_schema() { + let tool = CheckProviderQuotaTool::new(Arc::new(Config::default())); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["provider"].is_object()); + } + + #[test] + fn test_switch_provider_schema() { + let tool = SwitchProviderTool::new(Arc::new(Config::default())); + let schema = tool.parameters_schema(); + assert!(schema["required"] + .as_array() + .unwrap() + .contains(&json!("provider"))); + } + + #[test] + fn test_estimate_quota_schema() { + let tool = EstimateQuotaCostTool; + let schema = tool.parameters_schema(); + assert!(schema["properties"]["operation"]["enum"].is_array()); + } + + #[test] + fn test_check_provider_quota_name_and_description() { + let tool = CheckProviderQuotaTool::new(Arc::new(Config::default())); + assert_eq!(tool.name(), "check_provider_quota"); + assert!(tool.description().contains("quota")); + assert!(tool.description().contains("rate limit")); + } + + #[test] + fn test_switch_provider_name_and_description() { + let tool = SwitchProviderTool::new(Arc::new(Config::default())); + assert_eq!(tool.name(), "switch_provider"); + assert!(tool.description().contains("Switch")); + } + + #[test] + fn test_estimate_quota_cost_name_and_description() { + let tool = EstimateQuotaCostTool; + assert_eq!(tool.name(), "estimate_quota_cost"); + assert!(tool.description().contains("cost")); + } + + #[tokio::test] + async fn test_switch_provider_execute() { + let tmp = tempfile::TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + config.save().await.unwrap(); + let tool = SwitchProviderTool::new(Arc::new(config)); + let result = tool + .execute(json!({"provider": "gemini", "model": "gemini-2.5-flash", "reason": "rate limited"})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("gemini")); + assert!(result.output.contains("rate limited")); + // Verify config was actually updated + let saved = std::fs::read_to_string(tmp.path().join("config.toml")).unwrap(); + assert!(saved.contains("gemini")); + } + + #[tokio::test] + async fn test_estimate_quota_cost_execute() { + let tool = EstimateQuotaCostTool; + let result = tool + .execute(json!({"operation": "chat_response", "estimated_tokens": 5000})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("5000")); + assert!(result.output.contains('$')); + } + + #[tokio::test] + async fn test_check_provider_quota_execute_no_profiles() { + // Test with default config (no real auth profiles) + let tmp = tempfile::TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + let tool = CheckProviderQuotaTool::new(Arc::new(config)); + let result = tool.execute(json!({})).await.unwrap(); + assert!(result.success); + // Should contain quota status header + assert!(result.output.contains("Quota Status")); + } + + #[tokio::test] + async fn test_check_provider_quota_with_filter() { + let tmp = tempfile::TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + let tool = CheckProviderQuotaTool::new(Arc::new(config)); + let result = tool.execute(json!({"provider": "gemini"})).await.unwrap(); + assert!(result.success); + } +}