feat(providers): Gemini OAuth credential rotation and token refresh

This commit is contained in:
Aleksandr Prilipko 2026-02-21 03:25:08 +07:00 committed by Chummy
parent 5571852b7b
commit 0646abfed9

View File

@ -15,6 +15,8 @@ use std::sync::Arc;
pub struct GeminiProvider {
auth: Option<GeminiAuth>,
oauth_project: Arc<tokio::sync::Mutex<Option<String>>>,
oauth_cred_paths: Vec<PathBuf>,
oauth_index: Arc<tokio::sync::Mutex<usize>>,
}
/// Mutable OAuth token state — supports runtime refresh for long-lived processes.
@ -379,19 +381,22 @@ impl GeminiProvider {
/// 3. `GOOGLE_API_KEY` environment variable
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
pub fn new(api_key: Option<&str>) -> Self {
let oauth_cred_paths = Self::discover_oauth_cred_paths();
let resolved_auth = api_key
.and_then(Self::normalize_non_empty)
.map(GeminiAuth::ExplicitKey)
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
.or_else(|| {
Self::try_load_gemini_cli_token()
Self::try_load_gemini_cli_token(oauth_cred_paths.first())
.map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))))
});
Self {
auth: resolved_auth,
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
oauth_cred_paths,
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
}
}
@ -410,22 +415,58 @@ impl GeminiProvider {
.and_then(|value| Self::normalize_non_empty(&value))
}
fn load_gemini_cli_creds() -> Option<GeminiCliOAuthCreds> {
let gemini_dir = Self::gemini_cli_dir()?;
let creds_path = gemini_dir.join("oauth_creds.json");
fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option<GeminiCliOAuthCreds> {
if !creds_path.exists() {
return None;
}
let content = std::fs::read_to_string(&creds_path).ok()?;
let content = std::fs::read_to_string(creds_path).ok()?;
serde_json::from_str(&content).ok()
}
/// Discover all OAuth credential files from known Gemini CLI installations.
///
/// Looks in `~/.gemini/oauth_creds.json` (default) plus any
/// `~/.gemini-*-home/.gemini/oauth_creds.json` siblings.
fn discover_oauth_cred_paths() -> Vec<PathBuf> {
let home = match UserDirs::new() {
Some(u) => u.home_dir().to_path_buf(),
None => return Vec::new(),
};
let mut paths = Vec::new();
let primary = home.join(".gemini").join("oauth_creds.json");
if primary.exists() {
paths.push(primary);
}
if let Ok(entries) = std::fs::read_dir(&home) {
let mut extras: Vec<PathBuf> = entries
.filter_map(|e| e.ok())
.filter_map(|e| {
let name = e.file_name().to_string_lossy().to_string();
if name.starts_with(".gemini-") && name.ends_with("-home") {
let path = e.path().join(".gemini").join("oauth_creds.json");
if path.exists() {
return Some(path);
}
}
None
})
.collect();
extras.sort();
paths.extend(extras);
}
paths
}
/// Try to load OAuth credentials from Gemini CLI's cached credentials.
/// Location: `~/.gemini/oauth_creds.json`
///
/// Returns the full `OAuthTokenState` so the provider can refresh at runtime.
fn try_load_gemini_cli_token() -> Option<OAuthTokenState> {
let creds = Self::load_gemini_cli_creds()?;
fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option<OAuthTokenState> {
let creds = Self::load_gemini_cli_creds(path?)?;
// Determine expiry in millis: prefer expiry_date over expiry (RFC 3339)
let expiry_millis = creds.expiry_date.or_else(|| {
@ -469,14 +510,16 @@ impl GeminiProvider {
/// Check if Gemini CLI is configured and has valid credentials
pub fn has_cli_credentials() -> bool {
Self::load_gemini_cli_creds()
.and_then(|creds| {
creds
.access_token
.as_deref()
.and_then(Self::normalize_non_empty)
})
.is_some()
Self::discover_oauth_cred_paths().iter().any(|path| {
Self::load_gemini_cli_creds(path)
.and_then(|creds| {
creds
.access_token
.as_deref()
.and_then(Self::normalize_non_empty)
})
.is_some()
})
}
/// Check if any Gemini authentication is available
@ -537,6 +580,47 @@ impl GeminiProvider {
Ok(guard.access_token.clone())
}
/// Rotate to the next available OAuth credentials file and swap state.
/// Returns `true` when rotation succeeded.
async fn rotate_oauth_credential(
&self,
state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
) -> bool {
if self.oauth_cred_paths.len() <= 1 {
return false;
}
let mut idx = self.oauth_index.lock().await;
let start = *idx;
loop {
let next = (*idx + 1) % self.oauth_cred_paths.len();
*idx = next;
if next == start {
return false;
}
if let Some(next_state) =
Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next))
{
{
let mut guard = state.lock().await;
*guard = next_state;
}
{
let mut cached_project = self.oauth_project.lock().await;
*cached_project = None;
}
tracing::warn!(
"Gemini OAuth: rotated credential to {}",
self.oauth_cred_paths[next].display()
);
return true;
}
}
}
fn format_model_name(model: &str) -> String {
if model.starts_with("models/") {
model.to_string()
@ -693,6 +777,13 @@ impl GeminiProvider {
|| error_text.contains("Unknown name 'generationConfig'")
|| error_text.contains(r#"Unknown name \"generationConfig\""#)
}
fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool {
status == reqwest::StatusCode::TOO_MANY_REQUESTS
|| status == reqwest::StatusCode::SERVICE_UNAVAILABLE
|| status.is_server_error()
|| error_text.contains("RESOURCE_EXHAUSTED")
}
}
impl GeminiProvider {
@ -713,8 +804,13 @@ impl GeminiProvider {
)
})?;
let oauth_state = match auth {
GeminiAuth::OAuthToken(state) => Some(state.clone()),
_ => None,
};
// For OAuth: get a valid (potentially refreshed) token and resolve project
let (oauth_token, project) = if let GeminiAuth::OAuthToken(state) = auth {
let (mut oauth_token, mut project) = if let GeminiAuth::OAuthToken(state) = auth {
let token = Self::get_valid_oauth_token(state).await?;
let proj = self.resolve_oauth_project(&token).await?;
(Some(token), Some(proj))
@ -746,6 +842,58 @@ impl GeminiProvider {
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) {
if let Some(state) = oauth_state.as_ref() {
if self.rotate_oauth_credential(state).await {
let token = Self::get_valid_oauth_token(state).await?;
let proj = self.resolve_oauth_project(&token).await?;
oauth_token = Some(token);
project = Some(proj);
response = self
.build_generate_content_request(
auth,
&url,
&request,
model,
true,
project.as_deref(),
oauth_token.as_deref(),
)
.send()
.await?;
} else {
anyhow::bail!("Gemini API error ({status}): {error_text}");
}
} else {
anyhow::bail!("Gemini API error ({status}): {error_text}");
}
} else if auth.is_oauth()
&& Self::should_retry_oauth_without_generation_config(status, &error_text)
{
tracing::warn!(
"Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
);
response = self
.build_generate_content_request(
auth,
&url,
&request,
model,
false,
project.as_deref(),
oauth_token.as_deref(),
)
.send()
.await?;
} else {
anyhow::bail!("Gemini API error ({status}): {error_text}");
}
}
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
@ -977,6 +1125,15 @@ mod tests {
})))
}
fn test_provider(auth: Option<GeminiAuth>) -> GeminiProvider {
GeminiProvider {
auth,
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
oauth_cred_paths: Vec::new(),
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
}
}
#[test]
fn normalize_non_empty_trims_and_filters() {
assert_eq!(
@ -1021,28 +1178,19 @@ mod tests {
#[test]
fn auth_source_explicit_key() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::ExplicitKey("key".into())),
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(Some(GeminiAuth::ExplicitKey("key".into())));
assert_eq!(provider.auth_source(), "config");
}
#[test]
fn auth_source_none_without_credentials() {
let provider = GeminiProvider {
auth: None,
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(None);
assert_eq!(provider.auth_source(), "none");
}
#[test]
fn auth_source_oauth() {
let provider = GeminiProvider {
auth: Some(test_oauth_auth("ya29.mock")),
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(Some(test_oauth_auth("ya29.mock")));
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
}
@ -1093,10 +1241,7 @@ mod tests {
#[test]
fn oauth_request_uses_bearer_auth_header() {
let provider = GeminiProvider {
auth: Some(test_oauth_auth("ya29.mock-token")),
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
let auth = test_oauth_auth("ya29.mock-token");
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
@ -1137,10 +1282,7 @@ mod tests {
#[test]
fn oauth_request_wraps_payload_in_request_envelope() {
let provider = GeminiProvider {
auth: Some(test_oauth_auth("ya29.mock-token")),
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
let auth = test_oauth_auth("ya29.mock-token");
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
@ -1184,10 +1326,7 @@ mod tests {
#[test]
fn api_key_request_does_not_set_bearer_header() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())),
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into())));
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
@ -1610,24 +1749,36 @@ mod tests {
#[tokio::test]
async fn warmup_without_key_is_noop() {
let provider = GeminiProvider {
auth: None,
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(None);
let result = provider.warmup().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn warmup_oauth_is_noop() {
let provider = GeminiProvider {
auth: Some(test_oauth_auth("ya29.mock-token")),
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
};
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
let result = provider.warmup().await;
assert!(result.is_ok());
}
#[test]
fn discover_oauth_cred_paths_does_not_panic() {
let _paths = GeminiProvider::discover_oauth_cred_paths();
}
#[tokio::test]
async fn rotate_oauth_without_alternatives_returns_false() {
let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
access_token: "ya29.mock".to_string(),
refresh_token: None,
client_id: None,
client_secret: None,
expiry_millis: None,
}));
let provider = test_provider(Some(GeminiAuth::OAuthToken(state.clone())));
assert!(!provider.rotate_oauth_credential(&state).await);
}
#[test]
fn response_parses_usage_metadata() {
let json = r#"{