From d56c0618969b8f0ed838c99ef25063c3fb7ceae1 Mon Sep 17 00:00:00 2001 From: Aleksandr Prilipko Date: Sat, 21 Feb 2026 15:32:29 +0700 Subject: [PATCH] refactor(auth): add Gemini OAuth and consolidate OAuth utilities (DRY) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add src/auth/gemini_oauth.rs: Full Gemini/Google OAuth2 implementation - PKCE authorization code flow with loopback redirect - Device code flow for headless environments - Token refresh with automatic expiration handling - Stdin fallback for remote/headless OAuth callback capture - Add src/auth/oauth_common.rs: Shared OAuth utilities - PkceState struct and generate_pkce_state() - url_encode/url_decode (RFC 3986) - parse_query_params for URL parameter parsing - random_base64url for cryptographic random generation - Update src/auth/mod.rs: Add Gemini support to AuthService - store_gemini_tokens() for saving OAuth tokens - get_valid_gemini_access_token() with automatic refresh - get_gemini_profile() for provider initialization - Update src/main.rs: Generic PendingOAuthLogin - Consolidate PendingOpenAiLogin and PendingGeminiLogin into generic struct - Reduce 10 functions to 4 generic functions - Support both openai-codex and gemini providers in auth commands - Update src/providers/gemini.rs: ManagedOAuth authentication - GeminiAuth enum with ApiKey and ManagedOAuth variants - new_with_auth() constructor for OAuth-based authentication - Automatic token refresh via AuthService integration - Update src/providers/mod.rs: Wire GeminiProvider with AuthService Net reduction: ~290 lines of duplicated code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/auth/gemini_oauth.rs | 575 +++++++++++++++++++++++++++++++++++++++ src/auth/mod.rs | 119 ++++++++ src/auth/oauth_common.rs | 183 +++++++++++++ src/auth/openai_oauth.rs | 95 +------ src/main.rs | 432 +++++++++++++++++++---------- src/providers/gemini.rs | 204 +++++++++++--- src/providers/mod.rs | 17 +- 7 files changed, 1360 insertions(+), 265 deletions(-) create mode 100644 src/auth/gemini_oauth.rs create mode 100644 src/auth/oauth_common.rs diff --git a/src/auth/gemini_oauth.rs b/src/auth/gemini_oauth.rs new file mode 100644 index 000000000..6dfedaec5 --- /dev/null +++ b/src/auth/gemini_oauth.rs @@ -0,0 +1,575 @@ +//! Google/Gemini OAuth2 authentication flow. +//! +//! Supports: +//! - Authorization code flow with PKCE (loopback redirect) +//! - Device code flow for headless environments +//! +//! Uses the same client credentials as Gemini CLI for compatibility. + +use crate::auth::oauth_common::{parse_query_params, url_decode, url_encode}; +use crate::auth::profiles::TokenSet; +use anyhow::{Context, Result}; +use base64::Engine; +use chrono::Utc; +use reqwest::Client; +use serde::Deserialize; +use std::collections::BTreeMap; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +// Re-export for external use (used by main.rs) +#[allow(unused_imports)] +pub use crate::auth::oauth_common::{generate_pkce_state, PkceState}; + +/// Get Gemini OAuth client ID from environment. +/// Required: set GEMINI_OAUTH_CLIENT_ID environment variable. +pub fn gemini_oauth_client_id() -> Option { + std::env::var("GEMINI_OAUTH_CLIENT_ID") + .ok() + .filter(|s| !s.is_empty()) +} + +/// Get Gemini OAuth client secret from environment. +/// Required: set GEMINI_OAUTH_CLIENT_SECRET environment variable. +pub fn gemini_oauth_client_secret() -> Option { + std::env::var("GEMINI_OAUTH_CLIENT_SECRET") + .ok() + .filter(|s| !s.is_empty()) +} + +/// Get required OAuth credentials or return error. +fn get_oauth_credentials() -> Result<(String, String)> { + let client_id = gemini_oauth_client_id().ok_or_else(|| { + anyhow::anyhow!("GEMINI_OAUTH_CLIENT_ID environment variable is required") + })?; + let client_secret = gemini_oauth_client_secret().ok_or_else(|| { + anyhow::anyhow!("GEMINI_OAUTH_CLIENT_SECRET environment variable is required") + })?; + Ok((client_id, client_secret)) +} + +pub const GOOGLE_OAUTH_AUTHORIZE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth"; +pub const GOOGLE_OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token"; +pub const GOOGLE_OAUTH_DEVICE_CODE_URL: &str = "https://oauth2.googleapis.com/device/code"; +pub const GEMINI_OAUTH_REDIRECT_URI: &str = "http://localhost:1456/auth/callback"; + +/// Scopes required for Gemini API access. +pub const GEMINI_OAUTH_SCOPES: &str = + "openid profile email https://www.googleapis.com/auth/cloud-platform"; + +#[derive(Debug, Clone)] +pub struct DeviceCodeStart { + pub device_code: String, + pub user_code: String, + pub verification_uri: String, + pub verification_uri_complete: Option, + pub expires_in: u64, + pub interval: u64, +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + id_token: Option, + #[serde(default)] + expires_in: Option, + #[serde(default)] + token_type: Option, + #[serde(default)] + scope: Option, +} + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_url: String, + #[serde(default)] + expires_in: Option, + #[serde(default)] + interval: Option, +} + +#[derive(Debug, Deserialize)] +struct OAuthErrorResponse { + error: String, + #[serde(default)] + error_description: Option, +} + +pub fn build_authorize_url(pkce: &PkceState) -> Result { + let (client_id, _) = get_oauth_credentials()?; + let mut params = BTreeMap::new(); + params.insert("response_type", "code"); + params.insert("client_id", client_id.as_str()); + params.insert("redirect_uri", GEMINI_OAUTH_REDIRECT_URI); + params.insert("scope", GEMINI_OAUTH_SCOPES); + params.insert("code_challenge", pkce.code_challenge.as_str()); + params.insert("code_challenge_method", "S256"); + params.insert("state", pkce.state.as_str()); + params.insert("access_type", "offline"); + params.insert("prompt", "consent"); + + let mut encoded: Vec = Vec::with_capacity(params.len()); + for (k, v) in params { + encoded.push(format!("{}={}", url_encode(k), url_encode(v))); + } + + Ok(format!( + "{}?{}", + GOOGLE_OAUTH_AUTHORIZE_URL, + encoded.join("&") + )) +} + +pub async fn exchange_code_for_tokens( + client: &Client, + code: &str, + pkce: &PkceState, +) -> Result { + let (client_id, client_secret) = get_oauth_credentials()?; + let form = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", GEMINI_OAUTH_REDIRECT_URI), + ("client_id", client_id.as_str()), + ("client_secret", client_secret.as_str()), + ("code_verifier", &pkce.code_verifier), + ]; + + let response = client + .post(GOOGLE_OAUTH_TOKEN_URL) + .form(&form) + .send() + .await + .context("Failed to send token exchange request")?; + + let status = response.status(); + let body = response + .text() + .await + .context("Failed to read token response body")?; + + if !status.is_success() { + if let Ok(err) = serde_json::from_str::(&body) { + anyhow::bail!( + "Google OAuth error: {} - {}", + err.error, + err.error_description.unwrap_or_default() + ); + } + anyhow::bail!("Google OAuth token exchange failed ({}): {}", status, body); + } + + let token_response: TokenResponse = + serde_json::from_str(&body).context("Failed to parse token response")?; + + let expires_at = token_response + .expires_in + .map(|secs| Utc::now() + chrono::Duration::seconds(secs)); + + Ok(TokenSet { + access_token: token_response.access_token, + refresh_token: token_response.refresh_token, + id_token: token_response.id_token, + expires_at, + token_type: token_response.token_type.or_else(|| Some("Bearer".into())), + scope: token_response.scope, + }) +} + +pub async fn refresh_access_token(client: &Client, refresh_token: &str) -> Result { + let (client_id, client_secret) = get_oauth_credentials()?; + let form = [ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ("client_id", client_id.as_str()), + ("client_secret", client_secret.as_str()), + ]; + + let response = client + .post(GOOGLE_OAUTH_TOKEN_URL) + .form(&form) + .send() + .await + .context("Failed to send refresh token request")?; + + let status = response.status(); + let body = response + .text() + .await + .context("Failed to read refresh response body")?; + + if !status.is_success() { + if let Ok(err) = serde_json::from_str::(&body) { + anyhow::bail!( + "Google OAuth refresh error: {} - {}", + err.error, + err.error_description.unwrap_or_default() + ); + } + anyhow::bail!("Google OAuth refresh failed ({}): {}", status, body); + } + + let token_response: TokenResponse = + serde_json::from_str(&body).context("Failed to parse refresh response")?; + + let expires_at = token_response + .expires_in + .map(|secs| Utc::now() + chrono::Duration::seconds(secs)); + + Ok(TokenSet { + access_token: token_response.access_token, + refresh_token: token_response.refresh_token, + id_token: token_response.id_token, + expires_at, + token_type: token_response.token_type.or_else(|| Some("Bearer".into())), + scope: token_response.scope, + }) +} + +pub async fn start_device_code_flow(client: &Client) -> Result { + let (client_id, _) = get_oauth_credentials()?; + let form = [ + ("client_id", client_id.as_str()), + ("scope", GEMINI_OAUTH_SCOPES), + ]; + + let response = client + .post(GOOGLE_OAUTH_DEVICE_CODE_URL) + .form(&form) + .send() + .await + .context("Failed to start device code flow")?; + + let status = response.status(); + let body = response + .text() + .await + .context("Failed to read device code response")?; + + if !status.is_success() { + if let Ok(err) = serde_json::from_str::(&body) { + anyhow::bail!( + "Google device code error: {} - {}", + err.error, + err.error_description.unwrap_or_default() + ); + } + anyhow::bail!("Google device code request failed ({}): {}", status, body); + } + + let device_response: DeviceCodeResponse = + serde_json::from_str(&body).context("Failed to parse device code response")?; + + let user_code = device_response.user_code; + let verification_url = device_response.verification_url; + + Ok(DeviceCodeStart { + device_code: device_response.device_code, + verification_uri_complete: Some(format!("{}?user_code={}", &verification_url, &user_code)), + user_code, + verification_uri: verification_url, + expires_in: device_response.expires_in.unwrap_or(1800), + interval: device_response.interval.unwrap_or(5), + }) +} + +pub async fn poll_device_code_tokens( + client: &Client, + device: &DeviceCodeStart, +) -> Result { + let (client_id, client_secret) = get_oauth_credentials()?; + let deadline = std::time::Instant::now() + Duration::from_secs(device.expires_in); + let interval = Duration::from_secs(device.interval.max(5)); + + loop { + if std::time::Instant::now() > deadline { + anyhow::bail!("Device code expired before authorization was completed"); + } + + tokio::time::sleep(interval).await; + + let form = [ + ("client_id", client_id.as_str()), + ("client_secret", client_secret.as_str()), + ("device_code", device.device_code.as_str()), + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ]; + + let response = client + .post(GOOGLE_OAUTH_TOKEN_URL) + .form(&form) + .send() + .await + .context("Failed to poll device code")?; + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + + if status.is_success() { + let token_response: TokenResponse = + serde_json::from_str(&body).context("Failed to parse token response")?; + + let expires_at = token_response + .expires_in + .map(|secs| Utc::now() + chrono::Duration::seconds(secs)); + + return Ok(TokenSet { + access_token: token_response.access_token, + refresh_token: token_response.refresh_token, + id_token: token_response.id_token, + expires_at, + token_type: token_response.token_type.or_else(|| Some("Bearer".into())), + scope: token_response.scope, + }); + } + + if let Ok(err) = serde_json::from_str::(&body) { + match err.error.as_str() { + "authorization_pending" => {} + "slow_down" => { + tokio::time::sleep(Duration::from_secs(5)).await; + } + "access_denied" => { + anyhow::bail!("User denied authorization"); + } + "expired_token" => { + anyhow::bail!("Device code expired"); + } + _ => { + anyhow::bail!( + "Google OAuth error: {} - {}", + err.error, + err.error_description.unwrap_or_default() + ); + } + } + } + } +} + +/// Receive OAuth code via loopback callback OR manual stdin input. +/// +/// If the callback server can't receive the redirect (e.g., remote/headless environment), +/// the user can paste the full callback URL or just the code. +pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result { + // Try to bind to the callback port + let listener = match TcpListener::bind("127.0.0.1:1456").await { + Ok(l) => l, + Err(e) => { + eprintln!("Could not bind to localhost:1456: {e}"); + eprintln!("Falling back to manual input."); + return receive_code_from_stdin(expected_state).await; + } + }; + + println!("Waiting for callback at http://localhost:1456/auth/callback ..."); + println!("(Or paste the full callback URL / authorization code here if running remotely)"); + + // Race between: callback arriving OR stdin input + tokio::select! { + accept_result = async { + tokio::time::timeout(timeout, listener.accept()).await + } => { + match accept_result { + Ok(Ok((mut stream, _))) => { + let mut buffer = vec![0u8; 4096]; + let n = stream + .read(&mut buffer) + .await + .context("Failed to read from callback connection")?; + + let request = String::from_utf8_lossy(&buffer[..n]); + let (code, state) = parse_callback_request(&request)?; + + if state != expected_state { + let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\ +

