feat(providers): Gemini OAuth credential rotation and token refresh
This commit is contained in:
parent
5571852b7b
commit
0646abfed9
@ -15,6 +15,8 @@ use std::sync::Arc;
|
||||
pub struct GeminiProvider {
|
||||
auth: Option<GeminiAuth>,
|
||||
oauth_project: Arc<tokio::sync::Mutex<Option<String>>>,
|
||||
oauth_cred_paths: Vec<PathBuf>,
|
||||
oauth_index: Arc<tokio::sync::Mutex<usize>>,
|
||||
}
|
||||
|
||||
/// Mutable OAuth token state — supports runtime refresh for long-lived processes.
|
||||
@ -379,19 +381,22 @@ impl GeminiProvider {
|
||||
/// 3. `GOOGLE_API_KEY` environment variable
|
||||
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
let oauth_cred_paths = Self::discover_oauth_cred_paths();
|
||||
let resolved_auth = api_key
|
||||
.and_then(Self::normalize_non_empty)
|
||||
.map(GeminiAuth::ExplicitKey)
|
||||
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
|
||||
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
|
||||
.or_else(|| {
|
||||
Self::try_load_gemini_cli_token()
|
||||
Self::try_load_gemini_cli_token(oauth_cred_paths.first())
|
||||
.map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))))
|
||||
});
|
||||
|
||||
Self {
|
||||
auth: resolved_auth,
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
oauth_cred_paths,
|
||||
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -410,22 +415,58 @@ impl GeminiProvider {
|
||||
.and_then(|value| Self::normalize_non_empty(&value))
|
||||
}
|
||||
|
||||
fn load_gemini_cli_creds() -> Option<GeminiCliOAuthCreds> {
|
||||
let gemini_dir = Self::gemini_cli_dir()?;
|
||||
let creds_path = gemini_dir.join("oauth_creds.json");
|
||||
fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option<GeminiCliOAuthCreds> {
|
||||
if !creds_path.exists() {
|
||||
return None;
|
||||
}
|
||||
let content = std::fs::read_to_string(&creds_path).ok()?;
|
||||
let content = std::fs::read_to_string(creds_path).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
|
||||
/// Discover all OAuth credential files from known Gemini CLI installations.
|
||||
///
|
||||
/// Looks in `~/.gemini/oauth_creds.json` (default) plus any
|
||||
/// `~/.gemini-*-home/.gemini/oauth_creds.json` siblings.
|
||||
fn discover_oauth_cred_paths() -> Vec<PathBuf> {
|
||||
let home = match UserDirs::new() {
|
||||
Some(u) => u.home_dir().to_path_buf(),
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
let mut paths = Vec::new();
|
||||
|
||||
let primary = home.join(".gemini").join("oauth_creds.json");
|
||||
if primary.exists() {
|
||||
paths.push(primary);
|
||||
}
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(&home) {
|
||||
let mut extras: Vec<PathBuf> = entries
|
||||
.filter_map(|e| e.ok())
|
||||
.filter_map(|e| {
|
||||
let name = e.file_name().to_string_lossy().to_string();
|
||||
if name.starts_with(".gemini-") && name.ends_with("-home") {
|
||||
let path = e.path().join(".gemini").join("oauth_creds.json");
|
||||
if path.exists() {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.collect();
|
||||
extras.sort();
|
||||
paths.extend(extras);
|
||||
}
|
||||
|
||||
paths
|
||||
}
|
||||
|
||||
/// Try to load OAuth credentials from Gemini CLI's cached credentials.
|
||||
/// Location: `~/.gemini/oauth_creds.json`
|
||||
///
|
||||
/// Returns the full `OAuthTokenState` so the provider can refresh at runtime.
|
||||
fn try_load_gemini_cli_token() -> Option<OAuthTokenState> {
|
||||
let creds = Self::load_gemini_cli_creds()?;
|
||||
fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option<OAuthTokenState> {
|
||||
let creds = Self::load_gemini_cli_creds(path?)?;
|
||||
|
||||
// Determine expiry in millis: prefer expiry_date over expiry (RFC 3339)
|
||||
let expiry_millis = creds.expiry_date.or_else(|| {
|
||||
@ -469,14 +510,16 @@ impl GeminiProvider {
|
||||
|
||||
/// Check if Gemini CLI is configured and has valid credentials
|
||||
pub fn has_cli_credentials() -> bool {
|
||||
Self::load_gemini_cli_creds()
|
||||
.and_then(|creds| {
|
||||
creds
|
||||
.access_token
|
||||
.as_deref()
|
||||
.and_then(Self::normalize_non_empty)
|
||||
})
|
||||
.is_some()
|
||||
Self::discover_oauth_cred_paths().iter().any(|path| {
|
||||
Self::load_gemini_cli_creds(path)
|
||||
.and_then(|creds| {
|
||||
creds
|
||||
.access_token
|
||||
.as_deref()
|
||||
.and_then(Self::normalize_non_empty)
|
||||
})
|
||||
.is_some()
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if any Gemini authentication is available
|
||||
@ -537,6 +580,47 @@ impl GeminiProvider {
|
||||
Ok(guard.access_token.clone())
|
||||
}
|
||||
|
||||
/// Rotate to the next available OAuth credentials file and swap state.
|
||||
/// Returns `true` when rotation succeeded.
|
||||
async fn rotate_oauth_credential(
|
||||
&self,
|
||||
state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
|
||||
) -> bool {
|
||||
if self.oauth_cred_paths.len() <= 1 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut idx = self.oauth_index.lock().await;
|
||||
let start = *idx;
|
||||
|
||||
loop {
|
||||
let next = (*idx + 1) % self.oauth_cred_paths.len();
|
||||
*idx = next;
|
||||
|
||||
if next == start {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(next_state) =
|
||||
Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next))
|
||||
{
|
||||
{
|
||||
let mut guard = state.lock().await;
|
||||
*guard = next_state;
|
||||
}
|
||||
{
|
||||
let mut cached_project = self.oauth_project.lock().await;
|
||||
*cached_project = None;
|
||||
}
|
||||
tracing::warn!(
|
||||
"Gemini OAuth: rotated credential to {}",
|
||||
self.oauth_cred_paths[next].display()
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_model_name(model: &str) -> String {
|
||||
if model.starts_with("models/") {
|
||||
model.to_string()
|
||||
@ -693,6 +777,13 @@ impl GeminiProvider {
|
||||
|| error_text.contains("Unknown name 'generationConfig'")
|
||||
|| error_text.contains(r#"Unknown name \"generationConfig\""#)
|
||||
}
|
||||
|
||||
fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool {
|
||||
status == reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == reqwest::StatusCode::SERVICE_UNAVAILABLE
|
||||
|| status.is_server_error()
|
||||
|| error_text.contains("RESOURCE_EXHAUSTED")
|
||||
}
|
||||
}
|
||||
|
||||
impl GeminiProvider {
|
||||
@ -713,8 +804,13 @@ impl GeminiProvider {
|
||||
)
|
||||
})?;
|
||||
|
||||
let oauth_state = match auth {
|
||||
GeminiAuth::OAuthToken(state) => Some(state.clone()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// For OAuth: get a valid (potentially refreshed) token and resolve project
|
||||
let (oauth_token, project) = if let GeminiAuth::OAuthToken(state) = auth {
|
||||
let (mut oauth_token, mut project) = if let GeminiAuth::OAuthToken(state) = auth {
|
||||
let token = Self::get_valid_oauth_token(state).await?;
|
||||
let proj = self.resolve_oauth_project(&token).await?;
|
||||
(Some(token), Some(proj))
|
||||
@ -746,6 +842,58 @@ impl GeminiProvider {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
|
||||
if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) {
|
||||
if let Some(state) = oauth_state.as_ref() {
|
||||
if self.rotate_oauth_credential(state).await {
|
||||
let token = Self::get_valid_oauth_token(state).await?;
|
||||
let proj = self.resolve_oauth_project(&token).await?;
|
||||
oauth_token = Some(token);
|
||||
project = Some(proj);
|
||||
response = self
|
||||
.build_generate_content_request(
|
||||
auth,
|
||||
&url,
|
||||
&request,
|
||||
model,
|
||||
true,
|
||||
project.as_deref(),
|
||||
oauth_token.as_deref(),
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
} else {
|
||||
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
||||
}
|
||||
} else {
|
||||
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
||||
}
|
||||
} else if auth.is_oauth()
|
||||
&& Self::should_retry_oauth_without_generation_config(status, &error_text)
|
||||
{
|
||||
tracing::warn!(
|
||||
"Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
|
||||
);
|
||||
response = self
|
||||
.build_generate_content_request(
|
||||
auth,
|
||||
&url,
|
||||
&request,
|
||||
model,
|
||||
false,
|
||||
project.as_deref(),
|
||||
oauth_token.as_deref(),
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
} else {
|
||||
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
||||
}
|
||||
}
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
@ -977,6 +1125,15 @@ mod tests {
|
||||
})))
|
||||
}
|
||||
|
||||
fn test_provider(auth: Option<GeminiAuth>) -> GeminiProvider {
|
||||
GeminiProvider {
|
||||
auth,
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
oauth_cred_paths: Vec::new(),
|
||||
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_non_empty_trims_and_filters() {
|
||||
assert_eq!(
|
||||
@ -1021,28 +1178,19 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn auth_source_explicit_key() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(GeminiAuth::ExplicitKey("key".into())),
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(Some(GeminiAuth::ExplicitKey("key".into())));
|
||||
assert_eq!(provider.auth_source(), "config");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_source_none_without_credentials() {
|
||||
let provider = GeminiProvider {
|
||||
auth: None,
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(None);
|
||||
assert_eq!(provider.auth_source(), "none");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_source_oauth() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(test_oauth_auth("ya29.mock")),
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(Some(test_oauth_auth("ya29.mock")));
|
||||
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
|
||||
}
|
||||
|
||||
@ -1093,10 +1241,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn oauth_request_uses_bearer_auth_header() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(test_oauth_auth("ya29.mock-token")),
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
|
||||
let auth = test_oauth_auth("ya29.mock-token");
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
let body = GenerateContentRequest {
|
||||
@ -1137,10 +1282,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn oauth_request_wraps_payload_in_request_envelope() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(test_oauth_auth("ya29.mock-token")),
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
|
||||
let auth = test_oauth_auth("ya29.mock-token");
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
let body = GenerateContentRequest {
|
||||
@ -1184,10 +1326,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn api_key_request_does_not_set_bearer_header() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())),
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into())));
|
||||
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
let body = GenerateContentRequest {
|
||||
@ -1610,24 +1749,36 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_without_key_is_noop() {
|
||||
let provider = GeminiProvider {
|
||||
auth: None,
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(None);
|
||||
let result = provider.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_oauth_is_noop() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(test_oauth_auth("ya29.mock-token")),
|
||||
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
|
||||
let result = provider.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_oauth_cred_paths_does_not_panic() {
|
||||
let _paths = GeminiProvider::discover_oauth_cred_paths();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rotate_oauth_without_alternatives_returns_false() {
|
||||
let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
|
||||
access_token: "ya29.mock".to_string(),
|
||||
refresh_token: None,
|
||||
client_id: None,
|
||||
client_secret: None,
|
||||
expiry_millis: None,
|
||||
}));
|
||||
let provider = test_provider(Some(GeminiAuth::OAuthToken(state.clone())));
|
||||
assert!(!provider.rotate_oauth_credential(&state).await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_parses_usage_metadata() {
|
||||
let json = r#"{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user