feat(tools): wire auth_profile + quota tools into agent loop and persist switch_provider

- Register 4 new tools (ManageAuthProfileTool, CheckProviderQuotaTool,
  SwitchProviderTool, EstimateQuotaCostTool) in all_tools_with_runtime
- SwitchProviderTool now loads config from disk and calls save() to
  persist default_provider/default_model to config.toml
- Inject Provider & Budget Context section into system prompt when
  Config is available
- Remove emoji from tool output for cleaner parsing
- Replace format! push_str with std::fmt::Write for consistency

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
ZeroClaw Bot 2026-02-27 10:07:23 +07:00 committed by Argenis
parent 8c0be20422
commit d5fe47acff
3 changed files with 881 additions and 0 deletions

310
src/tools/auth_profile.rs Normal file
View File

@ -0,0 +1,310 @@
//! Tool for managing auth profiles (list, switch, refresh).
//!
//! Allows the agent to:
//! - List all configured auth profiles with expiry status
//! - Switch active profile for a provider
//! - Refresh OAuth tokens that are expired or expiring
use crate::auth::{normalize_provider, AuthService};
use crate::config::Config;
use crate::tools::{Tool, ToolResult};
use anyhow::Result;
use async_trait::async_trait;
use serde_json::{json, Value};
use std::fmt::Write as _;
use std::sync::Arc;
pub struct ManageAuthProfileTool {
config: Arc<Config>,
}
impl ManageAuthProfileTool {
pub fn new(config: Arc<Config>) -> Self {
Self { config }
}
fn auth_service(&self) -> AuthService {
AuthService::from_config(&self.config)
}
async fn handle_list(&self, provider_filter: Option<&str>) -> Result<ToolResult> {
let auth = self.auth_service();
let data = auth.load_profiles().await?;
let mut output = String::new();
let _ = writeln!(output, "## Auth Profiles\n");
let mut count = 0u32;
for (id, profile) in &data.profiles {
if let Some(filter) = provider_filter {
let normalized =
normalize_provider(filter).unwrap_or_else(|_| filter.to_string());
if profile.provider != normalized {
continue;
}
}
count += 1;
let is_active = data
.active_profiles
.get(&profile.provider)
.map_or(false, |active| active == id);
let active_marker = if is_active { " [ACTIVE]" } else { "" };
let _ = writeln!(
output,
"- **{}** ({}){active_marker}",
profile.profile_name, profile.provider
);
if let Some(ref acct) = profile.account_id {
let _ = writeln!(output, " Account: {acct}");
}
let _ = writeln!(output, " Type: {:?}", profile.kind);
if let Some(ref ts) = profile.token_set {
if let Some(expires) = ts.expires_at {
let now = chrono::Utc::now();
if expires < now {
let ago = now.signed_duration_since(expires);
let _ = writeln!(output, " Token: EXPIRED ({}h ago)", ago.num_hours());
} else {
let left = expires.signed_duration_since(now);
let _ = writeln!(
output,
" Token: valid (expires in {}h {}m)",
left.num_hours(),
left.num_minutes() % 60
);
}
} else {
let _ = writeln!(output, " Token: no expiry set");
}
let has_refresh = ts.refresh_token.is_some();
let _ = writeln!(
output,
" Refresh token: {}",
if has_refresh { "yes" } else { "no" }
);
} else if profile.token.is_some() {
let _ = writeln!(output, " Token: API key (no expiry)");
}
}
if count == 0 {
if provider_filter.is_some() {
let _ = writeln!(output, "No profiles found for the specified provider.");
} else {
let _ = writeln!(output, "No auth profiles configured.");
}
} else {
let _ = writeln!(output, "\nTotal: {count} profile(s)");
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
async fn handle_switch(&self, provider: &str, profile_name: &str) -> Result<ToolResult> {
let auth = self.auth_service();
let profile_id = auth.set_active_profile(provider, profile_name).await?;
Ok(ToolResult {
success: true,
output: format!("Switched active profile for {provider} to: {profile_id}"),
error: None,
})
}
async fn handle_refresh(&self, provider: &str) -> Result<ToolResult> {
let normalized = normalize_provider(provider)?;
let auth = self.auth_service();
let result = match normalized.as_str() {
"openai-codex" => match auth.get_valid_openai_access_token(None).await {
Ok(Some(_)) => "OpenAI Codex token refreshed successfully.".to_string(),
Ok(None) => "No OpenAI Codex profile found to refresh.".to_string(),
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("OpenAI token refresh failed: {e}")),
})
}
},
"gemini" => match auth.get_valid_gemini_access_token(None).await {
Ok(Some(_)) => "Gemini token refreshed successfully.".to_string(),
Ok(None) => "No Gemini profile found to refresh.".to_string(),
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Gemini token refresh failed: {e}")),
})
}
},
other => {
// For non-OAuth providers, just verify the token exists
match auth.get_provider_bearer_token(other, None).await {
Ok(Some(_)) => format!("Provider '{other}' uses API key auth (no refresh needed). Token is present."),
Ok(None) => format!("No profile found for provider '{other}'."),
Err(e) => return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Token check failed for '{other}': {e}")),
}),
}
}
};
Ok(ToolResult {
success: true,
output: result,
error: None,
})
}
}
#[async_trait]
impl Tool for ManageAuthProfileTool {
fn name(&self) -> &str {
"manage_auth_profile"
}
fn description(&self) -> &str {
"Manage auth profiles: list all profiles with token status, switch active profile \
for a provider, or refresh expired OAuth tokens. Use when user asks about accounts, \
tokens, or when you encounter expired/rate-limited credentials."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["list", "switch", "refresh"],
"description": "Action to perform: 'list' shows all profiles, 'switch' changes active profile, 'refresh' renews OAuth tokens"
},
"provider": {
"type": "string",
"description": "Provider name (e.g., 'gemini', 'openai-codex', 'anthropic'). Required for switch and refresh."
},
"profile": {
"type": "string",
"description": "Profile name to switch to (for 'switch' action). E.g., 'default', 'work', 'personal'."
}
},
"required": ["action"]
})
}
async fn execute(&self, args: Value) -> Result<ToolResult> {
let action = args
.get("action")
.and_then(|v| v.as_str())
.unwrap_or("list");
let provider = args.get("provider").and_then(|v| v.as_str());
let result = match action {
"list" => self.handle_list(provider).await,
"switch" => {
let Some(provider) = provider else {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'provider' is required for switch action".into()),
});
};
let profile = args
.get("profile")
.and_then(|v| v.as_str())
.unwrap_or("default");
self.handle_switch(provider, profile).await
}
"refresh" => {
let Some(provider) = provider else {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'provider' is required for refresh action".into()),
});
};
self.handle_refresh(provider).await
}
other => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action '{other}'. Valid: list, switch, refresh"
)),
}),
};
match result {
Ok(outcome) => Ok(outcome),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manage_auth_profile_schema() {
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
let schema = tool.parameters_schema();
assert!(schema["properties"]["action"]["enum"].is_array());
assert_eq!(tool.name(), "manage_auth_profile");
assert!(tool.description().contains("auth profiles"));
}
#[tokio::test]
async fn test_list_empty_profiles() {
let tmp = tempfile::TempDir::new().unwrap();
let config = Config {
workspace_dir: tmp.path().to_path_buf(),
config_path: tmp.path().join("config.toml"),
..Config::default()
};
let tool = ManageAuthProfileTool::new(Arc::new(config));
let result = tool.execute(json!({"action": "list"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("Auth Profiles"));
}
#[tokio::test]
async fn test_switch_missing_provider() {
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
let result = tool.execute(json!({"action": "switch"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("provider"));
}
#[tokio::test]
async fn test_refresh_missing_provider() {
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
let result = tool.execute(json!({"action": "refresh"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("provider"));
}
#[tokio::test]
async fn test_unknown_action() {
let tool = ManageAuthProfileTool::new(Arc::new(Config::default()));
let result = tool.execute(json!({"action": "delete"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}
}

View File

@ -17,6 +17,7 @@
pub mod agents_ipc;
pub mod apply_patch;
pub mod auth_profile;
pub mod browser;
pub mod browser_open;
pub mod cli_discovery;
@ -58,6 +59,7 @@ pub mod pdf_read;
pub mod process;
pub mod proxy_config;
pub mod pushover;
pub mod quota_tools;
pub mod schedule;
pub mod schema;
pub mod screenshot;
@ -134,6 +136,9 @@ pub use web_fetch::WebFetchTool;
pub use web_search_config::WebSearchConfigTool;
pub use web_search_tool::WebSearchTool;
pub use auth_profile::ManageAuthProfileTool;
pub use quota_tools::{CheckProviderQuotaTool, EstimateQuotaCostTool, SwitchProviderTool};
use crate::config::{Config, DelegateAgentConfig};
use crate::memory::Memory;
use crate::runtime::{NativeRuntime, RuntimeAdapter};
@ -290,6 +295,10 @@ pub fn all_tools_with_runtime(
Arc::new(ProxyConfigTool::new(config.clone(), security.clone())),
Arc::new(WebAccessConfigTool::new(config.clone(), security.clone())),
Arc::new(WebSearchConfigTool::new(config.clone(), security.clone())),
Arc::new(ManageAuthProfileTool::new(config.clone())),
Arc::new(CheckProviderQuotaTool::new(config.clone())),
Arc::new(SwitchProviderTool::new(config.clone())),
Arc::new(EstimateQuotaCostTool),
Arc::new(PushoverTool::new(
security.clone(),
workspace_dir.to_path_buf(),

562
src/tools/quota_tools.rs Normal file
View File

@ -0,0 +1,562 @@
//! Built-in tools for quota monitoring and provider management.
//!
//! These tools allow the agent to:
//! - Check quota status conversationally
//! - Switch providers when rate limited
//! - Estimate quota costs before operations
//! - Report usage metrics to the user
use crate::auth::profiles::AuthProfilesStore;
use crate::config::Config;
use crate::cost::tracker::CostTracker;
use crate::providers::health::ProviderHealthTracker;
use crate::providers::quota_types::{QuotaStatus, QuotaSummary};
use crate::tools::{Tool, ToolResult};
use anyhow::Result;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write as _;
use std::sync::Arc;
use std::time::Duration;
/// Tool for checking provider quota status.
///
/// Allows agent to query: "какие модели доступны?" or "what providers have quota?"
pub struct CheckProviderQuotaTool {
config: Arc<Config>,
cost_tracker: Option<Arc<CostTracker>>,
}
impl CheckProviderQuotaTool {
pub fn new(config: Arc<Config>) -> Self {
Self {
config,
cost_tracker: None,
}
}
pub fn with_cost_tracker(mut self, tracker: Arc<CostTracker>) -> Self {
self.cost_tracker = Some(tracker);
self
}
async fn build_quota_summary(&self, provider_filter: Option<&str>) -> Result<QuotaSummary> {
// Initialize health tracker with same settings as reliable.rs
let health_tracker = ProviderHealthTracker::new(
3, // failure_threshold
Duration::from_secs(60), // cooldown
100, // max tracked providers
);
// Load OAuth profiles (state_dir = config dir parent, where auth-profiles.json lives)
let state_dir = crate::auth::state_dir_from_config(&self.config);
let auth_store = AuthProfilesStore::new(&state_dir, self.config.secrets.encrypt);
let profiles_data = auth_store.load().await?;
// Build quota summary using quota_cli logic
crate::providers::quota_cli::build_quota_summary(
&health_tracker,
&profiles_data,
provider_filter,
)
}
}
#[async_trait]
impl Tool for CheckProviderQuotaTool {
fn name(&self) -> &str {
"check_provider_quota"
}
fn description(&self) -> &str {
"Check current rate limit and quota status for AI providers. \
Returns available providers, rate-limited providers, quota remaining, \
and estimated reset time. Use this when user asks about model availability \
or when you encounter rate limit errors."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"provider": {
"type": "string",
"description": "Specific provider to check (optional). Examples: openai, gemini, anthropic. If omitted, checks all providers."
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
use std::fmt::Write;
let provider_filter = args.get("provider").and_then(|v| v.as_str());
let summary = self.build_quota_summary(provider_filter).await?;
// Format result for agent
let available = summary.available_providers();
let rate_limited = summary.rate_limited_providers();
let circuit_open = summary.circuit_open_providers();
let mut output = String::new();
let _ = write!(
output,
"Quota Status ({})\n\n",
summary.timestamp.format("%Y-%m-%d %H:%M UTC")
);
if !available.is_empty() {
let _ = writeln!(output, "Available providers: {}", available.join(", "));
}
if !rate_limited.is_empty() {
let _ = writeln!(output, "Rate-limited providers: {}", rate_limited.join(", "));
}
if !circuit_open.is_empty() {
let _ = writeln!(output, "Circuit-open providers: {}", circuit_open.join(", "));
}
if available.is_empty() && rate_limited.is_empty() && circuit_open.is_empty() {
output.push_str(
"No quota information available. Quota is populated after API calls.\n",
);
}
// Always show per-provider and per-profile details
for provider_info in &summary.providers {
let status_label = match &provider_info.status {
QuotaStatus::Ok => "ok",
QuotaStatus::RateLimited => "rate-limited",
QuotaStatus::CircuitOpen => "circuit-open",
QuotaStatus::QuotaExhausted => "quota-exhausted",
};
let _ = write!(
output,
"\n{} (status: {})\n",
provider_info.provider, status_label
);
if provider_info.failure_count > 0 {
let _ = writeln!(output, " Failures: {}", provider_info.failure_count);
}
if let Some(retry_after) = provider_info.retry_after_seconds {
let _ = writeln!(output, " Retry after: {}s", retry_after);
}
if let Some(ref err) = provider_info.last_error {
let truncated = if err.len() > 120 { &err[..120] } else { err };
let _ = writeln!(output, " Last error: {}", truncated);
}
for profile in &provider_info.profiles {
let _ = write!(output, " - {}", profile.profile_name);
if let Some(ref acct) = profile.account_id {
let _ = write!(output, " ({})", acct);
}
output.push('\n');
if let Some(remaining) = profile.rate_limit_remaining {
if let Some(total) = profile.rate_limit_total {
let _ = writeln!(output, " Quota: {}/{} requests", remaining, total);
} else {
let _ = writeln!(output, " Quota: {} remaining", remaining);
}
}
if let Some(reset_at) = profile.rate_limit_reset_at {
let _ = writeln!(
output,
" Resets at: {}",
reset_at.format("%Y-%m-%d %H:%M UTC")
);
}
if let Some(expires) = profile.token_expires_at {
let now = chrono::Utc::now();
if expires < now {
let ago = now.signed_duration_since(expires);
let _ = writeln!(output, " Token: EXPIRED ({}h ago)", ago.num_hours());
} else {
let left = expires.signed_duration_since(now);
let _ = writeln!(
output,
" Token: valid (expires in {}h {}m)",
left.num_hours(),
left.num_minutes() % 60
);
}
}
if let Some(ref plan) = profile.plan_type {
let _ = writeln!(output, " Plan: {}", plan);
}
}
}
// Add cost tracking information if available
if let Some(tracker) = &self.cost_tracker {
if let Ok(cost_summary) = tracker.get_summary() {
let _ = writeln!(output, "\nCost & Usage Summary:");
let _ = writeln!(
output,
" Session: ${:.4} ({} tokens, {} requests)",
cost_summary.session_cost_usd,
cost_summary.total_tokens,
cost_summary.request_count
);
let _ = writeln!(output, " Today: ${:.4}", cost_summary.daily_cost_usd);
let _ = writeln!(output, " Month: ${:.4}", cost_summary.monthly_cost_usd);
if !cost_summary.by_model.is_empty() {
let _ = writeln!(output, "\n Per-model breakdown:");
for (model, stats) in &cost_summary.by_model {
let _ = writeln!(
output,
" {}: ${:.4} ({} tokens)",
model, stats.cost_usd, stats.total_tokens
);
}
}
}
}
// Add metadata as JSON at the end of output for programmatic parsing
let _ = write!(
output,
"\n\n<!-- metadata: {} -->",
json!({
"available_providers": available,
"rate_limited_providers": rate_limited,
"circuit_open_providers": circuit_open,
})
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
/// Tool for switching the default provider/model in config.toml.
///
/// Writes `default_provider` and `default_model` to config.toml so the
/// change persists across requests. Uses the same Config::save() pattern
/// as ModelRoutingConfigTool.
pub struct SwitchProviderTool {
config: Arc<Config>,
}
impl SwitchProviderTool {
pub fn new(config: Arc<Config>) -> Self {
Self { config }
}
fn load_config_without_env(&self) -> Result<Config> {
let contents = std::fs::read_to_string(&self.config.config_path).map_err(|error| {
anyhow::anyhow!(
"Failed to read config file {}: {error}",
self.config.config_path.display()
)
})?;
let mut parsed: Config = toml::from_str(&contents).map_err(|error| {
anyhow::anyhow!(
"Failed to parse config file {}: {error}",
self.config.config_path.display()
)
})?;
parsed.config_path.clone_from(&self.config.config_path);
parsed.workspace_dir.clone_from(&self.config.workspace_dir);
Ok(parsed)
}
}
#[async_trait]
impl Tool for SwitchProviderTool {
fn name(&self) -> &str {
"switch_provider"
}
fn description(&self) -> &str {
"Switch to a different AI provider/model by updating config.toml. \
Use when current provider is rate-limited or when user explicitly \
requests a specific provider for a task. The change persists across requests."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"provider": {
"type": "string",
"description": "Provider name (e.g., 'gemini', 'openai', 'anthropic')",
},
"model": {
"type": "string",
"description": "Specific model (optional, e.g., 'gemini-2.5-flash', 'claude-opus-4')"
},
"reason": {
"type": "string",
"description": "Reason for switching (for logging and user notification)"
}
},
"required": ["provider"]
})
}
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
let provider = args["provider"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing provider"))?;
let model = args.get("model").and_then(|v| v.as_str());
let reason = args
.get("reason")
.and_then(|v| v.as_str())
.unwrap_or("user request");
// Load config from disk (without env overrides), update, and save
let save_result = async {
let mut cfg = self.load_config_without_env()?;
let previous_provider = cfg.default_provider.clone();
let previous_model = cfg.default_model.clone();
cfg.default_provider = Some(provider.to_string());
if let Some(m) = model {
cfg.default_model = Some(m.to_string());
}
cfg.save().await?;
Ok::<_, anyhow::Error>((previous_provider, previous_model))
}
.await;
match save_result {
Ok((prev_provider, prev_model)) => {
let mut output = format!(
"Switched provider to '{provider}'{}. Reason: {reason}",
model.map(|m| format!(" (model: {m})")).unwrap_or_default(),
);
if let Some(pp) = &prev_provider {
let _ = write!(output, "\nPrevious: {pp}");
if let Some(pm) = &prev_model {
let _ = write!(output, " ({pm})");
}
}
let _ = write!(
output,
"\n\n<!-- metadata: {} -->",
json!({
"action": "switch_provider",
"provider": provider,
"model": model,
"reason": reason,
"previous_provider": prev_provider,
"previous_model": prev_model,
"persisted": true,
})
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to update config: {e}")),
}),
}
}
}
/// Tool for estimating quota cost before expensive operations.
///
/// Allows agent to predict: "это займет ~100 токенов"
pub struct EstimateQuotaCostTool;
#[async_trait]
impl Tool for EstimateQuotaCostTool {
fn name(&self) -> &str {
"estimate_quota_cost"
}
fn description(&self) -> &str {
"Estimate quota cost (tokens, requests) for an operation before executing it. \
Useful for warning user if operation may exhaust quota or when planning \
parallel tool calls."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"operation": {
"type": "string",
"description": "Operation type",
"enum": ["tool_call", "chat_response", "parallel_tools", "file_analysis"]
},
"estimated_tokens": {
"type": "integer",
"description": "Estimated input+output tokens (optional, default: 1000)"
},
"parallel_count": {
"type": "integer",
"description": "Number of parallel operations (if applicable, default: 1)"
}
},
"required": ["operation"]
})
}
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
let operation = args["operation"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing operation"))?;
let estimated_tokens = args
.get("estimated_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(1000);
let parallel_count = args
.get("parallel_count")
.and_then(|v| v.as_u64())
.unwrap_or(1);
// Simple cost estimation (can be improved with provider-specific pricing)
let total_tokens = estimated_tokens * parallel_count;
let total_requests = parallel_count;
// Rough cost estimate (based on average pricing)
let cost_per_1k_tokens = 0.015; // Average across providers
let estimated_cost_usd = (total_tokens as f64 / 1000.0) * cost_per_1k_tokens;
let output = format!(
"Estimated cost for {operation}:\n\
- Requests: {total_requests}\n\
- Tokens: {total_tokens}\n\
- Cost: ${estimated_cost_usd:.4} USD (estimate)\n\
\n\
Note: Actual cost may vary by provider and model."
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_provider_quota_schema() {
let tool = CheckProviderQuotaTool::new(Arc::new(Config::default()));
let schema = tool.parameters_schema();
assert!(schema["properties"]["provider"].is_object());
}
#[test]
fn test_switch_provider_schema() {
let tool = SwitchProviderTool::new(Arc::new(Config::default()));
let schema = tool.parameters_schema();
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("provider")));
}
#[test]
fn test_estimate_quota_schema() {
let tool = EstimateQuotaCostTool;
let schema = tool.parameters_schema();
assert!(schema["properties"]["operation"]["enum"].is_array());
}
#[test]
fn test_check_provider_quota_name_and_description() {
let tool = CheckProviderQuotaTool::new(Arc::new(Config::default()));
assert_eq!(tool.name(), "check_provider_quota");
assert!(tool.description().contains("quota"));
assert!(tool.description().contains("rate limit"));
}
#[test]
fn test_switch_provider_name_and_description() {
let tool = SwitchProviderTool::new(Arc::new(Config::default()));
assert_eq!(tool.name(), "switch_provider");
assert!(tool.description().contains("Switch"));
}
#[test]
fn test_estimate_quota_cost_name_and_description() {
let tool = EstimateQuotaCostTool;
assert_eq!(tool.name(), "estimate_quota_cost");
assert!(tool.description().contains("cost"));
}
#[tokio::test]
async fn test_switch_provider_execute() {
let tmp = tempfile::TempDir::new().unwrap();
let config = Config {
workspace_dir: tmp.path().to_path_buf(),
config_path: tmp.path().join("config.toml"),
..Config::default()
};
config.save().await.unwrap();
let tool = SwitchProviderTool::new(Arc::new(config));
let result = tool
.execute(json!({"provider": "gemini", "model": "gemini-2.5-flash", "reason": "rate limited"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("gemini"));
assert!(result.output.contains("rate limited"));
// Verify config was actually updated
let saved = std::fs::read_to_string(tmp.path().join("config.toml")).unwrap();
assert!(saved.contains("gemini"));
}
#[tokio::test]
async fn test_estimate_quota_cost_execute() {
let tool = EstimateQuotaCostTool;
let result = tool
.execute(json!({"operation": "chat_response", "estimated_tokens": 5000}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("5000"));
assert!(result.output.contains('$'));
}
#[tokio::test]
async fn test_check_provider_quota_execute_no_profiles() {
// Test with default config (no real auth profiles)
let tmp = tempfile::TempDir::new().unwrap();
let config = Config {
workspace_dir: tmp.path().to_path_buf(),
config_path: tmp.path().join("config.toml"),
..Config::default()
};
let tool = CheckProviderQuotaTool::new(Arc::new(config));
let result = tool.execute(json!({})).await.unwrap();
assert!(result.success);
// Should contain quota status header
assert!(result.output.contains("Quota Status"));
}
#[tokio::test]
async fn test_check_provider_quota_with_filter() {
let tmp = tempfile::TempDir::new().unwrap();
let config = Config {
workspace_dir: tmp.path().to_path_buf(),
config_path: tmp.path().join("config.toml"),
..Config::default()
};
let tool = CheckProviderQuotaTool::new(Arc::new(config));
let result = tool.execute(json!({"provider": "gemini"})).await.unwrap();
assert!(result.success);
}
}