State mismatch

Please try again.

"; + let _ = stream.write_all(response.as_bytes()).await; + anyhow::bail!("OAuth state mismatch"); + } + + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\ +

Success!

You can close this window and return to the terminal.

"; + let _ = stream.write_all(response.as_bytes()).await; + + Ok(code) + } + Ok(Err(e)) => Err(anyhow::anyhow!("Failed to accept connection: {e}")), + Err(_) => { + eprintln!("\nCallback timeout. Falling back to manual input."); + receive_code_from_stdin(expected_state).await + } + } + } + stdin_result = receive_code_from_stdin(expected_state) => { + stdin_result + } + } +} + +/// Read authorization code from stdin (supports full URL or raw code). +async fn receive_code_from_stdin(expected_state: &str) -> Result { + use std::io::{self, BufRead}; + + let expected = expected_state.to_string(); + let input = tokio::task::spawn_blocking(move || { + let stdin = io::stdin(); + let mut line = String::new(); + stdin.lock().read_line(&mut line).ok(); + let trimmed = line.trim().to_string(); + if trimmed.is_empty() { + return Err(anyhow::anyhow!("No input received")); + } + parse_code_from_redirect(&trimmed, Some(&expected)) + }) + .await + .context("Failed to read from stdin")??; + + Ok(input) +} + +fn parse_callback_request(request: &str) -> Result<(String, String)> { + let first_line = request.lines().next().unwrap_or(""); + let path = first_line + .split_whitespace() + .nth(1) + .unwrap_or("") + .to_string(); + + let query_start = path.find('?').map(|i| i + 1).unwrap_or(path.len()); + let query = &path[query_start..]; + + let mut code = None; + let mut state = None; + + for pair in query.split('&') { + if let Some((key, value)) = pair.split_once('=') { + match key { + "code" => code = Some(url_decode(value)), + "state" => state = Some(url_decode(value)), + _ => {} + } + } + } + + let code = code.ok_or_else(|| anyhow::anyhow!("No 'code' parameter in callback"))?; + let state = state.ok_or_else(|| anyhow::anyhow!("No 'state' parameter in callback"))?; + + Ok((code, state)) +} + +pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result { + let trimmed = input.trim(); + if trimmed.is_empty() { + anyhow::bail!("No OAuth code provided"); + } + + // Extract query string + let query = if let Some((_, right)) = trimmed.split_once('?') { + right + } else { + trimmed + }; + + let params = parse_query_params(query); + + // If we have code param, extract it + if let Some(code) = params.get("code") { + // Validate state if expected + if let Some(expected) = expected_state { + if let Some(actual) = params.get("state") { + if actual != expected { + anyhow::bail!("OAuth state mismatch: expected {expected}, got {actual}"); + } + } + } + return Ok(code.clone()); + } + + // Otherwise, assume it's the raw code (if long enough and no spaces) + if trimmed.len() > 10 && !trimmed.contains(' ') && !trimmed.contains('&') { + return Ok(trimmed.to_string()); + } + + anyhow::bail!("Could not parse OAuth code from input") +} + +/// Extract account email from Google ID token. +pub fn extract_account_email_from_id_token(id_token: &str) -> Option { + let parts: Vec<&str> = id_token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(parts[1]) + .ok()?; + + #[derive(Deserialize)] + struct IdTokenPayload { + email: Option, + } + + let payload: IdTokenPayload = serde_json::from_slice(&payload).ok()?; + payload.email +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pkce_generates_valid_state() { + let pkce = generate_pkce_state(); + assert!(!pkce.code_verifier.is_empty()); + assert!(!pkce.code_challenge.is_empty()); + assert!(!pkce.state.is_empty()); + } + + #[test] + fn authorize_url_contains_required_params() { + // Set test credentials + std::env::set_var("GEMINI_OAUTH_CLIENT_ID", "test-client-id"); + std::env::set_var("GEMINI_OAUTH_CLIENT_SECRET", "test-client-secret"); + + let pkce = generate_pkce_state(); + let url = build_authorize_url(&pkce).expect("Failed to build authorize URL"); + assert!(url.contains("accounts.google.com")); + assert!(url.contains("client_id=")); + assert!(url.contains("redirect_uri=")); + assert!(url.contains("code_challenge=")); + assert!(url.contains("access_type=offline")); + } + + #[test] + fn parse_code_from_url() { + let url = "http://localhost:1456/auth/callback?code=4/0test&state=xyz"; + let code = parse_code_from_redirect(url, Some("xyz")).unwrap(); + assert_eq!(code, "4/0test"); + } + + #[test] + fn parse_code_from_raw() { + let raw = "4/0AcvDMrC1234567890abcdef"; + let code = parse_code_from_redirect(raw, None).unwrap(); + assert_eq!(code, raw); + } + + #[test] + fn extract_email_from_id_token() { + // Minimal test JWT with email claim + let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#); + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(r#"{"email":"test@example.com"}"#); + let token = format!("{}.{}.signature", header, payload); + + let email = extract_account_email_from_id_token(&token); + assert_eq!(email, Some("test@example.com".to_string())); + } +} diff --git a/src/auth/mod.rs b/src/auth/mod.rs index bf6749656..a9ebd9963 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,4 +1,6 @@ pub mod anthropic_token; +pub mod gemini_oauth; +pub mod oauth_common; pub mod openai_oauth; pub mod profiles; @@ -15,6 +17,7 @@ use std::time::{Duration, Instant}; const OPENAI_CODEX_PROVIDER: &str = "openai-codex"; const ANTHROPIC_PROVIDER: &str = "anthropic"; +const GEMINI_PROVIDER: &str = "gemini"; const DEFAULT_PROFILE_NAME: &str = "default"; const OPENAI_REFRESH_SKEW_SECS: u64 = 90; const OPENAI_REFRESH_FAILURE_BACKOFF_SECS: u64 = 10; @@ -58,6 +61,21 @@ impl AuthService { Ok(profile) } + pub async fn store_gemini_tokens( + &self, + profile_name: &str, + token_set: crate::auth::profiles::TokenSet, + account_id: Option, + set_active: bool, + ) -> Result { + let mut profile = AuthProfile::new_oauth(GEMINI_PROVIDER, profile_name, token_set); + profile.account_id = account_id; + self.store + .upsert_profile(profile.clone(), set_active) + .await?; + Ok(profile) + } + pub async fn store_provider_token( &self, provider: &str, @@ -224,6 +242,106 @@ impl AuthService { Ok(updated.token_set.map(|t| t.access_token)) } + + /// Get a valid Gemini OAuth access token, refreshing if necessary. + /// + /// Returns `None` if no Gemini profile exists. + pub async fn get_valid_gemini_access_token( + &self, + profile_override: Option<&str>, + ) -> Result> { + let data = self.store.load().await?; + let Some(profile_id) = select_profile_id(&data, GEMINI_PROVIDER, profile_override) else { + return Ok(None); + }; + + let Some(profile) = data.profiles.get(&profile_id) else { + return Ok(None); + }; + + let Some(token_set) = profile.token_set.as_ref() else { + anyhow::bail!("Gemini auth profile is not OAuth-based: {profile_id}"); + }; + + if !token_set.is_expiring_within(Duration::from_secs(OPENAI_REFRESH_SKEW_SECS)) { + return Ok(Some(token_set.access_token.clone())); + } + + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Ok(Some(token_set.access_token.clone())); + }; + + let refresh_lock = refresh_lock_for_profile(&profile_id); + let _guard = refresh_lock.lock().await; + + // Re-load after waiting for lock to avoid duplicate refreshes. + let data = self.store.load().await?; + let Some(latest_profile) = data.profiles.get(&profile_id) else { + return Ok(None); + }; + + let Some(latest_tokens) = latest_profile.token_set.as_ref() else { + anyhow::bail!("Gemini auth profile is missing token set: {profile_id}"); + }; + + if !latest_tokens.is_expiring_within(Duration::from_secs(OPENAI_REFRESH_SKEW_SECS)) { + return Ok(Some(latest_tokens.access_token.clone())); + } + + let refresh_token = latest_tokens.refresh_token.clone().unwrap_or(refresh_token); + + if let Some(remaining) = refresh_backoff_remaining(&profile_id) { + anyhow::bail!( + "Gemini token refresh is in backoff for {remaining}s due to previous failures" + ); + } + + let mut refreshed = + match gemini_oauth::refresh_access_token(&self.client, &refresh_token).await { + Ok(tokens) => { + clear_refresh_backoff(&profile_id); + tokens + } + Err(err) => { + set_refresh_backoff( + &profile_id, + Duration::from_secs(OPENAI_REFRESH_FAILURE_BACKOFF_SECS), + ); + return Err(err); + } + }; + if refreshed.refresh_token.is_none() { + refreshed + .refresh_token + .clone_from(&latest_tokens.refresh_token); + } + + let account_id = refreshed + .id_token + .as_deref() + .and_then(gemini_oauth::extract_account_email_from_id_token) + .or_else(|| latest_profile.account_id.clone()); + + let updated = self + .store + .update_profile(&profile_id, |profile| { + profile.kind = AuthProfileKind::OAuth; + profile.token_set = Some(refreshed.clone()); + profile.account_id.clone_from(&account_id); + Ok(()) + }) + .await?; + + Ok(updated.token_set.map(|t| t.access_token)) + } + + /// Get Gemini profile info (for provider initialization). + pub async fn get_gemini_profile( + &self, + profile_override: Option<&str>, + ) -> Result> { + self.get_profile(GEMINI_PROVIDER, profile_override).await + } } pub fn normalize_provider(provider: &str) -> Result { @@ -231,6 +349,7 @@ pub fn normalize_provider(provider: &str) -> Result { match normalized.as_str() { "openai-codex" | "openai_codex" | "codex" => Ok(OPENAI_CODEX_PROVIDER.to_string()), "anthropic" | "claude" | "claude-code" => Ok(ANTHROPIC_PROVIDER.to_string()), + "gemini" | "google" | "vertex" => Ok(GEMINI_PROVIDER.to_string()), other if !other.is_empty() => Ok(other.to_string()), _ => anyhow::bail!("Provider name cannot be empty"), } diff --git a/src/auth/oauth_common.rs b/src/auth/oauth_common.rs new file mode 100644 index 000000000..b279c800e --- /dev/null +++ b/src/auth/oauth_common.rs @@ -0,0 +1,183 @@ +//! Common OAuth2 utilities shared across providers. +//! +//! This module contains shared functionality for OAuth2 authentication: +//! - PKCE (Proof Key for Code Exchange) state generation +//! - URL encoding/decoding +//! - Query parameter parsing + +use base64::Engine; +use sha2::{Digest, Sha256}; +use std::collections::BTreeMap; + +/// PKCE state container for OAuth2 authorization code flow. +#[derive(Debug, Clone)] +pub struct PkceState { + pub code_verifier: String, + pub code_challenge: String, + pub state: String, +} + +/// Generate a new PKCE state with cryptographically random values. +/// +/// Creates a code verifier, derives the S256 code challenge, and generates +/// a random state parameter for CSRF protection. +pub fn generate_pkce_state() -> PkceState { + let code_verifier = random_base64url(64); + let digest = Sha256::digest(code_verifier.as_bytes()); + let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest); + + PkceState { + code_verifier, + code_challenge, + state: random_base64url(24), + } +} + +/// Generate a cryptographically random base64url-encoded string. +pub fn random_base64url(byte_len: usize) -> String { + use chacha20poly1305::aead::{rand_core::RngCore, OsRng}; + + let mut bytes = vec![0_u8; byte_len]; + OsRng.fill_bytes(&mut bytes); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) +} + +/// URL-encode a string using percent encoding (RFC 3986). +pub fn url_encode(input: &str) -> String { + input + .bytes() + .map(|b| match b { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + (b as char).to_string() + } + _ => format!("%{b:02X}"), + }) + .collect::() +} + +/// URL-decode a percent-encoded string. +pub fn url_decode(input: &str) -> String { + let bytes = input.as_bytes(); + let mut out = Vec::with_capacity(bytes.len()); + let mut i = 0; + + while i < bytes.len() { + match bytes[i] { + b'%' if i + 2 < bytes.len() => { + let hi = bytes[i + 1] as char; + let lo = bytes[i + 2] as char; + if let (Some(h), Some(l)) = (hi.to_digit(16), lo.to_digit(16)) { + if let Ok(value) = u8::try_from(h * 16 + l) { + out.push(value); + i += 3; + continue; + } + } + out.push(bytes[i]); + i += 1; + } + b'+' => { + out.push(b' '); + i += 1; + } + b => { + out.push(b); + i += 1; + } + } + } + + String::from_utf8_lossy(&out).to_string() +} + +/// Parse URL query parameters into a BTreeMap. +/// +/// Handles URL-encoded keys and values. +pub fn parse_query_params(input: &str) -> BTreeMap { + let mut out = BTreeMap::new(); + for pair in input.split('&') { + if pair.is_empty() { + continue; + } + let (key, value) = match pair.split_once('=') { + Some((k, v)) => (k, v), + None => (pair, ""), + }; + out.insert(url_decode(key), url_decode(value)); + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pkce_generation_is_valid() { + let pkce = generate_pkce_state(); + // Code verifier should be at least 43 chars (base64url of 32 bytes) + assert!(pkce.code_verifier.len() >= 43); + assert!(!pkce.code_challenge.is_empty()); + assert!(!pkce.state.is_empty()); + } + + #[test] + fn pkce_challenge_is_sha256_of_verifier() { + let pkce = generate_pkce_state(); + let expected = { + let digest = Sha256::digest(pkce.code_verifier.as_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) + }; + assert_eq!(pkce.code_challenge, expected); + } + + #[test] + fn url_encode_basic() { + assert_eq!(url_encode("hello"), "hello"); + assert_eq!(url_encode("hello world"), "hello%20world"); + assert_eq!(url_encode("a=b&c=d"), "a%3Db%26c%3Dd"); + } + + #[test] + fn url_decode_basic() { + assert_eq!(url_decode("hello"), "hello"); + assert_eq!(url_decode("hello%20world"), "hello world"); + assert_eq!(url_decode("hello+world"), "hello world"); + assert_eq!(url_decode("a%3Db%26c%3Dd"), "a=b&c=d"); + } + + #[test] + fn url_encode_decode_roundtrip() { + let original = "hello world! @#$%^&*()"; + let encoded = url_encode(original); + let decoded = url_decode(&encoded); + assert_eq!(decoded, original); + } + + #[test] + fn parse_query_params_basic() { + let params = parse_query_params("code=abc123&state=xyz"); + assert_eq!(params.get("code"), Some(&"abc123".to_string())); + assert_eq!(params.get("state"), Some(&"xyz".to_string())); + } + + #[test] + fn parse_query_params_encoded() { + let params = parse_query_params("name=hello%20world&value=a%3Db"); + assert_eq!(params.get("name"), Some(&"hello world".to_string())); + assert_eq!(params.get("value"), Some(&"a=b".to_string())); + } + + #[test] + fn parse_query_params_empty() { + let params = parse_query_params(""); + assert!(params.is_empty()); + } + + #[test] + fn random_base64url_length() { + let s = random_base64url(32); + // base64url encodes 3 bytes to 4 chars, so 32 bytes = ~43 chars + assert!(s.len() >= 42); + } +} diff --git a/src/auth/openai_oauth.rs b/src/auth/openai_oauth.rs index 1acf4ab31..8e6442ddb 100644 --- a/src/auth/openai_oauth.rs +++ b/src/auth/openai_oauth.rs @@ -1,28 +1,26 @@ +use crate::auth::oauth_common::{parse_query_params, url_encode}; + use crate::auth::profiles::TokenSet; use anyhow::{Context, Result}; use base64::Engine; use chrono::Utc; use reqwest::Client; use serde::Deserialize; -use sha2::{Digest, Sha256}; use std::collections::BTreeMap; use std::time::{Duration, Instant}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; +// Re-export for external use (used by main.rs) +#[allow(unused_imports)] +pub use crate::auth::oauth_common::{generate_pkce_state, PkceState}; + pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const OPENAI_OAUTH_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize"; pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; pub const OPENAI_OAUTH_DEVICE_CODE_URL: &str = "https://auth.openai.com/oauth/device/code"; pub const OPENAI_OAUTH_REDIRECT_URI: &str = "http://localhost:1455/auth/callback"; -#[derive(Debug, Clone)] -pub struct PkceState { - pub code_verifier: String, - pub code_challenge: String, - pub state: String, -} - #[derive(Debug, Clone)] pub struct DeviceCodeStart { pub device_code: String, @@ -70,18 +68,6 @@ struct OAuthErrorResponse { error_description: Option, } -pub fn generate_pkce_state() -> PkceState { - let code_verifier = random_base64url(64); - let digest = Sha256::digest(code_verifier.as_bytes()); - let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest); - - PkceState { - code_verifier, - code_challenge, - state: random_base64url(24), - } -} - pub fn build_authorize_url(pkce: &PkceState) -> String { let mut params = BTreeMap::new(); params.insert("response_type", "code"); @@ -382,75 +368,6 @@ async fn parse_token_response(response: reqwest::Response) -> Result { }) } -fn parse_query_params(input: &str) -> BTreeMap { - let mut out = BTreeMap::new(); - for pair in input.split('&') { - if pair.is_empty() { - continue; - } - let (key, value) = match pair.split_once('=') { - Some((k, v)) => (k, v), - None => (pair, ""), - }; - out.insert(url_decode(key), url_decode(value)); - } - out -} - -fn random_base64url(byte_len: usize) -> String { - use chacha20poly1305::aead::{rand_core::RngCore, OsRng}; - - let mut bytes = vec![0_u8; byte_len]; - OsRng.fill_bytes(&mut bytes); - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) -} - -fn url_encode(input: &str) -> String { - input - .bytes() - .map(|b| match b { - b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { - (b as char).to_string() - } - _ => format!("%{b:02X}"), - }) - .collect::() -} - -fn url_decode(input: &str) -> String { - let bytes = input.as_bytes(); - let mut out = Vec::with_capacity(bytes.len()); - let mut i = 0; - - while i < bytes.len() { - match bytes[i] { - b'%' if i + 2 < bytes.len() => { - let hi = bytes[i + 1] as char; - let lo = bytes[i + 2] as char; - if let (Some(h), Some(l)) = (hi.to_digit(16), lo.to_digit(16)) { - if let Ok(value) = u8::try_from(h * 16 + l) { - out.push(value); - i += 3; - continue; - } - } - out.push(bytes[i]); - i += 1; - } - b'+' => { - out.push(b' '); - i += 1; - } - b => { - out.push(b); - i += 1; - } - } - } - - String::from_utf8_lossy(&out).to_string() -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/main.rs b/src/main.rs index 1f1c48ec4..39832c45a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -425,9 +425,9 @@ enum ConfigCommands { #[derive(Subcommand, Debug)] enum AuthCommands { - /// Login with OpenAI Codex OAuth + /// Login with OAuth (OpenAI Codex or Gemini) Login { - /// Provider (`openai-codex`) + /// Provider (`openai-codex` or `gemini`) #[arg(long)] provider: String, /// Profile name (default: default) @@ -948,8 +948,12 @@ fn write_shell_completion(shell: CompletionShell, writer: &mut W) -> R Ok(()) } +// ─── Generic Pending OAuth Login ──────────────────────────────────────────── + +/// Generic pending OAuth login state, shared across providers. #[derive(Debug, Clone, Serialize, Deserialize)] -struct PendingOpenAiLogin { +struct PendingOAuthLogin { + provider: String, profile: String, code_verifier: String, state: String, @@ -957,7 +961,9 @@ struct PendingOpenAiLogin { } #[derive(Debug, Clone, Serialize, Deserialize)] -struct PendingOpenAiLoginFile { +struct PendingOAuthLoginFile { + #[serde(default)] + provider: Option, profile: String, #[serde(skip_serializing_if = "Option::is_none")] code_verifier: Option, @@ -967,11 +973,12 @@ struct PendingOpenAiLoginFile { created_at: String, } -fn pending_openai_login_path(config: &Config) -> std::path::PathBuf { - auth::state_dir_from_config(config).join("auth-openai-pending.json") +fn pending_oauth_login_path(config: &Config, provider: &str) -> std::path::PathBuf { + let filename = format!("auth-{}-pending.json", provider); + auth::state_dir_from_config(config).join(filename) } -fn pending_openai_secret_store(config: &Config) -> security::secrets::SecretStore { +fn pending_oauth_secret_store(config: &Config) -> security::secrets::SecretStore { security::secrets::SecretStore::new( &auth::state_dir_from_config(config), config.secrets.encrypt, @@ -990,14 +997,15 @@ fn set_owner_only_permissions(_path: &std::path::Path) -> Result<()> { Ok(()) } -fn save_pending_openai_login(config: &Config, pending: &PendingOpenAiLogin) -> Result<()> { - let path = pending_openai_login_path(config); +fn save_pending_oauth_login(config: &Config, pending: &PendingOAuthLogin) -> Result<()> { + let path = pending_oauth_login_path(config, &pending.provider); if let Some(parent) = path.parent() { std::fs::create_dir_all(parent)?; } - let secret_store = pending_openai_secret_store(config); + let secret_store = pending_oauth_secret_store(config); let encrypted_code_verifier = secret_store.encrypt(&pending.code_verifier)?; - let persisted = PendingOpenAiLoginFile { + let persisted = PendingOAuthLoginFile { + provider: Some(pending.provider.clone()), profile: pending.profile.clone(), code_verifier: None, encrypted_code_verifier: Some(encrypted_code_verifier), @@ -1017,25 +1025,26 @@ fn save_pending_openai_login(config: &Config, pending: &PendingOpenAiLogin) -> R Ok(()) } -fn load_pending_openai_login(config: &Config) -> Result> { - let path = pending_openai_login_path(config); +fn load_pending_oauth_login(config: &Config, provider: &str) -> Result> { + let path = pending_oauth_login_path(config, provider); if !path.exists() { return Ok(None); } - let bytes = std::fs::read(path)?; + let bytes = std::fs::read(&path)?; if bytes.is_empty() { return Ok(None); } - let persisted: PendingOpenAiLoginFile = serde_json::from_slice(&bytes)?; - let secret_store = pending_openai_secret_store(config); + let persisted: PendingOAuthLoginFile = serde_json::from_slice(&bytes)?; + let secret_store = pending_oauth_secret_store(config); let code_verifier = if let Some(encrypted) = persisted.encrypted_code_verifier { secret_store.decrypt(&encrypted)? } else if let Some(plaintext) = persisted.code_verifier { plaintext } else { - bail!("Pending OpenAI login is missing code verifier"); + bail!("Pending {} login is missing code verifier", provider); }; - Ok(Some(PendingOpenAiLogin { + Ok(Some(PendingOAuthLogin { + provider: persisted.provider.unwrap_or_else(|| provider.to_string()), profile: persisted.profile, code_verifier, state: persisted.state, @@ -1043,8 +1052,8 @@ fn load_pending_openai_login(config: &Config) -> Result Res device_code, } => { let provider = auth::normalize_provider(&provider)?; - if provider != "openai-codex" { - bail!("`auth login` currently supports only --provider openai-codex"); - } - let client = reqwest::Client::new(); - if device_code { - match auth::openai_oauth::start_device_code_flow(&client).await { - Ok(device) => { - println!("OpenAI device-code login started."); - println!("Visit: {}", device.verification_uri); - println!("Code: {}", device.user_code); - if let Some(uri_complete) = &device.verification_uri_complete { - println!("Fast link: {uri_complete}"); + match provider.as_str() { + "gemini" => { + // Gemini OAuth flow + if device_code { + match auth::gemini_oauth::start_device_code_flow(&client).await { + Ok(device) => { + println!("Google/Gemini device-code login started."); + println!("Visit: {}", device.verification_uri); + println!("Code: {}", device.user_code); + if let Some(uri_complete) = &device.verification_uri_complete { + println!("Fast link: {uri_complete}"); + } + + let token_set = + auth::gemini_oauth::poll_device_code_tokens(&client, &device) + .await?; + let account_id = token_set.id_token.as_deref().and_then( + auth::gemini_oauth::extract_account_email_from_id_token, + ); + + auth_service + .store_gemini_tokens(&profile, token_set, account_id, true) + .await?; + + println!("Saved profile {profile}"); + println!("Active profile for gemini: {profile}"); + return Ok(()); + } + Err(e) => { + println!( + "Device-code flow unavailable: {e}. Falling back to browser flow." + ); + } } - if let Some(message) = &device.message { - println!("{message}"); + } + + let pkce = auth::gemini_oauth::generate_pkce_state(); + let authorize_url = auth::gemini_oauth::build_authorize_url(&pkce)?; + + // Save pending login for paste-redirect fallback + let pending = PendingOAuthLogin { + provider: "gemini".to_string(), + profile: profile.clone(), + code_verifier: pkce.code_verifier.clone(), + state: pkce.state.clone(), + created_at: chrono::Utc::now().to_rfc3339(), + }; + save_pending_oauth_login(config, &pending)?; + + println!("Open this URL in your browser and authorize access:"); + println!("{authorize_url}"); + println!(); + + let code = match auth::gemini_oauth::receive_loopback_code( + &pkce.state, + std::time::Duration::from_secs(180), + ) + .await + { + Ok(code) => { + clear_pending_oauth_login(config, "gemini"); + code } + Err(e) => { + println!("Callback capture failed: {e}"); + println!( + "Run `zeroclaw auth paste-redirect --provider gemini --profile {profile}`" + ); + return Ok(()); + } + }; - let token_set = - auth::openai_oauth::poll_device_code_tokens(&client, &device).await?; - let account_id = - extract_openai_account_id_for_profile(&token_set.access_token); + let token_set = + auth::gemini_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; + let account_id = token_set + .id_token + .as_deref() + .and_then(auth::gemini_oauth::extract_account_email_from_id_token); - auth_service - .store_openai_tokens(&profile, token_set, account_id, true) - .await?; - clear_pending_openai_login(config); + auth_service + .store_gemini_tokens(&profile, token_set, account_id, true) + .await?; - println!("Saved profile {profile}"); - println!("Active profile for openai-codex: {profile}"); - return Ok(()); - } - Err(e) => { - println!( - "Device-code flow unavailable: {e}. Falling back to browser/paste flow." - ); + println!("Saved profile {profile}"); + println!("Active profile for gemini: {profile}"); + Ok(()) + } + "openai-codex" => { + // OpenAI Codex OAuth flow + if device_code { + match auth::openai_oauth::start_device_code_flow(&client).await { + Ok(device) => { + println!("OpenAI device-code login started."); + println!("Visit: {}", device.verification_uri); + println!("Code: {}", device.user_code); + if let Some(uri_complete) = &device.verification_uri_complete { + println!("Fast link: {uri_complete}"); + } + if let Some(message) = &device.message { + println!("{message}"); + } + + let token_set = + auth::openai_oauth::poll_device_code_tokens(&client, &device) + .await?; + let account_id = + extract_openai_account_id_for_profile(&token_set.access_token); + + auth_service + .store_openai_tokens(&profile, token_set, account_id, true) + .await?; + clear_pending_oauth_login(config, "openai"); + + println!("Saved profile {profile}"); + println!("Active profile for openai-codex: {profile}"); + return Ok(()); + } + Err(e) => { + println!( + "Device-code flow unavailable: {e}. Falling back to browser/paste flow." + ); + } + } } + + let pkce = auth::openai_oauth::generate_pkce_state(); + let pending = PendingOAuthLogin { + provider: "openai".to_string(), + profile: profile.clone(), + code_verifier: pkce.code_verifier.clone(), + state: pkce.state.clone(), + created_at: chrono::Utc::now().to_rfc3339(), + }; + save_pending_oauth_login(config, &pending)?; + + let authorize_url = auth::openai_oauth::build_authorize_url(&pkce); + println!("Open this URL in your browser and authorize access:"); + println!("{authorize_url}"); + println!(); + println!("Waiting for callback at http://localhost:1455/auth/callback ..."); + + let code = match auth::openai_oauth::receive_loopback_code( + &pkce.state, + std::time::Duration::from_secs(180), + ) + .await + { + Ok(code) => code, + Err(e) => { + println!("Callback capture failed: {e}"); + println!( + "Run `zeroclaw auth paste-redirect --provider openai-codex --profile {profile}`" + ); + return Ok(()); + } + }; + + let token_set = + auth::openai_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; + let account_id = extract_openai_account_id_for_profile(&token_set.access_token); + + auth_service + .store_openai_tokens(&profile, token_set, account_id, true) + .await?; + clear_pending_oauth_login(config, "openai"); + + println!("Saved profile {profile}"); + println!("Active profile for openai-codex: {profile}"); + Ok(()) + } + _ => { + bail!( + "`auth login` supports --provider openai-codex or gemini, got: {provider}" + ); } } - - let pkce = auth::openai_oauth::generate_pkce_state(); - let pending = PendingOpenAiLogin { - profile: profile.clone(), - code_verifier: pkce.code_verifier.clone(), - state: pkce.state.clone(), - created_at: chrono::Utc::now().to_rfc3339(), - }; - save_pending_openai_login(config, &pending)?; - - let authorize_url = auth::openai_oauth::build_authorize_url(&pkce); - println!("Open this URL in your browser and authorize access:"); - println!("{authorize_url}"); - println!(); - println!("Waiting for callback at http://localhost:1455/auth/callback ..."); - - let code = match auth::openai_oauth::receive_loopback_code( - &pkce.state, - std::time::Duration::from_secs(180), - ) - .await - { - Ok(code) => code, - Err(e) => { - println!("Callback capture failed: {e}"); - println!( - "Run `zeroclaw auth paste-redirect --provider openai-codex --profile {profile}`" - ); - return Ok(()); - } - }; - - let token_set = - auth::openai_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; - let account_id = extract_openai_account_id_for_profile(&token_set.access_token); - - auth_service - .store_openai_tokens(&profile, token_set, account_id, true) - .await?; - clear_pending_openai_login(config); - - println!("Saved profile {profile}"); - println!("Active profile for openai-codex: {profile}"); - Ok(()) } AuthCommands::PasteRedirect { @@ -1198,52 +1301,103 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res input, } => { let provider = auth::normalize_provider(&provider)?; - if provider != "openai-codex" { - bail!("`auth paste-redirect` currently supports only --provider openai-codex"); + + match provider.as_str() { + "openai-codex" => { + let pending = load_pending_oauth_login(config, "openai")?.ok_or_else(|| { + anyhow::anyhow!( + "No pending OpenAI login found. Run `zeroclaw auth login --provider openai-codex` first." + ) + })?; + + if pending.profile != profile { + bail!( + "Pending login profile mismatch: pending={}, requested={}", + pending.profile, + profile + ); + } + + let redirect_input = match input { + Some(value) => value, + None => read_plain_input("Paste redirect URL or OAuth code")?, + }; + + let code = auth::openai_oauth::parse_code_from_redirect( + &redirect_input, + Some(&pending.state), + )?; + + let pkce = auth::openai_oauth::PkceState { + code_verifier: pending.code_verifier.clone(), + code_challenge: String::new(), + state: pending.state.clone(), + }; + + let client = reqwest::Client::new(); + let token_set = + auth::openai_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; + let account_id = extract_openai_account_id_for_profile(&token_set.access_token); + + auth_service + .store_openai_tokens(&profile, token_set, account_id, true) + .await?; + clear_pending_oauth_login(config, "openai"); + + println!("Saved profile {profile}"); + println!("Active profile for openai-codex: {profile}"); + } + "gemini" => { + let pending = load_pending_oauth_login(config, "gemini")?.ok_or_else(|| { + anyhow::anyhow!( + "No pending Gemini login found. Run `zeroclaw auth login --provider gemini` first." + ) + })?; + + if pending.profile != profile { + bail!( + "Pending login profile mismatch: pending={}, requested={}", + pending.profile, + profile + ); + } + + let redirect_input = match input { + Some(value) => value, + None => read_plain_input("Paste redirect URL or OAuth code")?, + }; + + let code = auth::gemini_oauth::parse_code_from_redirect( + &redirect_input, + Some(&pending.state), + )?; + + let pkce = auth::gemini_oauth::PkceState { + code_verifier: pending.code_verifier.clone(), + code_challenge: String::new(), + state: pending.state.clone(), + }; + + let client = reqwest::Client::new(); + let token_set = + auth::gemini_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; + let account_id = token_set + .id_token + .as_deref() + .and_then(auth::gemini_oauth::extract_account_email_from_id_token); + + auth_service + .store_gemini_tokens(&profile, token_set, account_id, true) + .await?; + clear_pending_oauth_login(config, "gemini"); + + println!("Saved profile {profile}"); + println!("Active profile for gemini: {profile}"); + } + _ => { + bail!("`auth paste-redirect` supports --provider openai-codex or gemini"); + } } - - let pending = load_pending_openai_login(config)?.ok_or_else(|| { - anyhow::anyhow!( - "No pending OpenAI login found. Run `zeroclaw auth login --provider openai-codex` first." - ) - })?; - - if pending.profile != profile { - bail!( - "Pending login profile mismatch: pending={}, requested={}", - pending.profile, - profile - ); - } - - let redirect_input = match input { - Some(value) => value, - None => read_plain_input("Paste redirect URL or OAuth code")?, - }; - - let code = auth::openai_oauth::parse_code_from_redirect( - &redirect_input, - Some(&pending.state), - )?; - - let pkce = auth::openai_oauth::PkceState { - code_verifier: pending.code_verifier.clone(), - code_challenge: String::new(), - state: pending.state.clone(), - }; - - let client = reqwest::Client::new(); - let token_set = - auth::openai_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; - let account_id = extract_openai_account_id_for_profile(&token_set.access_token); - - auth_service - .store_openai_tokens(&profile, token_set, account_id, true) - .await?; - clear_pending_openai_login(config); - - println!("Saved profile {profile}"); - println!("Active profile for openai-codex: {profile}"); Ok(()) } diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 786661c03..ae393f40c 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -1,8 +1,10 @@ //! Google Gemini provider with support for: //! - Direct API key (`GEMINI_API_KEY` env var or config) //! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication) +//! - ZeroClaw auth-profiles OAuth tokens //! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`) +use crate::auth::AuthService; use crate::providers::traits::{ChatMessage, ChatResponse, Provider, TokenUsage}; use async_trait::async_trait; use directories::UserDirs; @@ -17,6 +19,10 @@ pub struct GeminiProvider { oauth_project: Arc>>, oauth_cred_paths: Vec, oauth_index: Arc>, + /// AuthService for managed profiles (auth-profiles.json). + auth_service: Option, + /// Override profile name for managed auth. + auth_profile_override: Option, } /// Mutable OAuth token state — supports runtime refresh for long-lived processes. @@ -41,6 +47,9 @@ enum GeminiAuth { /// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`. /// Wrapped in a Mutex to allow runtime token refresh. OAuthToken(Arc>), + /// OAuth token managed by AuthService (auth-profiles.json). + /// Token refresh is handled by AuthService, not here. + ManagedOAuth, } impl GeminiAuth { @@ -52,9 +61,9 @@ impl GeminiAuth { ) } - /// Whether this credential is an OAuth token from Gemini CLI. + /// Whether this credential is an OAuth token (CLI or managed). fn is_oauth(&self) -> bool { - matches!(self, GeminiAuth::OAuthToken(_)) + matches!(self, GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth) } /// The raw credential string (for API key variants only). @@ -63,7 +72,7 @@ impl GeminiAuth { GeminiAuth::ExplicitKey(s) | GeminiAuth::EnvGeminiKey(s) | GeminiAuth::EnvGoogleKey(s) => s, - GeminiAuth::OAuthToken(_) => "", + GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => "", } } } @@ -397,6 +406,81 @@ impl GeminiProvider { oauth_project: Arc::new(tokio::sync::Mutex::new(None)), oauth_cred_paths, oauth_index: Arc::new(tokio::sync::Mutex::new(0)), + auth_service: None, + auth_profile_override: None, + } + } + + /// Create a new Gemini provider with managed OAuth from auth-profiles.json. + /// + /// Authentication priority: + /// 1. Explicit API key passed in + /// 2. `GEMINI_API_KEY` environment variable + /// 3. `GOOGLE_API_KEY` environment variable + /// 4. Managed OAuth from auth-profiles.json (if auth_service provided) + /// 5. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`) + pub fn new_with_auth( + api_key: Option<&str>, + auth_service: AuthService, + profile_override: Option, + ) -> Self { + let oauth_cred_paths = Self::discover_oauth_cred_paths(); + + // First check API keys + let resolved_auth = api_key + .and_then(Self::normalize_non_empty) + .map(GeminiAuth::ExplicitKey) + .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey)) + .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey)); + + // If no API key, we'll use managed OAuth (checked at runtime) + // or fall back to CLI OAuth + let (auth, use_managed) = if resolved_auth.is_some() { + (resolved_auth, false) + } else { + // Check if we have a managed profile - this is a blocking check + // but we need to know at construction time + let has_managed = std::thread::scope(|s| { + s.spawn(|| { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .ok()?; + rt.block_on(async { + auth_service + .get_gemini_profile(profile_override.as_deref()) + .await + .ok() + .flatten() + }) + }) + .join() + .ok() + .flatten() + .is_some() + }); + + if has_managed { + (Some(GeminiAuth::ManagedOAuth), true) + } else { + // Fall back to CLI OAuth + let cli_auth = Self::try_load_gemini_cli_token(oauth_cred_paths.first()) + .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state)))); + (cli_auth, false) + } + }; + + Self { + auth, + oauth_project: Arc::new(tokio::sync::Mutex::new(None)), + oauth_cred_paths, + oauth_index: Arc::new(tokio::sync::Mutex::new(0)), + auth_service: if use_managed { + Some(auth_service) + } else { + None + }, + auth_profile_override: profile_override, } } @@ -473,7 +557,7 @@ impl GeminiProvider { creds.expiry.as_deref().and_then(|expiry| { chrono::DateTime::parse_from_rfc3339(expiry) .ok() - .and_then(|dt| i64::try_from(dt.timestamp_millis()).ok()) + .map(|dt| dt.timestamp_millis()) }) }); @@ -537,6 +621,7 @@ impl GeminiProvider { Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var", Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var", Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth", + Some(GeminiAuth::ManagedOAuth) => "auth-profiles", None => "none", } } @@ -644,10 +729,9 @@ impl GeminiProvider { /// See: https://github.com/google-gemini/gemini-cli/issues/19200 fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String { match auth { - GeminiAuth::OAuthToken(_) => { - // OAuth tokens from Gemini CLI are scoped for the internal - // Code Assist API. The model is passed in the request body, - // not the URL path. + GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => { + // OAuth tokens are scoped for the internal Code Assist API. + // The model is passed in the request body, not the URL path. format!("{CLOUDCODE_PA_ENDPOINT}:generateContent") } _ => { @@ -811,8 +895,9 @@ impl GeminiProvider { "Gemini API key not found. Options:\n\ 1. Set GEMINI_API_KEY env var\n\ 2. Run `gemini` CLI to authenticate (tokens will be reused)\n\ - 3. Get an API key from https://aistudio.google.com/app/apikey\n\ - 4. Run `zeroclaw onboard` to configure" + 3. Run `zeroclaw auth login --provider gemini`\n\ + 4. Get an API key from https://aistudio.google.com/app/apikey\n\ + 5. Run `zeroclaw onboard` to configure" ) })?; @@ -822,12 +907,29 @@ impl GeminiProvider { }; // For OAuth: get a valid (potentially refreshed) token and resolve project - let (mut oauth_token, mut project) = if let GeminiAuth::OAuthToken(state) = auth { - let token = Self::get_valid_oauth_token(state).await?; - let proj = self.resolve_oauth_project(&token).await?; - (Some(token), Some(proj)) - } else { - (None, None) + let (mut oauth_token, mut project) = match auth { + GeminiAuth::OAuthToken(state) => { + let token = Self::get_valid_oauth_token(state).await?; + let proj = self.resolve_oauth_project(&token).await?; + (Some(token), Some(proj)) + } + GeminiAuth::ManagedOAuth => { + let auth_service = self + .auth_service + .as_ref() + .ok_or_else(|| anyhow::anyhow!("ManagedOAuth requires auth_service"))?; + let token = auth_service + .get_valid_gemini_access_token(self.auth_profile_override.as_deref()) + .await? + .ok_or_else(|| { + anyhow::anyhow!( + "Gemini auth profile not found. Run `zeroclaw auth login --provider gemini`." + ) + })?; + let proj = self.resolve_oauth_project(&token).await?; + (Some(token), Some(proj)) + } + _ => (None, None), }; let request = GenerateContentRequest { @@ -859,27 +961,55 @@ impl GeminiProvider { let error_text = response.text().await.unwrap_or_default(); if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) { - if let Some(state) = oauth_state.as_ref() { - if self.rotate_oauth_credential(state).await { - let token = Self::get_valid_oauth_token(state).await?; - let proj = self.resolve_oauth_project(&token).await?; - oauth_token = Some(token); - project = Some(proj); - response = self - .build_generate_content_request( - auth, - &url, - &request, - model, - true, - project.as_deref(), - oauth_token.as_deref(), - ) - .send() - .await?; - } else { - anyhow::bail!("Gemini API error ({status}): {error_text}"); + // For CLI OAuth: rotate credentials + // For ManagedOAuth: AuthService handles refresh, just retry + let can_retry = match auth { + GeminiAuth::OAuthToken(_) => { + if let Some(state) = oauth_state.as_ref() { + self.rotate_oauth_credential(state).await + } else { + false + } } + GeminiAuth::ManagedOAuth => true, // AuthService refreshes automatically + _ => false, + }; + + if can_retry { + // Re-fetch token (may be refreshed) + let (new_token, new_project) = match auth { + GeminiAuth::OAuthToken(state) => { + let token = Self::get_valid_oauth_token(state).await?; + let proj = self.resolve_oauth_project(&token).await?; + (token, proj) + } + GeminiAuth::ManagedOAuth => { + let auth_service = self.auth_service.as_ref().unwrap(); + let token = auth_service + .get_valid_gemini_access_token( + self.auth_profile_override.as_deref(), + ) + .await? + .ok_or_else(|| anyhow::anyhow!("Gemini auth profile not found"))?; + let proj = self.resolve_oauth_project(&token).await?; + (token, proj) + } + _ => unreachable!(), + }; + oauth_token = Some(new_token); + project = Some(new_project); + response = self + .build_generate_content_request( + auth, + &url, + &request, + model, + true, + project.as_deref(), + oauth_token.as_deref(), + ) + .send() + .await?; } else { anyhow::bail!("Gemini API error ({status}): {error_text}"); } @@ -1143,6 +1273,8 @@ mod tests { oauth_project: Arc::new(tokio::sync::Mutex::new(None)), oauth_cred_paths: Vec::new(), oauth_index: Arc::new(tokio::sync::Mutex::new(0)), + auth_service: None, + auth_profile_override: None, } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index d11233103..eeb6e5624 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -36,6 +36,7 @@ pub use traits::{ ToolCall, ToolResultMessage, }; +use crate::auth::AuthService; use compatible::{AuthStyle, OpenAiCompatibleProvider}; use reliable::ReliableProvider; use serde::Deserialize; @@ -960,7 +961,21 @@ fn create_provider_with_url_and_options( options.reasoning_enabled, ))), "gemini" | "google" | "google-gemini" => { - Ok(Box::new(gemini::GeminiProvider::new(key))) + let state_dir = options + .zeroclaw_dir + .clone() + .unwrap_or_else(|| { + directories::UserDirs::new().map_or_else( + || PathBuf::from(".zeroclaw"), + |dirs| dirs.home_dir().join(".zeroclaw"), + ) + }); + let auth_service = AuthService::new(&state_dir, options.secrets_encrypt); + Ok(Box::new(gemini::GeminiProvider::new_with_auth( + key, + auth_service, + options.auth_profile_override.clone(), + ))) } "telnyx" => Ok(Box::new(telnyx::TelnyxProvider::new(key))),