From 0646abfed9b95cd17ca1b1ee6cdbd22f67106b6f Mon Sep 17 00:00:00 2001 From: Aleksandr Prilipko Date: Sat, 21 Feb 2026 03:25:08 +0700 Subject: [PATCH] feat(providers): Gemini OAuth credential rotation and token refresh --- src/providers/gemini.rs | 247 ++++++++++++++++++++++++++++++++-------- 1 file changed, 199 insertions(+), 48 deletions(-) diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 3a242f5ca..b59f2f2c6 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -15,6 +15,8 @@ use std::sync::Arc; pub struct GeminiProvider { auth: Option, oauth_project: Arc>>, + oauth_cred_paths: Vec, + oauth_index: Arc>, } /// Mutable OAuth token state — supports runtime refresh for long-lived processes. @@ -379,19 +381,22 @@ impl GeminiProvider { /// 3. `GOOGLE_API_KEY` environment variable /// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`) pub fn new(api_key: Option<&str>) -> Self { + let oauth_cred_paths = Self::discover_oauth_cred_paths(); let resolved_auth = api_key .and_then(Self::normalize_non_empty) .map(GeminiAuth::ExplicitKey) .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey)) .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey)) .or_else(|| { - Self::try_load_gemini_cli_token() + Self::try_load_gemini_cli_token(oauth_cred_paths.first()) .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state)))) }); Self { auth: resolved_auth, oauth_project: Arc::new(tokio::sync::Mutex::new(None)), + oauth_cred_paths, + oauth_index: Arc::new(tokio::sync::Mutex::new(0)), } } @@ -410,22 +415,58 @@ impl GeminiProvider { .and_then(|value| Self::normalize_non_empty(&value)) } - fn load_gemini_cli_creds() -> Option { - let gemini_dir = Self::gemini_cli_dir()?; - let creds_path = gemini_dir.join("oauth_creds.json"); + fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option { if !creds_path.exists() { return None; } - let content = std::fs::read_to_string(&creds_path).ok()?; + let content = std::fs::read_to_string(creds_path).ok()?; serde_json::from_str(&content).ok() } + /// Discover all OAuth credential files from known Gemini CLI installations. + /// + /// Looks in `~/.gemini/oauth_creds.json` (default) plus any + /// `~/.gemini-*-home/.gemini/oauth_creds.json` siblings. + fn discover_oauth_cred_paths() -> Vec { + let home = match UserDirs::new() { + Some(u) => u.home_dir().to_path_buf(), + None => return Vec::new(), + }; + + let mut paths = Vec::new(); + + let primary = home.join(".gemini").join("oauth_creds.json"); + if primary.exists() { + paths.push(primary); + } + + if let Ok(entries) = std::fs::read_dir(&home) { + let mut extras: Vec = entries + .filter_map(|e| e.ok()) + .filter_map(|e| { + let name = e.file_name().to_string_lossy().to_string(); + if name.starts_with(".gemini-") && name.ends_with("-home") { + let path = e.path().join(".gemini").join("oauth_creds.json"); + if path.exists() { + return Some(path); + } + } + None + }) + .collect(); + extras.sort(); + paths.extend(extras); + } + + paths + } + /// Try to load OAuth credentials from Gemini CLI's cached credentials. /// Location: `~/.gemini/oauth_creds.json` /// /// Returns the full `OAuthTokenState` so the provider can refresh at runtime. - fn try_load_gemini_cli_token() -> Option { - let creds = Self::load_gemini_cli_creds()?; + fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option { + let creds = Self::load_gemini_cli_creds(path?)?; // Determine expiry in millis: prefer expiry_date over expiry (RFC 3339) let expiry_millis = creds.expiry_date.or_else(|| { @@ -469,14 +510,16 @@ impl GeminiProvider { /// Check if Gemini CLI is configured and has valid credentials pub fn has_cli_credentials() -> bool { - Self::load_gemini_cli_creds() - .and_then(|creds| { - creds - .access_token - .as_deref() - .and_then(Self::normalize_non_empty) - }) - .is_some() + Self::discover_oauth_cred_paths().iter().any(|path| { + Self::load_gemini_cli_creds(path) + .and_then(|creds| { + creds + .access_token + .as_deref() + .and_then(Self::normalize_non_empty) + }) + .is_some() + }) } /// Check if any Gemini authentication is available @@ -537,6 +580,47 @@ impl GeminiProvider { Ok(guard.access_token.clone()) } + /// Rotate to the next available OAuth credentials file and swap state. + /// Returns `true` when rotation succeeded. + async fn rotate_oauth_credential( + &self, + state: &Arc>, + ) -> bool { + if self.oauth_cred_paths.len() <= 1 { + return false; + } + + let mut idx = self.oauth_index.lock().await; + let start = *idx; + + loop { + let next = (*idx + 1) % self.oauth_cred_paths.len(); + *idx = next; + + if next == start { + return false; + } + + if let Some(next_state) = + Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next)) + { + { + let mut guard = state.lock().await; + *guard = next_state; + } + { + let mut cached_project = self.oauth_project.lock().await; + *cached_project = None; + } + tracing::warn!( + "Gemini OAuth: rotated credential to {}", + self.oauth_cred_paths[next].display() + ); + return true; + } + } + } + fn format_model_name(model: &str) -> String { if model.starts_with("models/") { model.to_string() @@ -693,6 +777,13 @@ impl GeminiProvider { || error_text.contains("Unknown name 'generationConfig'") || error_text.contains(r#"Unknown name \"generationConfig\""#) } + + fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool { + status == reqwest::StatusCode::TOO_MANY_REQUESTS + || status == reqwest::StatusCode::SERVICE_UNAVAILABLE + || status.is_server_error() + || error_text.contains("RESOURCE_EXHAUSTED") + } } impl GeminiProvider { @@ -713,8 +804,13 @@ impl GeminiProvider { ) })?; + let oauth_state = match auth { + GeminiAuth::OAuthToken(state) => Some(state.clone()), + _ => None, + }; + // For OAuth: get a valid (potentially refreshed) token and resolve project - let (oauth_token, project) = if let GeminiAuth::OAuthToken(state) = auth { + let (mut oauth_token, mut project) = if let GeminiAuth::OAuthToken(state) = auth { let token = Self::get_valid_oauth_token(state).await?; let proj = self.resolve_oauth_project(&token).await?; (Some(token), Some(proj)) @@ -746,6 +842,58 @@ impl GeminiProvider { .send() .await?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + + if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) { + if let Some(state) = oauth_state.as_ref() { + if self.rotate_oauth_credential(state).await { + let token = Self::get_valid_oauth_token(state).await?; + let proj = self.resolve_oauth_project(&token).await?; + oauth_token = Some(token); + project = Some(proj); + response = self + .build_generate_content_request( + auth, + &url, + &request, + model, + true, + project.as_deref(), + oauth_token.as_deref(), + ) + .send() + .await?; + } else { + anyhow::bail!("Gemini API error ({status}): {error_text}"); + } + } else { + anyhow::bail!("Gemini API error ({status}): {error_text}"); + } + } else if auth.is_oauth() + && Self::should_retry_oauth_without_generation_config(status, &error_text) + { + tracing::warn!( + "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig" + ); + response = self + .build_generate_content_request( + auth, + &url, + &request, + model, + false, + project.as_deref(), + oauth_token.as_deref(), + ) + .send() + .await?; + } else { + anyhow::bail!("Gemini API error ({status}): {error_text}"); + } + } + if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); @@ -977,6 +1125,15 @@ mod tests { }))) } + fn test_provider(auth: Option) -> GeminiProvider { + GeminiProvider { + auth, + oauth_project: Arc::new(tokio::sync::Mutex::new(None)), + oauth_cred_paths: Vec::new(), + oauth_index: Arc::new(tokio::sync::Mutex::new(0)), + } + } + #[test] fn normalize_non_empty_trims_and_filters() { assert_eq!( @@ -1021,28 +1178,19 @@ mod tests { #[test] fn auth_source_explicit_key() { - let provider = GeminiProvider { - auth: Some(GeminiAuth::ExplicitKey("key".into())), - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(Some(GeminiAuth::ExplicitKey("key".into()))); assert_eq!(provider.auth_source(), "config"); } #[test] fn auth_source_none_without_credentials() { - let provider = GeminiProvider { - auth: None, - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(None); assert_eq!(provider.auth_source(), "none"); } #[test] fn auth_source_oauth() { - let provider = GeminiProvider { - auth: Some(test_oauth_auth("ya29.mock")), - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(Some(test_oauth_auth("ya29.mock"))); assert_eq!(provider.auth_source(), "Gemini CLI OAuth"); } @@ -1093,10 +1241,7 @@ mod tests { #[test] fn oauth_request_uses_bearer_auth_header() { - let provider = GeminiProvider { - auth: Some(test_oauth_auth("ya29.mock-token")), - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(Some(test_oauth_auth("ya29.mock-token"))); let auth = test_oauth_auth("ya29.mock-token"); let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); let body = GenerateContentRequest { @@ -1137,10 +1282,7 @@ mod tests { #[test] fn oauth_request_wraps_payload_in_request_envelope() { - let provider = GeminiProvider { - auth: Some(test_oauth_auth("ya29.mock-token")), - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(Some(test_oauth_auth("ya29.mock-token"))); let auth = test_oauth_auth("ya29.mock-token"); let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); let body = GenerateContentRequest { @@ -1184,10 +1326,7 @@ mod tests { #[test] fn api_key_request_does_not_set_bearer_header() { - let provider = GeminiProvider { - auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())), - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into()))); let auth = GeminiAuth::ExplicitKey("api-key-123".into()); let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); let body = GenerateContentRequest { @@ -1610,24 +1749,36 @@ mod tests { #[tokio::test] async fn warmup_without_key_is_noop() { - let provider = GeminiProvider { - auth: None, - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(None); let result = provider.warmup().await; assert!(result.is_ok()); } #[tokio::test] async fn warmup_oauth_is_noop() { - let provider = GeminiProvider { - auth: Some(test_oauth_auth("ya29.mock-token")), - oauth_project: Arc::new(tokio::sync::Mutex::new(None)), - }; + let provider = test_provider(Some(test_oauth_auth("ya29.mock-token"))); let result = provider.warmup().await; assert!(result.is_ok()); } + #[test] + fn discover_oauth_cred_paths_does_not_panic() { + let _paths = GeminiProvider::discover_oauth_cred_paths(); + } + + #[tokio::test] + async fn rotate_oauth_without_alternatives_returns_false() { + let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState { + access_token: "ya29.mock".to_string(), + refresh_token: None, + client_id: None, + client_secret: None, + expiry_millis: None, + })); + let provider = test_provider(Some(GeminiAuth::OAuthToken(state.clone()))); + assert!(!provider.rotate_oauth_credential(&state).await); + } + #[test] fn response_parses_usage_metadata() { let json = r#"{